diff --git a/client/grpc/grpc.go b/client/grpc/grpc.go index 1932f9ce..7571e5bf 100644 --- a/client/grpc/grpc.go +++ b/client/grpc/grpc.go @@ -6,6 +6,7 @@ import ( "crypto/tls" "fmt" "net" + "reflect" "strings" "sync/atomic" "time" @@ -173,7 +174,7 @@ func (g *grpcClient) call(ctx context.Context, node *registry.Node, req client.R return grr } -func (g *grpcClient) stream(ctx context.Context, node *registry.Node, req client.Request, opts client.CallOptions) (client.Stream, error) { +func (g *grpcClient) stream(ctx context.Context, node *registry.Node, req client.Request, rsp interface{}, opts client.CallOptions) error { var header map[string]string address := node.Address @@ -199,7 +200,7 @@ func (g *grpcClient) stream(ctx context.Context, node *registry.Node, req client cf, err := g.newGRPCCodec(req.ContentType()) if err != nil { - return nil, errors.InternalServerError("go.micro.client", err.Error()) + return errors.InternalServerError("go.micro.client", err.Error()) } var dialCtx context.Context @@ -224,7 +225,7 @@ func (g *grpcClient) stream(ctx context.Context, node *registry.Node, req client cc, err := grpc.DialContext(dialCtx, address, grpcDialOptions...) if err != nil { - return nil, errors.InternalServerError("go.micro.client", fmt.Sprintf("Error sending request: %v", err)) + return errors.InternalServerError("go.micro.client", fmt.Sprintf("Error sending request: %v", err)) } desc := &grpc.StreamDesc{ @@ -252,7 +253,7 @@ func (g *grpcClient) stream(ctx context.Context, node *registry.Node, req client // close the connection cc.Close() // now return the error - return nil, errors.InternalServerError("go.micro.client", fmt.Sprintf("Error creating stream: %v", err)) + return errors.InternalServerError("go.micro.client", fmt.Sprintf("Error creating stream: %v", err)) } codec := &grpcCodec{ @@ -265,21 +266,25 @@ func (g *grpcClient) stream(ctx context.Context, node *registry.Node, req client r.codec = codec } - rsp := &response{ - conn: cc, + // setup the stream response + stream := &grpcStream{ + context: ctx, + request: req, + response: &response{ + conn: cc, + stream: st, + codec: cf, + gcodec: codec, + }, stream: st, - codec: cf, - gcodec: codec, + conn: cc, + cancel: cancel, } - return &grpcStream{ - context: ctx, - request: req, - response: rsp, - stream: st, - conn: cc, - cancel: cancel, - }, nil + // set the stream as the response + val := reflect.ValueOf(rsp).Elem() + val.Set(reflect.ValueOf(stream).Elem()) + return nil } func (g *grpcClient) poolMaxStreams() int { @@ -506,6 +511,14 @@ func (g *grpcClient) Stream(ctx context.Context, req client.Request, opts ...cli default: } + // make a copy of stream + gstream := g.stream + + // wrap the call in reverse + for i := len(callOpts.CallWrappers); i > 0; i-- { + gstream = callOpts.CallWrappers[i-1](gstream) + } + call := func(i int) (client.Stream, error) { // call backoff first. Someone may want an initial start delay t, err := callOpts.Backoff(ctx, req, i) @@ -527,7 +540,10 @@ func (g *grpcClient) Stream(ctx context.Context, req client.Request, opts ...cli return nil, errors.InternalServerError("go.micro.client", "error selecting %s node: %s", service, err.Error()) } - stream, err := g.stream(ctx, node, req, callOpts) + // make the call + stream := &grpcStream{} + err = g.stream(ctx, node, req, stream, callOpts) + g.opts.Selector.Mark(service, node, err) return stream, err }