Merge pull request #386 from micro/mdns

Set MDNS as default registry
This commit is contained in:
Asim Aslam 2019-01-15 16:59:07 +00:00 committed by GitHub
commit 7bd0bd14c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 1183 additions and 1194 deletions

View File

@ -202,8 +202,8 @@ var (
defaultClient = "rpc" defaultClient = "rpc"
defaultServer = "rpc" defaultServer = "rpc"
defaultBroker = "http" defaultBroker = "http"
defaultRegistry = "consul" defaultRegistry = "mdns"
defaultSelector = "cache" defaultSelector = "registry"
defaultTransport = "http" defaultTransport = "http"
) )

View File

@ -1,11 +1,387 @@
// Package consul provides a consul based registry and is the default discovery system
package consul package consul
import ( import (
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"runtime"
"sync"
"time"
consul "github.com/hashicorp/consul/api"
"github.com/micro/go-micro/registry" "github.com/micro/go-micro/registry"
hash "github.com/mitchellh/hashstructure"
) )
// NewRegistry returns a new consul registry type consulRegistry struct {
func NewRegistry(opts ...registry.Option) registry.Registry { Address string
return registry.NewRegistry(opts...) Client *consul.Client
opts registry.Options
// connect enabled
connect bool
queryOptions *consul.QueryOptions
sync.Mutex
register map[string]uint64
// lastChecked tracks when a node was last checked as existing in Consul
lastChecked map[string]time.Time
}
func getDeregisterTTL(t time.Duration) time.Duration {
// splay slightly for the watcher?
splay := time.Second * 5
deregTTL := t + splay
// consul has a minimum timeout on deregistration of 1 minute.
if t < time.Minute {
deregTTL = time.Minute + splay
}
return deregTTL
}
func newTransport(config *tls.Config) *http.Transport {
if config == nil {
config = &tls.Config{
InsecureSkipVerify: true,
}
}
t := &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial,
TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: config,
}
runtime.SetFinalizer(&t, func(tr **http.Transport) {
(*tr).CloseIdleConnections()
})
return t
}
func configure(c *consulRegistry, opts ...registry.Option) {
// set opts
for _, o := range opts {
o(&c.opts)
}
// use default config
config := consul.DefaultConfig()
if c.opts.Context != nil {
// Use the consul config passed in the options, if available
if co, ok := c.opts.Context.Value("consul_config").(*consul.Config); ok {
config = co
}
if cn, ok := c.opts.Context.Value("consul_connect").(bool); ok {
c.connect = cn
}
// Use the consul query options passed in the options, if available
if qo, ok := c.opts.Context.Value("consul_query_options").(*consul.QueryOptions); ok && qo != nil {
c.queryOptions = qo
}
if as, ok := c.opts.Context.Value("consul_allow_stale").(bool); ok {
c.queryOptions.AllowStale = as
}
}
// check if there are any addrs
if len(c.opts.Addrs) > 0 {
addr, port, err := net.SplitHostPort(c.opts.Addrs[0])
if ae, ok := err.(*net.AddrError); ok && ae.Err == "missing port in address" {
port = "8500"
addr = c.opts.Addrs[0]
config.Address = fmt.Sprintf("%s:%s", addr, port)
} else if err == nil {
config.Address = fmt.Sprintf("%s:%s", addr, port)
}
}
// requires secure connection?
if c.opts.Secure || c.opts.TLSConfig != nil {
if config.HttpClient == nil {
config.HttpClient = new(http.Client)
}
config.Scheme = "https"
// We're going to support InsecureSkipVerify
config.HttpClient.Transport = newTransport(c.opts.TLSConfig)
}
// set timeout
if c.opts.Timeout > 0 {
config.HttpClient.Timeout = c.opts.Timeout
}
// create the client
client, _ := consul.NewClient(config)
// set address/client
c.Address = config.Address
c.Client = client
}
func (c *consulRegistry) Init(opts ...registry.Option) error {
configure(c, opts...)
return nil
}
func (c *consulRegistry) Deregister(s *registry.Service) error {
if len(s.Nodes) == 0 {
return errors.New("Require at least one node")
}
// delete our hash and time check of the service
c.Lock()
delete(c.register, s.Name)
delete(c.lastChecked, s.Name)
c.Unlock()
node := s.Nodes[0]
return c.Client.Agent().ServiceDeregister(node.Id)
}
func (c *consulRegistry) Register(s *registry.Service, opts ...registry.RegisterOption) error {
if len(s.Nodes) == 0 {
return errors.New("Require at least one node")
}
var regTCPCheck bool
var regInterval time.Duration
var options registry.RegisterOptions
for _, o := range opts {
o(&options)
}
if c.opts.Context != nil {
if tcpCheckInterval, ok := c.opts.Context.Value("consul_tcp_check").(time.Duration); ok {
regTCPCheck = true
regInterval = tcpCheckInterval
}
}
// create hash of service; uint64
h, err := hash.Hash(s, nil)
if err != nil {
return err
}
// use first node
node := s.Nodes[0]
// get existing hash and last checked time
c.Lock()
v, ok := c.register[s.Name]
lastChecked := c.lastChecked[s.Name]
c.Unlock()
// if it's already registered and matches then just pass the check
if ok && v == h {
if options.TTL == time.Duration(0) {
// ensure that our service hasn't been deregistered by Consul
if time.Since(lastChecked) <= getDeregisterTTL(regInterval) {
return nil
}
services, _, err := c.Client.Health().Checks(s.Name, c.queryOptions)
if err == nil {
for _, v := range services {
if v.ServiceID == node.Id {
return nil
}
}
}
} else {
// if the err is nil we're all good, bail out
// if not, we don't know what the state is, so full re-register
if err := c.Client.Agent().PassTTL("service:"+node.Id, ""); err == nil {
return nil
}
}
}
// encode the tags
tags := encodeMetadata(node.Metadata)
tags = append(tags, encodeEndpoints(s.Endpoints)...)
tags = append(tags, encodeVersion(s.Version)...)
var check *consul.AgentServiceCheck
if regTCPCheck {
deregTTL := getDeregisterTTL(regInterval)
check = &consul.AgentServiceCheck{
TCP: fmt.Sprintf("%s:%d", node.Address, node.Port),
Interval: fmt.Sprintf("%v", regInterval),
DeregisterCriticalServiceAfter: fmt.Sprintf("%v", deregTTL),
}
// if the TTL is greater than 0 create an associated check
} else if options.TTL > time.Duration(0) {
deregTTL := getDeregisterTTL(options.TTL)
check = &consul.AgentServiceCheck{
TTL: fmt.Sprintf("%v", options.TTL),
DeregisterCriticalServiceAfter: fmt.Sprintf("%v", deregTTL),
}
}
// register the service
asr := &consul.AgentServiceRegistration{
ID: node.Id,
Name: s.Name,
Tags: tags,
Port: node.Port,
Address: node.Address,
Check: check,
}
// Specify consul connect
if c.connect {
asr.Connect = &consul.AgentServiceConnect{
Native: true,
}
}
if err := c.Client.Agent().ServiceRegister(asr); err != nil {
return err
}
// save our hash and time check of the service
c.Lock()
c.register[s.Name] = h
c.lastChecked[s.Name] = time.Now()
c.Unlock()
// if the TTL is 0 we don't mess with the checks
if options.TTL == time.Duration(0) {
return nil
}
// pass the healthcheck
return c.Client.Agent().PassTTL("service:"+node.Id, "")
}
func (c *consulRegistry) GetService(name string) ([]*registry.Service, error) {
var rsp []*consul.ServiceEntry
var err error
// if we're connect enabled only get connect services
if c.connect {
rsp, _, err = c.Client.Health().Connect(name, "", false, c.queryOptions)
} else {
rsp, _, err = c.Client.Health().Service(name, "", false, c.queryOptions)
}
if err != nil {
return nil, err
}
serviceMap := map[string]*registry.Service{}
for _, s := range rsp {
if s.Service.Service != name {
continue
}
// version is now a tag
version, _ := decodeVersion(s.Service.Tags)
// service ID is now the node id
id := s.Service.ID
// key is always the version
key := version
// address is service address
address := s.Service.Address
// use node address
if len(address) == 0 {
address = s.Node.Address
}
svc, ok := serviceMap[key]
if !ok {
svc = &registry.Service{
Endpoints: decodeEndpoints(s.Service.Tags),
Name: s.Service.Service,
Version: version,
}
serviceMap[key] = svc
}
var del bool
for _, check := range s.Checks {
// delete the node if the status is critical
if check.Status == "critical" {
del = true
break
}
}
// if delete then skip the node
if del {
continue
}
svc.Nodes = append(svc.Nodes, &registry.Node{
Id: id,
Address: address,
Port: s.Service.Port,
Metadata: decodeMetadata(s.Service.Tags),
})
}
var services []*registry.Service
for _, service := range serviceMap {
services = append(services, service)
}
return services, nil
}
func (c *consulRegistry) ListServices() ([]*registry.Service, error) {
rsp, _, err := c.Client.Catalog().Services(c.queryOptions)
if err != nil {
return nil, err
}
var services []*registry.Service
for service := range rsp {
services = append(services, &registry.Service{Name: service})
}
return services, nil
}
func (c *consulRegistry) Watch(opts ...registry.WatchOption) (registry.Watcher, error) {
return newConsulWatcher(c, opts...)
}
func (c *consulRegistry) String() string {
return "consul"
}
func (c *consulRegistry) Options() registry.Options {
return c.opts
}
func NewRegistry(opts ...registry.Option) registry.Registry {
cr := &consulRegistry{
opts: registry.Options{},
register: make(map[string]uint64),
lastChecked: make(map[string]time.Time),
queryOptions: &consul.QueryOptions{
AllowStale: true,
},
}
configure(cr, opts...)
return cr
} }

170
registry/consul/encoding.go Normal file
View File

@ -0,0 +1,170 @@
package consul
import (
"bytes"
"compress/zlib"
"encoding/hex"
"encoding/json"
"io/ioutil"
"github.com/micro/go-micro/registry"
)
func encode(buf []byte) string {
var b bytes.Buffer
defer b.Reset()
w := zlib.NewWriter(&b)
if _, err := w.Write(buf); err != nil {
return ""
}
w.Close()
return hex.EncodeToString(b.Bytes())
}
func decode(d string) []byte {
hr, err := hex.DecodeString(d)
if err != nil {
return nil
}
br := bytes.NewReader(hr)
zr, err := zlib.NewReader(br)
if err != nil {
return nil
}
rbuf, err := ioutil.ReadAll(zr)
if err != nil {
return nil
}
return rbuf
}
func encodeEndpoints(en []*registry.Endpoint) []string {
var tags []string
for _, e := range en {
if b, err := json.Marshal(e); err == nil {
tags = append(tags, "e-"+encode(b))
}
}
return tags
}
func decodeEndpoints(tags []string) []*registry.Endpoint {
var en []*registry.Endpoint
// use the first format you find
var ver byte
for _, tag := range tags {
if len(tag) == 0 || tag[0] != 'e' {
continue
}
// check version
if ver > 0 && tag[1] != ver {
continue
}
var e *registry.Endpoint
var buf []byte
// Old encoding was plain
if tag[1] == '=' {
buf = []byte(tag[2:])
}
// New encoding is hex
if tag[1] == '-' {
buf = decode(tag[2:])
}
if err := json.Unmarshal(buf, &e); err == nil {
en = append(en, e)
}
// set version
ver = tag[1]
}
return en
}
func encodeMetadata(md map[string]string) []string {
var tags []string
for k, v := range md {
if b, err := json.Marshal(map[string]string{
k: v,
}); err == nil {
// new encoding
tags = append(tags, "t-"+encode(b))
}
}
return tags
}
func decodeMetadata(tags []string) map[string]string {
md := make(map[string]string)
var ver byte
for _, tag := range tags {
if len(tag) == 0 || tag[0] != 't' {
continue
}
// check version
if ver > 0 && tag[1] != ver {
continue
}
var kv map[string]string
var buf []byte
// Old encoding was plain
if tag[1] == '=' {
buf = []byte(tag[2:])
}
// New encoding is hex
if tag[1] == '-' {
buf = decode(tag[2:])
}
// Now unmarshal
if err := json.Unmarshal(buf, &kv); err == nil {
for k, v := range kv {
md[k] = v
}
}
// set version
ver = tag[1]
}
return md
}
func encodeVersion(v string) []string {
return []string{"v-" + encode([]byte(v))}
}
func decodeVersion(tags []string) (string, bool) {
for _, tag := range tags {
if len(tag) < 2 || tag[0] != 'v' {
continue
}
// Old encoding was plain
if tag[1] == '=' {
return tag[2:], true
}
// New encoding is hex
if tag[1] == '-' {
return string(decode(tag[2:])), true
}
}
return "", false
}

View File

@ -0,0 +1,147 @@
package consul
import (
"encoding/json"
"testing"
"github.com/micro/go-micro/registry"
)
func TestEncodingEndpoints(t *testing.T) {
eps := []*registry.Endpoint{
&registry.Endpoint{
Name: "endpoint1",
Request: &registry.Value{
Name: "request",
Type: "request",
},
Response: &registry.Value{
Name: "response",
Type: "response",
},
Metadata: map[string]string{
"foo1": "bar1",
},
},
&registry.Endpoint{
Name: "endpoint2",
Request: &registry.Value{
Name: "request",
Type: "request",
},
Response: &registry.Value{
Name: "response",
Type: "response",
},
Metadata: map[string]string{
"foo2": "bar2",
},
},
&registry.Endpoint{
Name: "endpoint3",
Request: &registry.Value{
Name: "request",
Type: "request",
},
Response: &registry.Value{
Name: "response",
Type: "response",
},
Metadata: map[string]string{
"foo3": "bar3",
},
},
}
testEp := func(ep *registry.Endpoint, enc string) {
// encode endpoint
e := encodeEndpoints([]*registry.Endpoint{ep})
// check there are two tags; old and new
if len(e) != 1 {
t.Fatalf("Expected 1 encoded tags, got %v", e)
}
// check old encoding
var seen bool
for _, en := range e {
if en == enc {
seen = true
break
}
}
if !seen {
t.Fatalf("Expected %s but not found", enc)
}
// decode
d := decodeEndpoints([]string{enc})
if len(d) == 0 {
t.Fatalf("Expected %v got %v", ep, d)
}
// check name
if d[0].Name != ep.Name {
t.Fatalf("Expected ep %s got %s", ep.Name, d[0].Name)
}
// check all the metadata exists
for k, v := range ep.Metadata {
if gv := d[0].Metadata[k]; gv != v {
t.Fatalf("Expected key %s val %s got val %s", k, v, gv)
}
}
}
for _, ep := range eps {
// JSON encoded
jencoded, err := json.Marshal(ep)
if err != nil {
t.Fatal(err)
}
// HEX encoded
hencoded := encode(jencoded)
// endpoint tag
hepTag := "e-" + hencoded
testEp(ep, hepTag)
}
}
func TestEncodingVersion(t *testing.T) {
testData := []struct {
decoded string
encoded string
}{
{"1.0.0", "v-789c32d433d03300040000ffff02ce00ee"},
{"latest", "v-789cca492c492d2e01040000ffff08cc028e"},
}
for _, data := range testData {
e := encodeVersion(data.decoded)
if e[0] != data.encoded {
t.Fatalf("Expected %s got %s", data.encoded, e)
}
d, ok := decodeVersion(e)
if !ok {
t.Fatalf("Unexpected %t for %s", ok, data.encoded)
}
if d != data.decoded {
t.Fatalf("Expected %s got %s", data.decoded, d)
}
d, ok = decodeVersion([]string{data.encoded})
if !ok {
t.Fatalf("Unexpected %t for %s", ok, data.encoded)
}
if d != data.decoded {
t.Fatalf("Expected %s got %s", data.decoded, d)
}
}
}

View File

@ -1,4 +1,4 @@
package registry package consul
import ( import (
"bytes" "bytes"
@ -10,6 +10,7 @@ import (
"time" "time"
consul "github.com/hashicorp/consul/api" consul "github.com/hashicorp/consul/api"
"github.com/micro/go-micro/registry"
) )
type mockRegistry struct { type mockRegistry struct {
@ -56,7 +57,7 @@ func newConsulTestRegistry(r *mockRegistry) (*consulRegistry, func()) {
return &consulRegistry{ return &consulRegistry{
Address: cfg.Address, Address: cfg.Address,
Client: cl, Client: cl,
opts: Options{}, opts: registry.Options{},
register: make(map[string]uint64), register: make(map[string]uint64),
lastChecked: make(map[string]time.Time), lastChecked: make(map[string]time.Time),
queryOptions: &consul.QueryOptions{ queryOptions: &consul.QueryOptions{

View File

@ -1,4 +1,4 @@
package registry package consul
import ( import (
"errors" "errors"
@ -6,23 +6,24 @@ import (
"github.com/hashicorp/consul/api" "github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/watch" "github.com/hashicorp/consul/watch"
"github.com/micro/go-micro/registry"
) )
type consulWatcher struct { type consulWatcher struct {
r *consulRegistry r *consulRegistry
wo WatchOptions wo registry.WatchOptions
wp *watch.Plan wp *watch.Plan
watchers map[string]*watch.Plan watchers map[string]*watch.Plan
next chan *Result next chan *registry.Result
exit chan bool exit chan bool
sync.RWMutex sync.RWMutex
services map[string][]*Service services map[string][]*registry.Service
} }
func newConsulWatcher(cr *consulRegistry, opts ...WatchOption) (Watcher, error) { func newConsulWatcher(cr *consulRegistry, opts ...registry.WatchOption) (registry.Watcher, error) {
var wo WatchOptions var wo registry.WatchOptions
for _, o := range opts { for _, o := range opts {
o(&wo) o(&wo)
} }
@ -31,9 +32,9 @@ func newConsulWatcher(cr *consulRegistry, opts ...WatchOption) (Watcher, error)
r: cr, r: cr,
wo: wo, wo: wo,
exit: make(chan bool), exit: make(chan bool),
next: make(chan *Result, 10), next: make(chan *registry.Result, 10),
watchers: make(map[string]*watch.Plan), watchers: make(map[string]*watch.Plan),
services: make(map[string][]*Service), services: make(map[string][]*registry.Service),
} }
wp, err := watch.Parse(map[string]interface{}{"type": "services"}) wp, err := watch.Parse(map[string]interface{}{"type": "services"})
@ -54,7 +55,7 @@ func (cw *consulWatcher) serviceHandler(idx uint64, data interface{}) {
return return
} }
serviceMap := map[string]*Service{} serviceMap := map[string]*registry.Service{}
serviceName := "" serviceName := ""
for _, e := range entries { for _, e := range entries {
@ -75,7 +76,7 @@ func (cw *consulWatcher) serviceHandler(idx uint64, data interface{}) {
svc, ok := serviceMap[key] svc, ok := serviceMap[key]
if !ok { if !ok {
svc = &Service{ svc = &registry.Service{
Endpoints: decodeEndpoints(e.Service.Tags), Endpoints: decodeEndpoints(e.Service.Tags),
Name: e.Service.Service, Name: e.Service.Service,
Version: version, Version: version,
@ -98,7 +99,7 @@ func (cw *consulWatcher) serviceHandler(idx uint64, data interface{}) {
continue continue
} }
svc.Nodes = append(svc.Nodes, &Node{ svc.Nodes = append(svc.Nodes, &registry.Node{
Id: id, Id: id,
Address: address, Address: address,
Port: e.Service.Port, Port: e.Service.Port,
@ -108,13 +109,13 @@ func (cw *consulWatcher) serviceHandler(idx uint64, data interface{}) {
cw.RLock() cw.RLock()
// make a copy // make a copy
rservices := make(map[string][]*Service) rservices := make(map[string][]*registry.Service)
for k, v := range cw.services { for k, v := range cw.services {
rservices[k] = v rservices[k] = v
} }
cw.RUnlock() cw.RUnlock()
var newServices []*Service var newServices []*registry.Service
// serviceMap is the new set of services keyed by name+version // serviceMap is the new set of services keyed by name+version
for _, newService := range serviceMap { for _, newService := range serviceMap {
@ -125,7 +126,7 @@ func (cw *consulWatcher) serviceHandler(idx uint64, data interface{}) {
oldServices, ok := rservices[serviceName] oldServices, ok := rservices[serviceName]
if !ok { if !ok {
// does not exist? then we're creating brand new entries // does not exist? then we're creating brand new entries
cw.next <- &Result{Action: "create", Service: newService} cw.next <- &registry.Result{Action: "create", Service: newService}
continue continue
} }
@ -142,7 +143,7 @@ func (cw *consulWatcher) serviceHandler(idx uint64, data interface{}) {
// yes? then it's an update // yes? then it's an update
action = "update" action = "update"
var nodes []*Node var nodes []*registry.Node
// check the old nodes to see if they've been deleted // check the old nodes to see if they've been deleted
for _, oldNode := range oldService.Nodes { for _, oldNode := range oldService.Nodes {
var seen bool var seen bool
@ -163,11 +164,11 @@ func (cw *consulWatcher) serviceHandler(idx uint64, data interface{}) {
if len(nodes) > 0 { if len(nodes) > 0 {
delService := oldService delService := oldService
delService.Nodes = nodes delService.Nodes = nodes
cw.next <- &Result{Action: "delete", Service: delService} cw.next <- &registry.Result{Action: "delete", Service: delService}
} }
} }
cw.next <- &Result{Action: action, Service: newService} cw.next <- &registry.Result{Action: action, Service: newService}
} }
// Now check old versions that may not be in new services map // Now check old versions that may not be in new services map
@ -175,7 +176,7 @@ func (cw *consulWatcher) serviceHandler(idx uint64, data interface{}) {
// old version does not exist in new version map // old version does not exist in new version map
// kill it with fire! // kill it with fire!
if _, ok := serviceMap[old.Version]; !ok { if _, ok := serviceMap[old.Version]; !ok {
cw.next <- &Result{Action: "delete", Service: old} cw.next <- &registry.Result{Action: "delete", Service: old}
} }
} }
@ -209,13 +210,13 @@ func (cw *consulWatcher) handle(idx uint64, data interface{}) {
wp.Handler = cw.serviceHandler wp.Handler = cw.serviceHandler
go wp.Run(cw.r.Address) go wp.Run(cw.r.Address)
cw.watchers[service] = wp cw.watchers[service] = wp
cw.next <- &Result{Action: "create", Service: &Service{Name: service}} cw.next <- &registry.Result{Action: "create", Service: &registry.Service{Name: service}}
} }
} }
cw.RLock() cw.RLock()
// make a copy // make a copy
rservices := make(map[string][]*Service) rservices := make(map[string][]*registry.Service)
for k, v := range cw.services { for k, v := range cw.services {
rservices[k] = v rservices[k] = v
} }
@ -235,12 +236,12 @@ func (cw *consulWatcher) handle(idx uint64, data interface{}) {
if _, ok := services[service]; !ok { if _, ok := services[service]; !ok {
w.Stop() w.Stop()
delete(cw.watchers, service) delete(cw.watchers, service)
cw.next <- &Result{Action: "delete", Service: &Service{Name: service}} cw.next <- &registry.Result{Action: "delete", Service: &registry.Service{Name: service}}
} }
} }
} }
func (cw *consulWatcher) Next() (*Result, error) { func (cw *consulWatcher) Next() (*registry.Result, error) {
select { select {
case <-cw.exit: case <-cw.exit:
return nil, errors.New("result chan closed") return nil, errors.New("result chan closed")

View File

@ -1,9 +1,10 @@
package registry package consul
import ( import (
"testing" "testing"
"github.com/hashicorp/consul/api" "github.com/hashicorp/consul/api"
"github.com/micro/go-micro/registry"
) )
func TestHealthyServiceHandler(t *testing.T) { func TestHealthyServiceHandler(t *testing.T) {
@ -58,8 +59,8 @@ func TestUnhealthyNodeServiceHandler(t *testing.T) {
func newWatcher() *consulWatcher { func newWatcher() *consulWatcher {
return &consulWatcher{ return &consulWatcher{
exit: make(chan bool), exit: make(chan bool),
next: make(chan *Result, 10), next: make(chan *registry.Result, 10),
services: make(map[string][]*Service), services: make(map[string][]*registry.Service),
} }
} }

View File

@ -1,386 +0,0 @@
package registry
import (
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"runtime"
"sync"
"time"
consul "github.com/hashicorp/consul/api"
hash "github.com/mitchellh/hashstructure"
)
type consulRegistry struct {
Address string
Client *consul.Client
opts Options
// connect enabled
connect bool
queryOptions *consul.QueryOptions
sync.Mutex
register map[string]uint64
// lastChecked tracks when a node was last checked as existing in Consul
lastChecked map[string]time.Time
}
func getDeregisterTTL(t time.Duration) time.Duration {
// splay slightly for the watcher?
splay := time.Second * 5
deregTTL := t + splay
// consul has a minimum timeout on deregistration of 1 minute.
if t < time.Minute {
deregTTL = time.Minute + splay
}
return deregTTL
}
func newTransport(config *tls.Config) *http.Transport {
if config == nil {
config = &tls.Config{
InsecureSkipVerify: true,
}
}
t := &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial,
TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: config,
}
runtime.SetFinalizer(&t, func(tr **http.Transport) {
(*tr).CloseIdleConnections()
})
return t
}
func configure(c *consulRegistry, opts ...Option) {
// set opts
for _, o := range opts {
o(&c.opts)
}
// use default config
config := consul.DefaultConfig()
if c.opts.Context != nil {
// Use the consul config passed in the options, if available
if co, ok := c.opts.Context.Value("consul_config").(*consul.Config); ok {
config = co
}
if cn, ok := c.opts.Context.Value("consul_connect").(bool); ok {
c.connect = cn
}
// Use the consul query options passed in the options, if available
if qo, ok := c.opts.Context.Value("consul_query_options").(*consul.QueryOptions); ok && qo != nil {
c.queryOptions = qo
}
if as, ok := c.opts.Context.Value("consul_allow_stale").(bool); ok {
c.queryOptions.AllowStale = as
}
}
// check if there are any addrs
if len(c.opts.Addrs) > 0 {
addr, port, err := net.SplitHostPort(c.opts.Addrs[0])
if ae, ok := err.(*net.AddrError); ok && ae.Err == "missing port in address" {
port = "8500"
addr = c.opts.Addrs[0]
config.Address = fmt.Sprintf("%s:%s", addr, port)
} else if err == nil {
config.Address = fmt.Sprintf("%s:%s", addr, port)
}
}
// requires secure connection?
if c.opts.Secure || c.opts.TLSConfig != nil {
if config.HttpClient == nil {
config.HttpClient = new(http.Client)
}
config.Scheme = "https"
// We're going to support InsecureSkipVerify
config.HttpClient.Transport = newTransport(c.opts.TLSConfig)
}
// set timeout
if c.opts.Timeout > 0 {
config.HttpClient.Timeout = c.opts.Timeout
}
// create the client
client, _ := consul.NewClient(config)
// set address/client
c.Address = config.Address
c.Client = client
}
func newConsulRegistry(opts ...Option) Registry {
cr := &consulRegistry{
opts: Options{},
register: make(map[string]uint64),
lastChecked: make(map[string]time.Time),
queryOptions: &consul.QueryOptions{
AllowStale: true,
},
}
configure(cr, opts...)
return cr
}
func (c *consulRegistry) Init(opts ...Option) error {
configure(c, opts...)
return nil
}
func (c *consulRegistry) Deregister(s *Service) error {
if len(s.Nodes) == 0 {
return errors.New("Require at least one node")
}
// delete our hash and time check of the service
c.Lock()
delete(c.register, s.Name)
delete(c.lastChecked, s.Name)
c.Unlock()
node := s.Nodes[0]
return c.Client.Agent().ServiceDeregister(node.Id)
}
func (c *consulRegistry) Register(s *Service, opts ...RegisterOption) error {
if len(s.Nodes) == 0 {
return errors.New("Require at least one node")
}
var regTCPCheck bool
var regInterval time.Duration
var options RegisterOptions
for _, o := range opts {
o(&options)
}
if c.opts.Context != nil {
if tcpCheckInterval, ok := c.opts.Context.Value("consul_tcp_check").(time.Duration); ok {
regTCPCheck = true
regInterval = tcpCheckInterval
}
}
// create hash of service; uint64
h, err := hash.Hash(s, nil)
if err != nil {
return err
}
// use first node
node := s.Nodes[0]
// get existing hash and last checked time
c.Lock()
v, ok := c.register[s.Name]
lastChecked := c.lastChecked[s.Name]
c.Unlock()
// if it's already registered and matches then just pass the check
if ok && v == h {
if options.TTL == time.Duration(0) {
// ensure that our service hasn't been deregistered by Consul
if time.Since(lastChecked) <= getDeregisterTTL(regInterval) {
return nil
}
services, _, err := c.Client.Health().Checks(s.Name, c.queryOptions)
if err == nil {
for _, v := range services {
if v.ServiceID == node.Id {
return nil
}
}
}
} else {
// if the err is nil we're all good, bail out
// if not, we don't know what the state is, so full re-register
if err := c.Client.Agent().PassTTL("service:"+node.Id, ""); err == nil {
return nil
}
}
}
// encode the tags
tags := encodeMetadata(node.Metadata)
tags = append(tags, encodeEndpoints(s.Endpoints)...)
tags = append(tags, encodeVersion(s.Version)...)
var check *consul.AgentServiceCheck
if regTCPCheck {
deregTTL := getDeregisterTTL(regInterval)
check = &consul.AgentServiceCheck{
TCP: fmt.Sprintf("%s:%d", node.Address, node.Port),
Interval: fmt.Sprintf("%v", regInterval),
DeregisterCriticalServiceAfter: fmt.Sprintf("%v", deregTTL),
}
// if the TTL is greater than 0 create an associated check
} else if options.TTL > time.Duration(0) {
deregTTL := getDeregisterTTL(options.TTL)
check = &consul.AgentServiceCheck{
TTL: fmt.Sprintf("%v", options.TTL),
DeregisterCriticalServiceAfter: fmt.Sprintf("%v", deregTTL),
}
}
// register the service
asr := &consul.AgentServiceRegistration{
ID: node.Id,
Name: s.Name,
Tags: tags,
Port: node.Port,
Address: node.Address,
Check: check,
}
// Specify consul connect
if c.connect {
asr.Connect = &consul.AgentServiceConnect{
Native: true,
}
}
if err := c.Client.Agent().ServiceRegister(asr); err != nil {
return err
}
// save our hash and time check of the service
c.Lock()
c.register[s.Name] = h
c.lastChecked[s.Name] = time.Now()
c.Unlock()
// if the TTL is 0 we don't mess with the checks
if options.TTL == time.Duration(0) {
return nil
}
// pass the healthcheck
return c.Client.Agent().PassTTL("service:"+node.Id, "")
}
func (c *consulRegistry) GetService(name string) ([]*Service, error) {
var rsp []*consul.ServiceEntry
var err error
// if we're connect enabled only get connect services
if c.connect {
rsp, _, err = c.Client.Health().Connect(name, "", false, c.queryOptions)
} else {
rsp, _, err = c.Client.Health().Service(name, "", false, c.queryOptions)
}
if err != nil {
return nil, err
}
serviceMap := map[string]*Service{}
for _, s := range rsp {
if s.Service.Service != name {
continue
}
// version is now a tag
version, _ := decodeVersion(s.Service.Tags)
// service ID is now the node id
id := s.Service.ID
// key is always the version
key := version
// address is service address
address := s.Service.Address
// use node address
if len(address) == 0 {
address = s.Node.Address
}
svc, ok := serviceMap[key]
if !ok {
svc = &Service{
Endpoints: decodeEndpoints(s.Service.Tags),
Name: s.Service.Service,
Version: version,
}
serviceMap[key] = svc
}
var del bool
for _, check := range s.Checks {
// delete the node if the status is critical
if check.Status == "critical" {
del = true
break
}
}
// if delete then skip the node
if del {
continue
}
svc.Nodes = append(svc.Nodes, &Node{
Id: id,
Address: address,
Port: s.Service.Port,
Metadata: decodeMetadata(s.Service.Tags),
})
}
var services []*Service
for _, service := range serviceMap {
services = append(services, service)
}
return services, nil
}
func (c *consulRegistry) ListServices() ([]*Service, error) {
rsp, _, err := c.Client.Catalog().Services(c.queryOptions)
if err != nil {
return nil, err
}
var services []*Service
for service := range rsp {
services = append(services, &Service{Name: service})
}
return services, nil
}
func (c *consulRegistry) Watch(opts ...WatchOption) (Watcher, error) {
return newConsulWatcher(c, opts...)
}
func (c *consulRegistry) String() string {
return "consul"
}
func (c *consulRegistry) Options() Options {
return c.opts
}

View File

@ -6,163 +6,68 @@ import (
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"io/ioutil" "io/ioutil"
"strings"
) )
func encode(buf []byte) string { func encode(txt *mdnsTxt) ([]string, error) {
var b bytes.Buffer b, err := json.Marshal(txt)
defer b.Reset() if err != nil {
return nil, err
}
w := zlib.NewWriter(&b) var buf bytes.Buffer
if _, err := w.Write(buf); err != nil { defer buf.Reset()
return ""
w := zlib.NewWriter(&buf)
if _, err := w.Write(b); err != nil {
return nil, err
} }
w.Close() w.Close()
return hex.EncodeToString(b.Bytes()) encoded := hex.EncodeToString(buf.Bytes())
// individual txt limit
if len(encoded) <= 255 {
return []string{encoded}, nil
} }
func decode(d string) []byte { // split encoded string
hr, err := hex.DecodeString(d) var record []string
for len(encoded) > 255 {
record = append(record, encoded[:255])
encoded = encoded[255:]
}
record = append(record, encoded)
return record, nil
}
func decode(record []string) (*mdnsTxt, error) {
encoded := strings.Join(record, "")
hr, err := hex.DecodeString(encoded)
if err != nil { if err != nil {
return nil return nil, err
} }
br := bytes.NewReader(hr) br := bytes.NewReader(hr)
zr, err := zlib.NewReader(br) zr, err := zlib.NewReader(br)
if err != nil { if err != nil {
return nil return nil, err
} }
rbuf, err := ioutil.ReadAll(zr) rbuf, err := ioutil.ReadAll(zr)
if err != nil { if err != nil {
return nil return nil, err
} }
return rbuf var txt *mdnsTxt
if err := json.Unmarshal(rbuf, &txt); err != nil {
return nil, err
} }
func encodeEndpoints(en []*Endpoint) []string { return txt, nil
var tags []string
for _, e := range en {
if b, err := json.Marshal(e); err == nil {
tags = append(tags, "e-"+encode(b))
}
}
return tags
}
func decodeEndpoints(tags []string) []*Endpoint {
var en []*Endpoint
// use the first format you find
var ver byte
for _, tag := range tags {
if len(tag) == 0 || tag[0] != 'e' {
continue
}
// check version
if ver > 0 && tag[1] != ver {
continue
}
var e *Endpoint
var buf []byte
// Old encoding was plain
if tag[1] == '=' {
buf = []byte(tag[2:])
}
// New encoding is hex
if tag[1] == '-' {
buf = decode(tag[2:])
}
if err := json.Unmarshal(buf, &e); err == nil {
en = append(en, e)
}
// set version
ver = tag[1]
}
return en
}
func encodeMetadata(md map[string]string) []string {
var tags []string
for k, v := range md {
if b, err := json.Marshal(map[string]string{
k: v,
}); err == nil {
// new encoding
tags = append(tags, "t-"+encode(b))
}
}
return tags
}
func decodeMetadata(tags []string) map[string]string {
md := make(map[string]string)
var ver byte
for _, tag := range tags {
if len(tag) == 0 || tag[0] != 't' {
continue
}
// check version
if ver > 0 && tag[1] != ver {
continue
}
var kv map[string]string
var buf []byte
// Old encoding was plain
if tag[1] == '=' {
buf = []byte(tag[2:])
}
// New encoding is hex
if tag[1] == '-' {
buf = decode(tag[2:])
}
// Now unmarshal
if err := json.Unmarshal(buf, &kv); err == nil {
for k, v := range kv {
md[k] = v
}
}
// set version
ver = tag[1]
}
return md
}
func encodeVersion(v string) []string {
return []string{"v-" + encode([]byte(v))}
}
func decodeVersion(tags []string) (string, bool) {
for _, tag := range tags {
if len(tag) < 2 || tag[0] != 'v' {
continue
}
// Old encoding was plain
if tag[1] == '=' {
return tag[2:], true
}
// New encoding is hex
if tag[1] == '-' {
return string(decode(tag[2:])), true
}
}
return "", false
} }

View File

@ -1,13 +1,17 @@
package registry package registry
import ( import (
"encoding/json"
"testing" "testing"
) )
func TestEncodingEndpoints(t *testing.T) { func TestEncoding(t *testing.T) {
eps := []*Endpoint{ testData := []*mdnsTxt{
&mdnsTxt{
Version: "1.0.0",
Metadata: map[string]string{
"foo": "bar",
},
Endpoints: []*Endpoint{
&Endpoint{ &Endpoint{
Name: "endpoint1", Name: "endpoint1",
Request: &Value{ Request: &Value{
@ -22,125 +26,40 @@ func TestEncodingEndpoints(t *testing.T) {
"foo1": "bar1", "foo1": "bar1",
}, },
}, },
&Endpoint{
Name: "endpoint2",
Request: &Value{
Name: "request",
Type: "request",
},
Response: &Value{
Name: "response",
Type: "response",
},
Metadata: map[string]string{
"foo2": "bar2",
},
},
&Endpoint{
Name: "endpoint3",
Request: &Value{
Name: "request",
Type: "request",
},
Response: &Value{
Name: "response",
Type: "response",
},
Metadata: map[string]string{
"foo3": "bar3",
}, },
}, },
} }
testEp := func(ep *Endpoint, enc string) { for _, d := range testData {
// encode endpoint encoded, err := encode(d)
e := encodeEndpoints([]*Endpoint{ep})
// check there are two tags; old and new
if len(e) != 1 {
t.Fatalf("Expected 1 encoded tags, got %v", e)
}
// check old encoding
var seen bool
for _, en := range e {
if en == enc {
seen = true
break
}
}
if !seen {
t.Fatalf("Expected %s but not found", enc)
}
// decode
d := decodeEndpoints([]string{enc})
if len(d) == 0 {
t.Fatalf("Expected %v got %v", ep, d)
}
// check name
if d[0].Name != ep.Name {
t.Fatalf("Expected ep %s got %s", ep.Name, d[0].Name)
}
// check all the metadata exists
for k, v := range ep.Metadata {
if gv := d[0].Metadata[k]; gv != v {
t.Fatalf("Expected key %s val %s got val %s", k, v, gv)
}
}
}
for _, ep := range eps {
// JSON encoded
jencoded, err := json.Marshal(ep)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// HEX encoded for _, txt := range encoded {
hencoded := encode(jencoded) if len(txt) > 255 {
// endpoint tag t.Fatalf("One of parts for txt is %d characters", len(txt))
hepTag := "e-" + hencoded
testEp(ep, hepTag)
} }
} }
func TestEncodingVersion(t *testing.T) { decoded, err := decode(encoded)
testData := []struct { if err != nil {
decoded string t.Fatal(err)
encoded string
}{
{"1.0.0", "v-789c32d433d03300040000ffff02ce00ee"},
{"latest", "v-789cca492c492d2e01040000ffff08cc028e"},
} }
for _, data := range testData { if decoded.Version != d.Version {
e := encodeVersion(data.decoded) t.Fatalf("Expected version %s got %s", d.Version, decoded.Version)
if e[0] != data.encoded {
t.Fatalf("Expected %s got %s", data.encoded, e)
} }
d, ok := decodeVersion(e) if len(decoded.Endpoints) != len(d.Endpoints) {
if !ok { t.Fatalf("Expected %d endpoints, got %d", len(d.Endpoints), len(decoded.Endpoints))
t.Fatalf("Unexpected %t for %s", ok, data.encoded)
} }
if d != data.decoded { for k, v := range d.Metadata {
t.Fatalf("Expected %s got %s", data.decoded, d) if val := decoded.Metadata[k]; val != v {
t.Fatalf("Expected %s=%s got %s=%s", k, v, k, val)
}
}
} }
d, ok = decodeVersion([]string{data.encoded})
if !ok {
t.Fatalf("Unexpected %t for %s", ok, data.encoded)
}
if d != data.decoded {
t.Fatalf("Expected %s got %s", data.decoded, d)
}
}
} }

View File

@ -1,73 +0,0 @@
package mdns
import (
"bytes"
"compress/zlib"
"encoding/hex"
"encoding/json"
"io/ioutil"
"strings"
)
func encode(txt *mdnsTxt) ([]string, error) {
b, err := json.Marshal(txt)
if err != nil {
return nil, err
}
var buf bytes.Buffer
defer buf.Reset()
w := zlib.NewWriter(&buf)
if _, err := w.Write(b); err != nil {
return nil, err
}
w.Close()
encoded := hex.EncodeToString(buf.Bytes())
// individual txt limit
if len(encoded) <= 255 {
return []string{encoded}, nil
}
// split encoded string
var record []string
for len(encoded) > 255 {
record = append(record, encoded[:255])
encoded = encoded[255:]
}
record = append(record, encoded)
return record, nil
}
func decode(record []string) (*mdnsTxt, error) {
encoded := strings.Join(record, "")
hr, err := hex.DecodeString(encoded)
if err != nil {
return nil, err
}
br := bytes.NewReader(hr)
zr, err := zlib.NewReader(br)
if err != nil {
return nil, err
}
rbuf, err := ioutil.ReadAll(zr)
if err != nil {
return nil, err
}
var txt *mdnsTxt
if err := json.Unmarshal(rbuf, &txt); err != nil {
return nil, err
}
return txt, nil
}

View File

@ -1,67 +0,0 @@
package mdns
import (
"testing"
"github.com/micro/go-micro/registry"
)
func TestEncoding(t *testing.T) {
testData := []*mdnsTxt{
&mdnsTxt{
Version: "1.0.0",
Metadata: map[string]string{
"foo": "bar",
},
Endpoints: []*registry.Endpoint{
&registry.Endpoint{
Name: "endpoint1",
Request: &registry.Value{
Name: "request",
Type: "request",
},
Response: &registry.Value{
Name: "response",
Type: "response",
},
Metadata: map[string]string{
"foo1": "bar1",
},
},
},
},
}
for _, d := range testData {
encoded, err := encode(d)
if err != nil {
t.Fatal(err)
}
for _, txt := range encoded {
if len(txt) > 255 {
t.Fatalf("One of parts for txt is %d characters", len(txt))
}
}
decoded, err := decode(encoded)
if err != nil {
t.Fatal(err)
}
if decoded.Version != d.Version {
t.Fatalf("Expected version %s got %s", d.Version, decoded.Version)
}
if len(decoded.Endpoints) != len(d.Endpoints) {
t.Fatalf("Expected %d endpoints, got %d", len(d.Endpoints), len(decoded.Endpoints))
}
for k, v := range d.Metadata {
if val := decoded.Metadata[k]; val != v {
t.Fatalf("Expected %s=%s got %s=%s", k, v, k, val)
}
}
}
}

View File

@ -1,339 +1,11 @@
// Package mdns is a multicast dns registry // Package mdns provides a multicast dns registry
package mdns package mdns
/*
MDNS is a multicast dns registry for service discovery
This creates a zero dependency system which is great
where multicast dns is available. This usually depends
on the ability to leverage udp and multicast/broadcast.
*/
import ( import (
"net"
"strings"
"sync"
"time"
"github.com/micro/go-micro/registry" "github.com/micro/go-micro/registry"
"github.com/micro/mdns"
hash "github.com/mitchellh/hashstructure"
) )
type mdnsTxt struct { // NewRegistry returns a new mdns registry
Service string
Version string
Endpoints []*registry.Endpoint
Metadata map[string]string
}
type mdnsEntry struct {
hash uint64
id string
node *mdns.Server
}
type mdnsRegistry struct {
opts registry.Options
sync.Mutex
services map[string][]*mdnsEntry
}
func newRegistry(opts ...registry.Option) registry.Registry {
options := registry.Options{
Timeout: time.Millisecond * 100,
}
return &mdnsRegistry{
opts: options,
services: make(map[string][]*mdnsEntry),
}
}
func (m *mdnsRegistry) Init(opts ...registry.Option) error {
for _, o := range opts {
o(&m.opts)
}
return nil
}
func (m *mdnsRegistry) Options() registry.Options {
return m.opts
}
func (m *mdnsRegistry) Register(service *registry.Service, opts ...registry.RegisterOption) error {
m.Lock()
defer m.Unlock()
entries, ok := m.services[service.Name]
// first entry, create wildcard used for list queries
if !ok {
s, err := mdns.NewMDNSService(
service.Name,
"_services",
"",
"",
9999,
[]net.IP{net.ParseIP("0.0.0.0")},
nil,
)
if err != nil {
return err
}
srv, err := mdns.NewServer(&mdns.Config{Zone: &mdns.DNSSDService{s}})
if err != nil {
return err
}
// append the wildcard entry
entries = append(entries, &mdnsEntry{id: "*", node: srv})
}
var gerr error
for _, node := range service.Nodes {
// create hash of service; uint64
h, err := hash.Hash(node, nil)
if err != nil {
gerr = err
continue
}
var seen bool
var e *mdnsEntry
for _, entry := range entries {
if node.Id == entry.id {
seen = true
e = entry
break
}
}
// already registered, continue
if seen && e.hash == h {
continue
// hash doesn't match, shutdown
} else if seen {
e.node.Shutdown()
// doesn't exist
} else {
e = &mdnsEntry{hash: h}
}
txt, err := encode(&mdnsTxt{
Service: service.Name,
Version: service.Version,
Endpoints: service.Endpoints,
Metadata: node.Metadata,
})
if err != nil {
gerr = err
continue
}
// we got here, new node
s, err := mdns.NewMDNSService(
node.Id,
service.Name,
"",
"",
node.Port,
[]net.IP{net.ParseIP(node.Address)},
txt,
)
if err != nil {
gerr = err
continue
}
srv, err := mdns.NewServer(&mdns.Config{Zone: s})
if err != nil {
gerr = err
continue
}
e.id = node.Id
e.node = srv
entries = append(entries, e)
}
// save
m.services[service.Name] = entries
return gerr
}
func (m *mdnsRegistry) Deregister(service *registry.Service) error {
m.Lock()
defer m.Unlock()
var newEntries []*mdnsEntry
// loop existing entries, check if any match, shutdown those that do
for _, entry := range m.services[service.Name] {
var remove bool
for _, node := range service.Nodes {
if node.Id == entry.id {
entry.node.Shutdown()
remove = true
break
}
}
// keep it?
if !remove {
newEntries = append(newEntries, entry)
}
}
// last entry is the wildcard for list queries. Remove it.
if len(newEntries) == 1 && newEntries[0].id == "*" {
newEntries[0].node.Shutdown()
delete(m.services, service.Name)
} else {
m.services[service.Name] = newEntries
}
return nil
}
func (m *mdnsRegistry) GetService(service string) ([]*registry.Service, error) {
p := mdns.DefaultParams(service)
p.Timeout = m.opts.Timeout
entryCh := make(chan *mdns.ServiceEntry, 10)
p.Entries = entryCh
exit := make(chan bool)
defer close(exit)
serviceMap := make(map[string]*registry.Service)
go func() {
for {
select {
case e := <-entryCh:
// list record so skip
if p.Service == "_services" {
continue
}
if e.TTL == 0 {
continue
}
txt, err := decode(e.InfoFields)
if err != nil {
continue
}
if txt.Service != service {
continue
}
s, ok := serviceMap[txt.Version]
if !ok {
s = &registry.Service{
Name: txt.Service,
Version: txt.Version,
Endpoints: txt.Endpoints,
}
}
s.Nodes = append(s.Nodes, &registry.Node{
Id: strings.TrimSuffix(e.Name, "."+p.Service+"."+p.Domain+"."),
Address: e.AddrV4.String(),
Port: e.Port,
Metadata: txt.Metadata,
})
serviceMap[txt.Version] = s
case <-exit:
return
}
}
}()
if err := mdns.Query(p); err != nil {
return nil, err
}
// create list and return
var services []*registry.Service
for _, service := range serviceMap {
services = append(services, service)
}
return services, nil
}
func (m *mdnsRegistry) ListServices() ([]*registry.Service, error) {
p := mdns.DefaultParams("_services")
p.Timeout = m.opts.Timeout
entryCh := make(chan *mdns.ServiceEntry, 10)
p.Entries = entryCh
exit := make(chan bool)
defer close(exit)
serviceMap := make(map[string]bool)
var services []*registry.Service
go func() {
for {
select {
case e := <-entryCh:
if e.TTL == 0 {
continue
}
name := strings.TrimSuffix(e.Name, "."+p.Service+"."+p.Domain+".")
if !serviceMap[name] {
serviceMap[name] = true
services = append(services, &registry.Service{Name: name})
}
case <-exit:
return
}
}
}()
if err := mdns.Query(p); err != nil {
return nil, err
}
return services, nil
}
func (m *mdnsRegistry) Watch(opts ...registry.WatchOption) (registry.Watcher, error) {
var wo registry.WatchOptions
for _, o := range opts {
o(&wo)
}
md := &mdnsWatcher{
wo: wo,
ch: make(chan *mdns.ServiceEntry, 32),
exit: make(chan struct{}),
}
go func() {
if err := mdns.Listen(md.ch, md.exit); err != nil {
md.Stop()
}
}()
return md, nil
}
func (m *mdnsRegistry) String() string {
return "mdns"
}
func NewRegistry(opts ...registry.Option) registry.Registry { func NewRegistry(opts ...registry.Option) registry.Registry {
return newRegistry(opts...) return registry.NewRegistry(opts...)
} }

332
registry/mdns_registry.go Normal file
View File

@ -0,0 +1,332 @@
// Package mdns is a multicast dns registry
package registry
import (
"net"
"strings"
"sync"
"time"
"github.com/micro/mdns"
hash "github.com/mitchellh/hashstructure"
)
type mdnsTxt struct {
Service string
Version string
Endpoints []*Endpoint
Metadata map[string]string
}
type mdnsEntry struct {
hash uint64
id string
node *mdns.Server
}
type mdnsRegistry struct {
opts Options
sync.Mutex
services map[string][]*mdnsEntry
}
func newRegistry(opts ...Option) Registry {
options := Options{
Timeout: time.Millisecond * 100,
}
return &mdnsRegistry{
opts: options,
services: make(map[string][]*mdnsEntry),
}
}
func (m *mdnsRegistry) Init(opts ...Option) error {
for _, o := range opts {
o(&m.opts)
}
return nil
}
func (m *mdnsRegistry) Options() Options {
return m.opts
}
func (m *mdnsRegistry) Register(service *Service, opts ...RegisterOption) error {
m.Lock()
defer m.Unlock()
entries, ok := m.services[service.Name]
// first entry, create wildcard used for list queries
if !ok {
s, err := mdns.NewMDNSService(
service.Name,
"_services",
"",
"",
9999,
[]net.IP{net.ParseIP("0.0.0.0")},
nil,
)
if err != nil {
return err
}
srv, err := mdns.NewServer(&mdns.Config{Zone: &mdns.DNSSDService{s}})
if err != nil {
return err
}
// append the wildcard entry
entries = append(entries, &mdnsEntry{id: "*", node: srv})
}
var gerr error
for _, node := range service.Nodes {
// create hash of service; uint64
h, err := hash.Hash(node, nil)
if err != nil {
gerr = err
continue
}
var seen bool
var e *mdnsEntry
for _, entry := range entries {
if node.Id == entry.id {
seen = true
e = entry
break
}
}
// already registered, continue
if seen && e.hash == h {
continue
// hash doesn't match, shutdown
} else if seen {
e.node.Shutdown()
// doesn't exist
} else {
e = &mdnsEntry{hash: h}
}
txt, err := encode(&mdnsTxt{
Service: service.Name,
Version: service.Version,
Endpoints: service.Endpoints,
Metadata: node.Metadata,
})
if err != nil {
gerr = err
continue
}
// we got here, new node
s, err := mdns.NewMDNSService(
node.Id,
service.Name,
"",
"",
node.Port,
[]net.IP{net.ParseIP(node.Address)},
txt,
)
if err != nil {
gerr = err
continue
}
srv, err := mdns.NewServer(&mdns.Config{Zone: s})
if err != nil {
gerr = err
continue
}
e.id = node.Id
e.node = srv
entries = append(entries, e)
}
// save
m.services[service.Name] = entries
return gerr
}
func (m *mdnsRegistry) Deregister(service *Service) error {
m.Lock()
defer m.Unlock()
var newEntries []*mdnsEntry
// loop existing entries, check if any match, shutdown those that do
for _, entry := range m.services[service.Name] {
var remove bool
for _, node := range service.Nodes {
if node.Id == entry.id {
entry.node.Shutdown()
remove = true
break
}
}
// keep it?
if !remove {
newEntries = append(newEntries, entry)
}
}
// last entry is the wildcard for list queries. Remove it.
if len(newEntries) == 1 && newEntries[0].id == "*" {
newEntries[0].node.Shutdown()
delete(m.services, service.Name)
} else {
m.services[service.Name] = newEntries
}
return nil
}
func (m *mdnsRegistry) GetService(service string) ([]*Service, error) {
p := mdns.DefaultParams(service)
p.Timeout = m.opts.Timeout
entryCh := make(chan *mdns.ServiceEntry, 10)
p.Entries = entryCh
exit := make(chan bool)
defer close(exit)
serviceMap := make(map[string]*Service)
go func() {
for {
select {
case e := <-entryCh:
// list record so skip
if p.Service == "_services" {
continue
}
if e.TTL == 0 {
continue
}
txt, err := decode(e.InfoFields)
if err != nil {
continue
}
if txt.Service != service {
continue
}
s, ok := serviceMap[txt.Version]
if !ok {
s = &Service{
Name: txt.Service,
Version: txt.Version,
Endpoints: txt.Endpoints,
}
}
s.Nodes = append(s.Nodes, &Node{
Id: strings.TrimSuffix(e.Name, "."+p.Service+"."+p.Domain+"."),
Address: e.AddrV4.String(),
Port: e.Port,
Metadata: txt.Metadata,
})
serviceMap[txt.Version] = s
case <-exit:
return
}
}
}()
if err := mdns.Query(p); err != nil {
return nil, err
}
// create list and return
var services []*Service
for _, service := range serviceMap {
services = append(services, service)
}
return services, nil
}
func (m *mdnsRegistry) ListServices() ([]*Service, error) {
p := mdns.DefaultParams("_services")
p.Timeout = m.opts.Timeout
entryCh := make(chan *mdns.ServiceEntry, 10)
p.Entries = entryCh
exit := make(chan bool)
defer close(exit)
serviceMap := make(map[string]bool)
var services []*Service
go func() {
for {
select {
case e := <-entryCh:
if e.TTL == 0 {
continue
}
name := strings.TrimSuffix(e.Name, "."+p.Service+"."+p.Domain+".")
if !serviceMap[name] {
serviceMap[name] = true
services = append(services, &Service{Name: name})
}
case <-exit:
return
}
}
}()
if err := mdns.Query(p); err != nil {
return nil, err
}
return services, nil
}
func (m *mdnsRegistry) Watch(opts ...WatchOption) (Watcher, error) {
var wo WatchOptions
for _, o := range opts {
o(&wo)
}
md := &mdnsWatcher{
wo: wo,
ch: make(chan *mdns.ServiceEntry, 32),
exit: make(chan struct{}),
}
go func() {
if err := mdns.Listen(md.ch, md.exit); err != nil {
md.Stop()
}
}()
return md, nil
}
func (m *mdnsRegistry) String() string {
return "mdns"
}
// NewRegistry returns a new default registry which is mdns
func NewRegistry(opts ...Option) Registry {
return newRegistry(opts...)
}

View File

@ -1,19 +1,17 @@
package mdns package registry
import ( import (
"testing" "testing"
"time" "time"
"github.com/micro/go-micro/registry"
) )
func TestMDNS(t *testing.T) { func TestMDNS(t *testing.T) {
testData := []*registry.Service{ testData := []*Service{
&registry.Service{ &Service{
Name: "test1", Name: "test1",
Version: "1.0.1", Version: "1.0.1",
Nodes: []*registry.Node{ Nodes: []*Node{
&registry.Node{ &Node{
Id: "test1-1", Id: "test1-1",
Address: "10.0.0.1", Address: "10.0.0.1",
Port: 10001, Port: 10001,
@ -23,11 +21,11 @@ func TestMDNS(t *testing.T) {
}, },
}, },
}, },
&registry.Service{ &Service{
Name: "test2", Name: "test2",
Version: "1.0.2", Version: "1.0.2",
Nodes: []*registry.Node{ Nodes: []*Node{
&registry.Node{ &Node{
Id: "test2-1", Id: "test2-1",
Address: "10.0.0.2", Address: "10.0.0.2",
Port: 10002, Port: 10002,
@ -37,11 +35,11 @@ func TestMDNS(t *testing.T) {
}, },
}, },
}, },
&registry.Service{ &Service{
Name: "test3", Name: "test3",
Version: "1.0.3", Version: "1.0.3",
Nodes: []*registry.Node{ Nodes: []*Node{
&registry.Node{ &Node{
Id: "test3-1", Id: "test3-1",
Address: "10.0.0.3", Address: "10.0.0.3",
Port: 10003, Port: 10003,

View File

@ -1,20 +1,19 @@
package mdns package registry
import ( import (
"errors" "errors"
"strings" "strings"
"github.com/micro/go-micro/registry"
"github.com/micro/mdns" "github.com/micro/mdns"
) )
type mdnsWatcher struct { type mdnsWatcher struct {
wo registry.WatchOptions wo WatchOptions
ch chan *mdns.ServiceEntry ch chan *mdns.ServiceEntry
exit chan struct{} exit chan struct{}
} }
func (m *mdnsWatcher) Next() (*registry.Result, error) { func (m *mdnsWatcher) Next() (*Result, error) {
for { for {
select { select {
case e := <-m.ch: case e := <-m.ch:
@ -41,7 +40,7 @@ func (m *mdnsWatcher) Next() (*registry.Result, error) {
action = "create" action = "create"
} }
service := &registry.Service{ service := &Service{
Name: txt.Service, Name: txt.Service,
Version: txt.Version, Version: txt.Version,
Endpoints: txt.Endpoints, Endpoints: txt.Endpoints,
@ -52,14 +51,14 @@ func (m *mdnsWatcher) Next() (*registry.Result, error) {
continue continue
} }
service.Nodes = append(service.Nodes, &registry.Node{ service.Nodes = append(service.Nodes, &Node{
Id: strings.TrimSuffix(e.Name, "."+service.Name+".local."), Id: strings.TrimSuffix(e.Name, "."+service.Name+".local."),
Address: e.AddrV4.String(), Address: e.AddrV4.String(),
Port: e.Port, Port: e.Port,
Metadata: txt.Metadata, Metadata: txt.Metadata,
}) })
return &registry.Result{ return &Result{
Action: action, Action: action,
Service: service, Service: service,
}, nil }, nil

View File

@ -26,7 +26,7 @@ type RegisterOption func(*RegisterOptions)
type WatchOption func(*WatchOptions) type WatchOption func(*WatchOptions)
var ( var (
DefaultRegistry = newConsulRegistry() DefaultRegistry = NewRegistry()
// Not found error when GetService is called // Not found error when GetService is called
ErrNotFound = errors.New("not found") ErrNotFound = errors.New("not found")
@ -34,10 +34,6 @@ var (
ErrWatcherStopped = errors.New("watcher stopped") ErrWatcherStopped = errors.New("watcher stopped")
) )
func NewRegistry(opts ...Option) Registry {
return newConsulRegistry(opts...)
}
// Register a service node. Additionally supply options such as TTL. // Register a service node. Additionally supply options such as TTL.
func Register(s *Service, opts ...RegisterOption) error { func Register(s *Service, opts ...RegisterOption) error {
return DefaultRegistry.Register(s, opts...) return DefaultRegistry.Register(s, opts...)

View File

@ -1,18 +1,16 @@
package mdns package registry
import ( import (
"testing" "testing"
"github.com/micro/go-micro/registry"
) )
func TestWatcher(t *testing.T) { func TestWatcher(t *testing.T) {
testData := []*registry.Service{ testData := []*Service{
&registry.Service{ &Service{
Name: "test1", Name: "test1",
Version: "1.0.1", Version: "1.0.1",
Nodes: []*registry.Node{ Nodes: []*Node{
&registry.Node{ &Node{
Id: "test1-1", Id: "test1-1",
Address: "10.0.0.1", Address: "10.0.0.1",
Port: 10001, Port: 10001,
@ -22,11 +20,11 @@ func TestWatcher(t *testing.T) {
}, },
}, },
}, },
&registry.Service{ &Service{
Name: "test2", Name: "test2",
Version: "1.0.2", Version: "1.0.2",
Nodes: []*registry.Node{ Nodes: []*Node{
&registry.Node{ &Node{
Id: "test2-1", Id: "test2-1",
Address: "10.0.0.2", Address: "10.0.0.2",
Port: 10002, Port: 10002,
@ -36,11 +34,11 @@ func TestWatcher(t *testing.T) {
}, },
}, },
}, },
&registry.Service{ &Service{
Name: "test3", Name: "test3",
Version: "1.0.3", Version: "1.0.3",
Nodes: []*registry.Node{ Nodes: []*Node{
&registry.Node{ &Node{
Id: "test3-1", Id: "test3-1",
Address: "10.0.0.3", Address: "10.0.0.3",
Port: 10003, Port: 10003,
@ -52,7 +50,7 @@ func TestWatcher(t *testing.T) {
}, },
} }
testFn := func(service, s *registry.Service) { testFn := func(service, s *Service) {
if s == nil { if s == nil {
t.Fatalf("Expected one result for %s got nil", service.Name) t.Fatalf("Expected one result for %s got nil", service.Name)