diff --git a/client/rpc_client.go b/client/rpc_client.go index cdf2eeca..9d96491d 100644 --- a/client/rpc_client.go +++ b/client/rpc_client.go @@ -63,6 +63,9 @@ func (r *rpcClient) call(ctx context.Context, address string, req Request, resp } } + // set timeout in nanoseconds + msg.Header["Timeout"] = fmt.Sprintf("%d", opts.RequestTimeout) + // set the content type for the request msg.Header["Content-Type"] = req.ContentType() cf, err := r.newCodec(req.ContentType()) @@ -81,13 +84,11 @@ func (r *rpcClient) call(ctx context.Context, address string, req Request, resp closed: make(chan bool), codec: newRpcPlusCodec(msg, c, cf), } + defer stream.Close() ch := make(chan error, 1) go func() { - // defer stream close - defer stream.Close() - // send request if err := stream.Send(req.Request()); err != nil { ch <- err @@ -105,12 +106,11 @@ func (r *rpcClient) call(ctx context.Context, address string, req Request, resp }() select { - case err = <-ch: - case <-time.After(opts.RequestTimeout): - err = errors.New("go.micro.client", "request timeout", 408) + case err := <-ch: + return err + case <-ctx.Done(): + return ctx.Err() } - - return err } func (r *rpcClient) stream(ctx context.Context, address string, req Request, opts CallOptions) (Streamer, error) { @@ -125,6 +125,9 @@ func (r *rpcClient) stream(ctx context.Context, address string, req Request, opt } } + // set timeout in nanoseconds + msg.Header["Timeout"] = fmt.Sprintf("%d", opts.RequestTimeout) + // set the content type for the request msg.Header["Content-Type"] = req.ContentType() cf, err := r.newCodec(req.ContentType()) @@ -150,13 +153,21 @@ func (r *rpcClient) stream(ctx context.Context, address string, req Request, opt ch <- stream.Send(req.Request()) }() + var grr error + select { - case err = <-ch: - case <-time.After(opts.RequestTimeout): - err = errors.New("go.micro.client", "request timeout", 408) + case err := <-ch: + grr = err + case <-ctx.Done(): + grr = ctx.Err() } - return stream, err + if grr != nil { + stream.Close() + return nil, grr + } + + return stream, nil } func (r *rpcClient) Init(opts ...Option) error { @@ -173,22 +184,20 @@ func (r *rpcClient) Options() Options { func (r *rpcClient) CallRemote(ctx context.Context, address string, request Request, response interface{}, opts ...CallOption) error { // make a copy of call opts callOpts := r.opts.CallOptions - for _, opt := range opts { opt(&callOpts) } - return r.call(ctx, address, request, response, callOpts) } func (r *rpcClient) Call(ctx context.Context, request Request, response interface{}, opts ...CallOption) error { // make a copy of call opts callOpts := r.opts.CallOptions - for _, opt := range opts { opt(&callOpts) } + // get next nodes from the selector next, err := r.opts.Selector.Select(request.Service(), callOpts.SelectOptions...) if err != nil && err == selector.ErrNotFound { return errors.NotFound("go.micro.client", err.Error()) @@ -196,9 +205,20 @@ func (r *rpcClient) Call(ctx context.Context, request Request, response interfac return errors.InternalServerError("go.micro.client", err.Error()) } - var grr error + // check if we already have a deadline + d, ok := ctx.Deadline() + if !ok { + // no deadline so we create a new one + ctx, _ = context.WithTimeout(ctx, callOpts.RequestTimeout) + } else { + // got a deadline so no need to setup context + // but we need to set the timeout we pass along + opt := WithRequestTimeout(d.Sub(time.Now())) + opt(&callOpts) + } - for i := 0; i < callOpts.Retries; i++ { + // return errors.New("go.micro.client", "request timeout", 408) + call := func(i int) error { // call backoff first. Someone may want an initial start delay t, err := callOpts.Backoff(ctx, request, i) if err != nil { @@ -210,6 +230,7 @@ func (r *rpcClient) Call(ctx context.Context, request Request, response interfac time.Sleep(t) } + // select next node node, err := next() if err != nil && err == selector.ErrNotFound { return errors.NotFound("go.micro.client", err.Error()) @@ -217,42 +238,58 @@ func (r *rpcClient) Call(ctx context.Context, request Request, response interfac return errors.InternalServerError("go.micro.client", err.Error()) } + // set the address address := node.Address if node.Port > 0 { address = fmt.Sprintf("%s:%d", address, node.Port) } - grr = r.call(ctx, address, request, response, callOpts) - r.opts.Selector.Mark(request.Service(), node, grr) + // make the call + err = r.call(ctx, address, request, response, callOpts) + r.opts.Selector.Mark(request.Service(), node, err) + return err + } - // if the call succeeded lets bail early - if grr == nil { - return nil + ch := make(chan error, callOpts.Retries) + var gerr error + + for i := 0; i < callOpts.Retries; i++ { + go func() { + ch <- call(i) + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-ch: + // if the call succeeded lets bail early + if err == nil { + return nil + } + gerr = err } } - return grr + return gerr } func (r *rpcClient) StreamRemote(ctx context.Context, address string, request Request, opts ...CallOption) (Streamer, error) { // make a copy of call opts callOpts := r.opts.CallOptions - for _, opt := range opts { opt(&callOpts) } - return r.stream(ctx, address, request, callOpts) } func (r *rpcClient) Stream(ctx context.Context, request Request, opts ...CallOption) (Streamer, error) { // make a copy of call opts callOpts := r.opts.CallOptions - for _, opt := range opts { opt(&callOpts) } + // get next nodes from the selector next, err := r.opts.Selector.Select(request.Service(), callOpts.SelectOptions...) if err != nil && err == selector.ErrNotFound { return nil, errors.NotFound("go.micro.client", err.Error()) @@ -260,10 +297,19 @@ func (r *rpcClient) Stream(ctx context.Context, request Request, opts ...CallOpt return nil, errors.InternalServerError("go.micro.client", err.Error()) } - var stream Streamer - var grr error + // check if we already have a deadline + d, ok := ctx.Deadline() + if !ok { + // no deadline so we create a new one + ctx, _ = context.WithTimeout(ctx, callOpts.RequestTimeout) + } else { + // got a deadline so no need to setup context + // but we need to set the timeout we pass along + opt := WithRequestTimeout(d.Sub(time.Now())) + opt(&callOpts) + } - for i := 0; i < callOpts.Retries; i++ { + call := func(i int) (Streamer, error) { // call backoff first. Someone may want an initial start delay t, err := callOpts.Backoff(ctx, request, i) if err != nil { @@ -287,16 +333,38 @@ func (r *rpcClient) Stream(ctx context.Context, request Request, opts ...CallOpt address = fmt.Sprintf("%s:%d", address, node.Port) } - stream, grr = r.stream(ctx, address, request, callOpts) - r.opts.Selector.Mark(request.Service(), node, grr) + stream, err := r.stream(ctx, address, request, callOpts) + r.opts.Selector.Mark(request.Service(), node, err) + return stream, err + } - // bail early if succeeds - if grr == nil { - return stream, nil + type response struct { + stream Streamer + err error + } + + ch := make(chan response, callOpts.Retries) + var grr error + + for i := 0; i < callOpts.Retries; i++ { + go func() { + s, err := call(i) + ch <- response{s, err} + }() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case rsp := <-ch: + // if the call succeeded lets bail early + if rsp.err == nil { + return rsp.stream, nil + } + grr = rsp.err } } - return stream, grr + return nil, grr } func (r *rpcClient) Publish(ctx context.Context, p Publication, opts ...PublishOption) error { diff --git a/server/rpc_server.go b/server/rpc_server.go index ea607fec..3a70d54d 100644 --- a/server/rpc_server.go +++ b/server/rpc_server.go @@ -7,6 +7,7 @@ import ( "strconv" "strings" "sync" + "time" "github.com/micro/go-micro/broker" "github.com/micro/go-micro/codec" @@ -59,7 +60,11 @@ func (s *rpcServer) accept(sock transport.Socket) { return } + // we use this Timeout header to set a server deadline + to := msg.Header["Timeout"] + // we use this Content-Type header to identify the codec needed ct := msg.Header["Content-Type"] + cf, err := s.newCodec(ct) // TODO: needs better error handling if err != nil { @@ -80,9 +85,17 @@ func (s *rpcServer) accept(sock transport.Socket) { hdr[k] = v } delete(hdr, "Content-Type") + delete(hdr, "Timeout") ctx := metadata.NewContext(context.Background(), hdr) + // set the timeout if we have it + if len(to) > 0 { + if n, err := strconv.ParseUint(to, 10, 64); err == nil { + ctx, _ = context.WithTimeout(ctx, time.Duration(n)) + } + } + // TODO: needs better error handling if err := s.rpc.serveRequest(ctx, codec, ct); err != nil { log.Printf("Unexpected error serving request, closing socket: %v", err)