Merge pull request #12 from richtr/improved_dns_srv_formatting

Pull request for issue #4
This commit is contained in:
Armon Dadgar 2014-09-13 15:38:04 -07:00
commit f5542c2469
5 changed files with 76 additions and 42 deletions

View File

@ -16,7 +16,6 @@ Using the library is very simple, here is an example of publishing a service ent
service := &mdns.MDNSService{ service := &mdns.MDNSService{
Instance: host, Instance: host,
Service: "_foobar._tcp", Service: "_foobar._tcp",
Addr: []byte{127,0,0,1},
Port: 8000, Port: 8000,
Info: "My awesome service", Info: "My awesome service",
} }

View File

@ -15,17 +15,21 @@ import (
// ServiceEntry is returned after we query for a service // ServiceEntry is returned after we query for a service
type ServiceEntry struct { type ServiceEntry struct {
Name string Name string
Addr net.IP Host string
AddrV4 net.IP
AddrV6 net.IP
Port int Port int
Info string Info string
Addr net.IP // @Deprecated
hasTXT bool hasTXT bool
sent bool sent bool
} }
// complete is used to check if we have all the info we need // complete is used to check if we have all the info we need
func (s *ServiceEntry) complete() bool { func (s *ServiceEntry) complete() bool {
return s.Addr != nil && s.Port != 0 && s.hasTXT return (s.AddrV4 != nil || s.AddrV6 != nil || s.Addr != nil) && s.Port != 0 && s.hasTXT
} }
// QueryParam is used to customize how a Lookup is performed // QueryParam is used to customize how a Lookup is performed
@ -196,7 +200,7 @@ func (c *client) query(params *QueryParam) error {
// Get the port // Get the port
inp = ensureName(inprogress, rr.Hdr.Name) inp = ensureName(inprogress, rr.Hdr.Name)
inp.Name = rr.Target inp.Host = rr.Target
inp.Port = int(rr.Port) inp.Port = int(rr.Port)
case *dns.TXT: case *dns.TXT:
@ -208,12 +212,14 @@ func (c *client) query(params *QueryParam) error {
case *dns.A: case *dns.A:
// Pull out the IP // Pull out the IP
inp = ensureName(inprogress, rr.Hdr.Name) inp = ensureName(inprogress, rr.Hdr.Name)
inp.Addr = rr.A inp.Addr = rr.A // @Deprecated
inp.AddrV4 = rr.A
case *dns.AAAA: case *dns.AAAA:
// Pull out the IP // Pull out the IP
inp = ensureName(inprogress, rr.Hdr.Name) inp = ensureName(inprogress, rr.Hdr.Name)
inp.Addr = rr.AAAA inp.Addr = rr.AAAA // @Deprecated
inp.AddrV6 = rr.AAAA
} }
} }

View File

@ -1,7 +1,6 @@
package mdns package mdns
import ( import (
"bytes"
"testing" "testing"
"time" "time"
) )
@ -33,9 +32,6 @@ func TestServer_Lookup(t *testing.T) {
if e.Name != "hostname._foobar._tcp.local." { if e.Name != "hostname._foobar._tcp.local." {
t.Fatalf("bad: %v", e) t.Fatalf("bad: %v", e)
} }
if !bytes.Equal(e.Addr.To4(), []byte{127, 0, 0, 1}) {
t.Fatalf("bad: %v", e)
}
if e.Port != 80 { if e.Port != 80 {
t.Fatalf("bad: %v", e) t.Fatalf("bad: %v", e)
} }

60
zone.go
View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"github.com/miekg/dns" "github.com/miekg/dns"
"net" "net"
"os"
"strings" "strings"
) )
@ -22,10 +23,15 @@ type Zone interface {
type MDNSService struct { type MDNSService struct {
Instance string // Instance name (e.g. host name) Instance string // Instance name (e.g. host name)
Service string // Service name (e.g. _http._tcp.) Service string // Service name (e.g. _http._tcp.)
Addr net.IP // Service IP HostName string // Host machine DNS name
Port int // Service Port Port int // Service Port
Info string // Service info served as a TXT record Info string // Service info served as a TXT record
Domain string // If blank, assumes ".local" Domain string // If blank, assumes "local"
Addr net.IP // @Deprecated. Service IP
ipv4Addr net.IP // Host machine IPv4 address
ipv6Addr net.IP // Host machine IPv6 address
serviceAddr string // Fully qualified service address serviceAddr string // Fully qualified service address
instanceAddr string // Fully qualified instance address instanceAddr string // Fully qualified instance address
@ -45,13 +51,39 @@ func (m *MDNSService) Init() error {
if m.Service == "" { if m.Service == "" {
return fmt.Errorf("Missing service name") return fmt.Errorf("Missing service name")
} }
if m.Addr == nil {
return fmt.Errorf("Missing service address")
}
if m.Port == 0 { if m.Port == 0 {
return fmt.Errorf("Missing service port") return fmt.Errorf("Missing service port")
} }
// Get host information
hostName, err := os.Hostname()
if err == nil {
m.HostName = fmt.Sprintf("%s.", hostName)
addrs, err := net.LookupIP(m.HostName)
if err != nil {
// Try appending the host domain suffix and lookup again
// (required for Linux-based hosts)
tmpHostName := fmt.Sprintf("%s%s.", m.HostName, m.Domain)
addrs, err = net.LookupIP(tmpHostName)
if err != nil {
return fmt.Errorf("Could not determine host IP addresses for %s", m.HostName)
}
}
for i := 0; i < len(addrs); i++ {
if ipv4 := addrs[i].To4(); ipv4 != nil {
m.ipv4Addr = addrs[i]
} else if ipv6 := addrs[i].To16(); ipv6 != nil {
m.ipv6Addr = addrs[i]
}
}
} else {
return fmt.Errorf("Could not determine host")
}
// Create the full addresses // Create the full addresses
m.serviceAddr = fmt.Sprintf("%s.%s.", m.serviceAddr = fmt.Sprintf("%s.%s.",
trimDot(m.Service), trimDot(m.Domain)) trimDot(m.Service), trimDot(m.Domain))
@ -127,34 +159,38 @@ func (m *MDNSService) instanceRecords(q dns.Question) []dns.RR {
case dns.TypeA: case dns.TypeA:
// Only handle if we have a ipv4 addr // Only handle if we have a ipv4 addr
ipv4 := m.Addr.To4() ipv4 := m.Addr.To4()
if ipv4 == nil { if ipv4 != nil {
m.ipv4Addr = ipv4
} else if m.ipv4Addr == nil {
return nil return nil
} }
a := &dns.A{ a := &dns.A{
Hdr: dns.RR_Header{ Hdr: dns.RR_Header{
Name: q.Name, Name: m.HostName,
Rrtype: dns.TypeA, Rrtype: dns.TypeA,
Class: dns.ClassINET, Class: dns.ClassINET,
Ttl: defaultTTL, Ttl: defaultTTL,
}, },
A: ipv4, A: m.ipv4Addr,
} }
return []dns.RR{a} return []dns.RR{a}
case dns.TypeAAAA: case dns.TypeAAAA:
// Only handle if we have a ipv6 addr // Only handle if we have a ipv6 addr
ipv6 := m.Addr.To16() ipv6 := m.Addr.To16()
if m.Addr.To4() != nil { if ipv6 != nil && m.Addr.To4() == nil {
m.ipv6Addr = ipv6
} else if m.ipv6Addr == nil && m.Addr.To4() != nil {
return nil return nil
} }
a4 := &dns.AAAA{ a4 := &dns.AAAA{
Hdr: dns.RR_Header{ Hdr: dns.RR_Header{
Name: q.Name, Name: m.HostName,
Rrtype: dns.TypeAAAA, Rrtype: dns.TypeAAAA,
Class: dns.ClassINET, Class: dns.ClassINET,
Ttl: defaultTTL, Ttl: defaultTTL,
}, },
AAAA: ipv6, AAAA: m.ipv6Addr,
} }
return []dns.RR{a4} return []dns.RR{a4}
@ -170,7 +206,7 @@ func (m *MDNSService) instanceRecords(q dns.Question) []dns.RR {
Priority: 10, Priority: 10,
Weight: 1, Weight: 1,
Port: uint16(m.Port), Port: uint16(m.Port),
Target: q.Name, Target: m.HostName,
} }
recs := []dns.RR{srv} recs := []dns.RR{srv}

View File

@ -11,7 +11,6 @@ func makeService(t *testing.T) *MDNSService {
m := &MDNSService{ m := &MDNSService{
Instance: "hostname.", Instance: "hostname.",
Service: "_http._tcp.", Service: "_http._tcp.",
Addr: []byte{127, 0, 0, 1},
Port: 80, Port: 80,
Info: "Local web server", Info: "Local web server",
Domain: "local.", Domain: "local.",
@ -41,7 +40,7 @@ func TestMDNSService_ServiceAddr(t *testing.T) {
Qtype: dns.TypeANY, Qtype: dns.TypeANY,
} }
recs := s.Records(q) recs := s.Records(q)
if len(recs) != 4 { if len(recs) != 5 {
t.Fatalf("bad: %v", recs) t.Fatalf("bad: %v", recs)
} }
@ -55,9 +54,12 @@ func TestMDNSService_ServiceAddr(t *testing.T) {
if _, ok := recs[2].(*dns.A); !ok { if _, ok := recs[2].(*dns.A); !ok {
t.Fatalf("bad: %v", recs[2]) t.Fatalf("bad: %v", recs[2])
} }
if _, ok := recs[3].(*dns.TXT); !ok { if _, ok := recs[3].(*dns.AAAA); !ok {
t.Fatalf("bad: %v", recs[3]) t.Fatalf("bad: %v", recs[3])
} }
if _, ok := recs[4].(*dns.TXT); !ok {
t.Fatalf("bad: %v", recs[4])
}
if ptr.Ptr != s.instanceAddr { if ptr.Ptr != s.instanceAddr {
t.Fatalf("bad: %v", recs[0]) t.Fatalf("bad: %v", recs[0])
@ -77,7 +79,7 @@ func TestMDNSService_InstanceAddr_ANY(t *testing.T) {
Qtype: dns.TypeANY, Qtype: dns.TypeANY,
} }
recs := s.Records(q) recs := s.Records(q)
if len(recs) != 3 { if len(recs) != 4 {
t.Fatalf("bad: %v", recs) t.Fatalf("bad: %v", recs)
} }
if _, ok := recs[0].(*dns.SRV); !ok { if _, ok := recs[0].(*dns.SRV); !ok {
@ -86,9 +88,12 @@ func TestMDNSService_InstanceAddr_ANY(t *testing.T) {
if _, ok := recs[1].(*dns.A); !ok { if _, ok := recs[1].(*dns.A); !ok {
t.Fatalf("bad: %v", recs[1]) t.Fatalf("bad: %v", recs[1])
} }
if _, ok := recs[2].(*dns.TXT); !ok { if _, ok := recs[2].(*dns.AAAA); !ok {
t.Fatalf("bad: %v", recs[2]) t.Fatalf("bad: %v", recs[2])
} }
if _, ok := recs[3].(*dns.TXT); !ok {
t.Fatalf("bad: %v", recs[3])
}
} }
func TestMDNSService_InstanceAddr_SRV(t *testing.T) { func TestMDNSService_InstanceAddr_SRV(t *testing.T) {
@ -98,7 +103,7 @@ func TestMDNSService_InstanceAddr_SRV(t *testing.T) {
Qtype: dns.TypeSRV, Qtype: dns.TypeSRV,
} }
recs := s.Records(q) recs := s.Records(q)
if len(recs) != 2 { if len(recs) != 3 {
t.Fatalf("bad: %v", recs) t.Fatalf("bad: %v", recs)
} }
srv, ok := recs[0].(*dns.SRV) srv, ok := recs[0].(*dns.SRV)
@ -108,10 +113,10 @@ func TestMDNSService_InstanceAddr_SRV(t *testing.T) {
if _, ok := recs[1].(*dns.A); !ok { if _, ok := recs[1].(*dns.A); !ok {
t.Fatalf("bad: %v", recs[1]) t.Fatalf("bad: %v", recs[1])
} }
if _, ok := recs[2].(*dns.AAAA); !ok {
if srv.Target != s.instanceAddr { t.Fatalf("bad: %v", recs[2])
t.Fatalf("bad: %v", recs[0])
} }
if srv.Port != uint16(s.Port) { if srv.Port != uint16(s.Port) {
t.Fatalf("bad: %v", recs[0]) t.Fatalf("bad: %v", recs[0])
} }
@ -131,26 +136,18 @@ func TestMDNSService_InstanceAddr_A(t *testing.T) {
if !ok { if !ok {
t.Fatalf("bad: %v", recs[0]) t.Fatalf("bad: %v", recs[0])
} }
if !bytes.Equal(a.A, s.Addr) { if !bytes.Equal(a.A, s.ipv4Addr) {
t.Fatalf("bad: %v", recs[0]) t.Fatalf("bad: %v", recs[0])
} }
} }
func TestMDNSService_InstanceAddr_AAAA(t *testing.T) { func TestMDNSService_InstanceAddr_AAAA(t *testing.T) {
s := makeService(t) s := makeService(t)
s.Addr = []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
11, 12, 13, 14, 15, 16}
q := dns.Question{ q := dns.Question{
Name: "hostname._http._tcp.local.", Name: "hostname._http._tcp.local.",
Qtype: dns.TypeA, Qtype: dns.TypeAAAA,
} }
recs := s.Records(q) recs := s.Records(q)
if len(recs) != 0 {
t.Fatalf("bad: %v", recs)
}
q.Qtype = dns.TypeAAAA
recs = s.Records(q)
if len(recs) != 1 { if len(recs) != 1 {
t.Fatalf("bad: %v", recs) t.Fatalf("bad: %v", recs)
} }
@ -158,7 +155,7 @@ func TestMDNSService_InstanceAddr_AAAA(t *testing.T) {
if !ok { if !ok {
t.Fatalf("bad: %v", recs[0]) t.Fatalf("bad: %v", recs[0])
} }
if !bytes.Equal(a4.AAAA, s.Addr) { if !bytes.Equal(a4.AAAA, s.ipv6Addr) {
t.Fatalf("bad: %v", recs[0]) t.Fatalf("bad: %v", recs[0])
} }
} }