diff --git a/client/grpc/grpc.go b/client/grpc/grpc.go index e2a0f9d3..98d40f56 100644 --- a/client/grpc/grpc.go +++ b/client/grpc/grpc.go @@ -238,15 +238,15 @@ func (g *grpcClient) stream(ctx context.Context, node *registry.Node, req client // setup the stream response stream := &grpcStream{ - context: ctx, - request: req, + ClientStream: st, + context: ctx, + request: req, response: &response{ conn: cc, stream: st, codec: cf, gcodec: codec, }, - stream: st, conn: cc, cancel: cancel, } diff --git a/client/grpc/stream.go b/client/grpc/stream.go index 5051718d..d27493b3 100644 --- a/client/grpc/stream.go +++ b/client/grpc/stream.go @@ -11,11 +11,13 @@ import ( // Implements the streamer interface type grpcStream struct { + // embed so we can access if need be + grpc.ClientStream + sync.RWMutex closed bool err error conn *grpc.ClientConn - stream grpc.ClientStream request client.Request response client.Response context context.Context @@ -35,7 +37,7 @@ func (g *grpcStream) Response() client.Response { } func (g *grpcStream) Send(msg interface{}) error { - if err := g.stream.SendMsg(msg); err != nil { + if err := g.ClientStream.SendMsg(msg); err != nil { g.setError(err) return err } @@ -44,7 +46,8 @@ func (g *grpcStream) Send(msg interface{}) error { func (g *grpcStream) Recv(msg interface{}) (err error) { defer g.setError(err) - if err = g.stream.RecvMsg(msg); err != nil { + + if err = g.ClientStream.RecvMsg(msg); err != nil { // #202 - inconsistent gRPC stream behavior // the only way to tell if the stream is done is when we get a EOF on the Recv // here we should close the underlying gRPC ClientConn @@ -52,7 +55,10 @@ func (g *grpcStream) Recv(msg interface{}) (err error) { if err == io.EOF && closeErr != nil { err = closeErr } + + return err } + return } @@ -83,6 +89,6 @@ func (g *grpcStream) Close() error { // cancel the context g.cancel() g.closed = true - g.stream.CloseSend() + g.ClientStream.CloseSend() return g.conn.Close() } diff --git a/proxy/mucp/mucp.go b/proxy/mucp/mucp.go index b6a5764e..405c637a 100644 --- a/proxy/mucp/mucp.go +++ b/proxy/mucp/mucp.go @@ -11,7 +11,7 @@ import ( "time" "github.com/micro/go-micro/v3/client" - "github.com/micro/go-micro/v3/client/grpc" + grpcc "github.com/micro/go-micro/v3/client/grpc" "github.com/micro/go-micro/v3/codec" "github.com/micro/go-micro/v3/codec/bytes" "github.com/micro/go-micro/v3/errors" @@ -23,6 +23,7 @@ import ( "github.com/micro/go-micro/v3/selector" "github.com/micro/go-micro/v3/selector/roundrobin" "github.com/micro/go-micro/v3/server" + "google.golang.org/grpc" ) // Proxy will transparently proxy requests to an endpoint. @@ -514,6 +515,9 @@ func (p *Proxy) serveRequest(ctx context.Context, link client.Client, service, e return nil } + // new context with cancel + ctx, cancel := context.WithCancel(ctx) + // create new stream stream, err := link.Stream(ctx, creq, opts...) if err != nil { @@ -542,7 +546,13 @@ func (p *Proxy) serveRequest(ctx context.Context, link client.Client, service, e } // create client request read loop if streaming - go readLoop(req, stream) + go func() { + err := readLoop(req, stream) + if err != nil && err != io.EOF { + // cancel the context + cancel() + } + }() // get raw response resp := stream.Response() @@ -551,6 +561,15 @@ func (p *Proxy) serveRequest(ctx context.Context, link client.Client, service, e for { // read backend response body body, err := resp.Read() + if err != nil { + // when we're done if its a grpc stream we have to set the trailer + if cc, ok := stream.(grpc.ClientStream); ok { + if ss, ok := resp.Codec().(grpc.ServerStream); ok { + ss.SetTrailer(cc.Trailer()) + } + } + } + if err == io.EOF { return nil } else if err != nil { @@ -605,7 +624,7 @@ func NewProxy(opts ...proxy.Option) proxy.Proxy { // set the default client if p.Client == nil { - p.Client = grpc.NewClient() + p.Client = grpcc.NewClient() } // create default router and start it diff --git a/server/grpc/codec.go b/server/grpc/codec.go index 597420bf..89597f07 100644 --- a/server/grpc/codec.go +++ b/server/grpc/codec.go @@ -128,18 +128,18 @@ func (bytesCodec) Name() string { } type grpcCodec struct { + grpc.ServerStream // headers id string target string method string endpoint string - s grpc.ServerStream c encoding.Codec } func (g *grpcCodec) ReadHeader(m *codec.Message, mt codec.MessageType) error { - md, _ := metadata.FromIncomingContext(g.s.Context()) + md, _ := metadata.FromIncomingContext(g.ServerStream.Context()) if m == nil { m = new(codec.Message) } @@ -159,9 +159,9 @@ func (g *grpcCodec) ReadHeader(m *codec.Message, mt codec.MessageType) error { func (g *grpcCodec) ReadBody(v interface{}) error { // caller has requested a frame if f, ok := v.(*bytes.Frame); ok { - return g.s.RecvMsg(f) + return g.ServerStream.RecvMsg(f) } - return g.s.RecvMsg(v) + return g.ServerStream.RecvMsg(v) } func (g *grpcCodec) Write(m *codec.Message, v interface{}) error { @@ -174,7 +174,7 @@ func (g *grpcCodec) Write(m *codec.Message, v interface{}) error { m.Body = b } // write the body using the framing codec - return g.s.SendMsg(&bytes.Frame{Data: m.Body}) + return g.ServerStream.SendMsg(&bytes.Frame{Data: m.Body}) } func (g *grpcCodec) Close() error { diff --git a/server/grpc/grpc.go b/server/grpc/grpc.go index c75649e6..6dadbfdd 100644 --- a/server/grpc/grpc.go +++ b/server/grpc/grpc.go @@ -265,11 +265,11 @@ func (g *grpcServer) handler(srv interface{}, stream grpc.ServerStream) (err err return errors.InternalServerError(g.opts.Name, err.Error()) } codec := &grpcCodec{ - method: fmt.Sprintf("%s.%s", serviceName, methodName), - endpoint: fmt.Sprintf("%s.%s", serviceName, methodName), - target: g.opts.Name, - s: stream, - c: cc, + ServerStream: stream, + method: fmt.Sprintf("%s.%s", serviceName, methodName), + endpoint: fmt.Sprintf("%s.%s", serviceName, methodName), + target: g.opts.Name, + c: cc, } // create a client.Request @@ -394,8 +394,10 @@ func (g *grpcServer) processRequest(stream grpc.ServerStream, service *service, for i := len(g.opts.HdlrWrappers); i > 0; i-- { fn = g.opts.HdlrWrappers[i-1](fn) } + statusCode := codes.OK statusDesc := "" + // execute the handler if appErr := fn(ctx, r, replyv.Interface()); appErr != nil { var errStatus *status.Status @@ -411,6 +413,7 @@ func (g *grpcServer) processRequest(stream grpc.ServerStream, service *service, // micro.Error now proto based and we can attach it to grpc status statusCode = microError(verr) statusDesc = verr.Error() + errStatus, err = status.New(statusCode, statusDesc).WithDetails(perr) if err != nil { return err @@ -428,6 +431,7 @@ func (g *grpcServer) processRequest(stream grpc.ServerStream, service *service, statusCode = convertCode(verr) statusDesc = verr.Error() errStatus = status.New(statusCode, statusDesc) + fmt.Printf("Responding with :%v\n", errStatus) } return errStatus.Err() @@ -436,6 +440,7 @@ func (g *grpcServer) processRequest(stream grpc.ServerStream, service *service, if err := stream.SendMsg(replyv.Interface()); err != nil { return err } + return status.New(statusCode, statusDesc).Err() } } @@ -451,8 +456,8 @@ func (g *grpcServer) processStream(stream grpc.ServerStream, service *service, m } ss := &rpcStream{ - request: r, - s: stream, + ServerStream: stream, + request: r, } function := mtype.method.Func @@ -507,6 +512,7 @@ func (g *grpcServer) processStream(stream grpc.ServerStream, service *service, m statusDesc = verr.Error() errStatus = status.New(statusCode, statusDesc) } + return errStatus.Err() } diff --git a/server/grpc/stream.go b/server/grpc/stream.go index 38139c54..adb22936 100644 --- a/server/grpc/stream.go +++ b/server/grpc/stream.go @@ -9,7 +9,9 @@ import ( // rpcStream implements a server side Stream. type rpcStream struct { - s grpc.ServerStream + // embed the grpc stream so we can access it + grpc.ServerStream + request server.Request } @@ -26,13 +28,13 @@ func (r *rpcStream) Request() server.Request { } func (r *rpcStream) Context() context.Context { - return r.s.Context() + return r.ServerStream.Context() } func (r *rpcStream) Send(m interface{}) error { - return r.s.SendMsg(m) + return r.ServerStream.SendMsg(m) } func (r *rpcStream) Recv(m interface{}) error { - return r.s.RecvMsg(m) + return r.ServerStream.RecvMsg(m) }