MDNS registry fix for users on VPNs (#1759)

* filter out unsolicited responses
* send to local ip in case
* allow ip func to be passed in. add option for sending to 0.0.0.0
This commit is contained in:
Dominic Wong 2020-06-30 11:12:52 +01:00
parent 3e6ac73cfe
commit 7be4a67673
4 changed files with 67 additions and 19 deletions

View File

@ -252,7 +252,7 @@ func (m *mdnsRegistry) Register(service *Service, opts ...RegisterOption) error
continue
}
srv, err := mdns.NewServer(&mdns.Config{Zone: s})
srv, err := mdns.NewServer(&mdns.Config{Zone: s, LocalhostChecking: true})
if err != nil {
gerr = err
continue

View File

@ -34,6 +34,7 @@ type ServiceEntry struct {
// complete is used to check if we have all the info we need
func (s *ServiceEntry) complete() bool {
return (len(s.AddrV4) > 0 || len(s.AddrV6) > 0 || len(s.Addr) > 0) && s.Port != 0 && s.hasTXT
}
@ -347,15 +348,21 @@ func (c *client) query(params *QueryParam) error {
select {
case resp := <-msgCh:
inp := messageToEntry(resp, inprogress)
if inp == nil {
continue
}
if len(resp.Question) == 0 || resp.Question[0].Name != m.Question[0].Name {
// discard anything which we've not asked for
continue
}
// Check if this entry is complete
if inp.complete() {
if inp.sent {
continue
}
inp.sent = true
select {
case params.Entries <- inp:

View File

@ -2,13 +2,13 @@ package mdns
import (
"fmt"
"log"
"math/rand"
"net"
"sync"
"sync/atomic"
"time"
log "github.com/micro/go-micro/v2/logger"
"github.com/miekg/dns"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
@ -39,6 +39,10 @@ var (
}
)
// GetMachineIP is a func which returns the outbound IP of this machine.
// Used by the server to determine whether to attempt send the response on a local address
type GetMachineIP func() net.IP
// Config is used to configure the mDNS server
type Config struct {
// Zone must be provided to support responding to queries
@ -51,9 +55,15 @@ type Config struct {
// Port If it is not 0, replace the port 5353 with this port number.
Port int
// GetMachineIP is a function to return the IP of the local machine
GetMachineIP GetMachineIP
// LocalhostChecking if enabled asks the server to also send responses to 0.0.0.0 if the target IP
// is this host (as defined by GetMachineIP). Useful in case machine is on a VPN which blocks comms on non standard ports
LocalhostChecking bool
}
// mDNS server is used to listen for mDNS queries and respond if we
// Server is an mDNS server used to listen for mDNS queries and respond if we
// have a matching local record
type Server struct {
config *Config
@ -65,6 +75,8 @@ type Server struct {
shutdownCh chan struct{}
shutdownLock sync.Mutex
wg sync.WaitGroup
outboundIP net.IP
}
// NewServer is used to create a new mDNS server from a config
@ -118,11 +130,17 @@ func NewServer(config *Config) (*Server, error) {
}
}
ipFunc := getOutboundIP
if config.GetMachineIP != nil {
ipFunc = config.GetMachineIP
}
s := &Server{
config: config,
ipv4List: ipv4List,
ipv6List: ipv6List,
shutdownCh: make(chan struct{}),
outboundIP: ipFunc(),
}
go s.recv(s.ipv4List)
@ -176,7 +194,7 @@ func (s *Server) recv(c *net.UDPConn) {
continue
}
if err := s.parsePacket(buf[:n], from); err != nil {
log.Printf("[ERR] mdns: Failed to handle query: %v", err)
log.Errorf("[ERR] mdns: Failed to handle query: %v", err)
}
}
}
@ -185,7 +203,7 @@ func (s *Server) recv(c *net.UDPConn) {
func (s *Server) parsePacket(packet []byte, from net.Addr) error {
var msg dns.Msg
if err := msg.Unpack(packet); err != nil {
log.Printf("[ERR] mdns: Failed to unpack packet: %v", err)
log.Errorf("[ERR] mdns: Failed to unpack packet: %v", err)
return err
}
// TODO: This is a bit of a hack
@ -278,8 +296,8 @@ func (s *Server) handleQuery(query *dns.Msg, from net.Addr) error {
// caveats in the RFC), so set the Compress bit (part of the dns library
// API, not part of the DNS packet) to true.
Compress: true,
Answer: answer,
Question: query.Question,
Answer: answer,
}
}
@ -302,7 +320,6 @@ func (s *Server) handleQuery(query *dns.Msg, from net.Addr) error {
// both. The return values are DNS records for each transmission type.
func (s *Server) handleQuestion(q dns.Question) (multicastRecs, unicastRecs []dns.RR) {
records := s.config.Zone.Records(q)
if len(records) == 0 {
return nil, nil
}
@ -365,7 +382,7 @@ func (s *Server) probe() {
for i := 0; i < 3; i++ {
if err := s.SendMulticast(q); err != nil {
log.Println("[ERR] mdns: failed to send probe:", err.Error())
log.Errorf("[ERR] mdns: failed to send probe:", err.Error())
}
time.Sleep(time.Duration(randomizer.Intn(250)) * time.Millisecond)
}
@ -391,7 +408,7 @@ func (s *Server) probe() {
timer := time.NewTimer(timeout)
for i := 0; i < 3; i++ {
if err := s.SendMulticast(resp); err != nil {
log.Println("[ERR] mdns: failed to send announcement:", err.Error())
log.Errorf("[ERR] mdns: failed to send announcement:", err.Error())
}
select {
case <-timer.C:
@ -404,7 +421,7 @@ func (s *Server) probe() {
}
}
// multicastResponse us used to send a multicast response packet
// SendMulticast us used to send a multicast response packet
func (s *Server) SendMulticast(msg *dns.Msg) error {
buf, err := msg.Pack()
if err != nil {
@ -430,13 +447,23 @@ func (s *Server) sendResponse(resp *dns.Msg, from net.Addr) error {
// Determine the socket to send from
addr := from.(*net.UDPAddr)
if addr.IP.To4() != nil {
_, err = s.ipv4List.WriteToUDP(buf, addr)
return err
} else {
_, err = s.ipv6List.WriteToUDP(buf, addr)
return err
conn := s.ipv4List
backupTarget := net.IPv4zero
if addr.IP.To4() == nil {
conn = s.ipv6List
backupTarget = net.IPv6zero
}
_, err = conn.WriteToUDP(buf, addr)
// If the address we're responding to is this machine then we can also attempt sending on 0.0.0.0
// This covers the case where this machine is using a VPN and certain ports are blocked so the response never gets there
// Sending two responses is OK
if s.config.LocalhostChecking && addr.IP.Equal(s.outboundIP) {
// ignore any errors, this is best efforts
conn.WriteToUDP(buf, &net.UDPAddr{IP: backupTarget, Port: addr.Port})
}
return err
}
func (s *Server) unregister() error {
@ -474,3 +501,17 @@ func setCustomPort(port int) {
}
}
}
// getOutboundIP returns the IP address of this machine as seen when dialling out
func getOutboundIP() net.IP {
conn, err := net.Dial("udp", "8.8.8.8:80")
if err != nil {
// no net connectivity maybe so fallback
return nil
}
defer conn.Close()
localAddr := conn.LocalAddr().(*net.UDPAddr)
return localAddr.IP
}

View File

@ -7,7 +7,7 @@ import (
func TestServer_StartStop(t *testing.T) {
s := makeService(t)
serv, err := NewServer(&Config{Zone: s})
serv, err := NewServer(&Config{Zone: s, LocalhostChecking: true})
if err != nil {
t.Fatalf("err: %v", err)
}
@ -15,7 +15,7 @@ func TestServer_StartStop(t *testing.T) {
}
func TestServer_Lookup(t *testing.T) {
serv, err := NewServer(&Config{Zone: makeServiceWithServiceName(t, "_foobar._tcp")})
serv, err := NewServer(&Config{Zone: makeServiceWithServiceName(t, "_foobar._tcp"), LocalhostChecking: true})
if err != nil {
t.Fatalf("err: %v", err)
}