embed grpc server stream and client so they can be accessed (#1916)

This commit is contained in:
Asim Aslam 2020-08-09 15:43:41 +01:00 committed by GitHub
parent 69a2032dd7
commit 51f8b4ae3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 59 additions and 26 deletions

View File

@ -238,15 +238,15 @@ func (g *grpcClient) stream(ctx context.Context, node *registry.Node, req client
// setup the stream response // setup the stream response
stream := &grpcStream{ stream := &grpcStream{
context: ctx, ClientStream: st,
request: req, context: ctx,
request: req,
response: &response{ response: &response{
conn: cc, conn: cc,
stream: st, stream: st,
codec: cf, codec: cf,
gcodec: codec, gcodec: codec,
}, },
stream: st,
conn: cc, conn: cc,
cancel: cancel, cancel: cancel,
} }

View File

@ -11,11 +11,13 @@ import (
// Implements the streamer interface // Implements the streamer interface
type grpcStream struct { type grpcStream struct {
// embed so we can access if need be
grpc.ClientStream
sync.RWMutex sync.RWMutex
closed bool closed bool
err error err error
conn *grpc.ClientConn conn *grpc.ClientConn
stream grpc.ClientStream
request client.Request request client.Request
response client.Response response client.Response
context context.Context context context.Context
@ -35,7 +37,7 @@ func (g *grpcStream) Response() client.Response {
} }
func (g *grpcStream) Send(msg interface{}) error { func (g *grpcStream) Send(msg interface{}) error {
if err := g.stream.SendMsg(msg); err != nil { if err := g.ClientStream.SendMsg(msg); err != nil {
g.setError(err) g.setError(err)
return err return err
} }
@ -44,7 +46,8 @@ func (g *grpcStream) Send(msg interface{}) error {
func (g *grpcStream) Recv(msg interface{}) (err error) { func (g *grpcStream) Recv(msg interface{}) (err error) {
defer g.setError(err) defer g.setError(err)
if err = g.stream.RecvMsg(msg); err != nil {
if err = g.ClientStream.RecvMsg(msg); err != nil {
// #202 - inconsistent gRPC stream behavior // #202 - inconsistent gRPC stream behavior
// the only way to tell if the stream is done is when we get a EOF on the Recv // the only way to tell if the stream is done is when we get a EOF on the Recv
// here we should close the underlying gRPC ClientConn // here we should close the underlying gRPC ClientConn
@ -52,7 +55,10 @@ func (g *grpcStream) Recv(msg interface{}) (err error) {
if err == io.EOF && closeErr != nil { if err == io.EOF && closeErr != nil {
err = closeErr err = closeErr
} }
return err
} }
return return
} }
@ -83,6 +89,6 @@ func (g *grpcStream) Close() error {
// cancel the context // cancel the context
g.cancel() g.cancel()
g.closed = true g.closed = true
g.stream.CloseSend() g.ClientStream.CloseSend()
return g.conn.Close() return g.conn.Close()
} }

View File

@ -11,7 +11,7 @@ import (
"time" "time"
"github.com/micro/go-micro/v3/client" "github.com/micro/go-micro/v3/client"
"github.com/micro/go-micro/v3/client/grpc" grpcc "github.com/micro/go-micro/v3/client/grpc"
"github.com/micro/go-micro/v3/codec" "github.com/micro/go-micro/v3/codec"
"github.com/micro/go-micro/v3/codec/bytes" "github.com/micro/go-micro/v3/codec/bytes"
"github.com/micro/go-micro/v3/errors" "github.com/micro/go-micro/v3/errors"
@ -23,6 +23,7 @@ import (
"github.com/micro/go-micro/v3/selector" "github.com/micro/go-micro/v3/selector"
"github.com/micro/go-micro/v3/selector/roundrobin" "github.com/micro/go-micro/v3/selector/roundrobin"
"github.com/micro/go-micro/v3/server" "github.com/micro/go-micro/v3/server"
"google.golang.org/grpc"
) )
// Proxy will transparently proxy requests to an endpoint. // Proxy will transparently proxy requests to an endpoint.
@ -514,6 +515,9 @@ func (p *Proxy) serveRequest(ctx context.Context, link client.Client, service, e
return nil return nil
} }
// new context with cancel
ctx, cancel := context.WithCancel(ctx)
// create new stream // create new stream
stream, err := link.Stream(ctx, creq, opts...) stream, err := link.Stream(ctx, creq, opts...)
if err != nil { if err != nil {
@ -542,7 +546,13 @@ func (p *Proxy) serveRequest(ctx context.Context, link client.Client, service, e
} }
// create client request read loop if streaming // create client request read loop if streaming
go readLoop(req, stream) go func() {
err := readLoop(req, stream)
if err != nil && err != io.EOF {
// cancel the context
cancel()
}
}()
// get raw response // get raw response
resp := stream.Response() resp := stream.Response()
@ -551,6 +561,15 @@ func (p *Proxy) serveRequest(ctx context.Context, link client.Client, service, e
for { for {
// read backend response body // read backend response body
body, err := resp.Read() body, err := resp.Read()
if err != nil {
// when we're done if its a grpc stream we have to set the trailer
if cc, ok := stream.(grpc.ClientStream); ok {
if ss, ok := resp.Codec().(grpc.ServerStream); ok {
ss.SetTrailer(cc.Trailer())
}
}
}
if err == io.EOF { if err == io.EOF {
return nil return nil
} else if err != nil { } else if err != nil {
@ -605,7 +624,7 @@ func NewProxy(opts ...proxy.Option) proxy.Proxy {
// set the default client // set the default client
if p.Client == nil { if p.Client == nil {
p.Client = grpc.NewClient() p.Client = grpcc.NewClient()
} }
// create default router and start it // create default router and start it

View File

@ -128,18 +128,18 @@ func (bytesCodec) Name() string {
} }
type grpcCodec struct { type grpcCodec struct {
grpc.ServerStream
// headers // headers
id string id string
target string target string
method string method string
endpoint string endpoint string
s grpc.ServerStream
c encoding.Codec c encoding.Codec
} }
func (g *grpcCodec) ReadHeader(m *codec.Message, mt codec.MessageType) error { func (g *grpcCodec) ReadHeader(m *codec.Message, mt codec.MessageType) error {
md, _ := metadata.FromIncomingContext(g.s.Context()) md, _ := metadata.FromIncomingContext(g.ServerStream.Context())
if m == nil { if m == nil {
m = new(codec.Message) m = new(codec.Message)
} }
@ -159,9 +159,9 @@ func (g *grpcCodec) ReadHeader(m *codec.Message, mt codec.MessageType) error {
func (g *grpcCodec) ReadBody(v interface{}) error { func (g *grpcCodec) ReadBody(v interface{}) error {
// caller has requested a frame // caller has requested a frame
if f, ok := v.(*bytes.Frame); ok { if f, ok := v.(*bytes.Frame); ok {
return g.s.RecvMsg(f) return g.ServerStream.RecvMsg(f)
} }
return g.s.RecvMsg(v) return g.ServerStream.RecvMsg(v)
} }
func (g *grpcCodec) Write(m *codec.Message, v interface{}) error { func (g *grpcCodec) Write(m *codec.Message, v interface{}) error {
@ -174,7 +174,7 @@ func (g *grpcCodec) Write(m *codec.Message, v interface{}) error {
m.Body = b m.Body = b
} }
// write the body using the framing codec // write the body using the framing codec
return g.s.SendMsg(&bytes.Frame{Data: m.Body}) return g.ServerStream.SendMsg(&bytes.Frame{Data: m.Body})
} }
func (g *grpcCodec) Close() error { func (g *grpcCodec) Close() error {

View File

@ -265,11 +265,11 @@ func (g *grpcServer) handler(srv interface{}, stream grpc.ServerStream) (err err
return errors.InternalServerError(g.opts.Name, err.Error()) return errors.InternalServerError(g.opts.Name, err.Error())
} }
codec := &grpcCodec{ codec := &grpcCodec{
method: fmt.Sprintf("%s.%s", serviceName, methodName), ServerStream: stream,
endpoint: fmt.Sprintf("%s.%s", serviceName, methodName), method: fmt.Sprintf("%s.%s", serviceName, methodName),
target: g.opts.Name, endpoint: fmt.Sprintf("%s.%s", serviceName, methodName),
s: stream, target: g.opts.Name,
c: cc, c: cc,
} }
// create a client.Request // create a client.Request
@ -394,8 +394,10 @@ func (g *grpcServer) processRequest(stream grpc.ServerStream, service *service,
for i := len(g.opts.HdlrWrappers); i > 0; i-- { for i := len(g.opts.HdlrWrappers); i > 0; i-- {
fn = g.opts.HdlrWrappers[i-1](fn) fn = g.opts.HdlrWrappers[i-1](fn)
} }
statusCode := codes.OK statusCode := codes.OK
statusDesc := "" statusDesc := ""
// execute the handler // execute the handler
if appErr := fn(ctx, r, replyv.Interface()); appErr != nil { if appErr := fn(ctx, r, replyv.Interface()); appErr != nil {
var errStatus *status.Status var errStatus *status.Status
@ -411,6 +413,7 @@ func (g *grpcServer) processRequest(stream grpc.ServerStream, service *service,
// micro.Error now proto based and we can attach it to grpc status // micro.Error now proto based and we can attach it to grpc status
statusCode = microError(verr) statusCode = microError(verr)
statusDesc = verr.Error() statusDesc = verr.Error()
errStatus, err = status.New(statusCode, statusDesc).WithDetails(perr) errStatus, err = status.New(statusCode, statusDesc).WithDetails(perr)
if err != nil { if err != nil {
return err return err
@ -428,6 +431,7 @@ func (g *grpcServer) processRequest(stream grpc.ServerStream, service *service,
statusCode = convertCode(verr) statusCode = convertCode(verr)
statusDesc = verr.Error() statusDesc = verr.Error()
errStatus = status.New(statusCode, statusDesc) errStatus = status.New(statusCode, statusDesc)
fmt.Printf("Responding with :%v\n", errStatus)
} }
return errStatus.Err() return errStatus.Err()
@ -436,6 +440,7 @@ func (g *grpcServer) processRequest(stream grpc.ServerStream, service *service,
if err := stream.SendMsg(replyv.Interface()); err != nil { if err := stream.SendMsg(replyv.Interface()); err != nil {
return err return err
} }
return status.New(statusCode, statusDesc).Err() return status.New(statusCode, statusDesc).Err()
} }
} }
@ -451,8 +456,8 @@ func (g *grpcServer) processStream(stream grpc.ServerStream, service *service, m
} }
ss := &rpcStream{ ss := &rpcStream{
request: r, ServerStream: stream,
s: stream, request: r,
} }
function := mtype.method.Func function := mtype.method.Func
@ -507,6 +512,7 @@ func (g *grpcServer) processStream(stream grpc.ServerStream, service *service, m
statusDesc = verr.Error() statusDesc = verr.Error()
errStatus = status.New(statusCode, statusDesc) errStatus = status.New(statusCode, statusDesc)
} }
return errStatus.Err() return errStatus.Err()
} }

View File

@ -9,7 +9,9 @@ import (
// rpcStream implements a server side Stream. // rpcStream implements a server side Stream.
type rpcStream struct { type rpcStream struct {
s grpc.ServerStream // embed the grpc stream so we can access it
grpc.ServerStream
request server.Request request server.Request
} }
@ -26,13 +28,13 @@ func (r *rpcStream) Request() server.Request {
} }
func (r *rpcStream) Context() context.Context { func (r *rpcStream) Context() context.Context {
return r.s.Context() return r.ServerStream.Context()
} }
func (r *rpcStream) Send(m interface{}) error { func (r *rpcStream) Send(m interface{}) error {
return r.s.SendMsg(m) return r.ServerStream.SendMsg(m)
} }
func (r *rpcStream) Recv(m interface{}) error { func (r *rpcStream) Recv(m interface{}) error {
return r.s.RecvMsg(m) return r.ServerStream.RecvMsg(m)
} }