feat: refactor register func (#1807)
Co-authored-by: huanghuan.27@bytedance.com <huanghuan.27@bytedance.com> Co-authored-by: Asim Aslam <asim@aslam.me>
This commit is contained in:
parent
1bac08cc0e
commit
3468331506
@ -1,4 +1,3 @@
|
||||
// Package mdns is a multicast dns registry
|
||||
package registry
|
||||
|
||||
import (
|
||||
@ -44,6 +43,7 @@ type mdnsEntry struct {
|
||||
// used for listing
|
||||
type services map[string][]*mdnsEntry
|
||||
|
||||
// mdsRegistry is a multicast dns registry
|
||||
type mdnsRegistry struct {
|
||||
opts Options
|
||||
|
||||
@ -136,6 +136,7 @@ func decode(record []string) (*mdnsTxt, error) {
|
||||
|
||||
return txt, nil
|
||||
}
|
||||
|
||||
func newRegistry(opts ...Option) Registry {
|
||||
options := Options{
|
||||
Context: context.Background(),
|
||||
@ -148,9 +149,7 @@ func newRegistry(opts ...Option) Registry {
|
||||
|
||||
// set the domain
|
||||
defaultDomain := DefaultDomain
|
||||
|
||||
d, ok := options.Context.Value("mdns.domain").(string)
|
||||
if ok {
|
||||
if d, ok := options.Context.Value("mdns.domain").(string); ok {
|
||||
defaultDomain = d
|
||||
}
|
||||
|
||||
@ -192,35 +191,23 @@ func createServiceMDNSEntry(name, domain string) (*mdnsEntry, error) {
|
||||
return &mdnsEntry{id: "*", node: srv}, nil
|
||||
}
|
||||
|
||||
func (m *mdnsRegistry) Register(service *Service, opts ...RegisterOption) error {
|
||||
m.Lock()
|
||||
|
||||
// parse the options
|
||||
var options RegisterOptions
|
||||
for _, o := range opts {
|
||||
o(&options)
|
||||
}
|
||||
if len(options.Domain) == 0 {
|
||||
options.Domain = m.defaultDomain
|
||||
}
|
||||
|
||||
// create the domain in the memory store if it doesn't yet exist
|
||||
if _, ok := m.domains[options.Domain]; !ok {
|
||||
m.domains[options.Domain] = make(services)
|
||||
func (m *mdnsRegistry) getMdnsEntries(domain, serviceName string) ([]*mdnsEntry, error) {
|
||||
entries, ok := m.domains[domain][serviceName]
|
||||
if ok {
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
// create the wildcard entry used for list queries in this domain
|
||||
entries, ok := m.domains[options.Domain][service.Name]
|
||||
if !ok {
|
||||
entry, err := createServiceMDNSEntry(service.Name, options.Domain)
|
||||
entry, err := createServiceMDNSEntry(serviceName, domain)
|
||||
if err != nil {
|
||||
m.Unlock()
|
||||
return err
|
||||
}
|
||||
entries = append(entries, entry)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var gerr error
|
||||
return []*mdnsEntry{entry}, nil
|
||||
}
|
||||
|
||||
func registerService(service *Service, entries []*mdnsEntry, options RegisterOptions) ([]*mdnsEntry, error) {
|
||||
var lastError error
|
||||
for _, node := range service.Nodes {
|
||||
var seen bool
|
||||
|
||||
@ -244,13 +231,13 @@ func (m *mdnsRegistry) Register(service *Service, opts ...RegisterOption) error
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
gerr = err
|
||||
lastError = err
|
||||
continue
|
||||
}
|
||||
|
||||
host, pt, err := net.SplitHostPort(node.Address)
|
||||
if err != nil {
|
||||
gerr = err
|
||||
lastError = err
|
||||
continue
|
||||
}
|
||||
port, _ := strconv.Atoi(pt)
|
||||
@ -269,25 +256,23 @@ func (m *mdnsRegistry) Register(service *Service, opts ...RegisterOption) error
|
||||
txt,
|
||||
)
|
||||
if err != nil {
|
||||
gerr = err
|
||||
lastError = err
|
||||
continue
|
||||
}
|
||||
|
||||
srv, err := mdns.NewServer(&mdns.Config{Zone: s, LocalhostChecking: true})
|
||||
if err != nil {
|
||||
gerr = err
|
||||
lastError = err
|
||||
continue
|
||||
}
|
||||
|
||||
entries = append(entries, &mdnsEntry{id: node.Id, node: srv})
|
||||
}
|
||||
|
||||
// save the mdns entry
|
||||
m.domains[options.Domain][service.Name] = entries
|
||||
m.Unlock()
|
||||
return entries, lastError
|
||||
}
|
||||
|
||||
// register in the global Domain so it can be queried as one
|
||||
if options.Domain != m.globalDomain {
|
||||
func createGlobalDomainService(service *Service, options RegisterOptions) *Service {
|
||||
srv := *service
|
||||
srv.Nodes = nil
|
||||
|
||||
@ -304,7 +289,42 @@ func (m *mdnsRegistry) Register(service *Service, opts ...RegisterOption) error
|
||||
srv.Nodes = append(srv.Nodes, node)
|
||||
}
|
||||
|
||||
if err := m.Register(&srv, append(opts, RegisterDomain(m.globalDomain))...); err != nil {
|
||||
return &srv
|
||||
}
|
||||
|
||||
func (m *mdnsRegistry) Register(service *Service, opts ...RegisterOption) error {
|
||||
m.Lock()
|
||||
|
||||
// parse the options
|
||||
var options RegisterOptions
|
||||
for _, o := range opts {
|
||||
o(&options)
|
||||
}
|
||||
if len(options.Domain) == 0 {
|
||||
options.Domain = m.defaultDomain
|
||||
}
|
||||
|
||||
// create the domain in the memory store if it doesn't yet exist
|
||||
if _, ok := m.domains[options.Domain]; !ok {
|
||||
m.domains[options.Domain] = make(services)
|
||||
}
|
||||
|
||||
entries, err := m.getMdnsEntries(options.Domain, service.Name)
|
||||
if err != nil {
|
||||
m.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
entries, gerr := registerService(service, entries, options)
|
||||
|
||||
// save the mdns entry
|
||||
m.domains[options.Domain][service.Name] = entries
|
||||
m.Unlock()
|
||||
|
||||
// register in the global Domain so it can be queried as one
|
||||
if options.Domain != m.globalDomain {
|
||||
srv := createGlobalDomainService(service, options)
|
||||
if err := m.Register(srv, append(opts, RegisterDomain(m.globalDomain))...); err != nil {
|
||||
gerr = err
|
||||
}
|
||||
}
|
||||
|
@ -79,7 +79,6 @@ func TestMDNS(t *testing.T) {
|
||||
|
||||
if len(s) != 1 {
|
||||
t.Fatalf("Expected one result for %s got %d", service.Name, len(s))
|
||||
|
||||
}
|
||||
|
||||
if s[0].Name != service.Name {
|
||||
|
Loading…
Reference in New Issue
Block a user