MDNS registry fix for users on VPNs (#1759)

* filter out unsolicited responses
* send to local ip in case
* allow ip func to be passed in. add option for sending to 0.0.0.0
This commit is contained in:
Dominic Wong 2020-06-30 11:12:52 +01:00 committed by GitHub
parent 0f5c53b6e4
commit 6532b6208b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 68 additions and 20 deletions

View File

@ -184,7 +184,7 @@ func createServiceMDNSEntry(name, domain string) (*mdnsEntry, error) {
return nil, err return nil, err
} }
srv, err := mdns.NewServer(&mdns.Config{Zone: &mdns.DNSSDService{MDNSService: s}}) srv, err := mdns.NewServer(&mdns.Config{Zone: &mdns.DNSSDService{MDNSService: s}, LocalhostChecking: true})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -273,7 +273,7 @@ func (m *mdnsRegistry) Register(service *Service, opts ...RegisterOption) error
continue continue
} }
srv, err := mdns.NewServer(&mdns.Config{Zone: s}) srv, err := mdns.NewServer(&mdns.Config{Zone: s, LocalhostChecking: true})
if err != nil { if err != nil {
gerr = err gerr = err
continue continue

View File

@ -34,6 +34,7 @@ type ServiceEntry struct {
// complete is used to check if we have all the info we need // complete is used to check if we have all the info we need
func (s *ServiceEntry) complete() bool { func (s *ServiceEntry) complete() bool {
return (len(s.AddrV4) > 0 || len(s.AddrV6) > 0 || len(s.Addr) > 0) && s.Port != 0 && s.hasTXT return (len(s.AddrV4) > 0 || len(s.AddrV6) > 0 || len(s.Addr) > 0) && s.Port != 0 && s.hasTXT
} }
@ -347,15 +348,21 @@ func (c *client) query(params *QueryParam) error {
select { select {
case resp := <-msgCh: case resp := <-msgCh:
inp := messageToEntry(resp, inprogress) inp := messageToEntry(resp, inprogress)
if inp == nil { if inp == nil {
continue continue
} }
if len(resp.Question) == 0 || resp.Question[0].Name != m.Question[0].Name {
// discard anything which we've not asked for
continue
}
// Check if this entry is complete // Check if this entry is complete
if inp.complete() { if inp.complete() {
if inp.sent { if inp.sent {
continue continue
} }
inp.sent = true inp.sent = true
select { select {
case params.Entries <- inp: case params.Entries <- inp:

View File

@ -2,13 +2,13 @@ package mdns
import ( import (
"fmt" "fmt"
"log"
"math/rand" "math/rand"
"net" "net"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
log "github.com/micro/go-micro/v2/logger"
"github.com/miekg/dns" "github.com/miekg/dns"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
@ -39,6 +39,10 @@ var (
} }
) )
// GetMachineIP is a func which returns the outbound IP of this machine.
// Used by the server to determine whether to attempt send the response on a local address
type GetMachineIP func() net.IP
// Config is used to configure the mDNS server // Config is used to configure the mDNS server
type Config struct { type Config struct {
// Zone must be provided to support responding to queries // Zone must be provided to support responding to queries
@ -51,9 +55,15 @@ type Config struct {
// Port If it is not 0, replace the port 5353 with this port number. // Port If it is not 0, replace the port 5353 with this port number.
Port int Port int
// GetMachineIP is a function to return the IP of the local machine
GetMachineIP GetMachineIP
// LocalhostChecking if enabled asks the server to also send responses to 0.0.0.0 if the target IP
// is this host (as defined by GetMachineIP). Useful in case machine is on a VPN which blocks comms on non standard ports
LocalhostChecking bool
} }
// mDNS server is used to listen for mDNS queries and respond if we // Server is an mDNS server used to listen for mDNS queries and respond if we
// have a matching local record // have a matching local record
type Server struct { type Server struct {
config *Config config *Config
@ -65,6 +75,8 @@ type Server struct {
shutdownCh chan struct{} shutdownCh chan struct{}
shutdownLock sync.Mutex shutdownLock sync.Mutex
wg sync.WaitGroup wg sync.WaitGroup
outboundIP net.IP
} }
// NewServer is used to create a new mDNS server from a config // NewServer is used to create a new mDNS server from a config
@ -118,11 +130,17 @@ func NewServer(config *Config) (*Server, error) {
} }
} }
ipFunc := getOutboundIP
if config.GetMachineIP != nil {
ipFunc = config.GetMachineIP
}
s := &Server{ s := &Server{
config: config, config: config,
ipv4List: ipv4List, ipv4List: ipv4List,
ipv6List: ipv6List, ipv6List: ipv6List,
shutdownCh: make(chan struct{}), shutdownCh: make(chan struct{}),
outboundIP: ipFunc(),
} }
go s.recv(s.ipv4List) go s.recv(s.ipv4List)
@ -176,7 +194,7 @@ func (s *Server) recv(c *net.UDPConn) {
continue continue
} }
if err := s.parsePacket(buf[:n], from); err != nil { if err := s.parsePacket(buf[:n], from); err != nil {
log.Printf("[ERR] mdns: Failed to handle query: %v", err) log.Errorf("[ERR] mdns: Failed to handle query: %v", err)
} }
} }
} }
@ -185,7 +203,7 @@ func (s *Server) recv(c *net.UDPConn) {
func (s *Server) parsePacket(packet []byte, from net.Addr) error { func (s *Server) parsePacket(packet []byte, from net.Addr) error {
var msg dns.Msg var msg dns.Msg
if err := msg.Unpack(packet); err != nil { if err := msg.Unpack(packet); err != nil {
log.Printf("[ERR] mdns: Failed to unpack packet: %v", err) log.Errorf("[ERR] mdns: Failed to unpack packet: %v", err)
return err return err
} }
// TODO: This is a bit of a hack // TODO: This is a bit of a hack
@ -278,8 +296,8 @@ func (s *Server) handleQuery(query *dns.Msg, from net.Addr) error {
// caveats in the RFC), so set the Compress bit (part of the dns library // caveats in the RFC), so set the Compress bit (part of the dns library
// API, not part of the DNS packet) to true. // API, not part of the DNS packet) to true.
Compress: true, Compress: true,
Question: query.Question,
Answer: answer, Answer: answer,
} }
} }
@ -302,7 +320,6 @@ func (s *Server) handleQuery(query *dns.Msg, from net.Addr) error {
// both. The return values are DNS records for each transmission type. // both. The return values are DNS records for each transmission type.
func (s *Server) handleQuestion(q dns.Question) (multicastRecs, unicastRecs []dns.RR) { func (s *Server) handleQuestion(q dns.Question) (multicastRecs, unicastRecs []dns.RR) {
records := s.config.Zone.Records(q) records := s.config.Zone.Records(q)
if len(records) == 0 { if len(records) == 0 {
return nil, nil return nil, nil
} }
@ -365,7 +382,7 @@ func (s *Server) probe() {
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
if err := s.SendMulticast(q); err != nil { if err := s.SendMulticast(q); err != nil {
log.Println("[ERR] mdns: failed to send probe:", err.Error()) log.Errorf("[ERR] mdns: failed to send probe:", err.Error())
} }
time.Sleep(time.Duration(randomizer.Intn(250)) * time.Millisecond) time.Sleep(time.Duration(randomizer.Intn(250)) * time.Millisecond)
} }
@ -391,7 +408,7 @@ func (s *Server) probe() {
timer := time.NewTimer(timeout) timer := time.NewTimer(timeout)
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
if err := s.SendMulticast(resp); err != nil { if err := s.SendMulticast(resp); err != nil {
log.Println("[ERR] mdns: failed to send announcement:", err.Error()) log.Errorf("[ERR] mdns: failed to send announcement:", err.Error())
} }
select { select {
case <-timer.C: case <-timer.C:
@ -404,7 +421,7 @@ func (s *Server) probe() {
} }
} }
// multicastResponse us used to send a multicast response packet // SendMulticast us used to send a multicast response packet
func (s *Server) SendMulticast(msg *dns.Msg) error { func (s *Server) SendMulticast(msg *dns.Msg) error {
buf, err := msg.Pack() buf, err := msg.Pack()
if err != nil { if err != nil {
@ -430,13 +447,23 @@ func (s *Server) sendResponse(resp *dns.Msg, from net.Addr) error {
// Determine the socket to send from // Determine the socket to send from
addr := from.(*net.UDPAddr) addr := from.(*net.UDPAddr)
if addr.IP.To4() != nil { conn := s.ipv4List
_, err = s.ipv4List.WriteToUDP(buf, addr) backupTarget := net.IPv4zero
return err
} else { if addr.IP.To4() == nil {
_, err = s.ipv6List.WriteToUDP(buf, addr) conn = s.ipv6List
return err backupTarget = net.IPv6zero
} }
_, err = conn.WriteToUDP(buf, addr)
// If the address we're responding to is this machine then we can also attempt sending on 0.0.0.0
// This covers the case where this machine is using a VPN and certain ports are blocked so the response never gets there
// Sending two responses is OK
if s.config.LocalhostChecking && addr.IP.Equal(s.outboundIP) {
// ignore any errors, this is best efforts
conn.WriteToUDP(buf, &net.UDPAddr{IP: backupTarget, Port: addr.Port})
}
return err
} }
func (s *Server) unregister() error { func (s *Server) unregister() error {
@ -474,3 +501,17 @@ func setCustomPort(port int) {
} }
} }
} }
// getOutboundIP returns the IP address of this machine as seen when dialling out
func getOutboundIP() net.IP {
conn, err := net.Dial("udp", "8.8.8.8:80")
if err != nil {
// no net connectivity maybe so fallback
return nil
}
defer conn.Close()
localAddr := conn.LocalAddr().(*net.UDPAddr)
return localAddr.IP
}

View File

@ -7,7 +7,7 @@ import (
func TestServer_StartStop(t *testing.T) { func TestServer_StartStop(t *testing.T) {
s := makeService(t) s := makeService(t)
serv, err := NewServer(&Config{Zone: s}) serv, err := NewServer(&Config{Zone: s, LocalhostChecking: true})
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -15,7 +15,7 @@ func TestServer_StartStop(t *testing.T) {
} }
func TestServer_Lookup(t *testing.T) { func TestServer_Lookup(t *testing.T) {
serv, err := NewServer(&Config{Zone: makeServiceWithServiceName(t, "_foobar._tcp")}) serv, err := NewServer(&Config{Zone: makeServiceWithServiceName(t, "_foobar._tcp"), LocalhostChecking: true})
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }