diff --git a/encoding.go b/encoding.go index bb8feff..d7323da 100644 --- a/encoding.go +++ b/encoding.go @@ -6,119 +6,68 @@ import ( "encoding/hex" "encoding/json" "io/ioutil" - - "github.com/micro/go-micro/registry" + "strings" ) -func encode(buf []byte) string { - var b bytes.Buffer - defer b.Reset() +func encode(txt *mdnsTxt) ([]string, error) { + b, err := json.Marshal(txt) + if err != nil { + return nil, err + } - w := zlib.NewWriter(&b) - if _, err := w.Write(buf); err != nil { - return "" + var buf bytes.Buffer + defer buf.Reset() + + w := zlib.NewWriter(&buf) + if _, err := w.Write(b); err != nil { + return nil, err } w.Close() - return hex.EncodeToString(b.Bytes()) + encoded := hex.EncodeToString(buf.Bytes()) + + // individual txt limit + if len(encoded) <= 255 { + return []string{encoded}, nil + } + + // split encoded string + var record []string + + for len(encoded) > 255 { + record = append(record, encoded[:255]) + encoded = encoded[255:] + } + + record = append(record, encoded) + + return record, nil } -func decode(d string) []byte { - hr, err := hex.DecodeString(d) +func decode(record []string) (*mdnsTxt, error) { + encoded := strings.Join(record, "") + + hr, err := hex.DecodeString(encoded) if err != nil { - return nil + return nil, err } br := bytes.NewReader(hr) zr, err := zlib.NewReader(br) if err != nil { - return nil + return nil, err } rbuf, err := ioutil.ReadAll(zr) if err != nil { - return nil + return nil, err } - return rbuf -} + var txt *mdnsTxt -func encodeEndpoints(en []*registry.Endpoint) []string { - var tags []string - for _, e := range en { - if b, err := json.Marshal(e); err == nil { - tags = append(tags, "e-"+encode(b)) - } + if err := json.Unmarshal(rbuf, &txt); err != nil { + return nil, err } - return tags -} - -func decodeEndpoints(tags []string) []*registry.Endpoint { - var en []*registry.Endpoint - - for _, tag := range tags { - if len(tag) == 0 || tag[0] != 'e' || tag[1] != '-' { - continue - } - - buf := decode(tag[2:]) - - var e *registry.Endpoint - if err := json.Unmarshal(buf, &e); err == nil { - en = append(en, e) - } - } - return en -} - -func encodeMetadata(md map[string]string) []string { - var tags []string - for k, v := range md { - if b, err := json.Marshal(map[string]string{ - k: v, - }); err == nil { - // new encoding - tags = append(tags, "t-"+encode(b)) - } - } - return tags -} - -func decodeMetadata(tags []string) map[string]string { - md := make(map[string]string) - - for _, tag := range tags { - if len(tag) == 0 || tag[0] != 't' || tag[1] != '-' { - continue - } - - buf := decode(tag[2:]) - - var kv map[string]string - - // Now unmarshal - if err := json.Unmarshal(buf, &kv); err == nil { - for k, v := range kv { - md[k] = v - } - } - } - return md -} - -func encodeVersion(v string) []string { - return []string{ - // new encoding, - "v-" + encode([]byte(v)), - } -} - -func decodeVersion(tags []string) (string, bool) { - for _, tag := range tags { - if len(tag) < 2 || tag[0] != 'v' || tag[1] != '-' { - continue - } - return string(decode(tag[2:])), true - } - return "", false + + return txt, nil } diff --git a/encoding_test.go b/encoding_test.go index 27f0f67..baea1c5 100644 --- a/encoding_test.go +++ b/encoding_test.go @@ -1,147 +1,67 @@ package mdns import ( - "encoding/json" "testing" "github.com/micro/go-micro/registry" ) -func TestEncodingEndpoints(t *testing.T) { - eps := []*registry.Endpoint{ - ®istry.Endpoint{ - Name: "endpoint1", - Request: ®istry.Value{ - Name: "request", - Type: "request", - }, - Response: ®istry.Value{ - Name: "response", - Type: "response", - }, +func TestEncoding(t *testing.T) { + testData := []*mdnsTxt{ + &mdnsTxt{ + Version: "1.0.0", Metadata: map[string]string{ - "foo1": "bar1", + "foo": "bar", }, - }, - ®istry.Endpoint{ - Name: "endpoint2", - Request: ®istry.Value{ - Name: "request", - Type: "request", - }, - Response: ®istry.Value{ - Name: "response", - Type: "response", - }, - Metadata: map[string]string{ - "foo2": "bar2", - }, - }, - ®istry.Endpoint{ - Name: "endpoint3", - Request: ®istry.Value{ - Name: "request", - Type: "request", - }, - Response: ®istry.Value{ - Name: "response", - Type: "response", - }, - Metadata: map[string]string{ - "foo3": "bar3", + Endpoints: []*registry.Endpoint{ + ®istry.Endpoint{ + Name: "endpoint1", + Request: ®istry.Value{ + Name: "request", + Type: "request", + }, + Response: ®istry.Value{ + Name: "response", + Type: "response", + }, + Metadata: map[string]string{ + "foo1": "bar1", + }, + }, }, }, } - testEp := func(ep *registry.Endpoint, enc string) { - // encode endpoint - e := encodeEndpoints([]*registry.Endpoint{ep}) - - // check there are two tags; old and new - if len(e) != 1 { - t.Fatalf("Expected 1 encoded tag, got %v", e) - } - - // check old encoding - var seen bool - - for _, en := range e { - if en == enc { - seen = true - break - } - } - - if !seen { - t.Fatalf("Expected %s but not found", enc) - } - - // decode - d := decodeEndpoints([]string{enc}) - if len(d) == 0 { - t.Fatalf("Expected %v got %v", ep, d) - } - - // check name - if d[0].Name != ep.Name { - t.Fatalf("Expected ep %s got %s", ep.Name, d[0].Name) - } - - // check all the metadata exists - for k, v := range ep.Metadata { - if gv := d[0].Metadata[k]; gv != v { - t.Fatalf("Expected key %s val %s got val %s", k, v, gv) - } - } - } - - for _, ep := range eps { - // JSON encoded - jencoded, err := json.Marshal(ep) + for _, d := range testData { + encoded, err := encode(d) if err != nil { t.Fatal(err) } - // HEX encoded - hencoded := encode(jencoded) - // endpoint tag - hepTag := "e-" + hencoded - testEp(ep, hepTag) - } -} - -func TestEncodingVersion(t *testing.T) { - testData := []struct { - decoded string - encoded string - }{ - {"1.0.0", "v-789c32d433d03300040000ffff02ce00ee"}, - {"latest", "v-789cca492c492d2e01040000ffff08cc028e"}, - } - - for _, data := range testData { - e := encodeVersion(data.decoded) - - if e[0] != data.encoded { - t.Fatalf("Expected %s got %s", data.encoded, e) - } - - d, ok := decodeVersion(e) - if !ok { - t.Fatalf("Unexpected %t for %s", ok, data.encoded) - } - - if d != data.decoded { - t.Fatalf("Expected %s got %s", data.decoded, d) - } - - d, ok = decodeVersion([]string{data.encoded}) - if !ok { - t.Fatalf("Unexpected %t for %s", ok, data.encoded) - } - - if d != data.decoded { - t.Fatalf("Expected %s got %s", data.decoded, d) - } + for _, txt := range encoded { + if len(txt) > 255 { + t.Fatalf("One of parts for txt is %d characters", len(txt)) + } + } + + decoded, err := decode(encoded) + if err != nil { + t.Fatal(err) + } + + if decoded.Version != d.Version { + t.Fatalf("Expected version %s got %s", d.Version, decoded.Version) + } + + if len(decoded.Endpoints) != len(d.Endpoints) { + t.Fatalf("Expected %d endpoints, got %d", len(d.Endpoints), len(decoded.Endpoints)) + } + + for k, v := range d.Metadata { + if val := decoded.Metadata[k]; val != v { + t.Fatalf("Expected %s=%s got %s=%s", k, v, k, val) + } + } } + } diff --git a/mdns.go b/mdns.go index cb1e933..b7d056e 100644 --- a/mdns.go +++ b/mdns.go @@ -18,6 +18,12 @@ import ( hash "github.com/mitchellh/hashstructure" ) +type mdnsTxt struct { + Version string + Endpoints []*registry.Endpoint + Metadata map[string]string +} + type mdnsEntry struct { hash uint64 id string @@ -103,10 +109,16 @@ func (m *mdnsRegistry) Register(service *registry.Service, opts ...registry.Regi e = &mdnsEntry{hash: h} } - var txt []string - txt = append(txt, encodeVersion(service.Version)...) - txt = append(txt, encodeMetadata(node.Metadata)...) - // txt = append(txt, encodeEndpoints(service.Endpoints)...) + txt, err := encode(&mdnsTxt{ + Version: service.Version, + Endpoints: service.Endpoints, + Metadata: node.Metadata, + }) + + if err != nil { + gerr = err + continue + } // we got here, new node s, err := mdns.NewMDNSService( @@ -195,17 +207,17 @@ func (m *mdnsRegistry) GetService(service string) ([]*registry.Service, error) { continue } - version, exists := decodeVersion(e.InfoFields) - if !exists { + txt, err := decode(e.InfoFields) + if err != nil { continue } - s, ok := serviceMap[version] + s, ok := serviceMap[txt.Version] if !ok { s = ®istry.Service{ - Name: service, - Version: version, - // Endpoints: decodeEndpoints(e.InfoFields), + Name: service, + Version: txt.Version, + Endpoints: txt.Endpoints, } } @@ -213,10 +225,10 @@ func (m *mdnsRegistry) GetService(service string) ([]*registry.Service, error) { Id: strings.TrimSuffix(e.Name, "."+p.Service+"."+p.Domain+"."), Address: e.AddrV4.String(), Port: e.Port, - Metadata: decodeMetadata(e.InfoFields), + Metadata: txt.Metadata, }) - serviceMap[version] = s + serviceMap[txt.Version] = s case <-exit: return }