From 60a10ee1d7e590f646661a26647bc9bba58f8070 Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Wed, 29 Jan 2014 14:28:46 -0800 Subject: [PATCH] Adding zones with tests --- zone.go | 198 +++++++++++++++++++++++++++++++++++++++++++++++++++ zone_test.go | 183 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 381 insertions(+) create mode 100644 zone.go create mode 100644 zone_test.go diff --git a/zone.go b/zone.go new file mode 100644 index 0000000..3e69a57 --- /dev/null +++ b/zone.go @@ -0,0 +1,198 @@ +package mdns + +import ( + "fmt" + "github.com/miekg/dns" + "net" + "strings" +) + +// Zone is the interface used to integrate with the server and +// to serve records dynamically +type Zone interface { + Records(q dns.Question) []dns.RR +} + +// MDNSService is used to export a named service by implementing a Zone +type MDNSService struct { + Instance string // Instance name (e.g. host name) + Service string // Service name (e.g. _http._tcp.) + Addr net.IP // Service IP + Port int // Service Port + Info string // Service info served as a TXT record + Domain string // If blank, assumes ".local" + + serviceAddr string // Fully qualified service address + instanceAddr string // Fully qualified instance address +} + +// Init should be called to setup the internal state +func (m *MDNSService) Init() error { + // Setup default domain + if m.Domain == "" { + m.Domain = "local" + } + + // Sanity check inputs + if m.Instance == "" { + return fmt.Errorf("Missing service instance name") + } + if m.Service == "" { + return fmt.Errorf("Missing service name") + } + if m.Addr == nil { + return fmt.Errorf("Missing service address") + } + if m.Port == 0 { + return fmt.Errorf("Missing service port") + } + + // Create the full addresses + m.serviceAddr = fmt.Sprintf("%s.%s.", + trimDot(m.Service), trimDot(m.Domain)) + m.instanceAddr = fmt.Sprintf("%s.%s", + trimDot(m.Instance), m.serviceAddr) + return nil +} + +// trimDot is used to trim the dots from the start or end of a string +func trimDot(s string) string { + return strings.Trim(s, ".") +} + +func (m *MDNSService) Records(q dns.Question) []dns.RR { + switch q.Name { + case m.serviceAddr: + return m.serviceRecords(q) + case m.instanceAddr: + return m.instanceRecords(q) + default: + return nil + } +} + +// serviceRecords is called when the query matches the service name +func (m *MDNSService) serviceRecords(q dns.Question) []dns.RR { + switch q.Qtype { + case dns.TypeANY: + fallthrough + case dns.TypePTR: + // Build a PTR response for the service + rr := &dns.PTR{ + Hdr: dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypePTR, + Class: dns.ClassINET, + Ttl: 0, + }, + Ptr: m.instanceAddr, + } + servRec := []dns.RR{rr} + + // Get the isntance records + instRecs := m.instanceRecords(dns.Question{ + Name: m.instanceAddr, + Qtype: dns.TypeANY, + }) + + // Return the service record with the instance records + return append(servRec, instRecs...) + default: + return nil + } +} + +// serviceRecords is called when the query matches the instance name +func (m *MDNSService) instanceRecords(q dns.Question) []dns.RR { + switch q.Qtype { + case dns.TypeANY: + // Get the SRV, which includes A and AAAA + recs := m.instanceRecords(dns.Question{ + Name: m.instanceAddr, + Qtype: dns.TypeSRV, + }) + + // Add the TXT record + recs = append(recs, m.instanceRecords(dns.Question{ + Name: m.instanceAddr, + Qtype: dns.TypeTXT, + })...) + return recs + + case dns.TypeA: + // Only handle if we have a ipv4 addr + ipv4 := m.Addr.To4() + if ipv4 == nil { + return nil + } + a := &dns.A{ + Hdr: dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 0, + }, + A: ipv4, + } + return []dns.RR{a} + + case dns.TypeAAAA: + // Only handle if we have a ipv6 addr + ipv6 := m.Addr.To16() + if m.Addr.To4() != nil { + return nil + } + a4 := &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: 0, + }, + AAAA: ipv6, + } + return []dns.RR{a4} + + case dns.TypeSRV: + // Create the SRV Record + srv := &dns.SRV{ + Hdr: dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypeSRV, + Class: dns.ClassINET, + Ttl: 0, + }, + Priority: 10, + Weight: 1, + Port: uint16(m.Port), + Target: q.Name, + } + recs := []dns.RR{srv} + + // Add the A record + recs = append(recs, m.instanceRecords(dns.Question{ + Name: m.instanceAddr, + Qtype: dns.TypeA, + })...) + + // Add the AAAA record + recs = append(recs, m.instanceRecords(dns.Question{ + Name: m.instanceAddr, + Qtype: dns.TypeAAAA, + })...) + return recs + + case dns.TypeTXT: + txt := &dns.TXT{ + Hdr: dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypeTXT, + Class: dns.ClassINET, + Ttl: 0, + }, + Txt: []string{m.Info}, + } + return []dns.RR{txt} + } + return nil +} diff --git a/zone_test.go b/zone_test.go new file mode 100644 index 0000000..1a964ed --- /dev/null +++ b/zone_test.go @@ -0,0 +1,183 @@ +package mdns + +import ( + "bytes" + "github.com/miekg/dns" + "reflect" + "testing" +) + +func makeService(t *testing.T) *MDNSService { + m := &MDNSService{ + Instance: "hostname.", + Service: "_http._tcp.", + Addr: []byte{127, 0, 0, 1}, + Port: 80, + Info: "Local web server", + Domain: "local.", + } + if err := m.Init(); err != nil { + t.Fatalf("err: %v", err) + } + return m +} + +func TestMDNSService_BadAddr(t *testing.T) { + s := makeService(t) + q := dns.Question{ + Name: "random", + Qtype: dns.TypeANY, + } + recs := s.Records(q) + if len(recs) != 0 { + t.Fatalf("bad: %v", recs) + } +} + +func TestMDNSService_ServiceAddr(t *testing.T) { + s := makeService(t) + q := dns.Question{ + Name: "_http._tcp.local.", + Qtype: dns.TypeANY, + } + recs := s.Records(q) + if len(recs) != 4 { + t.Fatalf("bad: %v", recs) + } + + ptr, ok := recs[0].(*dns.PTR) + if !ok { + t.Fatalf("bad: %v", recs[0]) + } + if _, ok := recs[1].(*dns.SRV); !ok { + t.Fatalf("bad: %v", recs[1]) + } + if _, ok := recs[2].(*dns.A); !ok { + t.Fatalf("bad: %v", recs[2]) + } + if _, ok := recs[3].(*dns.TXT); !ok { + t.Fatalf("bad: %v", recs[3]) + } + + if ptr.Ptr != s.instanceAddr { + t.Fatalf("bad: %v", recs[0]) + } + + q.Qtype = dns.TypePTR + recs2 := s.Records(q) + if !reflect.DeepEqual(recs, recs2) { + t.Fatalf("no match: %v %v", recs, recs2) + } +} + +func TestMDNSService_InstanceAddr_ANY(t *testing.T) { + s := makeService(t) + q := dns.Question{ + Name: "hostname._http._tcp.local.", + Qtype: dns.TypeANY, + } + recs := s.Records(q) + if len(recs) != 3 { + t.Fatalf("bad: %v", recs) + } + if _, ok := recs[0].(*dns.SRV); !ok { + t.Fatalf("bad: %v", recs[0]) + } + if _, ok := recs[1].(*dns.A); !ok { + t.Fatalf("bad: %v", recs[1]) + } + if _, ok := recs[2].(*dns.TXT); !ok { + t.Fatalf("bad: %v", recs[2]) + } +} + +func TestMDNSService_InstanceAddr_SRV(t *testing.T) { + s := makeService(t) + q := dns.Question{ + Name: "hostname._http._tcp.local.", + Qtype: dns.TypeSRV, + } + recs := s.Records(q) + if len(recs) != 2 { + t.Fatalf("bad: %v", recs) + } + srv, ok := recs[0].(*dns.SRV) + if !ok { + t.Fatalf("bad: %v", recs[0]) + } + if _, ok := recs[1].(*dns.A); !ok { + t.Fatalf("bad: %v", recs[1]) + } + + if srv.Target != s.instanceAddr { + t.Fatalf("bad: %v", recs[0]) + } + if srv.Port != uint16(s.Port) { + t.Fatalf("bad: %v", recs[0]) + } +} + +func TestMDNSService_InstanceAddr_A(t *testing.T) { + s := makeService(t) + q := dns.Question{ + Name: "hostname._http._tcp.local.", + Qtype: dns.TypeA, + } + recs := s.Records(q) + if len(recs) != 1 { + t.Fatalf("bad: %v", recs) + } + a, ok := recs[0].(*dns.A) + if !ok { + t.Fatalf("bad: %v", recs[0]) + } + if !bytes.Equal(a.A, s.Addr) { + t.Fatalf("bad: %v", recs[0]) + } +} + +func TestMDNSService_InstanceAddr_AAAA(t *testing.T) { + s := makeService(t) + s.Addr = []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16} + q := dns.Question{ + Name: "hostname._http._tcp.local.", + Qtype: dns.TypeA, + } + recs := s.Records(q) + if len(recs) != 0 { + t.Fatalf("bad: %v", recs) + } + + q.Qtype = dns.TypeAAAA + recs = s.Records(q) + if len(recs) != 1 { + t.Fatalf("bad: %v", recs) + } + a4, ok := recs[0].(*dns.AAAA) + if !ok { + t.Fatalf("bad: %v", recs[0]) + } + if !bytes.Equal(a4.AAAA, s.Addr) { + t.Fatalf("bad: %v", recs[0]) + } +} + +func TestMDNSService_InstanceAddr_TXT(t *testing.T) { + s := makeService(t) + q := dns.Question{ + Name: "hostname._http._tcp.local.", + Qtype: dns.TypeTXT, + } + recs := s.Records(q) + if len(recs) != 1 { + t.Fatalf("bad: %v", recs) + } + txt, ok := recs[0].(*dns.TXT) + if !ok { + t.Fatalf("bad: %v", recs[0]) + } + if txt.Txt[0] != s.Info { + t.Fatalf("bad: %v", recs[0]) + } +}