diff options
Diffstat (limited to 'go/vendor/google.golang.org/grpc/server.go')
-rw-r--r-- | go/vendor/google.golang.org/grpc/server.go | 216 |
1 files changed, 81 insertions, 135 deletions
diff --git a/go/vendor/google.golang.org/grpc/server.go b/go/vendor/google.golang.org/grpc/server.go index b15f71c..985226d 100644 --- a/go/vendor/google.golang.org/grpc/server.go +++ b/go/vendor/google.golang.org/grpc/server.go @@ -53,10 +53,8 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/internal" - "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" "google.golang.org/grpc/stats" - "google.golang.org/grpc/status" "google.golang.org/grpc/tap" "google.golang.org/grpc/transport" ) @@ -118,9 +116,6 @@ type options struct { statsHandler stats.Handler maxConcurrentStreams uint32 useHandlerImpl bool // use http.Handler-based server - unknownStreamDesc *StreamDesc - keepaliveParams keepalive.ServerParameters - keepalivePolicy keepalive.EnforcementPolicy } var defaultMaxMsgSize = 1024 * 1024 * 4 // use 4MB as the default message size limit @@ -128,20 +123,6 @@ var defaultMaxMsgSize = 1024 * 1024 * 4 // use 4MB as the default message size l // A ServerOption sets options. type ServerOption func(*options) -// KeepaliveParams returns a ServerOption that sets keepalive and max-age parameters for the server. -func KeepaliveParams(kp keepalive.ServerParameters) ServerOption { - return func(o *options) { - o.keepaliveParams = kp - } -} - -// KeepaliveEnforcementPolicy returns a ServerOption that sets keepalive enforcement policy for the server. -func KeepaliveEnforcementPolicy(kep keepalive.EnforcementPolicy) ServerOption { - return func(o *options) { - o.keepalivePolicy = kep - } -} - // CustomCodec returns a ServerOption that sets a codec for message marshaling and unmarshaling. func CustomCodec(codec Codec) ServerOption { return func(o *options) { @@ -227,24 +208,6 @@ func StatsHandler(h stats.Handler) ServerOption { } } -// UnknownServiceHandler returns a ServerOption that allows for adding a custom -// unknown service handler. The provided method is a bidi-streaming RPC service -// handler that will be invoked instead of returning the the "unimplemented" gRPC -// error whenever a request is received for an unregistered service or method. -// The handling function has full access to the Context of the request and the -// stream, and the invocation passes through interceptors. -func UnknownServiceHandler(streamHandler StreamHandler) ServerOption { - return func(o *options) { - o.unknownStreamDesc = &StreamDesc{ - StreamName: "unknown_service_handler", - Handler: streamHandler, - // We need to assume that the users of the streamHandler will want to use both. - ClientStreams: true, - ServerStreams: true, - } - } -} - // NewServer creates a gRPC server which has no service registered and has not // started to accept requests yet. func NewServer(opt ...ServerOption) *Server { @@ -483,12 +446,10 @@ func (s *Server) handleRawConn(rawConn net.Conn) { // transport.NewServerTransport). func (s *Server) serveHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) { config := &transport.ServerConfig{ - MaxStreams: s.opts.maxConcurrentStreams, - AuthInfo: authInfo, - InTapHandle: s.opts.inTapHandle, - StatsHandler: s.opts.statsHandler, - KeepaliveParams: s.opts.keepaliveParams, - KeepalivePolicy: s.opts.keepalivePolicy, + MaxStreams: s.opts.maxConcurrentStreams, + AuthInfo: authInfo, + InTapHandle: s.opts.inTapHandle, + StatsHandler: s.opts.statsHandler, } st, err := transport.NewServerTransport("http2", c, config) if err != nil { @@ -672,7 +633,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. stream.SetSendCompress(s.opts.cp.Type()) } p := &parser{r: stream} - for { // TODO: delete + for { pf, req, err := p.recvMsg(s.opts.maxMsgSize) if err == io.EOF { // The entire stream is done (for unary RPC only). @@ -682,37 +643,36 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. err = Errorf(codes.Internal, io.ErrUnexpectedEOF.Error()) } if err != nil { - if st, ok := status.FromError(err); ok { - if e := t.WriteStatus(stream, st); e != nil { + switch err := err.(type) { + case *rpcError: + if e := t.WriteStatus(stream, err.code, err.desc); e != nil { grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e) } - } else { - switch st := err.(type) { - case transport.ConnectionError: - // Nothing to do here. - case transport.StreamError: - if e := t.WriteStatus(stream, status.New(st.Code, st.Desc)); e != nil { - grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e) - } - default: - panic(fmt.Sprintf("grpc: Unexpected error (%T) from recvMsg: %v", st, st)) + case transport.ConnectionError: + // Nothing to do here. + case transport.StreamError: + if e := t.WriteStatus(stream, err.Code, err.Desc); e != nil { + grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e) } + default: + panic(fmt.Sprintf("grpc: Unexpected error (%T) from recvMsg: %v", err, err)) } return err } if err := checkRecvPayload(pf, stream.RecvCompress(), s.opts.dc); err != nil { - if st, ok := status.FromError(err); ok { - if e := t.WriteStatus(stream, st); e != nil { + switch err := err.(type) { + case *rpcError: + if e := t.WriteStatus(stream, err.code, err.desc); e != nil { grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e) } return err + default: + if e := t.WriteStatus(stream, codes.Internal, err.Error()); e != nil { + grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e) + } + // TODO checkRecvPayload always return RPC error. Add a return here if necessary. } - if e := t.WriteStatus(stream, status.New(codes.Internal, err.Error())); e != nil { - grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e) - } - - // TODO checkRecvPayload always return RPC error. Add a return here if necessary. } var inPayload *stats.InPayload if sh != nil { @@ -720,6 +680,8 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. RecvTime: time.Now(), } } + statusCode := codes.OK + statusDesc := "" df := func(v interface{}) error { if inPayload != nil { inPayload.WireLength = len(req) @@ -728,16 +690,20 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. var err error req, err = s.opts.dc.Do(bytes.NewReader(req)) if err != nil { + if err := t.WriteStatus(stream, codes.Internal, err.Error()); err != nil { + grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err) + } return Errorf(codes.Internal, err.Error()) } } if len(req) > s.opts.maxMsgSize { // TODO: Revisit the error code. Currently keep it consistent with // java implementation. - return status.Errorf(codes.Internal, "grpc: server received a message of %d bytes exceeding %d limit", len(req), s.opts.maxMsgSize) + statusCode = codes.Internal + statusDesc = fmt.Sprintf("grpc: server received a message of %d bytes exceeding %d limit", len(req), s.opts.maxMsgSize) } if err := s.opts.codec.Unmarshal(req, v); err != nil { - return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err) + return err } if inPayload != nil { inPayload.Payload = v @@ -752,20 +718,21 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } reply, appErr := md.Handler(srv.server, stream.Context(), df, s.opts.unaryInt) if appErr != nil { - appStatus, ok := status.FromError(appErr) - if !ok { - // Convert appErr if it is not a grpc status error. - appErr = status.Error(convertCode(appErr), appErr.Error()) - appStatus, _ = status.FromError(appErr) + if err, ok := appErr.(*rpcError); ok { + statusCode = err.code + statusDesc = err.desc + } else { + statusCode = convertCode(appErr) + statusDesc = appErr.Error() } - if trInfo != nil { - trInfo.tr.LazyLog(stringer(appStatus.Message()), true) + if trInfo != nil && statusCode != codes.OK { + trInfo.tr.LazyLog(stringer(statusDesc), true) trInfo.tr.SetError() } - if e := t.WriteStatus(stream, appStatus); e != nil { - grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", e) + if err := t.WriteStatus(stream, statusCode, statusDesc); err != nil { + grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", err) } - return appErr + return Errorf(statusCode, statusDesc) } if trInfo != nil { trInfo.tr.LazyLog(stringer("OK"), false) @@ -775,35 +742,26 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. Delay: false, } if err := s.sendResponse(t, stream, reply, s.opts.cp, opts); err != nil { - if err == io.EOF { - // The entire stream is done (for unary RPC only). - return err - } - if s, ok := status.FromError(err); ok { - if e := t.WriteStatus(stream, s); e != nil { - grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", e) - } - } else { - switch st := err.(type) { - case transport.ConnectionError: - // Nothing to do here. - case transport.StreamError: - if e := t.WriteStatus(stream, status.New(st.Code, st.Desc)); e != nil { - grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e) - } - default: - panic(fmt.Sprintf("grpc: Unexpected error (%T) from sendResponse: %v", st, st)) - } + switch err := err.(type) { + case transport.ConnectionError: + // Nothing to do here. + case transport.StreamError: + statusCode = err.Code + statusDesc = err.Desc + default: + statusCode = codes.Unknown + statusDesc = err.Error() } return err } if trInfo != nil { trInfo.tr.LazyLog(&payload{sent: true, msg: reply}, true) } - // TODO: Should we be logging if writing status failed here, like above? - // Should the logging be in WriteStatus? Should we ignore the WriteStatus - // error or allow the stats handler to see it? - return t.WriteStatus(stream, status.New(codes.OK, "")) + errWrite := t.WriteStatus(stream, statusCode, statusDesc) + if statusCode != codes.OK { + return Errorf(statusCode, statusDesc) + } + return errWrite } } @@ -857,47 +815,43 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp }() } var appErr error - var server interface{} - if srv != nil { - server = srv.server - } if s.opts.streamInt == nil { - appErr = sd.Handler(server, ss) + appErr = sd.Handler(srv.server, ss) } else { info := &StreamServerInfo{ FullMethod: stream.Method(), IsClientStream: sd.ClientStreams, IsServerStream: sd.ServerStreams, } - appErr = s.opts.streamInt(server, ss, info, sd.Handler) + appErr = s.opts.streamInt(srv.server, ss, info, sd.Handler) } if appErr != nil { - appStatus, ok := status.FromError(appErr) - if !ok { - switch err := appErr.(type) { - case transport.StreamError: - appStatus = status.New(err.Code, err.Desc) - default: - appStatus = status.New(convertCode(appErr), appErr.Error()) - } - appErr = appStatus.Err() - } - if trInfo != nil { - ss.mu.Lock() - ss.trInfo.tr.LazyLog(stringer(appStatus.Message()), true) - ss.trInfo.tr.SetError() - ss.mu.Unlock() + if err, ok := appErr.(*rpcError); ok { + ss.statusCode = err.code + ss.statusDesc = err.desc + } else if err, ok := appErr.(transport.StreamError); ok { + ss.statusCode = err.Code + ss.statusDesc = err.Desc + } else { + ss.statusCode = convertCode(appErr) + ss.statusDesc = appErr.Error() } - t.WriteStatus(ss.s, appStatus) - // TODO: Should we log an error from WriteStatus here and below? - return appErr } if trInfo != nil { ss.mu.Lock() - ss.trInfo.tr.LazyLog(stringer("OK"), false) + if ss.statusCode != codes.OK { + ss.trInfo.tr.LazyLog(stringer(ss.statusDesc), true) + ss.trInfo.tr.SetError() + } else { + ss.trInfo.tr.LazyLog(stringer("OK"), false) + } ss.mu.Unlock() } - return t.WriteStatus(ss.s, status.New(codes.OK, "")) + errWrite := t.WriteStatus(ss.s, ss.statusCode, ss.statusDesc) + if ss.statusCode != codes.OK { + return Errorf(ss.statusCode, ss.statusDesc) + } + return errWrite } @@ -913,7 +867,7 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str trInfo.tr.SetError() } errDesc := fmt.Sprintf("malformed method name: %q", stream.Method()) - if err := t.WriteStatus(stream, status.New(codes.InvalidArgument, errDesc)); err != nil { + if err := t.WriteStatus(stream, codes.InvalidArgument, errDesc); err != nil { if trInfo != nil { trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) trInfo.tr.SetError() @@ -929,16 +883,12 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str method := sm[pos+1:] srv, ok := s.m[service] if !ok { - if unknownDesc := s.opts.unknownStreamDesc; unknownDesc != nil { - s.processStreamingRPC(t, stream, nil, unknownDesc, trInfo) - return - } if trInfo != nil { trInfo.tr.LazyLog(&fmtStringer{"Unknown service %v", []interface{}{service}}, true) trInfo.tr.SetError() } errDesc := fmt.Sprintf("unknown service %v", service) - if err := t.WriteStatus(stream, status.New(codes.Unimplemented, errDesc)); err != nil { + if err := t.WriteStatus(stream, codes.Unimplemented, errDesc); err != nil { if trInfo != nil { trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) trInfo.tr.SetError() @@ -963,12 +913,8 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str trInfo.tr.LazyLog(&fmtStringer{"Unknown method %v", []interface{}{method}}, true) trInfo.tr.SetError() } - if unknownDesc := s.opts.unknownStreamDesc; unknownDesc != nil { - s.processStreamingRPC(t, stream, nil, unknownDesc, trInfo) - return - } errDesc := fmt.Sprintf("unknown method %v", method) - if err := t.WriteStatus(stream, status.New(codes.Unimplemented, errDesc)); err != nil { + if err := t.WriteStatus(stream, codes.Unimplemented, errDesc); err != nil { if trInfo != nil { trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) trInfo.tr.SetError() |