RPC stream client/server mutex fix (#884)

* Unlock RPC client while actually receiving a message

As receiving a message might block for a long time, unblocking the client allows to let it send messages in the meanwhile without using 'tricks'

* Unlock RPC server while actually receiving a message

As receiving a message might block for a long time, unblocking the client allows to let it send messages in the meanwhile without using 'tricks'

* Protect Close() against race conditions

* Concurrency and Sequence tests
This commit is contained in:
Maarten Bezemer 2020-01-12 10:13:14 +01:00 committed by Asim Aslam
parent fa5b3ee9d9
commit 50b20413d3
3 changed files with 159 additions and 8 deletions

View File

@ -83,7 +83,10 @@ func (r *rpcStream) Recv(msg interface{}) error {
var resp codec.Message
if err := r.codec.ReadHeader(&resp, codec.Response); err != nil {
r.Unlock()
err := r.codec.ReadHeader(&resp, codec.Response)
r.Lock()
if err != nil {
if err == io.EOF && !r.isClosed() {
r.err = io.ErrUnexpectedEOF
return io.ErrUnexpectedEOF
@ -102,11 +105,17 @@ func (r *rpcStream) Recv(msg interface{}) error {
} else {
r.err = io.EOF
}
if err := r.codec.ReadBody(nil); err != nil {
r.Unlock()
err = r.codec.ReadBody(nil)
r.Lock()
if err != nil {
r.err = err
}
default:
if err := r.codec.ReadBody(msg); err != nil {
r.Unlock()
err = r.codec.ReadBody(msg)
r.Lock()
if err != nil {
r.err = err
}
}
@ -121,11 +130,15 @@ func (r *rpcStream) Error() error {
}
func (r *rpcStream) Close() error {
r.RLock()
select {
case <-r.closed:
r.RUnlock()
return nil
default:
close(r.closed)
r.RUnlock()
// send the end of stream message
if r.sendEOS {

View File

@ -48,13 +48,13 @@ func (r *rpcStream) Send(msg interface{}) error {
}
func (r *rpcStream) Recv(msg interface{}) error {
r.Lock()
defer r.Unlock()
req := new(codec.Message)
req.Type = codec.Request
if err := r.codec.ReadHeader(req, req.Type); err != nil {
err := r.codec.ReadHeader(req, req.Type)
r.Lock()
defer r.Unlock()
if err != nil {
// discard body
r.codec.ReadBody(nil)
r.err = err
@ -67,7 +67,9 @@ func (r *rpcStream) Recv(msg interface{}) error {
switch req.Error {
case lastStreamResponseError.Error():
// discard body
r.Unlock()
r.codec.ReadBody(nil)
r.Lock()
r.err = io.EOF
return io.EOF
default:
@ -77,7 +79,10 @@ func (r *rpcStream) Recv(msg interface{}) error {
// we need to stay up to date with sequence numbers
r.id = req.Id
if err := r.codec.ReadBody(msg); err != nil {
r.Unlock()
err = r.codec.ReadBody(msg)
r.Lock()
if err != nil {
r.err = err
return err
}

133
server/rpc_stream_test.go Normal file
View File

@ -0,0 +1,133 @@
package server
import (
"bytes"
"fmt"
"io"
"math/rand"
"sync"
"testing"
"time"
"github.com/golang/protobuf/proto"
"github.com/micro/go-micro/codec/json"
protoCodec "github.com/micro/go-micro/codec/proto"
)
// protoStruct implements proto.Message
type protoStruct struct {
Payload string `protobuf:"bytes,1,opt,name=service,proto3" json:"service,omitempty"`
}
func (m *protoStruct) Reset() { *m = protoStruct{} }
func (m *protoStruct) String() string { return proto.CompactTextString(m) }
func (*protoStruct) ProtoMessage() {}
// safeBuffer throws away everything and wont Read data back
type safeBuffer struct {
sync.RWMutex
buf []byte
off int
}
func (b *safeBuffer) Write(p []byte) (n int, err error) {
if len(p) == 0 {
return 0, nil
}
// Cannot retain p, so we must copy it:
p2 := make([]byte, len(p))
copy(p2, p)
b.Lock()
b.buf = append(b.buf, p2...)
b.Unlock()
return len(p2), nil
}
func (b *safeBuffer) Read(p []byte) (n int, err error) {
if len(p) == 0 {
return 0, nil
}
b.RLock()
n = copy(p, b.buf[b.off:])
b.RUnlock()
if n == 0 {
return 0, io.EOF
}
b.off += n
return n, nil
}
func (b *safeBuffer) Close() error {
return nil
}
func TestRPCStream_Sequence(t *testing.T) {
buffer := new(bytes.Buffer)
rwc := readWriteCloser{
rbuf: buffer,
wbuf: buffer,
}
codec := json.NewCodec(&rwc)
streamServer := rpcStream{
codec: codec,
request: &rpcRequest{
codec: codec,
},
}
// Check if sequence is correct
for i := 0; i < 1000; i++ {
if err := streamServer.Send(fmt.Sprintf(`{"test":"value %d"}`, i)); err != nil {
t.Errorf("Unexpected Send error: %s", err)
}
}
for i := 0; i < 1000; i++ {
var msg string
if err := streamServer.Recv(&msg); err != nil {
t.Errorf("Unexpected Recv error: %s", err)
}
if msg != fmt.Sprintf(`{"test":"value %d"}`, i) {
t.Errorf("Unexpected msg: %s", msg)
}
}
}
func TestRPCStream_Concurrency(t *testing.T) {
buffer := new(safeBuffer)
codec := protoCodec.NewCodec(buffer)
streamServer := rpcStream{
codec: codec,
request: &rpcRequest{
codec: codec,
},
}
var wg sync.WaitGroup
// Check if race conditions happen
for i := 0; i < 10; i++ {
wg.Add(2)
go func() {
for i := 0; i < 50; i++ {
msg := protoStruct{Payload: "test"}
<-time.After(time.Duration(rand.Intn(50)) * time.Millisecond)
if err := streamServer.Send(msg); err != nil {
t.Errorf("Unexpected Send error: %s", err)
}
}
wg.Done()
}()
go func() {
for i := 0; i < 50; i++ {
<-time.After(time.Duration(rand.Intn(50)) * time.Millisecond)
if err := streamServer.Recv(&protoStruct{}); err != nil {
t.Errorf("Unexpected Recv error: %s", err)
}
}
wg.Done()
}()
}
wg.Wait()
}