Merge pull request #578 from unistack-org/memory
memory transport: fix race cond on channel close
This commit is contained in:
		| @@ -12,6 +12,10 @@ import ( | |||||||
| 	"github.com/micro/go-micro/transport" | 	"github.com/micro/go-micro/transport" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | var ( | ||||||
|  | 	r = rand.New(rand.NewSource(time.Now().UnixNano())) | ||||||
|  | ) | ||||||
|  |  | ||||||
| type memorySocket struct { | type memorySocket struct { | ||||||
| 	recv chan *transport.Message | 	recv chan *transport.Message | ||||||
| 	send chan *transport.Message | 	send chan *transport.Message | ||||||
| @@ -22,6 +26,7 @@ type memorySocket struct { | |||||||
|  |  | ||||||
| 	local  string | 	local  string | ||||||
| 	remote string | 	remote string | ||||||
|  | 	sync.RWMutex | ||||||
| } | } | ||||||
|  |  | ||||||
| type memoryClient struct { | type memoryClient struct { | ||||||
| @@ -34,16 +39,18 @@ type memoryListener struct { | |||||||
| 	exit chan bool | 	exit chan bool | ||||||
| 	conn chan *memorySocket | 	conn chan *memorySocket | ||||||
| 	opts transport.ListenOptions | 	opts transport.ListenOptions | ||||||
|  | 	sync.RWMutex | ||||||
| } | } | ||||||
|  |  | ||||||
| type memoryTransport struct { | type memoryTransport struct { | ||||||
| 	opts transport.Options | 	opts transport.Options | ||||||
|  | 	sync.RWMutex | ||||||
| 	sync.Mutex |  | ||||||
| 	listeners map[string]*memoryListener | 	listeners map[string]*memoryListener | ||||||
| } | } | ||||||
|  |  | ||||||
| func (ms *memorySocket) Recv(m *transport.Message) error { | func (ms *memorySocket) Recv(m *transport.Message) error { | ||||||
|  | 	ms.RLock() | ||||||
|  | 	defer ms.RUnlock() | ||||||
| 	select { | 	select { | ||||||
| 	case <-ms.exit: | 	case <-ms.exit: | ||||||
| 		return errors.New("connection closed") | 		return errors.New("connection closed") | ||||||
| @@ -64,6 +71,8 @@ func (ms *memorySocket) Remote() string { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (ms *memorySocket) Send(m *transport.Message) error { | func (ms *memorySocket) Send(m *transport.Message) error { | ||||||
|  | 	ms.RLock() | ||||||
|  | 	defer ms.RUnlock() | ||||||
| 	select { | 	select { | ||||||
| 	case <-ms.exit: | 	case <-ms.exit: | ||||||
| 		return errors.New("connection closed") | 		return errors.New("connection closed") | ||||||
| @@ -75,6 +84,8 @@ func (ms *memorySocket) Send(m *transport.Message) error { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (ms *memorySocket) Close() error { | func (ms *memorySocket) Close() error { | ||||||
|  | 	ms.RLock() | ||||||
|  | 	defer ms.RUnlock() | ||||||
| 	select { | 	select { | ||||||
| 	case <-ms.exit: | 	case <-ms.exit: | ||||||
| 		return nil | 		return nil | ||||||
| @@ -89,6 +100,8 @@ func (m *memoryListener) Addr() string { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (m *memoryListener) Close() error { | func (m *memoryListener) Close() error { | ||||||
|  | 	m.Lock() | ||||||
|  | 	defer m.Unlock() | ||||||
| 	select { | 	select { | ||||||
| 	case <-m.exit: | 	case <-m.exit: | ||||||
| 		return nil | 		return nil | ||||||
| @@ -117,8 +130,8 @@ func (m *memoryListener) Accept(fn func(transport.Socket)) error { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (m *memoryTransport) Dial(addr string, opts ...transport.DialOption) (transport.Client, error) { | func (m *memoryTransport) Dial(addr string, opts ...transport.DialOption) (transport.Client, error) { | ||||||
| 	m.Lock() | 	m.RLock() | ||||||
| 	defer m.Unlock() | 	defer m.RUnlock() | ||||||
|  |  | ||||||
| 	listener, ok := m.listeners[addr] | 	listener, ok := m.listeners[addr] | ||||||
| 	if !ok { | 	if !ok { | ||||||
| @@ -165,7 +178,6 @@ func (m *memoryTransport) Listen(addr string, opts ...transport.ListenOption) (t | |||||||
|  |  | ||||||
| 	// if zero port then randomly assign one | 	// if zero port then randomly assign one | ||||||
| 	if len(parts) > 1 && parts[len(parts)-1] == "0" { | 	if len(parts) > 1 && parts[len(parts)-1] == "0" { | ||||||
| 		r := rand.New(rand.NewSource(time.Now().UnixNano())) |  | ||||||
| 		i := r.Intn(10000) | 		i := r.Intn(10000) | ||||||
| 		// set addr with port | 		// set addr with port | ||||||
| 		addr = fmt.Sprintf("%s:%d", parts[:len(parts)-1], 10000+i) | 		addr = fmt.Sprintf("%s:%d", parts[:len(parts)-1], 10000+i) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user