diff --git a/server_test.go b/server_test.go index 52e24a6..3f20bc6 100644 --- a/server_test.go +++ b/server_test.go @@ -15,10 +15,7 @@ func TestServer_StartStop(t *testing.T) { } func TestServer_Lookup(t *testing.T) { - s := makeService(t) - s.Service = "_foobar._tcp" - s.Init() - serv, err := NewServer(&Config{Zone: s}) + serv, err := NewServer(&Config{Zone: makeServiceWithServiceName(t, "_foobar._tcp")}) if err != nil { t.Fatalf("err: %v", err) } diff --git a/zone.go b/zone.go index 4195263..432fcf8 100644 --- a/zone.go +++ b/zone.go @@ -2,10 +2,11 @@ package mdns import ( "fmt" - "github.com/miekg/dns" "net" "os" "strings" + + "github.com/miekg/dns" ) const ( @@ -16,80 +17,113 @@ const ( // Zone is the interface used to integrate with the server and // to serve records dynamically type Zone interface { + // Records returns DNS records in response to a DNS question. Records(q dns.Question) []dns.RR } // MDNSService is used to export a named service by implementing a Zone type MDNSService struct { - Instance string // Instance name (e.g. host name) - Service string // Service name (e.g. _http._tcp.) - HostName string // Host machine DNS name - Port int // Service Port - Info string // Service info served as a TXT record - Domain string // If blank, assumes "local" - - Addr net.IP // @Deprecated. Service IP - - ipv4Addr net.IP // Host machine IPv4 address - ipv6Addr net.IP // Host machine IPv6 address + Instance string // Instance name (e.g. "hostService name") + Service string // Service name (e.g. "_http._tcp.") + Domain string // If blank, assumes "local" + HostName string // Host machine DNS name (e.g. "mymachine.net.") + Port int // Service Port + IPs []net.IP // IP addresses for the service's host + TXT []string // Service TXT records serviceAddr string // Fully qualified service address instanceAddr string // Fully qualified instance address } -// Init should be called to setup the internal state -func (m *MDNSService) Init() error { - // Setup default domain - if m.Domain == "" { - m.Domain = "local" +// validateFQDN returns an error if the passed string is not a fully qualified +// hdomain name (more specifically, a hostname). +func validateFQDN(s string) error { + if len(s) == 0 { + return fmt.Errorf("FQDN must not be blank") } + if s[len(s)-1] != '.' { + return fmt.Errorf("FQDN must end in period: %s", s) + } + // TODO(reddaly): Perform full validation. + return nil +} + +// NewMDNSService returns a new instance of MDNSService. +// +// If domain, hostName, or ips is set to the zero value, then a default value +// will be inferred from the operating system. +// +// TODO(reddaly): This interface may need to change to account for "unique +// record" conflict rules of the mDNS protocol. Upon startup, the server should +// check to ensure that the instance name does not conflict with other instance +// names, and, if required, select a new name. There may also be conflicting +// hostName A/AAAA records. +func NewMDNSService(instance, service, domain, hostName string, port int, ips []net.IP, txt []string) (*MDNSService, error) { // Sanity check inputs - if m.Instance == "" { - return fmt.Errorf("Missing service instance name") + if instance == "" { + return nil, fmt.Errorf("missing service instance name") } - if m.Service == "" { - return fmt.Errorf("Missing service name") + if service == "" { + return nil, fmt.Errorf("missing service name") } - if m.Port == 0 { - return fmt.Errorf("Missing service port") + if port == 0 { + return nil, fmt.Errorf("missing service port") } - // Get host information - hostName, err := os.Hostname() - if err == nil { - m.HostName = fmt.Sprintf("%s.", hostName) + // Set default domain + if domain == "" { + domain = "local." + } + if err := validateFQDN(domain); err != nil { + return nil, fmt.Errorf("domain %q is not a fully-qualified domain name: %v", domain, err) + } - addrs, err := net.LookupIP(m.HostName) + // Get host information if no host is specified. + if hostName == "" { + var err error + hostName, err = os.Hostname() + if err != nil { + return nil, fmt.Errorf("could not determine host: %v", err) + } + hostName = fmt.Sprintf("%s.", hostName) + } + if err := validateFQDN(hostName); err != nil { + return nil, fmt.Errorf("hostName %q is not a fully-qualified domain name: %v", hostName, err) + } + + if len(ips) == 0 { + var err error + ips, err = net.LookupIP(hostName) if err != nil { // Try appending the host domain suffix and lookup again // (required for Linux-based hosts) - tmpHostName := fmt.Sprintf("%s%s.", m.HostName, m.Domain) + tmpHostName := fmt.Sprintf("%s%s.", hostName, domain) - addrs, err = net.LookupIP(tmpHostName) + ips, err = net.LookupIP(tmpHostName) if err != nil { - return fmt.Errorf("Could not determine host IP addresses for %s", m.HostName) + return nil, fmt.Errorf("could not determine host IP addresses for %s", hostName) } } - - for i := 0; i < len(addrs); i++ { - if ipv4 := addrs[i].To4(); ipv4 != nil { - m.ipv4Addr = addrs[i] - } else if ipv6 := addrs[i].To16(); ipv6 != nil { - m.ipv6Addr = addrs[i] - } + } + for _, ip := range ips { + if ip.To4() == nil && ip.To16() == nil { + return nil, fmt.Errorf("invalid IP address in IPs list: %v", ip) } - } else { - return fmt.Errorf("Could not determine host") } - // Create the full addresses - m.serviceAddr = fmt.Sprintf("%s.%s.", - trimDot(m.Service), trimDot(m.Domain)) - m.instanceAddr = fmt.Sprintf("%s.%s", - trimDot(m.Instance), m.serviceAddr) - return nil + return &MDNSService{ + Instance: instance, + Service: service, + Domain: domain, + HostName: hostName, + Port: port, + IPs: ips, + TXT: txt, + serviceAddr: fmt.Sprintf("%s.%s.", trimDot(service), trimDot(domain)), + instanceAddr: fmt.Sprintf("%s.%s.%s.", instance, trimDot(service), trimDot(domain)), + }, nil } // trimDot is used to trim the dots from the start or end of a string @@ -97,6 +131,7 @@ func trimDot(s string) string { return strings.Trim(s, ".") } +// Records returns DNS records in response to a DNS question. func (m *MDNSService) Records(q dns.Question) []dns.RR { switch q.Name { case m.serviceAddr: @@ -126,7 +161,7 @@ func (m *MDNSService) serviceRecords(q dns.Question) []dns.RR { } servRec := []dns.RR{rr} - // Get the isntance records + // Get the instance records instRecs := m.instanceRecords(dns.Question{ Name: m.instanceAddr, Qtype: dns.TypeANY, @@ -157,42 +192,46 @@ func (m *MDNSService) instanceRecords(q dns.Question) []dns.RR { return recs case dns.TypeA: - // Only handle if we have a ipv4 addr - ipv4 := m.Addr.To4() - if ipv4 != nil { - m.ipv4Addr = ipv4 - } else if m.ipv4Addr == nil { - return nil + var rr []dns.RR + for _, ip := range m.IPs { + if ip4 := ip.To4(); ip4 != nil { + rr = append(rr, &dns.A{ + Hdr: dns.RR_Header{ + Name: m.HostName, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: defaultTTL, + }, + A: ip4, + }) + } } - a := &dns.A{ - Hdr: dns.RR_Header{ - Name: m.HostName, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - Ttl: defaultTTL, - }, - A: m.ipv4Addr, - } - return []dns.RR{a} + return rr case dns.TypeAAAA: - // Only handle if we have a ipv6 addr - ipv6 := m.Addr.To16() - if ipv6 != nil && m.Addr.To4() == nil { - m.ipv6Addr = ipv6 - } else if m.ipv6Addr == nil && m.Addr.To4() != nil { - return nil + var rr []dns.RR + for _, ip := range m.IPs { + if ip.To4() != nil { + // TODO(reddaly): IPv4 addresses could be encoded in IPv6 format and + // putinto AAAA records, but the current logic puts ipv4-encodable + // addresses into the A records exclusively. Perhaps this should be + // configurable? + continue + } + + if ip16 := ip.To16(); ip16 != nil { + rr = append(rr, &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: m.HostName, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: defaultTTL, + }, + AAAA: ip16, + }) + } } - a4 := &dns.AAAA{ - Hdr: dns.RR_Header{ - Name: m.HostName, - Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, - Ttl: defaultTTL, - }, - AAAA: m.ipv6Addr, - } - return []dns.RR{a4} + return rr case dns.TypeSRV: // Create the SRV Record @@ -231,7 +270,7 @@ func (m *MDNSService) instanceRecords(q dns.Question) []dns.RR { Class: dns.ClassINET, Ttl: defaultTTL, }, - Txt: []string{m.Info}, + Txt: m.TXT, } return []dns.RR{txt} } diff --git a/zone_test.go b/zone_test.go index a613dce..ff345d2 100644 --- a/zone_test.go +++ b/zone_test.go @@ -2,25 +2,65 @@ package mdns import ( "bytes" - "github.com/miekg/dns" + "net" "reflect" "testing" + + "github.com/miekg/dns" ) func makeService(t *testing.T) *MDNSService { - m := &MDNSService{ - Instance: "hostname.", - Service: "_http._tcp.", - Port: 80, - Info: "Local web server", - Domain: "local.", - } - if err := m.Init(); err != nil { + return makeServiceWithServiceName(t, "_http._tcp") +} + +func makeServiceWithServiceName(t *testing.T, service string) *MDNSService { + m, err := NewMDNSService( + "hostname", + service, + "local.", + "testhost.", + 80, // port + []net.IP{net.IP([]byte{192, 168, 0, 42}), net.ParseIP("2620:0:1000:1900:b0c2:d0b2:c411:18bc")}, + []string{"Local web server"}) // TXT + + if err != nil { t.Fatalf("err: %v", err) } + return m } +func TestNewMDNSService_BadParams(t *testing.T) { + for _, test := range []struct { + testName string + hostName string + domain string + }{ + { + "NewMDNSService should fail when passed hostName that is not a legal fully-qualified domain name", + "hostname", // not legal FQDN - should be "hostname." or "hostname.local.", etc. + "local.", // legal + }, + { + "NewMDNSService should fail when passed domain that is not a legal fully-qualified domain name", + "hostname.", // legal + "local", // should be "local." + }, + } { + _, err := NewMDNSService( + "instance name", + "_http._tcp", + test.domain, + test.hostName, + 80, // port + []net.IP{net.IP([]byte{192, 168, 0, 42})}, + []string{"Local web server"}) // TXT + if err == nil { + t.Fatalf("%s: error expected, but got none", test.testName) + } + } +} + func TestMDNSService_BadAddr(t *testing.T) { s := makeService(t) q := dns.Question{ @@ -40,35 +80,32 @@ func TestMDNSService_ServiceAddr(t *testing.T) { Qtype: dns.TypeANY, } recs := s.Records(q) - if len(recs) != 5 { - t.Fatalf("bad: %v", recs) + if got, want := len(recs), 5; got != want { + t.Fatalf("got %d records, want %d: %v", got, want, recs) } - ptr, ok := recs[0].(*dns.PTR) - if !ok { - t.Fatalf("bad: %v", recs[0]) + if ptr, ok := recs[0].(*dns.PTR); !ok { + t.Errorf("recs[0] should be PTR record, got: %v, all records: %v", recs[0], recs) + } else if got, want := ptr.Ptr, "hostname._http._tcp.local."; got != want { + t.Fatalf("bad PTR record %v: got %v, want %v", ptr, got, want) } + if _, ok := recs[1].(*dns.SRV); !ok { - t.Fatalf("bad: %v", recs[1]) + t.Errorf("recs[1] should be SRV record, got: %v, all reccords: %v", recs[1], recs) } if _, ok := recs[2].(*dns.A); !ok { - t.Fatalf("bad: %v", recs[2]) + t.Errorf("recs[2] should be A record, got: %v, all records: %v", recs[2], recs) } if _, ok := recs[3].(*dns.AAAA); !ok { - t.Fatalf("bad: %v", recs[3]) + t.Errorf("recs[3] should be AAAA record, got: %v, all records: %v", recs[3], recs) } if _, ok := recs[4].(*dns.TXT); !ok { - t.Fatalf("bad: %v", recs[4]) - } - - if ptr.Ptr != s.instanceAddr { - t.Fatalf("bad: %v", recs[0]) + t.Errorf("recs[4] should be TXT record, got: %v, all records: %v", recs[4], recs) } q.Qtype = dns.TypePTR - recs2 := s.Records(q) - if !reflect.DeepEqual(recs, recs2) { - t.Fatalf("no match: %v %v", recs, recs2) + if recs2 := s.Records(q); !reflect.DeepEqual(recs, recs2) { + t.Fatalf("PTR question should return same result as ANY question: ANY => %v, PTR => %v", recs, recs2) } } @@ -136,7 +173,7 @@ func TestMDNSService_InstanceAddr_A(t *testing.T) { if !ok { t.Fatalf("bad: %v", recs[0]) } - if !bytes.Equal(a.A, s.ipv4Addr) { + if !bytes.Equal(a.A, []byte{192, 168, 0, 42}) { t.Fatalf("bad: %v", recs[0]) } } @@ -155,7 +192,11 @@ func TestMDNSService_InstanceAddr_AAAA(t *testing.T) { if !ok { t.Fatalf("bad: %v", recs[0]) } - if !bytes.Equal(a4.AAAA, s.ipv6Addr) { + ip6 := net.ParseIP("2620:0:1000:1900:b0c2:d0b2:c411:18bc") + if got := len(ip6); got != net.IPv6len { + t.Fatalf("test IP failed to parse (len = %d, want %d)", got, net.IPv6len) + } + if !bytes.Equal(a4.AAAA, ip6) { t.Fatalf("bad: %v", recs[0]) } } @@ -174,7 +215,7 @@ func TestMDNSService_InstanceAddr_TXT(t *testing.T) { if !ok { t.Fatalf("bad: %v", recs[0]) } - if txt.Txt[0] != s.Info { - t.Fatalf("bad: %v", recs[0]) + if got, want := txt.Txt, s.TXT; !reflect.DeepEqual(got, want) { + t.Fatalf("TXT record mismatch for %v: got %v, want %v", recs[0], got, want) } }