Compare commits
	
		
			8 Commits
		
	
	
		
			b6d2d459c5
			...
			v2.9.1
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 94bd1025a6 | ||
|  | 7be4a67673 | ||
|  | 3e6ac73cfe | ||
|  | aef6878ee0 | ||
|  | 81aa8e0231 | ||
|  | c28f625cd4 | ||
|  | 5b161b88f7 | ||
|  | cca8826a1f | 
							
								
								
									
										1
									
								
								.github/workflows/docker.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/docker.yml
									
									
									
									
										vendored
									
									
								
							| @@ -19,3 +19,4 @@ jobs: | |||||||
|            name: micro/go-micro |            name: micro/go-micro | ||||||
|            username: ${{ secrets.DOCKER_USERNAME }} |            username: ${{ secrets.DOCKER_USERNAME }} | ||||||
|            password: ${{ secrets.DOCKER_PASSWORD }} |            password: ${{ secrets.DOCKER_PASSWORD }} | ||||||
|  |            tag_names: true | ||||||
| @@ -70,12 +70,32 @@ func (dc *discordConn) Recv(event *input.Event) error { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func ChunkString(s string, chunkSize int) []string { | ||||||
|  | 	var chunks []string | ||||||
|  | 	runes := []rune(s) | ||||||
|  |  | ||||||
|  | 	if len(runes) == 0 { | ||||||
|  | 		return []string{s} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for i := 0; i < len(runes); i += chunkSize { | ||||||
|  | 		nn := i + chunkSize | ||||||
|  | 		if nn > len(runes) { | ||||||
|  | 			nn = len(runes) | ||||||
|  | 		} | ||||||
|  | 		chunks = append(chunks, string(runes[i:nn])) | ||||||
|  | 	} | ||||||
|  | 	return chunks | ||||||
|  | } | ||||||
|  |  | ||||||
| func (dc *discordConn) Send(e *input.Event) error { | func (dc *discordConn) Send(e *input.Event) error { | ||||||
| 	fields := strings.Split(e.To, ":") | 	fields := strings.Split(e.To, ":") | ||||||
| 	_, err := dc.master.session.ChannelMessageSend(fields[0], string(e.Data)) | 	for _, chunk := range ChunkString(string(e.Data), 2000) { | ||||||
| 	if err != nil { | 		_, err := dc.master.session.ChannelMessageSend(fields[0], chunk) | ||||||
| 		if logger.V(logger.ErrorLevel, logger.DefaultLogger) { | 		if err != nil { | ||||||
| 			logger.Error("[bot][loop][send]", err) | 			if logger.V(logger.ErrorLevel, logger.DefaultLogger) { | ||||||
|  | 				logger.Error("[bot][loop][send]", err) | ||||||
|  | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	return nil | 	return nil | ||||||
|   | |||||||
| @@ -41,7 +41,6 @@ func newConfig(opts ...Option) (Config, error) { | |||||||
|  |  | ||||||
| func (c *config) Init(opts ...Option) error { | func (c *config) Init(opts ...Option) error { | ||||||
| 	c.opts = Options{ | 	c.opts = Options{ | ||||||
| 		Loader: memory.NewLoader(), |  | ||||||
| 		Reader: json.NewReader(), | 		Reader: json.NewReader(), | ||||||
| 	} | 	} | ||||||
| 	c.exit = make(chan bool) | 	c.exit = make(chan bool) | ||||||
| @@ -49,6 +48,11 @@ func (c *config) Init(opts ...Option) error { | |||||||
| 		o(&c.opts) | 		o(&c.opts) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// default loader uses the configured reader | ||||||
|  | 	if c.opts.Loader == nil { | ||||||
|  | 		c.opts.Loader = memory.NewLoader(memory.WithReader(c.opts.Reader)) | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	err := c.opts.Loader.Load(c.opts.Source...) | 	err := c.opts.Loader.Load(c.opts.Source...) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| @@ -82,6 +86,11 @@ func (c *config) run() { | |||||||
|  |  | ||||||
| 			c.Lock() | 			c.Lock() | ||||||
|  |  | ||||||
|  | 			if c.snap.Version >= snap.Version { | ||||||
|  | 				c.Unlock() | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  |  | ||||||
| 			// save | 			// save | ||||||
| 			c.snap = snap | 			c.snap = snap | ||||||
|  |  | ||||||
|   | |||||||
| @@ -4,12 +4,15 @@ import ( | |||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"os" | 	"os" | ||||||
| 	"path/filepath" | 	"path/filepath" | ||||||
|  | 	"runtime" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/micro/go-micro/v2/config/source" | ||||||
| 	"github.com/micro/go-micro/v2/config/source/env" | 	"github.com/micro/go-micro/v2/config/source/env" | ||||||
| 	"github.com/micro/go-micro/v2/config/source/file" | 	"github.com/micro/go-micro/v2/config/source/file" | ||||||
|  | 	"github.com/micro/go-micro/v2/config/source/memory" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func createFileForIssue18(t *testing.T, content string) *os.File { | func createFileForIssue18(t *testing.T, content string) *os.File { | ||||||
| @@ -127,3 +130,37 @@ func TestConfigMerge(t *testing.T) { | |||||||
| 			actualHost) | 			actualHost) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func equalS(t *testing.T, actual, expect string) { | ||||||
|  | 	if actual != expect { | ||||||
|  | 		t.Errorf("Expected %s but got %s", actual, expect) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestConfigWatcherDirtyOverrite(t *testing.T) { | ||||||
|  | 	n := runtime.GOMAXPROCS(0) | ||||||
|  | 	defer runtime.GOMAXPROCS(n) | ||||||
|  |  | ||||||
|  | 	runtime.GOMAXPROCS(1) | ||||||
|  |  | ||||||
|  | 	l := 100 | ||||||
|  |  | ||||||
|  | 	ss := make([]source.Source, l, l) | ||||||
|  |  | ||||||
|  | 	for i := 0; i < l; i++ { | ||||||
|  | 		ss[i] = memory.NewSource(memory.WithJSON([]byte(fmt.Sprintf(`{"key%d": "val%d"}`, i, i)))) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	conf, _ := NewConfig() | ||||||
|  |  | ||||||
|  | 	for _, s := range ss { | ||||||
|  | 		_ = conf.Load(s) | ||||||
|  | 	} | ||||||
|  | 	runtime.Gosched() | ||||||
|  |  | ||||||
|  | 	for i, _ := range ss { | ||||||
|  | 		k := fmt.Sprintf("key%d", i) | ||||||
|  | 		v := fmt.Sprintf("val%d", i) | ||||||
|  | 		equalS(t, conf.Get(k).String(""), v) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|   | |||||||
| @@ -32,19 +32,21 @@ type memory struct { | |||||||
| 	watchers *list.List | 	watchers *list.List | ||||||
| } | } | ||||||
|  |  | ||||||
|  | type updateValue struct { | ||||||
|  | 	version string | ||||||
|  | 	value   reader.Value | ||||||
|  | } | ||||||
|  |  | ||||||
| type watcher struct { | type watcher struct { | ||||||
| 	exit    chan bool | 	exit    chan bool | ||||||
| 	path    []string | 	path    []string | ||||||
| 	value   reader.Value | 	value   reader.Value | ||||||
| 	reader  reader.Reader | 	reader  reader.Reader | ||||||
| 	updates chan reader.Value | 	version string | ||||||
|  | 	updates chan updateValue | ||||||
| } | } | ||||||
|  |  | ||||||
| func (m *memory) watch(idx int, s source.Source) { | func (m *memory) watch(idx int, s source.Source) { | ||||||
| 	m.Lock() |  | ||||||
| 	m.sets = append(m.sets, &source.ChangeSet{Source: s.String()}) |  | ||||||
| 	m.Unlock() |  | ||||||
|  |  | ||||||
| 	// watches a source for changes | 	// watches a source for changes | ||||||
| 	watch := func(idx int, s source.Watcher) error { | 	watch := func(idx int, s source.Watcher) error { | ||||||
| 		for { | 		for { | ||||||
| @@ -70,7 +72,7 @@ func (m *memory) watch(idx int, s source.Source) { | |||||||
| 			m.vals, _ = m.opts.Reader.Values(set) | 			m.vals, _ = m.opts.Reader.Values(set) | ||||||
| 			m.snap = &loader.Snapshot{ | 			m.snap = &loader.Snapshot{ | ||||||
| 				ChangeSet: set, | 				ChangeSet: set, | ||||||
| 				Version:   fmt.Sprintf("%d", time.Now().Unix()), | 				Version:   genVer(), | ||||||
| 			} | 			} | ||||||
| 			m.Unlock() | 			m.Unlock() | ||||||
|  |  | ||||||
| @@ -141,7 +143,7 @@ func (m *memory) reload() error { | |||||||
| 	m.vals, _ = m.opts.Reader.Values(set) | 	m.vals, _ = m.opts.Reader.Values(set) | ||||||
| 	m.snap = &loader.Snapshot{ | 	m.snap = &loader.Snapshot{ | ||||||
| 		ChangeSet: set, | 		ChangeSet: set, | ||||||
| 		Version:   fmt.Sprintf("%d", time.Now().Unix()), | 		Version:   genVer(), | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	m.Unlock() | 	m.Unlock() | ||||||
| @@ -159,11 +161,23 @@ func (m *memory) update() { | |||||||
| 	for e := m.watchers.Front(); e != nil; e = e.Next() { | 	for e := m.watchers.Front(); e != nil; e = e.Next() { | ||||||
| 		watchers = append(watchers, e.Value.(*watcher)) | 		watchers = append(watchers, e.Value.(*watcher)) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	vals := m.vals | ||||||
|  | 	snap := m.snap | ||||||
| 	m.RUnlock() | 	m.RUnlock() | ||||||
|  |  | ||||||
| 	for _, w := range watchers { | 	for _, w := range watchers { | ||||||
|  | 		if w.version >= snap.Version { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		uv := updateValue{ | ||||||
|  | 			version: m.snap.Version, | ||||||
|  | 			value:   vals.Get(w.path...), | ||||||
|  | 		} | ||||||
|  |  | ||||||
| 		select { | 		select { | ||||||
| 		case w.updates <- m.vals.Get(w.path...): | 		case w.updates <- uv: | ||||||
| 		default: | 		default: | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| @@ -226,7 +240,7 @@ func (m *memory) Sync() error { | |||||||
| 	m.vals = vals | 	m.vals = vals | ||||||
| 	m.snap = &loader.Snapshot{ | 	m.snap = &loader.Snapshot{ | ||||||
| 		ChangeSet: set, | 		ChangeSet: set, | ||||||
| 		Version:   fmt.Sprintf("%d", time.Now().Unix()), | 		Version:   genVer(), | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	m.Unlock() | 	m.Unlock() | ||||||
| @@ -285,6 +299,7 @@ func (m *memory) Get(path ...string) (reader.Value, error) { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// ok we're going hardcore now | 	// ok we're going hardcore now | ||||||
|  |  | ||||||
| 	return nil, errors.New("no values") | 	return nil, errors.New("no values") | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -333,7 +348,8 @@ func (m *memory) Watch(path ...string) (loader.Watcher, error) { | |||||||
| 		path:    path, | 		path:    path, | ||||||
| 		value:   value, | 		value:   value, | ||||||
| 		reader:  m.opts.Reader, | 		reader:  m.opts.Reader, | ||||||
| 		updates: make(chan reader.Value, 1), | 		updates: make(chan updateValue, 1), | ||||||
|  | 		version: m.snap.Version, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	e := m.watchers.PushBack(w) | 	e := m.watchers.PushBack(w) | ||||||
| @@ -355,28 +371,43 @@ func (m *memory) String() string { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (w *watcher) Next() (*loader.Snapshot, error) { | func (w *watcher) Next() (*loader.Snapshot, error) { | ||||||
|  | 	update := func(v reader.Value) *loader.Snapshot { | ||||||
|  | 		w.value = v | ||||||
|  |  | ||||||
|  | 		cs := &source.ChangeSet{ | ||||||
|  | 			Data:      v.Bytes(), | ||||||
|  | 			Format:    w.reader.String(), | ||||||
|  | 			Source:    "memory", | ||||||
|  | 			Timestamp: time.Now(), | ||||||
|  | 		} | ||||||
|  | 		cs.Checksum = cs.Sum() | ||||||
|  |  | ||||||
|  | 		return &loader.Snapshot{ | ||||||
|  | 			ChangeSet: cs, | ||||||
|  | 			Version:   w.version, | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	for { | 	for { | ||||||
| 		select { | 		select { | ||||||
| 		case <-w.exit: | 		case <-w.exit: | ||||||
| 			return nil, errors.New("watcher stopped") | 			return nil, errors.New("watcher stopped") | ||||||
| 		case v := <-w.updates: |  | ||||||
|  | 		case uv := <-w.updates: | ||||||
|  | 			if uv.version <= w.version { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			v := uv.value | ||||||
|  |  | ||||||
|  | 			w.version = uv.version | ||||||
|  |  | ||||||
| 			if bytes.Equal(w.value.Bytes(), v.Bytes()) { | 			if bytes.Equal(w.value.Bytes(), v.Bytes()) { | ||||||
| 				continue | 				continue | ||||||
| 			} | 			} | ||||||
| 			w.value = v |  | ||||||
|  |  | ||||||
| 			cs := &source.ChangeSet{ | 			return update(v), nil | ||||||
| 				Data:      v.Bytes(), |  | ||||||
| 				Format:    w.reader.String(), |  | ||||||
| 				Source:    "memory", |  | ||||||
| 				Timestamp: time.Now(), |  | ||||||
| 			} |  | ||||||
| 			cs.Sum() |  | ||||||
|  |  | ||||||
| 			return &loader.Snapshot{ |  | ||||||
| 				ChangeSet: cs, |  | ||||||
| 				Version:   fmt.Sprintf("%d", time.Now().Unix()), |  | ||||||
| 			}, nil |  | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| @@ -386,10 +417,16 @@ func (w *watcher) Stop() error { | |||||||
| 	case <-w.exit: | 	case <-w.exit: | ||||||
| 	default: | 	default: | ||||||
| 		close(w.exit) | 		close(w.exit) | ||||||
|  | 		close(w.updates) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func genVer() string { | ||||||
|  | 	return fmt.Sprintf("%d", time.Now().UnixNano()) | ||||||
|  | } | ||||||
|  |  | ||||||
| func NewLoader(opts ...loader.Option) loader.Loader { | func NewLoader(opts ...loader.Option) loader.Loader { | ||||||
| 	options := loader.Options{ | 	options := loader.Options{ | ||||||
| 		Reader: json.NewReader(), | 		Reader: json.NewReader(), | ||||||
| @@ -406,7 +443,10 @@ func NewLoader(opts ...loader.Option) loader.Loader { | |||||||
| 		sources:  options.Source, | 		sources:  options.Source, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	m.sets = make([]*source.ChangeSet, len(options.Source)) | ||||||
|  |  | ||||||
| 	for i, s := range options.Source { | 	for i, s := range options.Source { | ||||||
|  | 		m.sets[i] = &source.ChangeSet{Source: s.String()} | ||||||
| 		go m.watch(i, s) | 		go m.watch(i, s) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|   | |||||||
| @@ -42,6 +42,7 @@ func (s *memory) Watch() (source.Watcher, error) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (m *memory) Write(cs *source.ChangeSet) error { | func (m *memory) Write(cs *source.ChangeSet) error { | ||||||
|  | 	m.Update(cs) | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -252,7 +252,7 @@ func (m *mdnsRegistry) Register(service *Service, opts ...RegisterOption) error | |||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		srv, err := mdns.NewServer(&mdns.Config{Zone: s}) | 		srv, err := mdns.NewServer(&mdns.Config{Zone: s, LocalhostChecking: true}) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			gerr = err | 			gerr = err | ||||||
| 			continue | 			continue | ||||||
| @@ -563,9 +563,7 @@ func (m *mdnsWatcher) Next() (*Result, error) { | |||||||
| 			if len(m.wo.Service) > 0 && txt.Service != m.wo.Service { | 			if len(m.wo.Service) > 0 && txt.Service != m.wo.Service { | ||||||
| 				continue | 				continue | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			var action string | 			var action string | ||||||
|  |  | ||||||
| 			if e.TTL == 0 { | 			if e.TTL == 0 { | ||||||
| 				action = "delete" | 				action = "delete" | ||||||
| 			} else { | 			} else { | ||||||
| @@ -584,9 +582,18 @@ func (m *mdnsWatcher) Next() (*Result, error) { | |||||||
| 				continue | 				continue | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
|  | 			var addr string | ||||||
|  | 			if len(e.AddrV4) > 0 { | ||||||
|  | 				addr = e.AddrV4.String() | ||||||
|  | 			} else if len(e.AddrV6) > 0 { | ||||||
|  | 				addr = "[" + e.AddrV6.String() + "]" | ||||||
|  | 			} else { | ||||||
|  | 				addr = e.Addr.String() | ||||||
|  | 			} | ||||||
|  |  | ||||||
| 			service.Nodes = append(service.Nodes, &Node{ | 			service.Nodes = append(service.Nodes, &Node{ | ||||||
| 				Id:       strings.TrimSuffix(e.Name, suffix), | 				Id:       strings.TrimSuffix(e.Name, suffix), | ||||||
| 				Address:  fmt.Sprintf("%s:%d", e.AddrV4.String(), e.Port), | 				Address:  fmt.Sprintf("%s:%d", addr, e.Port), | ||||||
| 				Metadata: txt.Metadata, | 				Metadata: txt.Metadata, | ||||||
| 			}) | 			}) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -363,11 +363,12 @@ func (s *rpcServer) ServeConn(sock transport.Socket) { | |||||||
| 			r = rpcRouter{h: handler} | 			r = rpcRouter{h: handler} | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		// wait for two coroutines to exit | ||||||
|  | 		// serve the request and process the outbound messages | ||||||
|  | 		wg.Add(2) | ||||||
|  |  | ||||||
| 		// process the outbound messages from the socket | 		// process the outbound messages from the socket | ||||||
| 		go func(id string, psock *socket.Socket) { | 		go func(id string, psock *socket.Socket) { | ||||||
| 			// wait for processing to exit |  | ||||||
| 			wg.Add(1) |  | ||||||
|  |  | ||||||
| 			defer func() { | 			defer func() { | ||||||
| 				// TODO: don't hack this but if its grpc just break out of the stream | 				// TODO: don't hack this but if its grpc just break out of the stream | ||||||
| 				// We do this because the underlying connection is h2 and its a stream | 				// We do this because the underlying connection is h2 and its a stream | ||||||
| @@ -405,9 +406,6 @@ func (s *rpcServer) ServeConn(sock transport.Socket) { | |||||||
|  |  | ||||||
| 		// serve the request in a go routine as this may be a stream | 		// serve the request in a go routine as this may be a stream | ||||||
| 		go func(id string, psock *socket.Socket) { | 		go func(id string, psock *socket.Socket) { | ||||||
| 			// add to the waitgroup |  | ||||||
| 			wg.Add(1) |  | ||||||
|  |  | ||||||
| 			defer func() { | 			defer func() { | ||||||
| 				// release the socket | 				// release the socket | ||||||
| 				pool.Release(psock) | 				pool.Release(psock) | ||||||
|   | |||||||
| @@ -17,6 +17,15 @@ func init() { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // AppendPrivateBlocks append private network blocks | ||||||
|  | func AppendPrivateBlocks(bs ...string) { | ||||||
|  | 	for _, b := range bs { | ||||||
|  | 		if _, block, err := net.ParseCIDR(b); err == nil { | ||||||
|  | 			privateBlocks = append(privateBlocks, block) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| func isPrivateIP(ipAddr string) bool { | func isPrivateIP(ipAddr string) bool { | ||||||
| 	ip := net.ParseIP(ipAddr) | 	ip := net.ParseIP(ipAddr) | ||||||
| 	for _, priv := range privateBlocks { | 	for _, priv := range privateBlocks { | ||||||
|   | |||||||
| @@ -56,3 +56,24 @@ func TestExtractor(t *testing.T) { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestAppendPrivateBlocks(t *testing.T) { | ||||||
|  | 	tests := []struct { | ||||||
|  | 		addr   string | ||||||
|  | 		expect bool | ||||||
|  | 	}{ | ||||||
|  | 		{addr: "9.134.71.34", expect: true}, | ||||||
|  | 		{addr: "8.10.110.34", expect: false}, // not in private blocks | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	AppendPrivateBlocks("9.134.0.0/16") | ||||||
|  |  | ||||||
|  | 	for _, test := range tests { | ||||||
|  | 		t.Run(test.addr, func(t *testing.T) { | ||||||
|  | 			res := isPrivateIP(test.addr) | ||||||
|  | 			if res != test.expect { | ||||||
|  | 				t.Fatalf("expected %t got %t", test.expect, res) | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|   | |||||||
| @@ -34,6 +34,7 @@ type ServiceEntry struct { | |||||||
|  |  | ||||||
| // complete is used to check if we have all the info we need | // complete is used to check if we have all the info we need | ||||||
| func (s *ServiceEntry) complete() bool { | func (s *ServiceEntry) complete() bool { | ||||||
|  |  | ||||||
| 	return (len(s.AddrV4) > 0 || len(s.AddrV6) > 0 || len(s.Addr) > 0) && s.Port != 0 && s.hasTXT | 	return (len(s.AddrV4) > 0 || len(s.AddrV6) > 0 || len(s.Addr) > 0) && s.Port != 0 && s.hasTXT | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -347,15 +348,21 @@ func (c *client) query(params *QueryParam) error { | |||||||
| 		select { | 		select { | ||||||
| 		case resp := <-msgCh: | 		case resp := <-msgCh: | ||||||
| 			inp := messageToEntry(resp, inprogress) | 			inp := messageToEntry(resp, inprogress) | ||||||
|  |  | ||||||
| 			if inp == nil { | 			if inp == nil { | ||||||
| 				continue | 				continue | ||||||
| 			} | 			} | ||||||
|  | 			if len(resp.Question) == 0 || resp.Question[0].Name != m.Question[0].Name { | ||||||
|  | 				// discard anything which we've not asked for | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  |  | ||||||
| 			// Check if this entry is complete | 			// Check if this entry is complete | ||||||
| 			if inp.complete() { | 			if inp.complete() { | ||||||
| 				if inp.sent { | 				if inp.sent { | ||||||
| 					continue | 					continue | ||||||
| 				} | 				} | ||||||
|  |  | ||||||
| 				inp.sent = true | 				inp.sent = true | ||||||
| 				select { | 				select { | ||||||
| 				case params.Entries <- inp: | 				case params.Entries <- inp: | ||||||
|   | |||||||
| @@ -2,13 +2,13 @@ package mdns | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"log" |  | ||||||
| 	"math/rand" | 	"math/rand" | ||||||
| 	"net" | 	"net" | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"sync/atomic" | 	"sync/atomic" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
|  | 	log "github.com/micro/go-micro/v2/logger" | ||||||
| 	"github.com/miekg/dns" | 	"github.com/miekg/dns" | ||||||
| 	"golang.org/x/net/ipv4" | 	"golang.org/x/net/ipv4" | ||||||
| 	"golang.org/x/net/ipv6" | 	"golang.org/x/net/ipv6" | ||||||
| @@ -39,6 +39,10 @@ var ( | |||||||
| 	} | 	} | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | // GetMachineIP is a func which returns the outbound IP of this machine. | ||||||
|  | // Used by the server to determine whether to attempt send the response on a local address | ||||||
|  | type GetMachineIP func() net.IP | ||||||
|  |  | ||||||
| // Config is used to configure the mDNS server | // Config is used to configure the mDNS server | ||||||
| type Config struct { | type Config struct { | ||||||
| 	// Zone must be provided to support responding to queries | 	// Zone must be provided to support responding to queries | ||||||
| @@ -51,9 +55,15 @@ type Config struct { | |||||||
|  |  | ||||||
| 	// Port If it is not 0, replace the port 5353 with this port number. | 	// Port If it is not 0, replace the port 5353 with this port number. | ||||||
| 	Port int | 	Port int | ||||||
|  |  | ||||||
|  | 	// GetMachineIP is a function to return the IP of the local machine | ||||||
|  | 	GetMachineIP GetMachineIP | ||||||
|  | 	// LocalhostChecking if enabled asks the server to also send responses to 0.0.0.0 if the target IP | ||||||
|  | 	// is this host (as defined by GetMachineIP). Useful in case machine is on a VPN which blocks comms on non standard ports | ||||||
|  | 	LocalhostChecking bool | ||||||
| } | } | ||||||
|  |  | ||||||
| // mDNS server is used to listen for mDNS queries and respond if we | // Server is an mDNS server used to listen for mDNS queries and respond if we | ||||||
| // have a matching local record | // have a matching local record | ||||||
| type Server struct { | type Server struct { | ||||||
| 	config *Config | 	config *Config | ||||||
| @@ -65,6 +75,8 @@ type Server struct { | |||||||
| 	shutdownCh   chan struct{} | 	shutdownCh   chan struct{} | ||||||
| 	shutdownLock sync.Mutex | 	shutdownLock sync.Mutex | ||||||
| 	wg           sync.WaitGroup | 	wg           sync.WaitGroup | ||||||
|  |  | ||||||
|  | 	outboundIP net.IP | ||||||
| } | } | ||||||
|  |  | ||||||
| // NewServer is used to create a new mDNS server from a config | // NewServer is used to create a new mDNS server from a config | ||||||
| @@ -118,11 +130,17 @@ func NewServer(config *Config) (*Server, error) { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	ipFunc := getOutboundIP | ||||||
|  | 	if config.GetMachineIP != nil { | ||||||
|  | 		ipFunc = config.GetMachineIP | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	s := &Server{ | 	s := &Server{ | ||||||
| 		config:     config, | 		config:     config, | ||||||
| 		ipv4List:   ipv4List, | 		ipv4List:   ipv4List, | ||||||
| 		ipv6List:   ipv6List, | 		ipv6List:   ipv6List, | ||||||
| 		shutdownCh: make(chan struct{}), | 		shutdownCh: make(chan struct{}), | ||||||
|  | 		outboundIP: ipFunc(), | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	go s.recv(s.ipv4List) | 	go s.recv(s.ipv4List) | ||||||
| @@ -176,7 +194,7 @@ func (s *Server) recv(c *net.UDPConn) { | |||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
| 		if err := s.parsePacket(buf[:n], from); err != nil { | 		if err := s.parsePacket(buf[:n], from); err != nil { | ||||||
| 			log.Printf("[ERR] mdns: Failed to handle query: %v", err) | 			log.Errorf("[ERR] mdns: Failed to handle query: %v", err) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| @@ -185,7 +203,7 @@ func (s *Server) recv(c *net.UDPConn) { | |||||||
| func (s *Server) parsePacket(packet []byte, from net.Addr) error { | func (s *Server) parsePacket(packet []byte, from net.Addr) error { | ||||||
| 	var msg dns.Msg | 	var msg dns.Msg | ||||||
| 	if err := msg.Unpack(packet); err != nil { | 	if err := msg.Unpack(packet); err != nil { | ||||||
| 		log.Printf("[ERR] mdns: Failed to unpack packet: %v", err) | 		log.Errorf("[ERR] mdns: Failed to unpack packet: %v", err) | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	// TODO: This is a bit of a hack | 	// TODO: This is a bit of a hack | ||||||
| @@ -278,8 +296,8 @@ func (s *Server) handleQuery(query *dns.Msg, from net.Addr) error { | |||||||
| 			// caveats in the RFC), so set the Compress bit (part of the dns library | 			// caveats in the RFC), so set the Compress bit (part of the dns library | ||||||
| 			// API, not part of the DNS packet) to true. | 			// API, not part of the DNS packet) to true. | ||||||
| 			Compress: true, | 			Compress: true, | ||||||
|  | 			Question: query.Question, | ||||||
| 			Answer: answer, | 			Answer:   answer, | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -302,7 +320,6 @@ func (s *Server) handleQuery(query *dns.Msg, from net.Addr) error { | |||||||
| // both.  The return values are DNS records for each transmission type. | // both.  The return values are DNS records for each transmission type. | ||||||
| func (s *Server) handleQuestion(q dns.Question) (multicastRecs, unicastRecs []dns.RR) { | func (s *Server) handleQuestion(q dns.Question) (multicastRecs, unicastRecs []dns.RR) { | ||||||
| 	records := s.config.Zone.Records(q) | 	records := s.config.Zone.Records(q) | ||||||
|  |  | ||||||
| 	if len(records) == 0 { | 	if len(records) == 0 { | ||||||
| 		return nil, nil | 		return nil, nil | ||||||
| 	} | 	} | ||||||
| @@ -365,7 +382,7 @@ func (s *Server) probe() { | |||||||
|  |  | ||||||
| 	for i := 0; i < 3; i++ { | 	for i := 0; i < 3; i++ { | ||||||
| 		if err := s.SendMulticast(q); err != nil { | 		if err := s.SendMulticast(q); err != nil { | ||||||
| 			log.Println("[ERR] mdns: failed to send probe:", err.Error()) | 			log.Errorf("[ERR] mdns: failed to send probe:", err.Error()) | ||||||
| 		} | 		} | ||||||
| 		time.Sleep(time.Duration(randomizer.Intn(250)) * time.Millisecond) | 		time.Sleep(time.Duration(randomizer.Intn(250)) * time.Millisecond) | ||||||
| 	} | 	} | ||||||
| @@ -391,7 +408,7 @@ func (s *Server) probe() { | |||||||
| 	timer := time.NewTimer(timeout) | 	timer := time.NewTimer(timeout) | ||||||
| 	for i := 0; i < 3; i++ { | 	for i := 0; i < 3; i++ { | ||||||
| 		if err := s.SendMulticast(resp); err != nil { | 		if err := s.SendMulticast(resp); err != nil { | ||||||
| 			log.Println("[ERR] mdns: failed to send announcement:", err.Error()) | 			log.Errorf("[ERR] mdns: failed to send announcement:", err.Error()) | ||||||
| 		} | 		} | ||||||
| 		select { | 		select { | ||||||
| 		case <-timer.C: | 		case <-timer.C: | ||||||
| @@ -404,7 +421,7 @@ func (s *Server) probe() { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| // multicastResponse us used to send a multicast response packet | // SendMulticast us used to send a multicast response packet | ||||||
| func (s *Server) SendMulticast(msg *dns.Msg) error { | func (s *Server) SendMulticast(msg *dns.Msg) error { | ||||||
| 	buf, err := msg.Pack() | 	buf, err := msg.Pack() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -430,13 +447,23 @@ func (s *Server) sendResponse(resp *dns.Msg, from net.Addr) error { | |||||||
|  |  | ||||||
| 	// Determine the socket to send from | 	// Determine the socket to send from | ||||||
| 	addr := from.(*net.UDPAddr) | 	addr := from.(*net.UDPAddr) | ||||||
| 	if addr.IP.To4() != nil { | 	conn := s.ipv4List | ||||||
| 		_, err = s.ipv4List.WriteToUDP(buf, addr) | 	backupTarget := net.IPv4zero | ||||||
| 		return err |  | ||||||
| 	} else { | 	if addr.IP.To4() == nil { | ||||||
| 		_, err = s.ipv6List.WriteToUDP(buf, addr) | 		conn = s.ipv6List | ||||||
| 		return err | 		backupTarget = net.IPv6zero | ||||||
| 	} | 	} | ||||||
|  | 	_, err = conn.WriteToUDP(buf, addr) | ||||||
|  | 	// If the address we're responding to is this machine then we can also attempt sending on 0.0.0.0 | ||||||
|  | 	// This covers the case where this machine is using a VPN and certain ports are blocked so the response never gets there | ||||||
|  | 	// Sending two responses is OK | ||||||
|  | 	if s.config.LocalhostChecking && addr.IP.Equal(s.outboundIP) { | ||||||
|  | 		// ignore any errors, this is best efforts | ||||||
|  | 		conn.WriteToUDP(buf, &net.UDPAddr{IP: backupTarget, Port: addr.Port}) | ||||||
|  | 	} | ||||||
|  | 	return err | ||||||
|  |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *Server) unregister() error { | func (s *Server) unregister() error { | ||||||
| @@ -474,3 +501,17 @@ func setCustomPort(port int) { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // getOutboundIP returns the IP address of this machine as seen when dialling out | ||||||
|  | func getOutboundIP() net.IP { | ||||||
|  | 	conn, err := net.Dial("udp", "8.8.8.8:80") | ||||||
|  | 	if err != nil { | ||||||
|  | 		// no net connectivity maybe so fallback | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 	defer conn.Close() | ||||||
|  |  | ||||||
|  | 	localAddr := conn.LocalAddr().(*net.UDPAddr) | ||||||
|  |  | ||||||
|  | 	return localAddr.IP | ||||||
|  | } | ||||||
|   | |||||||
| @@ -7,7 +7,7 @@ import ( | |||||||
|  |  | ||||||
| func TestServer_StartStop(t *testing.T) { | func TestServer_StartStop(t *testing.T) { | ||||||
| 	s := makeService(t) | 	s := makeService(t) | ||||||
| 	serv, err := NewServer(&Config{Zone: s}) | 	serv, err := NewServer(&Config{Zone: s, LocalhostChecking: true}) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("err: %v", err) | 		t.Fatalf("err: %v", err) | ||||||
| 	} | 	} | ||||||
| @@ -15,7 +15,7 @@ func TestServer_StartStop(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestServer_Lookup(t *testing.T) { | func TestServer_Lookup(t *testing.T) { | ||||||
| 	serv, err := NewServer(&Config{Zone: makeServiceWithServiceName(t, "_foobar._tcp")}) | 	serv, err := NewServer(&Config{Zone: makeServiceWithServiceName(t, "_foobar._tcp"), LocalhostChecking: true}) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("err: %v", err) | 		t.Fatalf("err: %v", err) | ||||||
| 	} | 	} | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user