diff --git a/README.md b/README.md index bd902f5..d9550fd 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,6 @@ Using the library is very simple, here is an example of publishing a service ent service := &mdns.MDNSService{ Instance: host, Service: "_foobar._tcp", - Addr: []byte{127,0,0,1}, Port: 8000, Info: "My awesome service", } diff --git a/client.go b/client.go index 3697533..5b3cbda 100644 --- a/client.go +++ b/client.go @@ -15,17 +15,21 @@ import ( // ServiceEntry is returned after we query for a service type ServiceEntry struct { Name string - Addr net.IP + Host string + AddrV4 net.IP + AddrV6 net.IP Port int Info string + Addr net.IP // @Deprecated + hasTXT bool sent bool } // complete is used to check if we have all the info we need func (s *ServiceEntry) complete() bool { - return s.Addr != nil && s.Port != 0 && s.hasTXT + return (s.AddrV4 != nil || s.AddrV6 != nil || s.Addr != nil) && s.Port != 0 && s.hasTXT } // QueryParam is used to customize how a Lookup is performed @@ -196,7 +200,7 @@ func (c *client) query(params *QueryParam) error { // Get the port inp = ensureName(inprogress, rr.Hdr.Name) - inp.Name = rr.Target + inp.Host = rr.Target inp.Port = int(rr.Port) case *dns.TXT: @@ -208,12 +212,14 @@ func (c *client) query(params *QueryParam) error { case *dns.A: // Pull out the IP inp = ensureName(inprogress, rr.Hdr.Name) - inp.Addr = rr.A + inp.Addr = rr.A // @Deprecated + inp.AddrV4 = rr.A case *dns.AAAA: // Pull out the IP inp = ensureName(inprogress, rr.Hdr.Name) - inp.Addr = rr.AAAA + inp.Addr = rr.AAAA // @Deprecated + inp.AddrV6 = rr.AAAA } } diff --git a/server_test.go b/server_test.go index 57b21cf..52e24a6 100644 --- a/server_test.go +++ b/server_test.go @@ -1,7 +1,6 @@ package mdns import ( - "bytes" "testing" "time" ) @@ -33,9 +32,6 @@ func TestServer_Lookup(t *testing.T) { if e.Name != "hostname._foobar._tcp.local." { t.Fatalf("bad: %v", e) } - if !bytes.Equal(e.Addr.To4(), []byte{127, 0, 0, 1}) { - t.Fatalf("bad: %v", e) - } if e.Port != 80 { t.Fatalf("bad: %v", e) } diff --git a/zone.go b/zone.go index 431676e..2c6bd24 100644 --- a/zone.go +++ b/zone.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/miekg/dns" "net" + "os" "strings" ) @@ -22,10 +23,15 @@ type Zone interface { type MDNSService struct { Instance string // Instance name (e.g. host name) Service string // Service name (e.g. _http._tcp.) - Addr net.IP // Service IP + HostName string // Host machine DNS name Port int // Service Port Info string // Service info served as a TXT record - Domain string // If blank, assumes ".local" + Domain string // If blank, assumes "local" + + Addr net.IP // @Deprecated. Service IP + + ipv4Addr net.IP // Host machine IPv4 address + ipv6Addr net.IP // Host machine IPv6 address serviceAddr string // Fully qualified service address instanceAddr string // Fully qualified instance address @@ -45,13 +51,39 @@ func (m *MDNSService) Init() error { if m.Service == "" { return fmt.Errorf("Missing service name") } - if m.Addr == nil { - return fmt.Errorf("Missing service address") - } if m.Port == 0 { return fmt.Errorf("Missing service port") } + // Get host information + hostName, err := os.Hostname() + if err == nil { + m.HostName = fmt.Sprintf("%s.", hostName) + + addrs, err := net.LookupIP(m.HostName) + if err != nil { + // Try appending the host domain suffix and lookup again + // (required for Linux-based hosts) + tmpHostName := fmt.Sprintf("%s%s.", m.HostName, m.Domain) + + addrs, err = net.LookupIP(tmpHostName) + + if err != nil { + return fmt.Errorf("Could not determine host IP addresses for %s", m.HostName) + } + } + + for i := 0; i < len(addrs); i++ { + if ipv4 := addrs[i].To4(); ipv4 != nil { + m.ipv4Addr = addrs[i] + } else if ipv6 := addrs[i].To16(); ipv6 != nil { + m.ipv6Addr = addrs[i] + } + } + } else { + return fmt.Errorf("Could not determine host") + } + // Create the full addresses m.serviceAddr = fmt.Sprintf("%s.%s.", trimDot(m.Service), trimDot(m.Domain)) @@ -127,34 +159,38 @@ func (m *MDNSService) instanceRecords(q dns.Question) []dns.RR { case dns.TypeA: // Only handle if we have a ipv4 addr ipv4 := m.Addr.To4() - if ipv4 == nil { + if ipv4 != nil { + m.ipv4Addr = ipv4 + } else if m.ipv4Addr == nil { return nil } a := &dns.A{ Hdr: dns.RR_Header{ - Name: q.Name, + Name: m.HostName, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: defaultTTL, }, - A: ipv4, + A: m.ipv4Addr, } return []dns.RR{a} case dns.TypeAAAA: // Only handle if we have a ipv6 addr ipv6 := m.Addr.To16() - if m.Addr.To4() != nil { + if ipv6 != nil && m.Addr.To4() == nil { + m.ipv6Addr = ipv6 + } else if m.ipv6Addr == nil && m.Addr.To4() != nil { return nil } a4 := &dns.AAAA{ Hdr: dns.RR_Header{ - Name: q.Name, + Name: m.HostName, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: defaultTTL, }, - AAAA: ipv6, + AAAA: m.ipv6Addr, } return []dns.RR{a4} @@ -170,7 +206,7 @@ func (m *MDNSService) instanceRecords(q dns.Question) []dns.RR { Priority: 10, Weight: 1, Port: uint16(m.Port), - Target: q.Name, + Target: m.HostName, } recs := []dns.RR{srv} diff --git a/zone_test.go b/zone_test.go index 1a964ed..a613dce 100644 --- a/zone_test.go +++ b/zone_test.go @@ -11,7 +11,6 @@ func makeService(t *testing.T) *MDNSService { m := &MDNSService{ Instance: "hostname.", Service: "_http._tcp.", - Addr: []byte{127, 0, 0, 1}, Port: 80, Info: "Local web server", Domain: "local.", @@ -41,7 +40,7 @@ func TestMDNSService_ServiceAddr(t *testing.T) { Qtype: dns.TypeANY, } recs := s.Records(q) - if len(recs) != 4 { + if len(recs) != 5 { t.Fatalf("bad: %v", recs) } @@ -55,9 +54,12 @@ func TestMDNSService_ServiceAddr(t *testing.T) { if _, ok := recs[2].(*dns.A); !ok { t.Fatalf("bad: %v", recs[2]) } - if _, ok := recs[3].(*dns.TXT); !ok { + if _, ok := recs[3].(*dns.AAAA); !ok { t.Fatalf("bad: %v", recs[3]) } + if _, ok := recs[4].(*dns.TXT); !ok { + t.Fatalf("bad: %v", recs[4]) + } if ptr.Ptr != s.instanceAddr { t.Fatalf("bad: %v", recs[0]) @@ -77,7 +79,7 @@ func TestMDNSService_InstanceAddr_ANY(t *testing.T) { Qtype: dns.TypeANY, } recs := s.Records(q) - if len(recs) != 3 { + if len(recs) != 4 { t.Fatalf("bad: %v", recs) } if _, ok := recs[0].(*dns.SRV); !ok { @@ -86,9 +88,12 @@ func TestMDNSService_InstanceAddr_ANY(t *testing.T) { if _, ok := recs[1].(*dns.A); !ok { t.Fatalf("bad: %v", recs[1]) } - if _, ok := recs[2].(*dns.TXT); !ok { + if _, ok := recs[2].(*dns.AAAA); !ok { t.Fatalf("bad: %v", recs[2]) } + if _, ok := recs[3].(*dns.TXT); !ok { + t.Fatalf("bad: %v", recs[3]) + } } func TestMDNSService_InstanceAddr_SRV(t *testing.T) { @@ -98,7 +103,7 @@ func TestMDNSService_InstanceAddr_SRV(t *testing.T) { Qtype: dns.TypeSRV, } recs := s.Records(q) - if len(recs) != 2 { + if len(recs) != 3 { t.Fatalf("bad: %v", recs) } srv, ok := recs[0].(*dns.SRV) @@ -108,10 +113,10 @@ func TestMDNSService_InstanceAddr_SRV(t *testing.T) { if _, ok := recs[1].(*dns.A); !ok { t.Fatalf("bad: %v", recs[1]) } - - if srv.Target != s.instanceAddr { - t.Fatalf("bad: %v", recs[0]) + if _, ok := recs[2].(*dns.AAAA); !ok { + t.Fatalf("bad: %v", recs[2]) } + if srv.Port != uint16(s.Port) { t.Fatalf("bad: %v", recs[0]) } @@ -131,26 +136,18 @@ func TestMDNSService_InstanceAddr_A(t *testing.T) { if !ok { t.Fatalf("bad: %v", recs[0]) } - if !bytes.Equal(a.A, s.Addr) { + if !bytes.Equal(a.A, s.ipv4Addr) { t.Fatalf("bad: %v", recs[0]) } } func TestMDNSService_InstanceAddr_AAAA(t *testing.T) { s := makeService(t) - s.Addr = []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, - 11, 12, 13, 14, 15, 16} q := dns.Question{ Name: "hostname._http._tcp.local.", - Qtype: dns.TypeA, + Qtype: dns.TypeAAAA, } recs := s.Records(q) - if len(recs) != 0 { - t.Fatalf("bad: %v", recs) - } - - q.Qtype = dns.TypeAAAA - recs = s.Records(q) if len(recs) != 1 { t.Fatalf("bad: %v", recs) } @@ -158,7 +155,7 @@ func TestMDNSService_InstanceAddr_AAAA(t *testing.T) { if !ok { t.Fatalf("bad: %v", recs[0]) } - if !bytes.Equal(a4.AAAA, s.Addr) { + if !bytes.Equal(a4.AAAA, s.ipv6Addr) { t.Fatalf("bad: %v", recs[0]) } }