diff --git a/registry/mdns_registry.go b/registry/mdns_registry.go index e60e1268..d5f13222 100644 --- a/registry/mdns_registry.go +++ b/registry/mdns_registry.go @@ -1,4 +1,3 @@ -// Package mdns is a multicast dns registry package registry import ( @@ -44,6 +43,7 @@ type mdnsEntry struct { // used for listing type services map[string][]*mdnsEntry +// mdsRegistry is a multicast dns registry type mdnsRegistry struct { opts Options @@ -136,6 +136,7 @@ func decode(record []string) (*mdnsTxt, error) { return txt, nil } + func newRegistry(opts ...Option) Registry { options := Options{ Context: context.Background(), @@ -148,9 +149,7 @@ func newRegistry(opts ...Option) Registry { // set the domain defaultDomain := DefaultDomain - - d, ok := options.Context.Value("mdns.domain").(string) - if ok { + if d, ok := options.Context.Value("mdns.domain").(string); ok { defaultDomain = d } @@ -192,35 +191,23 @@ func createServiceMDNSEntry(name, domain string) (*mdnsEntry, error) { return &mdnsEntry{id: "*", node: srv}, nil } -func (m *mdnsRegistry) Register(service *Service, opts ...RegisterOption) error { - m.Lock() - - // parse the options - var options RegisterOptions - for _, o := range opts { - o(&options) - } - if len(options.Domain) == 0 { - options.Domain = m.defaultDomain - } - - // create the domain in the memory store if it doesn't yet exist - if _, ok := m.domains[options.Domain]; !ok { - m.domains[options.Domain] = make(services) +func (m *mdnsRegistry) getMdnsEntries(domain, serviceName string) ([]*mdnsEntry, error) { + entries, ok := m.domains[domain][serviceName] + if ok { + return entries, nil } // create the wildcard entry used for list queries in this domain - entries, ok := m.domains[options.Domain][service.Name] - if !ok { - entry, err := createServiceMDNSEntry(service.Name, options.Domain) - if err != nil { - m.Unlock() - return err - } - entries = append(entries, entry) + entry, err := createServiceMDNSEntry(serviceName, domain) + if err != nil { + return nil, err } - var gerr error + return []*mdnsEntry{entry}, nil +} + +func registerService(service *Service, entries []*mdnsEntry, options RegisterOptions) ([]*mdnsEntry, error) { + var lastError error for _, node := range service.Nodes { var seen bool @@ -244,13 +231,13 @@ func (m *mdnsRegistry) Register(service *Service, opts ...RegisterOption) error }) if err != nil { - gerr = err + lastError = err continue } host, pt, err := net.SplitHostPort(node.Address) if err != nil { - gerr = err + lastError = err continue } port, _ := strconv.Atoi(pt) @@ -269,42 +256,75 @@ func (m *mdnsRegistry) Register(service *Service, opts ...RegisterOption) error txt, ) if err != nil { - gerr = err + lastError = err continue } srv, err := mdns.NewServer(&mdns.Config{Zone: s, LocalhostChecking: true}) if err != nil { - gerr = err + lastError = err continue } entries = append(entries, &mdnsEntry{id: node.Id, node: srv}) } + return entries, lastError +} + +func createGlobalDomainService(service *Service, options RegisterOptions) *Service { + srv := *service + srv.Nodes = nil + + for _, n := range service.Nodes { + node := n + + // set the original domain in node metadata + if node.Metadata == nil { + node.Metadata = map[string]string{"domain": options.Domain} + } else { + node.Metadata["domain"] = options.Domain + } + + srv.Nodes = append(srv.Nodes, node) + } + + return &srv +} + +func (m *mdnsRegistry) Register(service *Service, opts ...RegisterOption) error { + m.Lock() + + // parse the options + var options RegisterOptions + for _, o := range opts { + o(&options) + } + if len(options.Domain) == 0 { + options.Domain = m.defaultDomain + } + + // create the domain in the memory store if it doesn't yet exist + if _, ok := m.domains[options.Domain]; !ok { + m.domains[options.Domain] = make(services) + } + + entries, err := m.getMdnsEntries(options.Domain, service.Name) + if err != nil { + m.Unlock() + return err + } + + entries, gerr := registerService(service, entries, options) + // save the mdns entry m.domains[options.Domain][service.Name] = entries m.Unlock() // register in the global Domain so it can be queried as one if options.Domain != m.globalDomain { - srv := *service - srv.Nodes = nil - - for _, n := range service.Nodes { - node := n - - // set the original domain in node metadata - if node.Metadata == nil { - node.Metadata = map[string]string{"domain": options.Domain} - } else { - node.Metadata["domain"] = options.Domain - } - - srv.Nodes = append(srv.Nodes, node) - } - - if err := m.Register(&srv, append(opts, RegisterDomain(m.globalDomain))...); err != nil { + srv := createGlobalDomainService(service, options) + if err := m.Register(srv, append(opts, RegisterDomain(m.globalDomain))...); err != nil { gerr = err } } diff --git a/registry/mdns_test.go b/registry/mdns_test.go index 5f5cc617..e1cb12da 100644 --- a/registry/mdns_test.go +++ b/registry/mdns_test.go @@ -79,7 +79,6 @@ func TestMDNS(t *testing.T) { if len(s) != 1 { t.Fatalf("Expected one result for %s got %d", service.Name, len(s)) - } if s[0].Name != service.Name {