Compare commits
	
		
			8 Commits
		
	
	
		
			dee7bc9c38
			...
			v2.9.1
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 94bd1025a6 | ||
|  | 7be4a67673 | ||
|  | 3e6ac73cfe | ||
|  | aef6878ee0 | ||
|  | 81aa8e0231 | ||
|  | c28f625cd4 | ||
|  | 5b161b88f7 | ||
|  | cca8826a1f | 
							
								
								
									
										1
									
								
								.github/workflows/docker.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/docker.yml
									
									
									
									
										vendored
									
									
								
							| @@ -19,3 +19,4 @@ jobs: | ||||
|            name: micro/go-micro | ||||
|            username: ${{ secrets.DOCKER_USERNAME }} | ||||
|            password: ${{ secrets.DOCKER_PASSWORD }} | ||||
|            tag_names: true | ||||
| @@ -70,12 +70,32 @@ func (dc *discordConn) Recv(event *input.Event) error { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func ChunkString(s string, chunkSize int) []string { | ||||
| 	var chunks []string | ||||
| 	runes := []rune(s) | ||||
|  | ||||
| 	if len(runes) == 0 { | ||||
| 		return []string{s} | ||||
| 	} | ||||
|  | ||||
| 	for i := 0; i < len(runes); i += chunkSize { | ||||
| 		nn := i + chunkSize | ||||
| 		if nn > len(runes) { | ||||
| 			nn = len(runes) | ||||
| 		} | ||||
| 		chunks = append(chunks, string(runes[i:nn])) | ||||
| 	} | ||||
| 	return chunks | ||||
| } | ||||
|  | ||||
| func (dc *discordConn) Send(e *input.Event) error { | ||||
| 	fields := strings.Split(e.To, ":") | ||||
| 	_, err := dc.master.session.ChannelMessageSend(fields[0], string(e.Data)) | ||||
| 	if err != nil { | ||||
| 		if logger.V(logger.ErrorLevel, logger.DefaultLogger) { | ||||
| 			logger.Error("[bot][loop][send]", err) | ||||
| 	for _, chunk := range ChunkString(string(e.Data), 2000) { | ||||
| 		_, err := dc.master.session.ChannelMessageSend(fields[0], chunk) | ||||
| 		if err != nil { | ||||
| 			if logger.V(logger.ErrorLevel, logger.DefaultLogger) { | ||||
| 				logger.Error("[bot][loop][send]", err) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
|   | ||||
| @@ -41,7 +41,6 @@ func newConfig(opts ...Option) (Config, error) { | ||||
|  | ||||
| func (c *config) Init(opts ...Option) error { | ||||
| 	c.opts = Options{ | ||||
| 		Loader: memory.NewLoader(), | ||||
| 		Reader: json.NewReader(), | ||||
| 	} | ||||
| 	c.exit = make(chan bool) | ||||
| @@ -49,6 +48,11 @@ func (c *config) Init(opts ...Option) error { | ||||
| 		o(&c.opts) | ||||
| 	} | ||||
|  | ||||
| 	// default loader uses the configured reader | ||||
| 	if c.opts.Loader == nil { | ||||
| 		c.opts.Loader = memory.NewLoader(memory.WithReader(c.opts.Reader)) | ||||
| 	} | ||||
|  | ||||
| 	err := c.opts.Loader.Load(c.opts.Source...) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| @@ -82,6 +86,11 @@ func (c *config) run() { | ||||
|  | ||||
| 			c.Lock() | ||||
|  | ||||
| 			if c.snap.Version >= snap.Version { | ||||
| 				c.Unlock() | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			// save | ||||
| 			c.snap = snap | ||||
|  | ||||
|   | ||||
| @@ -4,12 +4,15 @@ import ( | ||||
| 	"fmt" | ||||
| 	"os" | ||||
| 	"path/filepath" | ||||
| 	"runtime" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/micro/go-micro/v2/config/source" | ||||
| 	"github.com/micro/go-micro/v2/config/source/env" | ||||
| 	"github.com/micro/go-micro/v2/config/source/file" | ||||
| 	"github.com/micro/go-micro/v2/config/source/memory" | ||||
| ) | ||||
|  | ||||
| func createFileForIssue18(t *testing.T, content string) *os.File { | ||||
| @@ -127,3 +130,37 @@ func TestConfigMerge(t *testing.T) { | ||||
| 			actualHost) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func equalS(t *testing.T, actual, expect string) { | ||||
| 	if actual != expect { | ||||
| 		t.Errorf("Expected %s but got %s", actual, expect) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestConfigWatcherDirtyOverrite(t *testing.T) { | ||||
| 	n := runtime.GOMAXPROCS(0) | ||||
| 	defer runtime.GOMAXPROCS(n) | ||||
|  | ||||
| 	runtime.GOMAXPROCS(1) | ||||
|  | ||||
| 	l := 100 | ||||
|  | ||||
| 	ss := make([]source.Source, l, l) | ||||
|  | ||||
| 	for i := 0; i < l; i++ { | ||||
| 		ss[i] = memory.NewSource(memory.WithJSON([]byte(fmt.Sprintf(`{"key%d": "val%d"}`, i, i)))) | ||||
| 	} | ||||
|  | ||||
| 	conf, _ := NewConfig() | ||||
|  | ||||
| 	for _, s := range ss { | ||||
| 		_ = conf.Load(s) | ||||
| 	} | ||||
| 	runtime.Gosched() | ||||
|  | ||||
| 	for i, _ := range ss { | ||||
| 		k := fmt.Sprintf("key%d", i) | ||||
| 		v := fmt.Sprintf("val%d", i) | ||||
| 		equalS(t, conf.Get(k).String(""), v) | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -32,19 +32,21 @@ type memory struct { | ||||
| 	watchers *list.List | ||||
| } | ||||
|  | ||||
| type updateValue struct { | ||||
| 	version string | ||||
| 	value   reader.Value | ||||
| } | ||||
|  | ||||
| type watcher struct { | ||||
| 	exit    chan bool | ||||
| 	path    []string | ||||
| 	value   reader.Value | ||||
| 	reader  reader.Reader | ||||
| 	updates chan reader.Value | ||||
| 	version string | ||||
| 	updates chan updateValue | ||||
| } | ||||
|  | ||||
| func (m *memory) watch(idx int, s source.Source) { | ||||
| 	m.Lock() | ||||
| 	m.sets = append(m.sets, &source.ChangeSet{Source: s.String()}) | ||||
| 	m.Unlock() | ||||
|  | ||||
| 	// watches a source for changes | ||||
| 	watch := func(idx int, s source.Watcher) error { | ||||
| 		for { | ||||
| @@ -70,7 +72,7 @@ func (m *memory) watch(idx int, s source.Source) { | ||||
| 			m.vals, _ = m.opts.Reader.Values(set) | ||||
| 			m.snap = &loader.Snapshot{ | ||||
| 				ChangeSet: set, | ||||
| 				Version:   fmt.Sprintf("%d", time.Now().Unix()), | ||||
| 				Version:   genVer(), | ||||
| 			} | ||||
| 			m.Unlock() | ||||
|  | ||||
| @@ -141,7 +143,7 @@ func (m *memory) reload() error { | ||||
| 	m.vals, _ = m.opts.Reader.Values(set) | ||||
| 	m.snap = &loader.Snapshot{ | ||||
| 		ChangeSet: set, | ||||
| 		Version:   fmt.Sprintf("%d", time.Now().Unix()), | ||||
| 		Version:   genVer(), | ||||
| 	} | ||||
|  | ||||
| 	m.Unlock() | ||||
| @@ -159,11 +161,23 @@ func (m *memory) update() { | ||||
| 	for e := m.watchers.Front(); e != nil; e = e.Next() { | ||||
| 		watchers = append(watchers, e.Value.(*watcher)) | ||||
| 	} | ||||
|  | ||||
| 	vals := m.vals | ||||
| 	snap := m.snap | ||||
| 	m.RUnlock() | ||||
|  | ||||
| 	for _, w := range watchers { | ||||
| 		if w.version >= snap.Version { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		uv := updateValue{ | ||||
| 			version: m.snap.Version, | ||||
| 			value:   vals.Get(w.path...), | ||||
| 		} | ||||
|  | ||||
| 		select { | ||||
| 		case w.updates <- m.vals.Get(w.path...): | ||||
| 		case w.updates <- uv: | ||||
| 		default: | ||||
| 		} | ||||
| 	} | ||||
| @@ -226,7 +240,7 @@ func (m *memory) Sync() error { | ||||
| 	m.vals = vals | ||||
| 	m.snap = &loader.Snapshot{ | ||||
| 		ChangeSet: set, | ||||
| 		Version:   fmt.Sprintf("%d", time.Now().Unix()), | ||||
| 		Version:   genVer(), | ||||
| 	} | ||||
|  | ||||
| 	m.Unlock() | ||||
| @@ -285,6 +299,7 @@ func (m *memory) Get(path ...string) (reader.Value, error) { | ||||
| 	} | ||||
|  | ||||
| 	// ok we're going hardcore now | ||||
|  | ||||
| 	return nil, errors.New("no values") | ||||
| } | ||||
|  | ||||
| @@ -333,7 +348,8 @@ func (m *memory) Watch(path ...string) (loader.Watcher, error) { | ||||
| 		path:    path, | ||||
| 		value:   value, | ||||
| 		reader:  m.opts.Reader, | ||||
| 		updates: make(chan reader.Value, 1), | ||||
| 		updates: make(chan updateValue, 1), | ||||
| 		version: m.snap.Version, | ||||
| 	} | ||||
|  | ||||
| 	e := m.watchers.PushBack(w) | ||||
| @@ -355,28 +371,43 @@ func (m *memory) String() string { | ||||
| } | ||||
|  | ||||
| func (w *watcher) Next() (*loader.Snapshot, error) { | ||||
| 	update := func(v reader.Value) *loader.Snapshot { | ||||
| 		w.value = v | ||||
|  | ||||
| 		cs := &source.ChangeSet{ | ||||
| 			Data:      v.Bytes(), | ||||
| 			Format:    w.reader.String(), | ||||
| 			Source:    "memory", | ||||
| 			Timestamp: time.Now(), | ||||
| 		} | ||||
| 		cs.Checksum = cs.Sum() | ||||
|  | ||||
| 		return &loader.Snapshot{ | ||||
| 			ChangeSet: cs, | ||||
| 			Version:   w.version, | ||||
| 		} | ||||
|  | ||||
| 	} | ||||
|  | ||||
| 	for { | ||||
| 		select { | ||||
| 		case <-w.exit: | ||||
| 			return nil, errors.New("watcher stopped") | ||||
| 		case v := <-w.updates: | ||||
|  | ||||
| 		case uv := <-w.updates: | ||||
| 			if uv.version <= w.version { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			v := uv.value | ||||
|  | ||||
| 			w.version = uv.version | ||||
|  | ||||
| 			if bytes.Equal(w.value.Bytes(), v.Bytes()) { | ||||
| 				continue | ||||
| 			} | ||||
| 			w.value = v | ||||
|  | ||||
| 			cs := &source.ChangeSet{ | ||||
| 				Data:      v.Bytes(), | ||||
| 				Format:    w.reader.String(), | ||||
| 				Source:    "memory", | ||||
| 				Timestamp: time.Now(), | ||||
| 			} | ||||
| 			cs.Sum() | ||||
|  | ||||
| 			return &loader.Snapshot{ | ||||
| 				ChangeSet: cs, | ||||
| 				Version:   fmt.Sprintf("%d", time.Now().Unix()), | ||||
| 			}, nil | ||||
| 			return update(v), nil | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| @@ -386,10 +417,16 @@ func (w *watcher) Stop() error { | ||||
| 	case <-w.exit: | ||||
| 	default: | ||||
| 		close(w.exit) | ||||
| 		close(w.updates) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func genVer() string { | ||||
| 	return fmt.Sprintf("%d", time.Now().UnixNano()) | ||||
| } | ||||
|  | ||||
| func NewLoader(opts ...loader.Option) loader.Loader { | ||||
| 	options := loader.Options{ | ||||
| 		Reader: json.NewReader(), | ||||
| @@ -406,7 +443,10 @@ func NewLoader(opts ...loader.Option) loader.Loader { | ||||
| 		sources:  options.Source, | ||||
| 	} | ||||
|  | ||||
| 	m.sets = make([]*source.ChangeSet, len(options.Source)) | ||||
|  | ||||
| 	for i, s := range options.Source { | ||||
| 		m.sets[i] = &source.ChangeSet{Source: s.String()} | ||||
| 		go m.watch(i, s) | ||||
| 	} | ||||
|  | ||||
|   | ||||
| @@ -42,6 +42,7 @@ func (s *memory) Watch() (source.Watcher, error) { | ||||
| } | ||||
|  | ||||
| func (m *memory) Write(cs *source.ChangeSet) error { | ||||
| 	m.Update(cs) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -252,7 +252,7 @@ func (m *mdnsRegistry) Register(service *Service, opts ...RegisterOption) error | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		srv, err := mdns.NewServer(&mdns.Config{Zone: s}) | ||||
| 		srv, err := mdns.NewServer(&mdns.Config{Zone: s, LocalhostChecking: true}) | ||||
| 		if err != nil { | ||||
| 			gerr = err | ||||
| 			continue | ||||
| @@ -563,9 +563,7 @@ func (m *mdnsWatcher) Next() (*Result, error) { | ||||
| 			if len(m.wo.Service) > 0 && txt.Service != m.wo.Service { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			var action string | ||||
|  | ||||
| 			if e.TTL == 0 { | ||||
| 				action = "delete" | ||||
| 			} else { | ||||
| @@ -584,9 +582,18 @@ func (m *mdnsWatcher) Next() (*Result, error) { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			var addr string | ||||
| 			if len(e.AddrV4) > 0 { | ||||
| 				addr = e.AddrV4.String() | ||||
| 			} else if len(e.AddrV6) > 0 { | ||||
| 				addr = "[" + e.AddrV6.String() + "]" | ||||
| 			} else { | ||||
| 				addr = e.Addr.String() | ||||
| 			} | ||||
|  | ||||
| 			service.Nodes = append(service.Nodes, &Node{ | ||||
| 				Id:       strings.TrimSuffix(e.Name, suffix), | ||||
| 				Address:  fmt.Sprintf("%s:%d", e.AddrV4.String(), e.Port), | ||||
| 				Address:  fmt.Sprintf("%s:%d", addr, e.Port), | ||||
| 				Metadata: txt.Metadata, | ||||
| 			}) | ||||
|  | ||||
|   | ||||
| @@ -363,11 +363,12 @@ func (s *rpcServer) ServeConn(sock transport.Socket) { | ||||
| 			r = rpcRouter{h: handler} | ||||
| 		} | ||||
|  | ||||
| 		// wait for two coroutines to exit | ||||
| 		// serve the request and process the outbound messages | ||||
| 		wg.Add(2) | ||||
|  | ||||
| 		// process the outbound messages from the socket | ||||
| 		go func(id string, psock *socket.Socket) { | ||||
| 			// wait for processing to exit | ||||
| 			wg.Add(1) | ||||
|  | ||||
| 			defer func() { | ||||
| 				// TODO: don't hack this but if its grpc just break out of the stream | ||||
| 				// We do this because the underlying connection is h2 and its a stream | ||||
| @@ -405,9 +406,6 @@ func (s *rpcServer) ServeConn(sock transport.Socket) { | ||||
|  | ||||
| 		// serve the request in a go routine as this may be a stream | ||||
| 		go func(id string, psock *socket.Socket) { | ||||
| 			// add to the waitgroup | ||||
| 			wg.Add(1) | ||||
|  | ||||
| 			defer func() { | ||||
| 				// release the socket | ||||
| 				pool.Release(psock) | ||||
|   | ||||
| @@ -17,6 +17,15 @@ func init() { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // AppendPrivateBlocks append private network blocks | ||||
| func AppendPrivateBlocks(bs ...string) { | ||||
| 	for _, b := range bs { | ||||
| 		if _, block, err := net.ParseCIDR(b); err == nil { | ||||
| 			privateBlocks = append(privateBlocks, block) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func isPrivateIP(ipAddr string) bool { | ||||
| 	ip := net.ParseIP(ipAddr) | ||||
| 	for _, priv := range privateBlocks { | ||||
|   | ||||
| @@ -56,3 +56,24 @@ func TestExtractor(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| } | ||||
|  | ||||
| func TestAppendPrivateBlocks(t *testing.T) { | ||||
| 	tests := []struct { | ||||
| 		addr   string | ||||
| 		expect bool | ||||
| 	}{ | ||||
| 		{addr: "9.134.71.34", expect: true}, | ||||
| 		{addr: "8.10.110.34", expect: false}, // not in private blocks | ||||
| 	} | ||||
|  | ||||
| 	AppendPrivateBlocks("9.134.0.0/16") | ||||
|  | ||||
| 	for _, test := range tests { | ||||
| 		t.Run(test.addr, func(t *testing.T) { | ||||
| 			res := isPrivateIP(test.addr) | ||||
| 			if res != test.expect { | ||||
| 				t.Fatalf("expected %t got %t", test.expect, res) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -34,6 +34,7 @@ type ServiceEntry struct { | ||||
|  | ||||
| // complete is used to check if we have all the info we need | ||||
| func (s *ServiceEntry) complete() bool { | ||||
|  | ||||
| 	return (len(s.AddrV4) > 0 || len(s.AddrV6) > 0 || len(s.Addr) > 0) && s.Port != 0 && s.hasTXT | ||||
| } | ||||
|  | ||||
| @@ -347,15 +348,21 @@ func (c *client) query(params *QueryParam) error { | ||||
| 		select { | ||||
| 		case resp := <-msgCh: | ||||
| 			inp := messageToEntry(resp, inprogress) | ||||
|  | ||||
| 			if inp == nil { | ||||
| 				continue | ||||
| 			} | ||||
| 			if len(resp.Question) == 0 || resp.Question[0].Name != m.Question[0].Name { | ||||
| 				// discard anything which we've not asked for | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			// Check if this entry is complete | ||||
| 			if inp.complete() { | ||||
| 				if inp.sent { | ||||
| 					continue | ||||
| 				} | ||||
|  | ||||
| 				inp.sent = true | ||||
| 				select { | ||||
| 				case params.Entries <- inp: | ||||
|   | ||||
| @@ -2,13 +2,13 @@ package mdns | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"log" | ||||
| 	"math/rand" | ||||
| 	"net" | ||||
| 	"sync" | ||||
| 	"sync/atomic" | ||||
| 	"time" | ||||
|  | ||||
| 	log "github.com/micro/go-micro/v2/logger" | ||||
| 	"github.com/miekg/dns" | ||||
| 	"golang.org/x/net/ipv4" | ||||
| 	"golang.org/x/net/ipv6" | ||||
| @@ -39,6 +39,10 @@ var ( | ||||
| 	} | ||||
| ) | ||||
|  | ||||
| // GetMachineIP is a func which returns the outbound IP of this machine. | ||||
| // Used by the server to determine whether to attempt send the response on a local address | ||||
| type GetMachineIP func() net.IP | ||||
|  | ||||
| // Config is used to configure the mDNS server | ||||
| type Config struct { | ||||
| 	// Zone must be provided to support responding to queries | ||||
| @@ -51,9 +55,15 @@ type Config struct { | ||||
|  | ||||
| 	// Port If it is not 0, replace the port 5353 with this port number. | ||||
| 	Port int | ||||
|  | ||||
| 	// GetMachineIP is a function to return the IP of the local machine | ||||
| 	GetMachineIP GetMachineIP | ||||
| 	// LocalhostChecking if enabled asks the server to also send responses to 0.0.0.0 if the target IP | ||||
| 	// is this host (as defined by GetMachineIP). Useful in case machine is on a VPN which blocks comms on non standard ports | ||||
| 	LocalhostChecking bool | ||||
| } | ||||
|  | ||||
| // mDNS server is used to listen for mDNS queries and respond if we | ||||
| // Server is an mDNS server used to listen for mDNS queries and respond if we | ||||
| // have a matching local record | ||||
| type Server struct { | ||||
| 	config *Config | ||||
| @@ -65,6 +75,8 @@ type Server struct { | ||||
| 	shutdownCh   chan struct{} | ||||
| 	shutdownLock sync.Mutex | ||||
| 	wg           sync.WaitGroup | ||||
|  | ||||
| 	outboundIP net.IP | ||||
| } | ||||
|  | ||||
| // NewServer is used to create a new mDNS server from a config | ||||
| @@ -118,11 +130,17 @@ func NewServer(config *Config) (*Server, error) { | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	ipFunc := getOutboundIP | ||||
| 	if config.GetMachineIP != nil { | ||||
| 		ipFunc = config.GetMachineIP | ||||
| 	} | ||||
|  | ||||
| 	s := &Server{ | ||||
| 		config:     config, | ||||
| 		ipv4List:   ipv4List, | ||||
| 		ipv6List:   ipv6List, | ||||
| 		shutdownCh: make(chan struct{}), | ||||
| 		outboundIP: ipFunc(), | ||||
| 	} | ||||
|  | ||||
| 	go s.recv(s.ipv4List) | ||||
| @@ -176,7 +194,7 @@ func (s *Server) recv(c *net.UDPConn) { | ||||
| 			continue | ||||
| 		} | ||||
| 		if err := s.parsePacket(buf[:n], from); err != nil { | ||||
| 			log.Printf("[ERR] mdns: Failed to handle query: %v", err) | ||||
| 			log.Errorf("[ERR] mdns: Failed to handle query: %v", err) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| @@ -185,7 +203,7 @@ func (s *Server) recv(c *net.UDPConn) { | ||||
| func (s *Server) parsePacket(packet []byte, from net.Addr) error { | ||||
| 	var msg dns.Msg | ||||
| 	if err := msg.Unpack(packet); err != nil { | ||||
| 		log.Printf("[ERR] mdns: Failed to unpack packet: %v", err) | ||||
| 		log.Errorf("[ERR] mdns: Failed to unpack packet: %v", err) | ||||
| 		return err | ||||
| 	} | ||||
| 	// TODO: This is a bit of a hack | ||||
| @@ -278,8 +296,8 @@ func (s *Server) handleQuery(query *dns.Msg, from net.Addr) error { | ||||
| 			// caveats in the RFC), so set the Compress bit (part of the dns library | ||||
| 			// API, not part of the DNS packet) to true. | ||||
| 			Compress: true, | ||||
|  | ||||
| 			Answer: answer, | ||||
| 			Question: query.Question, | ||||
| 			Answer:   answer, | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| @@ -302,7 +320,6 @@ func (s *Server) handleQuery(query *dns.Msg, from net.Addr) error { | ||||
| // both.  The return values are DNS records for each transmission type. | ||||
| func (s *Server) handleQuestion(q dns.Question) (multicastRecs, unicastRecs []dns.RR) { | ||||
| 	records := s.config.Zone.Records(q) | ||||
|  | ||||
| 	if len(records) == 0 { | ||||
| 		return nil, nil | ||||
| 	} | ||||
| @@ -365,7 +382,7 @@ func (s *Server) probe() { | ||||
|  | ||||
| 	for i := 0; i < 3; i++ { | ||||
| 		if err := s.SendMulticast(q); err != nil { | ||||
| 			log.Println("[ERR] mdns: failed to send probe:", err.Error()) | ||||
| 			log.Errorf("[ERR] mdns: failed to send probe:", err.Error()) | ||||
| 		} | ||||
| 		time.Sleep(time.Duration(randomizer.Intn(250)) * time.Millisecond) | ||||
| 	} | ||||
| @@ -391,7 +408,7 @@ func (s *Server) probe() { | ||||
| 	timer := time.NewTimer(timeout) | ||||
| 	for i := 0; i < 3; i++ { | ||||
| 		if err := s.SendMulticast(resp); err != nil { | ||||
| 			log.Println("[ERR] mdns: failed to send announcement:", err.Error()) | ||||
| 			log.Errorf("[ERR] mdns: failed to send announcement:", err.Error()) | ||||
| 		} | ||||
| 		select { | ||||
| 		case <-timer.C: | ||||
| @@ -404,7 +421,7 @@ func (s *Server) probe() { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // multicastResponse us used to send a multicast response packet | ||||
| // SendMulticast us used to send a multicast response packet | ||||
| func (s *Server) SendMulticast(msg *dns.Msg) error { | ||||
| 	buf, err := msg.Pack() | ||||
| 	if err != nil { | ||||
| @@ -430,13 +447,23 @@ func (s *Server) sendResponse(resp *dns.Msg, from net.Addr) error { | ||||
|  | ||||
| 	// Determine the socket to send from | ||||
| 	addr := from.(*net.UDPAddr) | ||||
| 	if addr.IP.To4() != nil { | ||||
| 		_, err = s.ipv4List.WriteToUDP(buf, addr) | ||||
| 		return err | ||||
| 	} else { | ||||
| 		_, err = s.ipv6List.WriteToUDP(buf, addr) | ||||
| 		return err | ||||
| 	conn := s.ipv4List | ||||
| 	backupTarget := net.IPv4zero | ||||
|  | ||||
| 	if addr.IP.To4() == nil { | ||||
| 		conn = s.ipv6List | ||||
| 		backupTarget = net.IPv6zero | ||||
| 	} | ||||
| 	_, err = conn.WriteToUDP(buf, addr) | ||||
| 	// If the address we're responding to is this machine then we can also attempt sending on 0.0.0.0 | ||||
| 	// This covers the case where this machine is using a VPN and certain ports are blocked so the response never gets there | ||||
| 	// Sending two responses is OK | ||||
| 	if s.config.LocalhostChecking && addr.IP.Equal(s.outboundIP) { | ||||
| 		// ignore any errors, this is best efforts | ||||
| 		conn.WriteToUDP(buf, &net.UDPAddr{IP: backupTarget, Port: addr.Port}) | ||||
| 	} | ||||
| 	return err | ||||
|  | ||||
| } | ||||
|  | ||||
| func (s *Server) unregister() error { | ||||
| @@ -474,3 +501,17 @@ func setCustomPort(port int) { | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // getOutboundIP returns the IP address of this machine as seen when dialling out | ||||
| func getOutboundIP() net.IP { | ||||
| 	conn, err := net.Dial("udp", "8.8.8.8:80") | ||||
| 	if err != nil { | ||||
| 		// no net connectivity maybe so fallback | ||||
| 		return nil | ||||
| 	} | ||||
| 	defer conn.Close() | ||||
|  | ||||
| 	localAddr := conn.LocalAddr().(*net.UDPAddr) | ||||
|  | ||||
| 	return localAddr.IP | ||||
| } | ||||
|   | ||||
| @@ -7,7 +7,7 @@ import ( | ||||
|  | ||||
| func TestServer_StartStop(t *testing.T) { | ||||
| 	s := makeService(t) | ||||
| 	serv, err := NewServer(&Config{Zone: s}) | ||||
| 	serv, err := NewServer(&Config{Zone: s, LocalhostChecking: true}) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %v", err) | ||||
| 	} | ||||
| @@ -15,7 +15,7 @@ func TestServer_StartStop(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestServer_Lookup(t *testing.T) { | ||||
| 	serv, err := NewServer(&Config{Zone: makeServiceWithServiceName(t, "_foobar._tcp")}) | ||||
| 	serv, err := NewServer(&Config{Zone: makeServiceWithServiceName(t, "_foobar._tcp"), LocalhostChecking: true}) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %v", err) | ||||
| 	} | ||||
|   | ||||
		Reference in New Issue
	
	Block a user