diff --git a/server/options.go b/server/options.go index 6acb5960..b0bae35a 100644 --- a/server/options.go +++ b/server/options.go @@ -25,6 +25,8 @@ type Options struct { HdlrWrappers []HandlerWrapper SubWrappers []SubscriberWrapper + // RegisterCheck runs a check function before registering the service + RegisterCheck func(context.Context) error // The register expiry time RegisterTTL time.Duration // The interval on which to register @@ -67,6 +69,10 @@ func newOptions(opt ...Option) Options { opts.DebugHandler = debug.DefaultDebugHandler } + if opts.RegisterCheck == nil { + opts.RegisterCheck = DefaultRegisterCheck + } + if len(opts.Address) == 0 { opts.Address = DefaultAddress } @@ -163,6 +169,13 @@ func Metadata(md map[string]string) Option { } } +// RegisterCheck run func before registry service +func RegisterCheck(fn func(context.Context) error) Option { + return func(o *Options) { + o.RegisterCheck = fn + } +} + // Register the service with a TTL func RegisterTTL(t time.Duration) Option { return func(o *Options) { diff --git a/server/rpc_server.go b/server/rpc_server.go index 57fd0685..4f1049c1 100644 --- a/server/rpc_server.go +++ b/server/rpc_server.go @@ -468,9 +468,14 @@ func (s *rpcServer) Start() error { log.Logf("Broker [%s] Connected to %s", config.Broker.String(), config.Broker.Address()) - // announce self to the world - if err := s.Register(); err != nil { - log.Log("Server register error: ", err) + // use RegisterCheck func before register + if err = s.opts.RegisterCheck(s.opts.Context); err != nil { + log.Logf("Server %s-%s register check error: %s", config.Name, config.Id, err) + } else { + // announce self to the world + if err = s.Register(); err != nil { + log.Log("Server %s-%s register error: %s", config.Name, config.Id, err) + } } exit := make(chan bool) @@ -518,8 +523,19 @@ func (s *rpcServer) Start() error { select { // register self on interval case <-t.C: - if err := s.Register(); err != nil { - log.Log("Server register error: ", err) + s.RLock() + registered := s.registered + s.RUnlock() + if err = s.opts.RegisterCheck(s.opts.Context); err != nil && registered { + log.Logf("Server %s-%s register check error: %s, deregister it", config.Name, config.Id, err) + // deregister self in case of error + if err := s.Deregister(); err != nil { + log.Logf("Server %s-%s deregister error: %s", config.Name, config.Id, err) + } + } else { + if err := s.Register(); err != nil { + log.Logf("Server %s-%s register error: %s", config.Name, config.Id, err) + } } // wait for exit case ch = <-s.exit: @@ -531,7 +547,7 @@ func (s *rpcServer) Start() error { // deregister self if err := s.Deregister(); err != nil { - log.Log("Server deregister error: ", err) + log.Logf("Server %s-%s deregister error: %s", config.Name, config.Id, err) } // wait for requests to finish diff --git a/server/server.go b/server/server.go index 2810f253..706e6b7e 100644 --- a/server/server.go +++ b/server/server.go @@ -8,7 +8,7 @@ import ( "syscall" "github.com/google/uuid" - "github.com/micro/go-log" + log "github.com/micro/go-log" "github.com/micro/go-micro/codec" "github.com/micro/go-micro/registry" ) @@ -115,12 +115,13 @@ type Subscriber interface { type Option func(*Options) var ( - DefaultAddress = ":0" - DefaultName = "server" - DefaultVersion = "latest" - DefaultId = uuid.New().String() - DefaultServer Server = newRpcServer() - DefaultRouter = newRpcRouter() + DefaultAddress = ":0" + DefaultName = "server" + DefaultVersion = "latest" + DefaultId = uuid.New().String() + DefaultServer Server = newRpcServer() + DefaultRouter = newRpcRouter() + DefaultRegisterCheck = func(context.Context) error { return nil } ) // DefaultOptions returns config options for the default service