diff --git a/broker/http_broker.go b/broker/http_broker.go index fefeed4d..a7f6243d 100644 --- a/broker/http_broker.go +++ b/broker/http_broker.go @@ -22,10 +22,10 @@ import ( "github.com/micro/go-micro/codec/json" merr "github.com/micro/go-micro/errors" "github.com/micro/go-micro/registry" + maddr "github.com/micro/go-micro/util/addr" + mnet "github.com/micro/go-micro/util/net" + mls "github.com/micro/go-micro/util/tls" "github.com/micro/go-rcache" - maddr "github.com/micro/util/go/lib/addr" - mnet "github.com/micro/util/go/lib/net" - mls "github.com/micro/util/go/lib/tls" "golang.org/x/net/http2" ) diff --git a/server/rpc_server.go b/server/rpc_server.go index 69f880b3..6838d65e 100644 --- a/server/rpc_server.go +++ b/server/rpc_server.go @@ -16,8 +16,7 @@ import ( "github.com/micro/go-micro/metadata" "github.com/micro/go-micro/registry" "github.com/micro/go-micro/transport" - - "github.com/micro/util/go/lib/addr" + "github.com/micro/go-micro/util/addr" ) type rpcServer struct { diff --git a/transport/http_transport.go b/transport/http_transport.go index f72ed8cd..ec098ca2 100644 --- a/transport/http_transport.go +++ b/transport/http_transport.go @@ -13,9 +13,9 @@ import ( "sync" "time" - maddr "github.com/micro/util/go/lib/addr" - mnet "github.com/micro/util/go/lib/net" - mls "github.com/micro/util/go/lib/tls" + maddr "github.com/micro/go-micro/util/addr" + mnet "github.com/micro/go-micro/util/net" + mls "github.com/micro/go-micro/util/tls" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" ) diff --git a/util/addr/addr.go b/util/addr/addr.go new file mode 100644 index 00000000..ab3acca1 --- /dev/null +++ b/util/addr/addr.go @@ -0,0 +1,126 @@ +package addr + +import ( + "fmt" + "net" +) + +var ( + privateBlocks []*net.IPNet +) + +func init() { + for _, b := range []string{"10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "100.64.0.0/10", "fd00::/8"} { + if _, block, err := net.ParseCIDR(b); err == nil { + privateBlocks = append(privateBlocks, block) + } + } +} + +func isPrivateIP(ipAddr string) bool { + ip := net.ParseIP(ipAddr) + for _, priv := range privateBlocks { + if priv.Contains(ip) { + return true + } + } + return false +} + +// Extract returns a real ip +func Extract(addr string) (string, error) { + // if addr specified then its returned + if len(addr) > 0 && (addr != "0.0.0.0" && addr != "[::]") { + return addr, nil + } + + ifaces, err := net.Interfaces() + if err != nil { + return "", fmt.Errorf("Failed to get interfaces! Err: %v", err) + } + + var addrs []net.Addr + for _, iface := range ifaces { + ifaceAddrs, err := iface.Addrs() + if err != nil { + // ignore error, interface can dissapear from system + continue + } + addrs = append(addrs, ifaceAddrs...) + } + + var ipAddr []byte + var publicIP []byte + + for _, rawAddr := range addrs { + var ip net.IP + switch addr := rawAddr.(type) { + case *net.IPAddr: + ip = addr.IP + case *net.IPNet: + ip = addr.IP + default: + continue + } + + if !isPrivateIP(ip.String()) { + publicIP = ip + continue + } + + ipAddr = ip + break + } + + // return private ip + if ipAddr != nil { + return net.IP(ipAddr).String(), nil + } + + // return public or virtual ip + if publicIP != nil { + return net.IP(publicIP).String(), nil + } + + return "", fmt.Errorf("No IP address found, and explicit IP not provided") +} + +// IPs returns all known ips +func IPs() []string { + ifaces, err := net.Interfaces() + if err != nil { + return nil + } + + var ipAddrs []string + + for _, i := range ifaces { + addrs, err := i.Addrs() + if err != nil { + continue + } + + for _, addr := range addrs { + var ip net.IP + switch v := addr.(type) { + case *net.IPNet: + ip = v.IP + case *net.IPAddr: + ip = v.IP + } + + if ip == nil { + continue + } + + ip = ip.To4() + if ip == nil { + continue + } + + ipAddrs = append(ipAddrs, ip.String()) + } + } + + return ipAddrs +} diff --git a/util/addr/addr_test.go b/util/addr/addr_test.go new file mode 100644 index 00000000..ad23b128 --- /dev/null +++ b/util/addr/addr_test.go @@ -0,0 +1,38 @@ +package addr + +import ( + "net" + "testing" +) + +func TestExtractor(t *testing.T) { + testData := []struct { + addr string + expect string + parse bool + }{ + {"127.0.0.1", "127.0.0.1", false}, + {"10.0.0.1", "10.0.0.1", false}, + {"", "", true}, + {"0.0.0.0", "", true}, + {"[::]", "", true}, + } + + for _, d := range testData { + addr, err := Extract(d.addr) + if err != nil { + t.Errorf("Unexpected error %v", err) + } + + if d.parse { + ip := net.ParseIP(addr) + if ip == nil { + t.Error("Unexpected nil IP") + } + + } else if addr != d.expect { + t.Errorf("Expected %s got %s", d.expect, addr) + } + } + +} diff --git a/util/backoff/backoff.go b/util/backoff/backoff.go new file mode 100644 index 00000000..013d5291 --- /dev/null +++ b/util/backoff/backoff.go @@ -0,0 +1,14 @@ +// Package backoff provides backoff functionality +package backoff + +import ( + "math" + "time" +) + +func Do(attempts int) time.Duration { + if attempts == 0 { + return time.Duration(0) + } + return time.Duration(math.Pow(10, float64(attempts))) * time.Millisecond +} diff --git a/util/ctx/ctx.go b/util/ctx/ctx.go new file mode 100644 index 00000000..2fb69a43 --- /dev/null +++ b/util/ctx/ctx.go @@ -0,0 +1,18 @@ +package ctx + +import ( + "context" + "net/http" + "strings" + + "github.com/micro/go-micro/metadata" +) + +func FromRequest(r *http.Request) context.Context { + ctx := context.Background() + md := make(metadata.Metadata) + for k, v := range r.Header { + md[k] = strings.Join(v, ",") + } + return metadata.NewContext(ctx, md) +} diff --git a/util/ctx/ctx_test.go b/util/ctx/ctx_test.go new file mode 100644 index 00000000..440baea1 --- /dev/null +++ b/util/ctx/ctx_test.go @@ -0,0 +1,41 @@ +package ctx + +import ( + "net/http" + "testing" + + "github.com/micro/go-micro/metadata" +) + +func TestRequestToContext(t *testing.T) { + testData := []struct { + request *http.Request + expect metadata.Metadata + }{ + { + &http.Request{ + Header: http.Header{ + "foo1": []string{"bar"}, + "foo2": []string{"bar", "baz"}, + }, + }, + metadata.Metadata{ + "foo1": "bar", + "foo2": "bar,baz", + }, + }, + } + + for _, d := range testData { + ctx := FromRequest(d.request) + md, ok := metadata.FromContext(ctx) + if !ok { + t.Fatalf("Expected metadata for request %+v", d.request) + } + for k, v := range d.expect { + if val := md[k]; val != v { + t.Fatalf("Expected %s for key %s for expected md %+v, got md %+v", v, k, d.expect, md) + } + } + } +} diff --git a/util/file/file.go b/util/file/file.go new file mode 100644 index 00000000..3e61d3f3 --- /dev/null +++ b/util/file/file.go @@ -0,0 +1,15 @@ +package file + +import "os" + +// Exists returns true if the path is existing +func Exists(path string) (bool, error) { + _, err := os.Stat(path) + if err == nil { + return true, nil + } + if os.IsNotExist(err) { + return false, nil + } + return true, err +} diff --git a/util/file/file_test.go b/util/file/file_test.go new file mode 100644 index 00000000..4ee3fabc --- /dev/null +++ b/util/file/file_test.go @@ -0,0 +1,17 @@ +package file + +import ( + "testing" +) + +func TestExists(t *testing.T) { + ok, err := Exists("/") + + if ok { + return + } + + if !ok || err != nil { + t.Fatalf("Test Exists fail, %s", err) + } +} diff --git a/util/grpc/grpc.go b/util/grpc/grpc.go new file mode 100644 index 00000000..b06a0673 --- /dev/null +++ b/util/grpc/grpc.go @@ -0,0 +1,40 @@ +package grpc + +import ( + "fmt" + "strings" +) + +// ServiceMethod converts a gRPC method to a Go method +// Input: +// Foo.Bar, /Foo/Bar, /package.Foo/Bar, /a.package.Foo/Bar +// Output: +// [Foo, Bar] +func ServiceMethod(m string) (string, string, error) { + if len(m) == 0 { + return "", "", fmt.Errorf("malformed method name: %q", m) + } + + // grpc method + if m[0] == '/' { + // [ , Foo, Bar] + // [ , package.Foo, Bar] + // [ , a.package.Foo, Bar] + parts := strings.Split(m, "/") + if len(parts) != 3 || len(parts[1]) == 0 || len(parts[2]) == 0 { + return "", "", fmt.Errorf("malformed method name: %q", m) + } + service := strings.Split(parts[1], ".") + return service[len(service)-1], parts[2], nil + } + + // non grpc method + parts := strings.Split(m, ".") + + // expect [Foo, Bar] + if len(parts) != 2 { + return "", "", fmt.Errorf("malformed method name: %q", m) + } + + return parts[0], parts[1], nil +} diff --git a/util/grpc/grpc_test.go b/util/grpc/grpc_test.go new file mode 100644 index 00000000..c7f30cc1 --- /dev/null +++ b/util/grpc/grpc_test.go @@ -0,0 +1,46 @@ +package grpc + +import ( + "testing" +) + +func TestServiceMethod(t *testing.T) { + type testCase struct { + input string + service string + method string + err bool + } + + methods := []testCase{ + {"Foo.Bar", "Foo", "Bar", false}, + {"/Foo/Bar", "Foo", "Bar", false}, + {"/package.Foo/Bar", "Foo", "Bar", false}, + {"/a.package.Foo/Bar", "Foo", "Bar", false}, + {"a.package.Foo/Bar", "", "", true}, + {"/Foo/Bar/Baz", "", "", true}, + {"Foo.Bar.Baz", "", "", true}, + } + for _, test := range methods { + service, method, err := ServiceMethod(test.input) + if err != nil && test.err == true { + continue + } + // unexpected error + if err != nil && test.err == false { + t.Fatalf("unexpected err %v for %+v", err, test) + } + // expecter error + if test.err == true && err == nil { + t.Fatalf("expected error for %+v: got service: %s method: %s", test, service, method) + } + + if service != test.service { + t.Fatalf("wrong service for %+v: got service: %s method: %s", test, service, method) + } + + if method != test.method { + t.Fatalf("wrong method for %+v: got service: %s method: %s", test, service, method) + } + } +} diff --git a/util/http/http.go b/util/http/http.go new file mode 100644 index 00000000..06bff69f --- /dev/null +++ b/util/http/http.go @@ -0,0 +1,23 @@ +package http + +import ( + "net/http" + + "github.com/micro/go-micro/registry" + "github.com/micro/go-micro/selector" +) + +func NewRoundTripper(opts ...Option) http.RoundTripper { + options := Options{ + Registry: registry.DefaultRegistry, + } + for _, o := range opts { + o(&options) + } + + return &roundTripper{ + rt: http.DefaultTransport, + st: selector.Random, + opts: options, + } +} diff --git a/util/http/http_test.go b/util/http/http_test.go new file mode 100644 index 00000000..b7bfe370 --- /dev/null +++ b/util/http/http_test.go @@ -0,0 +1,87 @@ +package http + +import ( + "io/ioutil" + "net" + "net/http" + "strconv" + "testing" + + "github.com/micro/go-micro/registry" + "github.com/micro/go-micro/registry/memory" +) + +func TestRoundTripper(t *testing.T) { + m := memory.NewRegistry() + + rt := NewRoundTripper( + WithRegistry(m), + ) + + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(`hello world`)) + }) + + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer l.Close() + + go http.Serve(l, nil) + + host, p, _ := net.SplitHostPort(l.Addr().String()) + port, _ := strconv.Atoi(p) + + m.Register(®istry.Service{ + Name: "example.com", + Nodes: []*registry.Node{ + { + Id: "1", + Address: host, + Port: port, + }, + }, + }) + + req, err := http.NewRequest("GET", "http://example.com", nil) + if err != nil { + t.Fatal(err) + } + + w, err := rt.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + + b, err := ioutil.ReadAll(w.Body) + if err != nil { + t.Fatal(err) + } + w.Body.Close() + + if string(b) != "hello world" { + t.Fatal("response is", string(b)) + } + + // test http request + c := &http.Client{ + Transport: rt, + } + + rsp, err := c.Get("http://example.com") + if err != nil { + t.Fatal(err) + } + + b, err = ioutil.ReadAll(rsp.Body) + if err != nil { + t.Fatal(err) + } + rsp.Body.Close() + + if string(b) != "hello world" { + t.Fatal("response is", string(b)) + } + +} diff --git a/util/http/options.go b/util/http/options.go new file mode 100644 index 00000000..a248e422 --- /dev/null +++ b/util/http/options.go @@ -0,0 +1,17 @@ +package http + +import ( + "github.com/micro/go-micro/registry" +) + +type Options struct { + Registry registry.Registry +} + +type Option func(*Options) + +func WithRegistry(r registry.Registry) Option { + return func(o *Options) { + o.Registry = r + } +} diff --git a/util/http/roundtripper.go b/util/http/roundtripper.go new file mode 100644 index 00000000..f90ef34f --- /dev/null +++ b/util/http/roundtripper.go @@ -0,0 +1,40 @@ +package http + +import ( + "errors" + "fmt" + "net/http" + + "github.com/micro/go-micro/selector" +) + +type roundTripper struct { + rt http.RoundTripper + st selector.Strategy + opts Options +} + +func (r *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + s, err := r.opts.Registry.GetService(req.URL.Host) + if err != nil { + return nil, err + } + + next := r.st(s) + + // rudimentary retry 3 times + for i := 0; i < 3; i++ { + n, err := next() + if err != nil { + continue + } + req.URL.Host = fmt.Sprintf("%s:%d", n.Address, n.Port) + w, err := r.rt.RoundTrip(req) + if err != nil { + continue + } + return w, nil + } + + return nil, errors.New("failed request") +} diff --git a/util/net/net.go b/util/net/net.go new file mode 100644 index 00000000..b092068f --- /dev/null +++ b/util/net/net.go @@ -0,0 +1,63 @@ +package net + +import ( + "errors" + "fmt" + "net" + "strconv" + "strings" +) + +// Listen takes addr:portmin-portmax and binds to the first available port +// Example: Listen("localhost:5000-6000", fn) +func Listen(addr string, fn func(string) (net.Listener, error)) (net.Listener, error) { + // host:port || host:min-max + parts := strings.Split(addr, ":") + + // + if len(parts) < 2 { + return fn(addr) + } + + // try to extract port range + ports := strings.Split(parts[len(parts)-1], "-") + + // single port + if len(ports) < 2 { + return fn(addr) + } + + // we have a port range + + // extract min port + min, err := strconv.Atoi(ports[0]) + if err != nil { + return nil, errors.New("unable to extract port range") + } + + // extract max port + max, err := strconv.Atoi(ports[1]) + if err != nil { + return nil, errors.New("unable to extract port range") + } + + // set host + host := parts[:len(parts)-1] + + // range the ports + for port := min; port <= max; port++ { + // try bind to host:port + ln, err := fn(fmt.Sprintf("%s:%d", host, port)) + if err == nil { + return ln, nil + } + + // hit max port + if port == max { + return nil, err + } + } + + // why are we here? + return nil, fmt.Errorf("unable to bind to %s", addr) +} diff --git a/util/net/net_test.go b/util/net/net_test.go new file mode 100644 index 00000000..a9fca743 --- /dev/null +++ b/util/net/net_test.go @@ -0,0 +1,21 @@ +package net + +import ( + "net" + "testing" +) + +func TestListen(t *testing.T) { + fn := func(addr string) (net.Listener, error) { + return net.Listen("tcp", addr) + } + + // try to create a number of listeners + for i := 0; i < 10; i++ { + l, err := Listen("localhost:10000-11000", fn) + if err != nil { + t.Fatal(err) + } + defer l.Close() + } +} diff --git a/util/tls/tls.go b/util/tls/tls.go new file mode 100644 index 00000000..6df2c6e4 --- /dev/null +++ b/util/tls/tls.go @@ -0,0 +1,74 @@ +package tls + +import ( + "bytes" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "net" + "time" +) + +func Certificate(host ...string) (tls.Certificate, error) { + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return tls.Certificate{}, err + } + + notBefore := time.Now() + notAfter := notBefore.Add(time.Hour * 24 * 365) + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return tls.Certificate{}, err + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"Acme Co"}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + for _, h := range host { + if ip := net.ParseIP(h); ip != nil { + template.IPAddresses = append(template.IPAddresses, ip) + } else { + template.DNSNames = append(template.DNSNames, h) + } + } + + template.IsCA = true + template.KeyUsage |= x509.KeyUsageCertSign + + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return tls.Certificate{}, err + } + + // create public key + certOut := bytes.NewBuffer(nil) + pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + + // create private key + keyOut := bytes.NewBuffer(nil) + b, err := x509.MarshalECPrivateKey(priv) + if err != nil { + return tls.Certificate{}, err + } + pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: b}) + + return tls.X509KeyPair(certOut.Bytes(), keyOut.Bytes()) +}