diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 0000000..1899438 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,24 @@ +--- +name: Bug report +about: For reporting bugs in go-micro +title: "[BUG]" +labels: '' +assignees: '' + +--- + +**Describe the bug** + +1. What are you trying to do? +2. What did you expect to happen? +3. What happens instead? + +**How to reproduce the bug:** + +If possible, please include a minimal code snippet here. + +**Environment:** +Go Version: please paste `go version` output here +``` +please paste `go env` output here +``` diff --git a/.github/ISSUE_TEMPLATE/feature-request---enhancement.md b/.github/ISSUE_TEMPLATE/feature-request---enhancement.md new file mode 100644 index 0000000..459817f --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature-request---enhancement.md @@ -0,0 +1,17 @@ +--- +name: Feature request / Enhancement +about: If you have a need not served by go-micro +title: "[FEATURE]" +labels: '' +assignees: '' + +--- + +**Is your feature request related to a problem? Please describe.** +A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + +**Describe the solution you'd like** +A clear and concise description of what you want to happen. + +**Additional context** +Add any other context or screenshots about the feature request here. diff --git a/.github/ISSUE_TEMPLATE/question.md b/.github/ISSUE_TEMPLATE/question.md new file mode 100644 index 0000000..1daf48b --- /dev/null +++ b/.github/ISSUE_TEMPLATE/question.md @@ -0,0 +1,14 @@ +--- +name: Question +about: Ask a question about go-micro +title: '' +labels: '' +assignees: '' + +--- + +Before asking, please check if your question has already been answered: + +1. Check the documentation - https://micro.mu/docs/ +2. Check the examples and plugins - https://github.com/micro/examples & https://github.com/micro/go-plugins +3. Search existing issues diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..cba3cbc --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,9 @@ +## Pull Request template +Please, go through these steps before clicking submit on this PR. + +1. Give a descriptive title to your PR. +2. Provide a description of your changes. +3. Make sure you have some relevant tests. +4. Put `closes #XXXX` in your comment to auto-close the issue that your PR fixes (if applicable). + +**PLEASE REMOVE THIS TEMPLATE BEFORE SUBMITTING** diff --git a/.github/renovate.json b/.github/renovate.json new file mode 100644 index 0000000..52d2918 --- /dev/null +++ b/.github/renovate.json @@ -0,0 +1,19 @@ +{ + "extends": [ + "config:base" + ], + "packageRules": [ + { + "matchUpdateTypes": ["minor", "patch", "pin", "digest"], + "automerge": true + }, + { + "groupName": "all deps", + "separateMajorMinor": true, + "groupSlug": "all", + "packagePatterns": [ + "*" + ] + } + ] +} diff --git a/.github/stale.sh b/.github/stale.sh new file mode 100755 index 0000000..8a345c4 --- /dev/null +++ b/.github/stale.sh @@ -0,0 +1,13 @@ +#!/bin/bash -ex + +export PATH=$PATH:$(pwd)/bin +export GO111MODULE=on +export GOBIN=$(pwd)/bin + +#go get github.com/rvflash/goup@v0.4.1 + +#goup -v ./... +#go get github.com/psampaz/go-mod-outdated@v0.6.0 +go list -u -m -mod=mod -json all | go-mod-outdated -update -direct -ci || true + +#go list -u -m -json all | go-mod-outdated -update diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..eb19b68 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,46 @@ +name: build +on: + push: + branches: + - master +jobs: + test: + name: test + runs-on: ubuntu-latest + steps: + - name: setup + uses: actions/setup-go@v2 + with: + go-version: 1.16 + - name: checkout + uses: actions/checkout@v2 + - name: cache + uses: actions/cache@v2 + with: + path: ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + restore-keys: ${{ runner.os }}-go- + - name: deps + run: go get -v -t -d ./... + - name: test + env: + INTEGRATION_TESTS: yes + run: go test -mod readonly -v ./... + lint: + name: lint + runs-on: ubuntu-latest + steps: + - name: checkout + uses: actions/checkout@v2 + - name: lint + uses: golangci/golangci-lint-action@v2 + continue-on-error: true + with: + # Required: the version of golangci-lint is required and must be specified without patch version: we always use the latest patch version. + version: v1.30 + # Optional: working directory, useful for monorepos + # working-directory: somedir + # Optional: golangci-lint command line arguments. + # args: --issues-exit-code=0 + # Optional: show only new issues if it's a pull request. The default value is `false`. + # only-new-issues: true diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml new file mode 100644 index 0000000..545baf2 --- /dev/null +++ b/.github/workflows/pr.yml @@ -0,0 +1,46 @@ +name: prbuild +on: + pull_request: + branches: + - master +jobs: + test: + name: test + runs-on: ubuntu-latest + steps: + - name: setup + uses: actions/setup-go@v2 + with: + go-version: 1.16 + - name: checkout + uses: actions/checkout@v2 + - name: cache + uses: actions/cache@v2 + with: + path: ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + restore-keys: ${{ runner.os }}-go- + - name: deps + run: go get -v -t -d ./... + - name: test + env: + INTEGRATION_TESTS: yes + run: go test -mod readonly -v ./... + lint: + name: lint + runs-on: ubuntu-latest + steps: + - name: checkout + uses: actions/checkout@v2 + - name: lint + uses: golangci/golangci-lint-action@v2 + continue-on-error: true + with: + # Required: the version of golangci-lint is required and must be specified without patch version: we always use the latest patch version. + version: v1.30 + # Optional: working directory, useful for monorepos + # working-directory: somedir + # Optional: golangci-lint command line arguments. + # args: --issues-exit-code=0 + # Optional: show only new issues if it's a pull request. The default value is `false`. + # only-new-issues: true diff --git a/.synced b/.synced new file mode 100644 index 0000000..472f779 --- /dev/null +++ b/.synced @@ -0,0 +1 @@ +8975184b88a75a692780bbd4d6b18f0a28e99819 \ No newline at end of file diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..ff7ed4b --- /dev/null +++ b/go.mod @@ -0,0 +1,12 @@ +module github.com/unistack-org/micro-register-mdns/v3 + +go 1.15 + +require ( + github.com/google/uuid v1.2.0 + github.com/miekg/dns v1.1.31 + github.com/unistack-org/micro/v3 v3.2.14 + golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a // indirect + golang.org/x/net v0.0.0-20210119194325-5f4716e94777 + golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..79c5522 --- /dev/null +++ b/go.sum @@ -0,0 +1,46 @@ +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= +github.com/ef-ds/deque v1.0.4/go.mod h1:gXDnTC3yqvBcHbq2lcExjtAcVrOnJCbMcZXmuj8Z4tg= +github.com/go-test/deep v1.0.7/go.mod h1:QV8Hv/iy04NyLBxAdO9njL0iVPN1S4d/A3NVv1V36o8= +github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.2.0 h1:qJYtXnJRWmpe7m/3XlyhrsLrEURqHRM2kxzoxXqyUDs= +github.com/google/uuid v1.2.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/heimdalr/dag v1.0.1/go.mod h1:t+ZkR+sjKL4xhlE1B9rwpvwfo+x+2R0363efS+Oghns= +github.com/imdario/mergo v0.3.11/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/miekg/dns v1.1.31 h1:sJFOl9BgwbYAWOGEwr61FU28pqsBNdpRBnhGXtO06Oo= +github.com/miekg/dns v1.1.31/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7xM= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= +github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= +github.com/silas/dag v0.0.0-20210121180416-41cf55125c34/go.mod h1:7RTUFBdIRC9nZ7/3RyRNH1bdqIShrDejd1YbLwgPS+I= +github.com/unistack-org/micro/v3 v3.2.14 h1:BD7JR2W0WlJvJgHN3uPWrE/vNAGyxhIQrIODeDCfoSk= +github.com/unistack-org/micro/v3 v3.2.14/go.mod h1:3j13mSd/rILNjyP0tEVtDxyDkJBtnHUXShNCuPHkC5A= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a h1:vclmkQCjlDX5OydZ9wv8rBCcS0QyQY66Mpf/7BZbInM= +golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210119194325-5f4716e94777 h1:003p0dJM77cxMSyCPFphvZf/Y5/NXf5fzg6ufd1/Oew= +golang.org/x/net v0.0.0-20210119194325-5f4716e94777/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e h1:vcxGaoTs7kV8m5Np9uUNQin4BrLOthgV7252N8V+FwY= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 h1:nxC68pudNYkKU6jWhgrqdreuFiOQWj1Fs7T3VrH4Pjw= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191216052735-49a3e744a425/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/mdns.go b/mdns.go index cb6f48f..c23023c 100644 --- a/mdns.go +++ b/mdns.go @@ -15,9 +15,9 @@ import ( "time" "github.com/google/uuid" - "github.com/micro/go-micro/v3/logger" - "github.com/micro/go-micro/v3/registry" - "github.com/micro/go-micro/v3/util/mdns" + util "github.com/unistack-org/micro-register-mdns/v3/util" + "github.com/unistack-org/micro/v3/logger" + "github.com/unistack-org/micro/v3/register" ) const ( @@ -30,13 +30,13 @@ const ( type mdnsTxt struct { Service string Version string - Endpoints []*registry.Endpoint + Endpoints []*register.Endpoint Metadata map[string]string } type mdnsEntry struct { id string - node *mdns.Server + node *util.Server } // services are a key/value map, with the service name as a key and the value being a @@ -45,8 +45,8 @@ type mdnsEntry struct { type services map[string][]*mdnsEntry // mdsRegistry is a multicast dns registry -type mdnsRegistry struct { - opts registry.Options +type mdnsRegister struct { + opts register.Options // the top level domains, these can be overriden using options defaultDomain string @@ -61,18 +61,18 @@ type mdnsRegistry struct { watchers map[string]*mdnsWatcher // listener - listener chan *mdns.ServiceEntry + listener chan *util.ServiceEntry } type mdnsWatcher struct { id string - wo registry.WatchOptions - ch chan *mdns.ServiceEntry + wo register.WatchOptions + ch chan *util.ServiceEntry exit chan struct{} // the mdns domain domain string // the registry - registry *mdnsRegistry + registry *mdnsRegister } func encode(txt *mdnsTxt) ([]string, error) { @@ -87,8 +87,8 @@ func encode(txt *mdnsTxt) ([]string, error) { w := zlib.NewWriter(&buf) defer func() { if closeErr := w.Close(); closeErr != nil { - if logger.V(logger.ErrorLevel, logger.DefaultLogger) { - logger.Errorf("[mdns] registry close encoding writer err: %v", closeErr) + if logger.V(logger.ErrorLevel) { + logger.Errorf(context.TODO(), "[mdns] registry close encoding writer err: %v", closeErr) } } }() @@ -148,8 +148,8 @@ func decode(record []string) (*mdnsTxt, error) { return txt, nil } -func newRegistry(opts ...registry.Option) registry.Registry { - options := registry.Options{ +func newRegister(opts ...register.Option) register.Register { + options := register.Options{ Context: context.Background(), Timeout: time.Millisecond * 100, } @@ -159,12 +159,12 @@ func newRegistry(opts ...registry.Option) registry.Registry { } // set the domain - defaultDomain := registry.DefaultDomain + defaultDomain := register.DefaultDomain if d, ok := options.Context.Value("mdns.domain").(string); ok { defaultDomain = d } - return &mdnsRegistry{ + return &mdnsRegister{ defaultDomain: defaultDomain, globalDomain: globalDomain, opts: options, @@ -173,28 +173,38 @@ func newRegistry(opts ...registry.Option) registry.Registry { } } -func (m *mdnsRegistry) Init(opts ...registry.Option) error { +func (m *mdnsRegister) Init(opts ...register.Option) error { for _, o := range opts { o(&m.opts) } return nil } -func (m *mdnsRegistry) Options() registry.Options { +func (m *mdnsRegister) Options() register.Options { return m.opts } +func (m *mdnsRegister) Connect(ctx context.Context) error { + // TODO: real connect + return nil +} + +func (m *mdnsRegister) Disconnect(ctx context.Context) error { + // TODO: real disconnect + return nil +} + // createServiceMDNSEntry will create a new wildcard mdns entry for the service in the // given domain. This wildcard mdns entry is used when listing services. func createServiceMDNSEntry(name, domain string) (*mdnsEntry, error) { ip := net.ParseIP("0.0.0.0") - s, err := mdns.NewMDNSService(name, "_services", domain+".", "", 9999, []net.IP{ip}, nil) + s, err := util.NewMDNSService(name, "_services", domain+".", "", 9999, []net.IP{ip}, nil) if err != nil { return nil, err } - srv, err := mdns.NewServer(&mdns.Config{Zone: &mdns.DNSSDService{MDNSService: s}, LocalhostChecking: true}) + srv, err := util.NewServer(&util.Config{Zone: &util.DNSSDService{MDNSService: s} /*, LocalhostChecking: true*/}) if err != nil { return nil, err } @@ -202,7 +212,7 @@ func createServiceMDNSEntry(name, domain string) (*mdnsEntry, error) { return &mdnsEntry{id: "*", node: srv}, nil } -func (m *mdnsRegistry) createMDNSEntries(domain, serviceName string) ([]*mdnsEntry, error) { +func (m *mdnsRegister) createMDNSEntries(domain, serviceName string) ([]*mdnsEntry, error) { // if it already exists don't reegister it again entries, ok := m.domains[domain][serviceName] if ok { @@ -218,7 +228,7 @@ func (m *mdnsRegistry) createMDNSEntries(domain, serviceName string) ([]*mdnsEnt return []*mdnsEntry{entry}, nil } -func registerService(service *registry.Service, entries []*mdnsEntry, options registry.RegisterOptions) ([]*mdnsEntry, error) { +func registerService(service *register.Service, entries []*mdnsEntry, options register.RegisterOptions) ([]*mdnsEntry, error) { var lastError error for _, node := range service.Nodes { var seen bool @@ -254,11 +264,11 @@ func registerService(service *registry.Service, entries []*mdnsEntry, options re } port, _ := strconv.Atoi(pt) - if logger.V(logger.DebugLevel, logger.DefaultLogger) { - logger.Debugf("[mdns] registry create new service with ip: %s for: %s", net.ParseIP(host).String(), host) + if logger.V(logger.DebugLevel) { + logger.Debugf(context.TODO(), "[mdns] registry create new service with ip: %s for: %s", net.ParseIP(host).String(), host) } // we got here, new node - s, err := mdns.NewMDNSService( + s, err := util.NewMDNSService( node.Id, service.Name, options.Domain+".", @@ -272,7 +282,7 @@ func registerService(service *registry.Service, entries []*mdnsEntry, options re continue } - srv, err := mdns.NewServer(&mdns.Config{Zone: s, LocalhostChecking: true}) + srv, err := util.NewServer(&util.Config{Zone: s /*, LocalhostChecking: true*/}) if err != nil { lastError = err continue @@ -284,7 +294,7 @@ func registerService(service *registry.Service, entries []*mdnsEntry, options re return entries, lastError } -func createGlobalDomainService(service *registry.Service, options registry.RegisterOptions) *registry.Service { +func createGlobalDomainService(service *register.Service, options register.RegisterOptions) *register.Service { srv := *service srv.Nodes = nil @@ -304,11 +314,11 @@ func createGlobalDomainService(service *registry.Service, options registry.Regis return &srv } -func (m *mdnsRegistry) Register(service *registry.Service, opts ...registry.RegisterOption) error { +func (m *mdnsRegister) Register(ctx context.Context, service *register.Service, opts ...register.RegisterOption) error { m.Lock() // parse the options - var options registry.RegisterOptions + var options register.RegisterOptions for _, o := range opts { o(&options) } @@ -336,7 +346,7 @@ func (m *mdnsRegistry) Register(service *registry.Service, opts ...registry.Regi // register in the global Domain so it can be queried as one if options.Domain != m.globalDomain { srv := createGlobalDomainService(service, options) - if err := m.Register(srv, append(opts, registry.RegisterDomain(m.globalDomain))...); err != nil { + if err := m.Register(ctx, srv, append(opts, register.RegisterDomain(m.globalDomain))...); err != nil { gerr = err } } @@ -344,9 +354,9 @@ func (m *mdnsRegistry) Register(service *registry.Service, opts ...registry.Regi return gerr } -func (m *mdnsRegistry) Deregister(service *registry.Service, opts ...registry.DeregisterOption) error { +func (m *mdnsRegister) Deregister(ctx context.Context, service *register.Service, opts ...register.DeregisterOption) error { // parse the options - var options registry.DeregisterOptions + var options register.DeregisterOptions for _, o := range opts { o(&options) } @@ -358,7 +368,7 @@ func (m *mdnsRegistry) Deregister(service *registry.Service, opts ...registry.De var err error if options.Domain != m.globalDomain { defer func() { - err = m.Deregister(service, append(opts, registry.DeregisterDomain(m.globalDomain))...) + err = m.Deregister(ctx, service, append(opts, register.DeregisterDomain(m.globalDomain))...) }() } @@ -420,28 +430,28 @@ func (m *mdnsRegistry) Deregister(service *registry.Service, opts ...registry.De return err } -func (m *mdnsRegistry) GetService(service string, opts ...registry.GetOption) ([]*registry.Service, error) { +func (m *mdnsRegister) LookupService(ctx context.Context, service string, opts ...register.LookupOption) ([]*register.Service, error) { // parse the options - var options registry.GetOptions + var options register.LookupOptions for _, o := range opts { o(&options) } if len(options.Domain) == 0 { options.Domain = m.defaultDomain } - if options.Domain == registry.WildcardDomain { + if options.Domain == register.WildcardDomain { options.Domain = m.globalDomain } - serviceMap := make(map[string]*registry.Service) - entries := make(chan *mdns.ServiceEntry, 10) + serviceMap := make(map[string]*register.Service) + entries := make(chan *util.ServiceEntry, 10) done := make(chan bool) - p := mdns.DefaultParams(service) + p := util.DefaultParams(service) // set context with timeout - var cancel context.CancelFunc - p.Context, cancel = context.WithTimeout(context.Background(), m.opts.Timeout) - defer cancel() + //var cancel context.CancelFunc + //p.Context, cancel = context.WithTimeout(context.Background(), m.opts.Timeout) + //defer cancel() // set entries channel p.Entries = entries // set the domain @@ -470,7 +480,7 @@ func (m *mdnsRegistry) GetService(service string, opts ...registry.GetOption) ([ s, ok := serviceMap[txt.Version] if !ok { - s = ®istry.Service{ + s = ®ister.Service{ Name: txt.Service, Version: txt.Version, Endpoints: txt.Endpoints, @@ -484,27 +494,27 @@ func (m *mdnsRegistry) GetService(service string, opts ...registry.GetOption) ([ } else if len(e.AddrV6) > 0 { addr = "[" + e.AddrV6.String() + "]" } else { - if logger.V(logger.InfoLevel, logger.DefaultLogger) { - logger.Infof("[mdns]: invalid endpoint received: %v", e) + if logger.V(logger.InfoLevel) { + logger.Infof(context.TODO(), "[mdns]: invalid endpoint received: %v", e) } continue } - s.Nodes = append(s.Nodes, ®istry.Node{ + s.Nodes = append(s.Nodes, ®ister.Node{ Id: strings.TrimSuffix(e.Name, "."+p.Service+"."+p.Domain+"."), Address: fmt.Sprintf("%s:%d", addr, e.Port), Metadata: txt.Metadata, }) serviceMap[txt.Version] = s - case <-p.Context.Done(): - close(done) - return + //case <-p.Context.Done(): + // close(done) + // return } } }() // execute the query - if err := mdns.Query(p); err != nil { + if err := util.Query(ctx, p); err != nil { return nil, err } @@ -512,7 +522,7 @@ func (m *mdnsRegistry) GetService(service string, opts ...registry.GetOption) ([ <-done // create list and return - services := make([]*registry.Service, 0, len(serviceMap)) + services := make([]*register.Service, 0, len(serviceMap)) for _, service := range serviceMap { services = append(services, service) @@ -521,34 +531,34 @@ func (m *mdnsRegistry) GetService(service string, opts ...registry.GetOption) ([ return services, nil } -func (m *mdnsRegistry) ListServices(opts ...registry.ListOption) ([]*registry.Service, error) { +func (m *mdnsRegister) ListServices(ctx context.Context, opts ...register.ListOption) ([]*register.Service, error) { // parse the options - var options registry.ListOptions + var options register.ListOptions for _, o := range opts { o(&options) } if len(options.Domain) == 0 { options.Domain = m.defaultDomain } - if options.Domain == registry.WildcardDomain { + if options.Domain == register.WildcardDomain { options.Domain = m.globalDomain } serviceMap := make(map[string]bool) - entries := make(chan *mdns.ServiceEntry, 10) + entries := make(chan *util.ServiceEntry, 10) done := make(chan bool) - p := mdns.DefaultParams("_services") + p := util.DefaultParams("_services") // set context with timeout - var cancel context.CancelFunc - p.Context, cancel = context.WithTimeout(context.Background(), m.opts.Timeout) - defer cancel() + //var cancel context.CancelFunc + //p.Context, cancel = context.WithTimeout(context.Background(), m.opts.Timeout) + //defer cancel() // set entries channel p.Entries = entries // set domain p.Domain = options.Domain - var services []*registry.Service + var services []*register.Service go func() { for { @@ -563,17 +573,17 @@ func (m *mdnsRegistry) ListServices(opts ...registry.ListOption) ([]*registry.Se name := strings.TrimSuffix(e.Name, "."+p.Service+"."+p.Domain+".") if !serviceMap[name] { serviceMap[name] = true - services = append(services, ®istry.Service{Name: name}) + services = append(services, ®ister.Service{Name: name}) } - case <-p.Context.Done(): - close(done) - return + //case <-p.Context.Done(): + // close(done) + //return } } }() // execute query - if err := mdns.Query(p); err != nil { + if err := util.Query(ctx, p); err != nil { return nil, err } @@ -583,22 +593,22 @@ func (m *mdnsRegistry) ListServices(opts ...registry.ListOption) ([]*registry.Se return services, nil } -func (m *mdnsRegistry) Watch(opts ...registry.WatchOption) (registry.Watcher, error) { - var wo registry.WatchOptions +func (m *mdnsRegister) Watch(ctx context.Context, opts ...register.WatchOption) (register.Watcher, error) { + var wo register.WatchOptions for _, o := range opts { o(&wo) } if len(wo.Domain) == 0 { wo.Domain = m.defaultDomain } - if wo.Domain == registry.WildcardDomain { + if wo.Domain == register.WildcardDomain { wo.Domain = m.globalDomain } md := &mdnsWatcher{ id: uuid.New().String(), wo: wo, - ch: make(chan *mdns.ServiceEntry, 32), + ch: make(chan *util.ServiceEntry, 32), exit: make(chan struct{}), domain: wo.Domain, registry: m, @@ -636,14 +646,14 @@ func (m *mdnsRegistry) Watch(opts ...registry.WatchOption) (registry.Watcher, er // reset the listener exit := make(chan struct{}) - ch := make(chan *mdns.ServiceEntry, 32) + ch := make(chan *util.ServiceEntry, 32) m.listener = ch m.mtx.Unlock() // send messages to the watchers go func() { - send := func(w *mdnsWatcher, e *mdns.ServiceEntry) { + send := func(w *mdnsWatcher, e *util.ServiceEntry) { select { case w.ch <- e: default: @@ -670,9 +680,9 @@ func (m *mdnsRegistry) Watch(opts ...registry.WatchOption) (registry.Watcher, er }() // start listening, blocking call - mdns.Listen(ch, exit) + util.Listen(ch, exit) - // mdns.Listen has unblocked + // util.Listen has unblocked // kill the saved listener m.mtx.Lock() m.listener = nil @@ -684,11 +694,11 @@ func (m *mdnsRegistry) Watch(opts ...registry.WatchOption) (registry.Watcher, er return md, nil } -func (m *mdnsRegistry) String() string { +func (m *mdnsRegister) String() string { return "mdns" } -func (m *mdnsWatcher) Next() (*registry.Result, error) { +func (m *mdnsWatcher) Next() (*register.Result, error) { for { select { case e := <-m.ch: @@ -713,7 +723,7 @@ func (m *mdnsWatcher) Next() (*registry.Result, error) { action = "create" } - service := ®istry.Service{ + service := ®ister.Service{ Name: txt.Service, Version: txt.Version, Endpoints: txt.Endpoints, @@ -731,22 +741,22 @@ func (m *mdnsWatcher) Next() (*registry.Result, error) { addr = e.AddrV4.String() } else if len(e.AddrV6) > 0 { addr = "[" + e.AddrV6.String() + "]" - } else { - addr = e.Addr.String() + // } else { + // addr = e.Addr.String() } - service.Nodes = append(service.Nodes, ®istry.Node{ + service.Nodes = append(service.Nodes, ®ister.Node{ Id: strings.TrimSuffix(e.Name, suffix), Address: fmt.Sprintf("%s:%d", addr, e.Port), Metadata: txt.Metadata, }) - return ®istry.Result{ + return ®ister.Result{ Action: action, Service: service, }, nil case <-m.exit: - return nil, registry.ErrWatcherStopped + return nil, register.ErrWatcherStopped } } } @@ -764,7 +774,11 @@ func (m *mdnsWatcher) Stop() { } } -// NewRegistry returns a new default registry which is mdns -func NewRegistry(opts ...registry.Option) registry.Registry { - return newRegistry(opts...) +func (m *mdnsRegister) Name() string { + return m.opts.Name +} + +// NewRegistry returns a new default registry which is mdns +func NewRegister(opts ...register.Option) register.Register { + return newRegister(opts...) } diff --git a/mdns_test.go b/mdns_test.go index 7e3d4c1..af8efce 100644 --- a/mdns_test.go +++ b/mdns_test.go @@ -5,7 +5,7 @@ import ( "testing" "time" - "github.com/micro/go-micro/v3/registry" + "github.com/unistack-org/micro/v3/register" ) func TestMDNS(t *testing.T) { @@ -14,11 +14,11 @@ func TestMDNS(t *testing.T) { t.Skip() } - testData := []*registry.Service{ + testData := []*register.Service{ { Name: "test1", Version: "1.0.1", - Nodes: []*registry.Node{ + Nodes: []*register.Node{ { Id: "test1-1", Address: "10.0.0.1:10001", @@ -31,7 +31,7 @@ func TestMDNS(t *testing.T) { { Name: "test2", Version: "1.0.2", - Nodes: []*registry.Node{ + Nodes: []*register.Node{ { Id: "test2-1", Address: "10.0.0.2:10002", @@ -44,7 +44,7 @@ func TestMDNS(t *testing.T) { { Name: "test3", Version: "1.0.3", - Nodes: []*registry.Node{ + Nodes: []*register.Node{ { Id: "test3-1", Address: "10.0.0.3:10003", @@ -58,14 +58,14 @@ func TestMDNS(t *testing.T) { travis := os.Getenv("TRAVIS") - var opts []registry.Option + var opts []register.Option if travis == "true" { - opts = append(opts, registry.Timeout(time.Millisecond*100)) + opts = append(opts, register.Timeout(time.Millisecond*100)) } // new registry - r := NewRegistry(opts...) + r := NewRegister(opts...) for _, service := range testData { // register service @@ -74,7 +74,7 @@ func TestMDNS(t *testing.T) { } // get registered service - s, err := r.GetService(service.Name) + s, err := r.LookupService(service.Name) if err != nil { t.Fatal(err) } @@ -146,14 +146,14 @@ func TestEncoding(t *testing.T) { Metadata: map[string]string{ "foo": "bar", }, - Endpoints: []*registry.Endpoint{ + Endpoints: []*register.Endpoint{ { Name: "endpoint1", - Request: ®istry.Value{ + Request: ®ister.Value{ Name: "request", Type: "request", }, - Response: ®istry.Value{ + Response: ®ister.Value{ Name: "response", Type: "response", }, @@ -204,11 +204,11 @@ func TestWatcher(t *testing.T) { t.Skip() } - testData := []*registry.Service{ + testData := []*register.Service{ { Name: "test1", Version: "1.0.1", - Nodes: []*registry.Node{ + Nodes: []*register.Node{ { Id: "test1-1", Address: "10.0.0.1:10001", @@ -221,7 +221,7 @@ func TestWatcher(t *testing.T) { { Name: "test2", Version: "1.0.2", - Nodes: []*registry.Node{ + Nodes: []*register.Node{ { Id: "test2-1", Address: "10.0.0.2:10002", @@ -234,7 +234,7 @@ func TestWatcher(t *testing.T) { { Name: "test3", Version: "1.0.3", - Nodes: []*registry.Node{ + Nodes: []*register.Node{ { Id: "test3-1", Address: "10.0.0.3:10003", @@ -246,7 +246,7 @@ func TestWatcher(t *testing.T) { }, } - testFn := func(service, s *registry.Service) { + testFn := func(service, s *register.Service) { if s == nil { t.Fatalf("Expected one result for %s got nil", service.Name) @@ -277,14 +277,14 @@ func TestWatcher(t *testing.T) { travis := os.Getenv("TRAVIS") - var opts []registry.Option + var opts []register.Option if travis == "true" { - opts = append(opts, registry.Timeout(time.Millisecond*100)) + opts = append(opts, register.Timeout(time.Millisecond*100)) } // new registry - r := NewRegistry(opts...) + r := NewRegister(opts...) w, err := r.Watch() if err != nil { diff --git a/options.go b/options.go index 0de2d3b..0923297 100644 --- a/options.go +++ b/options.go @@ -4,12 +4,12 @@ package mdns import ( "context" - "github.com/micro/go-micro/v3/registry" + "github.com/unistack-org/micro/v3/register" ) // Domain sets the mdnsDomain -func Domain(d string) registry.Option { - return func(o *registry.Options) { +func Domain(d string) register.Option { + return func(o *register.Options) { if o.Context == nil { o.Context = context.Background() } diff --git a/util/client.go b/util/client.go new file mode 100644 index 0000000..b948869 --- /dev/null +++ b/util/client.go @@ -0,0 +1,504 @@ +package mdns + +import ( + "context" + "fmt" + "net" + "sync" + + "github.com/unistack-org/micro/v3/logger" + "golang.org/x/net/dns/dnsmessage" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +// ServiceEntry is returned after we query for a service +type ServiceEntry struct { + Name string + Host string + AddrV4 net.IP + AddrV6 net.IP + Port int + Info string + InfoFields []string + TTL int + Type uint16 + + hasTXT bool + sent bool +} + +// complete is used to check if we have all the info we need +func (s *ServiceEntry) complete() bool { + return (s.AddrV4 != nil || s.AddrV6 != nil) && s.Port != 0 && s.hasTXT +} + +// QueryParam is used to customize how a Lookup is performed +type QueryParam struct { + Service string // Service to lookup + Domain string // Lookup domain, default "local" + Type dnsmessage.Type // Lookup type, defaults to dns.TypePTR + Interface *net.Interface // Multicast interface to use + Entries chan<- *ServiceEntry // Entries Channel + WantUnicastResponse bool // Unicast response desired, as per 5.4 in RFC +} + +// DefaultParams is used to return a default set of QueryParam's +func DefaultParams(service string) *QueryParam { + qp := &QueryParam{ + Service: service, + Domain: "local", + Entries: make(chan *ServiceEntry), + WantUnicastResponse: false, + } + return qp +} + +// Query looks up a given service, in a domain, waiting at most +// for a timeout before finishing the query. The results are streamed +// to a channel. Sends will not block, so clients should make sure to +// either read or buffer. +func Query(ctx context.Context, params *QueryParam) error { + // Create a new client + client, err := newClient() + if err != nil { + return err + } + defer client.Close() + + // Set the multicast interface + if params.Interface != nil { + if err := client.setInterface(params.Interface, false); err != nil { + return err + } + } + + // Ensure defaults are set + if params.Domain == "" { + params.Domain = "local" + } + + // Run the query + return client.query(ctx, params) +} + +// Listen listens indefinitely for multicast updates +func Listen(entries chan<- *ServiceEntry, exit chan struct{}) error { + // Create a new client + client, err := newClient() + if err != nil { + return err + } + defer client.Close() + + client.setInterface(nil, true) + + // Start listening for response packets + msgCh := make(chan []byte, 32) + + go client.recv(client.ipv4UnicastConn, msgCh) + go client.recv(client.ipv6UnicastConn, msgCh) + go client.recv(client.ipv4MulticastConn, msgCh) + go client.recv(client.ipv6MulticastConn, msgCh) + + sentry := make(map[string]*ServiceEntry) + + for { + select { + case <-exit: + return nil + case <-client.closedCh: + return nil + case msg := <-msgCh: + fmt.Printf("%#+v\n", msg) + entry := messageToEntry(msg, sentry) + if entry == nil { + continue + } + + // Check if this entry is complete + if entry.complete() { + if entry.sent { + continue + } + entry.sent = true + entries <- entry + sentry = make(map[string]*ServiceEntry) + } else { + // Fire off a node specific query + /* + h: + -&dnsmessage.Header{RecursionDesired: false} + m := dnsmessage.NewBuilder() + m.SetQuestion(e.Name, dns.TypePTR) + if err := client.sendQuery(m); err != nil { + logger.Errorf("[ERR] mdns: Failed to query instance %s: %v", e.Name, err) + } + */ + } + } + } + + return nil +} + +/* +// Lookup is the same as Query, however it uses all the default parameters +func Lookup(service string, entries chan<- *ServiceEntry) error { + params := DefaultParams(service) + params.Entries = entries + return Query(params) +} +*/ +// Client provides a query interface that can be used to +// search for service providers using mDNS +type client struct { + ipv4UnicastConn *net.UDPConn + ipv6UnicastConn *net.UDPConn + + ipv4MulticastConn *net.UDPConn + ipv6MulticastConn *net.UDPConn + + closed bool + closedCh chan struct{} // TODO(reddaly): This doesn't appear to be used. + closeLock sync.RWMutex +} + +// NewClient creates a new mdns Client that can be used to query +// for records +func newClient() (*client, error) { + // TODO(reddaly): At least attempt to bind to the port required in the spec. + // Create a IPv4 listener + uconn4, err4 := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + uconn6, err6 := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0}) + if err4 != nil && err6 != nil { + logger.Errorf(context.TODO(), "[ERR] mdns: Failed to bind to udp port: %v %v", err4, err6) + } + + if uconn4 == nil && uconn6 == nil { + return nil, fmt.Errorf("failed to bind to any unicast udp port") + } + + if uconn4 == nil { + uconn4 = &net.UDPConn{} + } + + if uconn6 == nil { + uconn6 = &net.UDPConn{} + } + + mconn4, err4 := net.ListenUDP("udp4", mdnsWildcardAddrIPv4) + mconn6, err6 := net.ListenUDP("udp6", mdnsWildcardAddrIPv6) + if err4 != nil && err6 != nil { + logger.Errorf(context.TODO(), "[ERR] mdns: Failed to bind to udp port: %v %v", err4, err6) + } + + if mconn4 == nil && mconn6 == nil { + return nil, fmt.Errorf("failed to bind to any multicast udp port") + } + + if mconn4 == nil { + mconn4 = &net.UDPConn{} + } + + if mconn6 == nil { + mconn6 = &net.UDPConn{} + } + + p1 := ipv4.NewPacketConn(mconn4) + p2 := ipv6.NewPacketConn(mconn6) + p1.SetMulticastLoopback(true) + p2.SetMulticastLoopback(true) + + ifaces, err := net.Interfaces() + if err != nil { + return nil, err + } + + var errCount1, errCount2 int + + for _, iface := range ifaces { + if err := p1.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { + errCount1++ + } + if err := p2.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { + errCount2++ + } + } + + if len(ifaces) == errCount1 && len(ifaces) == errCount2 { + return nil, fmt.Errorf("Failed to join multicast group on all interfaces!") + } + + c := &client{ + ipv4MulticastConn: mconn4, + ipv6MulticastConn: mconn6, + ipv4UnicastConn: uconn4, + ipv6UnicastConn: uconn6, + closedCh: make(chan struct{}), + } + return c, nil +} + +// Close is used to cleanup the client +func (c *client) Close() error { + c.closeLock.Lock() + defer c.closeLock.Unlock() + + if c.closed { + return nil + } + c.closed = true + + close(c.closedCh) + + if c.ipv4UnicastConn != nil { + c.ipv4UnicastConn.Close() + } + if c.ipv6UnicastConn != nil { + c.ipv6UnicastConn.Close() + } + if c.ipv4MulticastConn != nil { + c.ipv4MulticastConn.Close() + } + if c.ipv6MulticastConn != nil { + c.ipv6MulticastConn.Close() + } + + return nil +} + +// setInterface is used to set the query interface, uses sytem +// default if not provided +func (c *client) setInterface(iface *net.Interface, loopback bool) error { + p := ipv4.NewPacketConn(c.ipv4UnicastConn) + if err := p.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { + return err + } + p2 := ipv6.NewPacketConn(c.ipv6UnicastConn) + if err := p2.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { + return err + } + p = ipv4.NewPacketConn(c.ipv4MulticastConn) + if err := p.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { + return err + } + p2 = ipv6.NewPacketConn(c.ipv6MulticastConn) + if err := p2.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { + return err + } + + if loopback { + p.SetMulticastLoopback(true) + p2.SetMulticastLoopback(true) + } + + return nil +} + +// query is used to perform a lookup and stream results +func (c *client) query(ctx context.Context, params *QueryParam) error { + // Create the service name + serviceAddr := fmt.Sprintf("%s.%s.", trimDot(params.Service), trimDot(params.Domain)) + + // Start listening for response packets + msgCh := make(chan []byte, 32) + go c.recv(c.ipv4UnicastConn, msgCh) + go c.recv(c.ipv6UnicastConn, msgCh) + go c.recv(c.ipv4MulticastConn, msgCh) + go c.recv(c.ipv6MulticastConn, msgCh) + + // buf := make([]byte, 2, 514) + hdr := dnsmessage.Header{RecursionDesired: false} + b := dnsmessage.NewBuilder(nil, hdr) + //b.EnableCompression() + name, err := dnsmessage.NewName(serviceAddr) + if err != nil { + return err + } + q := dnsmessage.Question{Name: name, Class: dnsmessage.ClassINET} + if params.Type == 0 { + q.Type = dnsmessage.TypePTR + } else { + q.Type = params.Type + } + //q.Class |= 1 << 15 + if err = b.StartQuestions(); err != nil { + return err + } + if err = b.Question(q); err != nil { + return err + } + bbuf, err := b.Finish() + if err != nil { + return err + } + + // Send the query + // RFC 6762, section 18.12. Repurposing of Top Bit of qclass in Question + // Section + // + // In the Question Section of a Multicast DNS query, the top bit of the qclass + // field is used to indicate that unicast responses are preferred for this + // particular question. (See Section 5.4.) + if err := c.sendQuery(bbuf); err != nil { + return err + } + + // Map the in-progress responses + inprogress := make(map[string]*ServiceEntry) + + for { + select { + case rsp := <-msgCh: + inp := messageToEntry(rsp, inprogress) + if inp == nil { + continue + } + + // Check if this entry is complete + if inp.complete() { + if inp.sent { + continue + } + inp.sent = true + select { + case params.Entries <- inp: + case <-ctx.Done(): + return nil + } + } else { + // Fire off a node specific query + // m := new(dns.Msg) + // m.SetQuestion(inp.Name, inp.Type) + // m.RecursionDesired = false + var buf []byte + if err := c.sendQuery(buf); err != nil { + logger.Errorf(context.TODO(), "[ERR] mdns: Failed to query instance %s: %v", inp.Name, err) + } + } + case <-ctx.Done(): + return nil + } + } + +} + +// sendQuery is used to multicast a query out +func (c *client) sendQuery(buf []byte) error { + if c.ipv4UnicastConn != nil { + c.ipv4UnicastConn.WriteToUDP(buf, ipv4Addr) + } + if c.ipv6UnicastConn != nil { + c.ipv6UnicastConn.WriteToUDP(buf, ipv6Addr) + } + return nil +} + +// recv is used to receive until we get a shutdown +func (c *client) recv(l *net.UDPConn, msgCh chan []byte) { + if l == nil { + return + } + buf := make([]byte, 65536) + for { + select { + case <-c.closedCh: + return + default: + c.closeLock.Lock() + if c.closed { + c.closeLock.Unlock() + return + } + c.closeLock.Unlock() + n, err := l.Read(buf) + if err != nil { + if logger.V(logger.DebugLevel) { + logger.Debug(context.TODO(), err) + } + continue + } + msgCh <- buf[:n] + } + } +} + +/* +// ensureName is used to ensure the named node is in progress +func ensureName(inprogress map[string]*ServiceEntry, name string, typ uint16) *ServiceEntry { + if inp, ok := inprogress[name]; ok { + return inp + } + inp := &ServiceEntry{ + Name: name, + Type: typ, + } + inprogress[name] = inp + return inp +} + +// alias is used to setup an alias between two entries +func alias(inprogress map[string]*ServiceEntry, src, dst string, typ uint16) { + srcEntry := ensureName(inprogress, src, typ) + inprogress[dst] = srcEntry +} +*/ +func messageToEntry(m []byte, inprogress map[string]*ServiceEntry) *ServiceEntry { + var inp *ServiceEntry + /* + for _, answer := range append(m.Answers, m.Additionals...) { + // TODO(reddaly): Check that response corresponds to serviceAddr? + switch answer.Header.Type { + case dnsmessage.TypePTR: + rr := answer.Body.(*dnsmessage.PTRResource) + // Create new entry for this + inp = ensureName(inprogress, rr.Ptr, rr.Hdr.Rrtype) + if inp.complete() { + continue + } + case dnsmessage.TypeSRV: + // Check for a target mismatch + if rr.Target != rr.Hdr.Name { + alias(inprogress, rr.Hdr.Name, rr.Target, rr.Hdr.Rrtype) + } + + // Get the port + inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype) + if inp.complete() { + continue + } + inp.Host = rr.Target + inp.Port = int(rr.Port) + case dnsmessage.TypeTXT: + // Pull out the txt + inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype) + if inp.complete() { + continue + } + inp.Info = strings.Join(rr.Txt, "|") + inp.InfoFields = rr.Txt + inp.hasTXT = true + case dnsmessage.TypeA: + // Pull out the IP + inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype) + if inp.complete() { + continue + } + inp.AddrV4 = rr.A + case dnsmessage.TypeAAAA: + // Pull out the IP + inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype) + if inp.complete() { + continue + } + inp.AddrV6 = rr.AAAA + } + + if inp != nil { + inp.TTL = int(answer.Header().Ttl) + } + } + */ + return inp +} diff --git a/util/dns_sd.go b/util/dns_sd.go new file mode 100644 index 0000000..18444c3 --- /dev/null +++ b/util/dns_sd.go @@ -0,0 +1,84 @@ +package mdns + +import "github.com/miekg/dns" + +// DNSSDService is a service that complies with the DNS-SD (RFC 6762) and MDNS +// (RFC 6762) specs for local, multicast-DNS-based discovery. +// +// DNSSDService implements the Zone interface and wraps an MDNSService instance. +// To deploy an mDNS service that is compliant with DNS-SD, it's recommended to +// register only the wrapped instance with the server. +// +// Example usage: +// service := &mdns.DNSSDService{ +// MDNSService: &mdns.MDNSService{ +// Instance: "My Foobar Service", +// Service: "_foobar._tcp", +// Port: 8000, +// } +// } +// server, err := mdns.NewServer(&mdns.Config{Zone: service}) +// if err != nil { +// log.Fatalf("Error creating server: %v", err) +// } +// defer server.Shutdown() +type DNSSDService struct { + MDNSService *MDNSService +} + +// Records returns DNS records in response to a DNS question. +// +// This function returns the DNS response of the underlying MDNSService +// instance. It also returns a PTR record for a request for " +// _services._dns-sd._udp.", as described in section 9 of RFC 6763 +// ("Service Type Enumeration"), to allow browsing of the underlying MDNSService +// instance. +func (s *DNSSDService) Records(q dns.Question) []dns.RR { + var recs []dns.RR + if q.Name == "_services._dns-sd._udp."+s.MDNSService.Domain+"." { + recs = s.dnssdMetaQueryRecords(q) + } + return append(recs, s.MDNSService.Records(q)...) +} + +// dnssdMetaQueryRecords returns the DNS records in response to a "meta-query" +// issued to browse for DNS-SD services, as per section 9. of RFC6763. +// +// A meta-query has a name of the form "_services._dns-sd._udp." where +// Domain is a fully-qualified domain, such as "local." +func (s *DNSSDService) dnssdMetaQueryRecords(q dns.Question) []dns.RR { + // Intended behavior, as described in the RFC: + // ...it may be useful for network administrators to find the list of + // advertised service types on the network, even if those Service Names + // are just opaque identifiers and not particularly informative in + // isolation. + // + // For this purpose, a special meta-query is defined. A DNS query for PTR + // records with the name "_services._dns-sd._udp." yields a set of + // PTR records, where the rdata of each PTR record is the two-abel + // name, plus the same domain, e.g., "_http._tcp.". + // Including the domain in the PTR rdata allows for slightly better name + // compression in Unicast DNS responses, but only the first two labels are + // relevant for the purposes of service type enumeration. These two-label + // service types can then be used to construct subsequent Service Instance + // Enumeration PTR queries, in this or others, to discover + // instances of that service type. + return []dns.RR{ + &dns.PTR{ + Hdr: dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypePTR, + Class: dns.ClassINET, + Ttl: defaultTTL, + }, + Ptr: s.MDNSService.serviceAddr, + }, + } +} + +// Announcement returns DNS records that should be broadcast during the initial +// availability of the service, as described in section 8.3 of RFC 6762. +// TODO(reddaly): Add this when Announcement is added to the mdns.Zone interface. +//func (s *DNSSDService) Announcement() []dns.RR { +// return s.MDNSService.Announcement() +//} diff --git a/util/dns_sd_test.go b/util/dns_sd_test.go new file mode 100644 index 0000000..ab1522d --- /dev/null +++ b/util/dns_sd_test.go @@ -0,0 +1,71 @@ +// +build ignore + +package mdns + +import ( + "reflect" + "testing" + + "github.com/miekg/dns" +) + +type mockMDNSService struct{} + +func (s *mockMDNSService) Records(q dns.Question) []dns.RR { + return []dns.RR{ + &dns.PTR{ + Hdr: dns.RR_Header{ + Name: "fakerecord", + Rrtype: dns.TypePTR, + Class: dns.ClassINET, + Ttl: 42, + }, + Ptr: "fake.local.", + }, + } +} + +func (s *mockMDNSService) Announcement() []dns.RR { + return []dns.RR{ + &dns.PTR{ + Hdr: dns.RR_Header{ + Name: "fakeannounce", + Rrtype: dns.TypePTR, + Class: dns.ClassINET, + Ttl: 42, + }, + Ptr: "fake.local.", + }, + } +} + +func TestDNSSDServiceRecords(t *testing.T) { + s := &DNSSDService{ + MDNSService: &MDNSService{ + serviceAddr: "_foobar._tcp.local.", + Domain: "local", + }, + } + q := dns.Question{ + Name: "_services._dns-sd._udp.local.", + Qtype: dns.TypePTR, + Qclass: dns.ClassINET, + } + recs := s.Records(q) + if got, want := len(recs), 1; got != want { + t.Fatalf("s.Records(%v) returned %v records, want %v", q, got, want) + } + + want := dns.RR(&dns.PTR{ + Hdr: dns.RR_Header{ + Name: "_services._dns-sd._udp.local.", + Rrtype: dns.TypePTR, + Class: dns.ClassINET, + Ttl: defaultTTL, + }, + Ptr: "_foobar._tcp.local.", + }) + if got := recs[0]; !reflect.DeepEqual(got, want) { + t.Errorf("s.Records()[0] = %v, want %v", got, want) + } +} diff --git a/util/server.go b/util/server.go new file mode 100644 index 0000000..d6cf172 --- /dev/null +++ b/util/server.go @@ -0,0 +1,546 @@ +package mdns + +import ( + "fmt" + "log" + "net" + "sync" + "sync/atomic" + + "github.com/miekg/dns" + registry "github.com/unistack-org/micro/v3/register" + regutil "github.com/unistack-org/micro/v3/util/register" + "golang.org/x/net/dns/dnsmessage" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +var ( + mdnsGroupIPv4 = net.ParseIP("224.0.0.251") + mdnsGroupIPv6 = net.ParseIP("ff02::fb") + + // mDNS wildcard addresses + mdnsWildcardAddrIPv4 = &net.UDPAddr{ + IP: net.ParseIP("224.0.0.0"), + Port: 5353, + } + mdnsWildcardAddrIPv6 = &net.UDPAddr{ + IP: net.ParseIP("ff02::"), + Port: 5353, + } + + // mDNS endpoint addresses + ipv4Addr = &net.UDPAddr{ + IP: mdnsGroupIPv4, + Port: 5353, + } + ipv6Addr = &net.UDPAddr{ + IP: mdnsGroupIPv6, + Port: 5353, + } +) + +// Config is used to configure the mDNS server +type Config struct { + // Zone must be provided to support responding to queries + Zone Zone + + // Iface if provided binds the multicast listener to the given + // interface. If not provided, the system default multicase interface + // is used. + Iface *net.Interface + + // Port If it is not 0, replace the port 5353 with this port number. + Port int +} + +// mDNS server is used to listen for mDNS queries and respond if we +// have a matching local record +type Server struct { + config *Config + + ipv4conn *net.UDPConn + ipv6conn *net.UDPConn + + shutdown bool + shutdownCh chan struct{} + shutdownLock sync.Mutex + wg sync.WaitGroup + + updates chan *registry.Service + services map[string][]*registry.Service + records map[string][]dnsmessage.Resource + + sync.RWMutex +} + +// NewServer is used to create a new mDNS server from a config +func NewServer(config *Config) (*Server, error) { + setCustomPort(config.Port) + + // Create the listeners + // Create wildcard connections (because :5353 can be already taken by other apps) + ipv4conn, _ := net.ListenUDP("udp4", mdnsWildcardAddrIPv4) + ipv6conn, _ := net.ListenUDP("udp6", mdnsWildcardAddrIPv6) + if ipv4conn == nil && ipv6conn == nil { + return nil, fmt.Errorf("[ERR] mdns: Failed to bind to any udp port!") + } + + if ipv4conn == nil { + ipv4conn = &net.UDPConn{} + } + if ipv6conn == nil { + ipv6conn = &net.UDPConn{} + } + + // Join multicast groups to receive announcements + p4 := ipv4.NewPacketConn(ipv4conn) + p6 := ipv6.NewPacketConn(ipv6conn) + p4.SetMulticastLoopback(true) + p6.SetMulticastLoopback(true) + + if config.Iface != nil { + if err := p4.JoinGroup(config.Iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { + return nil, err + } + if err := p6.JoinGroup(config.Iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { + return nil, err + } + } else { + ifaces, err := net.Interfaces() + if err != nil { + return nil, err + } + errCount1, errCount2 := 0, 0 + for _, iface := range ifaces { + if err := p4.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { + errCount1++ + } + if err := p6.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { + errCount2++ + } + } + if len(ifaces) == errCount1 && len(ifaces) == errCount2 { + return nil, fmt.Errorf("Failed to join multicast group on all interfaces!") + } + } + + s := &Server{ + config: config, + ipv4conn: ipv4conn, + ipv6conn: ipv6conn, + shutdownCh: make(chan struct{}), + records: make(map[string][]dnsmessage.Resource), + services: make(map[string][]*registry.Service), + updates: make(chan *registry.Service), + } + + go s.recv(s.ipv4conn) + go s.recv(s.ipv6conn) + + go s.update() + //s.wg.Add(1) + //go s.probe() + + return s, nil +} + +func (s *Server) update() { + var err error + var buf []byte + for svc := range s.updates { + fmt.Printf("update %#+v\n", svc) + if err = s.serviceToPacket(svc, buf); err != nil { + fmt.Printf("%v\n", err) + } else { + if s.sendResponse(buf, mdnsWildcardAddrIPv4); err != nil { + fmt.Printf("%v\n", err) + } + } + } +} + +// Shutdown is used to shutdown the listener +func (s *Server) Shutdown() error { + s.shutdownLock.Lock() + defer s.shutdownLock.Unlock() + + if s.shutdown { + return nil + } + + s.shutdown = true + close(s.shutdownCh) + s.unregister() + + if s.ipv4conn != nil { + s.ipv4conn.Close() + } + if s.ipv6conn != nil { + s.ipv6conn.Close() + } + + // s.wg.Wait() + return nil +} + +// recv is a long running routine to receive packets from an interface +func (s *Server) recv(c *net.UDPConn) { + if c == nil { + return + } + buf := make([]byte, 65536) + for { + s.shutdownLock.Lock() + if s.shutdown { + s.shutdownLock.Unlock() + return + } + s.shutdownLock.Unlock() + n, from, err := c.ReadFrom(buf) + if err != nil { + continue + } + if err := s.parsePacket(buf[:n], from); err != nil { + log.Printf("[ERR] mdns: Failed to handle query: %v", err) + } + } +} + +// parsePacket is used to parse an incoming packet +func (s *Server) parsePacket(buf []byte, from net.Addr) error { + var p dnsmessage.Parser + hdr, err := p.Start(buf) + if err != nil { + return err + } + return s.handleQuery(hdr, p, from) +} + +func (s *Server) LookupService(name string, opts ...registry.LookupOption) ([]*registry.Service, error) { + + return nil, nil +} + +func (s *Server) Register(service *registry.Service, opts ...registry.RegisterOption) error { + name := service.Name + ".micro." + s.Lock() + svcs, ok := s.services[name] + if !ok { + s.services[name] = []*registry.Service{service} + } else { + s.services[name] = regutil.Merge([]*registry.Service{service}, svcs) + } + s.Unlock() + s.updates <- service + return nil +} + +func (s *Server) serviceToPacket(svc *registry.Service, buf []byte) error { + var err error + var name dnsmessage.Name + + b := dnsmessage.NewBuilder(buf, dnsmessage.Header{}) + b.EnableCompression() + if err = b.StartAnswers(); err != nil { + return err + } + + if name, err = dnsmessage.NewName(svc.Name + ".micro."); err != nil { + return err + } + + if err = b.AResource(dnsmessage.ResourceHeader{Name: name, Class: dnsmessage.ClassINET, TTL: 60}, + dnsmessage.AResource{}); err != nil { + return err + } + + buf, err = b.Finish() + + return err +} + +// handleQuery is used to handle an incoming query +func (s *Server) handleQuery(hdr dnsmessage.Header, p dnsmessage.Parser, from net.Addr) error { + if hdr.OpCode != 0 { + // "In both multicast query and multicast response messages, the OPCODE MUST + // be zero on transmission (only standard queries are currently supported + // over multicast). Multicast DNS messages received with an OPCODE other + // than zero MUST be silently ignored." Note: OpcodeQuery == 0 + return fmt.Errorf("mdns: received query with non-zero Opcode %v: %v", hdr.OpCode, hdr) + } + if hdr.RCode != dnsmessage.RCodeSuccess { + // "In both multicast query and multicast response messages, the Response + // Code MUST be zero on transmission. Multicast DNS messages received with + // non-zero Response Codes MUST be silently ignored." + return fmt.Errorf("mdns: received query with non-zero Rcode %v: %v", hdr.RCode, hdr) + } + + // TODO(reddaly): Handle "TC (Truncated) Bit": + // In query messages, if the TC bit is set, it means that additional + // Known-Answer records may be following shortly. A responder SHOULD + // record this fact, and wait for those additional Known-Answer records, + // before deciding whether to respond. If the TC bit is clear, it means + // that the querying host has no additional Known Answers. + if hdr.Truncated { + return fmt.Errorf("[ERR] mdns: support for DNS requests with high truncated bit not implemented: %v", hdr) + } + + questions, err := p.AllQuestions() + if err != nil { + return err + } + + var unicastAnswer, multicastAnswer []dnsmessage.Resource + + for _, question := range questions { + mrecs, urecs := s.handleQuestion(question) + fmt.Printf("%#+v %#+v\n", mrecs, urecs) + multicastAnswer = append(multicastAnswer, mrecs...) + unicastAnswer = append(unicastAnswer, urecs...) + } + + rsp := func(unistact bool, buf []byte) error { + // See section 18 of RFC 6762 for rules about DNS headers. + // 18.1: ID (Query Identifier) + // 0 for multicast response, query.Id for unicast response + id := uint16(0) + if true /*unicast*/ { + id = hdr.ID + } + + hdrnew := dnsmessage.Header{ + ID: id, + // 18.2: QR (Query/Response) Bit - must be set to 1 in response. + Response: true, + // 18.3: OPCODE - must be zero in response (OpcodeQuery == 0) + OpCode: 0, + // 18.4: AA (Authoritative Answer) Bit - must be set to 1 + Authoritative: true, + // The following fields must all be set to 0: + // 18.5: TC (TRUNCATED) Bit + // 18.6: RD (Recursion Desired) Bit + // 18.7: RA (Recursion Available) Bit + // 18.8: Z (Zero) Bit + // 18.9: AD (Authentic Data) Bit + // 18.10: CD (Checking Disabled) Bit + // 18.11: RCODE (Response Code) + } + + b := dnsmessage.NewBuilder(buf, hdrnew) + b.EnableCompression() + + buf, err = b.Finish() + + return err + } + + var buf1 []byte + if err = rsp(false, buf1); err != nil { + return err + } + if len(buf1) > 0 { + if err := s.sendResponse(buf1, from); err != nil { + return fmt.Errorf("mdns: error sending multicast response: %v", err) + } + } + + var buf2 []byte + if err = rsp(true, buf2); err != nil { + return err + } + if len(buf2) > 0 { + if err := s.sendResponse(buf2, from); err != nil { + return fmt.Errorf("mdns: error sending unicast response: %v", err) + } + } + + return nil +} + +// handleQuestion is used to handle an incoming question +// +// The response to a question may be transmitted over multicast, unicast, or +// both. The return values are DNS records for each transmission type. +func (s *Server) handleQuestion(q dnsmessage.Question) (multicastRecs, unicastRecs []dnsmessage.Resource) { + records, ok := s.records[q.Name.String()] + + // Handle unicast and multicast responses. + // TODO(reddaly): The decision about sending over unicast vs. multicast is not + // yet fully compliant with RFC 6762. For example, the unicast bit should be + // ignored if the records in question are close to TTL expiration. For now, + // we just use the unicast bit to make the decision, as per the spec: + // RFC 6762, section 18.12. Repurposing of Top Bit of qclass in Question + // Section + // + // In the Question Section of a Multicast DNS query, the top bit of the + // qclass field is used to indicate that unicast responses are preferred + // for this particular question. (See Section 5.4.) + if ok { + if q.Class&(1<<15) != 0 { + return nil, records + } + + return records, nil + } + + services, ok := s.services[q.Name.String()] + if !ok { + return nil, nil + } + + fmt.Printf("%s\n", q.Name.String()) + fmt.Printf("%#+v\n", services) + return nil, nil +} + +/* +func (s *Server) probe() { + defer s.wg.Done() + + sd, ok := s.config.Zone.(*MDNSService) + if !ok { + return + } + + name := fmt.Sprintf("%s.%s.%s.", sd.Instance, trimDot(sd.Service), trimDot(sd.Domain)) + + q := new(dns.Msg) + q.SetQuestion(name, dns.TypePTR) + q.RecursionDesired = false + + srv := &dns.SRV{ + Hdr: dns.RR_Header{ + Name: name, + Rrtype: dns.TypeSRV, + Class: dns.ClassINET, + Ttl: defaultTTL, + }, + Priority: 0, + Weight: 0, + Port: uint16(sd.Port), + Target: sd.HostName, + } + txt := &dns.TXT{ + Hdr: dns.RR_Header{ + Name: name, + Rrtype: dns.TypeTXT, + Class: dns.ClassINET, + Ttl: defaultTTL, + }, + Txt: sd.TXT, + } + q.Ns = []dns.RR{srv, txt} + + randomizer := rand.New(rand.NewSource(time.Now().UnixNano())) + + for i := 0; i < 3; i++ { + if err := s.SendMulticast(q); err != nil { + log.Println("[ERR] mdns: failed to send probe:", err.Error()) + } + time.Sleep(time.Duration(randomizer.Intn(250)) * time.Millisecond) + } + + resp := new(dns.Msg) + resp.MsgHdr.Response = true + + // set for query + q.SetQuestion(name, dns.TypeANY) + + resp.Answer = append(resp.Answer, s.config.Zone.Records(q.Question[0])...) + + // reset + q.SetQuestion(name, dns.TypePTR) + + // From RFC6762 + // The Multicast DNS responder MUST send at least two unsolicited + // responses, one second apart. To provide increased robustness against + // packet loss, a responder MAY send up to eight unsolicited responses, + // provided that the interval between unsolicited responses increases by + // at least a factor of two with every response sent. + timeout := 1 * time.Second + timer := time.NewTimer(timeout) + for i := 0; i < 3; i++ { + if err := s.SendMulticast(resp); err != nil { + log.Println("[ERR] mdns: failed to send announcement:", err.Error()) + } + select { + case <-timer.C: + timeout *= 2 + timer.Reset(timeout) + case <-s.shutdownCh: + timer.Stop() + return + } + } +} +*/ + +// multicastResponse us used to send a multicast response packet +func (s *Server) SendMulticast(msg *dns.Msg) error { + buf, err := msg.Pack() + if err != nil { + return err + } + if s.ipv4conn != nil { + s.ipv4conn.WriteToUDP(buf, ipv4Addr) + } + if s.ipv6conn != nil { + s.ipv6conn.WriteToUDP(buf, ipv6Addr) + } + return nil +} + +// sendResponse is used to send a response packet +func (s *Server) sendResponse(buf []byte, from net.Addr) error { + var err error + + // TODO(reddaly): Respect the unicast argument, and allow sending responses + // over multicast. + + // Determine the socket to send from + addr := from.(*net.UDPAddr) + if addr.IP.To4() != nil { + _, err = s.ipv4conn.WriteToUDP(buf, addr) + } else { + _, err = s.ipv6conn.WriteToUDP(buf, addr) + } + + return err +} + +func (s *Server) unregister() error { + sd, ok := s.config.Zone.(*MDNSService) + if !ok { + return nil + } + + atomic.StoreUint32(&sd.TTL, 0) + name := fmt.Sprintf("%s.%s.%s.", sd.Instance, trimDot(sd.Service), trimDot(sd.Domain)) + + q := new(dns.Msg) + q.SetQuestion(name, dns.TypeANY) + + resp := new(dns.Msg) + resp.MsgHdr.Response = true + resp.Answer = append(resp.Answer, s.config.Zone.Records(q.Question[0])...) + + return s.SendMulticast(resp) +} + +func setCustomPort(port int) { + if port != 0 { + if mdnsWildcardAddrIPv4.Port != port { + mdnsWildcardAddrIPv4.Port = port + } + if mdnsWildcardAddrIPv6.Port != port { + mdnsWildcardAddrIPv6.Port = port + } + if ipv4Addr.Port != port { + ipv4Addr.Port = port + } + if ipv6Addr.Port != port { + ipv6Addr.Port = port + } + } +} diff --git a/util/server_test.go b/util/server_test.go new file mode 100644 index 0000000..45dfe99 --- /dev/null +++ b/util/server_test.go @@ -0,0 +1,65 @@ +// +build ignore + +package mdns + +import ( + "testing" + + registry "github.com/unistack-org/micro/v3/register" +) + +var ( + svc1 = ®istry.Service{ + Name: "foo", + Version: "latest", + Nodes: []*registry.Node{ + ®istry.Node{ + Id: "1", + Address: "127.0.0.1", + }, + }, + } +) + +func TestServer_StartStop(t *testing.T) { + //s := makeService(t) + srv, err := NewServer(&Config{}) + if err != nil { + t.Fatalf("err: %v", err) + } + if err = srv.Shutdown(); err != nil { + t.Fatalf("err: %v", err) + } +} + +func TestServer_Lookup(t *testing.T) { + srv1, err := NewServer(&Config{}) + if err != nil { + t.Fatalf("err: %v", err) + } + defer srv1.Shutdown() + /* + srv2, err := NewServer(&Config{}) + if err != nil { + t.Fatalf("err: %v", err) + } + defer srv2.Shutdown() + */ + if err = srv1.Register(svc1); err != nil { + t.Fatalf("err: %v", err) + } + /* + select {} + services, err := srv2.GetService("foo") + if err != nil { + t.Fatalf("err: %v", err) + } else if len(services) == 0 { + t.Fatalf("empty service") + } + + for _, svc := range services { + fmt.Printf("%#+v\n", svc) + } + */ + select {} +} diff --git a/util/zone.go b/util/zone.go new file mode 100644 index 0000000..abbab4b --- /dev/null +++ b/util/zone.go @@ -0,0 +1,309 @@ +package mdns + +import ( + "fmt" + "net" + "os" + "strings" + "sync/atomic" + + "github.com/miekg/dns" +) + +const ( + // defaultTTL is the default TTL value in returned DNS records in seconds. + defaultTTL = 120 +) + +// Zone is the interface used to integrate with the server and +// to serve records dynamically +type Zone interface { + // Records returns DNS records in response to a DNS question. + Records(q dns.Question) []dns.RR +} + +// MDNSService is used to export a named service by implementing a Zone +type MDNSService struct { + Instance string // Instance name (e.g. "hostService name") + Service string // Service name (e.g. "_http._tcp.") + Domain string // If blank, assumes "local" + HostName string // Host machine DNS name (e.g. "mymachine.net.") + Port int // Service Port + IPs []net.IP // IP addresses for the service's host + TXT []string // Service TXT records + TTL uint32 + serviceAddr string // Fully qualified service address + instanceAddr string // Fully qualified instance address + enumAddr string // _services._dns-sd._udp. +} + +// validateFQDN returns an error if the passed string is not a fully qualified +// hdomain name (more specifically, a hostname). +func validateFQDN(s string) error { + if len(s) == 0 { + return fmt.Errorf("FQDN must not be blank") + } + if s[len(s)-1] != '.' { + return fmt.Errorf("FQDN must end in period: %s", s) + } + // TODO(reddaly): Perform full validation. + + return nil +} + +// NewMDNSService returns a new instance of MDNSService. +// +// If domain, hostName, or ips is set to the zero value, then a default value +// will be inferred from the operating system. +// +// TODO(reddaly): This interface may need to change to account for "unique +// record" conflict rules of the mDNS protocol. Upon startup, the server should +// check to ensure that the instance name does not conflict with other instance +// names, and, if required, select a new name. There may also be conflicting +// hostName A/AAAA records. +func NewMDNSService(instance, service, domain, hostName string, port int, ips []net.IP, txt []string) (*MDNSService, error) { + // Sanity check inputs + if instance == "" { + return nil, fmt.Errorf("missing service instance name") + } + if service == "" { + return nil, fmt.Errorf("missing service name") + } + if port == 0 { + return nil, fmt.Errorf("missing service port") + } + + // Set default domain + if domain == "" { + domain = "local." + } + if err := validateFQDN(domain); err != nil { + return nil, fmt.Errorf("domain %q is not a fully-qualified domain name: %v", domain, err) + } + + // Get host information if no host is specified. + if hostName == "" { + var err error + hostName, err = os.Hostname() + if err != nil { + return nil, fmt.Errorf("could not determine host: %v", err) + } + hostName = fmt.Sprintf("%s.", hostName) + } + if err := validateFQDN(hostName); err != nil { + return nil, fmt.Errorf("hostName %q is not a fully-qualified domain name: %v", hostName, err) + } + + if len(ips) == 0 { + var err error + ips, err = net.LookupIP(trimDot(hostName)) + if err != nil { + // Try appending the host domain suffix and lookup again + // (required for Linux-based hosts) + tmpHostName := fmt.Sprintf("%s%s", hostName, domain) + + ips, err = net.LookupIP(trimDot(tmpHostName)) + + if err != nil { + return nil, fmt.Errorf("could not determine host IP addresses for %s", hostName) + } + } + } + for _, ip := range ips { + if ip.To4() == nil && ip.To16() == nil { + return nil, fmt.Errorf("invalid IP address in IPs list: %v", ip) + } + } + + return &MDNSService{ + Instance: instance, + Service: service, + Domain: domain, + HostName: hostName, + Port: port, + IPs: ips, + TXT: txt, + TTL: defaultTTL, + serviceAddr: fmt.Sprintf("%s.%s.", trimDot(service), trimDot(domain)), + instanceAddr: fmt.Sprintf("%s.%s.%s.", instance, trimDot(service), trimDot(domain)), + enumAddr: fmt.Sprintf("_services._dns-sd._udp.%s.", trimDot(domain)), + }, nil +} + +// trimDot is used to trim the dots from the start or end of a string +func trimDot(s string) string { + return strings.Trim(s, ".") +} + +// Records returns DNS records in response to a DNS question. +func (m *MDNSService) Records(q dns.Question) []dns.RR { + switch q.Name { + case m.enumAddr: + return m.serviceEnum(q) + case m.serviceAddr: + return m.serviceRecords(q) + case m.instanceAddr: + return m.instanceRecords(q) + case m.HostName: + if q.Qtype == dns.TypeA || q.Qtype == dns.TypeAAAA { + return m.instanceRecords(q) + } + fallthrough + default: + return nil + } +} + +func (m *MDNSService) serviceEnum(q dns.Question) []dns.RR { + switch q.Qtype { + case dns.TypeANY: + fallthrough + case dns.TypePTR: + rr := &dns.PTR{ + Hdr: dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypePTR, + Class: dns.ClassINET, + Ttl: atomic.LoadUint32(&m.TTL), + }, + Ptr: m.serviceAddr, + } + return []dns.RR{rr} + default: + return nil + } +} + +// serviceRecords is called when the query matches the service name +func (m *MDNSService) serviceRecords(q dns.Question) []dns.RR { + switch q.Qtype { + case dns.TypeANY: + fallthrough + case dns.TypePTR: + // Build a PTR response for the service + rr := &dns.PTR{ + Hdr: dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypePTR, + Class: dns.ClassINET, + Ttl: atomic.LoadUint32(&m.TTL), + }, + Ptr: m.instanceAddr, + } + servRec := []dns.RR{rr} + + // Get the instance records + instRecs := m.instanceRecords(dns.Question{ + Name: m.instanceAddr, + Qtype: dns.TypeANY, + }) + + // Return the service record with the instance records + return append(servRec, instRecs...) + default: + return nil + } +} + +// serviceRecords is called when the query matches the instance name +func (m *MDNSService) instanceRecords(q dns.Question) []dns.RR { + switch q.Qtype { + case dns.TypeANY: + // Get the SRV, which includes A and AAAA + recs := m.instanceRecords(dns.Question{ + Name: m.instanceAddr, + Qtype: dns.TypeSRV, + }) + + // Add the TXT record + recs = append(recs, m.instanceRecords(dns.Question{ + Name: m.instanceAddr, + Qtype: dns.TypeTXT, + })...) + return recs + + case dns.TypeA: + var rr []dns.RR + for _, ip := range m.IPs { + if ip4 := ip.To4(); ip4 != nil { + rr = append(rr, &dns.A{ + Hdr: dns.RR_Header{ + Name: m.HostName, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: atomic.LoadUint32(&m.TTL), + }, + A: ip4, + }) + } + } + return rr + + case dns.TypeAAAA: + var rr []dns.RR + for _, ip := range m.IPs { + if ip.To4() != nil { + // TODO(reddaly): IPv4 addresses could be encoded in IPv6 format and + // putinto AAAA records, but the current logic puts ipv4-encodable + // addresses into the A records exclusively. Perhaps this should be + // configurable? + continue + } + + if ip16 := ip.To16(); ip16 != nil { + rr = append(rr, &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: m.HostName, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: atomic.LoadUint32(&m.TTL), + }, + AAAA: ip16, + }) + } + } + return rr + + case dns.TypeSRV: + // Create the SRV Record + srv := &dns.SRV{ + Hdr: dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypeSRV, + Class: dns.ClassINET, + Ttl: atomic.LoadUint32(&m.TTL), + }, + Priority: 10, + Weight: 1, + Port: uint16(m.Port), + Target: m.HostName, + } + recs := []dns.RR{srv} + + // Add the A record + recs = append(recs, m.instanceRecords(dns.Question{ + Name: m.instanceAddr, + Qtype: dns.TypeA, + })...) + + // Add the AAAA record + recs = append(recs, m.instanceRecords(dns.Question{ + Name: m.instanceAddr, + Qtype: dns.TypeAAAA, + })...) + return recs + + case dns.TypeTXT: + txt := &dns.TXT{ + Hdr: dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypeTXT, + Class: dns.ClassINET, + Ttl: atomic.LoadUint32(&m.TTL), + }, + Txt: m.TXT, + } + return []dns.RR{txt} + } + return nil +} diff --git a/util/zone_test.go b/util/zone_test.go new file mode 100644 index 0000000..7c8476b --- /dev/null +++ b/util/zone_test.go @@ -0,0 +1,277 @@ +// +build ignore + +package mdns + +import ( + "bytes" + "net" + "reflect" + "testing" + + "github.com/miekg/dns" +) + +func makeService(t *testing.T) *MDNSService { + return makeServiceWithServiceName(t, "_http._tcp") +} + +func makeServiceWithServiceName(t *testing.T, service string) *MDNSService { + m, err := NewMDNSService( + "hostname", + service, + "local.", + "testhost.", + 80, // port + []net.IP{net.IP([]byte{192, 168, 0, 42}), net.ParseIP("2620:0:1000:1900:b0c2:d0b2:c411:18bc")}, + []string{"Local web server"}) // TXT + + if err != nil { + t.Fatalf("err: %v", err) + } + + return m +} + +func TestNewMDNSService_BadParams(t *testing.T) { + for _, test := range []struct { + testName string + hostName string + domain string + }{ + { + "NewMDNSService should fail when passed hostName that is not a legal fully-qualified domain name", + "hostname", // not legal FQDN - should be "hostname." or "hostname.local.", etc. + "local.", // legal + }, + { + "NewMDNSService should fail when passed domain that is not a legal fully-qualified domain name", + "hostname.", // legal + "local", // should be "local." + }, + } { + _, err := NewMDNSService( + "instance name", + "_http._tcp", + test.domain, + test.hostName, + 80, // port + []net.IP{net.IP([]byte{192, 168, 0, 42})}, + []string{"Local web server"}) // TXT + if err == nil { + t.Fatalf("%s: error expected, but got none", test.testName) + } + } +} + +func TestMDNSService_BadAddr(t *testing.T) { + s := makeService(t) + q := dns.Question{ + Name: "random", + Qtype: dns.TypeANY, + } + recs := s.Records(q) + if len(recs) != 0 { + t.Fatalf("bad: %v", recs) + } +} + +func TestMDNSService_ServiceAddr(t *testing.T) { + s := makeService(t) + q := dns.Question{ + Name: "_http._tcp.local.", + Qtype: dns.TypeANY, + } + recs := s.Records(q) + if got, want := len(recs), 5; got != want { + t.Fatalf("got %d records, want %d: %v", got, want, recs) + } + + if ptr, ok := recs[0].(*dns.PTR); !ok { + t.Errorf("recs[0] should be PTR record, got: %v, all records: %v", recs[0], recs) + } else if got, want := ptr.Ptr, "hostname._http._tcp.local."; got != want { + t.Fatalf("bad PTR record %v: got %v, want %v", ptr, got, want) + } + + if _, ok := recs[1].(*dns.SRV); !ok { + t.Errorf("recs[1] should be SRV record, got: %v, all reccords: %v", recs[1], recs) + } + if _, ok := recs[2].(*dns.A); !ok { + t.Errorf("recs[2] should be A record, got: %v, all records: %v", recs[2], recs) + } + if _, ok := recs[3].(*dns.AAAA); !ok { + t.Errorf("recs[3] should be AAAA record, got: %v, all records: %v", recs[3], recs) + } + if _, ok := recs[4].(*dns.TXT); !ok { + t.Errorf("recs[4] should be TXT record, got: %v, all records: %v", recs[4], recs) + } + + q.Qtype = dns.TypePTR + if recs2 := s.Records(q); !reflect.DeepEqual(recs, recs2) { + t.Fatalf("PTR question should return same result as ANY question: ANY => %v, PTR => %v", recs, recs2) + } +} + +func TestMDNSService_InstanceAddr_ANY(t *testing.T) { + s := makeService(t) + q := dns.Question{ + Name: "hostname._http._tcp.local.", + Qtype: dns.TypeANY, + } + recs := s.Records(q) + if len(recs) != 4 { + t.Fatalf("bad: %v", recs) + } + if _, ok := recs[0].(*dns.SRV); !ok { + t.Fatalf("bad: %v", recs[0]) + } + if _, ok := recs[1].(*dns.A); !ok { + t.Fatalf("bad: %v", recs[1]) + } + if _, ok := recs[2].(*dns.AAAA); !ok { + t.Fatalf("bad: %v", recs[2]) + } + if _, ok := recs[3].(*dns.TXT); !ok { + t.Fatalf("bad: %v", recs[3]) + } +} + +func TestMDNSService_InstanceAddr_SRV(t *testing.T) { + s := makeService(t) + q := dns.Question{ + Name: "hostname._http._tcp.local.", + Qtype: dns.TypeSRV, + } + recs := s.Records(q) + if len(recs) != 3 { + t.Fatalf("bad: %v", recs) + } + srv, ok := recs[0].(*dns.SRV) + if !ok { + t.Fatalf("bad: %v", recs[0]) + } + if _, ok := recs[1].(*dns.A); !ok { + t.Fatalf("bad: %v", recs[1]) + } + if _, ok := recs[2].(*dns.AAAA); !ok { + t.Fatalf("bad: %v", recs[2]) + } + + if srv.Port != uint16(s.Port) { + t.Fatalf("bad: %v", recs[0]) + } +} + +func TestMDNSService_InstanceAddr_A(t *testing.T) { + s := makeService(t) + q := dns.Question{ + Name: "hostname._http._tcp.local.", + Qtype: dns.TypeA, + } + recs := s.Records(q) + if len(recs) != 1 { + t.Fatalf("bad: %v", recs) + } + a, ok := recs[0].(*dns.A) + if !ok { + t.Fatalf("bad: %v", recs[0]) + } + if !bytes.Equal(a.A, []byte{192, 168, 0, 42}) { + t.Fatalf("bad: %v", recs[0]) + } +} + +func TestMDNSService_InstanceAddr_AAAA(t *testing.T) { + s := makeService(t) + q := dns.Question{ + Name: "hostname._http._tcp.local.", + Qtype: dns.TypeAAAA, + } + recs := s.Records(q) + if len(recs) != 1 { + t.Fatalf("bad: %v", recs) + } + a4, ok := recs[0].(*dns.AAAA) + if !ok { + t.Fatalf("bad: %v", recs[0]) + } + ip6 := net.ParseIP("2620:0:1000:1900:b0c2:d0b2:c411:18bc") + if got := len(ip6); got != net.IPv6len { + t.Fatalf("test IP failed to parse (len = %d, want %d)", got, net.IPv6len) + } + if !bytes.Equal(a4.AAAA, ip6) { + t.Fatalf("bad: %v", recs[0]) + } +} + +func TestMDNSService_InstanceAddr_TXT(t *testing.T) { + s := makeService(t) + q := dns.Question{ + Name: "hostname._http._tcp.local.", + Qtype: dns.TypeTXT, + } + recs := s.Records(q) + if len(recs) != 1 { + t.Fatalf("bad: %v", recs) + } + txt, ok := recs[0].(*dns.TXT) + if !ok { + t.Fatalf("bad: %v", recs[0]) + } + if got, want := txt.Txt, s.TXT; !reflect.DeepEqual(got, want) { + t.Fatalf("TXT record mismatch for %v: got %v, want %v", recs[0], got, want) + } +} + +func TestMDNSService_HostNameQuery(t *testing.T) { + s := makeService(t) + for _, test := range []struct { + q dns.Question + want []dns.RR + }{ + { + dns.Question{Name: "testhost.", Qtype: dns.TypeA}, + []dns.RR{&dns.A{ + Hdr: dns.RR_Header{ + Name: "testhost.", + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 120, + }, + A: net.IP([]byte{192, 168, 0, 42}), + }}, + }, + { + dns.Question{Name: "testhost.", Qtype: dns.TypeAAAA}, + []dns.RR{&dns.AAAA{ + Hdr: dns.RR_Header{ + Name: "testhost.", + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: 120, + }, + AAAA: net.ParseIP("2620:0:1000:1900:b0c2:d0b2:c411:18bc"), + }}, + }, + } { + if got := s.Records(test.q); !reflect.DeepEqual(got, test.want) { + t.Errorf("hostname query failed: s.Records(%v) = %v, want %v", test.q, got, test.want) + } + } +} + +func TestMDNSService_serviceEnum_PTR(t *testing.T) { + s := makeService(t) + q := dns.Question{ + Name: "_services._dns-sd._udp.local.", + Qtype: dns.TypePTR, + } + recs := s.Records(q) + if len(recs) != 1 { + t.Fatalf("bad: %v", recs) + } + if ptr, ok := recs[0].(*dns.PTR); !ok { + t.Errorf("recs[0] should be PTR record, got: %v, all records: %v", recs[0], recs) + } else if got, want := ptr.Ptr, "_http._tcp.local."; got != want { + t.Fatalf("bad PTR record %v: got %v, want %v", ptr, got, want) + } +}