diff --git a/client/rpc_client.go b/client/rpc_client.go index 3a2d579d..d1dc2a74 100644 --- a/client/rpc_client.go +++ b/client/rpc_client.go @@ -86,7 +86,7 @@ func (r *RpcClient) call(address, path string, request Request, response interfa msg.Header["Content-Type"] = request.ContentType() - c, err := r.opts.transport.NewClient(address) + c, err := r.opts.transport.Dial(address) if err != nil { return errors.InternalServerError("go.micro.client", fmt.Sprintf("Error sending request: %v", err)) } diff --git a/server/rpc_server.go b/server/rpc_server.go index 4f45d1b2..9955c8ea 100644 --- a/server/rpc_server.go +++ b/server/rpc_server.go @@ -26,10 +26,10 @@ var ( RpcPath = "/_rpc" ) -func (s *RpcServer) serve(sock transport.Socket) { +func (s *RpcServer) accept(sock transport.Socket) { // serveCtx := getServerContext(req) - msg, err := sock.Recv() - if err != nil { + var msg transport.Message + if err := sock.Recv(&msg); err != nil { return } @@ -55,13 +55,16 @@ func (s *RpcServer) serve(sock transport.Socket) { } //ctx := newContext(&ctx{}, serveCtx) - err = s.rpc.ServeRequestWithContext(context.Background(), cc) - if err != nil { + if err := s.rpc.ServeRequestWithContext(context.Background(), cc); err != nil { return } - sock.WriteHeader("Content-Type", msg.Header["Content-Type"]) - sock.Write(rsp.Bytes()) + sock.Send(&transport.Message{ + Header: map[string]string{ + "Content-Type": msg.Header["Content-Type"], + }, + Body: rsp.Bytes(), + }) } func (s *RpcServer) Address() string { @@ -96,7 +99,7 @@ func (s *RpcServer) Register(r Receiver) error { func (s *RpcServer) Start() error { registerHealthChecker(http.DefaultServeMux) - ts, err := s.opts.transport.NewServer(s.address) + ts, err := s.opts.transport.Listen(s.address) if err != nil { return err } @@ -107,7 +110,7 @@ func (s *RpcServer) Start() error { s.address = ts.Addr() s.mtx.RUnlock() - go ts.Serve(s.serve) + go ts.Accept(s.accept) go func() { ch := <-s.exit diff --git a/transport/http_transport.go b/transport/http_transport.go index c63d854a..278a524e 100644 --- a/transport/http_transport.go +++ b/transport/http_transport.go @@ -2,6 +2,7 @@ package transport import ( "bytes" + "errors" "io/ioutil" "net" "net/http" @@ -26,7 +27,7 @@ type HttpTransportSocket struct { w http.ResponseWriter } -type HttpTransportServer struct { +type HttpTransportListener struct { listener net.Listener } @@ -92,46 +93,55 @@ func (h *HttpTransportClient) Close() error { return nil } -func (h *HttpTransportSocket) Recv() (*Message, error) { - b, err := ioutil.ReadAll(h.r.Body) - if err != nil { - return nil, err +func (h *HttpTransportSocket) Recv(m *Message) error { + if m == nil { + return errors.New("message passed in is nil") } - m := &Message{ + b, err := ioutil.ReadAll(h.r.Body) + if err != nil { + return err + } + + mr := &Message{ Header: make(map[string]string), Body: b, } for k, v := range h.r.Header { if len(v) > 0 { - m.Header[k] = v[0] + mr.Header[k] = v[0] } else { - m.Header[k] = "" + mr.Header[k] = "" } } - return m, nil + *m = *mr + return nil } -func (h *HttpTransportSocket) WriteHeader(k string, v string) { - h.w.Header().Set(k, v) -} +func (h *HttpTransportSocket) Send(m *Message) error { + for k, v := range m.Header { + h.w.Header().Set(k, v) + } -func (h *HttpTransportSocket) Write(b []byte) error { - _, err := h.w.Write(b) + _, err := h.w.Write(m.Body) return err } -func (h *HttpTransportServer) Addr() string { +func (h *HttpTransportSocket) Close() error { + return nil +} + +func (h *HttpTransportListener) Addr() string { return h.listener.Addr().String() } -func (h *HttpTransportServer) Close() error { +func (h *HttpTransportListener) Close() error { return h.listener.Close() } -func (h *HttpTransportServer) Serve(fn func(Socket)) error { +func (h *HttpTransportListener) Accept(fn func(Socket)) error { srv := &http.Server{ Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fn(&HttpTransportSocket{ @@ -144,20 +154,20 @@ func (h *HttpTransportServer) Serve(fn func(Socket)) error { return srv.Serve(h.listener) } -func (h *HttpTransport) NewClient(addr string) (Client, error) { +func (h *HttpTransport) Dial(addr string) (Client, error) { return &HttpTransportClient{ ht: h, addr: addr, }, nil } -func (h *HttpTransport) NewServer(addr string) (Server, error) { +func (h *HttpTransport) Listen(addr string) (Listener, error) { l, err := net.Listen("tcp", addr) if err != nil { return nil, err } - return &HttpTransportServer{ + return &HttpTransportListener{ listener: l, }, nil } diff --git a/transport/nats_transport.go b/transport/nats_transport.go index fe0fb188..faa0296e 100644 --- a/transport/nats_transport.go +++ b/transport/nats_transport.go @@ -1,8 +1,8 @@ package transport import ( - "bytes" "encoding/json" + "errors" "strings" "time" @@ -19,12 +19,11 @@ type NatsTransportClient struct { } type NatsTransportSocket struct { - m *nats.Msg - hdr map[string]string - buf *bytes.Buffer + conn *nats.Conn + m *nats.Msg } -type NatsTransportServer struct { +type NatsTransportListener struct { conn *nats.Conn addr string exit chan bool @@ -54,58 +53,45 @@ func (n *NatsTransportClient) Close() error { return nil } -func (n *NatsTransportSocket) Recv() (*Message, error) { - var m *Message - - if err := json.Unmarshal(n.m.Data, &m); err != nil { - return nil, err +func (n *NatsTransportSocket) Recv(m *Message) error { + if m == nil { + return errors.New("message passed in is nil") } - return m, nil + if err := json.Unmarshal(n.m.Data, &m); err != nil { + return err + } + return nil } -func (n *NatsTransportSocket) WriteHeader(k string, v string) { - n.hdr[k] = v +func (n *NatsTransportSocket) Send(m *Message) error { + b, err := json.Marshal(m) + if err != nil { + return err + } + return n.conn.Publish(n.m.Reply, b) } -func (n *NatsTransportSocket) Write(b []byte) error { - _, err := n.buf.Write(b) - return err +func (n *NatsTransportSocket) Close() error { + return nil } -func (n *NatsTransportServer) Addr() string { +func (n *NatsTransportListener) Addr() string { return n.addr } -func (n *NatsTransportServer) Close() error { +func (n *NatsTransportListener) Close() error { n.exit <- true n.conn.Close() return nil } -func (n *NatsTransportServer) Serve(fn func(Socket)) error { +func (n *NatsTransportListener) Accept(fn func(Socket)) error { s, err := n.conn.Subscribe(n.addr, func(m *nats.Msg) { - buf := bytes.NewBuffer(nil) - hdr := make(map[string]string) - fn(&NatsTransportSocket{ - m: m, - hdr: hdr, - buf: buf, + conn: n.conn, + m: m, }) - - mrsp := &Message{ - Header: hdr, - Body: buf.Bytes(), - } - - b, err := json.Marshal(mrsp) - if err != nil { - return - } - - n.conn.Publish(m.Reply, b) - buf.Reset() }) if err != nil { return err @@ -115,7 +101,7 @@ func (n *NatsTransportServer) Serve(fn func(Socket)) error { return s.Unsubscribe() } -func (n *NatsTransport) NewClient(addr string) (Client, error) { +func (n *NatsTransport) Dial(addr string) (Client, error) { cAddr := nats.DefaultURL if len(n.addrs) > 0 && strings.HasPrefix(n.addrs[0], "nats://") { @@ -133,7 +119,7 @@ func (n *NatsTransport) NewClient(addr string) (Client, error) { }, nil } -func (n *NatsTransport) NewServer(addr string) (Server, error) { +func (n *NatsTransport) Listen(addr string) (Listener, error) { cAddr := nats.DefaultURL if len(n.addrs) > 0 && strings.HasPrefix(n.addrs[0], "nats://") { @@ -145,7 +131,7 @@ func (n *NatsTransport) NewServer(addr string) (Server, error) { return nil, err } - return &NatsTransportServer{ + return &NatsTransportListener{ addr: nats.NewInbox(), conn: c, exit: make(chan bool, 1), diff --git a/transport/rabbitmq_channel.go b/transport/rabbitmq_channel.go index 0c215cd9..1e2633d4 100644 --- a/transport/rabbitmq_channel.go +++ b/transport/rabbitmq_channel.go @@ -11,18 +11,18 @@ import ( "github.com/streadway/amqp" ) -type RabbitChannel struct { +type rabbitMQChannel struct { uuid string connection *amqp.Connection channel *amqp.Channel } -func NewRabbitChannel(conn *amqp.Connection) (*RabbitChannel, error) { +func newRabbitChannel(conn *amqp.Connection) (*rabbitMQChannel, error) { id, err := uuid.NewV4() if err != nil { return nil, err } - rabbitCh := &RabbitChannel{ + rabbitCh := &rabbitMQChannel{ uuid: id.String(), connection: conn, } @@ -33,7 +33,7 @@ func NewRabbitChannel(conn *amqp.Connection) (*RabbitChannel, error) { } -func (r *RabbitChannel) Connect() error { +func (r *rabbitMQChannel) Connect() error { var err error r.channel, err = r.connection.Channel() if err != nil { @@ -42,21 +42,21 @@ func (r *RabbitChannel) Connect() error { return nil } -func (r *RabbitChannel) Close() error { +func (r *rabbitMQChannel) Close() error { if r.channel == nil { return errors.New("Channel is nil") } return r.channel.Close() } -func (r *RabbitChannel) Publish(exchange, key string, message amqp.Publishing) error { +func (r *rabbitMQChannel) Publish(exchange, key string, message amqp.Publishing) error { if r.channel == nil { return errors.New("Channel is nil") } return r.channel.Publish(exchange, key, false, false, message) } -func (r *RabbitChannel) DeclareExchange(exchange string) error { +func (r *rabbitMQChannel) DeclareExchange(exchange string) error { return r.channel.ExchangeDeclare( exchange, // name "topic", // kind @@ -68,7 +68,7 @@ func (r *RabbitChannel) DeclareExchange(exchange string) error { ) } -func (r *RabbitChannel) DeclareQueue(queue string) error { +func (r *rabbitMQChannel) DeclareQueue(queue string) error { _, err := r.channel.QueueDeclare( queue, // name false, // durable @@ -80,7 +80,7 @@ func (r *RabbitChannel) DeclareQueue(queue string) error { return err } -func (r *RabbitChannel) DeclareDurableQueue(queue string) error { +func (r *rabbitMQChannel) DeclareDurableQueue(queue string) error { _, err := r.channel.QueueDeclare( queue, // name true, // durable @@ -92,7 +92,7 @@ func (r *RabbitChannel) DeclareDurableQueue(queue string) error { return err } -func (r *RabbitChannel) DeclareReplyQueue(queue string) error { +func (r *rabbitMQChannel) DeclareReplyQueue(queue string) error { _, err := r.channel.QueueDeclare( queue, // name false, // durable @@ -104,7 +104,7 @@ func (r *RabbitChannel) DeclareReplyQueue(queue string) error { return err } -func (r *RabbitChannel) ConsumeQueue(queue string) (<-chan amqp.Delivery, error) { +func (r *rabbitMQChannel) ConsumeQueue(queue string) (<-chan amqp.Delivery, error) { return r.channel.Consume( queue, // queue r.uuid, // consumer @@ -116,7 +116,7 @@ func (r *RabbitChannel) ConsumeQueue(queue string) (<-chan amqp.Delivery, error) ) } -func (r *RabbitChannel) BindQueue(queue, exchange string) error { +func (r *rabbitMQChannel) BindQueue(queue, exchange string) error { return r.channel.QueueBind( queue, // name queue, // key diff --git a/transport/rabbitmq_connection.go b/transport/rabbitmq_connection.go index a4353b01..ecd9c13b 100644 --- a/transport/rabbitmq_connection.go +++ b/transport/rabbitmq_connection.go @@ -17,10 +17,10 @@ var ( DefaultRabbitURL = "amqp://guest:guest@127.0.0.1:5672" ) -type RabbitConnection struct { +type rabbitMQConn struct { Connection *amqp.Connection - Channel *RabbitChannel - ExchangeChannel *RabbitChannel + Channel *rabbitMQChannel + ExchangeChannel *rabbitMQChannel notify chan bool exchange string url string @@ -32,12 +32,33 @@ type RabbitConnection struct { closed bool } -func (r *RabbitConnection) Init() chan bool { +func newRabbitMQConn(exchange string, urls []string) *rabbitMQConn { + var url string + + if len(urls) > 0 && strings.HasPrefix(urls[0], "amqp://") { + url = urls[0] + } else { + url = DefaultRabbitURL + } + + if len(exchange) == 0 { + exchange = DefaultExchange + } + + return &rabbitMQConn{ + exchange: exchange, + url: url, + notify: make(chan bool, 1), + close: make(chan bool), + } +} + +func (r *rabbitMQConn) Init() chan bool { go r.Connect(r.notify) return r.notify } -func (r *RabbitConnection) Connect(connected chan bool) { +func (r *rabbitMQConn) Connect(connected chan bool) { for { if err := r.tryToConnect(); err != nil { time.Sleep(1 * time.Second) @@ -63,11 +84,11 @@ func (r *RabbitConnection) Connect(connected chan bool) { } } -func (r *RabbitConnection) IsConnected() bool { +func (r *rabbitMQConn) IsConnected() bool { return r.connected } -func (r *RabbitConnection) Close() { +func (r *rabbitMQConn) Close() { r.mtx.Lock() defer r.mtx.Unlock() @@ -79,26 +100,26 @@ func (r *RabbitConnection) Close() { r.closed = true } -func (r *RabbitConnection) tryToConnect() error { +func (r *rabbitMQConn) tryToConnect() error { var err error r.Connection, err = amqp.Dial(r.url) if err != nil { return err } - r.Channel, err = NewRabbitChannel(r.Connection) + r.Channel, err = newRabbitChannel(r.Connection) if err != nil { return err } r.Channel.DeclareExchange(r.exchange) - r.ExchangeChannel, err = NewRabbitChannel(r.Connection) + r.ExchangeChannel, err = newRabbitChannel(r.Connection) if err != nil { return err } return nil } -func (r *RabbitConnection) Consume(queue string) (<-chan amqp.Delivery, error) { - consumerChannel, err := NewRabbitChannel(r.Connection) +func (r *rabbitMQConn) Consume(queue string) (<-chan amqp.Delivery, error) { + consumerChannel, err := newRabbitChannel(r.Connection) if err != nil { return nil, err } @@ -121,27 +142,6 @@ func (r *RabbitConnection) Consume(queue string) (<-chan amqp.Delivery, error) { return deliveries, nil } -func (r *RabbitConnection) Publish(exchange, key string, msg amqp.Publishing) error { +func (r *rabbitMQConn) Publish(exchange, key string, msg amqp.Publishing) error { return r.ExchangeChannel.Publish(exchange, key, msg) } - -func NewRabbitConnection(exchange string, urls []string) *RabbitConnection { - var url string - - if len(urls) > 0 && strings.HasPrefix(urls[0], "amqp://") { - url = urls[0] - } else { - url = DefaultRabbitURL - } - - if len(exchange) == 0 { - exchange = DefaultExchange - } - - return &RabbitConnection{ - exchange: exchange, - url: url, - notify: make(chan bool, 1), - close: make(chan bool), - } -} diff --git a/transport/rabbitmq_transport.go b/transport/rabbitmq_transport.go index c33392cb..731cfca2 100644 --- a/transport/rabbitmq_transport.go +++ b/transport/rabbitmq_transport.go @@ -1,7 +1,6 @@ package transport import ( - "bytes" "fmt" "sync" "time" @@ -12,7 +11,7 @@ import ( ) type RabbitMQTransport struct { - conn *RabbitConnection + conn *rabbitMQConn addrs []string } @@ -27,13 +26,12 @@ type RabbitMQTransportClient struct { } type RabbitMQTransportSocket struct { - d *amqp.Delivery - hdr amqp.Table - buf *bytes.Buffer + conn *rabbitMQConn + d *amqp.Delivery } -type RabbitMQTransportServer struct { - conn *RabbitConnection +type RabbitMQTransportListener struct { + conn *rabbitMQConn addr string } @@ -136,62 +134,63 @@ func (r *RabbitMQTransportClient) Close() error { return nil } -func (r *RabbitMQTransportSocket) Recv() (*Message, error) { - m := &Message{ +func (r *RabbitMQTransportSocket) Recv(m *Message) error { + if m == nil { + return errors.New("message passed in is nil") + } + + mr := &Message{ Header: make(map[string]string), Body: r.d.Body, } for k, v := range r.d.Headers { - m.Header[k] = fmt.Sprintf("%v", v) + mr.Header[k] = fmt.Sprintf("%v", v) } - return m, nil + *m = *mr + return nil } -func (r *RabbitMQTransportSocket) WriteHeader(k string, v string) { - r.hdr[k] = v +func (r *RabbitMQTransportSocket) Send(m *Message) error { + msg := amqp.Publishing{ + CorrelationId: r.d.CorrelationId, + Timestamp: time.Now().UTC(), + Body: m.Body, + Headers: amqp.Table{}, + } + + for k, v := range m.Header { + msg.Headers[k] = v + } + + return r.conn.Publish("", r.d.ReplyTo, msg) } -func (r *RabbitMQTransportSocket) Write(b []byte) error { - _, err := r.buf.Write(b) - return err +func (r *RabbitMQTransportSocket) Close() error { + return nil } -func (r *RabbitMQTransportServer) Addr() string { +func (r *RabbitMQTransportListener) Addr() string { return r.addr } -func (r *RabbitMQTransportServer) Close() error { +func (r *RabbitMQTransportListener) Close() error { r.conn.Close() return nil } -func (r *RabbitMQTransportServer) Serve(fn func(Socket)) error { +func (r *RabbitMQTransportListener) Accept(fn func(Socket)) error { deliveries, err := r.conn.Consume(r.addr) if err != nil { return err } handler := func(d amqp.Delivery) { - buf := bytes.NewBuffer(nil) - headers := amqp.Table{} - fn(&RabbitMQTransportSocket{ - d: &d, - hdr: headers, - buf: buf, + d: &d, + conn: r.conn, }) - - msg := amqp.Publishing{ - CorrelationId: d.CorrelationId, - Timestamp: time.Now().UTC(), - Body: buf.Bytes(), - Headers: headers, - } - - r.conn.Publish("", d.ReplyTo, msg) - buf.Reset() } for d := range deliveries { @@ -201,7 +200,7 @@ func (r *RabbitMQTransportServer) Serve(fn func(Socket)) error { return nil } -func (r *RabbitMQTransport) NewClient(addr string) (Client, error) { +func (r *RabbitMQTransport) Dial(addr string) (Client, error) { id, err := uuid.NewV4() if err != nil { return nil, err @@ -215,16 +214,16 @@ func (r *RabbitMQTransport) NewClient(addr string) (Client, error) { }, nil } -func (r *RabbitMQTransport) NewServer(addr string) (Server, error) { +func (r *RabbitMQTransport) Listen(addr string) (Listener, error) { id, err := uuid.NewV4() if err != nil { return nil, err } - conn := NewRabbitConnection("", r.addrs) + conn := newRabbitMQConn("", r.addrs) <-conn.Init() - return &RabbitMQTransportServer{ + return &RabbitMQTransportListener{ addr: id.String(), conn: conn, }, nil @@ -232,7 +231,7 @@ func (r *RabbitMQTransport) NewServer(addr string) (Server, error) { func NewRabbitMQTransport(addrs []string) *RabbitMQTransport { return &RabbitMQTransport{ - conn: NewRabbitConnection("", addrs), + conn: newRabbitMQConn("", addrs), addrs: addrs, } } diff --git a/transport/transport.go b/transport/transport.go index 2c4de967..0b53e1d5 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -6,9 +6,9 @@ type Message struct { } type Socket interface { - Recv() (*Message, error) - WriteHeader(string, string) - Write([]byte) error + Recv(*Message) error + Send(*Message) error + Close() error } type Client interface { @@ -16,25 +16,25 @@ type Client interface { Close() error } -type Server interface { +type Listener interface { Addr() string Close() error - Serve(func(Socket)) error + Accept(func(Socket)) error } type Transport interface { - NewClient(addr string) (Client, error) - NewServer(addr string) (Server, error) + Dial(addr string) (Client, error) + Listen(addr string) (Listener, error) } var ( DefaultTransport Transport = NewHttpTransport([]string{}) ) -func NewClient(addr string) (Client, error) { - return DefaultTransport.NewClient(addr) +func Dial(addr string) (Client, error) { + return DefaultTransport.Dial(addr) } -func NewServer(addr string) (Server, error) { - return DefaultTransport.NewServer(addr) +func Listen(addr string) (Listener, error) { + return DefaultTransport.Listen(addr) }