transport memory: add Send/Recv Timeout

Signed-off-by: Vasiliy Tolstov <v.tolstov@unistack.org>
This commit is contained in:
Василий Толстов 2019-08-03 15:39:44 +03:00
parent d250ac736f
commit e1709026e4

View File

@ -2,6 +2,7 @@
package memory package memory
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"math/rand" "math/rand"
@ -24,6 +25,10 @@ type memorySocket struct {
local string local string
remote string remote string
// for send/recv transport.Timeout
timeout time.Duration
ctx context.Context
sync.RWMutex sync.RWMutex
} }
@ -33,11 +38,13 @@ type memoryClient struct {
} }
type memoryListener struct { type memoryListener struct {
addr string addr string
exit chan bool exit chan bool
conn chan *memorySocket conn chan *memorySocket
opts transport.ListenOptions lopts transport.ListenOptions
topts transport.Options
sync.RWMutex sync.RWMutex
ctx context.Context
} }
type memoryTransport struct { type memoryTransport struct {
@ -49,7 +56,17 @@ type memoryTransport struct {
func (ms *memorySocket) Recv(m *transport.Message) error { func (ms *memorySocket) Recv(m *transport.Message) error {
ms.RLock() ms.RLock()
defer ms.RUnlock() defer ms.RUnlock()
ctx := ms.ctx
if ms.timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ms.ctx, ms.timeout)
defer cancel()
}
select { select {
case <-ctx.Done():
return ctx.Err()
case <-ms.exit: case <-ms.exit:
return errors.New("connection closed") return errors.New("connection closed")
case <-ms.lexit: case <-ms.lexit:
@ -71,7 +88,17 @@ func (ms *memorySocket) Remote() string {
func (ms *memorySocket) Send(m *transport.Message) error { func (ms *memorySocket) Send(m *transport.Message) error {
ms.RLock() ms.RLock()
defer ms.RUnlock() defer ms.RUnlock()
ctx := ms.ctx
if ms.timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ms.ctx, ms.timeout)
defer cancel()
}
select { select {
case <-ctx.Done():
return ctx.Err()
case <-ms.exit: case <-ms.exit:
return errors.New("connection closed") return errors.New("connection closed")
case <-ms.lexit: case <-ms.lexit:
@ -116,12 +143,14 @@ func (m *memoryListener) Accept(fn func(transport.Socket)) error {
return nil return nil
case c := <-m.conn: case c := <-m.conn:
go fn(&memorySocket{ go fn(&memorySocket{
lexit: c.lexit, lexit: c.lexit,
exit: c.exit, exit: c.exit,
send: c.recv, send: c.recv,
recv: c.send, recv: c.send,
local: c.Remote(), local: c.Remote(),
remote: c.Local(), remote: c.Local(),
timeout: m.topts.Timeout,
ctx: m.topts.Context,
}) })
} }
} }
@ -143,12 +172,14 @@ func (m *memoryTransport) Dial(addr string, opts ...transport.DialOption) (trans
client := &memoryClient{ client := &memoryClient{
&memorySocket{ &memorySocket{
send: make(chan *transport.Message), send: make(chan *transport.Message),
recv: make(chan *transport.Message), recv: make(chan *transport.Message),
exit: make(chan bool), exit: make(chan bool),
lexit: listener.exit, lexit: listener.exit,
local: addr, local: addr,
remote: addr, remote: addr,
timeout: m.opts.Timeout,
ctx: m.opts.Context,
}, },
options, options,
} }
@ -196,10 +227,12 @@ func (m *memoryTransport) Listen(addr string, opts ...transport.ListenOption) (t
} }
listener := &memoryListener{ listener := &memoryListener{
opts: options, lopts: options,
addr: addr, topts: m.opts,
conn: make(chan *memorySocket), addr: addr,
exit: make(chan bool), conn: make(chan *memorySocket),
exit: make(chan bool),
ctx: m.opts.Context,
} }
m.listeners[addr] = listener m.listeners[addr] = listener
@ -223,12 +256,18 @@ func (m *memoryTransport) String() string {
} }
func NewTransport(opts ...transport.Option) transport.Transport { func NewTransport(opts ...transport.Option) transport.Transport {
rand.Seed(time.Now().UnixNano())
var options transport.Options var options transport.Options
rand.Seed(time.Now().UnixNano())
for _, o := range opts { for _, o := range opts {
o(&options) o(&options)
} }
if options.Context == nil {
options.Context = context.Background()
}
return &memoryTransport{ return &memoryTransport{
opts: options, opts: options,
listeners: make(map[string]*memoryListener), listeners: make(map[string]*memoryListener),