diff --git a/transport/memory/memory.go b/transport/memory/memory.go index 51f8d01a..ed269367 100644 --- a/transport/memory/memory.go +++ b/transport/memory/memory.go @@ -2,6 +2,7 @@ package memory import ( + "context" "errors" "fmt" "math/rand" @@ -24,6 +25,10 @@ type memorySocket struct { local string remote string + + // for send/recv transport.Timeout + timeout time.Duration + ctx context.Context sync.RWMutex } @@ -33,11 +38,13 @@ type memoryClient struct { } type memoryListener struct { - addr string - exit chan bool - conn chan *memorySocket - opts transport.ListenOptions + addr string + exit chan bool + conn chan *memorySocket + lopts transport.ListenOptions + topts transport.Options sync.RWMutex + ctx context.Context } type memoryTransport struct { @@ -49,7 +56,17 @@ type memoryTransport struct { func (ms *memorySocket) Recv(m *transport.Message) error { ms.RLock() defer ms.RUnlock() + + ctx := ms.ctx + if ms.timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ms.ctx, ms.timeout) + defer cancel() + } + select { + case <-ctx.Done(): + return ctx.Err() case <-ms.exit: return errors.New("connection closed") case <-ms.lexit: @@ -71,7 +88,17 @@ func (ms *memorySocket) Remote() string { func (ms *memorySocket) Send(m *transport.Message) error { ms.RLock() defer ms.RUnlock() + + ctx := ms.ctx + if ms.timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ms.ctx, ms.timeout) + defer cancel() + } + select { + case <-ctx.Done(): + return ctx.Err() case <-ms.exit: return errors.New("connection closed") case <-ms.lexit: @@ -116,12 +143,14 @@ func (m *memoryListener) Accept(fn func(transport.Socket)) error { return nil case c := <-m.conn: go fn(&memorySocket{ - lexit: c.lexit, - exit: c.exit, - send: c.recv, - recv: c.send, - local: c.Remote(), - remote: c.Local(), + lexit: c.lexit, + exit: c.exit, + send: c.recv, + recv: c.send, + local: c.Remote(), + remote: c.Local(), + timeout: m.topts.Timeout, + ctx: m.topts.Context, }) } } @@ -143,12 +172,14 @@ func (m *memoryTransport) Dial(addr string, opts ...transport.DialOption) (trans client := &memoryClient{ &memorySocket{ - send: make(chan *transport.Message), - recv: make(chan *transport.Message), - exit: make(chan bool), - lexit: listener.exit, - local: addr, - remote: addr, + send: make(chan *transport.Message), + recv: make(chan *transport.Message), + exit: make(chan bool), + lexit: listener.exit, + local: addr, + remote: addr, + timeout: m.opts.Timeout, + ctx: m.opts.Context, }, options, } @@ -196,10 +227,12 @@ func (m *memoryTransport) Listen(addr string, opts ...transport.ListenOption) (t } listener := &memoryListener{ - opts: options, - addr: addr, - conn: make(chan *memorySocket), - exit: make(chan bool), + lopts: options, + topts: m.opts, + addr: addr, + conn: make(chan *memorySocket), + exit: make(chan bool), + ctx: m.opts.Context, } m.listeners[addr] = listener @@ -223,12 +256,18 @@ func (m *memoryTransport) String() string { } func NewTransport(opts ...transport.Option) transport.Transport { - rand.Seed(time.Now().UnixNano()) var options transport.Options + + rand.Seed(time.Now().UnixNano()) + for _, o := range opts { o(&options) } + if options.Context == nil { + options.Context = context.Background() + } + return &memoryTransport{ opts: options, listeners: make(map[string]*memoryListener),