feat: refactor register func (#1807)
Co-authored-by: huanghuan.27@bytedance.com <huanghuan.27@bytedance.com> Co-authored-by: Asim Aslam <asim@aslam.me>
This commit is contained in:
		| @@ -1,4 +1,3 @@ | |||||||
| // Package mdns is a multicast dns registry |  | ||||||
| package registry | package registry | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| @@ -44,6 +43,7 @@ type mdnsEntry struct { | |||||||
| // used for listing | // used for listing | ||||||
| type services map[string][]*mdnsEntry | type services map[string][]*mdnsEntry | ||||||
|  |  | ||||||
|  | // mdsRegistry is a multicast dns registry | ||||||
| type mdnsRegistry struct { | type mdnsRegistry struct { | ||||||
| 	opts Options | 	opts Options | ||||||
|  |  | ||||||
| @@ -136,6 +136,7 @@ func decode(record []string) (*mdnsTxt, error) { | |||||||
|  |  | ||||||
| 	return txt, nil | 	return txt, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func newRegistry(opts ...Option) Registry { | func newRegistry(opts ...Option) Registry { | ||||||
| 	options := Options{ | 	options := Options{ | ||||||
| 		Context: context.Background(), | 		Context: context.Background(), | ||||||
| @@ -148,9 +149,7 @@ func newRegistry(opts ...Option) Registry { | |||||||
|  |  | ||||||
| 	// set the domain | 	// set the domain | ||||||
| 	defaultDomain := DefaultDomain | 	defaultDomain := DefaultDomain | ||||||
|  | 	if d, ok := options.Context.Value("mdns.domain").(string); ok { | ||||||
| 	d, ok := options.Context.Value("mdns.domain").(string) |  | ||||||
| 	if ok { |  | ||||||
| 		defaultDomain = d | 		defaultDomain = d | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -192,35 +191,23 @@ func createServiceMDNSEntry(name, domain string) (*mdnsEntry, error) { | |||||||
| 	return &mdnsEntry{id: "*", node: srv}, nil | 	return &mdnsEntry{id: "*", node: srv}, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (m *mdnsRegistry) Register(service *Service, opts ...RegisterOption) error { | func (m *mdnsRegistry) getMdnsEntries(domain, serviceName string) ([]*mdnsEntry, error) { | ||||||
| 	m.Lock() | 	entries, ok := m.domains[domain][serviceName] | ||||||
|  | 	if ok { | ||||||
| 	// parse the options | 		return entries, nil | ||||||
| 	var options RegisterOptions |  | ||||||
| 	for _, o := range opts { |  | ||||||
| 		o(&options) |  | ||||||
| 	} |  | ||||||
| 	if len(options.Domain) == 0 { |  | ||||||
| 		options.Domain = m.defaultDomain |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// create the domain in the memory store if it doesn't yet exist |  | ||||||
| 	if _, ok := m.domains[options.Domain]; !ok { |  | ||||||
| 		m.domains[options.Domain] = make(services) |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// create the wildcard entry used for list queries in this domain | 	// create the wildcard entry used for list queries in this domain | ||||||
| 	entries, ok := m.domains[options.Domain][service.Name] | 	entry, err := createServiceMDNSEntry(serviceName, domain) | ||||||
| 	if !ok { | 	if err != nil { | ||||||
| 		entry, err := createServiceMDNSEntry(service.Name, options.Domain) | 		return nil, err | ||||||
| 		if err != nil { |  | ||||||
| 			m.Unlock() |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 		entries = append(entries, entry) |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	var gerr error | 	return []*mdnsEntry{entry}, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func registerService(service *Service, entries []*mdnsEntry, options RegisterOptions) ([]*mdnsEntry, error) { | ||||||
|  | 	var lastError error | ||||||
| 	for _, node := range service.Nodes { | 	for _, node := range service.Nodes { | ||||||
| 		var seen bool | 		var seen bool | ||||||
|  |  | ||||||
| @@ -244,13 +231,13 @@ func (m *mdnsRegistry) Register(service *Service, opts ...RegisterOption) error | |||||||
| 		}) | 		}) | ||||||
|  |  | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			gerr = err | 			lastError = err | ||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		host, pt, err := net.SplitHostPort(node.Address) | 		host, pt, err := net.SplitHostPort(node.Address) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			gerr = err | 			lastError = err | ||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
| 		port, _ := strconv.Atoi(pt) | 		port, _ := strconv.Atoi(pt) | ||||||
| @@ -269,42 +256,75 @@ func (m *mdnsRegistry) Register(service *Service, opts ...RegisterOption) error | |||||||
| 			txt, | 			txt, | ||||||
| 		) | 		) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			gerr = err | 			lastError = err | ||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		srv, err := mdns.NewServer(&mdns.Config{Zone: s, LocalhostChecking: true}) | 		srv, err := mdns.NewServer(&mdns.Config{Zone: s, LocalhostChecking: true}) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			gerr = err | 			lastError = err | ||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		entries = append(entries, &mdnsEntry{id: node.Id, node: srv}) | 		entries = append(entries, &mdnsEntry{id: node.Id, node: srv}) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	return entries, lastError | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func createGlobalDomainService(service *Service, options RegisterOptions) *Service { | ||||||
|  | 	srv := *service | ||||||
|  | 	srv.Nodes = nil | ||||||
|  |  | ||||||
|  | 	for _, n := range service.Nodes { | ||||||
|  | 		node := n | ||||||
|  |  | ||||||
|  | 		// set the original domain in node metadata | ||||||
|  | 		if node.Metadata == nil { | ||||||
|  | 			node.Metadata = map[string]string{"domain": options.Domain} | ||||||
|  | 		} else { | ||||||
|  | 			node.Metadata["domain"] = options.Domain | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		srv.Nodes = append(srv.Nodes, node) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return &srv | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *mdnsRegistry) Register(service *Service, opts ...RegisterOption) error { | ||||||
|  | 	m.Lock() | ||||||
|  |  | ||||||
|  | 	// parse the options | ||||||
|  | 	var options RegisterOptions | ||||||
|  | 	for _, o := range opts { | ||||||
|  | 		o(&options) | ||||||
|  | 	} | ||||||
|  | 	if len(options.Domain) == 0 { | ||||||
|  | 		options.Domain = m.defaultDomain | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// create the domain in the memory store if it doesn't yet exist | ||||||
|  | 	if _, ok := m.domains[options.Domain]; !ok { | ||||||
|  | 		m.domains[options.Domain] = make(services) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	entries, err := m.getMdnsEntries(options.Domain, service.Name) | ||||||
|  | 	if err != nil { | ||||||
|  | 		m.Unlock() | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	entries, gerr := registerService(service, entries, options) | ||||||
|  |  | ||||||
| 	// save the mdns entry | 	// save the mdns entry | ||||||
| 	m.domains[options.Domain][service.Name] = entries | 	m.domains[options.Domain][service.Name] = entries | ||||||
| 	m.Unlock() | 	m.Unlock() | ||||||
|  |  | ||||||
| 	// register in the global Domain so it can be queried as one | 	// register in the global Domain so it can be queried as one | ||||||
| 	if options.Domain != m.globalDomain { | 	if options.Domain != m.globalDomain { | ||||||
| 		srv := *service | 		srv := createGlobalDomainService(service, options) | ||||||
| 		srv.Nodes = nil | 		if err := m.Register(srv, append(opts, RegisterDomain(m.globalDomain))...); err != nil { | ||||||
|  |  | ||||||
| 		for _, n := range service.Nodes { |  | ||||||
| 			node := n |  | ||||||
|  |  | ||||||
| 			// set the original domain in node metadata |  | ||||||
| 			if node.Metadata == nil { |  | ||||||
| 				node.Metadata = map[string]string{"domain": options.Domain} |  | ||||||
| 			} else { |  | ||||||
| 				node.Metadata["domain"] = options.Domain |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			srv.Nodes = append(srv.Nodes, node) |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		if err := m.Register(&srv, append(opts, RegisterDomain(m.globalDomain))...); err != nil { |  | ||||||
| 			gerr = err | 			gerr = err | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -79,7 +79,6 @@ func TestMDNS(t *testing.T) { | |||||||
|  |  | ||||||
| 		if len(s) != 1 { | 		if len(s) != 1 { | ||||||
| 			t.Fatalf("Expected one result for %s got %d", service.Name, len(s)) | 			t.Fatalf("Expected one result for %s got %d", service.Name, len(s)) | ||||||
|  |  | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		if s[0].Name != service.Name { | 		if s[0].Name != service.Name { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user