transport cruft

This commit is contained in:
Asim 2015-05-20 22:57:19 +01:00
parent 4b51a55993
commit 50e44726f5
10 changed files with 970 additions and 164 deletions

View File

@ -3,14 +3,13 @@ package client
import (
"bytes"
"fmt"
"io/ioutil"
"math/rand"
"net/http"
"net/url"
"time"
"github.com/myodc/go-micro/errors"
"github.com/myodc/go-micro/registry"
"github.com/myodc/go-micro/transport"
rpc "github.com/youtube/vitess/go/rpcplus"
js "github.com/youtube/vitess/go/rpcplus/jsonrpc"
pb "github.com/youtube/vitess/go/rpcplus/pbrpc"
@ -22,7 +21,9 @@ type headerRoundTripper struct {
r http.RoundTripper
}
type RpcClient struct{}
type RpcClient struct {
transport transport.Transport
}
func init() {
rand.Seed(time.Now().UnixNano())
@ -71,48 +72,43 @@ func (r *RpcClient) call(address, path string, request Request, response interfa
return errors.InternalServerError("go.micro.client", fmt.Sprintf("Error writing request: %v", err))
}
client := &http.Client{}
client.Transport = &headerRoundTripper{http.DefaultTransport}
request.Headers().Set("Content-Type", request.ContentType())
hreq := &http.Request{
Method: "POST",
URL: &url.URL{
Scheme: "http",
Host: address,
Path: path,
},
Header: request.Headers().(http.Header),
Body: buf,
ContentLength: int64(reqB.Len()),
Host: address,
msg := &transport.Message{
Header: make(map[string]string),
Body: reqB.Bytes(),
}
rsp, err := client.Do(hreq)
h, _ := request.Headers().(http.Header)
for k, v := range h {
if len(v) > 0 {
msg.Header[k] = v[0]
}
}
msg.Header["Content-Type"] = request.ContentType()
c, err := r.transport.NewClient(request.Service(), address)
if err != nil {
return errors.InternalServerError("go.micro.client", fmt.Sprintf("Error sending request: %v", err))
}
defer rsp.Body.Close()
b, err := ioutil.ReadAll(rsp.Body)
rsp, err := c.Send(msg)
if err != nil {
return errors.InternalServerError("go.micro.client", fmt.Sprintf("Error reading response: %v", err))
return errors.InternalServerError("go.micro.client", fmt.Sprintf("Error sending request: %v", err))
}
rspB := bytes.NewBuffer(b)
rspB := bytes.NewBuffer(rsp.Body)
defer rspB.Reset()
rBuf := &buffer{
rspB,
}
switch rsp.Header.Get("Content-Type") {
switch rsp.Header["Content-Type"] {
case "application/octet-stream":
cc = pb.NewClientCodec(rBuf)
case "application/json":
cc = js.NewClientCodec(rBuf)
default:
return errors.InternalServerError("go.micro.client", string(b))
return errors.InternalServerError("go.micro.client", string(rsp.Body))
}
pRsp := &rpc.Response{}
@ -167,5 +163,7 @@ func (r *RpcClient) NewJsonRequest(service, method string, request interface{})
}
func NewRpcClient() *RpcClient {
return &RpcClient{}
return &RpcClient{
transport: transport.DefaultTransport,
}
}

View File

@ -11,6 +11,7 @@ import (
"github.com/myodc/go-micro/registry"
"github.com/myodc/go-micro/server"
"github.com/myodc/go-micro/store"
"github.com/myodc/go-micro/transport"
)
var (
@ -30,7 +31,6 @@ var (
cli.StringFlag{
Name: "broker_address",
EnvVar: "MICRO_BROKER_ADDRESS",
Value: ":0",
Usage: "Comma-separated list of broker addresses",
},
cli.StringFlag{
@ -42,7 +42,6 @@ var (
cli.StringFlag{
Name: "registry_address",
EnvVar: "MICRO_REGISTRY_ADDRESS",
Value: "127.0.0.1:8500",
Usage: "Comma-separated list of registry addresses",
},
cli.StringFlag{
@ -54,42 +53,63 @@ var (
cli.StringFlag{
Name: "store_address",
EnvVar: "MICRO_STORE_ADDRESS",
Value: "127.0.0.1:8500",
Usage: "Comma-separated list of store addresses",
},
cli.StringFlag{
Name: "transport",
EnvVar: "MICRO_TRANSPORT",
Value: "http",
Usage: "Transport mechanism used; http, rabbitmq, etc",
},
cli.StringFlag{
Name: "transport_address",
EnvVar: "MICRO_TRANSPORT_ADDRESS",
Usage: "Comma-separated list of transport addresses",
},
}
)
func Setup(c *cli.Context) error {
server.Address = c.String("server_address")
broker_addrs := strings.Split(c.String("broker_address"), ",")
bAddrs := strings.Split(c.String("broker_address"), ",")
switch c.String("broker") {
case "http":
broker.DefaultBroker = broker.NewHttpBroker(broker_addrs)
broker.DefaultBroker = broker.NewHttpBroker(bAddrs)
case "nats":
broker.DefaultBroker = broker.NewNatsBroker(broker_addrs)
broker.DefaultBroker = broker.NewNatsBroker(bAddrs)
}
registry_addrs := strings.Split(c.String("registry_address"), ",")
rAddrs := strings.Split(c.String("registry_address"), ",")
switch c.String("registry") {
case "kubernetes":
registry.DefaultRegistry = registry.NewKubernetesRegistry(registry_addrs)
registry.DefaultRegistry = registry.NewKubernetesRegistry(rAddrs)
case "consul":
registry.DefaultRegistry = registry.NewConsulRegistry(registry_addrs)
registry.DefaultRegistry = registry.NewConsulRegistry(rAddrs)
}
store_addrs := strings.Split(c.String("store_address"), ",")
sAddrs := strings.Split(c.String("store_address"), ",")
switch c.String("store") {
case "memcached":
store.DefaultStore = store.NewMemcacheStore(store_addrs)
store.DefaultStore = store.NewMemcacheStore(sAddrs)
case "memory":
store.DefaultStore = store.NewMemoryStore(store_addrs)
store.DefaultStore = store.NewMemoryStore(sAddrs)
case "etcd":
store.DefaultStore = store.NewEtcdStore(store_addrs)
store.DefaultStore = store.NewEtcdStore(sAddrs)
}
tAddrs := strings.Split(c.String("transport_address"), ",")
switch c.String("transport") {
case "http":
transport.DefaultTransport = transport.NewHttpTransport(tAddrs)
case "rabbitmq":
transport.DefaultTransport = transport.NewRabbitMQTransport(tAddrs)
case "nats":
transport.DefaultTransport = transport.NewNatsTransport(tAddrs)
}
return nil

View File

@ -2,27 +2,23 @@ package server
import (
"bytes"
"fmt"
"io/ioutil"
"net"
"net/http"
"runtime/debug"
"strconv"
"sync"
"github.com/bradfitz/http2"
log "github.com/golang/glog"
"github.com/myodc/go-micro/errors"
"github.com/myodc/go-micro/transport"
rpc "github.com/youtube/vitess/go/rpcplus"
js "github.com/youtube/vitess/go/rpcplus/jsonrpc"
pb "github.com/youtube/vitess/go/rpcplus/pbrpc"
"golang.org/x/net/context"
)
type RpcServer struct {
mtx sync.RWMutex
rpc *rpc.Server
address string
exit chan chan error
mtx sync.RWMutex
address string
transport transport.Transport
rpc *rpc.Server
exit chan chan error
}
var (
@ -30,92 +26,14 @@ var (
RpcPath = "/_rpc"
)
func executeRequestSafely(c *serverContext, r *http.Request) {
defer func() {
if x := recover(); x != nil {
log.Warningf("Panicked on request: %v", r)
log.Warningf("%v: %v", x, string(debug.Stack()))
err := errors.InternalServerError("go.micro.server", "Unexpected error")
c.WriteHeader(500)
c.Write([]byte(err.Error()))
}
}()
http.DefaultServeMux.ServeHTTP(c, r)
}
func (s *RpcServer) handler(w http.ResponseWriter, r *http.Request) {
c := &serverContext{
req: &serverRequest{r},
outHeader: w.Header(),
}
ctxs.Lock()
ctxs.m[r] = c
ctxs.Unlock()
defer func() {
ctxs.Lock()
delete(ctxs.m, r)
ctxs.Unlock()
}()
// Patch up RemoteAddr so it looks reasonable.
if addr := r.Header.Get("X-Forwarded-For"); len(addr) > 0 {
r.RemoteAddr = addr
} else {
// Should not normally reach here, but pick a sensible default anyway.
r.RemoteAddr = "127.0.0.1"
}
// The address in the headers will most likely be of these forms:
// 123.123.123.123
// 2001:db8::1
// net/http.Request.RemoteAddr is specified to be in "IP:port" form.
if _, _, err := net.SplitHostPort(r.RemoteAddr); err != nil {
// Assume the remote address is only a host; add a default port.
r.RemoteAddr = net.JoinHostPort(r.RemoteAddr, "80")
}
executeRequestSafely(c, r)
c.outHeader = nil // make sure header changes aren't respected any more
// Avoid nil Write call if c.Write is never called.
if c.outCode != 0 {
w.WriteHeader(c.outCode)
}
if c.outBody != nil {
w.Write(c.outBody)
}
}
func (s *RpcServer) Address() string {
s.mtx.RLock()
defer s.mtx.RUnlock()
return s.address
}
func (s *RpcServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
serveCtx := getServerContext(req)
// TODO: get user scope from context
// check access
if req.Method != "POST" {
err := errors.BadRequest("go.micro.server", "Method not allowed")
http.Error(w, err.Error(), http.StatusMethodNotAllowed)
return
}
defer req.Body.Close()
b, err := ioutil.ReadAll(req.Body)
func (s *RpcServer) serve(sock transport.Socket) {
// serveCtx := getServerContext(req)
msg, err := sock.Recv()
if err != nil {
errr := errors.InternalServerError("go.micro.server", fmt.Sprintf("Error reading request body: %v", err))
w.WriteHeader(500)
w.Write([]byte(errr.Error()))
log.Errorf("Erroring reading request body: %v", err)
return
}
rbq := bytes.NewBuffer(b)
rbq := bytes.NewBuffer(msg.Body)
rsp := bytes.NewBuffer(nil)
defer rsp.Reset()
defer rbq.Reset()
@ -126,36 +44,34 @@ func (s *RpcServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
}
var cc rpc.ServerCodec
switch req.Header.Get("Content-Type") {
switch msg.Header["Content-Type"] {
case "application/octet-stream":
cc = pb.NewServerCodec(buf)
case "application/json":
cc = js.NewServerCodec(buf)
default:
err = errors.InternalServerError("go.micro.server", fmt.Sprintf("Unsupported content-type: %v", req.Header.Get("Content-Type")))
w.WriteHeader(500)
w.Write([]byte(err.Error()))
return
// return nil, errors.InternalServerError("go.micro.server", fmt.Sprintf("Unsupported content-type: %v", req.Header.Get("Content-Type")))
}
ctx := newContext(&ctx{}, serveCtx)
err = s.rpc.ServeRequestWithContext(ctx, cc)
//ctx := newContext(&ctx{}, serveCtx)
err = s.rpc.ServeRequestWithContext(context.Background(), cc)
if err != nil {
// This should not be possible.
w.WriteHeader(500)
w.Write([]byte(err.Error()))
log.Errorf("Erroring serving request: %v", err)
return
}
w.Header().Set("Content-Type", req.Header.Get("Content-Type"))
w.Header().Set("Content-Length", strconv.Itoa(rsp.Len()))
w.Write(rsp.Bytes())
sock.WriteHeader("Content-Type", msg.Header["Content-Type"])
sock.Write(rsp.Bytes())
}
func (s *RpcServer) Address() string {
s.mtx.RLock()
address := s.address
s.mtx.RUnlock()
return address
}
func (s *RpcServer) Init() error {
log.Infof("Rpc handler %s", RpcPath)
http.Handle(RpcPath, s)
return nil
}
@ -180,28 +96,22 @@ func (s *RpcServer) Register(r Receiver) error {
func (s *RpcServer) Start() error {
registerHealthChecker(http.DefaultServeMux)
l, err := net.Listen("tcp", s.address)
ts, err := s.transport.NewServer(Name, s.address)
if err != nil {
return err
}
log.Infof("Listening on %s", l.Addr().String())
log.Infof("Listening on %s", ts.Addr())
s.mtx.Lock()
s.address = l.Addr().String()
s.mtx.Unlock()
s.mtx.RLock()
s.address = ts.Addr()
s.mtx.RUnlock()
srv := &http.Server{
Handler: http.HandlerFunc(s.handler),
}
http2.ConfigureServer(srv, nil)
go srv.Serve(l)
go ts.Serve(s.serve)
go func() {
ch := <-s.exit
ch <- l.Close()
ch <- ts.Close()
}()
return nil
@ -215,8 +125,9 @@ func (s *RpcServer) Stop() error {
func NewRpcServer(address string) *RpcServer {
return &RpcServer{
rpc: rpc.NewServer(),
address: address,
exit: make(chan chan error),
address: address,
transport: transport.DefaultTransport,
rpc: rpc.NewServer(),
exit: make(chan chan error),
}
}

13
transport/buffer.go Normal file
View File

@ -0,0 +1,13 @@
package transport
import (
"io"
)
type buffer struct {
io.ReadWriter
}
func (b *buffer) Close() error {
return nil
}

170
transport/http_transport.go Normal file
View File

@ -0,0 +1,170 @@
package transport
import (
"bytes"
"io/ioutil"
"net"
"net/http"
"net/url"
)
type headerRoundTripper struct {
r http.RoundTripper
}
type HttpTransport struct {
client *http.Client
}
type HttpTransportClient struct {
client *http.Client
target string
}
type HttpTransportSocket struct {
r *http.Request
w http.ResponseWriter
}
type HttpTransportServer struct {
listener net.Listener
}
func (t *headerRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
r.Header.Set("X-Client-Version", "1.0")
return t.r.RoundTrip(r)
}
func (h *HttpTransportClient) Send(m *Message) (*Message, error) {
header := make(http.Header)
for k, v := range m.Header {
header.Set(k, v)
}
reqB := bytes.NewBuffer(m.Body)
defer reqB.Reset()
buf := &buffer{
reqB,
}
hreq := &http.Request{
Method: "POST",
URL: &url.URL{
Scheme: "http",
Host: h.target,
// Path: path,
},
Header: header,
Body: buf,
ContentLength: int64(reqB.Len()),
Host: h.target,
}
rsp, err := h.client.Do(hreq)
if err != nil {
return nil, err
}
defer rsp.Body.Close()
b, err := ioutil.ReadAll(rsp.Body)
if err != nil {
return nil, err
}
mr := &Message{
Header: make(map[string]string),
Body: b,
}
for k, v := range rsp.Header {
if len(v) > 0 {
mr.Header[k] = v[0]
} else {
mr.Header[k] = ""
}
}
return mr, nil
}
func (h *HttpTransportClient) Close() error {
return nil
}
func (h *HttpTransportSocket) Recv() (*Message, error) {
b, err := ioutil.ReadAll(h.r.Body)
if err != nil {
return nil, err
}
m := &Message{
Header: make(map[string]string),
Body: b,
}
for k, v := range h.r.Header {
if len(v) > 0 {
m.Header[k] = v[0]
} else {
m.Header[k] = ""
}
}
return m, nil
}
func (h *HttpTransportSocket) WriteHeader(k string, v string) {
h.w.Header().Set(k, v)
}
func (h *HttpTransportSocket) Write(b []byte) error {
_, err := h.w.Write(b)
return err
}
func (h *HttpTransportServer) Addr() string {
return h.listener.Addr().String()
}
func (h *HttpTransportServer) Close() error {
return h.listener.Close()
}
func (h *HttpTransportServer) Serve(fn func(Socket)) error {
srv := &http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fn(&HttpTransportSocket{
r: r,
w: w,
})
}),
}
return srv.Serve(h.listener)
}
func (h *HttpTransport) NewClient(name, addr string) (Client, error) {
return &HttpTransportClient{
client: h.client,
target: addr,
}, nil
}
func (h *HttpTransport) NewServer(name, addr string) (Server, error) {
l, err := net.Listen("tcp", addr)
if err != nil {
return nil, err
}
return &HttpTransportServer{
listener: l,
}, nil
}
func NewHttpTransport(addrs []string) *HttpTransport {
client := &http.Client{}
client.Transport = &headerRoundTripper{http.DefaultTransport}
return &HttpTransport{client: client}
}

151
transport/nats_transport.go Normal file
View File

@ -0,0 +1,151 @@
package transport
import (
"bytes"
"encoding/json"
"strings"
"time"
"github.com/apcera/nats"
)
type NatsTransport struct{}
type NatsTransportClient struct {
conn *nats.Conn
target string
}
type NatsTransportSocket struct {
m *nats.Msg
hdr map[string]string
buf *bytes.Buffer
}
type NatsTransportServer struct {
conn *nats.Conn
name string
exit chan bool
}
func (n *NatsTransportClient) Send(m *Message) (*Message, error) {
b, err := json.Marshal(m)
if err != nil {
return nil, err
}
rsp, err := n.conn.Request(n.target, b, time.Second*10)
if err != nil {
return nil, err
}
var mr *Message
if err := json.Unmarshal(rsp.Data, &mr); err != nil {
return nil, err
}
return mr, nil
}
func (n *NatsTransportClient) Close() error {
n.conn.Close()
return nil
}
func (n *NatsTransportSocket) Recv() (*Message, error) {
var m *Message
if err := json.Unmarshal(n.m.Data, &m); err != nil {
return nil, err
}
return m, nil
}
func (n *NatsTransportSocket) WriteHeader(k string, v string) {
n.hdr[k] = v
}
func (n *NatsTransportSocket) Write(b []byte) error {
_, err := n.buf.Write(b)
return err
}
func (n *NatsTransportServer) Addr() string {
return "127.0.0.1:4222"
}
func (n *NatsTransportServer) Close() error {
n.exit <- true
n.conn.Close()
return nil
}
func (n *NatsTransportServer) Serve(fn func(Socket)) error {
s, err := n.conn.QueueSubscribe(n.name, "queue:"+n.name, func(m *nats.Msg) {
buf := bytes.NewBuffer(nil)
hdr := make(map[string]string)
fn(&NatsTransportSocket{
m: m,
hdr: hdr,
buf: buf,
})
mrsp := &Message{
Header: hdr,
Body: buf.Bytes(),
}
b, err := json.Marshal(mrsp)
if err != nil {
return
}
n.conn.Publish(m.Reply, b)
buf.Reset()
})
if err != nil {
return err
}
<-n.exit
return s.Unsubscribe()
}
func (n *NatsTransport) NewClient(name, addr string) (Client, error) {
if !strings.HasPrefix(addr, "nats://") {
addr = nats.DefaultURL
}
c, err := nats.Connect(addr)
if err != nil {
return nil, err
}
return &NatsTransportClient{
conn: c,
target: name,
}, nil
}
func (n *NatsTransport) NewServer(name, addr string) (Server, error) {
if !strings.HasPrefix(addr, "nats://") {
addr = nats.DefaultURL
}
c, err := nats.Connect(addr)
if err != nil {
return nil, err
}
return &NatsTransportServer{
name: name,
conn: c,
exit: make(chan bool, 1),
}, nil
}
func NewNatsTransport(addrs []string) *NatsTransport {
return &NatsTransport{}
}

View File

@ -0,0 +1,128 @@
package transport
//
// All credit to Mondo
// https://github.com/mondough/typhon
//
import (
"errors"
"github.com/nu7hatch/gouuid"
"github.com/streadway/amqp"
)
type RabbitChannel struct {
uuid string
connection *amqp.Connection
channel *amqp.Channel
}
func NewRabbitChannel(conn *amqp.Connection) (*RabbitChannel, error) {
id, err := uuid.NewV4()
if err != nil {
return nil, err
}
rabbitCh := &RabbitChannel{
uuid: id.String(),
connection: conn,
}
if err := rabbitCh.Connect(); err != nil {
return nil, err
}
return rabbitCh, nil
}
func (r *RabbitChannel) Connect() error {
var err error
r.channel, err = r.connection.Channel()
if err != nil {
return err
}
return nil
}
func (r *RabbitChannel) Close() error {
if r.channel == nil {
return errors.New("Channel is nil")
}
return r.channel.Close()
}
func (r *RabbitChannel) Publish(exchange, routingKey string, message amqp.Publishing) error {
if r.channel == nil {
return errors.New("Channel is nil")
}
return r.channel.Publish(exchange, routingKey, false, false, message)
}
func (r *RabbitChannel) DeclareExchange(exchange string) error {
return r.channel.ExchangeDeclare(
exchange, // name
"topic", // kind
false, // durable
false, // autoDelete
false, // internal
false, // noWait
nil, // args
)
}
func (r *RabbitChannel) DeclareQueue(queue string) error {
_, err := r.channel.QueueDeclare(
queue, // name
false, // durable
true, // autoDelete
false, // exclusive
false, // noWait
nil, // args
)
return err
}
func (r *RabbitChannel) DeclareDurableQueue(queue string) error {
_, err := r.channel.QueueDeclare(
queue, // name
true, // durable
false, // autoDelete
false, // exclusive
false, // noWait
nil, // args
)
return err
}
func (r *RabbitChannel) DeclareReplyQueue(queue string) error {
_, err := r.channel.QueueDeclare(
queue, // name
false, // durable
true, // autoDelete
true, // exclusive
false, // noWait
nil, // args
)
return err
}
func (r *RabbitChannel) ConsumeQueue(queue string) (<-chan amqp.Delivery, error) {
return r.channel.Consume(
queue, // queue
r.uuid, // consumer
true, // autoAck
false, // exclusive
false, // nolocal
false, // nowait
nil, // args
)
}
func (r *RabbitChannel) BindQueue(queue, exchange string) error {
return r.channel.QueueBind(
queue, // name
queue, // key
exchange, // exchange
false, // noWait
nil, // args
)
}

View File

@ -0,0 +1,143 @@
package transport
//
// All credit to Mondo
// https://github.com/mondough/typhon
//
import (
"sync"
"time"
"github.com/streadway/amqp"
)
var (
DefaultExchange = "micro"
DefaultRabbitURL = "amqp://guest:guest@127.0.0.1:5672"
)
type RabbitConnection struct {
Connection *amqp.Connection
Channel *RabbitChannel
ExchangeChannel *RabbitChannel
notify chan bool
exchange string
url string
connected bool
mtx sync.Mutex
closeChan chan struct{}
closed bool
}
func (r *RabbitConnection) Init() chan bool {
go r.Connect(r.notify)
return r.notify
}
func (r *RabbitConnection) Connect(connected chan bool) {
for {
if err := r.tryToConnect(); err != nil {
time.Sleep(1 * time.Second)
continue
}
connected <- true
r.connected = true
notifyClose := make(chan *amqp.Error)
r.Connection.NotifyClose(notifyClose)
// Block until we get disconnected, or shut down
select {
case <-notifyClose:
// Spin around and reconnect
r.connected = false
case <-r.closeChan:
// Shut down connection
if err := r.Connection.Close(); err != nil {
}
r.connected = false
return
}
}
}
func (r *RabbitConnection) IsConnected() bool {
return r.connected
}
func (r *RabbitConnection) Close() {
r.mtx.Lock()
defer r.mtx.Unlock()
if r.closed {
return
}
close(r.closeChan)
r.closed = true
}
func (r *RabbitConnection) tryToConnect() error {
var err error
r.Connection, err = amqp.Dial(r.url)
if err != nil {
return err
}
r.Channel, err = NewRabbitChannel(r.Connection)
if err != nil {
return err
}
r.Channel.DeclareExchange(r.exchange)
r.ExchangeChannel, err = NewRabbitChannel(r.Connection)
if err != nil {
return err
}
return nil
}
func (r *RabbitConnection) Consume(serverName string) (<-chan amqp.Delivery, error) {
consumerChannel, err := NewRabbitChannel(r.Connection)
if err != nil {
return nil, err
}
err = consumerChannel.DeclareQueue(serverName)
if err != nil {
return nil, err
}
deliveries, err := consumerChannel.ConsumeQueue(serverName)
if err != nil {
return nil, err
}
err = consumerChannel.BindQueue(serverName, r.exchange)
if err != nil {
return nil, err
}
return deliveries, nil
}
func (r *RabbitConnection) Publish(exchange, routingKey string, msg amqp.Publishing) error {
return r.ExchangeChannel.Publish(exchange, routingKey, msg)
}
func NewRabbitConnection(exchange, url string) *RabbitConnection {
if len(url) == 0 {
url = DefaultRabbitURL
}
if len(exchange) == 0 {
exchange = DefaultExchange
}
return &RabbitConnection{
exchange: DefaultExchange,
url: DefaultRabbitURL,
notify: make(chan bool, 1),
closeChan: make(chan struct{}),
}
}

View File

@ -0,0 +1,232 @@
package transport
import (
"bytes"
"fmt"
"sync"
"time"
"errors"
uuid "github.com/nu7hatch/gouuid"
"github.com/streadway/amqp"
)
type RabbitMQTransport struct {
conn *RabbitConnection
}
type RabbitMQTransportClient struct {
once sync.Once
conn *RabbitConnection
target string
replyTo string
sync.Mutex
inflight map[string]chan amqp.Delivery
}
type RabbitMQTransportSocket struct {
d *amqp.Delivery
hdr amqp.Table
buf *bytes.Buffer
}
type RabbitMQTransportServer struct {
conn *RabbitConnection
name string
}
func (h *RabbitMQTransportClient) init() {
<-h.conn.Init()
if err := h.conn.Channel.DeclareReplyQueue(h.replyTo); err != nil {
return
}
deliveries, err := h.conn.Channel.ConsumeQueue(h.replyTo)
if err != nil {
return
}
go func() {
for delivery := range deliveries {
go h.handle(delivery)
}
}()
}
func (h *RabbitMQTransportClient) handle(delivery amqp.Delivery) {
ch := h.getReq(delivery.CorrelationId)
if ch == nil {
return
}
select {
case ch <- delivery:
default:
}
}
func (h *RabbitMQTransportClient) putReq(id string) chan amqp.Delivery {
h.Lock()
ch := make(chan amqp.Delivery, 1)
h.inflight[id] = ch
h.Unlock()
return ch
}
func (h *RabbitMQTransportClient) getReq(id string) chan amqp.Delivery {
h.Lock()
defer h.Unlock()
if ch, ok := h.inflight[id]; ok {
delete(h.inflight, id)
return ch
}
return nil
}
func (h *RabbitMQTransportClient) Send(m *Message) (*Message, error) {
h.once.Do(h.init)
if !h.conn.IsConnected() {
return nil, errors.New("Not connected to AMQP")
}
id, err := uuid.NewV4()
if err != nil {
return nil, err
}
replyChan := h.putReq(id.String())
headers := amqp.Table{}
for k, v := range m.Header {
headers[k] = v
}
message := amqp.Publishing{
CorrelationId: id.String(),
Timestamp: time.Now().UTC(),
Body: m.Body,
ReplyTo: h.replyTo,
Headers: headers,
}
if err := h.conn.Publish("micro", h.target, message); err != nil {
h.getReq(id.String())
return nil, err
}
select {
case d := <-replyChan:
mr := &Message{
Header: make(map[string]string),
Body: d.Body,
}
for k, v := range d.Headers {
mr.Header[k] = fmt.Sprintf("%v", v)
}
return mr, nil
case <-time.After(time.Second * 10):
return nil, errors.New("timed out")
}
}
func (h *RabbitMQTransportClient) Close() error {
h.conn.Close()
return nil
}
func (h *RabbitMQTransportSocket) Recv() (*Message, error) {
m := &Message{
Header: make(map[string]string),
Body: h.d.Body,
}
for k, v := range h.d.Headers {
m.Header[k] = fmt.Sprintf("%v", v)
}
return m, nil
}
func (h *RabbitMQTransportSocket) WriteHeader(k string, v string) {
h.hdr[k] = v
}
func (h *RabbitMQTransportSocket) Write(b []byte) error {
_, err := h.buf.Write(b)
return err
}
func (h *RabbitMQTransportServer) Addr() string {
return h.conn.Connection.LocalAddr().String()
}
func (h *RabbitMQTransportServer) Close() error {
h.conn.Close()
return nil
}
func (h *RabbitMQTransportServer) Serve(fn func(Socket)) error {
deliveries, err := h.conn.Consume(h.name)
if err != nil {
return err
}
handler := func(d amqp.Delivery) {
buf := bytes.NewBuffer(nil)
headers := amqp.Table{}
fn(&RabbitMQTransportSocket{
d: &d,
hdr: headers,
buf: buf,
})
msg := amqp.Publishing{
CorrelationId: d.CorrelationId,
Timestamp: time.Now().UTC(),
Body: buf.Bytes(),
Headers: headers,
}
h.conn.Publish("", d.ReplyTo, msg)
buf.Reset()
}
for d := range deliveries {
go handler(d)
}
return nil
}
func (h *RabbitMQTransport) NewClient(name, addr string) (Client, error) {
id, err := uuid.NewV4()
if err != nil {
return nil, err
}
return &RabbitMQTransportClient{
conn: h.conn,
target: name,
inflight: make(map[string]chan amqp.Delivery),
replyTo: fmt.Sprintf("replyTo-%s", id.String()),
}, nil
}
func (h *RabbitMQTransport) NewServer(name, addr string) (Server, error) {
conn := NewRabbitConnection("", "")
<-conn.Init()
return &RabbitMQTransportServer{
name: name,
conn: conn,
}, nil
}
func NewRabbitMQTransport(addrs []string) *RabbitMQTransport {
return &RabbitMQTransport{
conn: NewRabbitConnection("", ""),
}
}

40
transport/transport.go Normal file
View File

@ -0,0 +1,40 @@
package transport
type Message struct {
Header map[string]string
Body []byte
}
type Socket interface {
Recv() (*Message, error)
WriteHeader(string, string)
Write([]byte) error
}
type Client interface {
Send(*Message) (*Message, error)
Close() error
}
type Server interface {
Addr() string
Close() error
Serve(func(Socket)) error
}
type Transport interface {
NewClient(name, addr string) (Client, error)
NewServer(name, addr string) (Server, error)
}
var (
DefaultTransport Transport = NewHttpTransport([]string{})
)
func NewClient(name, addr string) (Client, error) {
return DefaultTransport.NewClient(name, addr)
}
func NewServer(name, addr string) (Server, error) {
return DefaultTransport.NewServer(name, addr)
}