Compare commits

...

27 Commits

Author SHA1 Message Date
Asim Aslam
d090a97a3d Merge pull request #396 from micro/error
Fix #394 invalid error handling in rpc_router ServeRequest
2019-01-22 14:28:41 +00:00
Asim Aslam
8a0d5f0489 log if we can't even respond 2019-01-22 13:55:04 +00:00
Asim Aslam
2ed676acf4 handle errors differently 2019-01-22 13:52:18 +00:00
Asim Aslam
d8ba18deff change logging 2019-01-22 12:18:33 +00:00
Asim Aslam
1321782785 in case of reload return nil 2019-01-19 10:20:16 +00:00
Asim Aslam
48b80dd051 replace memory registry 2019-01-18 17:29:17 +00:00
Asim Aslam
943219f203 Merge pull request #393 from micro/legacy
Evolution of Codecs and Methods
2019-01-18 12:40:05 +00:00
Asim Aslam
6468733d98 Use protocol from node metadata 2019-01-18 12:30:39 +00:00
Asim Aslam
9bd32645be Account for old target 2019-01-18 10:43:41 +00:00
Asim Aslam
f41be53ff8 Add ability to process legacy requests 2019-01-18 10:23:36 +00:00
Asim Aslam
2cd2258731 For the legacy 2019-01-18 10:12:57 +00:00
Asim Aslam
9ce9977d21 Don't read unless we have b 2019-01-17 12:09:04 +00:00
Asim Aslam
617db003d4 Copy metadata 2019-01-17 09:40:49 +00:00
Asim Aslam
7b89b36e37 add benchmarks 2019-01-16 18:54:43 +00:00
Asim Aslam
e2e426b90c Increase default pool size 2019-01-16 18:54:32 +00:00
Asim Aslam
5b95ce7f26 Silence broker during tests 2019-01-16 18:54:04 +00:00
Asim Aslam
082f57fcad We can just check nil vals 2019-01-16 15:42:42 +00:00
Asim Aslam
cc5629fb6b Don't return zero length services 2019-01-16 15:41:37 +00:00
Asim Aslam
784a89b488 Allow bytes.Frame to be set to sent just bytes 2019-01-16 15:27:57 +00:00
Asim Aslam
a9c0b95603 update readme 2019-01-16 13:12:21 +00:00
Asim Aslam
7bd0bd14c8 Merge pull request #386 from micro/mdns
Set MDNS as default registry
2019-01-15 16:59:07 +00:00
Asim Aslam
7314af347b Set MDNS as default registry 2019-01-15 16:50:37 +00:00
Asim Aslam
00661f8a99 Clarify log message 2019-01-15 15:17:30 +00:00
Asim Aslam
e362466e8a use default router 2019-01-14 21:45:43 +00:00
Asim Aslam
c1d0237370 Add client response 2019-01-14 21:30:43 +00:00
Asim Aslam
f2ac73eae5 only log error if its plus 3 2019-01-14 16:09:51 +00:00
Asim Aslam
39c24baca9 rename mock things to memory 2019-01-14 15:27:25 +00:00
64 changed files with 2214 additions and 1619 deletions

View File

@@ -19,8 +19,9 @@ Follow us on [Twitter](https://twitter.com/microhq) or join the [Slack](http://s
Go Micro abstracts away the details of distributed systems. Here are the main features.
- **Service Discovery** - Automatic service registration and name resolution. Service discovery is at the core of micro service
development. When service A needs to speak to service B it needs the location of that service. Consul is the default discovery
system with multicast DNS (mdns) as a local option or the SWIM protocol (gossip) for zero dependency p2p networks.
development. When service A needs to speak to service B it needs the location of that service. The default discovery mechanism is
multicast DNS (mdns), a zeroconf system. You can optionally set gossip using the SWIM protocol for p2p networks or consul for a
resilient cloud-native setup.
- **Load Balancing** - Client side load balancing built on service discovery. Once we have the addresses of any number of instances
of a service we now need a way to decide which node to route to. We use random hashed load balancing to provide even distribution

View File

@@ -19,7 +19,6 @@ import (
"time"
"github.com/google/uuid"
"github.com/micro/go-log"
"github.com/micro/go-micro/codec/json"
merr "github.com/micro/go-micro/errors"
"github.com/micro/go-micro/registry"
@@ -397,7 +396,6 @@ func (h *httpBroker) Connect() error {
return err
}
log.Logf("Broker Listening on %s", l.Addr().String())
addr := h.address
h.address = l.Addr().String()

View File

@@ -5,13 +5,26 @@ import (
"testing"
"time"
glog "github.com/go-log/log"
"github.com/google/uuid"
"github.com/micro/go-micro/registry/mock"
"github.com/micro/go-log"
"github.com/micro/go-micro/registry/memory"
)
func newTestRegistry() *memory.Registry {
r := memory.NewRegistry()
m := r.(*memory.Registry)
m.Setup()
return m
}
func sub(be *testing.B, c int) {
// set no op logger
log.SetLogger(glog.DefaultLogger)
be.StopTimer()
m := mock.NewRegistry()
m := newTestRegistry()
b := NewBroker(Registry(m))
topic := uuid.New().String()
@@ -69,8 +82,11 @@ func sub(be *testing.B, c int) {
}
func pub(be *testing.B, c int) {
// set no op logger
log.SetLogger(glog.DefaultLogger)
be.StopTimer()
m := mock.NewRegistry()
m := newTestRegistry()
b := NewBroker(Registry(m))
topic := uuid.New().String()
@@ -139,7 +155,7 @@ func pub(be *testing.B, c int) {
}
func TestBroker(t *testing.T) {
m := mock.NewRegistry()
m := newTestRegistry()
b := NewBroker(Registry(m))
if err := b.Init(); err != nil {
@@ -186,7 +202,7 @@ func TestBroker(t *testing.T) {
}
func TestConcurrentSubBroker(t *testing.T) {
m := mock.NewRegistry()
m := newTestRegistry()
b := NewBroker(Registry(m))
if err := b.Init(); err != nil {
@@ -243,7 +259,7 @@ func TestConcurrentSubBroker(t *testing.T) {
}
func TestConcurrentPubBroker(t *testing.T) {
m := mock.NewRegistry()
m := newTestRegistry()
b := NewBroker(Registry(m))
if err := b.Init(); err != nil {

View File

@@ -1,5 +1,5 @@
// Package mock provides a mock broker for testing
package mock
// Package memory provides a memory broker
package memory
import (
"errors"
@@ -9,20 +9,20 @@ import (
"github.com/micro/go-micro/broker"
)
type mockBroker struct {
type memoryBroker struct {
opts broker.Options
sync.RWMutex
connected bool
Subscribers map[string][]*mockSubscriber
Subscribers map[string][]*memorySubscriber
}
type mockPublication struct {
type memoryPublication struct {
topic string
message *broker.Message
}
type mockSubscriber struct {
type memorySubscriber struct {
id string
topic string
exit chan bool
@@ -30,15 +30,15 @@ type mockSubscriber struct {
opts broker.SubscribeOptions
}
func (m *mockBroker) Options() broker.Options {
func (m *memoryBroker) Options() broker.Options {
return m.opts
}
func (m *mockBroker) Address() string {
func (m *memoryBroker) Address() string {
return ""
}
func (m *mockBroker) Connect() error {
func (m *memoryBroker) Connect() error {
m.Lock()
defer m.Unlock()
@@ -51,7 +51,7 @@ func (m *mockBroker) Connect() error {
return nil
}
func (m *mockBroker) Disconnect() error {
func (m *memoryBroker) Disconnect() error {
m.Lock()
defer m.Unlock()
@@ -64,14 +64,14 @@ func (m *mockBroker) Disconnect() error {
return nil
}
func (m *mockBroker) Init(opts ...broker.Option) error {
func (m *memoryBroker) Init(opts ...broker.Option) error {
for _, o := range opts {
o(&m.opts)
}
return nil
}
func (m *mockBroker) Publish(topic string, message *broker.Message, opts ...broker.PublishOption) error {
func (m *memoryBroker) Publish(topic string, message *broker.Message, opts ...broker.PublishOption) error {
m.Lock()
defer m.Unlock()
@@ -84,7 +84,7 @@ func (m *mockBroker) Publish(topic string, message *broker.Message, opts ...brok
return nil
}
p := &mockPublication{
p := &memoryPublication{
topic: topic,
message: message,
}
@@ -98,7 +98,7 @@ func (m *mockBroker) Publish(topic string, message *broker.Message, opts ...brok
return nil
}
func (m *mockBroker) Subscribe(topic string, handler broker.Handler, opts ...broker.SubscribeOption) (broker.Subscriber, error) {
func (m *memoryBroker) Subscribe(topic string, handler broker.Handler, opts ...broker.SubscribeOption) (broker.Subscriber, error) {
m.Lock()
defer m.Unlock()
@@ -111,7 +111,7 @@ func (m *mockBroker) Subscribe(topic string, handler broker.Handler, opts ...bro
o(&options)
}
sub := &mockSubscriber{
sub := &memorySubscriber{
exit: make(chan bool, 1),
id: uuid.New().String(),
topic: topic,
@@ -124,7 +124,7 @@ func (m *mockBroker) Subscribe(topic string, handler broker.Handler, opts ...bro
go func() {
<-sub.exit
m.Lock()
var newSubscribers []*mockSubscriber
var newSubscribers []*memorySubscriber
for _, sb := range m.Subscribers[topic] {
if sb.id == sub.id {
continue
@@ -138,31 +138,31 @@ func (m *mockBroker) Subscribe(topic string, handler broker.Handler, opts ...bro
return sub, nil
}
func (m *mockBroker) String() string {
return "mock"
func (m *memoryBroker) String() string {
return "memory"
}
func (m *mockPublication) Topic() string {
func (m *memoryPublication) Topic() string {
return m.topic
}
func (m *mockPublication) Message() *broker.Message {
func (m *memoryPublication) Message() *broker.Message {
return m.message
}
func (m *mockPublication) Ack() error {
func (m *memoryPublication) Ack() error {
return nil
}
func (m *mockSubscriber) Options() broker.SubscribeOptions {
func (m *memorySubscriber) Options() broker.SubscribeOptions {
return m.opts
}
func (m *mockSubscriber) Topic() string {
func (m *memorySubscriber) Topic() string {
return m.topic
}
func (m *mockSubscriber) Unsubscribe() error {
func (m *memorySubscriber) Unsubscribe() error {
m.exit <- true
return nil
}
@@ -173,8 +173,8 @@ func NewBroker(opts ...broker.Option) broker.Broker {
o(&options)
}
return &mockBroker{
return &memoryBroker{
opts: options,
Subscribers: make(map[string][]*mockSubscriber),
Subscribers: make(map[string][]*memorySubscriber),
}
}

View File

@@ -1,4 +1,4 @@
package mock
package memory
import (
"fmt"
@@ -7,7 +7,7 @@ import (
"github.com/micro/go-micro/broker"
)
func TestBroker(t *testing.T) {
func TestMemoryBroker(t *testing.T) {
b := NewBroker()
if err := b.Connect(); err != nil {

View File

@@ -38,7 +38,9 @@ type Message interface {
type Request interface {
// The service to call
Service() string
// The endpoint to call
// The action to take
Method() string
// The endpoint to invoke
Endpoint() string
// The content type
ContentType() string
@@ -62,11 +64,19 @@ type Response interface {
// Stream is the inteface for a bidirectional synchronous stream
type Stream interface {
// Context for the stream
Context() context.Context
// The request made
Request() Request
// The response read
Response() Response
// Send will encode and send a request
Send(interface{}) error
// Recv will decode and read a response
Recv(interface{}) error
// Error returns the stream error
Error() error
// Close closes the stream
Close() error
}
@@ -97,7 +107,7 @@ var (
// DefaultRequestTimeout is the default request timeout
DefaultRequestTimeout = time.Second * 5
// DefaultPoolSize sets the connection pool size
DefaultPoolSize = 1
DefaultPoolSize = 100
// DefaultPoolTTL sets the connection pool ttl
DefaultPoolTTL = time.Minute
)

View File

@@ -4,10 +4,11 @@ import (
"bytes"
"context"
"fmt"
"net"
"strconv"
"sync"
"time"
"sync/atomic"
"time"
"github.com/google/uuid"
"github.com/micro/go-micro/broker"
@@ -56,7 +57,12 @@ func (r *rpcClient) newCodec(contentType string) (codec.NewCodec, error) {
return nil, fmt.Errorf("Unsupported Content-Type: %s", contentType)
}
func (r *rpcClient) call(ctx context.Context, address string, req Request, resp interface{}, opts CallOptions) error {
func (r *rpcClient) call(ctx context.Context, node *registry.Node, req Request, resp interface{}, opts CallOptions) error {
address := node.Address
if node.Port > 0 {
address = fmt.Sprintf("%s:%d", address, node.Port)
}
msg := &transport.Message{
Header: make(map[string]string),
}
@@ -75,9 +81,16 @@ func (r *rpcClient) call(ctx context.Context, address string, req Request, resp
// set the accept header
msg.Header["Accept"] = req.ContentType()
cf, err := r.newCodec(req.ContentType())
if err != nil {
return errors.InternalServerError("go.micro.client", err.Error())
// setup old protocol
cf := setupProtocol(msg, node)
// no codec specified
if cf == nil {
var err error
cf, err = r.newCodec(req.ContentType())
if err != nil {
return errors.InternalServerError("go.micro.client", err.Error())
}
}
var grr error
@@ -92,13 +105,20 @@ func (r *rpcClient) call(ctx context.Context, address string, req Request, resp
seq := atomic.LoadUint64(&r.seq)
atomic.AddUint64(&r.seq, 1)
codec := newRpcCodec(msg, c, cf)
rsp := &rpcResponse{
socket: c,
codec: codec,
}
stream := &rpcStream{
context: ctx,
request: req,
closed: make(chan bool),
codec: newRpcCodec(msg, c, cf),
id: fmt.Sprintf("%v", seq),
context: ctx,
request: req,
response: rsp,
codec: codec,
closed: make(chan bool),
id: fmt.Sprintf("%v", seq),
}
defer stream.Close()
@@ -137,7 +157,12 @@ func (r *rpcClient) call(ctx context.Context, address string, req Request, resp
}
}
func (r *rpcClient) stream(ctx context.Context, address string, req Request, opts CallOptions) (Stream, error) {
func (r *rpcClient) stream(ctx context.Context, node *registry.Node, req Request, opts CallOptions) (Stream, error) {
address := node.Address
if node.Port > 0 {
address = fmt.Sprintf("%s:%d", address, node.Port)
}
msg := &transport.Message{
Header: make(map[string]string),
}
@@ -156,9 +181,16 @@ func (r *rpcClient) stream(ctx context.Context, address string, req Request, opt
// set the accept header
msg.Header["Accept"] = req.ContentType()
cf, err := r.newCodec(req.ContentType())
if err != nil {
return nil, errors.InternalServerError("go.micro.client", err.Error())
// set old codecs
cf := setupProtocol(msg, node)
// no codec specified
if cf == nil {
var err error
cf, err = r.newCodec(req.ContentType())
if err != nil {
return nil, errors.InternalServerError("go.micro.client", err.Error())
}
}
dOpts := []transport.DialOption{
@@ -174,11 +206,19 @@ func (r *rpcClient) stream(ctx context.Context, address string, req Request, opt
return nil, errors.InternalServerError("go.micro.client", "connection error: %v", err)
}
codec := newRpcCodec(msg, c, cf)
rsp := &rpcResponse{
socket: c,
codec: codec,
}
stream := &rpcStream{
context: ctx,
request: req,
closed: make(chan bool),
codec: newRpcCodec(msg, c, cf),
context: ctx,
request: req,
response: rsp,
closed: make(chan bool),
codec: newRpcCodec(msg, c, cf),
}
ch := make(chan error, 1)
@@ -230,9 +270,19 @@ func (r *rpcClient) Options() Options {
func (r *rpcClient) next(request Request, opts CallOptions) (selector.Next, error) {
// return remote address
if len(opts.Address) > 0 {
address := opts.Address
port := 0
host, sport, err := net.SplitHostPort(opts.Address)
if err == nil {
address = host
port, _ = strconv.Atoi(sport)
}
return func() (*registry.Node, error) {
return &registry.Node{
Address: opts.Address,
Address: address,
Port: port,
}, nil
}, nil
}
@@ -308,14 +358,8 @@ func (r *rpcClient) Call(ctx context.Context, request Request, response interfac
return errors.InternalServerError("go.micro.client", "error getting next %s node: %v", request.Service(), err.Error())
}
// set the address
address := node.Address
if node.Port > 0 {
address = fmt.Sprintf("%s:%d", address, node.Port)
}
// make the call
err = rcall(ctx, address, request, response, callOpts)
err = rcall(ctx, node, request, response, callOpts)
r.opts.Selector.Mark(request.Service(), node, err)
return err
}
@@ -391,12 +435,7 @@ func (r *rpcClient) Stream(ctx context.Context, request Request, opts ...CallOpt
return nil, errors.InternalServerError("go.micro.client", "error getting next %s node: %v", request.Service(), err.Error())
}
address := node.Address
if node.Port > 0 {
address = fmt.Sprintf("%s:%d", address, node.Port)
}
stream, err := r.stream(ctx, address, request, callOpts)
stream, err := r.stream(ctx, node, request, callOpts)
r.opts.Selector.Mark(request.Service(), node, err)
return stream, err
}

View File

@@ -7,18 +7,25 @@ import (
"github.com/micro/go-micro/errors"
"github.com/micro/go-micro/registry"
"github.com/micro/go-micro/registry/mock"
"github.com/micro/go-micro/registry/memory"
"github.com/micro/go-micro/selector"
)
func newTestRegistry() registry.Registry {
r := memory.NewRegistry()
r.(*memory.Registry).Setup()
return r
}
func TestCallAddress(t *testing.T) {
var called bool
service := "test.service"
endpoint := "Test.Endpoint"
address := "10.1.10.1:8080"
address := "10.1.10.1"
port := 8080
wrap := func(cf CallFunc) CallFunc {
return func(ctx context.Context, addr string, req Request, rsp interface{}, opts CallOptions) error {
return func(ctx context.Context, node *registry.Node, req Request, rsp interface{}, opts CallOptions) error {
called = true
if req.Service() != service {
@@ -29,8 +36,12 @@ func TestCallAddress(t *testing.T) {
return fmt.Errorf("expected service: %s got %s", endpoint, req.Endpoint())
}
if addr != address {
return fmt.Errorf("expected address: %s got %s", address, addr)
if node.Address != address {
return fmt.Errorf("expected address: %s got %s", address, node.Address)
}
if node.Port != port {
return fmt.Errorf("expected address: %d got %d", port, node.Port)
}
// don't do the call
@@ -38,7 +49,7 @@ func TestCallAddress(t *testing.T) {
}
}
r := mock.NewRegistry()
r := newTestRegistry()
c := NewClient(
Registry(r),
WrapCall(wrap),
@@ -48,7 +59,7 @@ func TestCallAddress(t *testing.T) {
req := c.NewRequest(service, endpoint, nil)
// test calling remote address
if err := c.Call(context.Background(), req, nil, WithAddress(address)); err != nil {
if err := c.Call(context.Background(), req, nil, WithAddress(fmt.Sprintf("%s:%d", address, port))); err != nil {
t.Fatal("call with address error", err)
}
@@ -61,12 +72,12 @@ func TestCallAddress(t *testing.T) {
func TestCallRetry(t *testing.T) {
service := "test.service"
endpoint := "Test.Endpoint"
address := "10.1.10.1:8080"
address := "10.1.10.1"
var called int
wrap := func(cf CallFunc) CallFunc {
return func(ctx context.Context, addr string, req Request, rsp interface{}, opts CallOptions) error {
return func(ctx context.Context, node *registry.Node, req Request, rsp interface{}, opts CallOptions) error {
called++
if called == 1 {
return errors.InternalServerError("test.error", "retry request")
@@ -77,7 +88,7 @@ func TestCallRetry(t *testing.T) {
}
}
r := mock.NewRegistry()
r := newTestRegistry()
c := NewClient(
Registry(r),
WrapCall(wrap),
@@ -102,12 +113,11 @@ func TestCallWrapper(t *testing.T) {
id := "test.1"
service := "test.service"
endpoint := "Test.Endpoint"
host := "10.1.10.1"
address := "10.1.10.1"
port := 8080
address := "10.1.10.1:8080"
wrap := func(cf CallFunc) CallFunc {
return func(ctx context.Context, addr string, req Request, rsp interface{}, opts CallOptions) error {
return func(ctx context.Context, node *registry.Node, req Request, rsp interface{}, opts CallOptions) error {
called = true
if req.Service() != service {
@@ -118,8 +128,8 @@ func TestCallWrapper(t *testing.T) {
return fmt.Errorf("expected service: %s got %s", endpoint, req.Endpoint())
}
if addr != address {
return fmt.Errorf("expected address: %s got %s", address, addr)
if node.Address != address {
return fmt.Errorf("expected address: %s got %s", address, node.Address)
}
// don't do the call
@@ -127,7 +137,7 @@ func TestCallWrapper(t *testing.T) {
}
}
r := mock.NewRegistry()
r := newTestRegistry()
c := NewClient(
Registry(r),
WrapCall(wrap),
@@ -140,7 +150,7 @@ func TestCallWrapper(t *testing.T) {
Nodes: []*registry.Node{
&registry.Node{
Id: id,
Address: host,
Address: address,
Port: port,
},
},

View File

@@ -12,6 +12,7 @@ import (
"github.com/micro/go-micro/codec/proto"
"github.com/micro/go-micro/codec/protorpc"
"github.com/micro/go-micro/errors"
"github.com/micro/go-micro/registry"
"github.com/micro/go-micro/transport"
)
@@ -58,6 +59,15 @@ var (
"application/proto-rpc": protorpc.NewCodec,
"application/octet-stream": raw.NewCodec,
}
// TODO: remove legacy codec list
defaultCodecs = map[string]codec.NewCodec{
"application/json": jsonrpc.NewCodec,
"application/json-rpc": jsonrpc.NewCodec,
"application/protobuf": protorpc.NewCodec,
"application/proto-rpc": protorpc.NewCodec,
"application/octet-stream": protorpc.NewCodec,
}
)
func (rwc *readWriteCloser) Read(p []byte) (n int, err error) {
@@ -74,6 +84,27 @@ func (rwc *readWriteCloser) Close() error {
return nil
}
// setupProtocol sets up the old protocol
func setupProtocol(msg *transport.Message, node *registry.Node) codec.NewCodec {
protocol := node.Metadata["protocol"]
// got protocol
if len(protocol) > 0 {
return nil
}
// no protocol use old codecs
switch msg.Header["Content-Type"] {
case "application/json":
msg.Header["Content-Type"] = "application/json-rpc"
case "application/protobuf":
msg.Header["Content-Type"] = "application/proto-rpc"
}
// now return codec
return defaultCodecs[msg.Header["Content-Type"]]
}
func newRpcCodec(req *transport.Message, client transport.Client, c codec.NewCodec) codec.Codec {
rwc := &readWriteCloser{
wbuf: bytes.NewBuffer(nil),
@@ -88,41 +119,55 @@ func newRpcCodec(req *transport.Message, client transport.Client, c codec.NewCod
return r
}
func (c *rpcCodec) Write(wm *codec.Message, body interface{}) error {
func (c *rpcCodec) Write(m *codec.Message, body interface{}) error {
c.buf.wbuf.Reset()
m := &codec.Message{
Id: wm.Id,
Target: wm.Target,
Endpoint: wm.Endpoint,
Type: codec.Request,
Header: map[string]string{
"X-Micro-Id": wm.Id,
"X-Micro-Service": wm.Target,
"X-Micro-Endpoint": wm.Endpoint,
},
// create header
if m.Header == nil {
m.Header = map[string]string{}
}
if err := c.codec.Write(m, body); err != nil {
return errors.InternalServerError("go.micro.client.codec", err.Error())
// copy original header
for k, v := range c.req.Header {
m.Header[k] = v
}
// set body
if len(wm.Body) > 0 {
c.req.Body = wm.Body
} else {
c.req.Body = c.buf.wbuf.Bytes()
// set the mucp headers
m.Header["X-Micro-Id"] = m.Id
m.Header["X-Micro-Service"] = m.Target
m.Header["X-Micro-Method"] = m.Method
m.Header["X-Micro-Endpoint"] = m.Endpoint
// if body is bytes Frame don't encode
if body != nil {
b, ok := body.(*raw.Frame)
if ok {
// set body
m.Body = b.Data
body = nil
}
}
// set header
for k, v := range m.Header {
c.req.Header[k] = v
if len(m.Body) == 0 {
// write to codec
if err := c.codec.Write(m, body); err != nil {
return errors.InternalServerError("go.micro.client.codec", err.Error())
}
// set body
m.Body = c.buf.wbuf.Bytes()
}
// create new transport message
msg := transport.Message{
Header: m.Header,
Body: m.Body,
}
// send the request
if err := c.client.Send(c.req); err != nil {
if err := c.client.Send(&msg); err != nil {
return errors.InternalServerError("go.micro.client.transport", err.Error())
}
return nil
}
@@ -141,6 +186,7 @@ func (c *rpcCodec) ReadHeader(wm *codec.Message, r codec.MessageType) error {
// read header
err := c.codec.ReadHeader(&me, r)
wm.Endpoint = me.Endpoint
wm.Method = me.Method
wm.Id = me.Id
wm.Error = me.Error
@@ -149,11 +195,16 @@ func (c *rpcCodec) ReadHeader(wm *codec.Message, r codec.MessageType) error {
wm.Error = me.Header["X-Micro-Error"]
}
// check method in header
// check endpoint in header
if len(me.Endpoint) == 0 {
wm.Endpoint = me.Header["X-Micro-Endpoint"]
}
// check method in header
if len(me.Method) == 0 {
wm.Method = me.Header["X-Micro-Method"]
}
if len(me.Id) == 0 {
wm.Id = me.Header["X-Micro-Id"]
}

View File

@@ -5,7 +5,7 @@ import (
"time"
"github.com/micro/go-micro/transport"
"github.com/micro/go-micro/transport/mock"
"github.com/micro/go-micro/transport/memory"
)
func testPool(t *testing.T, size int, ttl time.Duration) {
@@ -13,7 +13,7 @@ func testPool(t *testing.T, size int, ttl time.Duration) {
p := newPool(size, ttl)
// mock transport
tr := mock.NewTransport()
tr := memory.NewTransport()
// listen
l, err := tr.Listen(":0")

View File

@@ -6,6 +6,7 @@ import (
type rpcRequest struct {
service string
method string
endpoint string
contentType string
codec codec.Codec
@@ -27,6 +28,7 @@ func newRequest(service, endpoint string, request interface{}, contentType strin
return &rpcRequest{
service: service,
method: endpoint,
endpoint: endpoint,
body: request,
contentType: contentType,
@@ -42,6 +44,10 @@ func (r *rpcRequest) Service() string {
return r.service
}
func (r *rpcRequest) Method() string {
return r.method
}
func (r *rpcRequest) Endpoint() string {
return r.endpoint
}

View File

@@ -12,7 +12,7 @@ type rpcResponse struct {
codec codec.Codec
}
func (r *rpcResponse) Codec() codec.Writer {
func (r *rpcResponse) Codec() codec.Reader {
return r.codec
}

View File

@@ -11,12 +11,13 @@ import (
// Implements the streamer interface
type rpcStream struct {
sync.RWMutex
id string
closed chan bool
err error
request Request
codec codec.Codec
context context.Context
id string
closed chan bool
err error
request Request
response Response
codec codec.Codec
context context.Context
}
func (r *rpcStream) isClosed() bool {
@@ -36,6 +37,10 @@ func (r *rpcStream) Request() Request {
return r.request
}
func (r *rpcStream) Response() Response {
return r.response
}
func (r *rpcStream) Send(msg interface{}) error {
r.Lock()
defer r.Unlock()
@@ -48,6 +53,7 @@ func (r *rpcStream) Send(msg interface{}) error {
req := codec.Message{
Id: r.id,
Target: r.request.Service(),
Method: r.request.Method(),
Endpoint: r.request.Endpoint(),
Type: codec.Request,
}

View File

@@ -2,10 +2,12 @@ package client
import (
"context"
"github.com/micro/go-micro/registry"
)
// CallFunc represents the individual call func
type CallFunc func(ctx context.Context, address string, req Request, rsp interface{}, opts CallOptions) error
type CallFunc func(ctx context.Context, node *registry.Node, req Request, rsp interface{}, opts CallOptions) error
// CallWrapper is a low level wrapper for the CallFunc
type CallWrapper func(CallFunc) CallFunc

View File

@@ -17,12 +17,14 @@ import (
// brokers
"github.com/micro/go-micro/broker"
"github.com/micro/go-micro/broker/http"
"github.com/micro/go-micro/broker/memory"
// registries
"github.com/micro/go-micro/registry"
"github.com/micro/go-micro/registry/consul"
"github.com/micro/go-micro/registry/gossip"
"github.com/micro/go-micro/registry/mdns"
rmem "github.com/micro/go-micro/registry/memory"
// selectors
"github.com/micro/go-micro/selector"
@@ -32,6 +34,7 @@ import (
// transports
"github.com/micro/go-micro/transport"
thttp "github.com/micro/go-micro/transport/http"
tmem "github.com/micro/go-micro/transport/memory"
)
type Cmd interface {
@@ -164,7 +167,8 @@ var (
}
DefaultBrokers = map[string]func(...broker.Option) broker.Broker{
"http": http.NewBroker,
"http": http.NewBroker,
"memory": memory.NewBroker,
}
DefaultClients = map[string]func(...client.Option) client.Client{
@@ -175,6 +179,7 @@ var (
"consul": consul.NewRegistry,
"gossip": gossip.NewRegistry,
"mdns": mdns.NewRegistry,
"memory": rmem.NewRegistry,
}
DefaultSelectors = map[string]func(...selector.Option) selector.Selector{
@@ -189,15 +194,16 @@ var (
}
DefaultTransports = map[string]func(...transport.Option) transport.Transport{
"http": thttp.NewTransport,
"memory": tmem.NewTransport,
"http": thttp.NewTransport,
}
// used for default selection as the fall back
defaultClient = "rpc"
defaultServer = "rpc"
defaultBroker = "http"
defaultRegistry = "consul"
defaultSelector = "cache"
defaultRegistry = "mdns"
defaultSelector = "registry"
defaultTransport = "http"
)

View File

@@ -13,33 +13,50 @@ type Codec struct {
Conn io.ReadWriteCloser
}
// Frame gives us the ability to define raw data to send over the pipes
type Frame struct {
Data []byte
}
func (c *Codec) ReadHeader(m *codec.Message, t codec.MessageType) error {
return nil
}
func (c *Codec) ReadBody(b interface{}) error {
v, ok := b.(*[]byte)
if !ok {
return fmt.Errorf("failed to read body: %v is not type of *[]byte", b)
}
// read bytes
buf, err := ioutil.ReadAll(c.Conn)
if err != nil {
return err
}
// set bytes
*v = buf
switch b.(type) {
case *[]byte:
v := b.(*[]byte)
*v = buf
case *Frame:
v := b.(*Frame)
v.Data = buf
default:
return fmt.Errorf("failed to read body: %v is not type of *[]byte", b)
}
return nil
}
func (c *Codec) Write(m *codec.Message, b interface{}) error {
v, ok := b.(*[]byte)
if !ok {
return fmt.Errorf("failed to write: %v is not type of *[]byte", b)
var v []byte
switch b.(type) {
case *Frame:
v = b.(*Frame).Data
case *[]byte:
ve := b.(*[]byte)
v = *ve
case []byte:
v = b.([]byte)
default:
return fmt.Errorf("failed to write: %v is not type of *[]byte or []byte", b)
}
_, err := c.Conn.Write(*v)
_, err := c.Conn.Write(v)
return err
}

View File

@@ -13,6 +13,9 @@ type Message struct {
func (n Marshaler) Marshal(v interface{}) ([]byte, error) {
switch v.(type) {
case *[]byte:
ve := v.(*[]byte)
return *ve, nil
case []byte:
return v.([]byte), nil
case *Message:

View File

@@ -53,6 +53,7 @@ type Message struct {
Id string
Type MessageType
Target string
Method string
Endpoint string
Error string

View File

@@ -45,9 +45,9 @@ func newClientCodec(conn io.ReadWriteCloser) *clientCodec {
func (c *clientCodec) Write(m *codec.Message, b interface{}) error {
c.Lock()
c.pending[m.Id] = m.Endpoint
c.pending[m.Id] = m.Method
c.Unlock()
c.req.Method = m.Endpoint
c.req.Method = m.Method
c.req.Params[0] = b
c.req.ID = m.Id
return c.enc.Encode(&c.req)
@@ -66,7 +66,7 @@ func (c *clientCodec) ReadHeader(m *codec.Message) error {
}
c.Lock()
m.Endpoint = c.pending[c.resp.ID]
m.Method = c.pending[c.resp.ID]
delete(c.pending, c.resp.ID)
c.Unlock()

View File

@@ -53,7 +53,7 @@ func (c *serverCodec) ReadHeader(m *codec.Message) error {
if err := c.dec.Decode(&c.req); err != nil {
return err
}
m.Endpoint = c.req.Method
m.Method = c.req.Method
m.Id = fmt.Sprintf("%v", c.req.ID)
c.req.ID = nil
return nil

View File

@@ -18,13 +18,13 @@ func (c *Codec) ReadHeader(m *codec.Message, t codec.MessageType) error {
}
func (c *Codec) ReadBody(b interface{}) error {
if b == nil {
return nil
}
buf, err := ioutil.ReadAll(c.Conn)
if err != nil {
return err
}
if b == nil {
return nil
}
return proto.Unmarshal(buf, b.(proto.Message))
}

View File

@@ -47,7 +47,7 @@ func (c *protoCodec) Write(m *codec.Message, b interface{}) error {
c.Lock()
defer c.Unlock()
// This is protobuf, of course we copy it.
pbr := &Request{ServiceMethod: &m.Endpoint, Seq: id(m.Id)}
pbr := &Request{ServiceMethod: &m.Method, Seq: id(m.Id)}
data, err := proto.Marshal(pbr)
if err != nil {
return err
@@ -73,7 +73,7 @@ func (c *protoCodec) Write(m *codec.Message, b interface{}) error {
case codec.Response, codec.Error:
c.Lock()
defer c.Unlock()
rtmp := &Response{ServiceMethod: &m.Endpoint, Seq: id(m.Id), Error: &m.Error}
rtmp := &Response{ServiceMethod: &m.Method, Seq: id(m.Id), Error: &m.Error}
data, err := proto.Marshal(rtmp)
if err != nil {
return err
@@ -126,7 +126,7 @@ func (c *protoCodec) ReadHeader(m *codec.Message, mt codec.MessageType) error {
if err != nil {
return err
}
m.Endpoint = rtmp.GetServiceMethod()
m.Method = rtmp.GetServiceMethod()
m.Id = fmt.Sprintf("%d", rtmp.GetSeq())
case codec.Response:
data, err := ReadNetString(c.rwc)
@@ -138,7 +138,7 @@ func (c *protoCodec) ReadHeader(m *codec.Message, mt codec.MessageType) error {
if err != nil {
return err
}
m.Endpoint = rtmp.GetServiceMethod()
m.Method = rtmp.GetServiceMethod()
m.Id = fmt.Sprintf("%d", rtmp.GetSeq())
m.Error = rtmp.GetError()
case codec.Publication:

View File

@@ -5,7 +5,7 @@ import (
"sync"
"testing"
"github.com/micro/go-micro/registry/mock"
"github.com/micro/go-micro/registry/memory"
proto "github.com/micro/go-micro/server/debug/proto"
)
@@ -13,9 +13,12 @@ func TestFunction(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
r := memory.NewRegistry()
r.(*memory.Registry).Setup()
// create service
fn := NewFunction(
Registry(mock.NewRegistry()),
Registry(r),
Name("test.function"),
AfterStart(func() error {
wg.Done()

View File

@@ -12,6 +12,14 @@ type metaKey struct{}
// from Transport headers.
type Metadata map[string]string
func Copy(md Metadata) Metadata {
cmd := make(Metadata)
for k, v := range md {
cmd[k] = v
}
return cmd
}
func FromContext(ctx context.Context) (Metadata, bool) {
md, ok := ctx.Value(metaKey{}).(Metadata)
return md, ok

View File

@@ -5,6 +5,21 @@ import (
"testing"
)
func TestMetadataCopy(t *testing.T) {
md := Metadata{
"foo": "bar",
"bar": "baz",
}
cp := Copy(md)
for k, v := range md {
if cv := cp[k]; cv != v {
t.Fatalf("Got %s:%s for %s:%s", k, cv, k, v)
}
}
}
func TestMetadataContext(t *testing.T) {
md := Metadata{
"foo": "bar",

View File

@@ -1,11 +1,387 @@
// Package consul provides a consul based registry and is the default discovery system
package consul
import (
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"runtime"
"sync"
"time"
consul "github.com/hashicorp/consul/api"
"github.com/micro/go-micro/registry"
hash "github.com/mitchellh/hashstructure"
)
// NewRegistry returns a new consul registry
func NewRegistry(opts ...registry.Option) registry.Registry {
return registry.NewRegistry(opts...)
type consulRegistry struct {
Address string
Client *consul.Client
opts registry.Options
// connect enabled
connect bool
queryOptions *consul.QueryOptions
sync.Mutex
register map[string]uint64
// lastChecked tracks when a node was last checked as existing in Consul
lastChecked map[string]time.Time
}
func getDeregisterTTL(t time.Duration) time.Duration {
// splay slightly for the watcher?
splay := time.Second * 5
deregTTL := t + splay
// consul has a minimum timeout on deregistration of 1 minute.
if t < time.Minute {
deregTTL = time.Minute + splay
}
return deregTTL
}
func newTransport(config *tls.Config) *http.Transport {
if config == nil {
config = &tls.Config{
InsecureSkipVerify: true,
}
}
t := &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial,
TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: config,
}
runtime.SetFinalizer(&t, func(tr **http.Transport) {
(*tr).CloseIdleConnections()
})
return t
}
func configure(c *consulRegistry, opts ...registry.Option) {
// set opts
for _, o := range opts {
o(&c.opts)
}
// use default config
config := consul.DefaultConfig()
if c.opts.Context != nil {
// Use the consul config passed in the options, if available
if co, ok := c.opts.Context.Value("consul_config").(*consul.Config); ok {
config = co
}
if cn, ok := c.opts.Context.Value("consul_connect").(bool); ok {
c.connect = cn
}
// Use the consul query options passed in the options, if available
if qo, ok := c.opts.Context.Value("consul_query_options").(*consul.QueryOptions); ok && qo != nil {
c.queryOptions = qo
}
if as, ok := c.opts.Context.Value("consul_allow_stale").(bool); ok {
c.queryOptions.AllowStale = as
}
}
// check if there are any addrs
if len(c.opts.Addrs) > 0 {
addr, port, err := net.SplitHostPort(c.opts.Addrs[0])
if ae, ok := err.(*net.AddrError); ok && ae.Err == "missing port in address" {
port = "8500"
addr = c.opts.Addrs[0]
config.Address = fmt.Sprintf("%s:%s", addr, port)
} else if err == nil {
config.Address = fmt.Sprintf("%s:%s", addr, port)
}
}
// requires secure connection?
if c.opts.Secure || c.opts.TLSConfig != nil {
if config.HttpClient == nil {
config.HttpClient = new(http.Client)
}
config.Scheme = "https"
// We're going to support InsecureSkipVerify
config.HttpClient.Transport = newTransport(c.opts.TLSConfig)
}
// set timeout
if c.opts.Timeout > 0 {
config.HttpClient.Timeout = c.opts.Timeout
}
// create the client
client, _ := consul.NewClient(config)
// set address/client
c.Address = config.Address
c.Client = client
}
func (c *consulRegistry) Init(opts ...registry.Option) error {
configure(c, opts...)
return nil
}
func (c *consulRegistry) Deregister(s *registry.Service) error {
if len(s.Nodes) == 0 {
return errors.New("Require at least one node")
}
// delete our hash and time check of the service
c.Lock()
delete(c.register, s.Name)
delete(c.lastChecked, s.Name)
c.Unlock()
node := s.Nodes[0]
return c.Client.Agent().ServiceDeregister(node.Id)
}
func (c *consulRegistry) Register(s *registry.Service, opts ...registry.RegisterOption) error {
if len(s.Nodes) == 0 {
return errors.New("Require at least one node")
}
var regTCPCheck bool
var regInterval time.Duration
var options registry.RegisterOptions
for _, o := range opts {
o(&options)
}
if c.opts.Context != nil {
if tcpCheckInterval, ok := c.opts.Context.Value("consul_tcp_check").(time.Duration); ok {
regTCPCheck = true
regInterval = tcpCheckInterval
}
}
// create hash of service; uint64
h, err := hash.Hash(s, nil)
if err != nil {
return err
}
// use first node
node := s.Nodes[0]
// get existing hash and last checked time
c.Lock()
v, ok := c.register[s.Name]
lastChecked := c.lastChecked[s.Name]
c.Unlock()
// if it's already registered and matches then just pass the check
if ok && v == h {
if options.TTL == time.Duration(0) {
// ensure that our service hasn't been deregistered by Consul
if time.Since(lastChecked) <= getDeregisterTTL(regInterval) {
return nil
}
services, _, err := c.Client.Health().Checks(s.Name, c.queryOptions)
if err == nil {
for _, v := range services {
if v.ServiceID == node.Id {
return nil
}
}
}
} else {
// if the err is nil we're all good, bail out
// if not, we don't know what the state is, so full re-register
if err := c.Client.Agent().PassTTL("service:"+node.Id, ""); err == nil {
return nil
}
}
}
// encode the tags
tags := encodeMetadata(node.Metadata)
tags = append(tags, encodeEndpoints(s.Endpoints)...)
tags = append(tags, encodeVersion(s.Version)...)
var check *consul.AgentServiceCheck
if regTCPCheck {
deregTTL := getDeregisterTTL(regInterval)
check = &consul.AgentServiceCheck{
TCP: fmt.Sprintf("%s:%d", node.Address, node.Port),
Interval: fmt.Sprintf("%v", regInterval),
DeregisterCriticalServiceAfter: fmt.Sprintf("%v", deregTTL),
}
// if the TTL is greater than 0 create an associated check
} else if options.TTL > time.Duration(0) {
deregTTL := getDeregisterTTL(options.TTL)
check = &consul.AgentServiceCheck{
TTL: fmt.Sprintf("%v", options.TTL),
DeregisterCriticalServiceAfter: fmt.Sprintf("%v", deregTTL),
}
}
// register the service
asr := &consul.AgentServiceRegistration{
ID: node.Id,
Name: s.Name,
Tags: tags,
Port: node.Port,
Address: node.Address,
Check: check,
}
// Specify consul connect
if c.connect {
asr.Connect = &consul.AgentServiceConnect{
Native: true,
}
}
if err := c.Client.Agent().ServiceRegister(asr); err != nil {
return err
}
// save our hash and time check of the service
c.Lock()
c.register[s.Name] = h
c.lastChecked[s.Name] = time.Now()
c.Unlock()
// if the TTL is 0 we don't mess with the checks
if options.TTL == time.Duration(0) {
return nil
}
// pass the healthcheck
return c.Client.Agent().PassTTL("service:"+node.Id, "")
}
func (c *consulRegistry) GetService(name string) ([]*registry.Service, error) {
var rsp []*consul.ServiceEntry
var err error
// if we're connect enabled only get connect services
if c.connect {
rsp, _, err = c.Client.Health().Connect(name, "", false, c.queryOptions)
} else {
rsp, _, err = c.Client.Health().Service(name, "", false, c.queryOptions)
}
if err != nil {
return nil, err
}
serviceMap := map[string]*registry.Service{}
for _, s := range rsp {
if s.Service.Service != name {
continue
}
// version is now a tag
version, _ := decodeVersion(s.Service.Tags)
// service ID is now the node id
id := s.Service.ID
// key is always the version
key := version
// address is service address
address := s.Service.Address
// use node address
if len(address) == 0 {
address = s.Node.Address
}
svc, ok := serviceMap[key]
if !ok {
svc = &registry.Service{
Endpoints: decodeEndpoints(s.Service.Tags),
Name: s.Service.Service,
Version: version,
}
serviceMap[key] = svc
}
var del bool
for _, check := range s.Checks {
// delete the node if the status is critical
if check.Status == "critical" {
del = true
break
}
}
// if delete then skip the node
if del {
continue
}
svc.Nodes = append(svc.Nodes, &registry.Node{
Id: id,
Address: address,
Port: s.Service.Port,
Metadata: decodeMetadata(s.Service.Tags),
})
}
var services []*registry.Service
for _, service := range serviceMap {
services = append(services, service)
}
return services, nil
}
func (c *consulRegistry) ListServices() ([]*registry.Service, error) {
rsp, _, err := c.Client.Catalog().Services(c.queryOptions)
if err != nil {
return nil, err
}
var services []*registry.Service
for service := range rsp {
services = append(services, &registry.Service{Name: service})
}
return services, nil
}
func (c *consulRegistry) Watch(opts ...registry.WatchOption) (registry.Watcher, error) {
return newConsulWatcher(c, opts...)
}
func (c *consulRegistry) String() string {
return "consul"
}
func (c *consulRegistry) Options() registry.Options {
return c.opts
}
func NewRegistry(opts ...registry.Option) registry.Registry {
cr := &consulRegistry{
opts: registry.Options{},
register: make(map[string]uint64),
lastChecked: make(map[string]time.Time),
queryOptions: &consul.QueryOptions{
AllowStale: true,
},
}
configure(cr, opts...)
return cr
}

170
registry/consul/encoding.go Normal file
View File

@@ -0,0 +1,170 @@
package consul
import (
"bytes"
"compress/zlib"
"encoding/hex"
"encoding/json"
"io/ioutil"
"github.com/micro/go-micro/registry"
)
func encode(buf []byte) string {
var b bytes.Buffer
defer b.Reset()
w := zlib.NewWriter(&b)
if _, err := w.Write(buf); err != nil {
return ""
}
w.Close()
return hex.EncodeToString(b.Bytes())
}
func decode(d string) []byte {
hr, err := hex.DecodeString(d)
if err != nil {
return nil
}
br := bytes.NewReader(hr)
zr, err := zlib.NewReader(br)
if err != nil {
return nil
}
rbuf, err := ioutil.ReadAll(zr)
if err != nil {
return nil
}
return rbuf
}
func encodeEndpoints(en []*registry.Endpoint) []string {
var tags []string
for _, e := range en {
if b, err := json.Marshal(e); err == nil {
tags = append(tags, "e-"+encode(b))
}
}
return tags
}
func decodeEndpoints(tags []string) []*registry.Endpoint {
var en []*registry.Endpoint
// use the first format you find
var ver byte
for _, tag := range tags {
if len(tag) == 0 || tag[0] != 'e' {
continue
}
// check version
if ver > 0 && tag[1] != ver {
continue
}
var e *registry.Endpoint
var buf []byte
// Old encoding was plain
if tag[1] == '=' {
buf = []byte(tag[2:])
}
// New encoding is hex
if tag[1] == '-' {
buf = decode(tag[2:])
}
if err := json.Unmarshal(buf, &e); err == nil {
en = append(en, e)
}
// set version
ver = tag[1]
}
return en
}
func encodeMetadata(md map[string]string) []string {
var tags []string
for k, v := range md {
if b, err := json.Marshal(map[string]string{
k: v,
}); err == nil {
// new encoding
tags = append(tags, "t-"+encode(b))
}
}
return tags
}
func decodeMetadata(tags []string) map[string]string {
md := make(map[string]string)
var ver byte
for _, tag := range tags {
if len(tag) == 0 || tag[0] != 't' {
continue
}
// check version
if ver > 0 && tag[1] != ver {
continue
}
var kv map[string]string
var buf []byte
// Old encoding was plain
if tag[1] == '=' {
buf = []byte(tag[2:])
}
// New encoding is hex
if tag[1] == '-' {
buf = decode(tag[2:])
}
// Now unmarshal
if err := json.Unmarshal(buf, &kv); err == nil {
for k, v := range kv {
md[k] = v
}
}
// set version
ver = tag[1]
}
return md
}
func encodeVersion(v string) []string {
return []string{"v-" + encode([]byte(v))}
}
func decodeVersion(tags []string) (string, bool) {
for _, tag := range tags {
if len(tag) < 2 || tag[0] != 'v' {
continue
}
// Old encoding was plain
if tag[1] == '=' {
return tag[2:], true
}
// New encoding is hex
if tag[1] == '-' {
return string(decode(tag[2:])), true
}
}
return "", false
}

View File

@@ -0,0 +1,147 @@
package consul
import (
"encoding/json"
"testing"
"github.com/micro/go-micro/registry"
)
func TestEncodingEndpoints(t *testing.T) {
eps := []*registry.Endpoint{
&registry.Endpoint{
Name: "endpoint1",
Request: &registry.Value{
Name: "request",
Type: "request",
},
Response: &registry.Value{
Name: "response",
Type: "response",
},
Metadata: map[string]string{
"foo1": "bar1",
},
},
&registry.Endpoint{
Name: "endpoint2",
Request: &registry.Value{
Name: "request",
Type: "request",
},
Response: &registry.Value{
Name: "response",
Type: "response",
},
Metadata: map[string]string{
"foo2": "bar2",
},
},
&registry.Endpoint{
Name: "endpoint3",
Request: &registry.Value{
Name: "request",
Type: "request",
},
Response: &registry.Value{
Name: "response",
Type: "response",
},
Metadata: map[string]string{
"foo3": "bar3",
},
},
}
testEp := func(ep *registry.Endpoint, enc string) {
// encode endpoint
e := encodeEndpoints([]*registry.Endpoint{ep})
// check there are two tags; old and new
if len(e) != 1 {
t.Fatalf("Expected 1 encoded tags, got %v", e)
}
// check old encoding
var seen bool
for _, en := range e {
if en == enc {
seen = true
break
}
}
if !seen {
t.Fatalf("Expected %s but not found", enc)
}
// decode
d := decodeEndpoints([]string{enc})
if len(d) == 0 {
t.Fatalf("Expected %v got %v", ep, d)
}
// check name
if d[0].Name != ep.Name {
t.Fatalf("Expected ep %s got %s", ep.Name, d[0].Name)
}
// check all the metadata exists
for k, v := range ep.Metadata {
if gv := d[0].Metadata[k]; gv != v {
t.Fatalf("Expected key %s val %s got val %s", k, v, gv)
}
}
}
for _, ep := range eps {
// JSON encoded
jencoded, err := json.Marshal(ep)
if err != nil {
t.Fatal(err)
}
// HEX encoded
hencoded := encode(jencoded)
// endpoint tag
hepTag := "e-" + hencoded
testEp(ep, hepTag)
}
}
func TestEncodingVersion(t *testing.T) {
testData := []struct {
decoded string
encoded string
}{
{"1.0.0", "v-789c32d433d03300040000ffff02ce00ee"},
{"latest", "v-789cca492c492d2e01040000ffff08cc028e"},
}
for _, data := range testData {
e := encodeVersion(data.decoded)
if e[0] != data.encoded {
t.Fatalf("Expected %s got %s", data.encoded, e)
}
d, ok := decodeVersion(e)
if !ok {
t.Fatalf("Unexpected %t for %s", ok, data.encoded)
}
if d != data.decoded {
t.Fatalf("Expected %s got %s", data.decoded, d)
}
d, ok = decodeVersion([]string{data.encoded})
if !ok {
t.Fatalf("Unexpected %t for %s", ok, data.encoded)
}
if d != data.decoded {
t.Fatalf("Expected %s got %s", data.decoded, d)
}
}
}

View File

@@ -1,4 +1,4 @@
package registry
package consul
import (
"bytes"
@@ -10,6 +10,7 @@ import (
"time"
consul "github.com/hashicorp/consul/api"
"github.com/micro/go-micro/registry"
)
type mockRegistry struct {
@@ -56,7 +57,7 @@ func newConsulTestRegistry(r *mockRegistry) (*consulRegistry, func()) {
return &consulRegistry{
Address: cfg.Address,
Client: cl,
opts: Options{},
opts: registry.Options{},
register: make(map[string]uint64),
lastChecked: make(map[string]time.Time),
queryOptions: &consul.QueryOptions{

View File

@@ -1,4 +1,4 @@
package registry
package consul
import (
"errors"
@@ -6,23 +6,24 @@ import (
"github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/watch"
"github.com/micro/go-micro/registry"
)
type consulWatcher struct {
r *consulRegistry
wo WatchOptions
wo registry.WatchOptions
wp *watch.Plan
watchers map[string]*watch.Plan
next chan *Result
next chan *registry.Result
exit chan bool
sync.RWMutex
services map[string][]*Service
services map[string][]*registry.Service
}
func newConsulWatcher(cr *consulRegistry, opts ...WatchOption) (Watcher, error) {
var wo WatchOptions
func newConsulWatcher(cr *consulRegistry, opts ...registry.WatchOption) (registry.Watcher, error) {
var wo registry.WatchOptions
for _, o := range opts {
o(&wo)
}
@@ -31,9 +32,9 @@ func newConsulWatcher(cr *consulRegistry, opts ...WatchOption) (Watcher, error)
r: cr,
wo: wo,
exit: make(chan bool),
next: make(chan *Result, 10),
next: make(chan *registry.Result, 10),
watchers: make(map[string]*watch.Plan),
services: make(map[string][]*Service),
services: make(map[string][]*registry.Service),
}
wp, err := watch.Parse(map[string]interface{}{"type": "services"})
@@ -54,7 +55,7 @@ func (cw *consulWatcher) serviceHandler(idx uint64, data interface{}) {
return
}
serviceMap := map[string]*Service{}
serviceMap := map[string]*registry.Service{}
serviceName := ""
for _, e := range entries {
@@ -75,7 +76,7 @@ func (cw *consulWatcher) serviceHandler(idx uint64, data interface{}) {
svc, ok := serviceMap[key]
if !ok {
svc = &Service{
svc = &registry.Service{
Endpoints: decodeEndpoints(e.Service.Tags),
Name: e.Service.Service,
Version: version,
@@ -98,7 +99,7 @@ func (cw *consulWatcher) serviceHandler(idx uint64, data interface{}) {
continue
}
svc.Nodes = append(svc.Nodes, &Node{
svc.Nodes = append(svc.Nodes, &registry.Node{
Id: id,
Address: address,
Port: e.Service.Port,
@@ -108,13 +109,13 @@ func (cw *consulWatcher) serviceHandler(idx uint64, data interface{}) {
cw.RLock()
// make a copy
rservices := make(map[string][]*Service)
rservices := make(map[string][]*registry.Service)
for k, v := range cw.services {
rservices[k] = v
}
cw.RUnlock()
var newServices []*Service
var newServices []*registry.Service
// serviceMap is the new set of services keyed by name+version
for _, newService := range serviceMap {
@@ -125,7 +126,7 @@ func (cw *consulWatcher) serviceHandler(idx uint64, data interface{}) {
oldServices, ok := rservices[serviceName]
if !ok {
// does not exist? then we're creating brand new entries
cw.next <- &Result{Action: "create", Service: newService}
cw.next <- &registry.Result{Action: "create", Service: newService}
continue
}
@@ -142,7 +143,7 @@ func (cw *consulWatcher) serviceHandler(idx uint64, data interface{}) {
// yes? then it's an update
action = "update"
var nodes []*Node
var nodes []*registry.Node
// check the old nodes to see if they've been deleted
for _, oldNode := range oldService.Nodes {
var seen bool
@@ -163,11 +164,11 @@ func (cw *consulWatcher) serviceHandler(idx uint64, data interface{}) {
if len(nodes) > 0 {
delService := oldService
delService.Nodes = nodes
cw.next <- &Result{Action: "delete", Service: delService}
cw.next <- &registry.Result{Action: "delete", Service: delService}
}
}
cw.next <- &Result{Action: action, Service: newService}
cw.next <- &registry.Result{Action: action, Service: newService}
}
// Now check old versions that may not be in new services map
@@ -175,7 +176,7 @@ func (cw *consulWatcher) serviceHandler(idx uint64, data interface{}) {
// old version does not exist in new version map
// kill it with fire!
if _, ok := serviceMap[old.Version]; !ok {
cw.next <- &Result{Action: "delete", Service: old}
cw.next <- &registry.Result{Action: "delete", Service: old}
}
}
@@ -209,13 +210,13 @@ func (cw *consulWatcher) handle(idx uint64, data interface{}) {
wp.Handler = cw.serviceHandler
go wp.Run(cw.r.Address)
cw.watchers[service] = wp
cw.next <- &Result{Action: "create", Service: &Service{Name: service}}
cw.next <- &registry.Result{Action: "create", Service: &registry.Service{Name: service}}
}
}
cw.RLock()
// make a copy
rservices := make(map[string][]*Service)
rservices := make(map[string][]*registry.Service)
for k, v := range cw.services {
rservices[k] = v
}
@@ -235,12 +236,12 @@ func (cw *consulWatcher) handle(idx uint64, data interface{}) {
if _, ok := services[service]; !ok {
w.Stop()
delete(cw.watchers, service)
cw.next <- &Result{Action: "delete", Service: &Service{Name: service}}
cw.next <- &registry.Result{Action: "delete", Service: &registry.Service{Name: service}}
}
}
}
func (cw *consulWatcher) Next() (*Result, error) {
func (cw *consulWatcher) Next() (*registry.Result, error) {
select {
case <-cw.exit:
return nil, errors.New("result chan closed")

View File

@@ -1,9 +1,10 @@
package registry
package consul
import (
"testing"
"github.com/hashicorp/consul/api"
"github.com/micro/go-micro/registry"
)
func TestHealthyServiceHandler(t *testing.T) {
@@ -58,8 +59,8 @@ func TestUnhealthyNodeServiceHandler(t *testing.T) {
func newWatcher() *consulWatcher {
return &consulWatcher{
exit: make(chan bool),
next: make(chan *Result, 10),
services: make(map[string][]*Service),
next: make(chan *registry.Result, 10),
services: make(map[string][]*registry.Service),
}
}

View File

@@ -1,386 +0,0 @@
package registry
import (
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"runtime"
"sync"
"time"
consul "github.com/hashicorp/consul/api"
hash "github.com/mitchellh/hashstructure"
)
type consulRegistry struct {
Address string
Client *consul.Client
opts Options
// connect enabled
connect bool
queryOptions *consul.QueryOptions
sync.Mutex
register map[string]uint64
// lastChecked tracks when a node was last checked as existing in Consul
lastChecked map[string]time.Time
}
func getDeregisterTTL(t time.Duration) time.Duration {
// splay slightly for the watcher?
splay := time.Second * 5
deregTTL := t + splay
// consul has a minimum timeout on deregistration of 1 minute.
if t < time.Minute {
deregTTL = time.Minute + splay
}
return deregTTL
}
func newTransport(config *tls.Config) *http.Transport {
if config == nil {
config = &tls.Config{
InsecureSkipVerify: true,
}
}
t := &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial,
TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: config,
}
runtime.SetFinalizer(&t, func(tr **http.Transport) {
(*tr).CloseIdleConnections()
})
return t
}
func configure(c *consulRegistry, opts ...Option) {
// set opts
for _, o := range opts {
o(&c.opts)
}
// use default config
config := consul.DefaultConfig()
if c.opts.Context != nil {
// Use the consul config passed in the options, if available
if co, ok := c.opts.Context.Value("consul_config").(*consul.Config); ok {
config = co
}
if cn, ok := c.opts.Context.Value("consul_connect").(bool); ok {
c.connect = cn
}
// Use the consul query options passed in the options, if available
if qo, ok := c.opts.Context.Value("consul_query_options").(*consul.QueryOptions); ok && qo != nil {
c.queryOptions = qo
}
if as, ok := c.opts.Context.Value("consul_allow_stale").(bool); ok {
c.queryOptions.AllowStale = as
}
}
// check if there are any addrs
if len(c.opts.Addrs) > 0 {
addr, port, err := net.SplitHostPort(c.opts.Addrs[0])
if ae, ok := err.(*net.AddrError); ok && ae.Err == "missing port in address" {
port = "8500"
addr = c.opts.Addrs[0]
config.Address = fmt.Sprintf("%s:%s", addr, port)
} else if err == nil {
config.Address = fmt.Sprintf("%s:%s", addr, port)
}
}
// requires secure connection?
if c.opts.Secure || c.opts.TLSConfig != nil {
if config.HttpClient == nil {
config.HttpClient = new(http.Client)
}
config.Scheme = "https"
// We're going to support InsecureSkipVerify
config.HttpClient.Transport = newTransport(c.opts.TLSConfig)
}
// set timeout
if c.opts.Timeout > 0 {
config.HttpClient.Timeout = c.opts.Timeout
}
// create the client
client, _ := consul.NewClient(config)
// set address/client
c.Address = config.Address
c.Client = client
}
func newConsulRegistry(opts ...Option) Registry {
cr := &consulRegistry{
opts: Options{},
register: make(map[string]uint64),
lastChecked: make(map[string]time.Time),
queryOptions: &consul.QueryOptions{
AllowStale: true,
},
}
configure(cr, opts...)
return cr
}
func (c *consulRegistry) Init(opts ...Option) error {
configure(c, opts...)
return nil
}
func (c *consulRegistry) Deregister(s *Service) error {
if len(s.Nodes) == 0 {
return errors.New("Require at least one node")
}
// delete our hash and time check of the service
c.Lock()
delete(c.register, s.Name)
delete(c.lastChecked, s.Name)
c.Unlock()
node := s.Nodes[0]
return c.Client.Agent().ServiceDeregister(node.Id)
}
func (c *consulRegistry) Register(s *Service, opts ...RegisterOption) error {
if len(s.Nodes) == 0 {
return errors.New("Require at least one node")
}
var regTCPCheck bool
var regInterval time.Duration
var options RegisterOptions
for _, o := range opts {
o(&options)
}
if c.opts.Context != nil {
if tcpCheckInterval, ok := c.opts.Context.Value("consul_tcp_check").(time.Duration); ok {
regTCPCheck = true
regInterval = tcpCheckInterval
}
}
// create hash of service; uint64
h, err := hash.Hash(s, nil)
if err != nil {
return err
}
// use first node
node := s.Nodes[0]
// get existing hash and last checked time
c.Lock()
v, ok := c.register[s.Name]
lastChecked := c.lastChecked[s.Name]
c.Unlock()
// if it's already registered and matches then just pass the check
if ok && v == h {
if options.TTL == time.Duration(0) {
// ensure that our service hasn't been deregistered by Consul
if time.Since(lastChecked) <= getDeregisterTTL(regInterval) {
return nil
}
services, _, err := c.Client.Health().Checks(s.Name, c.queryOptions)
if err == nil {
for _, v := range services {
if v.ServiceID == node.Id {
return nil
}
}
}
} else {
// if the err is nil we're all good, bail out
// if not, we don't know what the state is, so full re-register
if err := c.Client.Agent().PassTTL("service:"+node.Id, ""); err == nil {
return nil
}
}
}
// encode the tags
tags := encodeMetadata(node.Metadata)
tags = append(tags, encodeEndpoints(s.Endpoints)...)
tags = append(tags, encodeVersion(s.Version)...)
var check *consul.AgentServiceCheck
if regTCPCheck {
deregTTL := getDeregisterTTL(regInterval)
check = &consul.AgentServiceCheck{
TCP: fmt.Sprintf("%s:%d", node.Address, node.Port),
Interval: fmt.Sprintf("%v", regInterval),
DeregisterCriticalServiceAfter: fmt.Sprintf("%v", deregTTL),
}
// if the TTL is greater than 0 create an associated check
} else if options.TTL > time.Duration(0) {
deregTTL := getDeregisterTTL(options.TTL)
check = &consul.AgentServiceCheck{
TTL: fmt.Sprintf("%v", options.TTL),
DeregisterCriticalServiceAfter: fmt.Sprintf("%v", deregTTL),
}
}
// register the service
asr := &consul.AgentServiceRegistration{
ID: node.Id,
Name: s.Name,
Tags: tags,
Port: node.Port,
Address: node.Address,
Check: check,
}
// Specify consul connect
if c.connect {
asr.Connect = &consul.AgentServiceConnect{
Native: true,
}
}
if err := c.Client.Agent().ServiceRegister(asr); err != nil {
return err
}
// save our hash and time check of the service
c.Lock()
c.register[s.Name] = h
c.lastChecked[s.Name] = time.Now()
c.Unlock()
// if the TTL is 0 we don't mess with the checks
if options.TTL == time.Duration(0) {
return nil
}
// pass the healthcheck
return c.Client.Agent().PassTTL("service:"+node.Id, "")
}
func (c *consulRegistry) GetService(name string) ([]*Service, error) {
var rsp []*consul.ServiceEntry
var err error
// if we're connect enabled only get connect services
if c.connect {
rsp, _, err = c.Client.Health().Connect(name, "", false, c.queryOptions)
} else {
rsp, _, err = c.Client.Health().Service(name, "", false, c.queryOptions)
}
if err != nil {
return nil, err
}
serviceMap := map[string]*Service{}
for _, s := range rsp {
if s.Service.Service != name {
continue
}
// version is now a tag
version, _ := decodeVersion(s.Service.Tags)
// service ID is now the node id
id := s.Service.ID
// key is always the version
key := version
// address is service address
address := s.Service.Address
// use node address
if len(address) == 0 {
address = s.Node.Address
}
svc, ok := serviceMap[key]
if !ok {
svc = &Service{
Endpoints: decodeEndpoints(s.Service.Tags),
Name: s.Service.Service,
Version: version,
}
serviceMap[key] = svc
}
var del bool
for _, check := range s.Checks {
// delete the node if the status is critical
if check.Status == "critical" {
del = true
break
}
}
// if delete then skip the node
if del {
continue
}
svc.Nodes = append(svc.Nodes, &Node{
Id: id,
Address: address,
Port: s.Service.Port,
Metadata: decodeMetadata(s.Service.Tags),
})
}
var services []*Service
for _, service := range serviceMap {
services = append(services, service)
}
return services, nil
}
func (c *consulRegistry) ListServices() ([]*Service, error) {
rsp, _, err := c.Client.Catalog().Services(c.queryOptions)
if err != nil {
return nil, err
}
var services []*Service
for service := range rsp {
services = append(services, &Service{Name: service})
}
return services, nil
}
func (c *consulRegistry) Watch(opts ...WatchOption) (Watcher, error) {
return newConsulWatcher(c, opts...)
}
func (c *consulRegistry) String() string {
return "consul"
}
func (c *consulRegistry) Options() Options {
return c.opts
}

View File

@@ -6,163 +6,68 @@ import (
"encoding/hex"
"encoding/json"
"io/ioutil"
"strings"
)
func encode(buf []byte) string {
var b bytes.Buffer
defer b.Reset()
func encode(txt *mdnsTxt) ([]string, error) {
b, err := json.Marshal(txt)
if err != nil {
return nil, err
}
w := zlib.NewWriter(&b)
if _, err := w.Write(buf); err != nil {
return ""
var buf bytes.Buffer
defer buf.Reset()
w := zlib.NewWriter(&buf)
if _, err := w.Write(b); err != nil {
return nil, err
}
w.Close()
return hex.EncodeToString(b.Bytes())
encoded := hex.EncodeToString(buf.Bytes())
// individual txt limit
if len(encoded) <= 255 {
return []string{encoded}, nil
}
// split encoded string
var record []string
for len(encoded) > 255 {
record = append(record, encoded[:255])
encoded = encoded[255:]
}
record = append(record, encoded)
return record, nil
}
func decode(d string) []byte {
hr, err := hex.DecodeString(d)
func decode(record []string) (*mdnsTxt, error) {
encoded := strings.Join(record, "")
hr, err := hex.DecodeString(encoded)
if err != nil {
return nil
return nil, err
}
br := bytes.NewReader(hr)
zr, err := zlib.NewReader(br)
if err != nil {
return nil
return nil, err
}
rbuf, err := ioutil.ReadAll(zr)
if err != nil {
return nil
return nil, err
}
return rbuf
}
var txt *mdnsTxt
func encodeEndpoints(en []*Endpoint) []string {
var tags []string
for _, e := range en {
if b, err := json.Marshal(e); err == nil {
tags = append(tags, "e-"+encode(b))
}
if err := json.Unmarshal(rbuf, &txt); err != nil {
return nil, err
}
return tags
}
func decodeEndpoints(tags []string) []*Endpoint {
var en []*Endpoint
// use the first format you find
var ver byte
for _, tag := range tags {
if len(tag) == 0 || tag[0] != 'e' {
continue
}
// check version
if ver > 0 && tag[1] != ver {
continue
}
var e *Endpoint
var buf []byte
// Old encoding was plain
if tag[1] == '=' {
buf = []byte(tag[2:])
}
// New encoding is hex
if tag[1] == '-' {
buf = decode(tag[2:])
}
if err := json.Unmarshal(buf, &e); err == nil {
en = append(en, e)
}
// set version
ver = tag[1]
}
return en
}
func encodeMetadata(md map[string]string) []string {
var tags []string
for k, v := range md {
if b, err := json.Marshal(map[string]string{
k: v,
}); err == nil {
// new encoding
tags = append(tags, "t-"+encode(b))
}
}
return tags
}
func decodeMetadata(tags []string) map[string]string {
md := make(map[string]string)
var ver byte
for _, tag := range tags {
if len(tag) == 0 || tag[0] != 't' {
continue
}
// check version
if ver > 0 && tag[1] != ver {
continue
}
var kv map[string]string
var buf []byte
// Old encoding was plain
if tag[1] == '=' {
buf = []byte(tag[2:])
}
// New encoding is hex
if tag[1] == '-' {
buf = decode(tag[2:])
}
// Now unmarshal
if err := json.Unmarshal(buf, &kv); err == nil {
for k, v := range kv {
md[k] = v
}
}
// set version
ver = tag[1]
}
return md
}
func encodeVersion(v string) []string {
return []string{"v-" + encode([]byte(v))}
}
func decodeVersion(tags []string) (string, bool) {
for _, tag := range tags {
if len(tag) < 2 || tag[0] != 'v' {
continue
}
// Old encoding was plain
if tag[1] == '=' {
return tag[2:], true
}
// New encoding is hex
if tag[1] == '-' {
return string(decode(tag[2:])), true
}
}
return "", false
return txt, nil
}

View File

@@ -1,146 +1,65 @@
package registry
import (
"encoding/json"
"testing"
)
func TestEncodingEndpoints(t *testing.T) {
eps := []*Endpoint{
&Endpoint{
Name: "endpoint1",
Request: &Value{
Name: "request",
Type: "request",
},
Response: &Value{
Name: "response",
Type: "response",
},
func TestEncoding(t *testing.T) {
testData := []*mdnsTxt{
&mdnsTxt{
Version: "1.0.0",
Metadata: map[string]string{
"foo1": "bar1",
"foo": "bar",
},
},
&Endpoint{
Name: "endpoint2",
Request: &Value{
Name: "request",
Type: "request",
},
Response: &Value{
Name: "response",
Type: "response",
},
Metadata: map[string]string{
"foo2": "bar2",
},
},
&Endpoint{
Name: "endpoint3",
Request: &Value{
Name: "request",
Type: "request",
},
Response: &Value{
Name: "response",
Type: "response",
},
Metadata: map[string]string{
"foo3": "bar3",
Endpoints: []*Endpoint{
&Endpoint{
Name: "endpoint1",
Request: &Value{
Name: "request",
Type: "request",
},
Response: &Value{
Name: "response",
Type: "response",
},
Metadata: map[string]string{
"foo1": "bar1",
},
},
},
},
}
testEp := func(ep *Endpoint, enc string) {
// encode endpoint
e := encodeEndpoints([]*Endpoint{ep})
// check there are two tags; old and new
if len(e) != 1 {
t.Fatalf("Expected 1 encoded tags, got %v", e)
}
// check old encoding
var seen bool
for _, en := range e {
if en == enc {
seen = true
break
}
}
if !seen {
t.Fatalf("Expected %s but not found", enc)
}
// decode
d := decodeEndpoints([]string{enc})
if len(d) == 0 {
t.Fatalf("Expected %v got %v", ep, d)
}
// check name
if d[0].Name != ep.Name {
t.Fatalf("Expected ep %s got %s", ep.Name, d[0].Name)
}
// check all the metadata exists
for k, v := range ep.Metadata {
if gv := d[0].Metadata[k]; gv != v {
t.Fatalf("Expected key %s val %s got val %s", k, v, gv)
}
}
}
for _, ep := range eps {
// JSON encoded
jencoded, err := json.Marshal(ep)
for _, d := range testData {
encoded, err := encode(d)
if err != nil {
t.Fatal(err)
}
// HEX encoded
hencoded := encode(jencoded)
// endpoint tag
hepTag := "e-" + hencoded
testEp(ep, hepTag)
}
}
func TestEncodingVersion(t *testing.T) {
testData := []struct {
decoded string
encoded string
}{
{"1.0.0", "v-789c32d433d03300040000ffff02ce00ee"},
{"latest", "v-789cca492c492d2e01040000ffff08cc028e"},
}
for _, data := range testData {
e := encodeVersion(data.decoded)
if e[0] != data.encoded {
t.Fatalf("Expected %s got %s", data.encoded, e)
}
d, ok := decodeVersion(e)
if !ok {
t.Fatalf("Unexpected %t for %s", ok, data.encoded)
}
if d != data.decoded {
t.Fatalf("Expected %s got %s", data.decoded, d)
}
d, ok = decodeVersion([]string{data.encoded})
if !ok {
t.Fatalf("Unexpected %t for %s", ok, data.encoded)
}
if d != data.decoded {
t.Fatalf("Expected %s got %s", data.decoded, d)
}
for _, txt := range encoded {
if len(txt) > 255 {
t.Fatalf("One of parts for txt is %d characters", len(txt))
}
}
decoded, err := decode(encoded)
if err != nil {
t.Fatal(err)
}
if decoded.Version != d.Version {
t.Fatalf("Expected version %s got %s", d.Version, decoded.Version)
}
if len(decoded.Endpoints) != len(d.Endpoints) {
t.Fatalf("Expected %d endpoints, got %d", len(d.Endpoints), len(decoded.Endpoints))
}
for k, v := range d.Metadata {
if val := decoded.Metadata[k]; val != v {
t.Fatalf("Expected %s=%s got %s=%s", k, v, k, val)
}
}
}
}

View File

@@ -1,73 +0,0 @@
package mdns
import (
"bytes"
"compress/zlib"
"encoding/hex"
"encoding/json"
"io/ioutil"
"strings"
)
func encode(txt *mdnsTxt) ([]string, error) {
b, err := json.Marshal(txt)
if err != nil {
return nil, err
}
var buf bytes.Buffer
defer buf.Reset()
w := zlib.NewWriter(&buf)
if _, err := w.Write(b); err != nil {
return nil, err
}
w.Close()
encoded := hex.EncodeToString(buf.Bytes())
// individual txt limit
if len(encoded) <= 255 {
return []string{encoded}, nil
}
// split encoded string
var record []string
for len(encoded) > 255 {
record = append(record, encoded[:255])
encoded = encoded[255:]
}
record = append(record, encoded)
return record, nil
}
func decode(record []string) (*mdnsTxt, error) {
encoded := strings.Join(record, "")
hr, err := hex.DecodeString(encoded)
if err != nil {
return nil, err
}
br := bytes.NewReader(hr)
zr, err := zlib.NewReader(br)
if err != nil {
return nil, err
}
rbuf, err := ioutil.ReadAll(zr)
if err != nil {
return nil, err
}
var txt *mdnsTxt
if err := json.Unmarshal(rbuf, &txt); err != nil {
return nil, err
}
return txt, nil
}

View File

@@ -1,67 +0,0 @@
package mdns
import (
"testing"
"github.com/micro/go-micro/registry"
)
func TestEncoding(t *testing.T) {
testData := []*mdnsTxt{
&mdnsTxt{
Version: "1.0.0",
Metadata: map[string]string{
"foo": "bar",
},
Endpoints: []*registry.Endpoint{
&registry.Endpoint{
Name: "endpoint1",
Request: &registry.Value{
Name: "request",
Type: "request",
},
Response: &registry.Value{
Name: "response",
Type: "response",
},
Metadata: map[string]string{
"foo1": "bar1",
},
},
},
},
}
for _, d := range testData {
encoded, err := encode(d)
if err != nil {
t.Fatal(err)
}
for _, txt := range encoded {
if len(txt) > 255 {
t.Fatalf("One of parts for txt is %d characters", len(txt))
}
}
decoded, err := decode(encoded)
if err != nil {
t.Fatal(err)
}
if decoded.Version != d.Version {
t.Fatalf("Expected version %s got %s", d.Version, decoded.Version)
}
if len(decoded.Endpoints) != len(d.Endpoints) {
t.Fatalf("Expected %d endpoints, got %d", len(d.Endpoints), len(decoded.Endpoints))
}
for k, v := range d.Metadata {
if val := decoded.Metadata[k]; val != v {
t.Fatalf("Expected %s=%s got %s=%s", k, v, k, val)
}
}
}
}

View File

@@ -1,339 +1,11 @@
// Package mdns is a multicast dns registry
// Package mdns provides a multicast dns registry
package mdns
/*
MDNS is a multicast dns registry for service discovery
This creates a zero dependency system which is great
where multicast dns is available. This usually depends
on the ability to leverage udp and multicast/broadcast.
*/
import (
"net"
"strings"
"sync"
"time"
"github.com/micro/go-micro/registry"
"github.com/micro/mdns"
hash "github.com/mitchellh/hashstructure"
)
type mdnsTxt struct {
Service string
Version string
Endpoints []*registry.Endpoint
Metadata map[string]string
}
type mdnsEntry struct {
hash uint64
id string
node *mdns.Server
}
type mdnsRegistry struct {
opts registry.Options
sync.Mutex
services map[string][]*mdnsEntry
}
func newRegistry(opts ...registry.Option) registry.Registry {
options := registry.Options{
Timeout: time.Millisecond * 100,
}
return &mdnsRegistry{
opts: options,
services: make(map[string][]*mdnsEntry),
}
}
func (m *mdnsRegistry) Init(opts ...registry.Option) error {
for _, o := range opts {
o(&m.opts)
}
return nil
}
func (m *mdnsRegistry) Options() registry.Options {
return m.opts
}
func (m *mdnsRegistry) Register(service *registry.Service, opts ...registry.RegisterOption) error {
m.Lock()
defer m.Unlock()
entries, ok := m.services[service.Name]
// first entry, create wildcard used for list queries
if !ok {
s, err := mdns.NewMDNSService(
service.Name,
"_services",
"",
"",
9999,
[]net.IP{net.ParseIP("0.0.0.0")},
nil,
)
if err != nil {
return err
}
srv, err := mdns.NewServer(&mdns.Config{Zone: &mdns.DNSSDService{s}})
if err != nil {
return err
}
// append the wildcard entry
entries = append(entries, &mdnsEntry{id: "*", node: srv})
}
var gerr error
for _, node := range service.Nodes {
// create hash of service; uint64
h, err := hash.Hash(node, nil)
if err != nil {
gerr = err
continue
}
var seen bool
var e *mdnsEntry
for _, entry := range entries {
if node.Id == entry.id {
seen = true
e = entry
break
}
}
// already registered, continue
if seen && e.hash == h {
continue
// hash doesn't match, shutdown
} else if seen {
e.node.Shutdown()
// doesn't exist
} else {
e = &mdnsEntry{hash: h}
}
txt, err := encode(&mdnsTxt{
Service: service.Name,
Version: service.Version,
Endpoints: service.Endpoints,
Metadata: node.Metadata,
})
if err != nil {
gerr = err
continue
}
// we got here, new node
s, err := mdns.NewMDNSService(
node.Id,
service.Name,
"",
"",
node.Port,
[]net.IP{net.ParseIP(node.Address)},
txt,
)
if err != nil {
gerr = err
continue
}
srv, err := mdns.NewServer(&mdns.Config{Zone: s})
if err != nil {
gerr = err
continue
}
e.id = node.Id
e.node = srv
entries = append(entries, e)
}
// save
m.services[service.Name] = entries
return gerr
}
func (m *mdnsRegistry) Deregister(service *registry.Service) error {
m.Lock()
defer m.Unlock()
var newEntries []*mdnsEntry
// loop existing entries, check if any match, shutdown those that do
for _, entry := range m.services[service.Name] {
var remove bool
for _, node := range service.Nodes {
if node.Id == entry.id {
entry.node.Shutdown()
remove = true
break
}
}
// keep it?
if !remove {
newEntries = append(newEntries, entry)
}
}
// last entry is the wildcard for list queries. Remove it.
if len(newEntries) == 1 && newEntries[0].id == "*" {
newEntries[0].node.Shutdown()
delete(m.services, service.Name)
} else {
m.services[service.Name] = newEntries
}
return nil
}
func (m *mdnsRegistry) GetService(service string) ([]*registry.Service, error) {
p := mdns.DefaultParams(service)
p.Timeout = m.opts.Timeout
entryCh := make(chan *mdns.ServiceEntry, 10)
p.Entries = entryCh
exit := make(chan bool)
defer close(exit)
serviceMap := make(map[string]*registry.Service)
go func() {
for {
select {
case e := <-entryCh:
// list record so skip
if p.Service == "_services" {
continue
}
if e.TTL == 0 {
continue
}
txt, err := decode(e.InfoFields)
if err != nil {
continue
}
if txt.Service != service {
continue
}
s, ok := serviceMap[txt.Version]
if !ok {
s = &registry.Service{
Name: txt.Service,
Version: txt.Version,
Endpoints: txt.Endpoints,
}
}
s.Nodes = append(s.Nodes, &registry.Node{
Id: strings.TrimSuffix(e.Name, "."+p.Service+"."+p.Domain+"."),
Address: e.AddrV4.String(),
Port: e.Port,
Metadata: txt.Metadata,
})
serviceMap[txt.Version] = s
case <-exit:
return
}
}
}()
if err := mdns.Query(p); err != nil {
return nil, err
}
// create list and return
var services []*registry.Service
for _, service := range serviceMap {
services = append(services, service)
}
return services, nil
}
func (m *mdnsRegistry) ListServices() ([]*registry.Service, error) {
p := mdns.DefaultParams("_services")
p.Timeout = m.opts.Timeout
entryCh := make(chan *mdns.ServiceEntry, 10)
p.Entries = entryCh
exit := make(chan bool)
defer close(exit)
serviceMap := make(map[string]bool)
var services []*registry.Service
go func() {
for {
select {
case e := <-entryCh:
if e.TTL == 0 {
continue
}
name := strings.TrimSuffix(e.Name, "."+p.Service+"."+p.Domain+".")
if !serviceMap[name] {
serviceMap[name] = true
services = append(services, &registry.Service{Name: name})
}
case <-exit:
return
}
}
}()
if err := mdns.Query(p); err != nil {
return nil, err
}
return services, nil
}
func (m *mdnsRegistry) Watch(opts ...registry.WatchOption) (registry.Watcher, error) {
var wo registry.WatchOptions
for _, o := range opts {
o(&wo)
}
md := &mdnsWatcher{
wo: wo,
ch: make(chan *mdns.ServiceEntry, 32),
exit: make(chan struct{}),
}
go func() {
if err := mdns.Listen(md.ch, md.exit); err != nil {
md.Stop()
}
}()
return md, nil
}
func (m *mdnsRegistry) String() string {
return "mdns"
}
// NewRegistry returns a new mdns registry
func NewRegistry(opts ...registry.Option) registry.Registry {
return newRegistry(opts...)
return registry.NewRegistry(opts...)
}

332
registry/mdns_registry.go Normal file
View File

@@ -0,0 +1,332 @@
// Package mdns is a multicast dns registry
package registry
import (
"net"
"strings"
"sync"
"time"
"github.com/micro/mdns"
hash "github.com/mitchellh/hashstructure"
)
type mdnsTxt struct {
Service string
Version string
Endpoints []*Endpoint
Metadata map[string]string
}
type mdnsEntry struct {
hash uint64
id string
node *mdns.Server
}
type mdnsRegistry struct {
opts Options
sync.Mutex
services map[string][]*mdnsEntry
}
func newRegistry(opts ...Option) Registry {
options := Options{
Timeout: time.Millisecond * 100,
}
return &mdnsRegistry{
opts: options,
services: make(map[string][]*mdnsEntry),
}
}
func (m *mdnsRegistry) Init(opts ...Option) error {
for _, o := range opts {
o(&m.opts)
}
return nil
}
func (m *mdnsRegistry) Options() Options {
return m.opts
}
func (m *mdnsRegistry) Register(service *Service, opts ...RegisterOption) error {
m.Lock()
defer m.Unlock()
entries, ok := m.services[service.Name]
// first entry, create wildcard used for list queries
if !ok {
s, err := mdns.NewMDNSService(
service.Name,
"_services",
"",
"",
9999,
[]net.IP{net.ParseIP("0.0.0.0")},
nil,
)
if err != nil {
return err
}
srv, err := mdns.NewServer(&mdns.Config{Zone: &mdns.DNSSDService{s}})
if err != nil {
return err
}
// append the wildcard entry
entries = append(entries, &mdnsEntry{id: "*", node: srv})
}
var gerr error
for _, node := range service.Nodes {
// create hash of service; uint64
h, err := hash.Hash(node, nil)
if err != nil {
gerr = err
continue
}
var seen bool
var e *mdnsEntry
for _, entry := range entries {
if node.Id == entry.id {
seen = true
e = entry
break
}
}
// already registered, continue
if seen && e.hash == h {
continue
// hash doesn't match, shutdown
} else if seen {
e.node.Shutdown()
// doesn't exist
} else {
e = &mdnsEntry{hash: h}
}
txt, err := encode(&mdnsTxt{
Service: service.Name,
Version: service.Version,
Endpoints: service.Endpoints,
Metadata: node.Metadata,
})
if err != nil {
gerr = err
continue
}
// we got here, new node
s, err := mdns.NewMDNSService(
node.Id,
service.Name,
"",
"",
node.Port,
[]net.IP{net.ParseIP(node.Address)},
txt,
)
if err != nil {
gerr = err
continue
}
srv, err := mdns.NewServer(&mdns.Config{Zone: s})
if err != nil {
gerr = err
continue
}
e.id = node.Id
e.node = srv
entries = append(entries, e)
}
// save
m.services[service.Name] = entries
return gerr
}
func (m *mdnsRegistry) Deregister(service *Service) error {
m.Lock()
defer m.Unlock()
var newEntries []*mdnsEntry
// loop existing entries, check if any match, shutdown those that do
for _, entry := range m.services[service.Name] {
var remove bool
for _, node := range service.Nodes {
if node.Id == entry.id {
entry.node.Shutdown()
remove = true
break
}
}
// keep it?
if !remove {
newEntries = append(newEntries, entry)
}
}
// last entry is the wildcard for list queries. Remove it.
if len(newEntries) == 1 && newEntries[0].id == "*" {
newEntries[0].node.Shutdown()
delete(m.services, service.Name)
} else {
m.services[service.Name] = newEntries
}
return nil
}
func (m *mdnsRegistry) GetService(service string) ([]*Service, error) {
p := mdns.DefaultParams(service)
p.Timeout = m.opts.Timeout
entryCh := make(chan *mdns.ServiceEntry, 10)
p.Entries = entryCh
exit := make(chan bool)
defer close(exit)
serviceMap := make(map[string]*Service)
go func() {
for {
select {
case e := <-entryCh:
// list record so skip
if p.Service == "_services" {
continue
}
if e.TTL == 0 {
continue
}
txt, err := decode(e.InfoFields)
if err != nil {
continue
}
if txt.Service != service {
continue
}
s, ok := serviceMap[txt.Version]
if !ok {
s = &Service{
Name: txt.Service,
Version: txt.Version,
Endpoints: txt.Endpoints,
}
}
s.Nodes = append(s.Nodes, &Node{
Id: strings.TrimSuffix(e.Name, "."+p.Service+"."+p.Domain+"."),
Address: e.AddrV4.String(),
Port: e.Port,
Metadata: txt.Metadata,
})
serviceMap[txt.Version] = s
case <-exit:
return
}
}
}()
if err := mdns.Query(p); err != nil {
return nil, err
}
// create list and return
var services []*Service
for _, service := range serviceMap {
services = append(services, service)
}
return services, nil
}
func (m *mdnsRegistry) ListServices() ([]*Service, error) {
p := mdns.DefaultParams("_services")
p.Timeout = m.opts.Timeout
entryCh := make(chan *mdns.ServiceEntry, 10)
p.Entries = entryCh
exit := make(chan bool)
defer close(exit)
serviceMap := make(map[string]bool)
var services []*Service
go func() {
for {
select {
case e := <-entryCh:
if e.TTL == 0 {
continue
}
name := strings.TrimSuffix(e.Name, "."+p.Service+"."+p.Domain+".")
if !serviceMap[name] {
serviceMap[name] = true
services = append(services, &Service{Name: name})
}
case <-exit:
return
}
}
}()
if err := mdns.Query(p); err != nil {
return nil, err
}
return services, nil
}
func (m *mdnsRegistry) Watch(opts ...WatchOption) (Watcher, error) {
var wo WatchOptions
for _, o := range opts {
o(&wo)
}
md := &mdnsWatcher{
wo: wo,
ch: make(chan *mdns.ServiceEntry, 32),
exit: make(chan struct{}),
}
go func() {
if err := mdns.Listen(md.ch, md.exit); err != nil {
md.Stop()
}
}()
return md, nil
}
func (m *mdnsRegistry) String() string {
return "mdns"
}
// NewRegistry returns a new default registry which is mdns
func NewRegistry(opts ...Option) Registry {
return newRegistry(opts...)
}

View File

@@ -1,19 +1,17 @@
package mdns
package registry
import (
"testing"
"time"
"github.com/micro/go-micro/registry"
)
func TestMDNS(t *testing.T) {
testData := []*registry.Service{
&registry.Service{
testData := []*Service{
&Service{
Name: "test1",
Version: "1.0.1",
Nodes: []*registry.Node{
&registry.Node{
Nodes: []*Node{
&Node{
Id: "test1-1",
Address: "10.0.0.1",
Port: 10001,
@@ -23,11 +21,11 @@ func TestMDNS(t *testing.T) {
},
},
},
&registry.Service{
&Service{
Name: "test2",
Version: "1.0.2",
Nodes: []*registry.Node{
&registry.Node{
Nodes: []*Node{
&Node{
Id: "test2-1",
Address: "10.0.0.2",
Port: 10002,
@@ -37,11 +35,11 @@ func TestMDNS(t *testing.T) {
},
},
},
&registry.Service{
&Service{
Name: "test3",
Version: "1.0.3",
Nodes: []*registry.Node{
&registry.Node{
Nodes: []*Node{
&Node{
Id: "test3-1",
Address: "10.0.0.3",
Port: 10003,

View File

@@ -1,20 +1,19 @@
package mdns
package registry
import (
"errors"
"strings"
"github.com/micro/go-micro/registry"
"github.com/micro/mdns"
)
type mdnsWatcher struct {
wo registry.WatchOptions
wo WatchOptions
ch chan *mdns.ServiceEntry
exit chan struct{}
}
func (m *mdnsWatcher) Next() (*registry.Result, error) {
func (m *mdnsWatcher) Next() (*Result, error) {
for {
select {
case e := <-m.ch:
@@ -41,7 +40,7 @@ func (m *mdnsWatcher) Next() (*registry.Result, error) {
action = "create"
}
service := &registry.Service{
service := &Service{
Name: txt.Service,
Version: txt.Version,
Endpoints: txt.Endpoints,
@@ -52,14 +51,14 @@ func (m *mdnsWatcher) Next() (*registry.Result, error) {
continue
}
service.Nodes = append(service.Nodes, &registry.Node{
service.Nodes = append(service.Nodes, &Node{
Id: strings.TrimSuffix(e.Name, "."+service.Name+".local."),
Address: e.AddrV4.String(),
Port: e.Port,
Metadata: txt.Metadata,
})
return &registry.Result{
return &Result{
Action: action,
Service: service,
}, nil

51
registry/memory/data.go Normal file
View File

@@ -0,0 +1,51 @@
package memory
import (
"github.com/micro/go-micro/registry"
)
var (
// mock data
Data = map[string][]*registry.Service{
"foo": []*registry.Service{
{
Name: "foo",
Version: "1.0.0",
Nodes: []*registry.Node{
{
Id: "foo-1.0.0-123",
Address: "localhost",
Port: 9999,
},
{
Id: "foo-1.0.0-321",
Address: "localhost",
Port: 9999,
},
},
},
{
Name: "foo",
Version: "1.0.1",
Nodes: []*registry.Node{
{
Id: "foo-1.0.1-321",
Address: "localhost",
Port: 6666,
},
},
},
{
Name: "foo",
Version: "1.0.3",
Nodes: []*registry.Node{
{
Id: "foo-1.0.3-345",
Address: "localhost",
Port: 8888,
},
},
},
},
}
)

View File

@@ -1,4 +1,4 @@
package mock
package memory
import (
"github.com/micro/go-micro/registry"

View File

@@ -1,4 +1,4 @@
package mock
package memory
import (
"testing"

160
registry/memory/memory.go Normal file
View File

@@ -0,0 +1,160 @@
// Package memory provides an in-memory registry
package memory
import (
"context"
"sync"
"time"
"github.com/google/uuid"
"github.com/micro/go-micro/registry"
)
type Registry struct {
options registry.Options
sync.RWMutex
Services map[string][]*registry.Service
Watchers map[string]*Watcher
}
var (
timeout = time.Millisecond * 10
)
// Setup sets mock data
func (m *Registry) Setup() {
m.Lock()
defer m.Unlock()
// add some memory data
m.Services = Data
}
func (m *Registry) watch(r *registry.Result) {
var watchers []*Watcher
m.RLock()
for _, w := range m.Watchers {
watchers = append(watchers, w)
}
m.RUnlock()
for _, w := range watchers {
select {
case <-w.exit:
m.Lock()
delete(m.Watchers, w.id)
m.Unlock()
default:
select {
case w.res <- r:
case <-time.After(timeout):
}
}
}
}
func (m *Registry) Init(opts ...registry.Option) error {
for _, o := range opts {
o(&m.options)
}
// add services
m.Lock()
for k, v := range getServices(m.options.Context) {
s := m.Services[k]
m.Services[k] = addServices(s, v)
}
m.Unlock()
return nil
}
func (m *Registry) Options() registry.Options {
return m.options
}
func (m *Registry) GetService(service string) ([]*registry.Service, error) {
m.RLock()
s, ok := m.Services[service]
if !ok || len(s) == 0 {
m.RUnlock()
return nil, registry.ErrNotFound
}
m.RUnlock()
return s, nil
}
func (m *Registry) ListServices() ([]*registry.Service, error) {
m.RLock()
var services []*registry.Service
for _, service := range m.Services {
services = append(services, service...)
}
m.RUnlock()
return services, nil
}
func (m *Registry) Register(s *registry.Service, opts ...registry.RegisterOption) error {
go m.watch(&registry.Result{Action: "update", Service: s})
m.Lock()
services := addServices(m.Services[s.Name], []*registry.Service{s})
m.Services[s.Name] = services
m.Unlock()
return nil
}
func (m *Registry) Deregister(s *registry.Service) error {
go m.watch(&registry.Result{Action: "delete", Service: s})
m.Lock()
services := delServices(m.Services[s.Name], []*registry.Service{s})
m.Services[s.Name] = services
m.Unlock()
return nil
}
func (m *Registry) Watch(opts ...registry.WatchOption) (registry.Watcher, error) {
var wo registry.WatchOptions
for _, o := range opts {
o(&wo)
}
w := &Watcher{
exit: make(chan bool),
res: make(chan *registry.Result),
id: uuid.New().String(),
wo: wo,
}
m.Lock()
m.Watchers[w.id] = w
m.Unlock()
return w, nil
}
func (m *Registry) String() string {
return "memory"
}
func NewRegistry(opts ...registry.Option) registry.Registry {
options := registry.Options{
Context: context.Background(),
}
for _, o := range opts {
o(&options)
}
services := getServices(options.Context)
if services == nil {
services = make(map[string][]*registry.Service)
}
return &Registry{
options: options,
Services: services,
Watchers: make(map[string]*Watcher),
}
}

View File

@@ -1,4 +1,4 @@
package mock
package memory
import (
"testing"
@@ -80,7 +80,7 @@ var (
}
)
func TestMockRegistry(t *testing.T) {
func TestMemoryRegistry(t *testing.T) {
m := NewRegistry()
fn := func(k string, v []*registry.Service) {
@@ -107,11 +107,6 @@ func TestMockRegistry(t *testing.T) {
}
}
// test existing mock data
for k, v := range mockData {
fn(k, v)
}
// register data
for _, v := range testData {
for _, service := range v {
@@ -123,7 +118,6 @@ func TestMockRegistry(t *testing.T) {
// using test data
for k, v := range testData {
fn(k, v)
}

View File

@@ -1,4 +1,4 @@
package mock
package memory
import (
"errors"
@@ -6,12 +6,12 @@ import (
"github.com/micro/go-micro/registry"
)
type mockWatcher struct {
type memoryWatcher struct {
exit chan bool
opts registry.WatchOptions
}
func (m *mockWatcher) Next() (*registry.Result, error) {
func (m *memoryWatcher) Next() (*registry.Result, error) {
// not implement so we just block until exit
select {
case <-m.exit:
@@ -19,7 +19,7 @@ func (m *mockWatcher) Next() (*registry.Result, error) {
}
}
func (m *mockWatcher) Stop() {
func (m *memoryWatcher) Stop() {
select {
case <-m.exit:
return

View File

@@ -0,0 +1,27 @@
package memory
import (
"context"
"github.com/micro/go-micro/registry"
)
type servicesKey struct{}
func getServices(ctx context.Context) map[string][]*registry.Service {
s, ok := ctx.Value(servicesKey{}).(map[string][]*registry.Service)
if !ok {
return nil
}
return s
}
// Services is an option that preloads service data
func Services(s map[string][]*registry.Service) registry.Option {
return func(o *registry.Options) {
if o.Context == nil {
o.Context = context.Background()
}
o.Context = context.WithValue(o.Context, servicesKey{}, s)
}
}

View File

@@ -0,0 +1,37 @@
package memory
import (
"errors"
"github.com/micro/go-micro/registry"
)
type Watcher struct {
id string
wo registry.WatchOptions
res chan *registry.Result
exit chan bool
}
func (m *Watcher) Next() (*registry.Result, error) {
for {
select {
case r := <-m.res:
if len(m.wo.Service) > 0 && m.wo.Service != r.Service.Name {
continue
}
return r, nil
case <-m.exit:
return nil, errors.New("watcher stopped")
}
}
}
func (m *Watcher) Stop() {
select {
case <-m.exit:
return
default:
close(m.exit)
}
}

View File

@@ -0,0 +1,30 @@
package memory
import (
"testing"
"github.com/micro/go-micro/registry"
)
func TestWatcher(t *testing.T) {
w := &Watcher{
id: "test",
res: make(chan *registry.Result),
exit: make(chan bool),
}
go func() {
w.res <- &registry.Result{}
}()
_, err := w.Next()
if err != nil {
t.Fatal("unexpected err", err)
}
w.Stop()
if _, err := w.Next(); err == nil {
t.Fatal("expected error on Next()")
}
}

View File

@@ -1,133 +0,0 @@
// Package mock provides a mock registry for testing
package mock
import (
"sync"
"github.com/micro/go-micro/registry"
)
type mockRegistry struct {
sync.RWMutex
Services map[string][]*registry.Service
}
var (
mockData = map[string][]*registry.Service{
"foo": []*registry.Service{
{
Name: "foo",
Version: "1.0.0",
Nodes: []*registry.Node{
{
Id: "foo-1.0.0-123",
Address: "localhost",
Port: 9999,
},
{
Id: "foo-1.0.0-321",
Address: "localhost",
Port: 9999,
},
},
},
{
Name: "foo",
Version: "1.0.1",
Nodes: []*registry.Node{
{
Id: "foo-1.0.1-321",
Address: "localhost",
Port: 6666,
},
},
},
{
Name: "foo",
Version: "1.0.3",
Nodes: []*registry.Node{
{
Id: "foo-1.0.3-345",
Address: "localhost",
Port: 8888,
},
},
},
},
}
)
func (m *mockRegistry) init() {
m.Lock()
defer m.Unlock()
// add some mock data
m.Services = mockData
}
func (m *mockRegistry) GetService(service string) ([]*registry.Service, error) {
m.Lock()
defer m.Unlock()
s, ok := m.Services[service]
if !ok || len(s) == 0 {
return nil, registry.ErrNotFound
}
return s, nil
}
func (m *mockRegistry) ListServices() ([]*registry.Service, error) {
m.Lock()
defer m.Unlock()
var services []*registry.Service
for _, service := range m.Services {
services = append(services, service...)
}
return services, nil
}
func (m *mockRegistry) Register(s *registry.Service, opts ...registry.RegisterOption) error {
m.Lock()
defer m.Unlock()
services := addServices(m.Services[s.Name], []*registry.Service{s})
m.Services[s.Name] = services
return nil
}
func (m *mockRegistry) Deregister(s *registry.Service) error {
m.Lock()
defer m.Unlock()
services := delServices(m.Services[s.Name], []*registry.Service{s})
m.Services[s.Name] = services
return nil
}
func (m *mockRegistry) Watch(opts ...registry.WatchOption) (registry.Watcher, error) {
var wopts registry.WatchOptions
for _, o := range opts {
o(&wopts)
}
return &mockWatcher{exit: make(chan bool), opts: wopts}, nil
}
func (m *mockRegistry) String() string {
return "mock"
}
func (m *mockRegistry) Init(opts ...registry.Option) error {
return nil
}
func (m *mockRegistry) Options() registry.Options {
return registry.Options{}
}
func NewRegistry(opts ...registry.Options) registry.Registry {
m := &mockRegistry{Services: make(map[string][]*registry.Service)}
m.init()
return m
}

View File

@@ -26,7 +26,7 @@ type RegisterOption func(*RegisterOptions)
type WatchOption func(*WatchOptions)
var (
DefaultRegistry = newConsulRegistry()
DefaultRegistry = NewRegistry()
// Not found error when GetService is called
ErrNotFound = errors.New("not found")
@@ -34,10 +34,6 @@ var (
ErrWatcherStopped = errors.New("watcher stopped")
)
func NewRegistry(opts ...Option) Registry {
return newConsulRegistry(opts...)
}
// Register a service node. Additionally supply options such as TTL.
func Register(s *Service, opts ...RegisterOption) error {
return DefaultRegistry.Register(s, opts...)

View File

@@ -1,18 +1,16 @@
package mdns
package registry
import (
"testing"
"github.com/micro/go-micro/registry"
)
func TestWatcher(t *testing.T) {
testData := []*registry.Service{
&registry.Service{
testData := []*Service{
&Service{
Name: "test1",
Version: "1.0.1",
Nodes: []*registry.Node{
&registry.Node{
Nodes: []*Node{
&Node{
Id: "test1-1",
Address: "10.0.0.1",
Port: 10001,
@@ -22,11 +20,11 @@ func TestWatcher(t *testing.T) {
},
},
},
&registry.Service{
&Service{
Name: "test2",
Version: "1.0.2",
Nodes: []*registry.Node{
&registry.Node{
Nodes: []*Node{
&Node{
Id: "test2-1",
Address: "10.0.0.2",
Port: 10002,
@@ -36,11 +34,11 @@ func TestWatcher(t *testing.T) {
},
},
},
&registry.Service{
&Service{
Name: "test3",
Version: "1.0.3",
Nodes: []*registry.Node{
&registry.Node{
Nodes: []*Node{
&Node{
Id: "test3-1",
Address: "10.0.0.3",
Port: 10003,
@@ -52,7 +50,7 @@ func TestWatcher(t *testing.T) {
},
}
testFn := func(service, s *registry.Service) {
testFn := func(service, s *Service) {
if s == nil {
t.Fatalf("Expected one result for %s got nil", service.Name)

View File

@@ -28,6 +28,27 @@ var (
DefaultTTL = time.Minute
)
// isValid checks if the service is valid
func (c *registrySelector) isValid(services []*registry.Service, ttl time.Time) bool {
// no services exist
if len(services) == 0 {
return false
}
// ttl is invalid
if ttl.IsZero() {
return false
}
// time since ttl is longer than timeout
if time.Since(ttl) > c.ttl {
return false
}
// ok
return true
}
func (c *registrySelector) quit() bool {
select {
case <-c.exit:
@@ -83,12 +104,12 @@ func (c *registrySelector) get(service string) ([]*registry.Service, error) {
c.RLock()
// check the cache first
services, ok := c.cache[service]
services := c.cache[service]
// get cache ttl
ttl, kk := c.ttls[service]
ttl := c.ttls[service]
// got services && within ttl so return cache
if ok && kk && time.Since(ttl) < c.ttl {
if c.isValid(services, ttl) {
// make a copy
cp := c.cp(services)
// unlock the read
@@ -255,6 +276,9 @@ func (c *registrySelector) run(name string) {
c.Unlock()
}()
// error counter
var cerr int
for {
// exit early if already dead
if c.quit() {
@@ -265,11 +289,16 @@ func (c *registrySelector) run(name string) {
w, err := c.so.Registry.Watch(
registry.WatchService(name),
)
if err != nil {
if c.quit() {
return
}
log.Log(err)
cerr++
if cerr > 3 {
log.Log(err)
cerr = 0
}
time.Sleep(time.Second)
continue
}
@@ -279,9 +308,16 @@ func (c *registrySelector) run(name string) {
if c.quit() {
return
}
log.Log(err)
cerr++
if cerr > 3 {
cerr = 0
log.Log(err)
}
continue
}
// reset err counter
cerr = 0
}
}
@@ -290,12 +326,16 @@ func (c *registrySelector) run(name string) {
func (c *registrySelector) watch(w registry.Watcher) error {
defer w.Stop()
// reload chan
reload := make(chan bool, 1)
// manage this loop
go func() {
// wait for exit or reload signal
select {
case <-c.exit:
case <-c.reload:
reload <- true
}
// stop the watcher
@@ -305,7 +345,12 @@ func (c *registrySelector) watch(w registry.Watcher) error {
for {
res, err := w.Next()
if err != nil {
return err
select {
case <-reload:
return nil
default:
return err
}
}
c.update(res)
}

View File

@@ -3,13 +3,15 @@ package selector
import (
"testing"
"github.com/micro/go-micro/registry/mock"
"github.com/micro/go-micro/registry/memory"
)
func TestRegistrySelector(t *testing.T) {
counts := map[string]int{}
cache := NewSelector(Registry(mock.NewRegistry()))
r := memory.NewRegistry()
r.(*memory.Registry).Setup()
cache := NewSelector(Registry(r))
next, err := cache.Select("foo")
if err != nil {

View File

@@ -41,6 +41,15 @@ var (
"application/proto-rpc": protorpc.NewCodec,
"application/octet-stream": raw.NewCodec,
}
// TODO: remove legacy codec list
defaultCodecs = map[string]codec.NewCodec{
"application/json": jsonrpc.NewCodec,
"application/json-rpc": jsonrpc.NewCodec,
"application/protobuf": protorpc.NewCodec,
"application/proto-rpc": protorpc.NewCodec,
"application/octet-stream": protorpc.NewCodec,
}
)
func (rwc *readWriteCloser) Read(p []byte) (n int, err error) {
@@ -57,6 +66,42 @@ func (rwc *readWriteCloser) Close() error {
return nil
}
// setupProtocol sets up the old protocol
func setupProtocol(msg *transport.Message) codec.NewCodec {
service := msg.Header["X-Micro-Service"]
method := msg.Header["X-Micro-Method"]
endpoint := msg.Header["X-Micro-Endpoint"]
protocol := msg.Header["X-Micro-Protocol"]
target := msg.Header["X-Micro-Target"]
// if the protocol exists (mucp) do nothing
if len(protocol) > 0 {
return nil
}
// if no service/method/endpoint then it's the old protocol
if len(service) == 0 && len(method) == 0 && len(endpoint) == 0 {
return defaultCodecs[msg.Header["Content-Type"]]
}
// old target method specified
if len(target) > 0 {
return defaultCodecs[msg.Header["Content-Type"]]
}
// no method then set to endpoint
if len(method) == 0 {
msg.Header["X-Micro-Method"] = method
}
// no endpoint then set to method
if len(endpoint) == 0 {
msg.Header["X-Micro-Endpoint"] = method
}
return nil
}
func newRpcCodec(req *transport.Message, socket transport.Socket, c codec.NewCodec) codec.Codec {
rwc := &readWriteCloser{
rbuf: bytes.NewBuffer(req.Body),
@@ -99,6 +144,9 @@ func (c *rpcCodec) ReadHeader(r *codec.Message, t codec.MessageType) error {
m.Header = tm.Header
// set the message body
m.Body = tm.Body
// set req
c.req = &tm
}
// no longer first read
@@ -106,6 +154,7 @@ func (c *rpcCodec) ReadHeader(r *codec.Message, t codec.MessageType) error {
// set some internal things
m.Target = m.Header["X-Micro-Service"]
m.Method = m.Header["X-Micro-Method"]
m.Endpoint = m.Header["X-Micro-Endpoint"]
m.Id = m.Header["X-Micro-Id"]
@@ -113,13 +162,29 @@ func (c *rpcCodec) ReadHeader(r *codec.Message, t codec.MessageType) error {
err := c.codec.ReadHeader(&m, codec.Request)
// set the method/id
r.Method = m.Method
r.Endpoint = m.Endpoint
r.Id = m.Id
// TODO: remove the old legacy cruft
if len(r.Endpoint) == 0 {
r.Endpoint = r.Method
}
return err
}
func (c *rpcCodec) ReadBody(b interface{}) error {
// don't read empty body
if len(c.req.Body) == 0 {
return nil
}
// read raw data
if v, ok := b.(*raw.Frame); ok {
v.Data = c.req.Body
return nil
}
// decode the usual way
return c.codec.ReadBody(b)
}
@@ -128,11 +193,17 @@ func (c *rpcCodec) Write(r *codec.Message, b interface{}) error {
// create a new message
m := &codec.Message{
Target: r.Target,
Method: r.Method,
Endpoint: r.Endpoint,
Id: r.Id,
Error: r.Error,
Type: r.Type,
Header: map[string]string{},
Header: r.Header,
}
if m.Header == nil {
m.Header = map[string]string{}
}
// set request id
@@ -145,6 +216,11 @@ func (c *rpcCodec) Write(r *codec.Message, b interface{}) error {
m.Header["X-Micro-Service"] = r.Target
}
// set request method
if len(r.Method) > 0 {
m.Header["X-Micro-Method"] = r.Method
}
// set request endpoint
if len(r.Endpoint) > 0 {
m.Header["X-Micro-Endpoint"] = r.Endpoint
@@ -157,10 +233,13 @@ func (c *rpcCodec) Write(r *codec.Message, b interface{}) error {
// the body being sent
var body []byte
// if we have encoded data just send it
if len(r.Body) > 0 {
// is it a raw frame?
if v, ok := b.(*raw.Frame); ok {
body = v.Data
// if we have encoded data just send it
} else if len(r.Body) > 0 {
body = r.Body
// write to the body
// write the body to codec
} else if err := c.codec.Write(m, b); err != nil {
c.buf.wbuf.Reset()
@@ -171,7 +250,6 @@ func (c *rpcCodec) Write(r *codec.Message, b interface{}) error {
if err := c.codec.Write(m, nil); err != nil {
return err
}
// write the body
} else {
// set the body
body = c.buf.wbuf.Bytes()

View File

@@ -7,6 +7,7 @@ import (
type rpcRequest struct {
service string
method string
endpoint string
contentType string
socket transport.Socket
@@ -34,6 +35,10 @@ func (r *rpcRequest) Service() string {
return r.service
}
func (r *rpcRequest) Method() string {
return r.method
}
func (r *rpcRequest) Endpoint() string {
return r.endpoint
}

View File

@@ -162,33 +162,30 @@ func prepareMethod(method reflect.Method) *methodType {
return &methodType{method: method, ArgType: argType, ReplyType: replyType, ContextType: contextType, stream: stream}
}
func (router *router) sendResponse(sending sync.Locker, req *request, reply interface{}, cc codec.Writer, errmsg string, last bool) (err error) {
func (router *router) sendResponse(sending sync.Locker, req *request, reply interface{}, cc codec.Writer, last bool) error {
msg := new(codec.Message)
msg.Type = codec.Response
resp := router.getResponse()
resp.msg = msg
// Encode the response header
resp.msg.Endpoint = req.msg.Endpoint
if errmsg != "" {
resp.msg.Error = errmsg
reply = invalidRequest
}
resp.msg.Id = req.msg.Id
sending.Lock()
err = cc.Write(resp.msg, reply)
err := cc.Write(resp.msg, reply)
sending.Unlock()
router.freeResponse(resp)
return err
}
func (s *service) call(ctx context.Context, router *router, sending *sync.Mutex, mtype *methodType, req *request, argv, replyv reflect.Value, cc codec.Writer) {
func (s *service) call(ctx context.Context, router *router, sending *sync.Mutex, mtype *methodType, req *request, argv, replyv reflect.Value, cc codec.Writer) error {
defer router.freeRequest(req)
function := mtype.method.Func
var returnValues []reflect.Value
r := &rpcRequest{
service: req.msg.Target,
contentType: req.msg.Header["Content-Type"],
method: req.msg.Method,
endpoint: req.msg.Endpoint,
body: req.msg.Body,
}
@@ -205,18 +202,13 @@ func (s *service) call(ctx context.Context, router *router, sending *sync.Mutex,
return nil
}
errmsg := ""
err := fn(ctx, r, replyv.Interface())
if err != nil {
errmsg = err.Error()
// execute handler
if err := fn(ctx, r, replyv.Interface()); err != nil {
return err
}
err = router.sendResponse(sending, req, replyv.Interface(), cc, errmsg, true)
if err != nil {
log.Log("rpc call: unable to send response: ", err)
}
router.freeRequest(req)
return
// send response
return router.sendResponse(sending, req, replyv.Interface(), cc, true)
}
// declare a local error to see if we errored out already
@@ -249,16 +241,15 @@ func (s *service) call(ctx context.Context, router *router, sending *sync.Mutex,
// client.Stream request
r.stream = true
errmsg := ""
// execute handler
if err := fn(ctx, r, stream); err != nil {
errmsg = err.Error()
return err
}
// this is the last packet, we don't do anything with
// the error here (well sendStreamResponse will log it
// already)
router.sendResponse(sending, req, nil, cc, errmsg, true)
router.freeRequest(req)
return router.sendResponse(sending, req, nil, cc, true)
}
func (m *methodType) prepareContext(ctx context.Context) reflect.Value {
@@ -447,11 +438,9 @@ func (router *router) ServeRequest(ctx context.Context, r Request, rsp Response)
}
// send a response if we actually managed to read a header.
if req != nil {
router.sendResponse(sending, req, invalidRequest, rsp.Codec(), err.Error(), true)
router.freeRequest(req)
}
return err
}
service.call(ctx, router, sending, mtype, req, argv, replyv, rsp.Codec())
return nil
return service.call(ctx, router, sending, mtype, req, argv, replyv, rsp.Codec())
}

View File

@@ -38,7 +38,7 @@ func newRpcServer(opts ...Option) Server {
options := newOptions(opts...)
return &rpcServer{
opts: options,
router: newRpcRouter(),
router: DefaultRouter,
handlers: make(map[string]Handler),
subscribers: make(map[*subscriber][]broker.Subscriber),
exit: make(chan chan error),
@@ -97,17 +97,23 @@ func (s *rpcServer) ServeConn(sock transport.Socket) {
ct = DefaultContentType
}
// TODO: needs better error handling
cf, err := s.newCodec(ct)
if err != nil {
sock.Send(&transport.Message{
Header: map[string]string{
"Content-Type": "text/plain",
},
Body: []byte(err.Error()),
})
s.wg.Done()
return
// setup old protocol
cf := setupProtocol(&msg)
// no old codec
if cf == nil {
// TODO: needs better error handling
var err error
if cf, err = s.newCodec(ct); err != nil {
sock.Send(&transport.Message{
Header: map[string]string{
"Content-Type": "text/plain",
},
Body: []byte(err.Error()),
})
s.wg.Done()
return
}
}
rcodec := newRpcCodec(&msg, sock, cf)
@@ -115,6 +121,7 @@ func (s *rpcServer) ServeConn(sock transport.Socket) {
// internal request
request := &rpcRequest{
service: msg.Header["X-Micro-Service"],
method: msg.Header["X-Micro-Method"],
endpoint: msg.Header["X-Micro-Endpoint"],
contentType: ct,
codec: rcodec,
@@ -151,12 +158,15 @@ func (s *rpcServer) ServeConn(sock transport.Socket) {
// TODO: handle error better
if err := handler(ctx, request, response); err != nil {
// write an error response
rcodec.Write(&codec.Message{
err = rcodec.Write(&codec.Message{
Header: msg.Header,
Error: err.Error(),
Type: codec.Error,
}, nil)
// could not write the error response
if err != nil {
log.Logf("rpc: unable to write error response: %v", err)
}
s.wg.Done()
return
}
@@ -188,12 +198,6 @@ func (s *rpcServer) Init(opts ...Option) error {
for _, opt := range opts {
opt(&s.opts)
}
// update router
r := newRpcRouter()
r.serviceMap = s.router.serviceMap
s.router = r
s.Unlock()
return nil
}
@@ -282,6 +286,7 @@ func (s *rpcServer) Register() error {
node.Metadata["broker"] = config.Broker.String()
node.Metadata["server"] = s.String()
node.Metadata["registry"] = config.Registry.String()
node.Metadata["protocol"] = "mucp"
s.RLock()
// Maps are ordered randomly, sort the keys for consistency
@@ -436,7 +441,8 @@ func (s *rpcServer) Start() error {
return err
}
log.Logf("Listening on %s", ts.Addr())
log.Logf("Transport [%s] Listening on %s", config.Transport.String(), ts.Addr())
s.Lock()
// swap address
addr := s.opts.Address
@@ -491,7 +497,12 @@ func (s *rpcServer) Start() error {
}()
// TODO: subscribe to cruft
return config.Broker.Connect()
if err := config.Broker.Connect(); err != nil {
return err
}
log.Logf("Broker [%s] Listening on %s", config.Broker.String(), config.Broker.Address())
return nil
}
func (s *rpcServer) Stop() error {

View File

@@ -31,6 +31,8 @@ func (r *rpcStream) Send(msg interface{}) error {
defer r.Unlock()
resp := codec.Message{
Target: r.request.Service(),
Method: r.request.Method(),
Endpoint: r.request.Endpoint(),
Id: r.id,
Type: codec.Response,

View File

@@ -45,6 +45,8 @@ type Message interface {
type Request interface {
// Service name requested
Service() string
// The action requested
Method() string
// Endpoint name requested
Endpoint() string
// Content type provided

View File

@@ -2,62 +2,179 @@ package micro
import (
"context"
"errors"
"sync"
"testing"
"github.com/micro/go-micro/registry/mock"
glog "github.com/go-log/log"
"github.com/micro/go-log"
"github.com/micro/go-micro/client"
"github.com/micro/go-micro/registry/memory"
proto "github.com/micro/go-micro/server/debug/proto"
)
func TestService(t *testing.T) {
var wg sync.WaitGroup
func testShutdown(wg *sync.WaitGroup, cancel func()) {
// add 1
wg.Add(1)
// shutdown the service
cancel()
// wait for stop
wg.Wait()
}
func testService(ctx context.Context, wg *sync.WaitGroup, name string) Service {
// set no op logger
log.SetLogger(glog.DefaultLogger)
// add self
wg.Add(1)
// cancellation context
ctx, cancel := context.WithCancel(context.Background())
r := memory.NewRegistry()
r.(*memory.Registry).Setup()
// create service
service := NewService(
Name("test.service"),
return NewService(
Name(name),
Context(ctx),
Registry(mock.NewRegistry()),
Registry(r),
AfterStart(func() error {
wg.Done()
return nil
}),
AfterStop(func() error {
wg.Done()
return nil
}),
)
}
func testRequest(ctx context.Context, c client.Client, name string) error {
// test call debug
req := c.NewRequest(
name,
"Debug.Health",
new(proto.HealthRequest),
)
// we can't test service.Init as it parses the command line
// service.Init()
rsp := new(proto.HealthResponse)
err := c.Call(context.TODO(), req, rsp)
if err != nil {
return err
}
if rsp.Status != "ok" {
return errors.New("service response: " + rsp.Status)
}
return nil
}
// TestService tests running and calling a service
func TestService(t *testing.T) {
// waitgroup for server start
var wg sync.WaitGroup
// cancellation context
ctx, cancel := context.WithCancel(context.Background())
// start test server
service := testService(ctx, &wg, "test.service")
// run service
go func() {
// wait for start
// wait for service to start
wg.Wait()
// test call debug
req := service.Client().NewRequest(
"test.service",
"Debug.Health",
new(proto.HealthRequest),
)
rsp := new(proto.HealthResponse)
err := service.Client().Call(context.TODO(), req, rsp)
if err != nil {
// make a test call
if err := testRequest(ctx, service.Client(), "test.service"); err != nil {
t.Fatal(err)
}
if rsp.Status != "ok" {
t.Fatalf("service response: %s", rsp.Status)
}
// shutdown the service
cancel()
testShutdown(&wg, cancel)
}()
// start service
if err := service.Run(); err != nil {
t.Fatal(err)
}
}
func benchmarkService(b *testing.B, n int, name string) {
// stop the timer
b.StopTimer()
// waitgroup for server start
var wg sync.WaitGroup
// cancellation context
ctx, cancel := context.WithCancel(context.Background())
// create test server
service := testService(ctx, &wg, name)
// start the server
go func() {
if err := service.Run(); err != nil {
b.Fatal(err)
}
}()
// wait for service to start
wg.Wait()
// make a test call to warm the cache
for i := 0; i < 10; i++ {
if err := testRequest(ctx, service.Client(), name); err != nil {
b.Fatal(err)
}
}
// start the timer
b.StartTimer()
// number of iterations
for i := 0; i < b.N; i++ {
// for concurrency
for j := 0; j < n; j++ {
wg.Add(1)
go func() {
err := testRequest(ctx, service.Client(), name)
wg.Done()
if err != nil {
b.Fatal(err)
}
}()
}
// wait for test completion
wg.Wait()
}
// stop the timer
b.StopTimer()
// shutdown service
testShutdown(&wg, cancel)
}
func BenchmarkService1(b *testing.B) {
benchmarkService(b, 1, "test.service.1")
}
func BenchmarkService8(b *testing.B) {
benchmarkService(b, 8, "test.service.8")
}
func BenchmarkService16(b *testing.B) {
benchmarkService(b, 16, "test.service.16")
}
func BenchmarkService32(b *testing.B) {
benchmarkService(b, 32, "test.service.32")
}
func BenchmarkService64(b *testing.B) {
benchmarkService(b, 64, "test.service.64")
}

View File

@@ -1,4 +1,5 @@
package mock
// Package memory is an in-memory transport
package memory
import (
"errors"
@@ -11,7 +12,7 @@ import (
"github.com/micro/go-micro/transport"
)
type mockSocket struct {
type memorySocket struct {
recv chan *transport.Message
send chan *transport.Message
// sock exit
@@ -23,26 +24,26 @@ type mockSocket struct {
remote string
}
type mockClient struct {
*mockSocket
type memoryClient struct {
*memorySocket
opts transport.DialOptions
}
type mockListener struct {
type memoryListener struct {
addr string
exit chan bool
conn chan *mockSocket
conn chan *memorySocket
opts transport.ListenOptions
}
type mockTransport struct {
type memoryTransport struct {
opts transport.Options
sync.Mutex
listeners map[string]*mockListener
listeners map[string]*memoryListener
}
func (ms *mockSocket) Recv(m *transport.Message) error {
func (ms *memorySocket) Recv(m *transport.Message) error {
select {
case <-ms.exit:
return errors.New("connection closed")
@@ -54,15 +55,15 @@ func (ms *mockSocket) Recv(m *transport.Message) error {
return nil
}
func (ms *mockSocket) Local() string {
func (ms *memorySocket) Local() string {
return ms.local
}
func (ms *mockSocket) Remote() string {
func (ms *memorySocket) Remote() string {
return ms.remote
}
func (ms *mockSocket) Send(m *transport.Message) error {
func (ms *memorySocket) Send(m *transport.Message) error {
select {
case <-ms.exit:
return errors.New("connection closed")
@@ -73,7 +74,7 @@ func (ms *mockSocket) Send(m *transport.Message) error {
return nil
}
func (ms *mockSocket) Close() error {
func (ms *memorySocket) Close() error {
select {
case <-ms.exit:
return nil
@@ -83,11 +84,11 @@ func (ms *mockSocket) Close() error {
return nil
}
func (m *mockListener) Addr() string {
func (m *memoryListener) Addr() string {
return m.addr
}
func (m *mockListener) Close() error {
func (m *memoryListener) Close() error {
select {
case <-m.exit:
return nil
@@ -97,13 +98,13 @@ func (m *mockListener) Close() error {
return nil
}
func (m *mockListener) Accept(fn func(transport.Socket)) error {
func (m *memoryListener) Accept(fn func(transport.Socket)) error {
for {
select {
case <-m.exit:
return nil
case c := <-m.conn:
go fn(&mockSocket{
go fn(&memorySocket{
lexit: c.lexit,
exit: c.exit,
send: c.recv,
@@ -115,7 +116,7 @@ func (m *mockListener) Accept(fn func(transport.Socket)) error {
}
}
func (m *mockTransport) Dial(addr string, opts ...transport.DialOption) (transport.Client, error) {
func (m *memoryTransport) Dial(addr string, opts ...transport.DialOption) (transport.Client, error) {
m.Lock()
defer m.Unlock()
@@ -129,8 +130,8 @@ func (m *mockTransport) Dial(addr string, opts ...transport.DialOption) (transpo
o(&options)
}
client := &mockClient{
&mockSocket{
client := &memoryClient{
&memorySocket{
send: make(chan *transport.Message),
recv: make(chan *transport.Message),
exit: make(chan bool),
@@ -145,13 +146,13 @@ func (m *mockTransport) Dial(addr string, opts ...transport.DialOption) (transpo
select {
case <-listener.exit:
return nil, errors.New("connection error")
case listener.conn <- client.mockSocket:
case listener.conn <- client.memorySocket:
}
return client, nil
}
func (m *mockTransport) Listen(addr string, opts ...transport.ListenOption) (transport.Listener, error) {
func (m *memoryTransport) Listen(addr string, opts ...transport.ListenOption) (transport.Listener, error) {
m.Lock()
defer m.Unlock()
@@ -174,10 +175,10 @@ func (m *mockTransport) Listen(addr string, opts ...transport.ListenOption) (tra
return nil, errors.New("already listening on " + addr)
}
listener := &mockListener{
listener := &memoryListener{
opts: options,
addr: addr,
conn: make(chan *mockSocket),
conn: make(chan *memorySocket),
exit: make(chan bool),
}
@@ -186,19 +187,19 @@ func (m *mockTransport) Listen(addr string, opts ...transport.ListenOption) (tra
return listener, nil
}
func (m *mockTransport) Init(opts ...transport.Option) error {
func (m *memoryTransport) Init(opts ...transport.Option) error {
for _, o := range opts {
o(&m.opts)
}
return nil
}
func (m *mockTransport) Options() transport.Options {
func (m *memoryTransport) Options() transport.Options {
return m.opts
}
func (m *mockTransport) String() string {
return "mock"
func (m *memoryTransport) String() string {
return "memory"
}
func NewTransport(opts ...transport.Option) transport.Transport {
@@ -207,8 +208,8 @@ func NewTransport(opts ...transport.Option) transport.Transport {
o(&options)
}
return &mockTransport{
return &memoryTransport{
opts: options,
listeners: make(map[string]*mockListener),
listeners: make(map[string]*memoryListener),
}
}

View File

@@ -1,4 +1,4 @@
package mock
package memory
import (
"testing"
@@ -6,7 +6,7 @@ import (
"github.com/micro/go-micro/transport"
)
func TestTransport(t *testing.T) {
func TestMemoryTransport(t *testing.T) {
tr := NewTransport()
// bind / listen

View File

@@ -13,15 +13,11 @@ type clientWrapper struct {
}
func (c *clientWrapper) setHeaders(ctx context.Context) context.Context {
md := make(metadata.Metadata)
if mda, ok := metadata.FromContext(ctx); ok {
// make copy of metadata
for k, v := range mda {
md[k] = v
}
}
// copy metadata
mda, _ := metadata.FromContext(ctx)
md := metadata.Copy(mda)
// set headers
for k, v := range c.headers {
if _, ok := md[k]; !ok {
md[k] = v