diff --git a/zone.go b/zone.go index 4be212b..732f309 100644 --- a/zone.go +++ b/zone.go @@ -33,6 +33,7 @@ type MDNSService struct { serviceAddr string // Fully qualified service address instanceAddr string // Fully qualified instance address + enumAddr string // _services._dns-sd._udp. } // validateFQDN returns an error if the passed string is not a fully qualified @@ -123,6 +124,7 @@ func NewMDNSService(instance, service, domain, hostName string, port int, ips [] TXT: txt, serviceAddr: fmt.Sprintf("%s.%s.", trimDot(service), trimDot(domain)), instanceAddr: fmt.Sprintf("%s.%s.%s.", instance, trimDot(service), trimDot(domain)), + enumAddr: fmt.Sprintf("_services._dns-sd._udp.%s.", trimDot(domain)), }, nil } @@ -134,6 +136,8 @@ func trimDot(s string) string { // Records returns DNS records in response to a DNS question. func (m *MDNSService) Records(q dns.Question) []dns.RR { switch q.Name { + case m.enumAddr: + return m.serviceEnum(q) case m.serviceAddr: return m.serviceRecords(q) case m.instanceAddr: @@ -148,6 +152,26 @@ func (m *MDNSService) Records(q dns.Question) []dns.RR { } } +func (m *MDNSService) serviceEnum(q dns.Question) []dns.RR { + switch q.Qtype { + case dns.TypeANY: + fallthrough + case dns.TypePTR: + rr := &dns.PTR{ + Hdr: dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypePTR, + Class: dns.ClassINET, + Ttl: defaultTTL, + }, + Ptr: m.serviceAddr, + } + return []dns.RR{rr} + default: + return nil + } +} + // serviceRecords is called when the query matches the service name func (m *MDNSService) serviceRecords(q dns.Question) []dns.RR { switch q.Qtype { diff --git a/zone_test.go b/zone_test.go index 23cf511..082d72d 100644 --- a/zone_test.go +++ b/zone_test.go @@ -256,3 +256,20 @@ func TestMDNSService_HostNameQuery(t *testing.T) { } } } + +func TestMDNSService_serviceEnum_PTR(t *testing.T) { + s := makeService(t) + q := dns.Question{ + Name: "_services._dns-sd._udp.local.", + Qtype: dns.TypePTR, + } + recs := s.Records(q) + if len(recs) != 1 { + t.Fatalf("bad: %v", recs) + } + if ptr, ok := recs[0].(*dns.PTR); !ok { + t.Errorf("recs[0] should be PTR record, got: %v, all records: %v", recs[0], recs) + } else if got, want := ptr.Ptr, "_http._tcp.local."; got != want { + t.Fatalf("bad PTR record %v: got %v, want %v", ptr, got, want) + } +}