feat: refactor register func (#1807)
Co-authored-by: huanghuan.27@bytedance.com <huanghuan.27@bytedance.com> Co-authored-by: Asim Aslam <asim@aslam.me>
This commit is contained in:
		| @@ -1,4 +1,3 @@ | ||||
| // Package mdns is a multicast dns registry | ||||
| package registry | ||||
|  | ||||
| import ( | ||||
| @@ -44,6 +43,7 @@ type mdnsEntry struct { | ||||
| // used for listing | ||||
| type services map[string][]*mdnsEntry | ||||
|  | ||||
| // mdsRegistry is a multicast dns registry | ||||
| type mdnsRegistry struct { | ||||
| 	opts Options | ||||
|  | ||||
| @@ -136,6 +136,7 @@ func decode(record []string) (*mdnsTxt, error) { | ||||
|  | ||||
| 	return txt, nil | ||||
| } | ||||
|  | ||||
| func newRegistry(opts ...Option) Registry { | ||||
| 	options := Options{ | ||||
| 		Context: context.Background(), | ||||
| @@ -148,9 +149,7 @@ func newRegistry(opts ...Option) Registry { | ||||
|  | ||||
| 	// set the domain | ||||
| 	defaultDomain := DefaultDomain | ||||
|  | ||||
| 	d, ok := options.Context.Value("mdns.domain").(string) | ||||
| 	if ok { | ||||
| 	if d, ok := options.Context.Value("mdns.domain").(string); ok { | ||||
| 		defaultDomain = d | ||||
| 	} | ||||
|  | ||||
| @@ -192,35 +191,23 @@ func createServiceMDNSEntry(name, domain string) (*mdnsEntry, error) { | ||||
| 	return &mdnsEntry{id: "*", node: srv}, nil | ||||
| } | ||||
|  | ||||
| func (m *mdnsRegistry) Register(service *Service, opts ...RegisterOption) error { | ||||
| 	m.Lock() | ||||
|  | ||||
| 	// parse the options | ||||
| 	var options RegisterOptions | ||||
| 	for _, o := range opts { | ||||
| 		o(&options) | ||||
| 	} | ||||
| 	if len(options.Domain) == 0 { | ||||
| 		options.Domain = m.defaultDomain | ||||
| 	} | ||||
|  | ||||
| 	// create the domain in the memory store if it doesn't yet exist | ||||
| 	if _, ok := m.domains[options.Domain]; !ok { | ||||
| 		m.domains[options.Domain] = make(services) | ||||
| func (m *mdnsRegistry) getMdnsEntries(domain, serviceName string) ([]*mdnsEntry, error) { | ||||
| 	entries, ok := m.domains[domain][serviceName] | ||||
| 	if ok { | ||||
| 		return entries, nil | ||||
| 	} | ||||
|  | ||||
| 	// create the wildcard entry used for list queries in this domain | ||||
| 	entries, ok := m.domains[options.Domain][service.Name] | ||||
| 	if !ok { | ||||
| 		entry, err := createServiceMDNSEntry(service.Name, options.Domain) | ||||
| 		if err != nil { | ||||
| 			m.Unlock() | ||||
| 			return err | ||||
| 		} | ||||
| 		entries = append(entries, entry) | ||||
| 	entry, err := createServiceMDNSEntry(serviceName, domain) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	var gerr error | ||||
| 	return []*mdnsEntry{entry}, nil | ||||
| } | ||||
|  | ||||
| func registerService(service *Service, entries []*mdnsEntry, options RegisterOptions) ([]*mdnsEntry, error) { | ||||
| 	var lastError error | ||||
| 	for _, node := range service.Nodes { | ||||
| 		var seen bool | ||||
|  | ||||
| @@ -244,13 +231,13 @@ func (m *mdnsRegistry) Register(service *Service, opts ...RegisterOption) error | ||||
| 		}) | ||||
|  | ||||
| 		if err != nil { | ||||
| 			gerr = err | ||||
| 			lastError = err | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		host, pt, err := net.SplitHostPort(node.Address) | ||||
| 		if err != nil { | ||||
| 			gerr = err | ||||
| 			lastError = err | ||||
| 			continue | ||||
| 		} | ||||
| 		port, _ := strconv.Atoi(pt) | ||||
| @@ -269,42 +256,75 @@ func (m *mdnsRegistry) Register(service *Service, opts ...RegisterOption) error | ||||
| 			txt, | ||||
| 		) | ||||
| 		if err != nil { | ||||
| 			gerr = err | ||||
| 			lastError = err | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		srv, err := mdns.NewServer(&mdns.Config{Zone: s, LocalhostChecking: true}) | ||||
| 		if err != nil { | ||||
| 			gerr = err | ||||
| 			lastError = err | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		entries = append(entries, &mdnsEntry{id: node.Id, node: srv}) | ||||
| 	} | ||||
|  | ||||
| 	return entries, lastError | ||||
| } | ||||
|  | ||||
| func createGlobalDomainService(service *Service, options RegisterOptions) *Service { | ||||
| 	srv := *service | ||||
| 	srv.Nodes = nil | ||||
|  | ||||
| 	for _, n := range service.Nodes { | ||||
| 		node := n | ||||
|  | ||||
| 		// set the original domain in node metadata | ||||
| 		if node.Metadata == nil { | ||||
| 			node.Metadata = map[string]string{"domain": options.Domain} | ||||
| 		} else { | ||||
| 			node.Metadata["domain"] = options.Domain | ||||
| 		} | ||||
|  | ||||
| 		srv.Nodes = append(srv.Nodes, node) | ||||
| 	} | ||||
|  | ||||
| 	return &srv | ||||
| } | ||||
|  | ||||
| func (m *mdnsRegistry) Register(service *Service, opts ...RegisterOption) error { | ||||
| 	m.Lock() | ||||
|  | ||||
| 	// parse the options | ||||
| 	var options RegisterOptions | ||||
| 	for _, o := range opts { | ||||
| 		o(&options) | ||||
| 	} | ||||
| 	if len(options.Domain) == 0 { | ||||
| 		options.Domain = m.defaultDomain | ||||
| 	} | ||||
|  | ||||
| 	// create the domain in the memory store if it doesn't yet exist | ||||
| 	if _, ok := m.domains[options.Domain]; !ok { | ||||
| 		m.domains[options.Domain] = make(services) | ||||
| 	} | ||||
|  | ||||
| 	entries, err := m.getMdnsEntries(options.Domain, service.Name) | ||||
| 	if err != nil { | ||||
| 		m.Unlock() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	entries, gerr := registerService(service, entries, options) | ||||
|  | ||||
| 	// save the mdns entry | ||||
| 	m.domains[options.Domain][service.Name] = entries | ||||
| 	m.Unlock() | ||||
|  | ||||
| 	// register in the global Domain so it can be queried as one | ||||
| 	if options.Domain != m.globalDomain { | ||||
| 		srv := *service | ||||
| 		srv.Nodes = nil | ||||
|  | ||||
| 		for _, n := range service.Nodes { | ||||
| 			node := n | ||||
|  | ||||
| 			// set the original domain in node metadata | ||||
| 			if node.Metadata == nil { | ||||
| 				node.Metadata = map[string]string{"domain": options.Domain} | ||||
| 			} else { | ||||
| 				node.Metadata["domain"] = options.Domain | ||||
| 			} | ||||
|  | ||||
| 			srv.Nodes = append(srv.Nodes, node) | ||||
| 		} | ||||
|  | ||||
| 		if err := m.Register(&srv, append(opts, RegisterDomain(m.globalDomain))...); err != nil { | ||||
| 		srv := createGlobalDomainService(service, options) | ||||
| 		if err := m.Register(srv, append(opts, RegisterDomain(m.globalDomain))...); err != nil { | ||||
| 			gerr = err | ||||
| 		} | ||||
| 	} | ||||
|   | ||||
| @@ -79,7 +79,6 @@ func TestMDNS(t *testing.T) { | ||||
|  | ||||
| 		if len(s) != 1 { | ||||
| 			t.Fatalf("Expected one result for %s got %d", service.Name, len(s)) | ||||
|  | ||||
| 		} | ||||
|  | ||||
| 		if s[0].Name != service.Name { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user