Merge pull request #488 from magodo/wait_accept_custom_wg
`Wait()` option now accept *sync.WaitGroup
This commit is contained in:
		| @@ -54,7 +54,7 @@ func newFunction(opts ...Option) Function { | |||||||
|  |  | ||||||
| 	service.Server().Init( | 	service.Server().Init( | ||||||
| 		// ensure the service waits for requests to finish | 		// ensure the service waits for requests to finish | ||||||
| 		server.Wait(true), | 		server.Wait(nil), | ||||||
| 		// wrap handlers and subscribers to finish execution | 		// wrap handlers and subscribers to finish execution | ||||||
| 		server.WrapHandler(fnHandlerWrapper(fn)), | 		server.WrapHandler(fnHandlerWrapper(fn)), | ||||||
| 		server.WrapSubscriber(fnSubWrapper(fn)), | 		server.WrapSubscriber(fnSubWrapper(fn)), | ||||||
|   | |||||||
| @@ -2,19 +2,20 @@ package server | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
|  | 	"sync" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type serverKey struct{} | type serverKey struct{} | ||||||
|  |  | ||||||
| func wait(ctx context.Context) bool { | func wait(ctx context.Context) *sync.WaitGroup { | ||||||
| 	if ctx == nil { | 	if ctx == nil { | ||||||
| 		return false | 		return nil | ||||||
| 	} | 	} | ||||||
| 	wait, ok := ctx.Value("wait").(bool) | 	wg, ok := ctx.Value("wait").(*sync.WaitGroup) | ||||||
| 	if !ok { | 	if !ok { | ||||||
| 		return false | 		return nil | ||||||
| 	} | 	} | ||||||
| 	return wait | 	return wg | ||||||
| } | } | ||||||
|  |  | ||||||
| func FromContext(ctx context.Context) (Server, bool) { | func FromContext(ctx context.Context) (Server, bool) { | ||||||
|   | |||||||
| @@ -2,6 +2,7 @@ package server | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
|  | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/micro/go-micro/broker" | 	"github.com/micro/go-micro/broker" | ||||||
| @@ -198,12 +199,18 @@ func WithRouter(r Router) Option { | |||||||
| } | } | ||||||
|  |  | ||||||
| // Wait tells the server to wait for requests to finish before exiting | // Wait tells the server to wait for requests to finish before exiting | ||||||
| func Wait(b bool) Option { | // If `wg` is nil, server only wait for completion of rpc handler. | ||||||
|  | // For user need finer grained control, pass a concrete `wg` here, server will | ||||||
|  | // wait against it on stop. | ||||||
|  | func Wait(wg *sync.WaitGroup) Option { | ||||||
| 	return func(o *Options) { | 	return func(o *Options) { | ||||||
| 		if o.Context == nil { | 		if o.Context == nil { | ||||||
| 			o.Context = context.Background() | 			o.Context = context.Background() | ||||||
| 		} | 		} | ||||||
| 		o.Context = context.WithValue(o.Context, "wait", b) | 		if wg == nil { | ||||||
|  | 			wg = new(sync.WaitGroup) | ||||||
|  | 		} | ||||||
|  | 		o.Context = context.WithValue(o.Context, "wait", wg) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -31,7 +31,7 @@ type rpcServer struct { | |||||||
| 	// used for first registration | 	// used for first registration | ||||||
| 	registered bool | 	registered bool | ||||||
| 	// graceful exit | 	// graceful exit | ||||||
| 	wg sync.WaitGroup | 	wg *sync.WaitGroup | ||||||
| } | } | ||||||
|  |  | ||||||
| func newRpcServer(opts ...Option) Server { | func newRpcServer(opts ...Option) Server { | ||||||
| @@ -42,6 +42,7 @@ func newRpcServer(opts ...Option) Server { | |||||||
| 		handlers:    make(map[string]Handler), | 		handlers:    make(map[string]Handler), | ||||||
| 		subscribers: make(map[*subscriber][]broker.Subscriber), | 		subscribers: make(map[*subscriber][]broker.Subscriber), | ||||||
| 		exit:        make(chan chan error), | 		exit:        make(chan chan error), | ||||||
|  | 		wg:          wait(options.Context), | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -63,8 +64,10 @@ func (s *rpcServer) ServeConn(sock transport.Socket) { | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// add to wait group | 		// add to wait group if "wait" is opt-in | ||||||
| 		s.wg.Add(1) | 		if s.wg != nil { | ||||||
|  | 			s.wg.Add(1) | ||||||
|  | 		} | ||||||
|  |  | ||||||
| 		// we use this Timeout header to set a server deadline | 		// we use this Timeout header to set a server deadline | ||||||
| 		to := msg.Header["Timeout"] | 		to := msg.Header["Timeout"] | ||||||
| @@ -111,7 +114,9 @@ func (s *rpcServer) ServeConn(sock transport.Socket) { | |||||||
| 					}, | 					}, | ||||||
| 					Body: []byte(err.Error()), | 					Body: []byte(err.Error()), | ||||||
| 				}) | 				}) | ||||||
| 				s.wg.Done() | 				if s.wg != nil { | ||||||
|  | 					s.wg.Done() | ||||||
|  | 				} | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| @@ -167,12 +172,16 @@ func (s *rpcServer) ServeConn(sock transport.Socket) { | |||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				log.Logf("rpc: unable to write error response: %v", err) | 				log.Logf("rpc: unable to write error response: %v", err) | ||||||
| 			} | 			} | ||||||
| 			s.wg.Done() | 			if s.wg != nil { | ||||||
|  | 				s.wg.Done() | ||||||
|  | 			} | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// done | 		// done | ||||||
| 		s.wg.Done() | 		if s.wg != nil { | ||||||
|  | 			s.wg.Done() | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -555,7 +564,7 @@ func (s *rpcServer) Start() error { | |||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// wait for requests to finish | 		// wait for requests to finish | ||||||
| 		if wait(s.opts.Context) { | 		if s.wg != nil { | ||||||
| 			s.wg.Wait() | 			s.wg.Wait() | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user