354 lines
		
	
	
		
			6.2 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			354 lines
		
	
	
		
			6.2 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Package mdns is a multicast dns registry
 | |
| package registry
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"fmt"
 | |
| 	"net"
 | |
| 	"strconv"
 | |
| 	"strings"
 | |
| 	"sync"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/micro/mdns"
 | |
| 	hash "github.com/mitchellh/hashstructure"
 | |
| )
 | |
| 
 | |
| type mdnsTxt struct {
 | |
| 	Service   string
 | |
| 	Version   string
 | |
| 	Endpoints []*Endpoint
 | |
| 	Metadata  map[string]string
 | |
| }
 | |
| 
 | |
| type mdnsEntry struct {
 | |
| 	hash uint64
 | |
| 	id   string
 | |
| 	node *mdns.Server
 | |
| }
 | |
| 
 | |
| type mdnsRegistry struct {
 | |
| 	opts Options
 | |
| 
 | |
| 	sync.Mutex
 | |
| 	services map[string][]*mdnsEntry
 | |
| }
 | |
| 
 | |
| func newRegistry(opts ...Option) Registry {
 | |
| 	options := Options{
 | |
| 		Timeout: time.Millisecond * 100,
 | |
| 	}
 | |
| 
 | |
| 	return &mdnsRegistry{
 | |
| 		opts:     options,
 | |
| 		services: make(map[string][]*mdnsEntry),
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (m *mdnsRegistry) Init(opts ...Option) error {
 | |
| 	for _, o := range opts {
 | |
| 		o(&m.opts)
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (m *mdnsRegistry) Options() Options {
 | |
| 	return m.opts
 | |
| }
 | |
| 
 | |
| func (m *mdnsRegistry) Register(service *Service, opts ...RegisterOption) error {
 | |
| 	m.Lock()
 | |
| 	defer m.Unlock()
 | |
| 
 | |
| 	entries, ok := m.services[service.Name]
 | |
| 	// first entry, create wildcard used for list queries
 | |
| 	if !ok {
 | |
| 		s, err := mdns.NewMDNSService(
 | |
| 			service.Name,
 | |
| 			"_services",
 | |
| 			"",
 | |
| 			"",
 | |
| 			9999,
 | |
| 			[]net.IP{net.ParseIP("0.0.0.0")},
 | |
| 			nil,
 | |
| 		)
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 
 | |
| 		srv, err := mdns.NewServer(&mdns.Config{Zone: &mdns.DNSSDService{s}})
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 
 | |
| 		// append the wildcard entry
 | |
| 		entries = append(entries, &mdnsEntry{id: "*", node: srv})
 | |
| 	}
 | |
| 
 | |
| 	var gerr error
 | |
| 
 | |
| 	for _, node := range service.Nodes {
 | |
| 		// create hash of service; uint64
 | |
| 		h, err := hash.Hash(node, nil)
 | |
| 		if err != nil {
 | |
| 			gerr = err
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		var seen bool
 | |
| 		var e *mdnsEntry
 | |
| 
 | |
| 		for _, entry := range entries {
 | |
| 			if node.Id == entry.id {
 | |
| 				seen = true
 | |
| 				e = entry
 | |
| 				break
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		// already registered, continue
 | |
| 		if seen && e.hash == h {
 | |
| 			continue
 | |
| 			// hash doesn't match, shutdown
 | |
| 		} else if seen {
 | |
| 			e.node.Shutdown()
 | |
| 			// doesn't exist
 | |
| 		} else {
 | |
| 			e = &mdnsEntry{hash: h}
 | |
| 		}
 | |
| 
 | |
| 		txt, err := encode(&mdnsTxt{
 | |
| 			Service:   service.Name,
 | |
| 			Version:   service.Version,
 | |
| 			Endpoints: service.Endpoints,
 | |
| 			Metadata:  node.Metadata,
 | |
| 		})
 | |
| 
 | |
| 		if err != nil {
 | |
| 			gerr = err
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		//
 | |
| 		host, pt, err := net.SplitHostPort(node.Address)
 | |
| 		if err != nil {
 | |
| 			gerr = err
 | |
| 			continue
 | |
| 		}
 | |
| 		port, _ := strconv.Atoi(pt)
 | |
| 
 | |
| 		// we got here, new node
 | |
| 		s, err := mdns.NewMDNSService(
 | |
| 			node.Id,
 | |
| 			service.Name,
 | |
| 			"",
 | |
| 			"",
 | |
| 			port,
 | |
| 			[]net.IP{net.ParseIP(host)},
 | |
| 			txt,
 | |
| 		)
 | |
| 		if err != nil {
 | |
| 			gerr = err
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		srv, err := mdns.NewServer(&mdns.Config{Zone: s})
 | |
| 		if err != nil {
 | |
| 			gerr = err
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		e.id = node.Id
 | |
| 		e.node = srv
 | |
| 		entries = append(entries, e)
 | |
| 	}
 | |
| 
 | |
| 	// save
 | |
| 	m.services[service.Name] = entries
 | |
| 
 | |
| 	return gerr
 | |
| }
 | |
| 
 | |
| func (m *mdnsRegistry) Deregister(service *Service) error {
 | |
| 	m.Lock()
 | |
| 	defer m.Unlock()
 | |
| 
 | |
| 	var newEntries []*mdnsEntry
 | |
| 
 | |
| 	// loop existing entries, check if any match, shutdown those that do
 | |
| 	for _, entry := range m.services[service.Name] {
 | |
| 		var remove bool
 | |
| 
 | |
| 		for _, node := range service.Nodes {
 | |
| 			if node.Id == entry.id {
 | |
| 				entry.node.Shutdown()
 | |
| 				remove = true
 | |
| 				break
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		// keep it?
 | |
| 		if !remove {
 | |
| 			newEntries = append(newEntries, entry)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// last entry is the wildcard for list queries. Remove it.
 | |
| 	if len(newEntries) == 1 && newEntries[0].id == "*" {
 | |
| 		newEntries[0].node.Shutdown()
 | |
| 		delete(m.services, service.Name)
 | |
| 	} else {
 | |
| 		m.services[service.Name] = newEntries
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (m *mdnsRegistry) GetService(service string) ([]*Service, error) {
 | |
| 	serviceMap := make(map[string]*Service)
 | |
| 	entries := make(chan *mdns.ServiceEntry, 10)
 | |
| 	done := make(chan bool)
 | |
| 
 | |
| 	p := mdns.DefaultParams(service)
 | |
| 	// set context with timeout
 | |
| 	p.Context, _ = context.WithTimeout(context.Background(), m.opts.Timeout)
 | |
| 	// set entries channel
 | |
| 	p.Entries = entries
 | |
| 
 | |
| 	go func() {
 | |
| 		for {
 | |
| 			select {
 | |
| 			case e := <-entries:
 | |
| 				// list record so skip
 | |
| 				if p.Service == "_services" {
 | |
| 					continue
 | |
| 				}
 | |
| 
 | |
| 				if e.TTL == 0 {
 | |
| 					continue
 | |
| 				}
 | |
| 
 | |
| 				txt, err := decode(e.InfoFields)
 | |
| 				if err != nil {
 | |
| 					continue
 | |
| 				}
 | |
| 
 | |
| 				if txt.Service != service {
 | |
| 					continue
 | |
| 				}
 | |
| 
 | |
| 				s, ok := serviceMap[txt.Version]
 | |
| 				if !ok {
 | |
| 					s = &Service{
 | |
| 						Name:      txt.Service,
 | |
| 						Version:   txt.Version,
 | |
| 						Endpoints: txt.Endpoints,
 | |
| 					}
 | |
| 				}
 | |
| 
 | |
| 				s.Nodes = append(s.Nodes, &Node{
 | |
| 					Id:       strings.TrimSuffix(e.Name, "."+p.Service+"."+p.Domain+"."),
 | |
| 					Address:  fmt.Sprintf("%s:%d", e.AddrV4.String(), e.Port),
 | |
| 					Metadata: txt.Metadata,
 | |
| 				})
 | |
| 
 | |
| 				serviceMap[txt.Version] = s
 | |
| 			case <-p.Context.Done():
 | |
| 				close(done)
 | |
| 				return
 | |
| 			}
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	// execute the query
 | |
| 	if err := mdns.Query(p); err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	// wait for completion
 | |
| 	<-done
 | |
| 
 | |
| 	// create list and return
 | |
| 	var services []*Service
 | |
| 
 | |
| 	for _, service := range serviceMap {
 | |
| 		services = append(services, service)
 | |
| 	}
 | |
| 
 | |
| 	return services, nil
 | |
| }
 | |
| 
 | |
| func (m *mdnsRegistry) ListServices() ([]*Service, error) {
 | |
| 	serviceMap := make(map[string]bool)
 | |
| 	entries := make(chan *mdns.ServiceEntry, 10)
 | |
| 	done := make(chan bool)
 | |
| 
 | |
| 	p := mdns.DefaultParams("_services")
 | |
| 	// set context with timeout
 | |
| 	p.Context, _ = context.WithTimeout(context.Background(), m.opts.Timeout)
 | |
| 	// set entries channel
 | |
| 	p.Entries = entries
 | |
| 
 | |
| 	var services []*Service
 | |
| 
 | |
| 	go func() {
 | |
| 		for {
 | |
| 			select {
 | |
| 			case e := <-entries:
 | |
| 				if e.TTL == 0 {
 | |
| 					continue
 | |
| 				}
 | |
| 
 | |
| 				name := strings.TrimSuffix(e.Name, "."+p.Service+"."+p.Domain+".")
 | |
| 				if !serviceMap[name] {
 | |
| 					serviceMap[name] = true
 | |
| 					services = append(services, &Service{Name: name})
 | |
| 				}
 | |
| 			case <-p.Context.Done():
 | |
| 				close(done)
 | |
| 				return
 | |
| 			}
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	// execute query
 | |
| 	if err := mdns.Query(p); err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	// wait till done
 | |
| 	<-done
 | |
| 
 | |
| 	return services, nil
 | |
| }
 | |
| 
 | |
| func (m *mdnsRegistry) Watch(opts ...WatchOption) (Watcher, error) {
 | |
| 	var wo WatchOptions
 | |
| 	for _, o := range opts {
 | |
| 		o(&wo)
 | |
| 	}
 | |
| 
 | |
| 	md := &mdnsWatcher{
 | |
| 		wo:   wo,
 | |
| 		ch:   make(chan *mdns.ServiceEntry, 32),
 | |
| 		exit: make(chan struct{}),
 | |
| 	}
 | |
| 
 | |
| 	go func() {
 | |
| 		if err := mdns.Listen(md.ch, md.exit); err != nil {
 | |
| 			md.Stop()
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	return md, nil
 | |
| }
 | |
| 
 | |
| func (m *mdnsRegistry) String() string {
 | |
| 	return "mdns"
 | |
| }
 | |
| 
 | |
| // NewRegistry returns a new default registry which is mdns
 | |
| func NewRegistry(opts ...Option) Registry {
 | |
| 	return newRegistry(opts...)
 | |
| }
 |