From c258ff3ca4a8b0833689c50721296d094c625fbf Mon Sep 17 00:00:00 2001 From: Asim Aslam Date: Sun, 9 Aug 2020 15:43:41 +0100 Subject: [PATCH] embed grpc server stream and client so they can be accessed (#1916) --- codec.go | 10 +++++----- grpc.go | 20 +++++++++++++------- stream.go | 10 ++++++---- 3 files changed, 24 insertions(+), 16 deletions(-) diff --git a/codec.go b/codec.go index 597420b..89597f0 100644 --- a/codec.go +++ b/codec.go @@ -128,18 +128,18 @@ func (bytesCodec) Name() string { } type grpcCodec struct { + grpc.ServerStream // headers id string target string method string endpoint string - s grpc.ServerStream c encoding.Codec } func (g *grpcCodec) ReadHeader(m *codec.Message, mt codec.MessageType) error { - md, _ := metadata.FromIncomingContext(g.s.Context()) + md, _ := metadata.FromIncomingContext(g.ServerStream.Context()) if m == nil { m = new(codec.Message) } @@ -159,9 +159,9 @@ func (g *grpcCodec) ReadHeader(m *codec.Message, mt codec.MessageType) error { func (g *grpcCodec) ReadBody(v interface{}) error { // caller has requested a frame if f, ok := v.(*bytes.Frame); ok { - return g.s.RecvMsg(f) + return g.ServerStream.RecvMsg(f) } - return g.s.RecvMsg(v) + return g.ServerStream.RecvMsg(v) } func (g *grpcCodec) Write(m *codec.Message, v interface{}) error { @@ -174,7 +174,7 @@ func (g *grpcCodec) Write(m *codec.Message, v interface{}) error { m.Body = b } // write the body using the framing codec - return g.s.SendMsg(&bytes.Frame{Data: m.Body}) + return g.ServerStream.SendMsg(&bytes.Frame{Data: m.Body}) } func (g *grpcCodec) Close() error { diff --git a/grpc.go b/grpc.go index c75649e..6dadbfd 100644 --- a/grpc.go +++ b/grpc.go @@ -265,11 +265,11 @@ func (g *grpcServer) handler(srv interface{}, stream grpc.ServerStream) (err err return errors.InternalServerError(g.opts.Name, err.Error()) } codec := &grpcCodec{ - method: fmt.Sprintf("%s.%s", serviceName, methodName), - endpoint: fmt.Sprintf("%s.%s", serviceName, methodName), - target: g.opts.Name, - s: stream, - c: cc, + ServerStream: stream, + method: fmt.Sprintf("%s.%s", serviceName, methodName), + endpoint: fmt.Sprintf("%s.%s", serviceName, methodName), + target: g.opts.Name, + c: cc, } // create a client.Request @@ -394,8 +394,10 @@ func (g *grpcServer) processRequest(stream grpc.ServerStream, service *service, for i := len(g.opts.HdlrWrappers); i > 0; i-- { fn = g.opts.HdlrWrappers[i-1](fn) } + statusCode := codes.OK statusDesc := "" + // execute the handler if appErr := fn(ctx, r, replyv.Interface()); appErr != nil { var errStatus *status.Status @@ -411,6 +413,7 @@ func (g *grpcServer) processRequest(stream grpc.ServerStream, service *service, // micro.Error now proto based and we can attach it to grpc status statusCode = microError(verr) statusDesc = verr.Error() + errStatus, err = status.New(statusCode, statusDesc).WithDetails(perr) if err != nil { return err @@ -428,6 +431,7 @@ func (g *grpcServer) processRequest(stream grpc.ServerStream, service *service, statusCode = convertCode(verr) statusDesc = verr.Error() errStatus = status.New(statusCode, statusDesc) + fmt.Printf("Responding with :%v\n", errStatus) } return errStatus.Err() @@ -436,6 +440,7 @@ func (g *grpcServer) processRequest(stream grpc.ServerStream, service *service, if err := stream.SendMsg(replyv.Interface()); err != nil { return err } + return status.New(statusCode, statusDesc).Err() } } @@ -451,8 +456,8 @@ func (g *grpcServer) processStream(stream grpc.ServerStream, service *service, m } ss := &rpcStream{ - request: r, - s: stream, + ServerStream: stream, + request: r, } function := mtype.method.Func @@ -507,6 +512,7 @@ func (g *grpcServer) processStream(stream grpc.ServerStream, service *service, m statusDesc = verr.Error() errStatus = status.New(statusCode, statusDesc) } + return errStatus.Err() } diff --git a/stream.go b/stream.go index 38139c5..adb2293 100644 --- a/stream.go +++ b/stream.go @@ -9,7 +9,9 @@ import ( // rpcStream implements a server side Stream. type rpcStream struct { - s grpc.ServerStream + // embed the grpc stream so we can access it + grpc.ServerStream + request server.Request } @@ -26,13 +28,13 @@ func (r *rpcStream) Request() server.Request { } func (r *rpcStream) Context() context.Context { - return r.s.Context() + return r.ServerStream.Context() } func (r *rpcStream) Send(m interface{}) error { - return r.s.SendMsg(m) + return r.ServerStream.SendMsg(m) } func (r *rpcStream) Recv(m interface{}) error { - return r.s.RecvMsg(m) + return r.ServerStream.RecvMsg(m) }