From ebc479ef2ccc965273fe2992ec5f38f6f82dc4be Mon Sep 17 00:00:00 2001 From: magodo Date: Mon, 27 May 2019 21:17:57 +0800 Subject: [PATCH] `Wait()` option now accept *sync.WaitGroup The original signature accept a boolean, and it feel like a little verbose, since when people pass in this option, he/she always want to pass a `true`. Now if input `wg` is nil, it has same effect as passing `true` in original code. Furthermore, if user want's finer grained control during shutdown, one can pass in a predefined `wg`, so that server will wait against it during shutdown. --- function.go | 2 +- server/context.go | 11 ++++++----- server/options.go | 11 +++++++++-- server/rpc_server.go | 23 ++++++++++++++++------- 4 files changed, 32 insertions(+), 15 deletions(-) diff --git a/function.go b/function.go index 4379a4b3..a9cc0d89 100644 --- a/function.go +++ b/function.go @@ -54,7 +54,7 @@ func newFunction(opts ...Option) Function { service.Server().Init( // ensure the service waits for requests to finish - server.Wait(true), + server.Wait(nil), // wrap handlers and subscribers to finish execution server.WrapHandler(fnHandlerWrapper(fn)), server.WrapSubscriber(fnSubWrapper(fn)), diff --git a/server/context.go b/server/context.go index de01624d..2bc3aec6 100644 --- a/server/context.go +++ b/server/context.go @@ -2,19 +2,20 @@ package server import ( "context" + "sync" ) type serverKey struct{} -func wait(ctx context.Context) bool { +func wait(ctx context.Context) *sync.WaitGroup { if ctx == nil { - return false + return nil } - wait, ok := ctx.Value("wait").(bool) + wg, ok := ctx.Value("wait").(*sync.WaitGroup) if !ok { - return false + return nil } - return wait + return wg } func FromContext(ctx context.Context) (Server, bool) { diff --git a/server/options.go b/server/options.go index b0bae35a..15de9b83 100644 --- a/server/options.go +++ b/server/options.go @@ -2,6 +2,7 @@ package server import ( "context" + "sync" "time" "github.com/micro/go-micro/broker" @@ -198,12 +199,18 @@ func WithRouter(r Router) Option { } // Wait tells the server to wait for requests to finish before exiting -func Wait(b bool) Option { +// If `wg` is nil, server only wait for completion of rpc handler. +// For user need finer grained control, pass a concrete `wg` here, server will +// wait against it on stop. +func Wait(wg *sync.WaitGroup) Option { return func(o *Options) { if o.Context == nil { o.Context = context.Background() } - o.Context = context.WithValue(o.Context, "wait", b) + if wg == nil { + wg = new(sync.WaitGroup) + } + o.Context = context.WithValue(o.Context, "wait", wg) } } diff --git a/server/rpc_server.go b/server/rpc_server.go index 4b7ef194..69f880b3 100644 --- a/server/rpc_server.go +++ b/server/rpc_server.go @@ -31,7 +31,7 @@ type rpcServer struct { // used for first registration registered bool // graceful exit - wg sync.WaitGroup + wg *sync.WaitGroup } func newRpcServer(opts ...Option) Server { @@ -42,6 +42,7 @@ func newRpcServer(opts ...Option) Server { handlers: make(map[string]Handler), subscribers: make(map[*subscriber][]broker.Subscriber), exit: make(chan chan error), + wg: wait(options.Context), } } @@ -63,8 +64,10 @@ func (s *rpcServer) ServeConn(sock transport.Socket) { return } - // add to wait group - s.wg.Add(1) + // add to wait group if "wait" is opt-in + if s.wg != nil { + s.wg.Add(1) + } // we use this Timeout header to set a server deadline to := msg.Header["Timeout"] @@ -111,7 +114,9 @@ func (s *rpcServer) ServeConn(sock transport.Socket) { }, Body: []byte(err.Error()), }) - s.wg.Done() + if s.wg != nil { + s.wg.Done() + } return } } @@ -167,12 +172,16 @@ func (s *rpcServer) ServeConn(sock transport.Socket) { if err != nil { log.Logf("rpc: unable to write error response: %v", err) } - s.wg.Done() + if s.wg != nil { + s.wg.Done() + } return } // done - s.wg.Done() + if s.wg != nil { + s.wg.Done() + } } } @@ -555,7 +564,7 @@ func (s *rpcServer) Start() error { } // wait for requests to finish - if wait(s.opts.Context) { + if s.wg != nil { s.wg.Wait() }