diff --git a/server/context.go b/server/context.go index 88d19257..627caa54 100644 --- a/server/context.go +++ b/server/context.go @@ -6,6 +6,17 @@ import ( type serverKey struct{} +func wait(ctx context.Context) bool { + if ctx == nil { + return false + } + wait, ok := ctx.Value("wait").(bool) + if !ok { + return false + } + return wait +} + func FromContext(ctx context.Context) (Server, bool) { c, ok := ctx.Value(serverKey{}).(Server) return c, ok diff --git a/server/options.go b/server/options.go index edaeceea..c4602498 100644 --- a/server/options.go +++ b/server/options.go @@ -165,6 +165,16 @@ func RegisterTTL(t time.Duration) Option { } } +// Wait tells the server to wait for requests to finish before exiting +func Wait(b bool) Option { + return func(o *Options) { + if o.Context == nil { + o.Context = context.Background() + } + o.Context = context.WithValue(o.Context, "wait", b) + } +} + // Adds a handler Wrapper to a list of options passed into the server func WrapHandler(w HandlerWrapper) Option { return func(o *Options) { diff --git a/server/rpc_server.go b/server/rpc_server.go index 6f97ab9c..6b6af066 100644 --- a/server/rpc_server.go +++ b/server/rpc_server.go @@ -30,6 +30,8 @@ type rpcServer struct { subscribers map[*subscriber][]broker.Subscriber // used for first registration registered bool + // graceful exit + wg sync.WaitGroup } func newRpcServer(opts ...Option) Server { @@ -100,11 +102,18 @@ func (s *rpcServer) accept(sock transport.Socket) { } } + // add to wait group + s.wg.Add(1) + // TODO: needs better error handling if err := s.rpc.serveRequest(ctx, codec, ct); err != nil { log.Logf("Unexpected error serving request, closing socket: %v", err) + s.wg.Done() return } + + // finish request + s.wg.Done() } } @@ -371,8 +380,18 @@ func (s *rpcServer) Start() error { go ts.Accept(s.accept) go func() { + // wait for exit ch := <-s.exit + + // wait for requests to finish + if wait(s.opts.Context) { + s.wg.Wait() + } + + // close transport listener ch <- ts.Close() + + // disconnect the broker config.Broker.Disconnect() }() diff --git a/server/subscriber.go b/server/subscriber.go index 3c30a40d..30594348 100644 --- a/server/subscriber.go +++ b/server/subscriber.go @@ -226,11 +226,15 @@ func (s *rpcServer) createSubHandler(sb *subscriber, opts Options) broker.Handle fn = opts.SubWrappers[i-1](fn) } - go fn(ctx, &rpcPublication{ - topic: sb.topic, - contentType: ct, - message: req.Interface(), - }) + s.wg.Add(1) + go func() { + fn(ctx, &rpcPublication{ + topic: sb.topic, + contentType: ct, + message: req.Interface(), + }) + s.wg.Done() + }() } return nil }