diff --git a/client/rpc_client.go b/client/rpc_client.go index b917ae2a..524b5049 100644 --- a/client/rpc_client.go +++ b/client/rpc_client.go @@ -3,14 +3,13 @@ package client import ( "bytes" "fmt" - "io/ioutil" "math/rand" "net/http" - "net/url" "time" "github.com/myodc/go-micro/errors" "github.com/myodc/go-micro/registry" + "github.com/myodc/go-micro/transport" rpc "github.com/youtube/vitess/go/rpcplus" js "github.com/youtube/vitess/go/rpcplus/jsonrpc" pb "github.com/youtube/vitess/go/rpcplus/pbrpc" @@ -22,7 +21,9 @@ type headerRoundTripper struct { r http.RoundTripper } -type RpcClient struct{} +type RpcClient struct { + transport transport.Transport +} func init() { rand.Seed(time.Now().UnixNano()) @@ -71,48 +72,43 @@ func (r *RpcClient) call(address, path string, request Request, response interfa return errors.InternalServerError("go.micro.client", fmt.Sprintf("Error writing request: %v", err)) } - client := &http.Client{} - client.Transport = &headerRoundTripper{http.DefaultTransport} - - request.Headers().Set("Content-Type", request.ContentType()) - - hreq := &http.Request{ - Method: "POST", - URL: &url.URL{ - Scheme: "http", - Host: address, - Path: path, - }, - Header: request.Headers().(http.Header), - Body: buf, - ContentLength: int64(reqB.Len()), - Host: address, + msg := &transport.Message{ + Header: make(map[string]string), + Body: reqB.Bytes(), } - rsp, err := client.Do(hreq) + h, _ := request.Headers().(http.Header) + for k, v := range h { + if len(v) > 0 { + msg.Header[k] = v[0] + } + } + + msg.Header["Content-Type"] = request.ContentType() + + c, err := r.transport.NewClient(request.Service(), address) if err != nil { return errors.InternalServerError("go.micro.client", fmt.Sprintf("Error sending request: %v", err)) } - defer rsp.Body.Close() - b, err := ioutil.ReadAll(rsp.Body) + rsp, err := c.Send(msg) if err != nil { - return errors.InternalServerError("go.micro.client", fmt.Sprintf("Error reading response: %v", err)) + return errors.InternalServerError("go.micro.client", fmt.Sprintf("Error sending request: %v", err)) } - rspB := bytes.NewBuffer(b) + rspB := bytes.NewBuffer(rsp.Body) defer rspB.Reset() rBuf := &buffer{ rspB, } - switch rsp.Header.Get("Content-Type") { + switch rsp.Header["Content-Type"] { case "application/octet-stream": cc = pb.NewClientCodec(rBuf) case "application/json": cc = js.NewClientCodec(rBuf) default: - return errors.InternalServerError("go.micro.client", string(b)) + return errors.InternalServerError("go.micro.client", string(rsp.Body)) } pRsp := &rpc.Response{} @@ -167,5 +163,7 @@ func (r *RpcClient) NewJsonRequest(service, method string, request interface{}) } func NewRpcClient() *RpcClient { - return &RpcClient{} + return &RpcClient{ + transport: transport.DefaultTransport, + } } diff --git a/cmd/cmd.go b/cmd/cmd.go index ff3c12d6..fdc069f2 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -11,6 +11,7 @@ import ( "github.com/myodc/go-micro/registry" "github.com/myodc/go-micro/server" "github.com/myodc/go-micro/store" + "github.com/myodc/go-micro/transport" ) var ( @@ -30,7 +31,6 @@ var ( cli.StringFlag{ Name: "broker_address", EnvVar: "MICRO_BROKER_ADDRESS", - Value: ":0", Usage: "Comma-separated list of broker addresses", }, cli.StringFlag{ @@ -42,7 +42,6 @@ var ( cli.StringFlag{ Name: "registry_address", EnvVar: "MICRO_REGISTRY_ADDRESS", - Value: "127.0.0.1:8500", Usage: "Comma-separated list of registry addresses", }, cli.StringFlag{ @@ -54,42 +53,63 @@ var ( cli.StringFlag{ Name: "store_address", EnvVar: "MICRO_STORE_ADDRESS", - Value: "127.0.0.1:8500", Usage: "Comma-separated list of store addresses", }, + cli.StringFlag{ + Name: "transport", + EnvVar: "MICRO_TRANSPORT", + Value: "http", + Usage: "Transport mechanism used; http, rabbitmq, etc", + }, + cli.StringFlag{ + Name: "transport_address", + EnvVar: "MICRO_TRANSPORT_ADDRESS", + Usage: "Comma-separated list of transport addresses", + }, } ) func Setup(c *cli.Context) error { server.Address = c.String("server_address") - broker_addrs := strings.Split(c.String("broker_address"), ",") + bAddrs := strings.Split(c.String("broker_address"), ",") switch c.String("broker") { case "http": - broker.DefaultBroker = broker.NewHttpBroker(broker_addrs) + broker.DefaultBroker = broker.NewHttpBroker(bAddrs) case "nats": - broker.DefaultBroker = broker.NewNatsBroker(broker_addrs) + broker.DefaultBroker = broker.NewNatsBroker(bAddrs) } - registry_addrs := strings.Split(c.String("registry_address"), ",") + rAddrs := strings.Split(c.String("registry_address"), ",") switch c.String("registry") { case "kubernetes": - registry.DefaultRegistry = registry.NewKubernetesRegistry(registry_addrs) + registry.DefaultRegistry = registry.NewKubernetesRegistry(rAddrs) case "consul": - registry.DefaultRegistry = registry.NewConsulRegistry(registry_addrs) + registry.DefaultRegistry = registry.NewConsulRegistry(rAddrs) } - store_addrs := strings.Split(c.String("store_address"), ",") + sAddrs := strings.Split(c.String("store_address"), ",") switch c.String("store") { case "memcached": - store.DefaultStore = store.NewMemcacheStore(store_addrs) + store.DefaultStore = store.NewMemcacheStore(sAddrs) case "memory": - store.DefaultStore = store.NewMemoryStore(store_addrs) + store.DefaultStore = store.NewMemoryStore(sAddrs) case "etcd": - store.DefaultStore = store.NewEtcdStore(store_addrs) + store.DefaultStore = store.NewEtcdStore(sAddrs) + } + + tAddrs := strings.Split(c.String("transport_address"), ",") + + switch c.String("transport") { + case "http": + transport.DefaultTransport = transport.NewHttpTransport(tAddrs) + case "rabbitmq": + transport.DefaultTransport = transport.NewRabbitMQTransport(tAddrs) + case "nats": + transport.DefaultTransport = transport.NewNatsTransport(tAddrs) } return nil diff --git a/server/rpc_server.go b/server/rpc_server.go index be7d9712..9e65b9c2 100644 --- a/server/rpc_server.go +++ b/server/rpc_server.go @@ -2,27 +2,23 @@ package server import ( "bytes" - "fmt" - "io/ioutil" - "net" "net/http" - "runtime/debug" - "strconv" "sync" - "github.com/bradfitz/http2" log "github.com/golang/glog" - "github.com/myodc/go-micro/errors" + "github.com/myodc/go-micro/transport" rpc "github.com/youtube/vitess/go/rpcplus" js "github.com/youtube/vitess/go/rpcplus/jsonrpc" pb "github.com/youtube/vitess/go/rpcplus/pbrpc" + "golang.org/x/net/context" ) type RpcServer struct { - mtx sync.RWMutex - rpc *rpc.Server - address string - exit chan chan error + mtx sync.RWMutex + address string + transport transport.Transport + rpc *rpc.Server + exit chan chan error } var ( @@ -30,92 +26,14 @@ var ( RpcPath = "/_rpc" ) -func executeRequestSafely(c *serverContext, r *http.Request) { - defer func() { - if x := recover(); x != nil { - log.Warningf("Panicked on request: %v", r) - log.Warningf("%v: %v", x, string(debug.Stack())) - err := errors.InternalServerError("go.micro.server", "Unexpected error") - c.WriteHeader(500) - c.Write([]byte(err.Error())) - } - }() - - http.DefaultServeMux.ServeHTTP(c, r) -} - -func (s *RpcServer) handler(w http.ResponseWriter, r *http.Request) { - c := &serverContext{ - req: &serverRequest{r}, - outHeader: w.Header(), - } - - ctxs.Lock() - ctxs.m[r] = c - ctxs.Unlock() - defer func() { - ctxs.Lock() - delete(ctxs.m, r) - ctxs.Unlock() - }() - - // Patch up RemoteAddr so it looks reasonable. - if addr := r.Header.Get("X-Forwarded-For"); len(addr) > 0 { - r.RemoteAddr = addr - } else { - // Should not normally reach here, but pick a sensible default anyway. - r.RemoteAddr = "127.0.0.1" - } - // The address in the headers will most likely be of these forms: - // 123.123.123.123 - // 2001:db8::1 - // net/http.Request.RemoteAddr is specified to be in "IP:port" form. - if _, _, err := net.SplitHostPort(r.RemoteAddr); err != nil { - // Assume the remote address is only a host; add a default port. - r.RemoteAddr = net.JoinHostPort(r.RemoteAddr, "80") - } - - executeRequestSafely(c, r) - c.outHeader = nil // make sure header changes aren't respected any more - - // Avoid nil Write call if c.Write is never called. - if c.outCode != 0 { - w.WriteHeader(c.outCode) - } - if c.outBody != nil { - w.Write(c.outBody) - } -} - -func (s *RpcServer) Address() string { - s.mtx.RLock() - defer s.mtx.RUnlock() - return s.address -} - -func (s *RpcServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { - serveCtx := getServerContext(req) - - // TODO: get user scope from context - // check access - - if req.Method != "POST" { - err := errors.BadRequest("go.micro.server", "Method not allowed") - http.Error(w, err.Error(), http.StatusMethodNotAllowed) - return - } - defer req.Body.Close() - - b, err := ioutil.ReadAll(req.Body) +func (s *RpcServer) serve(sock transport.Socket) { + // serveCtx := getServerContext(req) + msg, err := sock.Recv() if err != nil { - errr := errors.InternalServerError("go.micro.server", fmt.Sprintf("Error reading request body: %v", err)) - w.WriteHeader(500) - w.Write([]byte(errr.Error())) - log.Errorf("Erroring reading request body: %v", err) return } - rbq := bytes.NewBuffer(b) + rbq := bytes.NewBuffer(msg.Body) rsp := bytes.NewBuffer(nil) defer rsp.Reset() defer rbq.Reset() @@ -126,36 +44,34 @@ func (s *RpcServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { } var cc rpc.ServerCodec - switch req.Header.Get("Content-Type") { + switch msg.Header["Content-Type"] { case "application/octet-stream": cc = pb.NewServerCodec(buf) case "application/json": cc = js.NewServerCodec(buf) default: - err = errors.InternalServerError("go.micro.server", fmt.Sprintf("Unsupported content-type: %v", req.Header.Get("Content-Type"))) - w.WriteHeader(500) - w.Write([]byte(err.Error())) return + // return nil, errors.InternalServerError("go.micro.server", fmt.Sprintf("Unsupported content-type: %v", req.Header.Get("Content-Type"))) } - ctx := newContext(&ctx{}, serveCtx) - err = s.rpc.ServeRequestWithContext(ctx, cc) + //ctx := newContext(&ctx{}, serveCtx) + err = s.rpc.ServeRequestWithContext(context.Background(), cc) if err != nil { - // This should not be possible. - w.WriteHeader(500) - w.Write([]byte(err.Error())) - log.Errorf("Erroring serving request: %v", err) return } - w.Header().Set("Content-Type", req.Header.Get("Content-Type")) - w.Header().Set("Content-Length", strconv.Itoa(rsp.Len())) - w.Write(rsp.Bytes()) + sock.WriteHeader("Content-Type", msg.Header["Content-Type"]) + sock.Write(rsp.Bytes()) +} + +func (s *RpcServer) Address() string { + s.mtx.RLock() + address := s.address + s.mtx.RUnlock() + return address } func (s *RpcServer) Init() error { - log.Infof("Rpc handler %s", RpcPath) - http.Handle(RpcPath, s) return nil } @@ -180,28 +96,22 @@ func (s *RpcServer) Register(r Receiver) error { func (s *RpcServer) Start() error { registerHealthChecker(http.DefaultServeMux) - l, err := net.Listen("tcp", s.address) + ts, err := s.transport.NewServer(Name, s.address) if err != nil { return err } - log.Infof("Listening on %s", l.Addr().String()) + log.Infof("Listening on %s", ts.Addr()) - s.mtx.Lock() - s.address = l.Addr().String() - s.mtx.Unlock() + s.mtx.RLock() + s.address = ts.Addr() + s.mtx.RUnlock() - srv := &http.Server{ - Handler: http.HandlerFunc(s.handler), - } - - http2.ConfigureServer(srv, nil) - - go srv.Serve(l) + go ts.Serve(s.serve) go func() { ch := <-s.exit - ch <- l.Close() + ch <- ts.Close() }() return nil @@ -215,8 +125,9 @@ func (s *RpcServer) Stop() error { func NewRpcServer(address string) *RpcServer { return &RpcServer{ - rpc: rpc.NewServer(), - address: address, - exit: make(chan chan error), + address: address, + transport: transport.DefaultTransport, + rpc: rpc.NewServer(), + exit: make(chan chan error), } } diff --git a/transport/buffer.go b/transport/buffer.go new file mode 100644 index 00000000..b3fac087 --- /dev/null +++ b/transport/buffer.go @@ -0,0 +1,13 @@ +package transport + +import ( + "io" +) + +type buffer struct { + io.ReadWriter +} + +func (b *buffer) Close() error { + return nil +} diff --git a/transport/http_transport.go b/transport/http_transport.go new file mode 100644 index 00000000..abd0a82b --- /dev/null +++ b/transport/http_transport.go @@ -0,0 +1,170 @@ +package transport + +import ( + "bytes" + "io/ioutil" + "net" + "net/http" + "net/url" +) + +type headerRoundTripper struct { + r http.RoundTripper +} + +type HttpTransport struct { + client *http.Client +} + +type HttpTransportClient struct { + client *http.Client + target string +} + +type HttpTransportSocket struct { + r *http.Request + w http.ResponseWriter +} + +type HttpTransportServer struct { + listener net.Listener +} + +func (t *headerRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { + r.Header.Set("X-Client-Version", "1.0") + return t.r.RoundTrip(r) +} + +func (h *HttpTransportClient) Send(m *Message) (*Message, error) { + header := make(http.Header) + + for k, v := range m.Header { + header.Set(k, v) + } + + reqB := bytes.NewBuffer(m.Body) + defer reqB.Reset() + buf := &buffer{ + reqB, + } + + hreq := &http.Request{ + Method: "POST", + URL: &url.URL{ + Scheme: "http", + Host: h.target, + // Path: path, + }, + Header: header, + Body: buf, + ContentLength: int64(reqB.Len()), + Host: h.target, + } + + rsp, err := h.client.Do(hreq) + if err != nil { + return nil, err + } + defer rsp.Body.Close() + + b, err := ioutil.ReadAll(rsp.Body) + if err != nil { + return nil, err + } + + mr := &Message{ + Header: make(map[string]string), + Body: b, + } + + for k, v := range rsp.Header { + if len(v) > 0 { + mr.Header[k] = v[0] + } else { + mr.Header[k] = "" + } + } + + return mr, nil +} + +func (h *HttpTransportClient) Close() error { + return nil +} + +func (h *HttpTransportSocket) Recv() (*Message, error) { + b, err := ioutil.ReadAll(h.r.Body) + if err != nil { + return nil, err + } + + m := &Message{ + Header: make(map[string]string), + Body: b, + } + + for k, v := range h.r.Header { + if len(v) > 0 { + m.Header[k] = v[0] + } else { + m.Header[k] = "" + } + } + + return m, nil +} + +func (h *HttpTransportSocket) WriteHeader(k string, v string) { + h.w.Header().Set(k, v) +} + +func (h *HttpTransportSocket) Write(b []byte) error { + _, err := h.w.Write(b) + return err +} + +func (h *HttpTransportServer) Addr() string { + return h.listener.Addr().String() +} + +func (h *HttpTransportServer) Close() error { + return h.listener.Close() +} + +func (h *HttpTransportServer) Serve(fn func(Socket)) error { + srv := &http.Server{ + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fn(&HttpTransportSocket{ + r: r, + w: w, + }) + }), + } + + return srv.Serve(h.listener) +} + +func (h *HttpTransport) NewClient(name, addr string) (Client, error) { + return &HttpTransportClient{ + client: h.client, + target: addr, + }, nil +} + +func (h *HttpTransport) NewServer(name, addr string) (Server, error) { + l, err := net.Listen("tcp", addr) + if err != nil { + return nil, err + } + + return &HttpTransportServer{ + listener: l, + }, nil +} + +func NewHttpTransport(addrs []string) *HttpTransport { + client := &http.Client{} + client.Transport = &headerRoundTripper{http.DefaultTransport} + + return &HttpTransport{client: client} +} diff --git a/transport/nats_transport.go b/transport/nats_transport.go new file mode 100644 index 00000000..49a896b7 --- /dev/null +++ b/transport/nats_transport.go @@ -0,0 +1,151 @@ +package transport + +import ( + "bytes" + "encoding/json" + "strings" + "time" + + "github.com/apcera/nats" +) + +type NatsTransport struct{} + +type NatsTransportClient struct { + conn *nats.Conn + target string +} + +type NatsTransportSocket struct { + m *nats.Msg + hdr map[string]string + buf *bytes.Buffer +} + +type NatsTransportServer struct { + conn *nats.Conn + name string + exit chan bool +} + +func (n *NatsTransportClient) Send(m *Message) (*Message, error) { + b, err := json.Marshal(m) + if err != nil { + return nil, err + } + + rsp, err := n.conn.Request(n.target, b, time.Second*10) + if err != nil { + return nil, err + } + + var mr *Message + if err := json.Unmarshal(rsp.Data, &mr); err != nil { + return nil, err + } + + return mr, nil +} + +func (n *NatsTransportClient) Close() error { + n.conn.Close() + return nil +} + +func (n *NatsTransportSocket) Recv() (*Message, error) { + var m *Message + + if err := json.Unmarshal(n.m.Data, &m); err != nil { + return nil, err + } + + return m, nil +} + +func (n *NatsTransportSocket) WriteHeader(k string, v string) { + n.hdr[k] = v +} + +func (n *NatsTransportSocket) Write(b []byte) error { + _, err := n.buf.Write(b) + return err +} + +func (n *NatsTransportServer) Addr() string { + return "127.0.0.1:4222" +} + +func (n *NatsTransportServer) Close() error { + n.exit <- true + n.conn.Close() + return nil +} + +func (n *NatsTransportServer) Serve(fn func(Socket)) error { + s, err := n.conn.QueueSubscribe(n.name, "queue:"+n.name, func(m *nats.Msg) { + buf := bytes.NewBuffer(nil) + hdr := make(map[string]string) + + fn(&NatsTransportSocket{ + m: m, + hdr: hdr, + buf: buf, + }) + + mrsp := &Message{ + Header: hdr, + Body: buf.Bytes(), + } + + b, err := json.Marshal(mrsp) + if err != nil { + return + } + + n.conn.Publish(m.Reply, b) + buf.Reset() + }) + if err != nil { + return err + } + + <-n.exit + return s.Unsubscribe() +} + +func (n *NatsTransport) NewClient(name, addr string) (Client, error) { + if !strings.HasPrefix(addr, "nats://") { + addr = nats.DefaultURL + } + + c, err := nats.Connect(addr) + if err != nil { + return nil, err + } + + return &NatsTransportClient{ + conn: c, + target: name, + }, nil +} + +func (n *NatsTransport) NewServer(name, addr string) (Server, error) { + if !strings.HasPrefix(addr, "nats://") { + addr = nats.DefaultURL + } + + c, err := nats.Connect(addr) + if err != nil { + return nil, err + } + + return &NatsTransportServer{ + name: name, + conn: c, + exit: make(chan bool, 1), + }, nil +} + +func NewNatsTransport(addrs []string) *NatsTransport { + return &NatsTransport{} +} diff --git a/transport/rabbitmq_channel.go b/transport/rabbitmq_channel.go new file mode 100644 index 00000000..fcd9fdb7 --- /dev/null +++ b/transport/rabbitmq_channel.go @@ -0,0 +1,128 @@ +package transport + +// +// All credit to Mondo +// https://github.com/mondough/typhon +// + +import ( + "errors" + + "github.com/nu7hatch/gouuid" + "github.com/streadway/amqp" +) + +type RabbitChannel struct { + uuid string + connection *amqp.Connection + channel *amqp.Channel +} + +func NewRabbitChannel(conn *amqp.Connection) (*RabbitChannel, error) { + id, err := uuid.NewV4() + if err != nil { + return nil, err + } + rabbitCh := &RabbitChannel{ + uuid: id.String(), + connection: conn, + } + if err := rabbitCh.Connect(); err != nil { + return nil, err + } + return rabbitCh, nil + +} + +func (r *RabbitChannel) Connect() error { + var err error + r.channel, err = r.connection.Channel() + if err != nil { + return err + } + return nil +} + +func (r *RabbitChannel) Close() error { + if r.channel == nil { + return errors.New("Channel is nil") + } + return r.channel.Close() +} + +func (r *RabbitChannel) Publish(exchange, routingKey string, message amqp.Publishing) error { + if r.channel == nil { + return errors.New("Channel is nil") + } + return r.channel.Publish(exchange, routingKey, false, false, message) +} + +func (r *RabbitChannel) DeclareExchange(exchange string) error { + return r.channel.ExchangeDeclare( + exchange, // name + "topic", // kind + false, // durable + false, // autoDelete + false, // internal + false, // noWait + nil, // args + ) +} + +func (r *RabbitChannel) DeclareQueue(queue string) error { + _, err := r.channel.QueueDeclare( + queue, // name + false, // durable + true, // autoDelete + false, // exclusive + false, // noWait + nil, // args + ) + return err +} + +func (r *RabbitChannel) DeclareDurableQueue(queue string) error { + _, err := r.channel.QueueDeclare( + queue, // name + true, // durable + false, // autoDelete + false, // exclusive + false, // noWait + nil, // args + ) + return err +} + +func (r *RabbitChannel) DeclareReplyQueue(queue string) error { + _, err := r.channel.QueueDeclare( + queue, // name + false, // durable + true, // autoDelete + true, // exclusive + false, // noWait + nil, // args + ) + return err +} + +func (r *RabbitChannel) ConsumeQueue(queue string) (<-chan amqp.Delivery, error) { + return r.channel.Consume( + queue, // queue + r.uuid, // consumer + true, // autoAck + false, // exclusive + false, // nolocal + false, // nowait + nil, // args + ) +} + +func (r *RabbitChannel) BindQueue(queue, exchange string) error { + return r.channel.QueueBind( + queue, // name + queue, // key + exchange, // exchange + false, // noWait + nil, // args + ) +} diff --git a/transport/rabbitmq_connection.go b/transport/rabbitmq_connection.go new file mode 100644 index 00000000..0e8424d3 --- /dev/null +++ b/transport/rabbitmq_connection.go @@ -0,0 +1,143 @@ +package transport + +// +// All credit to Mondo +// https://github.com/mondough/typhon +// + +import ( + "sync" + "time" + + "github.com/streadway/amqp" +) + +var ( + DefaultExchange = "micro" + DefaultRabbitURL = "amqp://guest:guest@127.0.0.1:5672" +) + +type RabbitConnection struct { + Connection *amqp.Connection + Channel *RabbitChannel + ExchangeChannel *RabbitChannel + notify chan bool + exchange string + url string + + connected bool + + mtx sync.Mutex + closeChan chan struct{} + closed bool +} + +func (r *RabbitConnection) Init() chan bool { + go r.Connect(r.notify) + return r.notify +} + +func (r *RabbitConnection) Connect(connected chan bool) { + for { + if err := r.tryToConnect(); err != nil { + time.Sleep(1 * time.Second) + continue + } + connected <- true + r.connected = true + notifyClose := make(chan *amqp.Error) + r.Connection.NotifyClose(notifyClose) + + // Block until we get disconnected, or shut down + select { + case <-notifyClose: + // Spin around and reconnect + r.connected = false + case <-r.closeChan: + // Shut down connection + if err := r.Connection.Close(); err != nil { + } + r.connected = false + return + } + } +} + +func (r *RabbitConnection) IsConnected() bool { + return r.connected +} + +func (r *RabbitConnection) Close() { + r.mtx.Lock() + defer r.mtx.Unlock() + + if r.closed { + return + } + + close(r.closeChan) + r.closed = true +} + +func (r *RabbitConnection) tryToConnect() error { + var err error + r.Connection, err = amqp.Dial(r.url) + if err != nil { + return err + } + r.Channel, err = NewRabbitChannel(r.Connection) + if err != nil { + return err + } + r.Channel.DeclareExchange(r.exchange) + r.ExchangeChannel, err = NewRabbitChannel(r.Connection) + if err != nil { + return err + } + return nil +} + +func (r *RabbitConnection) Consume(serverName string) (<-chan amqp.Delivery, error) { + consumerChannel, err := NewRabbitChannel(r.Connection) + if err != nil { + return nil, err + } + + err = consumerChannel.DeclareQueue(serverName) + if err != nil { + return nil, err + } + + deliveries, err := consumerChannel.ConsumeQueue(serverName) + if err != nil { + return nil, err + } + + err = consumerChannel.BindQueue(serverName, r.exchange) + if err != nil { + return nil, err + } + + return deliveries, nil +} + +func (r *RabbitConnection) Publish(exchange, routingKey string, msg amqp.Publishing) error { + return r.ExchangeChannel.Publish(exchange, routingKey, msg) +} + +func NewRabbitConnection(exchange, url string) *RabbitConnection { + if len(url) == 0 { + url = DefaultRabbitURL + } + + if len(exchange) == 0 { + exchange = DefaultExchange + } + + return &RabbitConnection{ + exchange: DefaultExchange, + url: DefaultRabbitURL, + notify: make(chan bool, 1), + closeChan: make(chan struct{}), + } +} diff --git a/transport/rabbitmq_transport.go b/transport/rabbitmq_transport.go new file mode 100644 index 00000000..78a85e1c --- /dev/null +++ b/transport/rabbitmq_transport.go @@ -0,0 +1,232 @@ +package transport + +import ( + "bytes" + "fmt" + "sync" + "time" + + "errors" + uuid "github.com/nu7hatch/gouuid" + "github.com/streadway/amqp" +) + +type RabbitMQTransport struct { + conn *RabbitConnection +} + +type RabbitMQTransportClient struct { + once sync.Once + conn *RabbitConnection + target string + replyTo string + + sync.Mutex + inflight map[string]chan amqp.Delivery +} + +type RabbitMQTransportSocket struct { + d *amqp.Delivery + hdr amqp.Table + buf *bytes.Buffer +} + +type RabbitMQTransportServer struct { + conn *RabbitConnection + name string +} + +func (h *RabbitMQTransportClient) init() { + <-h.conn.Init() + if err := h.conn.Channel.DeclareReplyQueue(h.replyTo); err != nil { + return + } + deliveries, err := h.conn.Channel.ConsumeQueue(h.replyTo) + if err != nil { + return + } + go func() { + for delivery := range deliveries { + go h.handle(delivery) + } + }() +} + +func (h *RabbitMQTransportClient) handle(delivery amqp.Delivery) { + ch := h.getReq(delivery.CorrelationId) + if ch == nil { + return + } + select { + case ch <- delivery: + default: + } +} + +func (h *RabbitMQTransportClient) putReq(id string) chan amqp.Delivery { + h.Lock() + ch := make(chan amqp.Delivery, 1) + h.inflight[id] = ch + h.Unlock() + return ch +} + +func (h *RabbitMQTransportClient) getReq(id string) chan amqp.Delivery { + h.Lock() + defer h.Unlock() + if ch, ok := h.inflight[id]; ok { + delete(h.inflight, id) + return ch + } + return nil +} + +func (h *RabbitMQTransportClient) Send(m *Message) (*Message, error) { + h.once.Do(h.init) + + if !h.conn.IsConnected() { + return nil, errors.New("Not connected to AMQP") + } + + id, err := uuid.NewV4() + if err != nil { + return nil, err + } + + replyChan := h.putReq(id.String()) + + headers := amqp.Table{} + + for k, v := range m.Header { + headers[k] = v + } + + message := amqp.Publishing{ + CorrelationId: id.String(), + Timestamp: time.Now().UTC(), + Body: m.Body, + ReplyTo: h.replyTo, + Headers: headers, + } + + if err := h.conn.Publish("micro", h.target, message); err != nil { + h.getReq(id.String()) + return nil, err + } + + select { + case d := <-replyChan: + mr := &Message{ + Header: make(map[string]string), + Body: d.Body, + } + + for k, v := range d.Headers { + mr.Header[k] = fmt.Sprintf("%v", v) + } + + return mr, nil + case <-time.After(time.Second * 10): + return nil, errors.New("timed out") + } +} + +func (h *RabbitMQTransportClient) Close() error { + h.conn.Close() + return nil +} + +func (h *RabbitMQTransportSocket) Recv() (*Message, error) { + m := &Message{ + Header: make(map[string]string), + Body: h.d.Body, + } + + for k, v := range h.d.Headers { + m.Header[k] = fmt.Sprintf("%v", v) + } + + return m, nil +} + +func (h *RabbitMQTransportSocket) WriteHeader(k string, v string) { + h.hdr[k] = v +} + +func (h *RabbitMQTransportSocket) Write(b []byte) error { + _, err := h.buf.Write(b) + return err +} + +func (h *RabbitMQTransportServer) Addr() string { + return h.conn.Connection.LocalAddr().String() +} + +func (h *RabbitMQTransportServer) Close() error { + h.conn.Close() + return nil +} + +func (h *RabbitMQTransportServer) Serve(fn func(Socket)) error { + deliveries, err := h.conn.Consume(h.name) + if err != nil { + return err + } + + handler := func(d amqp.Delivery) { + buf := bytes.NewBuffer(nil) + headers := amqp.Table{} + + fn(&RabbitMQTransportSocket{ + d: &d, + hdr: headers, + buf: buf, + }) + + msg := amqp.Publishing{ + CorrelationId: d.CorrelationId, + Timestamp: time.Now().UTC(), + Body: buf.Bytes(), + Headers: headers, + } + + h.conn.Publish("", d.ReplyTo, msg) + buf.Reset() + } + + for d := range deliveries { + go handler(d) + } + + return nil +} + +func (h *RabbitMQTransport) NewClient(name, addr string) (Client, error) { + id, err := uuid.NewV4() + if err != nil { + return nil, err + } + + return &RabbitMQTransportClient{ + conn: h.conn, + target: name, + inflight: make(map[string]chan amqp.Delivery), + replyTo: fmt.Sprintf("replyTo-%s", id.String()), + }, nil +} + +func (h *RabbitMQTransport) NewServer(name, addr string) (Server, error) { + conn := NewRabbitConnection("", "") + <-conn.Init() + + return &RabbitMQTransportServer{ + name: name, + conn: conn, + }, nil +} + +func NewRabbitMQTransport(addrs []string) *RabbitMQTransport { + return &RabbitMQTransport{ + conn: NewRabbitConnection("", ""), + } +} diff --git a/transport/transport.go b/transport/transport.go new file mode 100644 index 00000000..6c349aee --- /dev/null +++ b/transport/transport.go @@ -0,0 +1,40 @@ +package transport + +type Message struct { + Header map[string]string + Body []byte +} + +type Socket interface { + Recv() (*Message, error) + WriteHeader(string, string) + Write([]byte) error +} + +type Client interface { + Send(*Message) (*Message, error) + Close() error +} + +type Server interface { + Addr() string + Close() error + Serve(func(Socket)) error +} + +type Transport interface { + NewClient(name, addr string) (Client, error) + NewServer(name, addr string) (Server, error) +} + +var ( + DefaultTransport Transport = NewHttpTransport([]string{}) +) + +func NewClient(name, addr string) (Client, error) { + return DefaultTransport.NewClient(name, addr) +} + +func NewServer(name, addr string) (Server, error) { + return DefaultTransport.NewServer(name, addr) +}