From 0b029da08c2d04dee3c6110bb2ece5537902abe8 Mon Sep 17 00:00:00 2001 From: Asim Date: Wed, 27 Apr 2016 18:21:05 +0100 Subject: [PATCH] First attempt at mdns --- encoding.go | 124 +++++++++++++++++++++ encoding_test.go | 147 ++++++++++++++++++++++++ mdns.go | 284 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 555 insertions(+) create mode 100644 encoding.go create mode 100644 encoding_test.go create mode 100644 mdns.go diff --git a/encoding.go b/encoding.go new file mode 100644 index 0000000..bb8feff --- /dev/null +++ b/encoding.go @@ -0,0 +1,124 @@ +package mdns + +import ( + "bytes" + "compress/zlib" + "encoding/hex" + "encoding/json" + "io/ioutil" + + "github.com/micro/go-micro/registry" +) + +func encode(buf []byte) string { + var b bytes.Buffer + defer b.Reset() + + w := zlib.NewWriter(&b) + if _, err := w.Write(buf); err != nil { + return "" + } + w.Close() + + return hex.EncodeToString(b.Bytes()) +} + +func decode(d string) []byte { + hr, err := hex.DecodeString(d) + if err != nil { + return nil + } + + br := bytes.NewReader(hr) + zr, err := zlib.NewReader(br) + if err != nil { + return nil + } + + rbuf, err := ioutil.ReadAll(zr) + if err != nil { + return nil + } + + return rbuf +} + +func encodeEndpoints(en []*registry.Endpoint) []string { + var tags []string + for _, e := range en { + if b, err := json.Marshal(e); err == nil { + tags = append(tags, "e-"+encode(b)) + } + } + return tags +} + +func decodeEndpoints(tags []string) []*registry.Endpoint { + var en []*registry.Endpoint + + for _, tag := range tags { + if len(tag) == 0 || tag[0] != 'e' || tag[1] != '-' { + continue + } + + buf := decode(tag[2:]) + + var e *registry.Endpoint + if err := json.Unmarshal(buf, &e); err == nil { + en = append(en, e) + } + } + return en +} + +func encodeMetadata(md map[string]string) []string { + var tags []string + for k, v := range md { + if b, err := json.Marshal(map[string]string{ + k: v, + }); err == nil { + // new encoding + tags = append(tags, "t-"+encode(b)) + } + } + return tags +} + +func decodeMetadata(tags []string) map[string]string { + md := make(map[string]string) + + for _, tag := range tags { + if len(tag) == 0 || tag[0] != 't' || tag[1] != '-' { + continue + } + + buf := decode(tag[2:]) + + var kv map[string]string + + // Now unmarshal + if err := json.Unmarshal(buf, &kv); err == nil { + for k, v := range kv { + md[k] = v + } + } + } + return md +} + +func encodeVersion(v string) []string { + return []string{ + // new encoding, + "v-" + encode([]byte(v)), + } +} + +func decodeVersion(tags []string) (string, bool) { + for _, tag := range tags { + if len(tag) < 2 || tag[0] != 'v' || tag[1] != '-' { + continue + } + return string(decode(tag[2:])), true + } + return "", false +} diff --git a/encoding_test.go b/encoding_test.go new file mode 100644 index 0000000..27f0f67 --- /dev/null +++ b/encoding_test.go @@ -0,0 +1,147 @@ +package mdns + +import ( + "encoding/json" + "testing" + + "github.com/micro/go-micro/registry" +) + +func TestEncodingEndpoints(t *testing.T) { + eps := []*registry.Endpoint{ + ®istry.Endpoint{ + Name: "endpoint1", + Request: ®istry.Value{ + Name: "request", + Type: "request", + }, + Response: ®istry.Value{ + Name: "response", + Type: "response", + }, + Metadata: map[string]string{ + "foo1": "bar1", + }, + }, + ®istry.Endpoint{ + Name: "endpoint2", + Request: ®istry.Value{ + Name: "request", + Type: "request", + }, + Response: ®istry.Value{ + Name: "response", + Type: "response", + }, + Metadata: map[string]string{ + "foo2": "bar2", + }, + }, + ®istry.Endpoint{ + Name: "endpoint3", + Request: ®istry.Value{ + Name: "request", + Type: "request", + }, + Response: ®istry.Value{ + Name: "response", + Type: "response", + }, + Metadata: map[string]string{ + "foo3": "bar3", + }, + }, + } + + testEp := func(ep *registry.Endpoint, enc string) { + // encode endpoint + e := encodeEndpoints([]*registry.Endpoint{ep}) + + // check there are two tags; old and new + if len(e) != 1 { + t.Fatalf("Expected 1 encoded tag, got %v", e) + } + + // check old encoding + var seen bool + + for _, en := range e { + if en == enc { + seen = true + break + } + } + + if !seen { + t.Fatalf("Expected %s but not found", enc) + } + + // decode + d := decodeEndpoints([]string{enc}) + if len(d) == 0 { + t.Fatalf("Expected %v got %v", ep, d) + } + + // check name + if d[0].Name != ep.Name { + t.Fatalf("Expected ep %s got %s", ep.Name, d[0].Name) + } + + // check all the metadata exists + for k, v := range ep.Metadata { + if gv := d[0].Metadata[k]; gv != v { + t.Fatalf("Expected key %s val %s got val %s", k, v, gv) + } + } + } + + for _, ep := range eps { + // JSON encoded + jencoded, err := json.Marshal(ep) + if err != nil { + t.Fatal(err) + } + + // HEX encoded + hencoded := encode(jencoded) + // endpoint tag + hepTag := "e-" + hencoded + testEp(ep, hepTag) + } +} + +func TestEncodingVersion(t *testing.T) { + testData := []struct { + decoded string + encoded string + }{ + {"1.0.0", "v-789c32d433d03300040000ffff02ce00ee"}, + {"latest", "v-789cca492c492d2e01040000ffff08cc028e"}, + } + + for _, data := range testData { + e := encodeVersion(data.decoded) + + if e[0] != data.encoded { + t.Fatalf("Expected %s got %s", data.encoded, e) + } + + d, ok := decodeVersion(e) + if !ok { + t.Fatalf("Unexpected %t for %s", ok, data.encoded) + } + + if d != data.decoded { + t.Fatalf("Expected %s got %s", data.decoded, d) + } + + d, ok = decodeVersion([]string{data.encoded}) + if !ok { + t.Fatalf("Unexpected %t for %s", ok, data.encoded) + } + + if d != data.decoded { + t.Fatalf("Expected %s got %s", data.decoded, d) + } + } +} diff --git a/mdns.go b/mdns.go new file mode 100644 index 0000000..cb1e933 --- /dev/null +++ b/mdns.go @@ -0,0 +1,284 @@ +package mdns + +/* + MDNS is a multicast dns registry for service discovery + This creates a zero dependency system which is great + where multicast dns is available. This usually depends + on the ability to leverage udp and multicast/broadcast. +*/ + +import ( + "net" + "strings" + "sync" + "time" + + "github.com/hashicorp/mdns" + "github.com/micro/go-micro/registry" + hash "github.com/mitchellh/hashstructure" +) + +type mdnsEntry struct { + hash uint64 + id string + node *mdns.Server +} + +type mdnsRegistry struct { + opts registry.Options + + sync.Mutex + services map[string][]*mdnsEntry +} + +func newRegistry(opts ...registry.Option) registry.Registry { + options := registry.Options{ + Timeout: time.Millisecond * 100, + } + + return &mdnsRegistry{ + opts: options, + services: make(map[string][]*mdnsEntry), + } +} + +func (m *mdnsRegistry) Register(service *registry.Service, opts ...registry.RegisterOption) error { + m.Lock() + defer m.Unlock() + + entries, ok := m.services[service.Name] + // first entry, create wildcard used for list queries + if !ok { + s, err := mdns.NewMDNSService( + service.Name, + "_services", + "", + "", + 9999, + []net.IP{net.ParseIP("0.0.0.0")}, + nil, + ) + if err != nil { + return err + } + + srv, err := mdns.NewServer(&mdns.Config{Zone: s}) + if err != nil { + return err + } + + // append the wildcard entry + entries = append(entries, &mdnsEntry{id: "*", node: srv}) + } + + var gerr error + + for _, node := range service.Nodes { + // create hash of service; uint64 + h, err := hash.Hash(node, nil) + if err != nil { + gerr = err + continue + } + + var seen bool + var e *mdnsEntry + + for _, entry := range entries { + if node.Id == entry.id { + seen = true + e = entry + break + } + } + + // already registered, continue + if seen && e.hash == h { + continue + // hash doesn't match, shutdown + } else if seen { + e.node.Shutdown() + // doesn't exist + } else { + e = &mdnsEntry{hash: h} + } + + var txt []string + txt = append(txt, encodeVersion(service.Version)...) + txt = append(txt, encodeMetadata(node.Metadata)...) + // txt = append(txt, encodeEndpoints(service.Endpoints)...) + + // we got here, new node + s, err := mdns.NewMDNSService( + node.Id, + service.Name, + "", + "", + node.Port, + []net.IP{net.ParseIP(node.Address)}, + txt, + ) + if err != nil { + gerr = err + continue + } + + srv, err := mdns.NewServer(&mdns.Config{Zone: s}) + if err != nil { + gerr = err + continue + } + + e.id = node.Id + e.node = srv + entries = append(entries, e) + } + + // save + m.services[service.Name] = entries + + return gerr +} + +func (m *mdnsRegistry) Deregister(service *registry.Service) error { + m.Lock() + defer m.Unlock() + + var newEntries []*mdnsEntry + + // loop existing entries, check if any match, shutdown those that do + for _, entry := range m.services[service.Name] { + var remove bool + + for _, node := range service.Nodes { + if node.Id == entry.id { + entry.node.Shutdown() + remove = true + break + } + } + + // keep it? + if !remove { + newEntries = append(newEntries, entry) + } + } + + // last entry is the wildcard for list queries. Remove it. + if len(newEntries) == 1 && newEntries[0].id == "*" { + newEntries[0].node.Shutdown() + delete(m.services, service.Name) + } else { + m.services[service.Name] = newEntries + } + + return nil +} + +func (m *mdnsRegistry) GetService(service string) ([]*registry.Service, error) { + p := mdns.DefaultParams(service) + p.Timeout = m.opts.Timeout + entryCh := make(chan *mdns.ServiceEntry, 10) + p.Entries = entryCh + + exit := make(chan bool) + defer close(exit) + + serviceMap := make(map[string]*registry.Service) + + go func() { + for { + select { + case e := <-entryCh: + // list record so skip + if p.Service == "_services" { + continue + } + + version, exists := decodeVersion(e.InfoFields) + if !exists { + continue + } + + s, ok := serviceMap[version] + if !ok { + s = ®istry.Service{ + Name: service, + Version: version, + // Endpoints: decodeEndpoints(e.InfoFields), + } + } + + s.Nodes = append(s.Nodes, ®istry.Node{ + Id: strings.TrimSuffix(e.Name, "."+p.Service+"."+p.Domain+"."), + Address: e.AddrV4.String(), + Port: e.Port, + Metadata: decodeMetadata(e.InfoFields), + }) + + serviceMap[version] = s + case <-exit: + return + } + } + }() + + if err := mdns.Query(p); err != nil { + return nil, err + } + + // create list and return + var services []*registry.Service + + for _, service := range serviceMap { + services = append(services, service) + } + + return services, nil +} + +func (m *mdnsRegistry) ListServices() ([]*registry.Service, error) { + p := mdns.DefaultParams("_services") + p.Timeout = m.opts.Timeout + entryCh := make(chan *mdns.ServiceEntry, 10) + p.Entries = entryCh + + exit := make(chan bool) + defer close(exit) + + serviceMap := make(map[string]bool) + var services []*registry.Service + + go func() { + for { + select { + case e := <-entryCh: + name := strings.TrimSuffix(e.Name, "."+p.Service+"."+p.Domain+".") + if !serviceMap[name] { + serviceMap[name] = true + services = append(services, ®istry.Service{Name: name}) + } + case <-exit: + return + } + } + }() + + if err := mdns.Query(p); err != nil { + return nil, err + } + + return services, nil +} + +func (m *mdnsRegistry) Watch() (registry.Watcher, error) { + return nil, nil +} + +func (m *mdnsRegistry) String() string { + return "mdns" +} + +func NewRegistry(opts ...registry.Option) registry.Registry { + return newRegistry(opts...) +}