Merge pull request #17 from gonzojive/refactor-zone
Refactor mdns.MDNSService API in zone.go.
This commit is contained in:
		| @@ -15,10 +15,7 @@ func TestServer_StartStop(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestServer_Lookup(t *testing.T) { | ||||
| 	s := makeService(t) | ||||
| 	s.Service = "_foobar._tcp" | ||||
| 	s.Init() | ||||
| 	serv, err := NewServer(&Config{Zone: s}) | ||||
| 	serv, err := NewServer(&Config{Zone: makeServiceWithServiceName(t, "_foobar._tcp")}) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %v", err) | ||||
| 	} | ||||
|   | ||||
							
								
								
									
										199
									
								
								zone.go
									
									
									
									
									
								
							
							
						
						
									
										199
									
								
								zone.go
									
									
									
									
									
								
							| @@ -2,10 +2,11 @@ package mdns | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/miekg/dns" | ||||
| 	"net" | ||||
| 	"os" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/miekg/dns" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| @@ -16,80 +17,113 @@ const ( | ||||
| // Zone is the interface used to integrate with the server and | ||||
| // to serve records dynamically | ||||
| type Zone interface { | ||||
| 	// Records returns DNS records in response to a DNS question. | ||||
| 	Records(q dns.Question) []dns.RR | ||||
| } | ||||
|  | ||||
| // MDNSService is used to export a named service by implementing a Zone | ||||
| type MDNSService struct { | ||||
| 	Instance string // Instance name (e.g. host name) | ||||
| 	Service  string // Service name (e.g. _http._tcp.) | ||||
| 	HostName string // Host machine DNS name | ||||
| 	Port     int    // Service Port | ||||
| 	Info     string // Service info served as a TXT record | ||||
| 	Domain   string // If blank, assumes "local" | ||||
|  | ||||
| 	Addr     net.IP // @Deprecated. Service IP | ||||
|  | ||||
| 	ipv4Addr net.IP // Host machine IPv4 address | ||||
| 	ipv6Addr net.IP // Host machine IPv6 address | ||||
| 	Instance string   // Instance name (e.g. "hostService name") | ||||
| 	Service  string   // Service name (e.g. "_http._tcp.") | ||||
| 	Domain   string   // If blank, assumes "local" | ||||
| 	HostName string   // Host machine DNS name (e.g. "mymachine.net.") | ||||
| 	Port     int      // Service Port | ||||
| 	IPs      []net.IP // IP addresses for the service's host | ||||
| 	TXT      []string // Service TXT records | ||||
|  | ||||
| 	serviceAddr  string // Fully qualified service address | ||||
| 	instanceAddr string // Fully qualified instance address | ||||
| } | ||||
|  | ||||
| // Init should be called to setup the internal state | ||||
| func (m *MDNSService) Init() error { | ||||
| 	// Setup default domain | ||||
| 	if m.Domain == "" { | ||||
| 		m.Domain = "local" | ||||
| // validateFQDN returns an error if the passed string is not a fully qualified | ||||
| // hdomain name (more specifically, a hostname). | ||||
| func validateFQDN(s string) error { | ||||
| 	if len(s) == 0 { | ||||
| 		return fmt.Errorf("FQDN must not be blank") | ||||
| 	} | ||||
| 	if s[len(s)-1] != '.' { | ||||
| 		return fmt.Errorf("FQDN must end in period: %s", s) | ||||
| 	} | ||||
| 	// TODO(reddaly): Perform full validation. | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // NewMDNSService returns a new instance of MDNSService. | ||||
| // | ||||
| // If domain, hostName, or ips is set to the zero value, then a default value | ||||
| // will be inferred from the operating system. | ||||
| // | ||||
| // TODO(reddaly): This interface may need to change to account for "unique | ||||
| // record" conflict rules of the mDNS protocol.  Upon startup, the server should | ||||
| // check to ensure that the instance name does not conflict with other instance | ||||
| // names, and, if required, select a new name.  There may also be conflicting | ||||
| // hostName A/AAAA records. | ||||
| func NewMDNSService(instance, service, domain, hostName string, port int, ips []net.IP, txt []string) (*MDNSService, error) { | ||||
| 	// Sanity check inputs | ||||
| 	if m.Instance == "" { | ||||
| 		return fmt.Errorf("Missing service instance name") | ||||
| 	if instance == "" { | ||||
| 		return nil, fmt.Errorf("missing service instance name") | ||||
| 	} | ||||
| 	if m.Service == "" { | ||||
| 		return fmt.Errorf("Missing service name") | ||||
| 	if service == "" { | ||||
| 		return nil, fmt.Errorf("missing service name") | ||||
| 	} | ||||
| 	if m.Port == 0 { | ||||
| 		return fmt.Errorf("Missing service port") | ||||
| 	if port == 0 { | ||||
| 		return nil, fmt.Errorf("missing service port") | ||||
| 	} | ||||
|  | ||||
| 	// Get host information | ||||
| 	hostName, err := os.Hostname() | ||||
| 	if err == nil { | ||||
| 		m.HostName = fmt.Sprintf("%s.", hostName) | ||||
| 	// Set default domain | ||||
| 	if domain == "" { | ||||
| 		domain = "local." | ||||
| 	} | ||||
| 	if err := validateFQDN(domain); err != nil { | ||||
| 		return nil, fmt.Errorf("domain %q is not a fully-qualified domain name: %v", domain, err) | ||||
| 	} | ||||
|  | ||||
| 		addrs, err := net.LookupIP(m.HostName) | ||||
| 	// Get host information if no host is specified. | ||||
| 	if hostName == "" { | ||||
| 		var err error | ||||
| 		hostName, err = os.Hostname() | ||||
| 		if err != nil { | ||||
| 			return nil, fmt.Errorf("could not determine host: %v", err) | ||||
| 		} | ||||
| 		hostName = fmt.Sprintf("%s.", hostName) | ||||
| 	} | ||||
| 	if err := validateFQDN(hostName); err != nil { | ||||
| 		return nil, fmt.Errorf("hostName %q is not a fully-qualified domain name: %v", hostName, err) | ||||
| 	} | ||||
|  | ||||
| 	if len(ips) == 0 { | ||||
| 		var err error | ||||
| 		ips, err = net.LookupIP(hostName) | ||||
| 		if err != nil { | ||||
| 			// Try appending the host domain suffix and lookup again | ||||
| 			// (required for Linux-based hosts) | ||||
| 			tmpHostName := fmt.Sprintf("%s%s.", m.HostName, m.Domain) | ||||
| 			tmpHostName := fmt.Sprintf("%s%s.", hostName, domain) | ||||
|  | ||||
| 			addrs, err = net.LookupIP(tmpHostName) | ||||
| 			ips, err = net.LookupIP(tmpHostName) | ||||
|  | ||||
| 			if err != nil { | ||||
| 				return fmt.Errorf("Could not determine host IP addresses for %s", m.HostName) | ||||
| 				return nil, fmt.Errorf("could not determine host IP addresses for %s", hostName) | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		for i := 0; i < len(addrs); i++ { | ||||
| 			if ipv4 := addrs[i].To4(); ipv4 != nil { | ||||
| 				m.ipv4Addr = addrs[i] | ||||
| 			} else if ipv6 := addrs[i].To16(); ipv6 != nil { | ||||
| 				m.ipv6Addr = addrs[i] | ||||
| 			} | ||||
| 	} | ||||
| 	for _, ip := range ips { | ||||
| 		if ip.To4() == nil && ip.To16() == nil { | ||||
| 			return nil, fmt.Errorf("invalid IP address in IPs list: %v", ip) | ||||
| 		} | ||||
| 	} else { | ||||
| 		return fmt.Errorf("Could not determine host") | ||||
| 	} | ||||
|  | ||||
| 	// Create the full addresses | ||||
| 	m.serviceAddr = fmt.Sprintf("%s.%s.", | ||||
| 		trimDot(m.Service), trimDot(m.Domain)) | ||||
| 	m.instanceAddr = fmt.Sprintf("%s.%s", | ||||
| 		trimDot(m.Instance), m.serviceAddr) | ||||
| 	return nil | ||||
| 	return &MDNSService{ | ||||
| 		Instance:     instance, | ||||
| 		Service:      service, | ||||
| 		Domain:       domain, | ||||
| 		HostName:     hostName, | ||||
| 		Port:         port, | ||||
| 		IPs:          ips, | ||||
| 		TXT:          txt, | ||||
| 		serviceAddr:  fmt.Sprintf("%s.%s.", trimDot(service), trimDot(domain)), | ||||
| 		instanceAddr: fmt.Sprintf("%s.%s.%s.", instance, trimDot(service), trimDot(domain)), | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| // trimDot is used to trim the dots from the start or end of a string | ||||
| @@ -97,6 +131,7 @@ func trimDot(s string) string { | ||||
| 	return strings.Trim(s, ".") | ||||
| } | ||||
|  | ||||
| // Records returns DNS records in response to a DNS question. | ||||
| func (m *MDNSService) Records(q dns.Question) []dns.RR { | ||||
| 	switch q.Name { | ||||
| 	case m.serviceAddr: | ||||
| @@ -126,7 +161,7 @@ func (m *MDNSService) serviceRecords(q dns.Question) []dns.RR { | ||||
| 		} | ||||
| 		servRec := []dns.RR{rr} | ||||
|  | ||||
| 		// Get the isntance records | ||||
| 		// Get the instance records | ||||
| 		instRecs := m.instanceRecords(dns.Question{ | ||||
| 			Name:  m.instanceAddr, | ||||
| 			Qtype: dns.TypeANY, | ||||
| @@ -157,42 +192,46 @@ func (m *MDNSService) instanceRecords(q dns.Question) []dns.RR { | ||||
| 		return recs | ||||
|  | ||||
| 	case dns.TypeA: | ||||
| 		// Only handle if we have a ipv4 addr | ||||
| 		ipv4 := m.Addr.To4() | ||||
| 		if ipv4 != nil { | ||||
| 			m.ipv4Addr = ipv4 | ||||
| 		} else if m.ipv4Addr == nil { | ||||
| 			return nil | ||||
| 		var rr []dns.RR | ||||
| 		for _, ip := range m.IPs { | ||||
| 			if ip4 := ip.To4(); ip4 != nil { | ||||
| 				rr = append(rr, &dns.A{ | ||||
| 					Hdr: dns.RR_Header{ | ||||
| 						Name:   m.HostName, | ||||
| 						Rrtype: dns.TypeA, | ||||
| 						Class:  dns.ClassINET, | ||||
| 						Ttl:    defaultTTL, | ||||
| 					}, | ||||
| 					A: ip4, | ||||
| 				}) | ||||
| 			} | ||||
| 		} | ||||
| 		a := &dns.A{ | ||||
| 			Hdr: dns.RR_Header{ | ||||
| 				Name:   m.HostName, | ||||
| 				Rrtype: dns.TypeA, | ||||
| 				Class:  dns.ClassINET, | ||||
| 				Ttl:    defaultTTL, | ||||
| 			}, | ||||
| 			A: m.ipv4Addr, | ||||
| 		} | ||||
| 		return []dns.RR{a} | ||||
| 		return rr | ||||
|  | ||||
| 	case dns.TypeAAAA: | ||||
| 		// Only handle if we have a ipv6 addr | ||||
| 		ipv6 := m.Addr.To16() | ||||
| 		if ipv6 != nil && m.Addr.To4() == nil { | ||||
| 			m.ipv6Addr = ipv6 | ||||
| 		} else if m.ipv6Addr == nil && m.Addr.To4() != nil { | ||||
| 			return nil | ||||
| 		var rr []dns.RR | ||||
| 		for _, ip := range m.IPs { | ||||
| 			if ip.To4() != nil { | ||||
| 				// TODO(reddaly): IPv4 addresses could be encoded in IPv6 format and | ||||
| 				// putinto AAAA records, but the current logic puts ipv4-encodable | ||||
| 				// addresses into the A records exclusively.  Perhaps this should be | ||||
| 				// configurable? | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			if ip16 := ip.To16(); ip16 != nil { | ||||
| 				rr = append(rr, &dns.AAAA{ | ||||
| 					Hdr: dns.RR_Header{ | ||||
| 						Name:   m.HostName, | ||||
| 						Rrtype: dns.TypeAAAA, | ||||
| 						Class:  dns.ClassINET, | ||||
| 						Ttl:    defaultTTL, | ||||
| 					}, | ||||
| 					AAAA: ip16, | ||||
| 				}) | ||||
| 			} | ||||
| 		} | ||||
| 		a4 := &dns.AAAA{ | ||||
| 			Hdr: dns.RR_Header{ | ||||
| 				Name:   m.HostName, | ||||
| 				Rrtype: dns.TypeAAAA, | ||||
| 				Class:  dns.ClassINET, | ||||
| 				Ttl:    defaultTTL, | ||||
| 			}, | ||||
| 			AAAA: m.ipv6Addr, | ||||
| 		} | ||||
| 		return []dns.RR{a4} | ||||
| 		return rr | ||||
|  | ||||
| 	case dns.TypeSRV: | ||||
| 		// Create the SRV Record | ||||
| @@ -231,7 +270,7 @@ func (m *MDNSService) instanceRecords(q dns.Question) []dns.RR { | ||||
| 				Class:  dns.ClassINET, | ||||
| 				Ttl:    defaultTTL, | ||||
| 			}, | ||||
| 			Txt: []string{m.Info}, | ||||
| 			Txt: m.TXT, | ||||
| 		} | ||||
| 		return []dns.RR{txt} | ||||
| 	} | ||||
|   | ||||
							
								
								
									
										99
									
								
								zone_test.go
									
									
									
									
									
								
							
							
						
						
									
										99
									
								
								zone_test.go
									
									
									
									
									
								
							| @@ -2,25 +2,65 @@ package mdns | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"github.com/miekg/dns" | ||||
| 	"net" | ||||
| 	"reflect" | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/miekg/dns" | ||||
| ) | ||||
|  | ||||
| func makeService(t *testing.T) *MDNSService { | ||||
| 	m := &MDNSService{ | ||||
| 		Instance: "hostname.", | ||||
| 		Service:  "_http._tcp.", | ||||
| 		Port:     80, | ||||
| 		Info:     "Local web server", | ||||
| 		Domain:   "local.", | ||||
| 	} | ||||
| 	if err := m.Init(); err != nil { | ||||
| 	return makeServiceWithServiceName(t, "_http._tcp") | ||||
| } | ||||
|  | ||||
| func makeServiceWithServiceName(t *testing.T, service string) *MDNSService { | ||||
| 	m, err := NewMDNSService( | ||||
| 		"hostname", | ||||
| 		service, | ||||
| 		"local.", | ||||
| 		"testhost.", | ||||
| 		80, // port | ||||
| 		[]net.IP{net.IP([]byte{192, 168, 0, 42}), net.ParseIP("2620:0:1000:1900:b0c2:d0b2:c411:18bc")}, | ||||
| 		[]string{"Local web server"}) // TXT | ||||
|  | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	return m | ||||
| } | ||||
|  | ||||
| func TestNewMDNSService_BadParams(t *testing.T) { | ||||
| 	for _, test := range []struct { | ||||
| 		testName string | ||||
| 		hostName string | ||||
| 		domain   string | ||||
| 	}{ | ||||
| 		{ | ||||
| 			"NewMDNSService should fail when passed hostName that is not a legal fully-qualified domain name", | ||||
| 			"hostname", // not legal FQDN - should be "hostname." or "hostname.local.", etc. | ||||
| 			"local.",   // legal | ||||
| 		}, | ||||
| 		{ | ||||
| 			"NewMDNSService should fail when passed domain that is not a legal fully-qualified domain name", | ||||
| 			"hostname.", // legal | ||||
| 			"local",     // should be "local." | ||||
| 		}, | ||||
| 	} { | ||||
| 		_, err := NewMDNSService( | ||||
| 			"instance name", | ||||
| 			"_http._tcp", | ||||
| 			test.domain, | ||||
| 			test.hostName, | ||||
| 			80, // port | ||||
| 			[]net.IP{net.IP([]byte{192, 168, 0, 42})}, | ||||
| 			[]string{"Local web server"}) // TXT | ||||
| 		if err == nil { | ||||
| 			t.Fatalf("%s: error expected, but got none", test.testName) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestMDNSService_BadAddr(t *testing.T) { | ||||
| 	s := makeService(t) | ||||
| 	q := dns.Question{ | ||||
| @@ -40,35 +80,32 @@ func TestMDNSService_ServiceAddr(t *testing.T) { | ||||
| 		Qtype: dns.TypeANY, | ||||
| 	} | ||||
| 	recs := s.Records(q) | ||||
| 	if len(recs) != 5 { | ||||
| 		t.Fatalf("bad: %v", recs) | ||||
| 	if got, want := len(recs), 5; got != want { | ||||
| 		t.Fatalf("got %d records, want %d: %v", got, want, recs) | ||||
| 	} | ||||
|  | ||||
| 	ptr, ok := recs[0].(*dns.PTR) | ||||
| 	if !ok { | ||||
| 		t.Fatalf("bad: %v", recs[0]) | ||||
| 	if ptr, ok := recs[0].(*dns.PTR); !ok { | ||||
| 		t.Errorf("recs[0] should be PTR record, got: %v, all records: %v", recs[0], recs) | ||||
| 	} else if got, want := ptr.Ptr, "hostname._http._tcp.local."; got != want { | ||||
| 		t.Fatalf("bad PTR record %v: got %v, want %v", ptr, got, want) | ||||
| 	} | ||||
|  | ||||
| 	if _, ok := recs[1].(*dns.SRV); !ok { | ||||
| 		t.Fatalf("bad: %v", recs[1]) | ||||
| 		t.Errorf("recs[1] should be SRV record, got: %v, all reccords: %v", recs[1], recs) | ||||
| 	} | ||||
| 	if _, ok := recs[2].(*dns.A); !ok { | ||||
| 		t.Fatalf("bad: %v", recs[2]) | ||||
| 		t.Errorf("recs[2] should be A record, got: %v, all records: %v", recs[2], recs) | ||||
| 	} | ||||
| 	if _, ok := recs[3].(*dns.AAAA); !ok { | ||||
| 		t.Fatalf("bad: %v", recs[3]) | ||||
| 		t.Errorf("recs[3] should be AAAA record, got: %v, all records: %v", recs[3], recs) | ||||
| 	} | ||||
| 	if _, ok := recs[4].(*dns.TXT); !ok { | ||||
| 		t.Fatalf("bad: %v", recs[4]) | ||||
| 	} | ||||
|  | ||||
| 	if ptr.Ptr != s.instanceAddr { | ||||
| 		t.Fatalf("bad: %v", recs[0]) | ||||
| 		t.Errorf("recs[4] should be TXT record, got: %v, all records: %v", recs[4], recs) | ||||
| 	} | ||||
|  | ||||
| 	q.Qtype = dns.TypePTR | ||||
| 	recs2 := s.Records(q) | ||||
| 	if !reflect.DeepEqual(recs, recs2) { | ||||
| 		t.Fatalf("no match: %v %v", recs, recs2) | ||||
| 	if recs2 := s.Records(q); !reflect.DeepEqual(recs, recs2) { | ||||
| 		t.Fatalf("PTR question should return same result as ANY question: ANY => %v, PTR => %v", recs, recs2) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -136,7 +173,7 @@ func TestMDNSService_InstanceAddr_A(t *testing.T) { | ||||
| 	if !ok { | ||||
| 		t.Fatalf("bad: %v", recs[0]) | ||||
| 	} | ||||
| 	if !bytes.Equal(a.A, s.ipv4Addr) { | ||||
| 	if !bytes.Equal(a.A, []byte{192, 168, 0, 42}) { | ||||
| 		t.Fatalf("bad: %v", recs[0]) | ||||
| 	} | ||||
| } | ||||
| @@ -155,7 +192,11 @@ func TestMDNSService_InstanceAddr_AAAA(t *testing.T) { | ||||
| 	if !ok { | ||||
| 		t.Fatalf("bad: %v", recs[0]) | ||||
| 	} | ||||
| 	if !bytes.Equal(a4.AAAA, s.ipv6Addr) { | ||||
| 	ip6 := net.ParseIP("2620:0:1000:1900:b0c2:d0b2:c411:18bc") | ||||
| 	if got := len(ip6); got != net.IPv6len { | ||||
| 		t.Fatalf("test IP failed to parse (len = %d, want %d)", got, net.IPv6len) | ||||
| 	} | ||||
| 	if !bytes.Equal(a4.AAAA, ip6) { | ||||
| 		t.Fatalf("bad: %v", recs[0]) | ||||
| 	} | ||||
| } | ||||
| @@ -174,7 +215,7 @@ func TestMDNSService_InstanceAddr_TXT(t *testing.T) { | ||||
| 	if !ok { | ||||
| 		t.Fatalf("bad: %v", recs[0]) | ||||
| 	} | ||||
| 	if txt.Txt[0] != s.Info { | ||||
| 		t.Fatalf("bad: %v", recs[0]) | ||||
| 	if got, want := txt.Txt, s.TXT; !reflect.DeepEqual(got, want) { | ||||
| 		t.Fatalf("TXT record mismatch for %v: got %v, want %v", recs[0], got, want) | ||||
| 	} | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user