diff --git a/registry/mdns_registry.go b/registry/mdns_registry.go index 9496fe0a..449a2293 100644 --- a/registry/mdns_registry.go +++ b/registry/mdns_registry.go @@ -21,8 +21,12 @@ import ( ) var ( - // use a .micro domain rather than .local - mdnsDomain = "micro" + // use a .micro tld rather than .local by default + defaultDomain = "micro" + // every service is written to the global domain so * domain queries work, e.g. + // calling mdns.List(registry.ListDomain("*")) will list the services across all + // domains + globalDomain = "global" ) type mdnsTxt struct { @@ -37,13 +41,20 @@ type mdnsEntry struct { node *mdns.Server } +// services are a key/value map, with the service name as a key and the value being a +// slice of mdns entries, representing the nodes with a single _services entry to be +// used for listing +type services map[string][]*mdnsEntry + type mdnsRegistry struct { opts Options - // the mdns domain - domain string + + // the top level domains, these can be overriden using options + defaultDomain string + globalDomain string sync.Mutex - services map[string][]*mdnsEntry + domains map[string]services mtx sync.RWMutex @@ -138,18 +149,19 @@ func newRegistry(opts ...Option) Registry { } // set the domain - domain := mdnsDomain + defaultDomain := defaultDomain d, ok := options.Context.Value("mdns.domain").(string) if ok { - domain = d + defaultDomain = d } return &mdnsRegistry{ - opts: options, - domain: domain, - services: make(map[string][]*mdnsEntry), - watchers: make(map[string]*mdnsWatcher), + defaultDomain: defaultDomain, + globalDomain: globalDomain, + opts: options, + domains: make(map[string]services), + watchers: make(map[string]*mdnsWatcher), } } @@ -164,55 +176,66 @@ func (m *mdnsRegistry) Options() Options { return m.opts } +// createServiceMDNSEntry will create a new wildcard mdns entry for the service in the +// given domain. This wildcard mdns entry is used when listing services. +func createServiceMDNSEntry(name, domain string) (*mdnsEntry, error) { + ip := net.ParseIP("0.0.0.0") + + s, err := mdns.NewMDNSService(name, "_services", domain+".", "", 9999, []net.IP{ip}, nil) + if err != nil { + return nil, err + } + + srv, err := mdns.NewServer(&mdns.Config{Zone: &mdns.DNSSDService{MDNSService: s}}) + if err != nil { + return nil, err + } + + return &mdnsEntry{id: "*", node: srv}, nil +} + func (m *mdnsRegistry) Register(service *Service, opts ...RegisterOption) error { m.Lock() - defer m.Unlock() - entries, ok := m.services[service.Name] - // first entry, create wildcard used for list queries + // parse the options + var options RegisterOptions + for _, o := range opts { + o(&options) + } + if len(options.Domain) == 0 { + options.Domain = m.defaultDomain + } + + // create the domain in the memory store if it doesn't yet exist + if _, ok := m.domains[options.Domain]; !ok { + m.domains[options.Domain] = make(services) + } + + // create the wildcard entry used for list queries in this domain + entries, ok := m.domains[options.Domain][service.Name] if !ok { - s, err := mdns.NewMDNSService( - service.Name, - "_services", - m.domain+".", - "", - 9999, - []net.IP{net.ParseIP("0.0.0.0")}, - nil, - ) + entry, err := createServiceMDNSEntry(service.Name, options.Domain) if err != nil { + m.Unlock() return err } - - srv, err := mdns.NewServer(&mdns.Config{Zone: &mdns.DNSSDService{MDNSService: s}}) - if err != nil { - return err - } - - // append the wildcard entry - entries = append(entries, &mdnsEntry{id: "*", node: srv}) + entries = append(entries, entry) } var gerr error - for _, node := range service.Nodes { var seen bool - var e *mdnsEntry for _, entry := range entries { if node.Id == entry.id { seen = true - e = entry break } } - // already registered, continue + // this node has already been registered, continue if seen { continue - // doesn't exist - } else { - e = &mdnsEntry{} } txt, err := encode(&mdnsTxt{ @@ -241,7 +264,7 @@ func (m *mdnsRegistry) Register(service *Service, opts ...RegisterOption) error s, err := mdns.NewMDNSService( node.Id, service.Name, - m.domain+".", + options.Domain+".", "", port, []net.IP{net.ParseIP(host)}, @@ -258,25 +281,70 @@ func (m *mdnsRegistry) Register(service *Service, opts ...RegisterOption) error continue } - e.id = node.Id - e.node = srv - entries = append(entries, e) + entries = append(entries, &mdnsEntry{id: node.Id, node: srv}) } - // save - m.services[service.Name] = entries + // save the mdns entry + m.domains[options.Domain][service.Name] = entries + m.Unlock() + + // register in the global Domain so it can be queried as one + if options.Domain != m.globalDomain { + srv := *service + srv.Nodes = nil + + for _, n := range service.Nodes { + node := n + + // set the original domain in node metadata + if node.Metadata == nil { + node.Metadata = map[string]string{"domain": options.Domain} + } else { + node.Metadata["domain"] = options.Domain + } + + srv.Nodes = append(srv.Nodes, node) + } + + if err := m.Register(service, append(opts, RegisterDomain(m.globalDomain))...); err != nil { + gerr = err + } + } return gerr } func (m *mdnsRegistry) Deregister(service *Service, opts ...DeregisterOption) error { + // parse the options + var options DeregisterOptions + for _, o := range opts { + o(&options) + } + if len(options.Domain) == 0 { + options.Domain = m.defaultDomain + } + + // register in the global Domain + var err error + if options.Domain != m.globalDomain { + defer func() { + err = m.Deregister(service, append(opts, DeregisterDomain(m.globalDomain))...) + }() + } + + // we want to unlock before we call deregister on the global domain, so it's important this unlock + // is applied after the defer m.Deregister is called above m.Lock() defer m.Unlock() - var newEntries []*mdnsEntry + // the service wasn't registered, we can safely exist + if _, ok := m.domains[options.Domain]; !ok { + return err + } // loop existing entries, check if any match, shutdown those that do - for _, entry := range m.services[service.Name] { + var newEntries []*mdnsEntry + for _, entry := range m.domains[options.Domain][service.Name] { var remove bool for _, node := range service.Nodes { @@ -293,18 +361,43 @@ func (m *mdnsRegistry) Deregister(service *Service, opts ...DeregisterOption) er } } - // last entry is the wildcard for list queries. Remove it. - if len(newEntries) == 1 && newEntries[0].id == "*" { - newEntries[0].node.Shutdown() - delete(m.services, service.Name) - } else { - m.services[service.Name] = newEntries + // we have more than one entry remaining, we can exit + if len(newEntries) > 1 { + m.domains[options.Domain][service.Name] = newEntries + return err } - return nil + // our remaining entry is not a wildcard, we can exit + if len(newEntries) == 1 && newEntries[0].id != "*" { + m.domains[options.Domain][service.Name] = newEntries + return err + } + + // last entry is the wildcard for list queries. Remove it. + newEntries[0].node.Shutdown() + delete(m.domains[options.Domain], service.Name) + + // check to see if we can delete the domain entry + if len(m.domains[options.Domain]) == 0 { + delete(m.domains, options.Domain) + } + + return err } func (m *mdnsRegistry) GetService(service string, opts ...GetOption) ([]*Service, error) { + // parse the options + var options GetOptions + for _, o := range opts { + o(&options) + } + if len(options.Domain) == 0 { + options.Domain = m.defaultDomain + } + if options.Domain == WildcardDomain { + options.Domain = m.globalDomain + } + serviceMap := make(map[string]*Service) entries := make(chan *mdns.ServiceEntry, 10) done := make(chan bool) @@ -317,17 +410,14 @@ func (m *mdnsRegistry) GetService(service string, opts ...GetOption) ([]*Service // set entries channel p.Entries = entries // set the domain - p.Domain = m.domain + p.Domain = options.Domain go func() { for { select { case e := <-entries: // list record so skip - if p.Service == "_services" { - continue - } - if p.Domain != m.domain { + if e.Name == "_services" { continue } if e.TTL == 0 { @@ -397,6 +487,18 @@ func (m *mdnsRegistry) GetService(service string, opts ...GetOption) ([]*Service } func (m *mdnsRegistry) ListServices(opts ...ListOption) ([]*Service, error) { + // parse the options + var options ListOptions + for _, o := range opts { + o(&options) + } + if len(options.Domain) == 0 { + options.Domain = m.defaultDomain + } + if options.Domain == WildcardDomain { + options.Domain = m.globalDomain + } + serviceMap := make(map[string]bool) entries := make(chan *mdns.ServiceEntry, 10) done := make(chan bool) @@ -409,7 +511,7 @@ func (m *mdnsRegistry) ListServices(opts ...ListOption) ([]*Service, error) { // set entries channel p.Entries = entries // set domain - p.Domain = m.domain + p.Domain = options.Domain var services []*Service @@ -451,13 +553,19 @@ func (m *mdnsRegistry) Watch(opts ...WatchOption) (Watcher, error) { for _, o := range opts { o(&wo) } + if len(wo.Domain) == 0 { + wo.Domain = m.defaultDomain + } + if wo.Domain == WildcardDomain { + wo.Domain = m.globalDomain + } md := &mdnsWatcher{ id: uuid.New().String(), wo: wo, ch: make(chan *mdns.ServiceEntry, 32), exit: make(chan struct{}), - domain: m.domain, + domain: wo.Domain, registry: m, } diff --git a/registry/options.go b/registry/options.go index 03fd5eb4..3cc5cc43 100644 --- a/registry/options.go +++ b/registry/options.go @@ -21,6 +21,8 @@ type RegisterOptions struct { // Other options for implementations of the interface // can be stored in a context Context context.Context + // Domain to register the service in + Domain string } type WatchOptions struct { @@ -30,18 +32,26 @@ type WatchOptions struct { // Other options for implementations of the interface // can be stored in a context Context context.Context + // Domain to watch + Domain string } type DeregisterOptions struct { Context context.Context + // Domain the service was registered in + Domain string } type GetOptions struct { Context context.Context + // Domain to scope the request to + Domain string } type ListOptions struct { Context context.Context + // Domain to scope the request to + Domain string } // Addrs is the registry addresses to use @@ -83,6 +93,12 @@ func RegisterContext(ctx context.Context) RegisterOption { } } +func RegisterDomain(d string) RegisterOption { + return func(o *RegisterOptions) { + o.Domain = d + } +} + // Watch a service func WatchService(name string) WatchOption { return func(o *WatchOptions) { @@ -96,20 +112,44 @@ func WatchContext(ctx context.Context) WatchOption { } } +func WatchDomain(d string) WatchOption { + return func(o *WatchOptions) { + o.Domain = d + } +} + func DeregisterContext(ctx context.Context) DeregisterOption { return func(o *DeregisterOptions) { o.Context = ctx } } +func DeregisterDomain(d string) DeregisterOption { + return func(o *DeregisterOptions) { + o.Domain = d + } +} + func GetContext(ctx context.Context) GetOption { return func(o *GetOptions) { o.Context = ctx } } +func GetDomain(d string) GetOption { + return func(o *GetOptions) { + o.Domain = d + } +} + func ListContext(ctx context.Context) ListOption { return func(o *ListOptions) { o.Context = ctx } } + +func ListDomain(d string) ListOption { + return func(o *ListOptions) { + o.Domain = d + } +} diff --git a/registry/registry.go b/registry/registry.go index dce9b431..0bbf1f50 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -5,6 +5,11 @@ import ( "errors" ) +const ( + // WildcardDomain indicates any domain + WildcardDomain = "*" +) + var ( DefaultRegistry = NewRegistry()