diff --git a/client/client.go b/client/client.go index 542b8e8e..f6720cdf 100644 --- a/client/client.go +++ b/client/client.go @@ -1,6 +1,7 @@ package client import ( + "github.com/myodc/go-micro/registry" "github.com/myodc/go-micro/transport" "golang.org/x/net/context" ) @@ -11,9 +12,18 @@ type Client interface { NewJsonRequest(string, string, interface{}) Request Call(context.Context, Request, interface{}) error CallRemote(context.Context, string, Request, interface{}) error + Stream(context.Context, Request, interface{}) (Streamer, error) + StreamRemote(context.Context, string, Request, interface{}) (Streamer, error) +} + +type Streamer interface { + Request() Request + Error() error + Close() error } type options struct { + registry registry.Registry transport transport.Transport } @@ -23,6 +33,12 @@ var ( DefaultClient Client = newRpcClient() ) +func Registry(r registry.Registry) Option { + return func(o *options) { + o.registry = r + } +} + func Transport(t transport.Transport) Option { return func(o *options) { o.transport = t @@ -37,6 +53,14 @@ func CallRemote(ctx context.Context, address string, request Request, response i return DefaultClient.CallRemote(ctx, address, request, response) } +func Stream(ctx context.Context, request Request, responseChan interface{}) (Streamer, error) { + return DefaultClient.Stream(ctx, request, responseChan) +} + +func StreamRemote(ctx context.Context, address string, request Request, responseChan interface{}) (Streamer, error) { + return DefaultClient.StreamRemote(ctx, address, request, responseChan) +} + func NewClient(opt ...Option) Client { return newRpcClient(opt...) } diff --git a/client/rpc_client.go b/client/rpc_client.go index 57a90662..6c185456 100644 --- a/client/rpc_client.go +++ b/client/rpc_client.go @@ -1,7 +1,6 @@ package client import ( - "bytes" "fmt" "math/rand" "net/http" @@ -13,11 +12,8 @@ import ( "github.com/myodc/go-micro/transport" rpc "github.com/youtube/vitess/go/rpcplus" - js "github.com/youtube/vitess/go/rpcplus/jsonrpc" - pb "github.com/youtube/vitess/go/rpcplus/pbrpc" "golang.org/x/net/context" - "google.golang.org/grpc" ) type headerRoundTripper struct { @@ -54,46 +50,8 @@ func (t *headerRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) } func (r *rpcClient) call(ctx context.Context, address string, request Request, response interface{}) error { - switch request.ContentType() { - case "application/grpc": - cc, err := grpc.Dial(address) - if err != nil { - return errors.InternalServerError("go.micro.client", fmt.Sprintf("Error connecting to server: %v", err)) - } - if err := grpc.Invoke(ctx, request.Method(), request.Request(), response, cc); err != nil { - return errors.InternalServerError("go.micro.client", fmt.Sprintf("Error sending request: %v", err)) - } - return nil - } - - pReq := &rpc.Request{ - ServiceMethod: request.Method(), - } - - reqB := bytes.NewBuffer(nil) - defer reqB.Reset() - buf := &buffer{ - reqB, - } - - var cc rpc.ClientCodec - switch request.ContentType() { - case "application/octet-stream": - cc = pb.NewClientCodec(buf) - case "application/json": - cc = js.NewClientCodec(buf) - default: - return errors.InternalServerError("go.micro.client", fmt.Sprintf("Unsupported request type: %s", request.ContentType())) - } - - err := cc.WriteRequest(pReq, request.Request()) - if err != nil { - return errors.InternalServerError("go.micro.client", fmt.Sprintf("Error writing request: %v", err)) - } - msg := &transport.Message{ Header: make(map[string]string), - Body: reqB.Bytes(), } md, ok := c.GetMetadata(ctx) @@ -110,42 +68,37 @@ func (r *rpcClient) call(ctx context.Context, address string, request Request, r return errors.InternalServerError("go.micro.client", fmt.Sprintf("Error sending request: %v", err)) } - rsp, err := c.Send(msg) + client := rpc.NewClientWithCodec(newRpcPlusCodec(msg, c)) + return client.Call(ctx, request.Method(), request.Request(), response) +} + +func (r *rpcClient) stream(ctx context.Context, address string, request Request, responseChan interface{}) (Streamer, error) { + msg := &transport.Message{ + Header: make(map[string]string), + } + + md, ok := c.GetMetadata(ctx) + if ok { + for k, v := range md { + msg.Header[k] = v + } + } + + msg.Header["Content-Type"] = request.ContentType() + + c, err := r.opts.transport.Dial(address, transport.WithStream()) if err != nil { - return errors.InternalServerError("go.micro.client", fmt.Sprintf("Error sending request: %v", err)) + return nil, errors.InternalServerError("go.micro.client", fmt.Sprintf("Error sending request: %v", err)) } - rspB := bytes.NewBuffer(rsp.Body) - defer rspB.Reset() - rBuf := &buffer{ - rspB, - } + client := rpc.NewClientWithCodec(newRpcPlusCodec(msg, c)) + call := client.StreamGo(request.Method(), request.Request(), responseChan) - switch rsp.Header["Content-Type"] { - case "application/octet-stream": - cc = pb.NewClientCodec(rBuf) - case "application/json": - cc = js.NewClientCodec(rBuf) - default: - return errors.InternalServerError("go.micro.client", string(rsp.Body)) - } - - pRsp := &rpc.Response{} - err = cc.ReadResponseHeader(pRsp) - if err != nil { - return errors.InternalServerError("go.micro.client", fmt.Sprintf("Error reading response headers: %v", err)) - } - - if len(pRsp.Error) > 0 { - return errors.Parse(pRsp.Error) - } - - err = cc.ReadResponseBody(response) - if err != nil { - return errors.InternalServerError("go.micro.client", fmt.Sprintf("Error reading response body: %v", err)) - } - - return nil + return &rpcStream{ + request: request, + call: call, + client: client, + }, nil } func (r *rpcClient) CallRemote(ctx context.Context, address string, request Request, response interface{}) error { @@ -174,6 +127,31 @@ func (r *rpcClient) Call(ctx context.Context, request Request, response interfac return r.call(ctx, address, request, response) } +func (r *rpcClient) StreamRemote(ctx context.Context, address string, request Request, responseChan interface{}) (Streamer, error) { + return r.stream(ctx, address, request, responseChan) +} + +func (r *rpcClient) Stream(ctx context.Context, request Request, responseChan interface{}) (Streamer, error) { + service, err := registry.GetService(request.Service()) + if err != nil { + return nil, errors.InternalServerError("go.micro.client", err.Error()) + } + + if len(service.Nodes) == 0 { + return nil, errors.NotFound("go.micro.client", "Service not found") + } + + n := rand.Int() % len(service.Nodes) + node := service.Nodes[n] + + address := node.Address + if node.Port > 0 { + address = fmt.Sprintf("%s:%d", address, node.Port) + } + + return r.stream(ctx, address, request, responseChan) +} + func (r *rpcClient) NewRequest(service, method string, request interface{}) Request { return r.NewProtoRequest(service, method, request) } diff --git a/client/rpc_codec.go b/client/rpc_codec.go new file mode 100644 index 00000000..223d52b8 --- /dev/null +++ b/client/rpc_codec.go @@ -0,0 +1,85 @@ +package client + +import ( + "bytes" + "fmt" + + "github.com/myodc/go-micro/transport" + rpc "github.com/youtube/vitess/go/rpcplus" + js "github.com/youtube/vitess/go/rpcplus/jsonrpc" + pb "github.com/youtube/vitess/go/rpcplus/pbrpc" +) + +type rpcPlusCodec struct { + client transport.Client + codec rpc.ClientCodec + + req *transport.Message + + wbuf *bytes.Buffer + rbuf *bytes.Buffer +} + +func newRpcPlusCodec(req *transport.Message, client transport.Client) *rpcPlusCodec { + return &rpcPlusCodec{ + req: req, + client: client, + wbuf: bytes.NewBuffer(nil), + rbuf: bytes.NewBuffer(nil), + } +} + +func (c *rpcPlusCodec) WriteRequest(req *rpc.Request, body interface{}) error { + c.wbuf.Reset() + buf := &buffer{c.wbuf} + + var cc rpc.ClientCodec + switch c.req.Header["Content-Type"] { + case "application/octet-stream": + cc = pb.NewClientCodec(buf) + case "application/json": + cc = js.NewClientCodec(buf) + default: + return fmt.Errorf("unsupported request type: %s", c.req.Header["Content-Type"]) + } + + if err := cc.WriteRequest(req, body); err != nil { + return err + } + + c.req.Body = c.wbuf.Bytes() + return c.client.Send(c.req) +} + +func (c *rpcPlusCodec) ReadResponseHeader(r *rpc.Response) error { + var m transport.Message + + if err := c.client.Recv(&m); err != nil { + return err + } + + c.rbuf.Reset() + c.rbuf.Write(m.Body) + buf := &buffer{c.rbuf} + + switch m.Header["Content-Type"] { + case "application/octet-stream": + c.codec = pb.NewClientCodec(buf) + case "application/json": + c.codec = js.NewClientCodec(buf) + default: + return fmt.Errorf("%s", string(m.Body)) + } + + return c.codec.ReadResponseHeader(r) +} + +func (c *rpcPlusCodec) ReadResponseBody(r interface{}) error { + return c.codec.ReadResponseBody(r) +} + +func (c *rpcPlusCodec) Close() error { + c.rbuf.Reset() + c.wbuf.Reset() + return c.client.Close() +} diff --git a/client/rpc_stream.go b/client/rpc_stream.go new file mode 100644 index 00000000..7816af1a --- /dev/null +++ b/client/rpc_stream.go @@ -0,0 +1,23 @@ +package client + +import ( + rpc "github.com/youtube/vitess/go/rpcplus" +) + +type rpcStream struct { + request Request + call *rpc.Call + client *rpc.Client +} + +func (r *rpcStream) Request() Request { + return r.request +} + +func (r *rpcStream) Error() error { + return r.call.Error +} + +func (r *rpcStream) Close() error { + return r.client.Close() +} diff --git a/cmd/cmd.go b/cmd/cmd.go index 1f67cf2e..a18d9e88 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -88,40 +88,37 @@ var ( Usage: "Comma-separated list of transport addresses", }, } + + Brokers = map[string]func([]string, ...broker.Option) broker.Broker{ + "http": http.NewBroker, + "nats": nats.NewBroker, + "rabbitmq": rabbitmq.NewBroker, + } + + Registries = map[string]func([]string, ...registry.Option) registry.Registry{ + "kubernetes": kubernetes.NewRegistry, + "consul": consul.NewRegistry, + "etcd": etcd.NewRegistry, + } + + Transports = map[string]func([]string, ...transport.Option) transport.Transport{ + "http": thttp.NewTransport, + "rabbitmq": trmq.NewTransport, + "nats": tnats.NewTransport, + } ) func Setup(c *cli.Context) error { - bAddrs := strings.Split(c.String("broker_address"), ",") - - switch c.String("broker") { - case "http": - broker.DefaultBroker = http.NewBroker(bAddrs) - case "nats": - broker.DefaultBroker = nats.NewBroker(bAddrs) - case "rabbitmq": - broker.DefaultBroker = rabbitmq.NewBroker(bAddrs) + if b, ok := Brokers[c.String("broker")]; ok { + broker.DefaultBroker = b(strings.Split(c.String("broker_address"), ",")) } - rAddrs := strings.Split(c.String("registry_address"), ",") - - switch c.String("registry") { - case "kubernetes": - registry.DefaultRegistry = kubernetes.NewRegistry(rAddrs) - case "consul": - registry.DefaultRegistry = consul.NewRegistry(rAddrs) - case "etcd": - registry.DefaultRegistry = etcd.NewRegistry(rAddrs) + if r, ok := Registries[c.String("registry")]; ok { + registry.DefaultRegistry = r(strings.Split(c.String("registry_address"), ",")) } - tAddrs := strings.Split(c.String("transport_address"), ",") - - switch c.String("transport") { - case "http": - transport.DefaultTransport = thttp.NewTransport(tAddrs) - case "rabbitmq": - transport.DefaultTransport = trmq.NewTransport(tAddrs) - case "nats": - transport.DefaultTransport = tnats.NewTransport(tAddrs) + if t, ok := Transports[c.String("transport")]; ok { + transport.DefaultTransport = t(strings.Split(c.String("transport_address"), ",")) } metadata := make(map[string]string) diff --git a/examples/client/main.go b/examples/client/main.go index 0aef3b8e..5cbbe331 100644 --- a/examples/client/main.go +++ b/examples/client/main.go @@ -10,9 +10,7 @@ import ( "golang.org/x/net/context" ) -func main() { - cmd.Init() - +func call(i int) { // Create new request to service go.micro.srv.example, method Example.Call req := client.NewRequest("go.micro.srv.example", "Example.Call", &example.Request{ Name: "John", @@ -28,9 +26,45 @@ func main() { // Call service if err := client.Call(ctx, req, rsp); err != nil { - fmt.Println(err) + fmt.Println("err: ", err, rsp) return } - fmt.Println(rsp.Msg) + fmt.Println("Call:", i, "rsp:", rsp.Msg) +} + +func stream() { + // Create new request to service go.micro.srv.example, method Example.Call + req := client.NewRequest("go.micro.srv.example", "Example.Stream", &example.StreamingRequest{ + Count: int64(10), + }) + + rspChan := make(chan *example.StreamingResponse, 10) + + stream, err := client.Stream(context.Background(), req, rspChan) + if err != nil { + fmt.Println("err:", err) + return + } + + for rsp := range rspChan { + fmt.Println("Stream: rsp:", rsp.Count) + } + + if stream.Error() != nil { + fmt.Println("err:", err) + return + } + + stream.Close() +} + +func main() { + cmd.Init() + + for i := 0; i < 10; i++ { + call(i) + } + + stream() } diff --git a/examples/server/handler/example.go b/examples/server/handler/example.go index b0bcc225..3911409b 100644 --- a/examples/server/handler/example.go +++ b/examples/server/handler/example.go @@ -13,7 +13,24 @@ type Example struct{} func (e *Example) Call(ctx context.Context, req *example.Request, rsp *example.Response) error { md, _ := c.GetMetadata(ctx) - log.Info("Received Example.Call request with metadata: %v", md) + log.Infof("Received Example.Call request with metadata: %v", md) rsp.Msg = server.Config().Id() + ": Hello " + req.Name return nil } + +func (e *Example) Stream(ctx context.Context, req *example.StreamingRequest, response func(interface{}) error) error { + log.Infof("Received Example.Stream request with count: %d", req.Count) + for i := 0; i < int(req.Count); i++ { + log.Infof("Responding: %d", i) + + r := &example.StreamingResponse{ + Count: int64(i), + } + + if err := response(r); err != nil { + return err + } + } + + return nil +} diff --git a/examples/server/proto/example/example.pb.go b/examples/server/proto/example/example.pb.go index 251d6517..33ebb3bf 100644 --- a/examples/server/proto/example/example.pb.go +++ b/examples/server/proto/example/example.pb.go @@ -11,6 +11,8 @@ It is generated from these files: It has these top-level messages: Request Response + StreamingRequest + StreamingResponse */ package example @@ -35,5 +37,21 @@ func (m *Response) Reset() { *m = Response{} } func (m *Response) String() string { return proto.CompactTextString(m) } func (*Response) ProtoMessage() {} +type StreamingRequest struct { + Count int64 `protobuf:"varint,1,opt,name=count" json:"count,omitempty"` +} + +func (m *StreamingRequest) Reset() { *m = StreamingRequest{} } +func (m *StreamingRequest) String() string { return proto.CompactTextString(m) } +func (*StreamingRequest) ProtoMessage() {} + +type StreamingResponse struct { + Count int64 `protobuf:"varint,1,opt,name=count" json:"count,omitempty"` +} + +func (m *StreamingResponse) Reset() { *m = StreamingResponse{} } +func (m *StreamingResponse) String() string { return proto.CompactTextString(m) } +func (*StreamingResponse) ProtoMessage() {} + func init() { } diff --git a/examples/server/proto/example/example.proto b/examples/server/proto/example/example.proto index 095efa85..676613fd 100644 --- a/examples/server/proto/example/example.proto +++ b/examples/server/proto/example/example.proto @@ -7,3 +7,11 @@ message Request { message Response { string msg = 1; } + +message StreamingRequest { + int64 count = 1; +} + +message StreamingResponse { + int64 count = 1; +} diff --git a/registry/consul_registry.go b/registry/consul_registry.go index 1b863bbe..0fa3dbbf 100644 --- a/registry/consul_registry.go +++ b/registry/consul_registry.go @@ -54,7 +54,6 @@ func newConsulRegistry(addrs []string, opts ...Option) Registry { services: make(map[string]*Service), } - cr.Watch() return cr } @@ -156,6 +155,6 @@ func (c *consulRegistry) ListServices() ([]*Service, error) { return services, nil } -func (c *consulRegistry) Watch() { - newConsulWatcher(c) +func (c *consulRegistry) Watch() (Watcher, error) { + return newConsulWatcher(c) } diff --git a/registry/consul_watcher.go b/registry/consul_watcher.go index 8a3ff0e9..da5932f4 100644 --- a/registry/consul_watcher.go +++ b/registry/consul_watcher.go @@ -15,20 +15,22 @@ type serviceWatcher struct { name string } -func newConsulWatcher(cr *consulRegistry) *consulWatcher { +func newConsulWatcher(cr *consulRegistry) (Watcher, error) { cw := &consulWatcher{ Registry: cr, watchers: make(map[string]*watch.WatchPlan), } wp, err := watch.Parse(map[string]interface{}{"type": "services"}) - if err == nil { - wp.Handler = cw.Handle - go wp.Run(cr.Address) - cw.wp = wp + if err != nil { + return nil, err } - return cw + wp.Handler = cw.Handle + go wp.Run(cr.Address) + cw.wp = wp + + return cw, nil } func (cw *consulWatcher) serviceHandler(idx uint64, data interface{}) { diff --git a/registry/etcd/etcd.go b/registry/etcd/etcd.go index 9f8f7fbb..cf38eb44 100644 --- a/registry/etcd/etcd.go +++ b/registry/etcd/etcd.go @@ -147,8 +147,9 @@ func (e *etcdRegistry) ListServices() ([]*registry.Service, error) { return services, nil } -func (e *etcdRegistry) Watch() { - newEtcdWatcher(e) +func (e *etcdRegistry) Watch() (registry.Watcher, error) { + // todo: fix watcher + return newEtcdWatcher(e) } func NewRegistry(addrs []string, opt ...registry.Option) registry.Registry { @@ -170,8 +171,5 @@ func NewRegistry(addrs []string, opt ...registry.Option) registry.Registry { services: make(map[string]*registry.Service), } - // Need to fix watcher - // e.Watch() - return e } diff --git a/registry/etcd/watcher.go b/registry/etcd/watcher.go index 1c60832f..4e8fa55b 100644 --- a/registry/etcd/watcher.go +++ b/registry/etcd/watcher.go @@ -10,7 +10,7 @@ type etcdWatcher struct { stop chan bool } -func newEtcdWatcher(r *etcdRegistry) *etcdWatcher { +func newEtcdWatcher(r *etcdRegistry) (registry.Watcher, error) { ew := &etcdWatcher{ registry: r, stop: make(chan bool), @@ -19,53 +19,55 @@ func newEtcdWatcher(r *etcdRegistry) *etcdWatcher { ch := make(chan *etcd.Response) go r.client.Watch(prefix, 0, true, ch, ew.stop) + go ew.watch(ch) - go func() { - for rsp := range ch { - if rsp.Node.Dir { - continue - } + return ew, nil +} - s := decode(rsp.Node.Value) - if s == nil { - continue - } - - r.Lock() - - service, ok := r.services[s.Name] - if !ok { - if rsp.Action == "create" { - r.services[s.Name] = s - } - r.Unlock() - continue - } - - switch rsp.Action { - case "delete": - var nodes []*registry.Node - for _, node := range service.Nodes { - var seen bool - for _, n := range s.Nodes { - if node.Id == n.Id { - seen = true - break - } - } - if !seen { - nodes = append(nodes, node) - } - } - service.Nodes = nodes - case "create": - service.Nodes = append(service.Nodes, s.Nodes...) - } - r.Unlock() +func (e *etcdWatcher) watch(ch chan *etcd.Response) { + for rsp := range ch { + if rsp.Node.Dir { + continue } - }() - return ew + s := decode(rsp.Node.Value) + if s == nil { + continue + } + + e.registry.Lock() + + service, ok := e.registry.services[s.Name] + if !ok { + if rsp.Action == "create" { + e.registry.services[s.Name] = s + } + e.registry.Unlock() + continue + } + + switch rsp.Action { + case "delete": + var nodes []*registry.Node + for _, node := range service.Nodes { + var seen bool + for _, n := range s.Nodes { + if node.Id == n.Id { + seen = true + break + } + } + if !seen { + nodes = append(nodes, node) + } + } + service.Nodes = nodes + case "create": + service.Nodes = append(service.Nodes, s.Nodes...) + } + + e.registry.Unlock() + } } func (ew *etcdWatcher) Stop() { diff --git a/registry/kubernetes/kubernetes.go b/registry/kubernetes/kubernetes.go index 88a76e52..2627bdf0 100644 --- a/registry/kubernetes/kubernetes.go +++ b/registry/kubernetes/kubernetes.go @@ -19,10 +19,6 @@ type kregistry struct { services map[string]*registry.Service } -func (c *kregistry) Watch() { - newWatcher(c) -} - func (c *kregistry) Deregister(s *registry.Service) error { return nil } @@ -97,6 +93,10 @@ func (c *kregistry) ListServices() ([]*registry.Service, error) { return services, nil } +func (c *kregistry) Watch() (registry.Watcher, error) { + return newWatcher(c) +} + func NewRegistry(addrs []string, opts ...registry.Option) registry.Registry { host := "http://" + os.Getenv("KUBERNETES_RO_SERVICE_HOST") + ":" + os.Getenv("KUBERNETES_RO_SERVICE_PORT") if len(addrs) > 0 { @@ -113,7 +113,5 @@ func NewRegistry(addrs []string, opts ...registry.Option) registry.Registry { services: make(map[string]*registry.Service), } - kr.Watch() - return kr } diff --git a/registry/kubernetes/watcher.go b/registry/kubernetes/watcher.go index 4059d8e9..666ef81a 100644 --- a/registry/kubernetes/watcher.go +++ b/registry/kubernetes/watcher.go @@ -1,73 +1,91 @@ package kubernetes import ( - "fmt" "net" - "time" "github.com/GoogleCloudPlatform/kubernetes/pkg/api" - "github.com/GoogleCloudPlatform/kubernetes/pkg/proxy/config" - "github.com/GoogleCloudPlatform/kubernetes/pkg/util" + "github.com/GoogleCloudPlatform/kubernetes/pkg/fields" + "github.com/GoogleCloudPlatform/kubernetes/pkg/labels" + "github.com/GoogleCloudPlatform/kubernetes/pkg/watch" "github.com/myodc/go-micro/registry" ) type watcher struct { registry *kregistry + watcher watch.Interface } -func (k *watcher) OnUpdate(services []api.Service) { - fmt.Println("got update") - activeServices := util.StringSet{} - for _, svc := range services { - fmt.Printf("%#v\n", svc.ObjectMeta) - name, exists := svc.ObjectMeta.Labels["name"] - if !exists { - continue - } - - activeServices.Insert(name) - serviceIP := net.ParseIP(svc.Spec.ClusterIP) - - ks := ®istry.Service{ - Name: name, - Nodes: []*registry.Node{ - ®istry.Node{ - Address: serviceIP.String(), - Port: svc.Spec.Ports[0].Port, - }, - }, - } - - k.registry.mtx.Lock() - k.registry.services[name] = ks - k.registry.mtx.Unlock() +func (k *watcher) update(event watch.Event) { + if event.Object == nil { + return } + var service *api.Service + switch obj := event.Object.(type) { + case *api.Service: + service = obj + default: + return + } + + name, exists := service.ObjectMeta.Labels["name"] + if !exists { + return + } + + switch event.Type { + case watch.Added, watch.Modified: + case watch.Deleted: + k.registry.mtx.Lock() + delete(k.registry.services, name) + k.registry.mtx.Unlock() + return + default: + return + } + + serviceIP := net.ParseIP(service.Spec.ClusterIP) + k.registry.mtx.Lock() - defer k.registry.mtx.Unlock() - for name, _ := range k.registry.services { - if !activeServices.Has(name) { - delete(k.registry.services, name) - } + k.registry.services[name] = ®istry.Service{ + Name: name, + Nodes: []*registry.Node{ + ®istry.Node{ + Address: serviceIP.String(), + Port: service.Spec.Ports[0].Port, + }, + }, } + k.registry.mtx.Unlock() } -func newWatcher(kr *kregistry) *watcher { - serviceConfig := config.NewServiceConfig() - endpointsConfig := config.NewEndpointsConfig() +func (k *watcher) Stop() { + k.watcher.Stop() +} - config.NewSourceAPI( - kr.client.Services(api.NamespaceAll), - kr.client.Endpoints(api.NamespaceAll), - time.Second*10, - serviceConfig.Channel("api"), - endpointsConfig.Channel("api"), - ) +func newWatcher(kr *kregistry) (registry.Watcher, error) { + svi := kr.client.Services(api.NamespaceAll) - ks := &watcher{ + services, err := svi.List(labels.Everything()) + if err != nil { + return nil, err + } + + watch, err := svi.Watch(labels.Everything(), fields.Everything(), services.ResourceVersion) + if err != nil { + return nil, err + } + + w := &watcher{ registry: kr, + watcher: watch, } - serviceConfig.RegisterHandler(ks) - return ks + go func() { + for event := range watch.ResultChan() { + w.update(event) + } + }() + + return w, nil } diff --git a/registry/registry.go b/registry/registry.go index 018fd791..530ef2ca 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -5,6 +5,11 @@ type Registry interface { Deregister(*Service) error GetService(string) (*Service, error) ListServices() ([]*Service, error) + Watch() (Watcher, error) +} + +type Watcher interface { + Stop() } type Service struct { diff --git a/server/buffer.go b/server/buffer.go index e3f6ebec..e833f1a6 100644 --- a/server/buffer.go +++ b/server/buffer.go @@ -5,8 +5,7 @@ import ( ) type buffer struct { - io.Reader - io.Writer + io.ReadWriter } func (b *buffer) Close() error { diff --git a/server/options.go b/server/options.go index 798d575f..08117b9c 100644 --- a/server/options.go +++ b/server/options.go @@ -1,10 +1,12 @@ package server import ( + "github.com/myodc/go-micro/registry" "github.com/myodc/go-micro/transport" ) type options struct { + registry registry.Registry transport transport.Transport metadata map[string]string name string @@ -19,6 +21,10 @@ func newOptions(opt ...Option) options { o(&opts) } + if opts.registry == nil { + opts.registry = registry.DefaultRegistry + } + if opts.transport == nil { opts.transport = transport.DefaultTransport } diff --git a/server/rpc_codec.go b/server/rpc_codec.go new file mode 100644 index 00000000..03651bad --- /dev/null +++ b/server/rpc_codec.go @@ -0,0 +1,82 @@ +package server + +import ( + "bytes" + "fmt" + + "github.com/myodc/go-micro/transport" + rpc "github.com/youtube/vitess/go/rpcplus" + js "github.com/youtube/vitess/go/rpcplus/jsonrpc" + pb "github.com/youtube/vitess/go/rpcplus/pbrpc" +) + +type rpcPlusCodec struct { + socket transport.Socket + codec rpc.ServerCodec + + req *transport.Message + + wbuf *bytes.Buffer + rbuf *bytes.Buffer +} + +func newRpcPlusCodec(req *transport.Message, socket transport.Socket) *rpcPlusCodec { + return &rpcPlusCodec{ + socket: socket, + req: req, + wbuf: bytes.NewBuffer(nil), + rbuf: bytes.NewBuffer(nil), + } +} + +func (c *rpcPlusCodec) ReadRequestHeader(r *rpc.Request) error { + c.rbuf.Reset() + c.rbuf.Write(c.req.Body) + buf := &buffer{c.rbuf} + + switch c.req.Header["Content-Type"] { + case "application/octet-stream": + c.codec = pb.NewServerCodec(buf) + case "application/json": + c.codec = js.NewServerCodec(buf) + default: + return fmt.Errorf("unsupported content type %s", c.req.Header["Content-Type"]) + } + + return c.codec.ReadRequestHeader(r) +} + +func (c *rpcPlusCodec) ReadRequestBody(r interface{}) error { + return c.codec.ReadRequestBody(r) +} + +func (c *rpcPlusCodec) WriteResponse(r *rpc.Response, body interface{}, last bool) error { + c.wbuf.Reset() + buf := &buffer{c.wbuf} + + var cc rpc.ServerCodec + switch c.req.Header["Content-Type"] { + case "application/octet-stream": + cc = pb.NewServerCodec(buf) + case "application/json": + cc = js.NewServerCodec(buf) + default: + return fmt.Errorf("unsupported request type: %s", c.req.Header["Content-Type"]) + } + + if err := cc.WriteResponse(r, body, last); err != nil { + return err + } + + return c.socket.Send(&transport.Message{ + Header: map[string]string{"Content-Type": c.req.Header["Content-Type"]}, + Body: c.wbuf.Bytes(), + }) + +} + +func (c *rpcPlusCodec) Close() error { + c.wbuf.Reset() + c.rbuf.Reset() + return c.socket.Close() +} diff --git a/server/rpc_server.go b/server/rpc_server.go index 276160f8..db372d02 100644 --- a/server/rpc_server.go +++ b/server/rpc_server.go @@ -1,15 +1,11 @@ package server import ( - "bytes" - c "github.com/myodc/go-micro/context" "github.com/myodc/go-micro/transport" log "github.com/golang/glog" rpc "github.com/youtube/vitess/go/rpcplus" - js "github.com/youtube/vitess/go/rpcplus/jsonrpc" - pb "github.com/youtube/vitess/go/rpcplus/pbrpc" "golang.org/x/net/context" ) @@ -34,42 +30,17 @@ func (s *rpcServer) accept(sock transport.Socket) { return } - rbq := bytes.NewBuffer(msg.Body) - rsp := bytes.NewBuffer(nil) - defer rsp.Reset() - defer rbq.Reset() - - buf := &buffer{ - rbq, - rsp, - } - - var cc rpc.ServerCodec - switch msg.Header["Content-Type"] { - case "application/octet-stream": - cc = pb.NewServerCodec(buf) - case "application/json": - cc = js.NewServerCodec(buf) - default: - return - } + codec := newRpcPlusCodec(&msg, sock) // strip our headers - ct := msg.Header["Content-Type"] - delete(msg.Header, "Content-Type") - - ctx := c.WithMetadata(context.Background(), msg.Header) - - if err := s.rpc.ServeRequestWithContext(ctx, cc); err != nil { - return + hdr := make(map[string]string) + for k, v := range msg.Header { + hdr[k] = v } + delete(hdr, "Content-Type") - sock.Send(&transport.Message{ - Header: map[string]string{ - "Content-Type": ct, - }, - Body: rsp.Bytes(), - }) + ctx := c.WithMetadata(context.Background(), hdr) + s.rpc.ServeRequestWithContext(ctx, codec) } func (s *rpcServer) Config() options { diff --git a/server/server.go b/server/server.go index 16a1c2a1..707bd534 100644 --- a/server/server.go +++ b/server/server.go @@ -50,6 +50,12 @@ func Address(a string) Option { } } +func Registry(r registry.Registry) Option { + return func(o *options) { + o.registry = r + } +} + func Transport(t transport.Transport) Option { return func(o *options) { o.transport = t @@ -121,9 +127,9 @@ func Run() error { log.Infof("Registering node: %s", node.Id) - err := registry.Register(service) + err := config.registry.Register(service) if err != nil { - log.Fatal("Failed to register: %v", err) + log.Fatalf("Failed to register: %v", err) } ch := make(chan os.Signal, 1) @@ -131,7 +137,7 @@ func Run() error { log.Infof("Received signal %s", <-ch) log.Infof("Deregistering %s", node.Id) - registry.Deregister(service) + config.registry.Deregister(service) return Stop() } diff --git a/transport/http_transport.go b/transport/http_transport.go index f17f3d2c..079f12ec 100644 --- a/transport/http_transport.go +++ b/transport/http_transport.go @@ -1,6 +1,7 @@ package transport import ( + "bufio" "bytes" "errors" "io/ioutil" @@ -9,34 +10,27 @@ import ( "net/url" ) -type headerRoundTripper struct { - r http.RoundTripper -} - -type httpTransport struct { - client *http.Client -} +type httpTransport struct{} type httpTransportClient struct { - ht *httpTransport - addr string + ht *httpTransport + addr string + conn net.Conn + buff *bufio.Reader + dialOpts dialOptions + r chan *http.Request } type httpTransportSocket struct { - r *http.Request - w http.ResponseWriter + r *http.Request + conn net.Conn } type httpTransportListener struct { listener net.Listener } -func (t *headerRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { - r.Header.Set("X-Client-Version", "1.0") - return t.r.RoundTrip(r) -} - -func (h *httpTransportClient) Send(m *Message) (*Message, error) { +func (h *httpTransportClient) Send(m *Message) error { header := make(http.Header) for k, v := range m.Header { @@ -49,7 +43,7 @@ func (h *httpTransportClient) Send(m *Message) (*Message, error) { reqB, } - hreq := &http.Request{ + req := &http.Request{ Method: "POST", URL: &url.URL{ Scheme: "http", @@ -61,15 +55,26 @@ func (h *httpTransportClient) Send(m *Message) (*Message, error) { Host: h.addr, } - rsp, err := h.ht.client.Do(hreq) + h.r <- req + + return req.Write(h.conn) +} + +func (h *httpTransportClient) Recv(m *Message) error { + var r *http.Request + if !h.dialOpts.stream { + r = <-h.r + } + + rsp, err := http.ReadResponse(h.buff, r) if err != nil { - return nil, err + return err } defer rsp.Body.Close() b, err := ioutil.ReadAll(rsp.Body) if err != nil { - return nil, err + return err } mr := &Message{ @@ -85,11 +90,12 @@ func (h *httpTransportClient) Send(m *Message) (*Message, error) { } } - return mr, nil + *m = *mr + return nil } func (h *httpTransportClient) Close() error { - return nil + return h.conn.Close() } func (h *httpTransportSocket) Recv(m *Message) error { @@ -101,7 +107,7 @@ func (h *httpTransportSocket) Recv(m *Message) error { if err != nil { return err } - + h.r.Body.Close() mr := &Message{ Header: make(map[string]string), Body: b, @@ -120,16 +126,30 @@ func (h *httpTransportSocket) Recv(m *Message) error { } func (h *httpTransportSocket) Send(m *Message) error { - for k, v := range m.Header { - h.w.Header().Set(k, v) + b := bytes.NewBuffer(m.Body) + defer b.Reset() + rsp := &http.Response{ + Header: h.r.Header, + Body: &buffer{b}, + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: int64(len(m.Body)), + // Request: h.r, } - _, err := h.w.Write(m.Body) - return err + for k, v := range m.Header { + rsp.Header.Set(k, v) + } + + return rsp.Write(h.conn) } func (h *httpTransportSocket) Close() error { - return nil + // TODO: fix this + return h.conn.Close() } func (h *httpTransportListener) Addr() string { @@ -143,9 +163,14 @@ func (h *httpTransportListener) Close() error { func (h *httpTransportListener) Accept(fn func(Socket)) error { srv := &http.Server{ Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, _, err := w.(http.Hijacker).Hijack() + if err != nil { + return + } + fn(&httpTransportSocket{ - r: r, - w: w, + conn: conn, + r: r, }) }), } @@ -153,10 +178,25 @@ func (h *httpTransportListener) Accept(fn func(Socket)) error { return srv.Serve(h.listener) } -func (h *httpTransport) Dial(addr string) (Client, error) { +func (h *httpTransport) Dial(addr string, opts ...DialOption) (Client, error) { + conn, err := net.Dial("tcp", addr) + if err != nil { + return nil, err + } + + var dopts dialOptions + + for _, opt := range opts { + opt(&dopts) + } + return &httpTransportClient{ - ht: h, - addr: addr, + ht: h, + addr: addr, + conn: conn, + buff: bufio.NewReader(conn), + dialOpts: dopts, + r: make(chan *http.Request, 1), }, nil } @@ -172,8 +212,5 @@ func (h *httpTransport) Listen(addr string) (Listener, error) { } func newHttpTransport(addrs []string, opt ...Option) *httpTransport { - client := &http.Client{} - client.Transport = &headerRoundTripper{http.DefaultTransport} - - return &httpTransport{client: client} + return &httpTransport{} } diff --git a/transport/nats/nats.go b/transport/nats/nats.go index ce53b311..15ccaa92 100644 --- a/transport/nats/nats.go +++ b/transport/nats/nats.go @@ -17,6 +17,8 @@ type ntport struct { type ntportClient struct { conn *nats.Conn addr string + id string + sub *nats.Subscription } type ntportSocket struct { @@ -30,26 +32,32 @@ type ntportListener struct { exit chan bool } -func (n *ntportClient) Send(m *transport.Message) (*transport.Message, error) { +func (n *ntportClient) Send(m *transport.Message) error { b, err := json.Marshal(m) if err != nil { - return nil, err + return err } - rsp, err := n.conn.Request(n.addr, b, time.Second*10) + return n.conn.PublishRequest(n.addr, n.id, b) +} + +func (n *ntportClient) Recv(m *transport.Message) error { + rsp, err := n.sub.NextMsg(time.Second * 10) if err != nil { - return nil, err + return err } var mr *transport.Message if err := json.Unmarshal(rsp.Data, &mr); err != nil { - return nil, err + return err } - return mr, nil + *m = *mr + return nil } func (n *ntportClient) Close() error { + n.sub.Unsubscribe() n.conn.Close() return nil } @@ -102,7 +110,7 @@ func (n *ntportListener) Accept(fn func(transport.Socket)) error { return s.Unsubscribe() } -func (n *ntport) Dial(addr string) (transport.Client, error) { +func (n *ntport) Dial(addr string, opts ...transport.DialOption) (transport.Client, error) { cAddr := nats.DefaultURL if len(n.addrs) > 0 && strings.HasPrefix(n.addrs[0], "nats://") { @@ -114,9 +122,17 @@ func (n *ntport) Dial(addr string) (transport.Client, error) { return nil, err } + id := nats.NewInbox() + sub, err := c.SubscribeSync(id) + if err != nil { + return nil, err + } + return &ntportClient{ conn: c, addr: addr, + id: id, + sub: sub, }, nil } diff --git a/transport/rabbitmq/rabbitmq.go b/transport/rabbitmq/rabbitmq.go index cc94962e..6b8d9e52 100644 --- a/transport/rabbitmq/rabbitmq.go +++ b/transport/rabbitmq/rabbitmq.go @@ -15,18 +15,21 @@ import ( type rmqtport struct { conn *rabbitMQConn addrs []string -} -type rmqtportClient struct { once sync.Once - rt *rmqtport - addr string replyTo string sync.Mutex inflight map[string]chan amqp.Delivery } +type rmqtportClient struct { + rt *rmqtport + addr string + corId string + reply chan amqp.Delivery +} + type rmqtportSocket struct { conn *rabbitMQConn d *amqp.Delivery @@ -37,86 +40,34 @@ type rmqtportListener struct { addr string } -func (r *rmqtportClient) init() { - <-r.rt.conn.Init() - if err := r.rt.conn.Channel.DeclareReplyQueue(r.replyTo); err != nil { - return - } - deliveries, err := r.rt.conn.Channel.ConsumeQueue(r.replyTo) - if err != nil { - return - } - go func() { - for delivery := range deliveries { - go r.handle(delivery) - } - }() -} - -func (r *rmqtportClient) handle(delivery amqp.Delivery) { - ch := r.getReq(delivery.CorrelationId) - if ch == nil { - return - } - select { - case ch <- delivery: - default: - } -} - -func (r *rmqtportClient) putReq(id string) chan amqp.Delivery { - r.Lock() - ch := make(chan amqp.Delivery, 1) - r.inflight[id] = ch - r.Unlock() - return ch -} - -func (r *rmqtportClient) getReq(id string) chan amqp.Delivery { - r.Lock() - defer r.Unlock() - if ch, ok := r.inflight[id]; ok { - delete(r.inflight, id) - return ch - } - return nil -} - -func (r *rmqtportClient) Send(m *transport.Message) (*transport.Message, error) { - r.once.Do(r.init) - +func (r *rmqtportClient) Send(m *transport.Message) error { if !r.rt.conn.IsConnected() { - return nil, errors.New("Not connected to AMQP") + return errors.New("Not connected to AMQP") } - id, err := uuid.NewV4() - if err != nil { - return nil, err - } - - replyChan := r.putReq(id.String()) - headers := amqp.Table{} - for k, v := range m.Header { headers[k] = v } message := amqp.Publishing{ - CorrelationId: id.String(), + CorrelationId: r.corId, Timestamp: time.Now().UTC(), Body: m.Body, - ReplyTo: r.replyTo, + ReplyTo: r.rt.replyTo, Headers: headers, } if err := r.rt.conn.Publish("micro", r.addr, message); err != nil { - r.getReq(id.String()) - return nil, err + return err } + return nil +} + +func (r *rmqtportClient) Recv(m *transport.Message) error { select { - case d := <-replyChan: + case d := <-r.reply: mr := &transport.Message{ Header: make(map[string]string), Body: d.Body, @@ -126,13 +77,15 @@ func (r *rmqtportClient) Send(m *transport.Message) (*transport.Message, error) mr.Header[k] = fmt.Sprintf("%v", v) } - return mr, nil + *m = *mr + return nil case <-time.After(time.Second * 10): - return nil, errors.New("timed out") + return errors.New("timed out") } } func (r *rmqtportClient) Close() error { + r.rt.popReq(r.corId) return nil } @@ -202,17 +155,68 @@ func (r *rmqtportListener) Accept(fn func(transport.Socket)) error { return nil } -func (r *rmqtport) Dial(addr string) (transport.Client, error) { +func (r *rmqtport) putReq(id string) chan amqp.Delivery { + r.Lock() + ch := make(chan amqp.Delivery, 1) + r.inflight[id] = ch + r.Unlock() + return ch +} + +func (r *rmqtport) getReq(id string) chan amqp.Delivery { + r.Lock() + defer r.Unlock() + if ch, ok := r.inflight[id]; ok { + return ch + } + return nil +} + +func (r *rmqtport) popReq(id string) { + r.Lock() + defer r.Unlock() + if _, ok := r.inflight[id]; ok { + delete(r.inflight, id) + } +} + +func (r *rmqtport) init() { + <-r.conn.Init() + if err := r.conn.Channel.DeclareReplyQueue(r.replyTo); err != nil { + return + } + deliveries, err := r.conn.Channel.ConsumeQueue(r.replyTo) + if err != nil { + return + } + go func() { + for delivery := range deliveries { + go r.handle(delivery) + } + }() +} + +func (r *rmqtport) handle(delivery amqp.Delivery) { + ch := r.getReq(delivery.CorrelationId) + if ch == nil { + return + } + ch <- delivery +} + +func (r *rmqtport) Dial(addr string, opts ...transport.DialOption) (transport.Client, error) { id, err := uuid.NewV4() if err != nil { return nil, err } + r.once.Do(r.init) + return &rmqtportClient{ - rt: r, - addr: addr, - inflight: make(map[string]chan amqp.Delivery), - replyTo: fmt.Sprintf("replyTo-%s", id.String()), + rt: r, + addr: addr, + corId: id.String(), + reply: r.putReq(id.String()), }, nil } @@ -232,8 +236,12 @@ func (r *rmqtport) Listen(addr string) (transport.Listener, error) { } func NewTransport(addrs []string, opt ...transport.Option) transport.Transport { + id, _ := uuid.NewV4() + return &rmqtport{ - conn: newRabbitMQConn("", addrs), - addrs: addrs, + conn: newRabbitMQConn("", addrs), + addrs: addrs, + replyTo: id.String(), + inflight: make(map[string]chan amqp.Delivery), } } diff --git a/transport/transport.go b/transport/transport.go index 8211e50d..1c84b889 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -1,6 +1,7 @@ package transport type Message struct { + Id string Header map[string]string Body []byte } @@ -12,7 +13,8 @@ type Socket interface { } type Client interface { - Send(*Message) (*Message, error) + Recv(*Message) error + Send(*Message) error Close() error } @@ -23,24 +25,36 @@ type Listener interface { } type Transport interface { - Dial(addr string) (Client, error) + Dial(addr string, opts ...DialOption) (Client, error) Listen(addr string) (Listener, error) } type options struct{} +type dialOptions struct { + stream bool +} + type Option func(*options) +type DialOption func(*dialOptions) + var ( DefaultTransport Transport = newHttpTransport([]string{}) ) +func WithStream() DialOption { + return func(o *dialOptions) { + o.stream = true + } +} + func NewTransport(addrs []string, opt ...Option) Transport { return newHttpTransport(addrs, opt...) } -func Dial(addr string) (Client, error) { - return DefaultTransport.Dial(addr) +func Dial(addr string, opts ...DialOption) (Client, error) { + return DefaultTransport.Dial(addr, opts...) } func Listen(addr string) (Listener, error) {