MDNS registry fix for users on VPNs (#1759)
* filter out unsolicited responses * send to local ip in case * allow ip func to be passed in. add option for sending to 0.0.0.0
This commit is contained in:
parent
0f5c53b6e4
commit
6532b6208b
@ -184,7 +184,7 @@ func createServiceMDNSEntry(name, domain string) (*mdnsEntry, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
srv, err := mdns.NewServer(&mdns.Config{Zone: &mdns.DNSSDService{MDNSService: s}})
|
||||
srv, err := mdns.NewServer(&mdns.Config{Zone: &mdns.DNSSDService{MDNSService: s}, LocalhostChecking: true})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -273,7 +273,7 @@ func (m *mdnsRegistry) Register(service *Service, opts ...RegisterOption) error
|
||||
continue
|
||||
}
|
||||
|
||||
srv, err := mdns.NewServer(&mdns.Config{Zone: s})
|
||||
srv, err := mdns.NewServer(&mdns.Config{Zone: s, LocalhostChecking: true})
|
||||
if err != nil {
|
||||
gerr = err
|
||||
continue
|
||||
|
@ -34,6 +34,7 @@ type ServiceEntry struct {
|
||||
|
||||
// complete is used to check if we have all the info we need
|
||||
func (s *ServiceEntry) complete() bool {
|
||||
|
||||
return (len(s.AddrV4) > 0 || len(s.AddrV6) > 0 || len(s.Addr) > 0) && s.Port != 0 && s.hasTXT
|
||||
}
|
||||
|
||||
@ -347,15 +348,21 @@ func (c *client) query(params *QueryParam) error {
|
||||
select {
|
||||
case resp := <-msgCh:
|
||||
inp := messageToEntry(resp, inprogress)
|
||||
|
||||
if inp == nil {
|
||||
continue
|
||||
}
|
||||
if len(resp.Question) == 0 || resp.Question[0].Name != m.Question[0].Name {
|
||||
// discard anything which we've not asked for
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this entry is complete
|
||||
if inp.complete() {
|
||||
if inp.sent {
|
||||
continue
|
||||
}
|
||||
|
||||
inp.sent = true
|
||||
select {
|
||||
case params.Entries <- inp:
|
||||
|
@ -2,13 +2,13 @@ package mdns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/micro/go-micro/v2/logger"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
@ -39,6 +39,10 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
// GetMachineIP is a func which returns the outbound IP of this machine.
|
||||
// Used by the server to determine whether to attempt send the response on a local address
|
||||
type GetMachineIP func() net.IP
|
||||
|
||||
// Config is used to configure the mDNS server
|
||||
type Config struct {
|
||||
// Zone must be provided to support responding to queries
|
||||
@ -51,9 +55,15 @@ type Config struct {
|
||||
|
||||
// Port If it is not 0, replace the port 5353 with this port number.
|
||||
Port int
|
||||
|
||||
// GetMachineIP is a function to return the IP of the local machine
|
||||
GetMachineIP GetMachineIP
|
||||
// LocalhostChecking if enabled asks the server to also send responses to 0.0.0.0 if the target IP
|
||||
// is this host (as defined by GetMachineIP). Useful in case machine is on a VPN which blocks comms on non standard ports
|
||||
LocalhostChecking bool
|
||||
}
|
||||
|
||||
// mDNS server is used to listen for mDNS queries and respond if we
|
||||
// Server is an mDNS server used to listen for mDNS queries and respond if we
|
||||
// have a matching local record
|
||||
type Server struct {
|
||||
config *Config
|
||||
@ -65,6 +75,8 @@ type Server struct {
|
||||
shutdownCh chan struct{}
|
||||
shutdownLock sync.Mutex
|
||||
wg sync.WaitGroup
|
||||
|
||||
outboundIP net.IP
|
||||
}
|
||||
|
||||
// NewServer is used to create a new mDNS server from a config
|
||||
@ -118,11 +130,17 @@ func NewServer(config *Config) (*Server, error) {
|
||||
}
|
||||
}
|
||||
|
||||
ipFunc := getOutboundIP
|
||||
if config.GetMachineIP != nil {
|
||||
ipFunc = config.GetMachineIP
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
config: config,
|
||||
ipv4List: ipv4List,
|
||||
ipv6List: ipv6List,
|
||||
shutdownCh: make(chan struct{}),
|
||||
outboundIP: ipFunc(),
|
||||
}
|
||||
|
||||
go s.recv(s.ipv4List)
|
||||
@ -176,7 +194,7 @@ func (s *Server) recv(c *net.UDPConn) {
|
||||
continue
|
||||
}
|
||||
if err := s.parsePacket(buf[:n], from); err != nil {
|
||||
log.Printf("[ERR] mdns: Failed to handle query: %v", err)
|
||||
log.Errorf("[ERR] mdns: Failed to handle query: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -185,7 +203,7 @@ func (s *Server) recv(c *net.UDPConn) {
|
||||
func (s *Server) parsePacket(packet []byte, from net.Addr) error {
|
||||
var msg dns.Msg
|
||||
if err := msg.Unpack(packet); err != nil {
|
||||
log.Printf("[ERR] mdns: Failed to unpack packet: %v", err)
|
||||
log.Errorf("[ERR] mdns: Failed to unpack packet: %v", err)
|
||||
return err
|
||||
}
|
||||
// TODO: This is a bit of a hack
|
||||
@ -278,7 +296,7 @@ func (s *Server) handleQuery(query *dns.Msg, from net.Addr) error {
|
||||
// caveats in the RFC), so set the Compress bit (part of the dns library
|
||||
// API, not part of the DNS packet) to true.
|
||||
Compress: true,
|
||||
|
||||
Question: query.Question,
|
||||
Answer: answer,
|
||||
}
|
||||
}
|
||||
@ -302,7 +320,6 @@ func (s *Server) handleQuery(query *dns.Msg, from net.Addr) error {
|
||||
// both. The return values are DNS records for each transmission type.
|
||||
func (s *Server) handleQuestion(q dns.Question) (multicastRecs, unicastRecs []dns.RR) {
|
||||
records := s.config.Zone.Records(q)
|
||||
|
||||
if len(records) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
@ -365,7 +382,7 @@ func (s *Server) probe() {
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
if err := s.SendMulticast(q); err != nil {
|
||||
log.Println("[ERR] mdns: failed to send probe:", err.Error())
|
||||
log.Errorf("[ERR] mdns: failed to send probe:", err.Error())
|
||||
}
|
||||
time.Sleep(time.Duration(randomizer.Intn(250)) * time.Millisecond)
|
||||
}
|
||||
@ -391,7 +408,7 @@ func (s *Server) probe() {
|
||||
timer := time.NewTimer(timeout)
|
||||
for i := 0; i < 3; i++ {
|
||||
if err := s.SendMulticast(resp); err != nil {
|
||||
log.Println("[ERR] mdns: failed to send announcement:", err.Error())
|
||||
log.Errorf("[ERR] mdns: failed to send announcement:", err.Error())
|
||||
}
|
||||
select {
|
||||
case <-timer.C:
|
||||
@ -404,7 +421,7 @@ func (s *Server) probe() {
|
||||
}
|
||||
}
|
||||
|
||||
// multicastResponse us used to send a multicast response packet
|
||||
// SendMulticast us used to send a multicast response packet
|
||||
func (s *Server) SendMulticast(msg *dns.Msg) error {
|
||||
buf, err := msg.Pack()
|
||||
if err != nil {
|
||||
@ -430,13 +447,23 @@ func (s *Server) sendResponse(resp *dns.Msg, from net.Addr) error {
|
||||
|
||||
// Determine the socket to send from
|
||||
addr := from.(*net.UDPAddr)
|
||||
if addr.IP.To4() != nil {
|
||||
_, err = s.ipv4List.WriteToUDP(buf, addr)
|
||||
return err
|
||||
} else {
|
||||
_, err = s.ipv6List.WriteToUDP(buf, addr)
|
||||
return err
|
||||
conn := s.ipv4List
|
||||
backupTarget := net.IPv4zero
|
||||
|
||||
if addr.IP.To4() == nil {
|
||||
conn = s.ipv6List
|
||||
backupTarget = net.IPv6zero
|
||||
}
|
||||
_, err = conn.WriteToUDP(buf, addr)
|
||||
// If the address we're responding to is this machine then we can also attempt sending on 0.0.0.0
|
||||
// This covers the case where this machine is using a VPN and certain ports are blocked so the response never gets there
|
||||
// Sending two responses is OK
|
||||
if s.config.LocalhostChecking && addr.IP.Equal(s.outboundIP) {
|
||||
// ignore any errors, this is best efforts
|
||||
conn.WriteToUDP(buf, &net.UDPAddr{IP: backupTarget, Port: addr.Port})
|
||||
}
|
||||
return err
|
||||
|
||||
}
|
||||
|
||||
func (s *Server) unregister() error {
|
||||
@ -474,3 +501,17 @@ func setCustomPort(port int) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getOutboundIP returns the IP address of this machine as seen when dialling out
|
||||
func getOutboundIP() net.IP {
|
||||
conn, err := net.Dial("udp", "8.8.8.8:80")
|
||||
if err != nil {
|
||||
// no net connectivity maybe so fallback
|
||||
return nil
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
localAddr := conn.LocalAddr().(*net.UDPAddr)
|
||||
|
||||
return localAddr.IP
|
||||
}
|
||||
|
@ -7,7 +7,7 @@ import (
|
||||
|
||||
func TestServer_StartStop(t *testing.T) {
|
||||
s := makeService(t)
|
||||
serv, err := NewServer(&Config{Zone: s})
|
||||
serv, err := NewServer(&Config{Zone: s, LocalhostChecking: true})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
@ -15,7 +15,7 @@ func TestServer_StartStop(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestServer_Lookup(t *testing.T) {
|
||||
serv, err := NewServer(&Config{Zone: makeServiceWithServiceName(t, "_foobar._tcp")})
|
||||
serv, err := NewServer(&Config{Zone: makeServiceWithServiceName(t, "_foobar._tcp"), LocalhostChecking: true})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user