Apply wrappers to gRPC streams (#1675)

* Add wrappers to grpc streams

* Fix typo
This commit is contained in:
ben-toogood 2020-06-02 17:56:26 +01:00 committed by Dominic Wong
parent b270860b79
commit f45cdba9ba

View File

@ -6,6 +6,7 @@ import (
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"net" "net"
"reflect"
"strings" "strings"
"sync/atomic" "sync/atomic"
"time" "time"
@ -173,7 +174,7 @@ func (g *grpcClient) call(ctx context.Context, node *registry.Node, req client.R
return grr return grr
} }
func (g *grpcClient) stream(ctx context.Context, node *registry.Node, req client.Request, opts client.CallOptions) (client.Stream, error) { func (g *grpcClient) stream(ctx context.Context, node *registry.Node, req client.Request, rsp interface{}, opts client.CallOptions) error {
var header map[string]string var header map[string]string
address := node.Address address := node.Address
@ -199,7 +200,7 @@ func (g *grpcClient) stream(ctx context.Context, node *registry.Node, req client
cf, err := g.newGRPCCodec(req.ContentType()) cf, err := g.newGRPCCodec(req.ContentType())
if err != nil { if err != nil {
return nil, errors.InternalServerError("go.micro.client", err.Error()) return errors.InternalServerError("go.micro.client", err.Error())
} }
var dialCtx context.Context var dialCtx context.Context
@ -224,7 +225,7 @@ func (g *grpcClient) stream(ctx context.Context, node *registry.Node, req client
cc, err := grpc.DialContext(dialCtx, address, grpcDialOptions...) cc, err := grpc.DialContext(dialCtx, address, grpcDialOptions...)
if err != nil { if err != nil {
return nil, errors.InternalServerError("go.micro.client", fmt.Sprintf("Error sending request: %v", err)) return errors.InternalServerError("go.micro.client", fmt.Sprintf("Error sending request: %v", err))
} }
desc := &grpc.StreamDesc{ desc := &grpc.StreamDesc{
@ -252,7 +253,7 @@ func (g *grpcClient) stream(ctx context.Context, node *registry.Node, req client
// close the connection // close the connection
cc.Close() cc.Close()
// now return the error // now return the error
return nil, errors.InternalServerError("go.micro.client", fmt.Sprintf("Error creating stream: %v", err)) return errors.InternalServerError("go.micro.client", fmt.Sprintf("Error creating stream: %v", err))
} }
codec := &grpcCodec{ codec := &grpcCodec{
@ -265,21 +266,25 @@ func (g *grpcClient) stream(ctx context.Context, node *registry.Node, req client
r.codec = codec r.codec = codec
} }
rsp := &response{ // setup the stream response
conn: cc, stream := &grpcStream{
context: ctx,
request: req,
response: &response{
conn: cc,
stream: st,
codec: cf,
gcodec: codec,
},
stream: st, stream: st,
codec: cf, conn: cc,
gcodec: codec, cancel: cancel,
} }
return &grpcStream{ // set the stream as the response
context: ctx, val := reflect.ValueOf(rsp).Elem()
request: req, val.Set(reflect.ValueOf(stream).Elem())
response: rsp, return nil
stream: st,
conn: cc,
cancel: cancel,
}, nil
} }
func (g *grpcClient) poolMaxStreams() int { func (g *grpcClient) poolMaxStreams() int {
@ -506,6 +511,14 @@ func (g *grpcClient) Stream(ctx context.Context, req client.Request, opts ...cli
default: default:
} }
// make a copy of stream
gstream := g.stream
// wrap the call in reverse
for i := len(callOpts.CallWrappers); i > 0; i-- {
gstream = callOpts.CallWrappers[i-1](gstream)
}
call := func(i int) (client.Stream, error) { call := func(i int) (client.Stream, error) {
// call backoff first. Someone may want an initial start delay // call backoff first. Someone may want an initial start delay
t, err := callOpts.Backoff(ctx, req, i) t, err := callOpts.Backoff(ctx, req, i)
@ -527,7 +540,10 @@ func (g *grpcClient) Stream(ctx context.Context, req client.Request, opts ...cli
return nil, errors.InternalServerError("go.micro.client", "error selecting %s node: %s", service, err.Error()) return nil, errors.InternalServerError("go.micro.client", "error selecting %s node: %s", service, err.Error())
} }
stream, err := g.stream(ctx, node, req, callOpts) // make the call
stream := &grpcStream{}
err = g.stream(ctx, node, req, stream, callOpts)
g.opts.Selector.Mark(service, node, err) g.opts.Selector.Mark(service, node, err)
return stream, err return stream, err
} }