diff --git a/client/options.go b/client/options.go index 9f363d74..0ba57ad7 100644 --- a/client/options.go +++ b/client/options.go @@ -59,6 +59,9 @@ type CallOptions struct { // Middleware for low level call func CallWrappers []CallWrapper + // SendEOS specifies whether to send EOS + SendEOS bool + // Other options for implementations of the interface // can be stored in a context Context context.Context @@ -305,6 +308,13 @@ func WithDialTimeout(d time.Duration) CallOption { } } +// SendEOS specifies whether to send the end of stream message +func SendEOS(b bool) CallOption { + return func(o *CallOptions) { + o.SendEOS = b + } +} + // Request Options func WithContentType(ct string) RequestOption { diff --git a/client/rpc_client.go b/client/rpc_client.go index 91d11a40..5449dcd9 100644 --- a/client/rpc_client.go +++ b/client/rpc_client.go @@ -96,15 +96,10 @@ func (r *rpcClient) call(ctx context.Context, node *registry.Node, req Request, } } - var grr error c, err := r.pool.Get(address, transport.WithTimeout(opts.DialTimeout)) if err != nil { return errors.InternalServerError("go.micro.client", "connection error: %v", err) } - defer func() { - // defer execution of release - r.pool.Release(c, grr) - }() seq := atomic.LoadUint64(&r.seq) atomic.AddUint64(&r.seq, 1) @@ -116,15 +111,19 @@ func (r *rpcClient) call(ctx context.Context, node *registry.Node, req Request, } stream := &rpcStream{ + id: fmt.Sprintf("%v", seq), context: ctx, request: req, response: rsp, codec: codec, closed: make(chan bool), - id: fmt.Sprintf("%v", seq), + release: func(err error) { r.pool.Release(c, err) }, + sendEOS: opts.SendEOS, } + // close the stream on exiting this function defer stream.Close() + // wait for error response ch := make(chan error, 1) go func() { @@ -150,14 +149,26 @@ func (r *rpcClient) call(ctx context.Context, node *registry.Node, req Request, ch <- nil }() + var grr error + select { case err := <-ch: grr = err return err case <-ctx.Done(): - grr = ctx.Err() - return errors.Timeout("go.micro.client", fmt.Sprintf("%v", ctx.Err())) + grr = errors.Timeout("go.micro.client", fmt.Sprintf("%v", ctx.Err())) } + + // set the stream error + if grr != nil { + stream.Lock() + stream.err = grr + stream.Unlock() + + return grr + } + + return nil } func (r *rpcClient) stream(ctx context.Context, node *registry.Node, req Request, opts CallOptions) (Stream, error) { @@ -201,7 +212,7 @@ func (r *rpcClient) stream(ctx context.Context, node *registry.Node, req Request dOpts = append(dOpts, transport.WithTimeout(opts.DialTimeout)) } - c, err := r.opts.Transport.Dial(address, dOpts...) + c, err := r.pool.Get(address, dOpts...) if err != nil { return nil, errors.InternalServerError("go.micro.client", "connection error: %v", err) } @@ -225,19 +236,24 @@ func (r *rpcClient) stream(ctx context.Context, node *registry.Node, req Request } stream := &rpcStream{ + id: id, context: ctx, request: req, response: rsp, - closed: make(chan bool), codec: codec, - id: id, + // used to close the stream + closed: make(chan bool), // signal the end of stream, - eos: true, + sendEOS: opts.SendEOS, + // release func + release: func(err error) { r.pool.Release(c, err) }, } + // wait for error response ch := make(chan error, 1) go func() { + // send the first message ch <- stream.Send(req.Body()) }() @@ -251,6 +267,12 @@ func (r *rpcClient) stream(ctx context.Context, node *registry.Node, req Request } if grr != nil { + // set the error + stream.Lock() + stream.err = grr + stream.Unlock() + + // close the stream stream.Close() return nil, grr } diff --git a/client/rpc_stream.go b/client/rpc_stream.go index 269e6299..f904d2dd 100644 --- a/client/rpc_stream.go +++ b/client/rpc_stream.go @@ -20,7 +20,10 @@ type rpcStream struct { context context.Context // signal whether we should send EOS - eos bool + sendEOS bool + + // release releases the connection back to the pool + release func(err error) } func (r *rpcStream) isClosed() bool { @@ -125,7 +128,7 @@ func (r *rpcStream) Close() error { close(r.closed) // send the end of stream message - if r.eos { + if r.sendEOS { // no need to check for error r.codec.Write(&codec.Message{ Id: r.id, @@ -137,6 +140,12 @@ func (r *rpcStream) Close() error { }, nil) } - return r.codec.Close() + err := r.codec.Close() + + // release the connection + r.release(r.Error()) + + // return the codec error + return err } } diff --git a/proxy/mucp/mucp.go b/proxy/mucp/mucp.go index 9df28aec..e4303d16 100644 --- a/proxy/mucp/mucp.go +++ b/proxy/mucp/mucp.go @@ -42,11 +42,6 @@ type Proxy struct { // read client request and write to server func readLoop(r server.Request, s client.Stream) error { - // we don't loop unless its a stream - if !r.Stream() { - return nil - } - // request to backend server req := s.Request() @@ -225,6 +220,11 @@ func (p *Proxy) ServeRequest(ctx context.Context, req server.Request, rsp server // create new request with raw bytes body creq := p.Client.NewRequest(service, endpoint, &bytes.Frame{body}, client.WithContentType(req.ContentType())) + if !req.Stream() { + // specify not to send eos + opts = append(opts, client.SendEOS(false)) + } + // create new stream stream, err := p.Client.Stream(ctx, creq, opts...) if err != nil { @@ -232,8 +232,10 @@ func (p *Proxy) ServeRequest(ctx context.Context, req server.Request, rsp server } defer stream.Close() - // create client request read loop - go readLoop(req, stream) + // create client request read loop if streaming + if req.Stream() { + go readLoop(req, stream) + } // get raw response resp := stream.Response()