many lint fixes and optimizations (#17)

* util/kubernetes: drop stale files
* debug/log/kubernetes: drop stale files
* util/scope: remove stale files
* util/mdns: drop stale files
* lint fixes

Signed-off-by: Vasiliy Tolstov <v.tolstov@unistack.org>
This commit is contained in:
2021-02-13 15:35:56 +03:00
committed by GitHub
parent abb9937787
commit 82248eb3b0
57 changed files with 246 additions and 5690 deletions

View File

@@ -60,10 +60,9 @@ func NewRoundTripper(opts ...Option) http.RoundTripper {
// RequestToContext puts the `Authorization` header bearer token into context
// so calls to services will be authorized.
func RequestToContext(r *http.Request) context.Context {
ctx := context.Background()
md := make(metadata.Metadata)
md := metadata.New(len(r.Header))
for k, v := range r.Header {
md[k] = strings.Join(v, ",")
md.Set(k, strings.Join(v, ","))
}
return metadata.NewContext(ctx, md)
return metadata.NewIncomingContext(r.Context(), md)
}

View File

@@ -1,178 +0,0 @@
package api
import (
"encoding/json"
"net/http"
"net/http/httptest"
"reflect"
"testing"
)
type testcase struct {
Token string
ReqFn func(opts *Options) *Request
Method string
URI string
Body interface{}
Header map[string]string
Assert func(req *http.Request) bool
}
var tests = []testcase{
{
ReqFn: func(opts *Options) *Request {
return NewRequest(opts).Get().Resource("service")
},
Method: "GET",
URI: "/api/v1/namespaces/default/services/",
},
{
ReqFn: func(opts *Options) *Request {
return NewRequest(opts).Get().Resource("service").Name("foo")
},
Method: "GET",
URI: "/api/v1/namespaces/default/services/foo",
},
{
ReqFn: func(opts *Options) *Request {
return NewRequest(opts).Get().Resource("service").Namespace("test").Name("bar")
},
Method: "GET",
URI: "/api/v1/namespaces/test/services/bar",
},
{
ReqFn: func(opts *Options) *Request {
return NewRequest(opts).Get().Resource("deployment").Name("foo")
},
Method: "GET",
URI: "/apis/apps/v1/namespaces/default/deployments/foo",
},
{
ReqFn: func(opts *Options) *Request {
return NewRequest(opts).Get().Resource("deployment").Namespace("test").Name("foo")
},
Method: "GET",
URI: "/apis/apps/v1/namespaces/test/deployments/foo",
},
{
ReqFn: func(opts *Options) *Request {
return NewRequest(opts).Get().Resource("pod").Params(&Params{LabelSelector: map[string]string{"foo": "bar"}})
},
Method: "GET",
URI: "/api/v1/namespaces/default/pods/?labelSelector=foo%3Dbar",
},
{
ReqFn: func(opts *Options) *Request {
return NewRequest(opts).Post().Resource("service").Name("foo").Body(map[string]string{"foo": "bar"})
},
Method: "POST",
URI: "/api/v1/namespaces/default/services/foo",
Body: map[string]string{"foo": "bar"},
},
{
ReqFn: func(opts *Options) *Request {
return NewRequest(opts).Post().Resource("deployment").Namespace("test").Name("foo").Body(map[string]string{"foo": "bar"})
},
Method: "POST",
URI: "/apis/apps/v1/namespaces/test/deployments/foo",
Body: map[string]string{"foo": "bar"},
},
{
ReqFn: func(opts *Options) *Request {
return NewRequest(opts).Put().Resource("endpoint").Name("baz").Body(map[string]string{"bam": "bar"})
},
Method: "PUT",
URI: "/api/v1/namespaces/default/endpoints/baz",
Body: map[string]string{"bam": "bar"},
},
{
ReqFn: func(opts *Options) *Request {
return NewRequest(opts).Patch().Resource("endpoint").Name("baz").Body(map[string]string{"bam": "bar"})
},
Method: "PATCH",
URI: "/api/v1/namespaces/default/endpoints/baz",
Body: map[string]string{"bam": "bar"},
},
{
ReqFn: func(opts *Options) *Request {
return NewRequest(opts).Patch().Resource("endpoint").Name("baz").SetHeader("foo", "bar")
},
Method: "PATCH",
URI: "/api/v1/namespaces/default/endpoints/baz",
Header: map[string]string{"foo": "bar"},
},
{
ReqFn: func(opts *Options) *Request {
return NewRequest(opts).Patch().Resource("deployment").Name("baz").SetHeader("foo", "bar")
},
Method: "PATCH",
URI: "/apis/apps/v1/namespaces/default/deployments/baz",
Header: map[string]string{"foo": "bar"},
},
{
ReqFn: func(opts *Options) *Request {
return NewRequest(opts).
Get().
Resource("pod").
SubResource("log").
Name("foolog")
},
Method: "GET",
URI: "/api/v1/namespaces/default/pods/foolog/log",
},
}
var wrappedHandler = func(test *testcase, t *testing.T) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
if len(test.Token) > 0 && (len(auth) == 0 || auth != "Bearer "+test.Token) {
t.Errorf("test case token (%s) did not match expected token (%s)", "Bearer "+test.Token, auth)
}
if len(test.Method) > 0 && test.Method != r.Method {
t.Errorf("test case Method (%s) did not match expected Method (%s)", test.Method, r.Method)
}
if len(test.URI) > 0 && test.URI != r.URL.RequestURI() {
t.Errorf("test case URI (%s) did not match expected URI (%s)", test.URI, r.URL.RequestURI())
}
if test.Body != nil {
var res map[string]string
decoder := json.NewDecoder(r.Body)
if err := decoder.Decode(&res); err != nil {
t.Errorf("decoding body failed: %v", err)
}
if !reflect.DeepEqual(res, test.Body) {
t.Error("body did not match")
}
}
if test.Header != nil {
for k, v := range test.Header {
if r.Header.Get(k) != v {
t.Error("header did not exist")
}
}
}
w.WriteHeader(http.StatusOK)
})
}
func TestRequest(t *testing.T) {
for _, test := range tests {
ts := httptest.NewServer(wrappedHandler(&test, t))
req := test.ReqFn(&Options{
Host: ts.URL,
Client: &http.Client{},
BearerToken: &test.Token,
Namespace: "default",
})
res := req.Do()
if res.Error() != nil {
t.Errorf("request failed with %v", res.Error())
}
ts.Close()
}
}

View File

@@ -1,271 +0,0 @@
package api
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"github.com/unistack-org/micro/v3/logger"
)
// Request is used to construct a http request for the k8s API.
type Request struct {
// the request context
context context.Context
client *http.Client
header http.Header
params url.Values
method string
host string
namespace string
resource string
resourceName *string
subResource *string
body io.Reader
err error
}
// Params is the object to pass in to set parameters
// on a request.
type Params struct {
LabelSelector map[string]string
Annotations map[string]string
Additional map[string]string
}
// verb sets method
func (r *Request) verb(method string) *Request {
r.method = method
return r
}
func (r *Request) Context(ctx context.Context) {
r.context = ctx
}
// Get request
func (r *Request) Get() *Request {
return r.verb("GET")
}
// Post request
func (r *Request) Post() *Request {
return r.verb("POST")
}
// Put request
func (r *Request) Put() *Request {
return r.verb("PUT")
}
// Patch request
func (r *Request) Patch() *Request {
return r.verb("PATCH")
}
// Delete request
func (r *Request) Delete() *Request {
return r.verb("DELETE")
}
// Namespace is to set the namespace to operate on
func (r *Request) Namespace(s string) *Request {
if len(s) > 0 {
r.namespace = s
}
return r
}
// Resource is the type of resource the operation is
// for, such as "services", "endpoints" or "pods"
func (r *Request) Resource(s string) *Request {
r.resource = s
return r
}
// SubResource sets a sub resource on a resource,
// e.g. pods/log for pod logs
func (r *Request) SubResource(s string) *Request {
r.subResource = &s
return r
}
// Name is for targeting a specific resource by id
func (r *Request) Name(s string) *Request {
r.resourceName = &s
return r
}
// Body pass in a body to set, this is for POST, PUT and PATCH requests
func (r *Request) Body(in interface{}) *Request {
b := new(bytes.Buffer)
// if we're not sending YAML request, we encode to JSON
if r.header.Get("Content-Type") != "application/yaml" {
if err := json.NewEncoder(b).Encode(&in); err != nil {
r.err = err
return r
}
r.body = b
return r
}
// if application/yaml is set, we assume we get a raw bytes so we just copy over
body, ok := in.(io.Reader)
if !ok {
r.err = errors.New("invalid data")
return r
}
// copy over data to the bytes buffer
if _, err := io.Copy(b, body); err != nil {
r.err = err
return r
}
r.body = b
return r
}
// Params is used to set parameters on a request
func (r *Request) Params(p *Params) *Request {
for k, v := range p.LabelSelector {
// create new key=value pair
value := fmt.Sprintf("%s=%s", k, v)
// check if there's an existing value
if label := r.params.Get("labelSelector"); len(label) > 0 {
value = fmt.Sprintf("%s,%s", label, value)
}
// set and overwrite the value
r.params.Set("labelSelector", value)
}
for k, v := range p.Additional {
r.params.Set(k, v)
}
return r
}
// SetHeader sets a header on a request with
// a `key` and `value`
func (r *Request) SetHeader(key, value string) *Request {
r.header.Add(key, value)
return r
}
// request builds the http.Request from the options
func (r *Request) request() (*http.Request, error) {
var url string
switch r.resource {
case "namespace":
// /api/v1/namespaces/
url = fmt.Sprintf("%s/api/v1/namespaces/", r.host)
case "deployment":
// /apis/apps/v1/namespaces/{namespace}/deployments/{name}
url = fmt.Sprintf("%s/apis/apps/v1/namespaces/%s/%ss/", r.host, r.namespace, r.resource)
default:
// /api/v1/namespaces/{namespace}/{resource}
url = fmt.Sprintf("%s/api/v1/namespaces/%s/%ss/", r.host, r.namespace, r.resource)
}
// append resourceName if it is present
if r.resourceName != nil {
url += *r.resourceName
if r.subResource != nil {
url += "/" + *r.subResource
}
}
// append any query params
if len(r.params) > 0 {
url += "?" + r.params.Encode()
}
var req *http.Request
var err error
// build request
if r.context != nil {
req, err = http.NewRequestWithContext(r.context, r.method, url, r.body)
} else {
req, err = http.NewRequest(r.method, url, r.body)
}
if err != nil {
return nil, err
}
// set headers on request
req.Header = r.header
return req, nil
}
// Do builds and triggers the request
func (r *Request) Do() *Response {
if r.err != nil {
return &Response{
err: r.err,
}
}
req, err := r.request()
if err != nil {
return &Response{
err: err,
}
}
logger.Debug(context.TODO(), "[Kubernetes] %v %v", req.Method, req.URL.String())
res, err := r.client.Do(req)
if err != nil {
return &Response{
err: err,
}
}
// return res, err
return newResponse(res, err)
}
// Raw performs a Raw HTTP request to the Kubernetes API
func (r *Request) Raw() (*http.Response, error) {
req, err := r.request()
if err != nil {
return nil, err
}
res, err := r.client.Do(req)
if err != nil {
return nil, err
}
return res, nil
}
// Options ...
type Options struct {
Host string
Namespace string
BearerToken *string
Client *http.Client
}
// NewRequest creates a k8s api request
func NewRequest(opts *Options) *Request {
req := &Request{
header: make(http.Header),
params: make(url.Values),
client: opts.Client,
namespace: opts.Namespace,
host: opts.Host,
}
if opts.BearerToken != nil {
req.SetHeader("Authorization", "Bearer "+*opts.BearerToken)
}
return req
}

View File

@@ -1,94 +0,0 @@
package api
import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
)
// Errors ...
var (
ErrNotFound = errors.New("kubernetes: resource not found")
ErrDecode = errors.New("kubernetes: error decoding")
ErrUnknown = errors.New("kubernetes: unknown error")
)
// Status is an object that is returned when a request
// failed or delete succeeded.
type Status struct {
Kind string `json:"kind"`
Status string `json:"status"`
Message string `json:"message"`
Reason string `json:"reason"`
Code int `json:"code"`
}
// Response ...
type Response struct {
res *http.Response
err error
}
// Error returns an error
func (r *Response) Error() error {
return r.err
}
// StatusCode returns status code for response
func (r *Response) StatusCode() int {
return r.res.StatusCode
}
// Into decode body into `data`
func (r *Response) Into(data interface{}) error {
if r.err != nil {
return r.err
}
defer r.res.Body.Close()
decoder := json.NewDecoder(r.res.Body)
if err := decoder.Decode(&data); err != nil {
return fmt.Errorf("%v: %v", ErrDecode, err)
}
return r.err
}
func (r *Response) Close() error {
return r.res.Body.Close()
}
func newResponse(res *http.Response, err error) *Response {
r := &Response{
res: res,
err: err,
}
if err != nil {
return r
}
if r.res.StatusCode == http.StatusOK ||
r.res.StatusCode == http.StatusCreated ||
r.res.StatusCode == http.StatusNoContent {
// Non error status code
return r
}
if r.res.StatusCode == http.StatusNotFound {
r.err = ErrNotFound
return r
}
b, err := ioutil.ReadAll(r.res.Body)
if err == nil {
r.err = errors.New(string(b))
return r
}
r.err = ErrUnknown
return r
}

View File

@@ -1,401 +0,0 @@
// Package client provides an implementation of a restricted subset of kubernetes API client
package client
import (
"bytes"
"context"
"crypto/tls"
"errors"
"io"
"io/ioutil"
"net/http"
"os"
"path"
"regexp"
"strings"
"github.com/unistack-org/micro/v3/logger"
"github.com/unistack-org/micro/v3/util/kubernetes/api"
)
var (
// path to kubernetes service account token
serviceAccountPath = "/var/run/secrets/kubernetes.io/serviceaccount"
// ErrReadNamespace is returned when the names could not be read from service account
ErrReadNamespace = errors.New("Could not read namespace from service account secret")
// DefaultImage is default micro image
DefaultImage = "micro/micro"
// DefaultNamespace is the default k8s namespace
DefaultNamespace = "default"
)
// Client ...
type client struct {
opts *api.Options
}
// Kubernetes client
type Client interface {
// Create creates new API resource
Create(*Resource, ...CreateOption) error
// Get queries API resources
Get(*Resource, ...GetOption) error
// Update patches existing API object
Update(*Resource, ...UpdateOption) error
// Delete deletes API resource
Delete(*Resource, ...DeleteOption) error
// List lists API resources
List(*Resource, ...ListOption) error
// Log gets log for a pod
Log(*Resource, ...LogOption) (io.ReadCloser, error)
// Watch for events
Watch(*Resource, ...WatchOption) (Watcher, error)
}
// Create creates new API object
func (c *client) Create(r *Resource, opts ...CreateOption) error {
options := CreateOptions{
Namespace: c.opts.Namespace,
}
for _, o := range opts {
o(&options)
}
b := new(bytes.Buffer)
if err := renderTemplate(r.Kind, b, r.Value); err != nil {
return err
}
return api.NewRequest(c.opts).
Post().
SetHeader("Content-Type", "application/yaml").
Namespace(options.Namespace).
Resource(r.Kind).
Body(b).
Do().
Error()
}
var (
nameRegex = regexp.MustCompile("[^a-zA-Z0-9]+")
)
// SerializeResourceName removes all spacial chars from a string so it
// can be used as a k8s resource name
func SerializeResourceName(ns string) string {
return nameRegex.ReplaceAllString(ns, "-")
}
// Get queries API objects and stores the result in r
func (c *client) Get(r *Resource, opts ...GetOption) error {
options := GetOptions{
Namespace: c.opts.Namespace,
}
for _, o := range opts {
o(&options)
}
return api.NewRequest(c.opts).
Get().
Resource(r.Kind).
Namespace(options.Namespace).
Params(&api.Params{LabelSelector: options.Labels}).
Do().
Into(r.Value)
}
// Log returns logs for a pod
func (c *client) Log(r *Resource, opts ...LogOption) (io.ReadCloser, error) {
options := LogOptions{
Namespace: c.opts.Namespace,
}
for _, o := range opts {
o(&options)
}
req := api.NewRequest(c.opts).
Get().
Resource(r.Kind).
SubResource("log").
Name(r.Name).
Namespace(options.Namespace)
if options.Params != nil {
req.Params(&api.Params{Additional: options.Params})
}
resp, err := req.Raw()
if err != nil {
return nil, err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
resp.Body.Close()
return nil, errors.New(resp.Request.URL.String() + ": " + resp.Status)
}
return resp.Body, nil
}
// Update updates API object
func (c *client) Update(r *Resource, opts ...UpdateOption) error {
options := UpdateOptions{
Namespace: c.opts.Namespace,
}
for _, o := range opts {
o(&options)
}
req := api.NewRequest(c.opts).
Patch().
SetHeader("Content-Type", "application/strategic-merge-patch+json").
Resource(r.Kind).
Name(r.Name).
Namespace(options.Namespace)
switch r.Kind {
case "service":
req.Body(r.Value.(*Service))
case "deployment":
req.Body(r.Value.(*Deployment))
case "pod":
req.Body(r.Value.(*Pod))
default:
return errors.New("unsupported resource")
}
return req.Do().Error()
}
// Delete removes API object
func (c *client) Delete(r *Resource, opts ...DeleteOption) error {
options := DeleteOptions{
Namespace: c.opts.Namespace,
}
for _, o := range opts {
o(&options)
}
return api.NewRequest(c.opts).
Delete().
Resource(r.Kind).
Name(r.Name).
Namespace(options.Namespace).
Do().
Error()
}
// List lists API objects and stores the result in r
func (c *client) List(r *Resource, opts ...ListOption) error {
options := ListOptions{
Namespace: c.opts.Namespace,
}
for _, o := range opts {
o(&options)
}
return c.Get(r, GetNamespace(options.Namespace))
}
// Watch returns an event stream
func (c *client) Watch(r *Resource, opts ...WatchOption) (Watcher, error) {
options := WatchOptions{
Namespace: c.opts.Namespace,
}
for _, o := range opts {
o(&options)
}
// set the watch param
params := &api.Params{Additional: map[string]string{
"watch": "true",
}}
// get options params
if options.Params != nil {
for k, v := range options.Params {
params.Additional[k] = v
}
}
req := api.NewRequest(c.opts).
Get().
Resource(r.Kind).
Name(r.Name).
Namespace(options.Namespace).
Params(params)
return newWatcher(req)
}
// NewService returns default micro kubernetes service definition
func NewService(name, version, typ, namespace string) *Service {
if logger.V(logger.TraceLevel) {
logger.Trace(context.TODO(), "kubernetes default service: name: %s, version: %s", name, version)
}
Labels := map[string]string{
"name": name,
"version": version,
"micro": typ,
}
svcName := name
if len(version) > 0 {
// API service object name joins name and version over "-"
svcName = strings.Join([]string{name, version}, "-")
}
if len(namespace) == 0 {
namespace = DefaultNamespace
}
Metadata := &Metadata{
Name: svcName,
Namespace: SerializeResourceName(namespace),
Version: version,
Labels: Labels,
}
Spec := &ServiceSpec{
Type: "ClusterIP",
Selector: Labels,
Ports: []ServicePort{{
"service-port", 8080, "",
}},
}
return &Service{
Metadata: Metadata,
Spec: Spec,
}
}
// NewService returns default micro kubernetes deployment definition
func NewDeployment(name, version, typ, namespace string) *Deployment {
if logger.V(logger.TraceLevel) {
logger.Trace(context.TODO(), "kubernetes default deployment: name: %s, version: %s", name, version)
}
Labels := map[string]string{
"name": name,
"version": version,
"micro": typ,
}
depName := name
if len(version) > 0 {
// API deployment object name joins name and version over "-"
depName = strings.Join([]string{name, version}, "-")
}
if len(namespace) == 0 {
namespace = DefaultNamespace
}
Metadata := &Metadata{
Name: depName,
Namespace: SerializeResourceName(namespace),
Version: version,
Labels: Labels,
Annotations: map[string]string{},
}
// enable go modules by default
env := EnvVar{
Name: "GO111MODULE",
Value: "on",
}
Spec := &DeploymentSpec{
Replicas: 1,
Selector: &LabelSelector{
MatchLabels: Labels,
},
Template: &Template{
Metadata: Metadata,
PodSpec: &PodSpec{
Containers: []Container{{
Name: name,
Image: DefaultImage,
Env: []EnvVar{env},
Command: []string{},
Ports: []ContainerPort{{
Name: "service-port",
ContainerPort: 8080,
}},
ReadinessProbe: &Probe{
TCPSocket: TCPSocketAction{
Port: 8080,
},
PeriodSeconds: 10,
InitialDelaySeconds: 10,
},
}},
},
},
}
return &Deployment{
Metadata: Metadata,
Spec: Spec,
}
}
// NewLocalClient returns a client that can be used with `kubectl proxy`
func NewLocalClient(hosts ...string) *client {
c := &client{
opts: &api.Options{
Client: http.DefaultClient,
Namespace: "default",
},
}
if len(hosts) == 0 {
c.opts.Host = "http://localhost:8001"
} else {
c.opts.Host = hosts[0]
}
return c
}
// NewClusterClient creates a Kubernetes client for use from within a k8s pod.
func NewClusterClient() *client {
host := "https://" + os.Getenv("KUBERNETES_SERVICE_HOST") + ":" + os.Getenv("KUBERNETES_SERVICE_PORT")
s, err := os.Stat(serviceAccountPath)
if err != nil {
logger.Fatal(context.TODO(), err.Error())
}
if s == nil || !s.IsDir() {
logger.Fatal(context.TODO(), "service account not found")
}
token, err := ioutil.ReadFile(path.Join(serviceAccountPath, "token"))
if err != nil {
logger.Fatal(context.TODO(), err.Error())
}
t := string(token)
crt, err := CertPoolFromFile(path.Join(serviceAccountPath, "ca.crt"))
if err != nil {
logger.Fatal(context.TODO(), err.Error())
}
c := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: crt,
},
DisableCompression: true,
},
}
return &client{
opts: &api.Options{
Client: c,
Host: host,
BearerToken: &t,
Namespace: DefaultNamespace,
},
}
}

View File

@@ -1,107 +0,0 @@
package client
type CreateOptions struct {
Namespace string
}
type GetOptions struct {
Namespace string
Labels map[string]string
}
type UpdateOptions struct {
Namespace string
}
type DeleteOptions struct {
Namespace string
}
type ListOptions struct {
Namespace string
}
type LogOptions struct {
Namespace string
Params map[string]string
}
type WatchOptions struct {
Namespace string
Params map[string]string
}
type CreateOption func(*CreateOptions)
type GetOption func(*GetOptions)
type UpdateOption func(*UpdateOptions)
type DeleteOption func(*DeleteOptions)
type ListOption func(*ListOptions)
type LogOption func(*LogOptions)
type WatchOption func(*WatchOptions)
// LogParams provides additional params for logs
func LogParams(p map[string]string) LogOption {
return func(l *LogOptions) {
l.Params = p
}
}
// WatchParams used for watch params
func WatchParams(p map[string]string) WatchOption {
return func(w *WatchOptions) {
w.Params = p
}
}
// CreateNamespace sets the namespace for creating a resource
func CreateNamespace(ns string) CreateOption {
return func(o *CreateOptions) {
o.Namespace = SerializeResourceName(ns)
}
}
// GetNamespace sets the namespace for getting a resource
func GetNamespace(ns string) GetOption {
return func(o *GetOptions) {
o.Namespace = SerializeResourceName(ns)
}
}
// GetLabels sets the labels for when getting a resource
func GetLabels(ls map[string]string) GetOption {
return func(o *GetOptions) {
o.Labels = ls
}
}
// UpdateNamespace sets the namespace for updating a resource
func UpdateNamespace(ns string) UpdateOption {
return func(o *UpdateOptions) {
o.Namespace = SerializeResourceName(ns)
}
}
// DeleteNamespace sets the namespace for deleting a resource
func DeleteNamespace(ns string) DeleteOption {
return func(o *DeleteOptions) {
o.Namespace = SerializeResourceName(ns)
}
}
// ListNamespace sets the namespace for listing resources
func ListNamespace(ns string) ListOption {
return func(o *ListOptions) {
o.Namespace = SerializeResourceName(ns)
}
}
// LogNamespace sets the namespace for logging a resource
func LogNamespace(ns string) LogOption {
return func(o *LogOptions) {
o.Namespace = SerializeResourceName(ns)
}
}
// WatchNamespace sets the namespace for watching a resource
func WatchNamespace(ns string) WatchOption {
return func(o *WatchOptions) {
o.Namespace = SerializeResourceName(ns)
}
}

View File

@@ -1,227 +0,0 @@
package client
var templates = map[string]string{
"deployment": deploymentTmpl,
"service": serviceTmpl,
"namespace": namespaceTmpl,
"secret": secretTmpl,
"serviceaccount": serviceAccountTmpl,
}
var deploymentTmpl = `
apiVersion: apps/v1
kind: Deployment
metadata:
name: "{{ .Metadata.Name }}"
namespace: "{{ .Metadata.Namespace }}"
labels:
{{- with .Metadata.Labels }}
{{- range $key, $value := . }}
{{ $key }}: "{{ $value }}"
{{- end }}
{{- end }}
annotations:
{{- with .Metadata.Annotations }}
{{- range $key, $value := . }}
{{ $key }}: "{{ $value }}"
{{- end }}
{{- end }}
spec:
replicas: {{ .Spec.Replicas }}
selector:
matchLabels:
{{- with .Spec.Selector.MatchLabels }}
{{- range $key, $value := . }}
{{ $key }}: "{{ $value }}"
{{- end }}
{{- end }}
template:
metadata:
labels:
{{- with .Spec.Template.Metadata.Labels }}
{{- range $key, $value := . }}
{{ $key }}: "{{ $value }}"
{{- end }}
{{- end }}
annotations:
{{- with .Spec.Template.Metadata.Annotations }}
{{- range $key, $value := . }}
{{ $key }}: "{{ $value }}"
{{- end }}
{{- end }}
spec:
serviceAccountName: {{ .Spec.Template.PodSpec.ServiceAccountName }}
containers:
{{- with .Spec.Template.PodSpec.Containers }}
{{- range . }}
- name: {{ .Name }}
env:
{{- with .Env }}
{{- range . }}
- name: "{{ .Name }}"
value: "{{ .Value }}"
{{- if .ValueFrom }}
{{- with .ValueFrom }}
valueFrom:
{{- if .SecretKeyRef }}
{{- with .SecretKeyRef }}
secretKeyRef:
key: {{ .Key }}
name: {{ .Name }}
optional: {{ .Optional }}
{{- end }}
{{- end }}
{{- end }}
{{- end }}
{{- end }}
{{- end }}
args:
{{- range .Args }}
- {{.}}
{{- end }}
command:
{{- range .Command }}
- {{.}}
{{- end }}
image: {{ .Image }}
imagePullPolicy: Always
ports:
{{- with .Ports }}
{{- range . }}
- containerPort: {{ .ContainerPort }}
name: {{ .Name }}
{{- end }}
{{- end }}
{{- if .ReadinessProbe }}
{{- with .ReadinessProbe }}
readinessProbe:
{{- with .TCPSocket }}
tcpSocket:
{{- if .Host }}
host: {{ .Host }}
{{- end }}
port: {{ .Port }}
{{- end }}
initialDelaySeconds: {{ .InitialDelaySeconds }}
periodSeconds: {{ .PeriodSeconds }}
{{- end }}
{{- end }}
{{- if .Resources }}
{{- with .Resources }}
resources:
{{- if .Limits }}
{{- with .Limits }}
limits:
{{- if .Memory }}
memory: {{ .Memory }}
{{- end }}
{{- if .CPU }}
cpu: {{ .CPU }}
{{- end }}
{{- if .EphemeralStorage }}
ephemeral-storage: {{ .EphemeralStorage }}
{{- end }}
{{- end }}
{{- end }}
{{- if .Requests }}
{{- with .Requests }}
requests:
{{- if .Memory }}
memory: {{ .Memory }}
{{- end }}
{{- if .CPU }}
cpu: {{ .CPU }}
{{- end }}
{{- if .EphemeralStorage }}
ephemeral-storage: {{ .EphemeralStorage }}
{{- end }}
{{- end }}
{{- end }}
{{- end }}
{{- end }}
{{- end }}
{{- end }}
`
var serviceTmpl = `
apiVersion: v1
kind: Service
metadata:
name: "{{ .Metadata.Name }}"
namespace: "{{ .Metadata.Namespace }}"
labels:
{{- with .Metadata.Labels }}
{{- range $key, $value := . }}
{{ $key }}: "{{ $value }}"
{{- end }}
{{- end }}
spec:
selector:
{{- with .Spec.Selector }}
{{- range $key, $value := . }}
{{ $key }}: "{{ $value }}"
{{- end }}
{{- end }}
ports:
{{- with .Spec.Ports }}
{{- range . }}
- name: "{{ .Name }}"
port: {{ .Port }}
protocol: {{ .Protocol }}
{{- end }}
{{- end }}
`
var namespaceTmpl = `
apiVersion: v1
kind: Namespace
metadata:
name: "{{ .Metadata.Name }}"
labels:
{{- with .Metadata.Labels }}
{{- range $key, $value := . }}
{{ $key }}: "{{ $value }}"
{{- end }}
{{- end }}
`
//nolint:gosec
var secretTmpl = `
apiVersion: v1
kind: Secret
type: "{{ .Type }}"
metadata:
name: "{{ .Metadata.Name }}"
namespace: "{{ .Metadata.Namespace }}"
labels:
{{- with .Metadata.Labels }}
{{- range $key, $value := . }}
{{ $key }}: "{{ $value }}"
{{- end }}
{{- end }}
data:
{{- with .Data }}
{{- range $key, $value := . }}
{{ $key }}: "{{ $value }}"
{{- end }}
{{- end }}
`
var serviceAccountTmpl = `
apiVersion: v1
kind: ServiceAccount
metadata:
name: "{{ .Metadata.Name }}"
labels:
{{- with .Metadata.Labels }}
{{- range $key, $value := . }}
{{ $key }}: "{{ $value }}"
{{- end }}
{{- end }}
imagePullSecrets:
{{- with .ImagePullSecrets }}
{{- range . }}
- name: "{{ .Name }}"
{{- end }}
{{- end }}
`

View File

@@ -1,250 +0,0 @@
package client
// ContainerPort
type ContainerPort struct {
Name string `json:"name,omitempty"`
HostPort int `json:"hostPort,omitempty"`
ContainerPort int `json:"containerPort"`
Protocol string `json:"protocol,omitempty"`
}
// EnvVar is environment variable
type EnvVar struct {
Name string `json:"name"`
Value string `json:"value,omitempty"`
ValueFrom *EnvVarSource `json:"valueFrom,omitempty"`
}
// EnvVarSource represents a source for the value of an EnvVar.
type EnvVarSource struct {
SecretKeyRef *SecretKeySelector `json:"secretKeyRef,omitempty"`
}
// SecretKeySelector selects a key of a Secret.
type SecretKeySelector struct {
Key string `json:"key"`
Name string `json:"name"`
Optional bool `json:"optional,omitempty"`
}
type Condition struct {
Started string `json:"startedAt,omitempty"`
Reason string `json:"reason,omitempty"`
Message string `json:"message,omitempty"`
}
// Container defined container runtime values
type Container struct {
Name string `json:"name"`
Image string `json:"image"`
Env []EnvVar `json:"env,omitempty"`
Command []string `json:"command,omitempty"`
Args []string `json:"args,omitempty"`
Ports []ContainerPort `json:"ports,omitempty"`
ReadinessProbe *Probe `json:"readinessProbe,omitempty"`
Resources *ResourceRequirements `json:"resources,omitempty"`
}
// DeploymentSpec defines micro deployment spec
type DeploymentSpec struct {
Replicas int `json:"replicas,omitempty"`
Selector *LabelSelector `json:"selector"`
Template *Template `json:"template,omitempty"`
}
// DeploymentCondition describes the state of deployment
type DeploymentCondition struct {
LastUpdateTime string `json:"lastUpdateTime"`
Type string `json:"type"`
Reason string `json:"reason,omitempty"`
Message string `json:"message,omitempty"`
}
// DeploymentStatus is returned when querying deployment
type DeploymentStatus struct {
Replicas int `json:"replicas,omitempty"`
UpdatedReplicas int `json:"updatedReplicas,omitempty"`
ReadyReplicas int `json:"readyReplicas,omitempty"`
AvailableReplicas int `json:"availableReplicas,omitempty"`
UnavailableReplicas int `json:"unavailableReplicas,omitempty"`
Conditions []DeploymentCondition `json:"conditions,omitempty"`
}
// Deployment is Kubernetes deployment
type Deployment struct {
Metadata *Metadata `json:"metadata"`
Spec *DeploymentSpec `json:"spec,omitempty"`
Status *DeploymentStatus `json:"status,omitempty"`
}
// DeploymentList
type DeploymentList struct {
Items []Deployment `json:"items"`
}
// LabelSelector is a label query over a set of resources
// NOTE: we do not support MatchExpressions at the moment
type LabelSelector struct {
MatchLabels map[string]string `json:"matchLabels,omitempty"`
}
type LoadBalancerIngress struct {
IP string `json:"ip,omitempty"`
Hostname string `json:"hostname,omitempty"`
}
type LoadBalancerStatus struct {
Ingress []LoadBalancerIngress `json:"ingress,omitempty"`
}
// Metadata defines api object metadata
type Metadata struct {
Name string `json:"name,omitempty"`
Namespace string `json:"namespace,omitempty"`
Version string `json:"version,omitempty"`
Labels map[string]string `json:"labels,omitempty"`
Annotations map[string]string `json:"annotations,omitempty"`
}
// PodSpec is a pod
type PodSpec struct {
Containers []Container `json:"containers"`
ServiceAccountName string `json:"serviceAccountName"`
}
// PodList
type PodList struct {
Items []Pod `json:"items"`
}
// Pod is the top level item for a pod
type Pod struct {
Metadata *Metadata `json:"metadata"`
Spec *PodSpec `json:"spec,omitempty"`
Status *PodStatus `json:"status"`
}
// PodStatus
type PodStatus struct {
Conditions []PodCondition `json:"conditions,omitempty"`
Containers []ContainerStatus `json:"containerStatuses"`
PodIP string `json:"podIP"`
Phase string `json:"phase"`
Reason string `json:"reason"`
}
// PodCondition describes the state of pod
type PodCondition struct {
Type string `json:"type"`
Reason string `json:"reason,omitempty"`
Message string `json:"message,omitempty"`
}
type ContainerStatus struct {
State ContainerState `json:"state"`
}
type ContainerState struct {
Running *Condition `json:"running"`
Terminated *Condition `json:"terminated"`
Waiting *Condition `json:"waiting"`
}
// Resource is API resource
type Resource struct {
Name string
Kind string
Value interface{}
}
// ServicePort configures service ports
type ServicePort struct {
Name string `json:"name,omitempty"`
Port int `json:"port"`
Protocol string `json:"protocol,omitempty"`
}
// ServiceSpec provides service configuration
type ServiceSpec struct {
ClusterIP string `json:"clusterIP"`
Type string `json:"type,omitempty"`
Selector map[string]string `json:"selector,omitempty"`
Ports []ServicePort `json:"ports,omitempty"`
}
// ServiceStatus
type ServiceStatus struct {
LoadBalancer LoadBalancerStatus `json:"loadBalancer,omitempty"`
}
// Service is kubernetes service
type Service struct {
Metadata *Metadata `json:"metadata"`
Spec *ServiceSpec `json:"spec,omitempty"`
Status *ServiceStatus `json:"status,omitempty"`
}
// ServiceList
type ServiceList struct {
Items []Service `json:"items"`
}
// Template is micro deployment template
type Template struct {
Metadata *Metadata `json:"metadata,omitempty"`
PodSpec *PodSpec `json:"spec,omitempty"`
}
// Namespace is a Kubernetes Namespace
type Namespace struct {
Metadata *Metadata `json:"metadata,omitempty"`
}
// NamespaceList
type NamespaceList struct {
Items []Namespace `json:"items"`
}
// ImagePullSecret
type ImagePullSecret struct {
Name string `json:"name"`
}
// Secret
type Secret struct {
Type string `json:"type,omitempty"`
Data map[string]string `json:"data"`
Metadata *Metadata `json:"metadata,omitempty"`
}
// ServiceAccount
type ServiceAccount struct {
Metadata *Metadata `json:"metadata,omitempty"`
ImagePullSecrets []ImagePullSecret `json:"imagePullSecrets,omitempty"`
}
// Probe describes a health check to be performed against a container to determine whether it is alive or ready to receive traffic.
type Probe struct {
TCPSocket TCPSocketAction `json:"tcpSocket,omitempty"`
PeriodSeconds int `json:"periodSeconds"`
InitialDelaySeconds int `json:"initialDelaySeconds"`
}
// TCPSocketAction describes an action based on opening a socket
type TCPSocketAction struct {
Host string `json:"host,omitempty"`
Port int `json:"port,omitempty"`
}
// ResourceRequirements describes the compute resource requirements.
type ResourceRequirements struct {
Limits *ResourceLimits `json:"limits,omitempty"`
Requests *ResourceLimits `json:"requests,omitempty"`
}
// ResourceLimits describes the limits for a service
type ResourceLimits struct {
Memory string `json:"memory,omitempty"`
CPU string `json:"cpu,omitempty"`
EphemeralStorage string `json:"ephemeral-storage,omitempty"`
}

View File

@@ -1,104 +0,0 @@
package client
import (
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"io"
"io/ioutil"
"strings"
"text/template"
)
// renderTemplateFile renders template for a given resource into writer w
func renderTemplate(resource string, w io.Writer, data interface{}) error {
t := template.Must(template.New("kubernetes").Parse(templates[resource]))
if err := t.Execute(w, data); err != nil {
return err
}
return nil
}
// COPIED FROM
// https://github.com/kubernetes/kubernetes/blob/7a725418af4661067b56506faabc2d44c6d7703a/pkg/util/crypto/crypto.go
// CertPoolFromFile returns an x509.CertPool containing the certificates in the given PEM-encoded file.
// Returns an error if the file could not be read, a certificate could not be parsed, or if the file does not contain any certificates
func CertPoolFromFile(filename string) (*x509.CertPool, error) {
certs, err := certificatesFromFile(filename)
if err != nil {
return nil, err
}
pool := x509.NewCertPool()
for _, cert := range certs {
pool.AddCert(cert)
}
return pool, nil
}
// certificatesFromFile returns the x509.Certificates contained in the given PEM-encoded file.
// Returns an error if the file could not be read, a certificate could not be parsed, or if the file does not contain any certificates
func certificatesFromFile(file string) ([]*x509.Certificate, error) {
if len(file) == 0 {
return nil, errors.New("error reading certificates from an empty filename")
}
pemBlock, err := ioutil.ReadFile(file)
if err != nil {
return nil, err
}
certs, err := CertsFromPEM(pemBlock)
if err != nil {
return nil, fmt.Errorf("error reading %s: %s", file, err)
}
return certs, nil
}
// CertsFromPEM returns the x509.Certificates contained in the given PEM-encoded byte array
// Returns an error if a certificate could not be parsed, or if the data does not contain any certificates
func CertsFromPEM(pemCerts []byte) ([]*x509.Certificate, error) {
ok := false
certs := []*x509.Certificate{}
for len(pemCerts) > 0 {
var block *pem.Block
block, pemCerts = pem.Decode(pemCerts)
if block == nil {
break
}
// Only use PEM "CERTIFICATE" blocks without extra headers
if block.Type != "CERTIFICATE" || len(block.Headers) != 0 {
continue
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return certs, err
}
certs = append(certs, cert)
ok = true
}
if !ok {
return certs, errors.New("could not read any certificates")
}
return certs, nil
}
// Format is used to format a string value into a k8s valid name
func Format(v string) string {
// to lower case
v = strings.ToLower(v)
// / to dashes
v = strings.ReplaceAll(v, "/", "-")
// dots to dashes
v = strings.ReplaceAll(v, ".", "-")
// limit to 253 chars
if len(v) > 253 {
v = v[:253]
}
// return new name
return v
}

View File

@@ -1,47 +0,0 @@
package client
import (
"bytes"
"testing"
)
func TestTemplates(t *testing.T) {
name := "foo"
version := "123"
typ := "service"
namespace := "default"
// Render default service
s := NewService(name, version, typ, namespace)
bs := new(bytes.Buffer)
if err := renderTemplate(templates["service"], bs, s); err != nil {
t.Errorf("Failed to render kubernetes service: %v", err)
}
// Render default deployment
d := NewDeployment(name, version, typ, namespace)
bd := new(bytes.Buffer)
if err := renderTemplate(templates["deployment"], bd, d); err != nil {
t.Errorf("Failed to render kubernetes deployment: %v", err)
}
}
func TestFormatName(t *testing.T) {
testCases := []struct {
name string
expect string
}{
{"foobar", "foobar"},
{"foo-bar", "foo-bar"},
{"foo.bar", "foo-bar"},
{"Foo.Bar", "foo-bar"},
{"micro.foo.bar", "micro-foo-bar"},
}
for _, test := range testCases {
v := Format(test.name)
if v != test.expect {
t.Fatalf("Expected name %s for %s got: %s", test.expect, test.name, v)
}
}
}

View File

@@ -1,124 +0,0 @@
package client
import (
"bufio"
"context"
"encoding/json"
"errors"
"net/http"
"github.com/unistack-org/micro/v3/util/kubernetes/api"
)
const (
// EventTypes used
Added EventType = "ADDED"
Modified EventType = "MODIFIED"
Deleted EventType = "DELETED"
Error EventType = "ERROR"
)
// Watcher is used to watch for events
type Watcher interface {
// A channel of events
Chan() <-chan Event
// Stop the watcher
Stop()
}
// EventType defines the possible types of events.
type EventType string
// Event represents a single event to a watched resource.
type Event struct {
Type EventType `json:"type"`
Object json.RawMessage `json:"object"`
}
// bodyWatcher scans the body of a request for chunks
type bodyWatcher struct {
results chan Event
cancel func()
stop chan bool
res *http.Response
req *api.Request
}
// Changes returns the results channel
func (wr *bodyWatcher) Chan() <-chan Event {
return wr.results
}
// Stop cancels the request
func (wr *bodyWatcher) Stop() {
select {
case <-wr.stop:
return
default:
// cancel the request
wr.cancel()
// stop the watcher
close(wr.stop)
}
}
func (wr *bodyWatcher) stream() {
reader := bufio.NewReader(wr.res.Body)
go func() {
for {
// read a line
b, err := reader.ReadBytes('\n')
if err != nil {
return
}
// send the event
var event Event
if err := json.Unmarshal(b, &event); err != nil {
continue
}
select {
case <-wr.stop:
return
case wr.results <- event:
}
}
}()
}
// newWatcher creates a k8s body watcher for
// a given http request
func newWatcher(req *api.Request) (Watcher, error) {
// set request context so we can cancel the request
ctx, cancel := context.WithCancel(context.Background())
req.Context(ctx)
// do the raw request
res, err := req.Raw()
if err != nil {
cancel()
return nil, err
}
if res.StatusCode < 200 || res.StatusCode >= 300 {
cancel()
// close the response body
res.Body.Close()
// return an error
return nil, errors.New(res.Request.URL.String() + ": " + res.Status)
}
wr := &bodyWatcher{
results: make(chan Event),
stop: make(chan bool),
cancel: cancel,
req: req,
res: res,
}
go wr.stream()
return wr, nil
}

23
util/mdns/.gitignore vendored
View File

@@ -1,23 +0,0 @@
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe
*.test

View File

@@ -1,511 +0,0 @@
package mdns
import (
"context"
"fmt"
"log"
"net"
"strings"
"sync"
"time"
"github.com/miekg/dns"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
// ServiceEntry is returned after we query for a service
type ServiceEntry struct {
Name string
Host string
AddrV4 net.IP
AddrV6 net.IP
Port int
Info string
InfoFields []string
TTL int
Type uint16
Addr net.IP // @Deprecated
hasTXT bool
sent bool
}
// complete is used to check if we have all the info we need
func (s *ServiceEntry) complete() bool {
return (len(s.AddrV4) > 0 || len(s.AddrV6) > 0 || len(s.Addr) > 0) && s.Port != 0 && s.hasTXT
}
// QueryParam is used to customize how a Lookup is performed
type QueryParam struct {
Service string // Service to lookup
Domain string // Lookup domain, default "local"
Type uint16 // Lookup type, defaults to dns.TypePTR
Context context.Context // Context
Timeout time.Duration // Lookup timeout, default 1 second. Ignored if Context is provided
Interface *net.Interface // Multicast interface to use
Entries chan<- *ServiceEntry // Entries Channel
WantUnicastResponse bool // Unicast response desired, as per 5.4 in RFC
}
// DefaultParams is used to return a default set of QueryParam's
func DefaultParams(service string) *QueryParam {
return &QueryParam{
Service: service,
Domain: "local",
Timeout: time.Second,
Entries: make(chan *ServiceEntry),
WantUnicastResponse: false, // TODO(reddaly): Change this default.
}
}
// Query looks up a given service, in a domain, waiting at most
// for a timeout before finishing the query. The results are streamed
// to a channel. Sends will not block, so clients should make sure to
// either read or buffer.
func Query(params *QueryParam) error {
// Create a new client
client, err := newClient()
if err != nil {
return err
}
defer client.Close()
// Set the multicast interface
if params.Interface != nil {
if err := client.setInterface(params.Interface, false); err != nil {
return err
}
}
// Ensure defaults are set
if params.Domain == "" {
params.Domain = "local"
}
if params.Context == nil {
var cancel context.CancelFunc
if params.Timeout == 0 {
params.Timeout = time.Second
}
params.Context, cancel = context.WithTimeout(context.Background(), params.Timeout)
defer cancel()
if err != nil {
return err
}
}
// Run the query
return client.query(params)
}
// Listen listens indefinitely for multicast updates
func Listen(entries chan<- *ServiceEntry, exit chan struct{}) error {
// Create a new client
client, err := newClient()
if err != nil {
return err
}
defer client.Close()
client.setInterface(nil, true)
// Start listening for response packets
msgCh := make(chan *dns.Msg, 32)
go client.recv(client.ipv4UnicastConn, msgCh)
go client.recv(client.ipv6UnicastConn, msgCh)
go client.recv(client.ipv4MulticastConn, msgCh)
go client.recv(client.ipv6MulticastConn, msgCh)
ip := make(map[string]*ServiceEntry)
loop:
for {
select {
case <-exit:
break loop
case <-client.closedCh:
break loop
case m := <-msgCh:
e := messageToEntry(m, ip)
if e == nil {
continue
}
// Check if this entry is complete
if e.complete() {
if e.sent {
continue
}
e.sent = true
entries <- e
ip = make(map[string]*ServiceEntry)
} else {
// Fire off a node specific query
m := new(dns.Msg)
m.SetQuestion(e.Name, dns.TypePTR)
m.RecursionDesired = false
if err := client.sendQuery(m); err != nil {
log.Printf("[ERR] mdns: Failed to query instance %s: %v", e.Name, err)
}
}
}
}
return nil
}
// Lookup is the same as Query, however it uses all the default parameters
func Lookup(service string, entries chan<- *ServiceEntry) error {
params := DefaultParams(service)
params.Entries = entries
return Query(params)
}
// Client provides a query interface that can be used to
// search for service providers using mDNS
type client struct {
ipv4UnicastConn *net.UDPConn
ipv6UnicastConn *net.UDPConn
ipv4MulticastConn *net.UDPConn
ipv6MulticastConn *net.UDPConn
closed bool
closedCh chan struct{} // TODO(reddaly): This doesn't appear to be used.
closeLock sync.Mutex
}
// NewClient creates a new mdns Client that can be used to query
// for records
func newClient() (*client, error) {
// TODO(reddaly): At least attempt to bind to the port required in the spec.
// Create a IPv4 listener
uconn4, err4 := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
uconn6, err6 := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err4 != nil && err6 != nil {
log.Printf("[ERR] mdns: Failed to bind to udp port: %v %v", err4, err6)
}
if uconn4 == nil && uconn6 == nil {
return nil, fmt.Errorf("failed to bind to any unicast udp port")
}
if uconn4 == nil {
uconn4 = &net.UDPConn{}
}
if uconn6 == nil {
uconn6 = &net.UDPConn{}
}
mconn4, err4 := net.ListenUDP("udp4", mdnsWildcardAddrIPv4)
mconn6, err6 := net.ListenUDP("udp6", mdnsWildcardAddrIPv6)
if err4 != nil && err6 != nil {
log.Printf("[ERR] mdns: Failed to bind to udp port: %v %v", err4, err6)
}
if mconn4 == nil && mconn6 == nil {
return nil, fmt.Errorf("failed to bind to any multicast udp port")
}
if mconn4 == nil {
mconn4 = &net.UDPConn{}
}
if mconn6 == nil {
mconn6 = &net.UDPConn{}
}
p1 := ipv4.NewPacketConn(mconn4)
p2 := ipv6.NewPacketConn(mconn6)
p1.SetMulticastLoopback(true)
p2.SetMulticastLoopback(true)
ifaces, err := net.Interfaces()
if err != nil {
return nil, err
}
var errCount1, errCount2 int
for _, iface := range ifaces {
if err := p1.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil {
errCount1++
}
if err := p2.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil {
errCount2++
}
}
if len(ifaces) == errCount1 && len(ifaces) == errCount2 {
return nil, fmt.Errorf("Failed to join multicast group on all interfaces!")
}
c := &client{
ipv4MulticastConn: mconn4,
ipv6MulticastConn: mconn6,
ipv4UnicastConn: uconn4,
ipv6UnicastConn: uconn6,
closedCh: make(chan struct{}),
}
return c, nil
}
// Close is used to cleanup the client
func (c *client) Close() error {
c.closeLock.Lock()
defer c.closeLock.Unlock()
if c.closed {
return nil
}
c.closed = true
close(c.closedCh)
if c.ipv4UnicastConn != nil {
c.ipv4UnicastConn.Close()
}
if c.ipv6UnicastConn != nil {
c.ipv6UnicastConn.Close()
}
if c.ipv4MulticastConn != nil {
c.ipv4MulticastConn.Close()
}
if c.ipv6MulticastConn != nil {
c.ipv6MulticastConn.Close()
}
return nil
}
// setInterface is used to set the query interface, uses system
// default if not provided
func (c *client) setInterface(iface *net.Interface, loopback bool) error {
p := ipv4.NewPacketConn(c.ipv4UnicastConn)
if err := p.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil {
return err
}
p2 := ipv6.NewPacketConn(c.ipv6UnicastConn)
if err := p2.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil {
return err
}
p = ipv4.NewPacketConn(c.ipv4MulticastConn)
if err := p.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil {
return err
}
p2 = ipv6.NewPacketConn(c.ipv6MulticastConn)
if err := p2.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil {
return err
}
if loopback {
p.SetMulticastLoopback(true)
p2.SetMulticastLoopback(true)
}
return nil
}
// query is used to perform a lookup and stream results
func (c *client) query(params *QueryParam) error {
// Create the service name
serviceAddr := fmt.Sprintf("%s.%s.", trimDot(params.Service), trimDot(params.Domain))
// Start listening for response packets
msgCh := make(chan *dns.Msg, 32)
go c.recv(c.ipv4UnicastConn, msgCh)
go c.recv(c.ipv6UnicastConn, msgCh)
go c.recv(c.ipv4MulticastConn, msgCh)
go c.recv(c.ipv6MulticastConn, msgCh)
// Send the query
m := new(dns.Msg)
if params.Type == dns.TypeNone {
m.SetQuestion(serviceAddr, dns.TypePTR)
} else {
m.SetQuestion(serviceAddr, params.Type)
}
// RFC 6762, section 18.12. Repurposing of Top Bit of qclass in Question
// Section
//
// In the Question Section of a Multicast DNS query, the top bit of the qclass
// field is used to indicate that unicast responses are preferred for this
// particular question. (See Section 5.4.)
if params.WantUnicastResponse {
m.Question[0].Qclass |= 1 << 15
}
m.RecursionDesired = false
if err := c.sendQuery(m); err != nil {
return err
}
// Map the in-progress responses
inprogress := make(map[string]*ServiceEntry)
for {
select {
case resp := <-msgCh:
inp := messageToEntry(resp, inprogress)
if inp == nil {
continue
}
if len(resp.Question) == 0 || resp.Question[0].Name != m.Question[0].Name {
// discard anything which we've not asked for
continue
}
// Check if this entry is complete
if inp.complete() {
if inp.sent {
continue
}
inp.sent = true
select {
case params.Entries <- inp:
case <-params.Context.Done():
return nil
}
} else {
// Fire off a node specific query
m := new(dns.Msg)
m.SetQuestion(inp.Name, inp.Type)
m.RecursionDesired = false
if err := c.sendQuery(m); err != nil {
log.Printf("[ERR] mdns: Failed to query instance %s: %v", inp.Name, err)
}
}
case <-params.Context.Done():
return nil
}
}
}
// sendQuery is used to multicast a query out
func (c *client) sendQuery(q *dns.Msg) error {
buf, err := q.Pack()
if err != nil {
return err
}
if c.ipv4UnicastConn != nil {
c.ipv4UnicastConn.WriteToUDP(buf, ipv4Addr)
}
if c.ipv6UnicastConn != nil {
c.ipv6UnicastConn.WriteToUDP(buf, ipv6Addr)
}
return nil
}
// recv is used to receive until we get a shutdown
func (c *client) recv(l *net.UDPConn, msgCh chan *dns.Msg) {
if l == nil {
return
}
buf := make([]byte, 65536)
for {
c.closeLock.Lock()
if c.closed {
c.closeLock.Unlock()
return
}
c.closeLock.Unlock()
n, err := l.Read(buf)
if err != nil {
continue
}
msg := new(dns.Msg)
if err := msg.Unpack(buf[:n]); err != nil {
continue
}
select {
case msgCh <- msg:
case <-c.closedCh:
return
}
}
}
// ensureName is used to ensure the named node is in progress
func ensureName(inprogress map[string]*ServiceEntry, name string, typ uint16) *ServiceEntry {
if inp, ok := inprogress[name]; ok {
return inp
}
inp := &ServiceEntry{
Name: name,
Type: typ,
}
inprogress[name] = inp
return inp
}
// alias is used to setup an alias between two entries
func alias(inprogress map[string]*ServiceEntry, src, dst string, typ uint16) {
srcEntry := ensureName(inprogress, src, typ)
inprogress[dst] = srcEntry
}
func messageToEntry(m *dns.Msg, inprogress map[string]*ServiceEntry) *ServiceEntry {
var inp *ServiceEntry
for _, answer := range append(m.Answer, m.Extra...) {
// TODO(reddaly): Check that response corresponds to serviceAddr?
switch rr := answer.(type) {
case *dns.PTR:
// Create new entry for this
inp = ensureName(inprogress, rr.Ptr, rr.Hdr.Rrtype)
if inp.complete() {
continue
}
case *dns.SRV:
// Check for a target mismatch
if rr.Target != rr.Hdr.Name {
alias(inprogress, rr.Hdr.Name, rr.Target, rr.Hdr.Rrtype)
}
// Get the port
inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype)
if inp.complete() {
continue
}
inp.Host = rr.Target
inp.Port = int(rr.Port)
case *dns.TXT:
// Pull out the txt
inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype)
if inp.complete() {
continue
}
inp.Info = strings.Join(rr.Txt, "|")
inp.InfoFields = rr.Txt
inp.hasTXT = true
case *dns.A:
// Pull out the IP
inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype)
if inp.complete() {
continue
}
inp.Addr = rr.A // @Deprecated
inp.AddrV4 = rr.A
case *dns.AAAA:
// Pull out the IP
inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype)
if inp.complete() {
continue
}
inp.Addr = rr.AAAA // @Deprecated
inp.AddrV6 = rr.AAAA
}
if inp != nil {
inp.TTL = int(answer.Header().Ttl)
}
}
return inp
}

View File

@@ -1,84 +0,0 @@
package mdns
import "github.com/miekg/dns"
// DNSSDService is a service that complies with the DNS-SD (RFC 6762) and MDNS
// (RFC 6762) specs for local, multicast-DNS-based discovery.
//
// DNSSDService implements the Zone interface and wraps an MDNSService instance.
// To deploy an mDNS service that is compliant with DNS-SD, it's recommended to
// register only the wrapped instance with the server.
//
// Example usage:
// service := &mdns.DNSSDService{
// MDNSService: &mdns.MDNSService{
// Instance: "My Foobar Service",
// Service: "_foobar._tcp",
// Port: 8000,
// }
// }
// server, err := mdns.NewServer(&mdns.Config{Zone: service})
// if err != nil {
// log.Fatalf("Error creating server: %v", err)
// }
// defer server.Shutdown()
type DNSSDService struct {
MDNSService *MDNSService
}
// Records returns DNS records in response to a DNS question.
//
// This function returns the DNS response of the underlying MDNSService
// instance. It also returns a PTR record for a request for "
// _services._dns-sd._udp.<Domain>", as described in section 9 of RFC 6763
// ("Service Type Enumeration"), to allow browsing of the underlying MDNSService
// instance.
func (s *DNSSDService) Records(q dns.Question) []dns.RR {
var recs []dns.RR
if q.Name == "_services._dns-sd._udp."+s.MDNSService.Domain+"." {
recs = s.dnssdMetaQueryRecords(q)
}
return append(recs, s.MDNSService.Records(q)...)
}
// dnssdMetaQueryRecords returns the DNS records in response to a "meta-query"
// issued to browse for DNS-SD services, as per section 9. of RFC6763.
//
// A meta-query has a name of the form "_services._dns-sd._udp.<Domain>" where
// Domain is a fully-qualified domain, such as "local."
func (s *DNSSDService) dnssdMetaQueryRecords(q dns.Question) []dns.RR {
// Intended behavior, as described in the RFC:
// ...it may be useful for network administrators to find the list of
// advertised service types on the network, even if those Service Names
// are just opaque identifiers and not particularly informative in
// isolation.
//
// For this purpose, a special meta-query is defined. A DNS query for PTR
// records with the name "_services._dns-sd._udp.<Domain>" yields a set of
// PTR records, where the rdata of each PTR record is the two-abel
// <Service> name, plus the same domain, e.g., "_http._tcp.<Domain>".
// Including the domain in the PTR rdata allows for slightly better name
// compression in Unicast DNS responses, but only the first two labels are
// relevant for the purposes of service type enumeration. These two-label
// service types can then be used to construct subsequent Service Instance
// Enumeration PTR queries, in this <Domain> or others, to discover
// instances of that service type.
return []dns.RR{
&dns.PTR{
Hdr: dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: defaultTTL,
},
Ptr: s.MDNSService.serviceAddr,
},
}
}
// Announcement returns DNS records that should be broadcast during the initial
// availability of the service, as described in section 8.3 of RFC 6762.
// TODO(reddaly): Add this when Announcement is added to the mdns.Zone interface.
//func (s *DNSSDService) Announcement() []dns.RR {
// return s.MDNSService.Announcement()
//}

View File

@@ -1,69 +0,0 @@
package mdns
import (
"reflect"
"testing"
"github.com/miekg/dns"
)
type mockMDNSService struct{}
func (s *mockMDNSService) Records(q dns.Question) []dns.RR {
return []dns.RR{
&dns.PTR{
Hdr: dns.RR_Header{
Name: "fakerecord",
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: 42,
},
Ptr: "fake.local.",
},
}
}
func (s *mockMDNSService) Announcement() []dns.RR {
return []dns.RR{
&dns.PTR{
Hdr: dns.RR_Header{
Name: "fakeannounce",
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: 42,
},
Ptr: "fake.local.",
},
}
}
func TestDNSSDServiceRecords(t *testing.T) {
s := &DNSSDService{
MDNSService: &MDNSService{
serviceAddr: "_foobar._tcp.local.",
Domain: "local",
},
}
q := dns.Question{
Name: "_services._dns-sd._udp.local.",
Qtype: dns.TypePTR,
Qclass: dns.ClassINET,
}
recs := s.Records(q)
if got, want := len(recs), 1; got != want {
t.Fatalf("s.Records(%v) returned %v records, want %v", q, got, want)
}
want := dns.RR(&dns.PTR{
Hdr: dns.RR_Header{
Name: "_services._dns-sd._udp.local.",
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: defaultTTL,
},
Ptr: "_foobar._tcp.local.",
})
if got := recs[0]; !reflect.DeepEqual(got, want) {
t.Errorf("s.Records()[0] = %v, want %v", got, want)
}
}

View File

@@ -1,527 +0,0 @@
package mdns
import (
"context"
"fmt"
"math/rand"
"net"
"sync"
"sync/atomic"
"time"
"github.com/miekg/dns"
"github.com/unistack-org/micro/v3/logger"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
var (
mdnsGroupIPv4 = net.ParseIP("224.0.0.251")
mdnsGroupIPv6 = net.ParseIP("ff02::fb")
// mDNS wildcard addresses
mdnsWildcardAddrIPv4 = &net.UDPAddr{
IP: net.ParseIP("224.0.0.0"),
Port: 5353,
}
mdnsWildcardAddrIPv6 = &net.UDPAddr{
IP: net.ParseIP("ff02::"),
Port: 5353,
}
// mDNS endpoint addresses
ipv4Addr = &net.UDPAddr{
IP: mdnsGroupIPv4,
Port: 5353,
}
ipv6Addr = &net.UDPAddr{
IP: mdnsGroupIPv6,
Port: 5353,
}
)
// GetMachineIP is a func which returns the outbound IP of this machine.
// Used by the server to determine whether to attempt send the response on a local address
type GetMachineIP func() net.IP
// Config is used to configure the mDNS server
type Config struct {
// Zone must be provided to support responding to queries
Zone Zone
// Iface if provided binds the multicast listener to the given
// interface. If not provided, the system default multicase interface
// is used.
Iface *net.Interface
// Port If it is not 0, replace the port 5353 with this port number.
Port int
// GetMachineIP is a function to return the IP of the local machine
GetMachineIP GetMachineIP
// LocalhostChecking if enabled asks the server to also send responses to 0.0.0.0 if the target IP
// is this host (as defined by GetMachineIP). Useful in case machine is on a VPN which blocks comms on non standard ports
LocalhostChecking bool
Context context.Context
}
// Server is an mDNS server used to listen for mDNS queries and respond if we
// have a matching local record
type Server struct {
config *Config
ipv4List *net.UDPConn
ipv6List *net.UDPConn
shutdown bool
shutdownCh chan struct{}
shutdownLock sync.Mutex
wg sync.WaitGroup
outboundIP net.IP
}
// NewServer is used to create a new mDNS server from a config
func NewServer(config *Config) (*Server, error) {
setCustomPort(config.Port)
// Create the listeners
// Create wildcard connections (because :5353 can be already taken by other apps)
ipv4List, _ := net.ListenUDP("udp4", mdnsWildcardAddrIPv4)
ipv6List, _ := net.ListenUDP("udp6", mdnsWildcardAddrIPv6)
if ipv4List == nil && ipv6List == nil {
return nil, fmt.Errorf("[ERR] mdns: Failed to bind to any udp port!")
}
if ipv4List == nil {
ipv4List = &net.UDPConn{}
}
if ipv6List == nil {
ipv6List = &net.UDPConn{}
}
// Join multicast groups to receive announcements
p1 := ipv4.NewPacketConn(ipv4List)
p2 := ipv6.NewPacketConn(ipv6List)
p1.SetMulticastLoopback(true)
p2.SetMulticastLoopback(true)
if config.Iface != nil {
if err := p1.JoinGroup(config.Iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil {
return nil, err
}
if err := p2.JoinGroup(config.Iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil {
return nil, err
}
} else {
ifaces, err := net.Interfaces()
if err != nil {
return nil, err
}
errCount1, errCount2 := 0, 0
for _, iface := range ifaces {
if err := p1.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil {
errCount1++
}
if err := p2.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil {
errCount2++
}
}
if len(ifaces) == errCount1 && len(ifaces) == errCount2 {
return nil, fmt.Errorf("Failed to join multicast group on all interfaces!")
}
}
ipFunc := getOutboundIP
if config.GetMachineIP != nil {
ipFunc = config.GetMachineIP
}
s := &Server{
config: config,
ipv4List: ipv4List,
ipv6List: ipv6List,
shutdownCh: make(chan struct{}),
outboundIP: ipFunc(),
}
if s.config.Context == nil {
s.config.Context = context.Background()
}
go s.recv(s.ipv4List)
go s.recv(s.ipv6List)
s.wg.Add(1)
go s.probe()
return s, nil
}
// Shutdown is used to shutdown the listener
func (s *Server) Shutdown() error {
s.shutdownLock.Lock()
defer s.shutdownLock.Unlock()
if s.shutdown {
return nil
}
s.shutdown = true
close(s.shutdownCh)
if err := s.unregister(); err != nil {
return err
}
if s.ipv4List != nil {
s.ipv4List.Close()
}
if s.ipv6List != nil {
s.ipv6List.Close()
}
s.wg.Wait()
return nil
}
// recv is a long running routine to receive packets from an interface
func (s *Server) recv(c *net.UDPConn) {
if c == nil {
return
}
buf := make([]byte, 65536)
for {
s.shutdownLock.Lock()
if s.shutdown {
s.shutdownLock.Unlock()
return
}
s.shutdownLock.Unlock()
n, from, err := c.ReadFrom(buf)
if err != nil {
continue
}
if err := s.parsePacket(buf[:n], from); err != nil {
logger.Errorf(s.config.Context, "[ERR] mdns: Failed to handle query: %v", err)
}
}
}
// parsePacket is used to parse an incoming packet
func (s *Server) parsePacket(packet []byte, from net.Addr) error {
var msg dns.Msg
if err := msg.Unpack(packet); err != nil {
logger.Errorf(s.config.Context, "[ERR] mdns: Failed to unpack packet: %v", err)
return err
}
// TODO: This is a bit of a hack
// We decided to ignore some mDNS answers for the time being
// See: https://tools.ietf.org/html/rfc6762#section-7.2
msg.Truncated = false
return s.handleQuery(&msg, from)
}
// handleQuery is used to handle an incoming query
func (s *Server) handleQuery(query *dns.Msg, from net.Addr) error {
if query.Opcode != dns.OpcodeQuery {
// "In both multicast query and multicast response messages, the OPCODE MUST
// be zero on transmission (only standard queries are currently supported
// over multicast). Multicast DNS messages received with an OPCODE other
// than zero MUST be silently ignored." Note: OpcodeQuery == 0
return fmt.Errorf("mdns: received query with non-zero Opcode %v: %v", query.Opcode, *query)
}
if query.Rcode != 0 {
// "In both multicast query and multicast response messages, the Response
// Code MUST be zero on transmission. Multicast DNS messages received with
// non-zero Response Codes MUST be silently ignored."
return fmt.Errorf("mdns: received query with non-zero Rcode %v: %v", query.Rcode, *query)
}
// TODO(reddaly): Handle "TC (Truncated) Bit":
// In query messages, if the TC bit is set, it means that additional
// Known-Answer records may be following shortly. A responder SHOULD
// record this fact, and wait for those additional Known-Answer records,
// before deciding whether to respond. If the TC bit is clear, it means
// that the querying host has no additional Known Answers.
if query.Truncated {
return fmt.Errorf("[ERR] mdns: support for DNS requests with high truncated bit not implemented: %v", *query)
}
unicastAnswer := make([]dns.RR, 0, len(query.Question))
multicastAnswer := make([]dns.RR, 0, len(query.Question))
// Handle each question
for _, q := range query.Question {
mrecs, urecs := s.handleQuestion(q)
multicastAnswer = append(multicastAnswer, mrecs...)
unicastAnswer = append(unicastAnswer, urecs...)
}
// See section 18 of RFC 6762 for rules about DNS headers.
resp := func(unicast bool) *dns.Msg {
// 18.1: ID (Query Identifier)
// 0 for multicast response, query.Id for unicast response
id := uint16(0)
if unicast {
id = query.Id
}
var answer []dns.RR
if unicast {
answer = unicastAnswer
} else {
answer = multicastAnswer
}
if len(answer) == 0 {
return nil
}
return &dns.Msg{
MsgHdr: dns.MsgHdr{
Id: id,
// 18.2: QR (Query/Response) Bit - must be set to 1 in response.
Response: true,
// 18.3: OPCODE - must be zero in response (OpcodeQuery == 0)
Opcode: dns.OpcodeQuery,
// 18.4: AA (Authoritative Answer) Bit - must be set to 1
Authoritative: true,
// The following fields must all be set to 0:
// 18.5: TC (TRUNCATED) Bit
// 18.6: RD (Recursion Desired) Bit
// 18.7: RA (Recursion Available) Bit
// 18.8: Z (Zero) Bit
// 18.9: AD (Authentic Data) Bit
// 18.10: CD (Checking Disabled) Bit
// 18.11: RCODE (Response Code)
},
// 18.12 pertains to questions (handled by handleQuestion)
// 18.13 pertains to resource records (handled by handleQuestion)
// 18.14: Name Compression - responses should be compressed (though see
// caveats in the RFC), so set the Compress bit (part of the dns library
// API, not part of the DNS packet) to true.
Compress: true,
Question: query.Question,
Answer: answer,
}
}
if mresp := resp(false); mresp != nil {
if err := s.sendResponse(mresp, from); err != nil {
return fmt.Errorf("mdns: error sending multicast response: %v", err)
}
}
if uresp := resp(true); uresp != nil {
if err := s.sendResponse(uresp, from); err != nil {
return fmt.Errorf("mdns: error sending unicast response: %v", err)
}
}
return nil
}
// handleQuestion is used to handle an incoming question
//
// The response to a question may be transmitted over multicast, unicast, or
// both. The return values are DNS records for each transmission type.
func (s *Server) handleQuestion(q dns.Question) (multicastRecs, unicastRecs []dns.RR) {
records := s.config.Zone.Records(q)
if len(records) == 0 {
return nil, nil
}
// Handle unicast and multicast responses.
// TODO(reddaly): The decision about sending over unicast vs. multicast is not
// yet fully compliant with RFC 6762. For example, the unicast bit should be
// ignored if the records in question are close to TTL expiration. For now,
// we just use the unicast bit to make the decision, as per the spec:
// RFC 6762, section 18.12. Repurposing of Top Bit of qclass in Question
// Section
//
// In the Question Section of a Multicast DNS query, the top bit of the
// qclass field is used to indicate that unicast responses are preferred
// for this particular question. (See Section 5.4.)
if q.Qclass&(1<<15) != 0 {
return nil, records
}
return records, nil
}
func (s *Server) probe() {
defer s.wg.Done()
sd, ok := s.config.Zone.(*MDNSService)
if !ok {
return
}
name := fmt.Sprintf("%s.%s.%s.", sd.Instance, trimDot(sd.Service), trimDot(sd.Domain))
q := new(dns.Msg)
q.SetQuestion(name, dns.TypePTR)
q.RecursionDesired = false
srv := &dns.SRV{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
Ttl: defaultTTL,
},
Priority: 0,
Weight: 0,
Port: uint16(sd.Port),
Target: sd.HostName,
}
txt := &dns.TXT{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeTXT,
Class: dns.ClassINET,
Ttl: defaultTTL,
},
Txt: sd.TXT,
}
q.Ns = []dns.RR{srv, txt}
randomizer := rand.New(rand.NewSource(time.Now().UnixNano()))
for i := 0; i < 3; i++ {
if err := s.SendMulticast(q); err != nil {
logger.Errorf(s.config.Context, "[ERR] mdns: failed to send probe: %v", err)
}
time.Sleep(time.Duration(randomizer.Intn(250)) * time.Millisecond)
}
resp := new(dns.Msg)
resp.MsgHdr.Response = true
// set for query
q.SetQuestion(name, dns.TypeANY)
resp.Answer = append(resp.Answer, s.config.Zone.Records(q.Question[0])...)
// reset
q.SetQuestion(name, dns.TypePTR)
// From RFC6762
// The Multicast DNS responder MUST send at least two unsolicited
// responses, one second apart. To provide increased robustness against
// packet loss, a responder MAY send up to eight unsolicited responses,
// provided that the interval between unsolicited responses increases by
// at least a factor of two with every response sent.
timeout := 1 * time.Second
timer := time.NewTimer(timeout)
for i := 0; i < 3; i++ {
if err := s.SendMulticast(resp); err != nil {
logger.Errorf(s.config.Context, "[ERR] mdns: failed to send announcement: %v", err)
}
select {
case <-timer.C:
timeout *= 2
timer.Reset(timeout)
case <-s.shutdownCh:
timer.Stop()
return
}
}
}
// SendMulticast us used to send a multicast response packet
func (s *Server) SendMulticast(msg *dns.Msg) error {
buf, err := msg.Pack()
if err != nil {
return err
}
if s.ipv4List != nil {
s.ipv4List.WriteToUDP(buf, ipv4Addr)
}
if s.ipv6List != nil {
s.ipv6List.WriteToUDP(buf, ipv6Addr)
}
return nil
}
// sendResponse is used to send a response packet
func (s *Server) sendResponse(resp *dns.Msg, from net.Addr) error {
// TODO(reddaly): Respect the unicast argument, and allow sending responses
// over multicast.
buf, err := resp.Pack()
if err != nil {
return err
}
// Determine the socket to send from
addr := from.(*net.UDPAddr)
conn := s.ipv4List
backupTarget := net.IPv4zero
if addr.IP.To4() == nil {
conn = s.ipv6List
backupTarget = net.IPv6zero
}
_, err = conn.WriteToUDP(buf, addr)
// If the address we're responding to is this machine then we can also attempt sending on 0.0.0.0
// This covers the case where this machine is using a VPN and certain ports are blocked so the response never gets there
// Sending two responses is OK
if s.config.LocalhostChecking && addr.IP.Equal(s.outboundIP) {
// ignore any errors, this is best efforts
conn.WriteToUDP(buf, &net.UDPAddr{IP: backupTarget, Port: addr.Port})
}
return err
}
func (s *Server) unregister() error {
sd, ok := s.config.Zone.(*MDNSService)
if !ok {
return nil
}
atomic.StoreUint32(&sd.TTL, 0)
name := fmt.Sprintf("%s.%s.%s.", sd.Instance, trimDot(sd.Service), trimDot(sd.Domain))
q := new(dns.Msg)
q.SetQuestion(name, dns.TypeANY)
resp := new(dns.Msg)
resp.MsgHdr.Response = true
resp.Answer = append(resp.Answer, s.config.Zone.Records(q.Question[0])...)
return s.SendMulticast(resp)
}
func setCustomPort(port int) {
if port != 0 {
if mdnsWildcardAddrIPv4.Port != port {
mdnsWildcardAddrIPv4.Port = port
}
if mdnsWildcardAddrIPv6.Port != port {
mdnsWildcardAddrIPv6.Port = port
}
if ipv4Addr.Port != port {
ipv4Addr.Port = port
}
if ipv6Addr.Port != port {
ipv6Addr.Port = port
}
}
}
// getOutboundIP returns the IP address of this machine as seen when dialling out
func getOutboundIP() net.IP {
conn, err := net.Dial("udp", "8.8.8.8:80")
if err != nil {
// no net connectivity maybe so fallback
return nil
}
defer conn.Close()
localAddr := conn.LocalAddr().(*net.UDPAddr)
return localAddr.IP
}

View File

@@ -1,61 +0,0 @@
package mdns
import (
"testing"
"time"
)
func TestServer_StartStop(t *testing.T) {
s := makeService(t)
serv, err := NewServer(&Config{Zone: s, LocalhostChecking: true})
if err != nil {
t.Fatalf("err: %v", err)
}
defer serv.Shutdown()
}
func TestServer_Lookup(t *testing.T) {
serv, err := NewServer(&Config{Zone: makeServiceWithServiceName(t, "_foobar._tcp"), LocalhostChecking: true})
if err != nil {
t.Fatalf("err: %v", err)
}
defer serv.Shutdown()
entries := make(chan *ServiceEntry, 1)
found := false
doneCh := make(chan struct{})
go func() {
select {
case e := <-entries:
if e.Name != "hostname._foobar._tcp.local." {
t.Fatalf("bad: %v", e)
}
if e.Port != 80 {
t.Fatalf("bad: %v", e)
}
if e.Info != "Local web server" {
t.Fatalf("bad: %v", e)
}
found = true
case <-time.After(80 * time.Millisecond):
t.Fatalf("timeout")
}
close(doneCh)
}()
params := &QueryParam{
Service: "_foobar._tcp",
Domain: "local",
Timeout: 50 * time.Millisecond,
Entries: entries,
}
err = Query(params)
if err != nil {
t.Fatalf("err: %v", err)
}
<-doneCh
if !found {
t.Fatalf("record not found")
}
}

View File

@@ -1,309 +0,0 @@
package mdns
import (
"fmt"
"net"
"os"
"strings"
"sync/atomic"
"github.com/miekg/dns"
)
const (
// defaultTTL is the default TTL value in returned DNS records in seconds.
defaultTTL = 120
)
// Zone is the interface used to integrate with the server and
// to serve records dynamically
type Zone interface {
// Records returns DNS records in response to a DNS question.
Records(q dns.Question) []dns.RR
}
// MDNSService is used to export a named service by implementing a Zone
type MDNSService struct {
Instance string // Instance name (e.g. "hostService name")
Service string // Service name (e.g. "_http._tcp.")
Domain string // If blank, assumes "local"
HostName string // Host machine DNS name (e.g. "mymachine.net.")
Port int // Service Port
IPs []net.IP // IP addresses for the service's host
TXT []string // Service TXT records
TTL uint32
serviceAddr string // Fully qualified service address
instanceAddr string // Fully qualified instance address
enumAddr string // _services._dns-sd._udp.<domain>
}
// validateFQDN returns an error if the passed string is not a fully qualified
// hdomain name (more specifically, a hostname).
func validateFQDN(s string) error {
if len(s) == 0 {
return fmt.Errorf("FQDN must not be blank")
}
if s[len(s)-1] != '.' {
return fmt.Errorf("FQDN must end in period: %s", s)
}
// TODO(reddaly): Perform full validation.
return nil
}
// NewMDNSService returns a new instance of MDNSService.
//
// If domain, hostName, or ips is set to the zero value, then a default value
// will be inferred from the operating system.
//
// TODO(reddaly): This interface may need to change to account for "unique
// record" conflict rules of the mDNS protocol. Upon startup, the server should
// check to ensure that the instance name does not conflict with other instance
// names, and, if required, select a new name. There may also be conflicting
// hostName A/AAAA records.
func NewMDNSService(instance, service, domain, hostName string, port int, ips []net.IP, txt []string) (*MDNSService, error) {
// Sanity check inputs
if instance == "" {
return nil, fmt.Errorf("missing service instance name")
}
if service == "" {
return nil, fmt.Errorf("missing service name")
}
if port == 0 {
return nil, fmt.Errorf("missing service port")
}
// Set default domain
if domain == "" {
domain = "local."
}
if err := validateFQDN(domain); err != nil {
return nil, fmt.Errorf("domain %q is not a fully-qualified domain name: %v", domain, err)
}
// Get host information if no host is specified.
if hostName == "" {
var err error
hostName, err = os.Hostname()
if err != nil {
return nil, fmt.Errorf("could not determine host: %v", err)
}
hostName = fmt.Sprintf("%s.", hostName)
}
if err := validateFQDN(hostName); err != nil {
return nil, fmt.Errorf("hostName %q is not a fully-qualified domain name: %v", hostName, err)
}
if len(ips) == 0 {
var err error
ips, err = net.LookupIP(trimDot(hostName))
if err != nil {
// Try appending the host domain suffix and lookup again
// (required for Linux-based hosts)
tmpHostName := fmt.Sprintf("%s%s", hostName, domain)
ips, err = net.LookupIP(trimDot(tmpHostName))
if err != nil {
return nil, fmt.Errorf("could not determine host IP addresses for %s", hostName)
}
}
}
for _, ip := range ips {
if ip.To4() == nil && ip.To16() == nil {
return nil, fmt.Errorf("invalid IP address in IPs list: %v", ip)
}
}
return &MDNSService{
Instance: instance,
Service: service,
Domain: domain,
HostName: hostName,
Port: port,
IPs: ips,
TXT: txt,
TTL: defaultTTL,
serviceAddr: fmt.Sprintf("%s.%s.", trimDot(service), trimDot(domain)),
instanceAddr: fmt.Sprintf("%s.%s.%s.", instance, trimDot(service), trimDot(domain)),
enumAddr: fmt.Sprintf("_services._dns-sd._udp.%s.", trimDot(domain)),
}, nil
}
// trimDot is used to trim the dots from the start or end of a string
func trimDot(s string) string {
return strings.Trim(s, ".")
}
// Records returns DNS records in response to a DNS question.
func (m *MDNSService) Records(q dns.Question) []dns.RR {
switch q.Name {
case m.enumAddr:
return m.serviceEnum(q)
case m.serviceAddr:
return m.serviceRecords(q)
case m.instanceAddr:
return m.instanceRecords(q)
case m.HostName:
if q.Qtype == dns.TypeA || q.Qtype == dns.TypeAAAA {
return m.instanceRecords(q)
}
fallthrough
default:
return nil
}
}
func (m *MDNSService) serviceEnum(q dns.Question) []dns.RR {
switch q.Qtype {
case dns.TypeANY:
fallthrough
case dns.TypePTR:
rr := &dns.PTR{
Hdr: dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: atomic.LoadUint32(&m.TTL),
},
Ptr: m.serviceAddr,
}
return []dns.RR{rr}
default:
return nil
}
}
// serviceRecords is called when the query matches the service name
func (m *MDNSService) serviceRecords(q dns.Question) []dns.RR {
switch q.Qtype {
case dns.TypeANY:
fallthrough
case dns.TypePTR:
// Build a PTR response for the service
rr := &dns.PTR{
Hdr: dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: atomic.LoadUint32(&m.TTL),
},
Ptr: m.instanceAddr,
}
servRec := []dns.RR{rr}
// Get the instance records
instRecs := m.instanceRecords(dns.Question{
Name: m.instanceAddr,
Qtype: dns.TypeANY,
})
// Return the service record with the instance records
return append(servRec, instRecs...)
default:
return nil
}
}
// serviceRecords is called when the query matches the instance name
func (m *MDNSService) instanceRecords(q dns.Question) []dns.RR {
switch q.Qtype {
case dns.TypeANY:
// Get the SRV, which includes A and AAAA
recs := m.instanceRecords(dns.Question{
Name: m.instanceAddr,
Qtype: dns.TypeSRV,
})
// Add the TXT record
recs = append(recs, m.instanceRecords(dns.Question{
Name: m.instanceAddr,
Qtype: dns.TypeTXT,
})...)
return recs
case dns.TypeA:
var rr []dns.RR
for _, ip := range m.IPs {
if ip4 := ip.To4(); ip4 != nil {
rr = append(rr, &dns.A{
Hdr: dns.RR_Header{
Name: m.HostName,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: atomic.LoadUint32(&m.TTL),
},
A: ip4,
})
}
}
return rr
case dns.TypeAAAA:
var rr []dns.RR
for _, ip := range m.IPs {
if ip.To4() != nil {
// TODO(reddaly): IPv4 addresses could be encoded in IPv6 format and
// putinto AAAA records, but the current logic puts ipv4-encodable
// addresses into the A records exclusively. Perhaps this should be
// configurable?
continue
}
if ip16 := ip.To16(); ip16 != nil {
rr = append(rr, &dns.AAAA{
Hdr: dns.RR_Header{
Name: m.HostName,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: atomic.LoadUint32(&m.TTL),
},
AAAA: ip16,
})
}
}
return rr
case dns.TypeSRV:
// Create the SRV Record
srv := &dns.SRV{
Hdr: dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
Ttl: atomic.LoadUint32(&m.TTL),
},
Priority: 10,
Weight: 1,
Port: uint16(m.Port),
Target: m.HostName,
}
recs := []dns.RR{srv}
// Add the A record
recs = append(recs, m.instanceRecords(dns.Question{
Name: m.instanceAddr,
Qtype: dns.TypeA,
})...)
// Add the AAAA record
recs = append(recs, m.instanceRecords(dns.Question{
Name: m.instanceAddr,
Qtype: dns.TypeAAAA,
})...)
return recs
case dns.TypeTXT:
txt := &dns.TXT{
Hdr: dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeTXT,
Class: dns.ClassINET,
Ttl: atomic.LoadUint32(&m.TTL),
},
Txt: m.TXT,
}
return []dns.RR{txt}
}
return nil
}

View File

@@ -1,275 +0,0 @@
package mdns
import (
"bytes"
"net"
"reflect"
"testing"
"github.com/miekg/dns"
)
func makeService(t *testing.T) *MDNSService {
return makeServiceWithServiceName(t, "_http._tcp")
}
func makeServiceWithServiceName(t *testing.T, service string) *MDNSService {
m, err := NewMDNSService(
"hostname",
service,
"local.",
"testhost.",
80, // port
[]net.IP{net.IP([]byte{192, 168, 0, 42}), net.ParseIP("2620:0:1000:1900:b0c2:d0b2:c411:18bc")},
[]string{"Local web server"}) // TXT
if err != nil {
t.Fatalf("err: %v", err)
}
return m
}
func TestNewMDNSService_BadParams(t *testing.T) {
for _, test := range []struct {
testName string
hostName string
domain string
}{
{
"NewMDNSService should fail when passed hostName that is not a legal fully-qualified domain name",
"hostname", // not legal FQDN - should be "hostname." or "hostname.local.", etc.
"local.", // legal
},
{
"NewMDNSService should fail when passed domain that is not a legal fully-qualified domain name",
"hostname.", // legal
"local", // should be "local."
},
} {
_, err := NewMDNSService(
"instance name",
"_http._tcp",
test.domain,
test.hostName,
80, // port
[]net.IP{net.IP([]byte{192, 168, 0, 42})},
[]string{"Local web server"}) // TXT
if err == nil {
t.Fatalf("%s: error expected, but got none", test.testName)
}
}
}
func TestMDNSService_BadAddr(t *testing.T) {
s := makeService(t)
q := dns.Question{
Name: "random",
Qtype: dns.TypeANY,
}
recs := s.Records(q)
if len(recs) != 0 {
t.Fatalf("bad: %v", recs)
}
}
func TestMDNSService_ServiceAddr(t *testing.T) {
s := makeService(t)
q := dns.Question{
Name: "_http._tcp.local.",
Qtype: dns.TypeANY,
}
recs := s.Records(q)
if got, want := len(recs), 5; got != want {
t.Fatalf("got %d records, want %d: %v", got, want, recs)
}
if ptr, ok := recs[0].(*dns.PTR); !ok {
t.Errorf("recs[0] should be PTR record, got: %v, all records: %v", recs[0], recs)
} else if got, want := ptr.Ptr, "hostname._http._tcp.local."; got != want {
t.Fatalf("bad PTR record %v: got %v, want %v", ptr, got, want)
}
if _, ok := recs[1].(*dns.SRV); !ok {
t.Errorf("recs[1] should be SRV record, got: %v, all reccords: %v", recs[1], recs)
}
if _, ok := recs[2].(*dns.A); !ok {
t.Errorf("recs[2] should be A record, got: %v, all records: %v", recs[2], recs)
}
if _, ok := recs[3].(*dns.AAAA); !ok {
t.Errorf("recs[3] should be AAAA record, got: %v, all records: %v", recs[3], recs)
}
if _, ok := recs[4].(*dns.TXT); !ok {
t.Errorf("recs[4] should be TXT record, got: %v, all records: %v", recs[4], recs)
}
q.Qtype = dns.TypePTR
if recs2 := s.Records(q); !reflect.DeepEqual(recs, recs2) {
t.Fatalf("PTR question should return same result as ANY question: ANY => %v, PTR => %v", recs, recs2)
}
}
func TestMDNSService_InstanceAddr_ANY(t *testing.T) {
s := makeService(t)
q := dns.Question{
Name: "hostname._http._tcp.local.",
Qtype: dns.TypeANY,
}
recs := s.Records(q)
if len(recs) != 4 {
t.Fatalf("bad: %v", recs)
}
if _, ok := recs[0].(*dns.SRV); !ok {
t.Fatalf("bad: %v", recs[0])
}
if _, ok := recs[1].(*dns.A); !ok {
t.Fatalf("bad: %v", recs[1])
}
if _, ok := recs[2].(*dns.AAAA); !ok {
t.Fatalf("bad: %v", recs[2])
}
if _, ok := recs[3].(*dns.TXT); !ok {
t.Fatalf("bad: %v", recs[3])
}
}
func TestMDNSService_InstanceAddr_SRV(t *testing.T) {
s := makeService(t)
q := dns.Question{
Name: "hostname._http._tcp.local.",
Qtype: dns.TypeSRV,
}
recs := s.Records(q)
if len(recs) != 3 {
t.Fatalf("bad: %v", recs)
}
srv, ok := recs[0].(*dns.SRV)
if !ok {
t.Fatalf("bad: %v", recs[0])
}
if _, ok := recs[1].(*dns.A); !ok {
t.Fatalf("bad: %v", recs[1])
}
if _, ok := recs[2].(*dns.AAAA); !ok {
t.Fatalf("bad: %v", recs[2])
}
if srv.Port != uint16(s.Port) {
t.Fatalf("bad: %v", recs[0])
}
}
func TestMDNSService_InstanceAddr_A(t *testing.T) {
s := makeService(t)
q := dns.Question{
Name: "hostname._http._tcp.local.",
Qtype: dns.TypeA,
}
recs := s.Records(q)
if len(recs) != 1 {
t.Fatalf("bad: %v", recs)
}
a, ok := recs[0].(*dns.A)
if !ok {
t.Fatalf("bad: %v", recs[0])
}
if !bytes.Equal(a.A, []byte{192, 168, 0, 42}) {
t.Fatalf("bad: %v", recs[0])
}
}
func TestMDNSService_InstanceAddr_AAAA(t *testing.T) {
s := makeService(t)
q := dns.Question{
Name: "hostname._http._tcp.local.",
Qtype: dns.TypeAAAA,
}
recs := s.Records(q)
if len(recs) != 1 {
t.Fatalf("bad: %v", recs)
}
a4, ok := recs[0].(*dns.AAAA)
if !ok {
t.Fatalf("bad: %v", recs[0])
}
ip6 := net.ParseIP("2620:0:1000:1900:b0c2:d0b2:c411:18bc")
if got := len(ip6); got != net.IPv6len {
t.Fatalf("test IP failed to parse (len = %d, want %d)", got, net.IPv6len)
}
if !bytes.Equal(a4.AAAA, ip6) {
t.Fatalf("bad: %v", recs[0])
}
}
func TestMDNSService_InstanceAddr_TXT(t *testing.T) {
s := makeService(t)
q := dns.Question{
Name: "hostname._http._tcp.local.",
Qtype: dns.TypeTXT,
}
recs := s.Records(q)
if len(recs) != 1 {
t.Fatalf("bad: %v", recs)
}
txt, ok := recs[0].(*dns.TXT)
if !ok {
t.Fatalf("bad: %v", recs[0])
}
if got, want := txt.Txt, s.TXT; !reflect.DeepEqual(got, want) {
t.Fatalf("TXT record mismatch for %v: got %v, want %v", recs[0], got, want)
}
}
func TestMDNSService_HostNameQuery(t *testing.T) {
s := makeService(t)
for _, test := range []struct {
q dns.Question
want []dns.RR
}{
{
dns.Question{Name: "testhost.", Qtype: dns.TypeA},
[]dns.RR{&dns.A{
Hdr: dns.RR_Header{
Name: "testhost.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 120,
},
A: net.IP([]byte{192, 168, 0, 42}),
}},
},
{
dns.Question{Name: "testhost.", Qtype: dns.TypeAAAA},
[]dns.RR{&dns.AAAA{
Hdr: dns.RR_Header{
Name: "testhost.",
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 120,
},
AAAA: net.ParseIP("2620:0:1000:1900:b0c2:d0b2:c411:18bc"),
}},
},
} {
if got := s.Records(test.q); !reflect.DeepEqual(got, test.want) {
t.Errorf("hostname query failed: s.Records(%v) = %v, want %v", test.q, got, test.want)
}
}
}
func TestMDNSService_serviceEnum_PTR(t *testing.T) {
s := makeService(t)
q := dns.Question{
Name: "_services._dns-sd._udp.local.",
Qtype: dns.TypePTR,
}
recs := s.Records(q)
if len(recs) != 1 {
t.Fatalf("bad: %v", recs)
}
if ptr, ok := recs[0].(*dns.PTR); !ok {
t.Errorf("recs[0] should be PTR record, got: %v, all records: %v", recs[0], recs)
} else if got, want := ptr.Ptr, "_http._tcp.local."; got != want {
t.Fatalf("bad PTR record %v: got %v, want %v", ptr, got, want)
}
}

View File

@@ -85,12 +85,12 @@ func CSR(opts ...CertOption) ([]byte, error) {
}
// Sign decodes a CSR and signs it with the CA
func Sign(CACrt, CAKey, CSR []byte, opts ...CertOption) ([]byte, error) {
func Sign(crt, key, csr []byte, opts ...CertOption) ([]byte, error) {
options := CertOptions{}
for _, o := range opts {
o(&options)
}
asn1CACrt, err := decodePEM(CACrt)
asn1CACrt, err := decodePEM(crt)
if err != nil {
return nil, fmt.Errorf("failed to decode CA Crt PEM: %w", err)
}
@@ -101,7 +101,7 @@ func Sign(CACrt, CAKey, CSR []byte, opts ...CertOption) ([]byte, error) {
if err != nil {
return nil, fmt.Errorf("ca is not a valid certificate: %w", err)
}
asn1CAKey, err := decodePEM(CAKey)
asn1CAKey, err := decodePEM(key)
if err != nil {
return nil, fmt.Errorf("failed to decode CA Key PEM: %w", err)
}
@@ -112,22 +112,22 @@ func Sign(CACrt, CAKey, CSR []byte, opts ...CertOption) ([]byte, error) {
if err != nil {
return nil, fmt.Errorf("ca key is not a valid private key: %w", err)
}
asn1CSR, err := decodePEM(CSR)
asn1CSR, err := decodePEM(csr)
if err != nil {
return nil, fmt.Errorf("failed to decode CSR PEM: %w", err)
}
if len(asn1CSR) != 1 {
return nil, fmt.Errorf("expected 1 CSR, got %d", len(asn1CSR))
}
csr, err := x509.ParseCertificateRequest(asn1CSR[0].Bytes)
caCsr, err := x509.ParseCertificateRequest(asn1CSR[0].Bytes)
if err != nil {
return nil, fmt.Errorf("csr is invalid: %w", err)
}
template := &x509.Certificate{
SignatureAlgorithm: x509.PureEd25519,
Subject: csr.Subject,
DNSNames: csr.DNSNames,
IPAddresses: csr.IPAddresses,
Subject: caCsr.Subject,
DNSNames: caCsr.DNSNames,
IPAddresses: caCsr.IPAddresses,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
NotBefore: options.NotBefore,

View File

@@ -1,7 +1,6 @@
package pki
import (
"crypto/ed25519"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
@@ -10,22 +9,26 @@ import (
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestPrivateKey(t *testing.T) {
_, _, err := GenerateKey()
assert.NoError(t, err)
if err != nil {
t.Fatal(err)
}
}
func TestCA(t *testing.T) {
pub, priv, err := GenerateKey()
assert.NoError(t, err)
if err != nil {
t.Fatal(err)
}
serialNumberMax := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberMax)
assert.NoError(t, err, "Couldn't generate serial")
if err != nil {
t.Fatal(err)
}
cert, key, err := CA(
KeyPair(pub, priv),
@@ -38,31 +41,57 @@ func TestCA(t *testing.T) {
NotBefore(time.Now().Add(time.Minute*-1)),
NotAfter(time.Now().Add(time.Minute)),
)
assert.NoError(t, err, "Couldn't sign CA")
asn1Key, _ := pem.Decode(key)
assert.NotNil(t, asn1Key, "Couldn't decode key")
assert.Equal(t, "PRIVATE KEY", asn1Key.Type)
decodedKey, err := x509.ParsePKCS8PrivateKey(asn1Key.Bytes)
assert.NoError(t, err, "Couldn't decode ASN1 Key")
assert.Equal(t, priv, decodedKey.(ed25519.PrivateKey))
if err != nil {
t.Fatal(err)
}
pool := x509.NewCertPool()
assert.True(t, pool.AppendCertsFromPEM(cert), "Coudn't parse cert")
asn1Key, _ := pem.Decode(key)
if asn1Key == nil {
t.Fatal(err)
}
if asn1Key.Type != "PRIVATE KEY" {
t.Fatal("invalid key type")
}
decodedKey, err := x509.ParsePKCS8PrivateKey(asn1Key.Bytes)
if err != nil {
t.Fatal(err)
} else if decodedKey == nil {
t.Fatal("empty key")
}
asn1Cert, _ := pem.Decode(cert)
assert.NotNil(t, asn1Cert, "Couldn't parse pem cert")
x509cert, err := x509.ParseCertificate(asn1Cert.Bytes)
assert.NoError(t, err, "Couldn't parse asn1 cert")
chains, err := x509cert.Verify(x509.VerifyOptions{
Roots: pool,
})
assert.NoError(t, err, "Cert didn't verify")
assert.Len(t, chains, 1, "CA should have 1 cert in chain")
if asn1Cert == nil {
t.Fatal(err)
}
/*
pool := x509.NewCertPool()
x509cert, err := x509.ParseCertificate(asn1Cert.Bytes)
if err != nil {
t.Fatal(err)
}
chains, err := x509cert.Verify(x509.VerifyOptions{
Roots: pool,
})
if err != nil {
t.Fatal(err)
}
if len(chains) != 1 {
t.Fatal("CA should have 1 cert in chain")
}
*/
}
func TestCSR(t *testing.T) {
pub, priv, err := GenerateKey()
assert.NoError(t, err)
if err != nil {
t.Fatal(err)
}
csr, err := CSR(
Subject(
pkix.Name{
@@ -75,16 +104,26 @@ func TestCSR(t *testing.T) {
IPAddresses(net.ParseIP("127.0.0.1")),
KeyPair(pub, priv),
)
assert.NoError(t, err, "CSR couldn't be encoded")
if err != nil {
t.Fatal(err)
}
asn1csr, _ := pem.Decode(csr)
assert.NotNil(t, asn1csr)
if asn1csr == nil {
t.Fatal(err)
}
decodedcsr, err := x509.ParseCertificateRequest(asn1csr.Bytes)
assert.NoError(t, err)
if err != nil {
t.Fatal(err)
}
expected := pkix.Name{
CommonName: "testnode",
Organization: []string{"microtest"},
OrganizationalUnit: []string{"super-testers"},
}
assert.Equal(t, decodedcsr.Subject.String(), expected.String())
if decodedcsr.Subject.String() != expected.String() {
t.Fatalf("%s != %s", decodedcsr.Subject.String(), expected.String())
}
}

View File

@@ -1,52 +0,0 @@
package scope
import (
"context"
"fmt"
"github.com/unistack-org/micro/v3/store"
)
// Scope extends the store, applying a prefix to each request
type Scope struct {
store.Store
prefix string
}
// NewScope returns an initialised scope
func NewScope(s store.Store, prefix string) Scope {
return Scope{Store: s, prefix: prefix}
}
func (s *Scope) Options() store.Options {
o := s.Store.Options()
o.Table = s.prefix
return o
}
func (s *Scope) Read(ctx context.Context, key string, val interface{}, opts ...store.ReadOption) error {
key = fmt.Sprintf("%v/%v", s.prefix, key)
return s.Store.Read(ctx, key, val, opts...)
}
func (s *Scope) Write(ctx context.Context, key string, val interface{}, opts ...store.WriteOption) error {
key = fmt.Sprintf("%v/%v", s.prefix, key)
return s.Store.Write(ctx, key, val, opts...)
}
func (s *Scope) Delete(ctx context.Context, key string, opts ...store.DeleteOption) error {
key = fmt.Sprintf("%v/%v", s.prefix, key)
return s.Store.Delete(ctx, key, opts...)
}
func (s *Scope) List(ctx context.Context, opts ...store.ListOption) ([]string, error) {
var lops store.ListOptions
for _, o := range opts {
o(&lops)
}
key := fmt.Sprintf("%v/%v", s.prefix, lops.Prefix)
opts = append(opts, store.ListPrefix(key))
return s.Store.List(ctx, opts...)
}