From 7be4a6767335e21a043215c4ebc959ad997585a3 Mon Sep 17 00:00:00 2001 From: Dominic Wong Date: Tue, 30 Jun 2020 11:12:52 +0100 Subject: [PATCH] MDNS registry fix for users on VPNs (#1759) * filter out unsolicited responses * send to local ip in case * allow ip func to be passed in. add option for sending to 0.0.0.0 --- registry/mdns_registry.go | 2 +- util/mdns/client.go | 7 ++++ util/mdns/server.go | 73 ++++++++++++++++++++++++++++++--------- util/mdns/server_test.go | 4 +-- 4 files changed, 67 insertions(+), 19 deletions(-) diff --git a/registry/mdns_registry.go b/registry/mdns_registry.go index 9496fe0a..cd441992 100644 --- a/registry/mdns_registry.go +++ b/registry/mdns_registry.go @@ -252,7 +252,7 @@ func (m *mdnsRegistry) Register(service *Service, opts ...RegisterOption) error continue } - srv, err := mdns.NewServer(&mdns.Config{Zone: s}) + srv, err := mdns.NewServer(&mdns.Config{Zone: s, LocalhostChecking: true}) if err != nil { gerr = err continue diff --git a/util/mdns/client.go b/util/mdns/client.go index c7b84655..ba88cea2 100644 --- a/util/mdns/client.go +++ b/util/mdns/client.go @@ -34,6 +34,7 @@ type ServiceEntry struct { // complete is used to check if we have all the info we need func (s *ServiceEntry) complete() bool { + return (len(s.AddrV4) > 0 || len(s.AddrV6) > 0 || len(s.Addr) > 0) && s.Port != 0 && s.hasTXT } @@ -347,15 +348,21 @@ func (c *client) query(params *QueryParam) error { select { case resp := <-msgCh: inp := messageToEntry(resp, inprogress) + if inp == nil { continue } + if len(resp.Question) == 0 || resp.Question[0].Name != m.Question[0].Name { + // discard anything which we've not asked for + continue + } // Check if this entry is complete if inp.complete() { if inp.sent { continue } + inp.sent = true select { case params.Entries <- inp: diff --git a/util/mdns/server.go b/util/mdns/server.go index 909b39c5..e62d3e92 100644 --- a/util/mdns/server.go +++ b/util/mdns/server.go @@ -2,13 +2,13 @@ package mdns import ( "fmt" - "log" "math/rand" "net" "sync" "sync/atomic" "time" + log "github.com/micro/go-micro/v2/logger" "github.com/miekg/dns" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" @@ -39,6 +39,10 @@ var ( } ) +// GetMachineIP is a func which returns the outbound IP of this machine. +// Used by the server to determine whether to attempt send the response on a local address +type GetMachineIP func() net.IP + // Config is used to configure the mDNS server type Config struct { // Zone must be provided to support responding to queries @@ -51,9 +55,15 @@ type Config struct { // Port If it is not 0, replace the port 5353 with this port number. Port int + + // GetMachineIP is a function to return the IP of the local machine + GetMachineIP GetMachineIP + // LocalhostChecking if enabled asks the server to also send responses to 0.0.0.0 if the target IP + // is this host (as defined by GetMachineIP). Useful in case machine is on a VPN which blocks comms on non standard ports + LocalhostChecking bool } -// mDNS server is used to listen for mDNS queries and respond if we +// Server is an mDNS server used to listen for mDNS queries and respond if we // have a matching local record type Server struct { config *Config @@ -65,6 +75,8 @@ type Server struct { shutdownCh chan struct{} shutdownLock sync.Mutex wg sync.WaitGroup + + outboundIP net.IP } // NewServer is used to create a new mDNS server from a config @@ -118,11 +130,17 @@ func NewServer(config *Config) (*Server, error) { } } + ipFunc := getOutboundIP + if config.GetMachineIP != nil { + ipFunc = config.GetMachineIP + } + s := &Server{ config: config, ipv4List: ipv4List, ipv6List: ipv6List, shutdownCh: make(chan struct{}), + outboundIP: ipFunc(), } go s.recv(s.ipv4List) @@ -176,7 +194,7 @@ func (s *Server) recv(c *net.UDPConn) { continue } if err := s.parsePacket(buf[:n], from); err != nil { - log.Printf("[ERR] mdns: Failed to handle query: %v", err) + log.Errorf("[ERR] mdns: Failed to handle query: %v", err) } } } @@ -185,7 +203,7 @@ func (s *Server) recv(c *net.UDPConn) { func (s *Server) parsePacket(packet []byte, from net.Addr) error { var msg dns.Msg if err := msg.Unpack(packet); err != nil { - log.Printf("[ERR] mdns: Failed to unpack packet: %v", err) + log.Errorf("[ERR] mdns: Failed to unpack packet: %v", err) return err } // TODO: This is a bit of a hack @@ -278,8 +296,8 @@ func (s *Server) handleQuery(query *dns.Msg, from net.Addr) error { // caveats in the RFC), so set the Compress bit (part of the dns library // API, not part of the DNS packet) to true. Compress: true, - - Answer: answer, + Question: query.Question, + Answer: answer, } } @@ -302,7 +320,6 @@ func (s *Server) handleQuery(query *dns.Msg, from net.Addr) error { // both. The return values are DNS records for each transmission type. func (s *Server) handleQuestion(q dns.Question) (multicastRecs, unicastRecs []dns.RR) { records := s.config.Zone.Records(q) - if len(records) == 0 { return nil, nil } @@ -365,7 +382,7 @@ func (s *Server) probe() { for i := 0; i < 3; i++ { if err := s.SendMulticast(q); err != nil { - log.Println("[ERR] mdns: failed to send probe:", err.Error()) + log.Errorf("[ERR] mdns: failed to send probe:", err.Error()) } time.Sleep(time.Duration(randomizer.Intn(250)) * time.Millisecond) } @@ -391,7 +408,7 @@ func (s *Server) probe() { timer := time.NewTimer(timeout) for i := 0; i < 3; i++ { if err := s.SendMulticast(resp); err != nil { - log.Println("[ERR] mdns: failed to send announcement:", err.Error()) + log.Errorf("[ERR] mdns: failed to send announcement:", err.Error()) } select { case <-timer.C: @@ -404,7 +421,7 @@ func (s *Server) probe() { } } -// multicastResponse us used to send a multicast response packet +// SendMulticast us used to send a multicast response packet func (s *Server) SendMulticast(msg *dns.Msg) error { buf, err := msg.Pack() if err != nil { @@ -430,13 +447,23 @@ func (s *Server) sendResponse(resp *dns.Msg, from net.Addr) error { // Determine the socket to send from addr := from.(*net.UDPAddr) - if addr.IP.To4() != nil { - _, err = s.ipv4List.WriteToUDP(buf, addr) - return err - } else { - _, err = s.ipv6List.WriteToUDP(buf, addr) - return err + conn := s.ipv4List + backupTarget := net.IPv4zero + + if addr.IP.To4() == nil { + conn = s.ipv6List + backupTarget = net.IPv6zero } + _, err = conn.WriteToUDP(buf, addr) + // If the address we're responding to is this machine then we can also attempt sending on 0.0.0.0 + // This covers the case where this machine is using a VPN and certain ports are blocked so the response never gets there + // Sending two responses is OK + if s.config.LocalhostChecking && addr.IP.Equal(s.outboundIP) { + // ignore any errors, this is best efforts + conn.WriteToUDP(buf, &net.UDPAddr{IP: backupTarget, Port: addr.Port}) + } + return err + } func (s *Server) unregister() error { @@ -474,3 +501,17 @@ func setCustomPort(port int) { } } } + +// getOutboundIP returns the IP address of this machine as seen when dialling out +func getOutboundIP() net.IP { + conn, err := net.Dial("udp", "8.8.8.8:80") + if err != nil { + // no net connectivity maybe so fallback + return nil + } + defer conn.Close() + + localAddr := conn.LocalAddr().(*net.UDPAddr) + + return localAddr.IP +} diff --git a/util/mdns/server_test.go b/util/mdns/server_test.go index 6fb00fa2..b1488c16 100644 --- a/util/mdns/server_test.go +++ b/util/mdns/server_test.go @@ -7,7 +7,7 @@ import ( func TestServer_StartStop(t *testing.T) { s := makeService(t) - serv, err := NewServer(&Config{Zone: s}) + serv, err := NewServer(&Config{Zone: s, LocalhostChecking: true}) if err != nil { t.Fatalf("err: %v", err) } @@ -15,7 +15,7 @@ func TestServer_StartStop(t *testing.T) { } func TestServer_Lookup(t *testing.T) { - serv, err := NewServer(&Config{Zone: makeServiceWithServiceName(t, "_foobar._tcp")}) + serv, err := NewServer(&Config{Zone: makeServiceWithServiceName(t, "_foobar._tcp"), LocalhostChecking: true}) if err != nil { t.Fatalf("err: %v", err) }