Adding zones with tests
This commit is contained in:
parent
a26d5c2b94
commit
60a10ee1d7
198
zone.go
Normal file
198
zone.go
Normal file
@ -0,0 +1,198 @@
|
||||
package mdns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/miekg/dns"
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Zone is the interface used to integrate with the server and
|
||||
// to serve records dynamically
|
||||
type Zone interface {
|
||||
Records(q dns.Question) []dns.RR
|
||||
}
|
||||
|
||||
// MDNSService is used to export a named service by implementing a Zone
|
||||
type MDNSService struct {
|
||||
Instance string // Instance name (e.g. host name)
|
||||
Service string // Service name (e.g. _http._tcp.)
|
||||
Addr net.IP // Service IP
|
||||
Port int // Service Port
|
||||
Info string // Service info served as a TXT record
|
||||
Domain string // If blank, assumes ".local"
|
||||
|
||||
serviceAddr string // Fully qualified service address
|
||||
instanceAddr string // Fully qualified instance address
|
||||
}
|
||||
|
||||
// Init should be called to setup the internal state
|
||||
func (m *MDNSService) Init() error {
|
||||
// Setup default domain
|
||||
if m.Domain == "" {
|
||||
m.Domain = "local"
|
||||
}
|
||||
|
||||
// Sanity check inputs
|
||||
if m.Instance == "" {
|
||||
return fmt.Errorf("Missing service instance name")
|
||||
}
|
||||
if m.Service == "" {
|
||||
return fmt.Errorf("Missing service name")
|
||||
}
|
||||
if m.Addr == nil {
|
||||
return fmt.Errorf("Missing service address")
|
||||
}
|
||||
if m.Port == 0 {
|
||||
return fmt.Errorf("Missing service port")
|
||||
}
|
||||
|
||||
// Create the full addresses
|
||||
m.serviceAddr = fmt.Sprintf("%s.%s.",
|
||||
trimDot(m.Service), trimDot(m.Domain))
|
||||
m.instanceAddr = fmt.Sprintf("%s.%s",
|
||||
trimDot(m.Instance), m.serviceAddr)
|
||||
return nil
|
||||
}
|
||||
|
||||
// trimDot is used to trim the dots from the start or end of a string
|
||||
func trimDot(s string) string {
|
||||
return strings.Trim(s, ".")
|
||||
}
|
||||
|
||||
func (m *MDNSService) Records(q dns.Question) []dns.RR {
|
||||
switch q.Name {
|
||||
case m.serviceAddr:
|
||||
return m.serviceRecords(q)
|
||||
case m.instanceAddr:
|
||||
return m.instanceRecords(q)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// serviceRecords is called when the query matches the service name
|
||||
func (m *MDNSService) serviceRecords(q dns.Question) []dns.RR {
|
||||
switch q.Qtype {
|
||||
case dns.TypeANY:
|
||||
fallthrough
|
||||
case dns.TypePTR:
|
||||
// Build a PTR response for the service
|
||||
rr := &dns.PTR{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: q.Name,
|
||||
Rrtype: dns.TypePTR,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 0,
|
||||
},
|
||||
Ptr: m.instanceAddr,
|
||||
}
|
||||
servRec := []dns.RR{rr}
|
||||
|
||||
// Get the isntance records
|
||||
instRecs := m.instanceRecords(dns.Question{
|
||||
Name: m.instanceAddr,
|
||||
Qtype: dns.TypeANY,
|
||||
})
|
||||
|
||||
// Return the service record with the instance records
|
||||
return append(servRec, instRecs...)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// serviceRecords is called when the query matches the instance name
|
||||
func (m *MDNSService) instanceRecords(q dns.Question) []dns.RR {
|
||||
switch q.Qtype {
|
||||
case dns.TypeANY:
|
||||
// Get the SRV, which includes A and AAAA
|
||||
recs := m.instanceRecords(dns.Question{
|
||||
Name: m.instanceAddr,
|
||||
Qtype: dns.TypeSRV,
|
||||
})
|
||||
|
||||
// Add the TXT record
|
||||
recs = append(recs, m.instanceRecords(dns.Question{
|
||||
Name: m.instanceAddr,
|
||||
Qtype: dns.TypeTXT,
|
||||
})...)
|
||||
return recs
|
||||
|
||||
case dns.TypeA:
|
||||
// Only handle if we have a ipv4 addr
|
||||
ipv4 := m.Addr.To4()
|
||||
if ipv4 == nil {
|
||||
return nil
|
||||
}
|
||||
a := &dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: q.Name,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 0,
|
||||
},
|
||||
A: ipv4,
|
||||
}
|
||||
return []dns.RR{a}
|
||||
|
||||
case dns.TypeAAAA:
|
||||
// Only handle if we have a ipv6 addr
|
||||
ipv6 := m.Addr.To16()
|
||||
if m.Addr.To4() != nil {
|
||||
return nil
|
||||
}
|
||||
a4 := &dns.AAAA{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: q.Name,
|
||||
Rrtype: dns.TypeAAAA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 0,
|
||||
},
|
||||
AAAA: ipv6,
|
||||
}
|
||||
return []dns.RR{a4}
|
||||
|
||||
case dns.TypeSRV:
|
||||
// Create the SRV Record
|
||||
srv := &dns.SRV{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: q.Name,
|
||||
Rrtype: dns.TypeSRV,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 0,
|
||||
},
|
||||
Priority: 10,
|
||||
Weight: 1,
|
||||
Port: uint16(m.Port),
|
||||
Target: q.Name,
|
||||
}
|
||||
recs := []dns.RR{srv}
|
||||
|
||||
// Add the A record
|
||||
recs = append(recs, m.instanceRecords(dns.Question{
|
||||
Name: m.instanceAddr,
|
||||
Qtype: dns.TypeA,
|
||||
})...)
|
||||
|
||||
// Add the AAAA record
|
||||
recs = append(recs, m.instanceRecords(dns.Question{
|
||||
Name: m.instanceAddr,
|
||||
Qtype: dns.TypeAAAA,
|
||||
})...)
|
||||
return recs
|
||||
|
||||
case dns.TypeTXT:
|
||||
txt := &dns.TXT{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: q.Name,
|
||||
Rrtype: dns.TypeTXT,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 0,
|
||||
},
|
||||
Txt: []string{m.Info},
|
||||
}
|
||||
return []dns.RR{txt}
|
||||
}
|
||||
return nil
|
||||
}
|
183
zone_test.go
Normal file
183
zone_test.go
Normal file
@ -0,0 +1,183 @@
|
||||
package mdns
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/miekg/dns"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func makeService(t *testing.T) *MDNSService {
|
||||
m := &MDNSService{
|
||||
Instance: "hostname.",
|
||||
Service: "_http._tcp.",
|
||||
Addr: []byte{127, 0, 0, 1},
|
||||
Port: 80,
|
||||
Info: "Local web server",
|
||||
Domain: "local.",
|
||||
}
|
||||
if err := m.Init(); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func TestMDNSService_BadAddr(t *testing.T) {
|
||||
s := makeService(t)
|
||||
q := dns.Question{
|
||||
Name: "random",
|
||||
Qtype: dns.TypeANY,
|
||||
}
|
||||
recs := s.Records(q)
|
||||
if len(recs) != 0 {
|
||||
t.Fatalf("bad: %v", recs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMDNSService_ServiceAddr(t *testing.T) {
|
||||
s := makeService(t)
|
||||
q := dns.Question{
|
||||
Name: "_http._tcp.local.",
|
||||
Qtype: dns.TypeANY,
|
||||
}
|
||||
recs := s.Records(q)
|
||||
if len(recs) != 4 {
|
||||
t.Fatalf("bad: %v", recs)
|
||||
}
|
||||
|
||||
ptr, ok := recs[0].(*dns.PTR)
|
||||
if !ok {
|
||||
t.Fatalf("bad: %v", recs[0])
|
||||
}
|
||||
if _, ok := recs[1].(*dns.SRV); !ok {
|
||||
t.Fatalf("bad: %v", recs[1])
|
||||
}
|
||||
if _, ok := recs[2].(*dns.A); !ok {
|
||||
t.Fatalf("bad: %v", recs[2])
|
||||
}
|
||||
if _, ok := recs[3].(*dns.TXT); !ok {
|
||||
t.Fatalf("bad: %v", recs[3])
|
||||
}
|
||||
|
||||
if ptr.Ptr != s.instanceAddr {
|
||||
t.Fatalf("bad: %v", recs[0])
|
||||
}
|
||||
|
||||
q.Qtype = dns.TypePTR
|
||||
recs2 := s.Records(q)
|
||||
if !reflect.DeepEqual(recs, recs2) {
|
||||
t.Fatalf("no match: %v %v", recs, recs2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMDNSService_InstanceAddr_ANY(t *testing.T) {
|
||||
s := makeService(t)
|
||||
q := dns.Question{
|
||||
Name: "hostname._http._tcp.local.",
|
||||
Qtype: dns.TypeANY,
|
||||
}
|
||||
recs := s.Records(q)
|
||||
if len(recs) != 3 {
|
||||
t.Fatalf("bad: %v", recs)
|
||||
}
|
||||
if _, ok := recs[0].(*dns.SRV); !ok {
|
||||
t.Fatalf("bad: %v", recs[0])
|
||||
}
|
||||
if _, ok := recs[1].(*dns.A); !ok {
|
||||
t.Fatalf("bad: %v", recs[1])
|
||||
}
|
||||
if _, ok := recs[2].(*dns.TXT); !ok {
|
||||
t.Fatalf("bad: %v", recs[2])
|
||||
}
|
||||
}
|
||||
|
||||
func TestMDNSService_InstanceAddr_SRV(t *testing.T) {
|
||||
s := makeService(t)
|
||||
q := dns.Question{
|
||||
Name: "hostname._http._tcp.local.",
|
||||
Qtype: dns.TypeSRV,
|
||||
}
|
||||
recs := s.Records(q)
|
||||
if len(recs) != 2 {
|
||||
t.Fatalf("bad: %v", recs)
|
||||
}
|
||||
srv, ok := recs[0].(*dns.SRV)
|
||||
if !ok {
|
||||
t.Fatalf("bad: %v", recs[0])
|
||||
}
|
||||
if _, ok := recs[1].(*dns.A); !ok {
|
||||
t.Fatalf("bad: %v", recs[1])
|
||||
}
|
||||
|
||||
if srv.Target != s.instanceAddr {
|
||||
t.Fatalf("bad: %v", recs[0])
|
||||
}
|
||||
if srv.Port != uint16(s.Port) {
|
||||
t.Fatalf("bad: %v", recs[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestMDNSService_InstanceAddr_A(t *testing.T) {
|
||||
s := makeService(t)
|
||||
q := dns.Question{
|
||||
Name: "hostname._http._tcp.local.",
|
||||
Qtype: dns.TypeA,
|
||||
}
|
||||
recs := s.Records(q)
|
||||
if len(recs) != 1 {
|
||||
t.Fatalf("bad: %v", recs)
|
||||
}
|
||||
a, ok := recs[0].(*dns.A)
|
||||
if !ok {
|
||||
t.Fatalf("bad: %v", recs[0])
|
||||
}
|
||||
if !bytes.Equal(a.A, s.Addr) {
|
||||
t.Fatalf("bad: %v", recs[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestMDNSService_InstanceAddr_AAAA(t *testing.T) {
|
||||
s := makeService(t)
|
||||
s.Addr = []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
|
||||
11, 12, 13, 14, 15, 16}
|
||||
q := dns.Question{
|
||||
Name: "hostname._http._tcp.local.",
|
||||
Qtype: dns.TypeA,
|
||||
}
|
||||
recs := s.Records(q)
|
||||
if len(recs) != 0 {
|
||||
t.Fatalf("bad: %v", recs)
|
||||
}
|
||||
|
||||
q.Qtype = dns.TypeAAAA
|
||||
recs = s.Records(q)
|
||||
if len(recs) != 1 {
|
||||
t.Fatalf("bad: %v", recs)
|
||||
}
|
||||
a4, ok := recs[0].(*dns.AAAA)
|
||||
if !ok {
|
||||
t.Fatalf("bad: %v", recs[0])
|
||||
}
|
||||
if !bytes.Equal(a4.AAAA, s.Addr) {
|
||||
t.Fatalf("bad: %v", recs[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestMDNSService_InstanceAddr_TXT(t *testing.T) {
|
||||
s := makeService(t)
|
||||
q := dns.Question{
|
||||
Name: "hostname._http._tcp.local.",
|
||||
Qtype: dns.TypeTXT,
|
||||
}
|
||||
recs := s.Records(q)
|
||||
if len(recs) != 1 {
|
||||
t.Fatalf("bad: %v", recs)
|
||||
}
|
||||
txt, ok := recs[0].(*dns.TXT)
|
||||
if !ok {
|
||||
t.Fatalf("bad: %v", recs[0])
|
||||
}
|
||||
if txt.Txt[0] != s.Info {
|
||||
t.Fatalf("bad: %v", recs[0])
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user