RPC stream client/server mutex fix (#884)
* Unlock RPC client while actually receiving a message As receiving a message might block for a long time, unblocking the client allows to let it send messages in the meanwhile without using 'tricks' * Unlock RPC server while actually receiving a message As receiving a message might block for a long time, unblocking the client allows to let it send messages in the meanwhile without using 'tricks' * Protect Close() against race conditions * Concurrency and Sequence tests
This commit is contained in:
parent
fa5b3ee9d9
commit
50b20413d3
@ -83,7 +83,10 @@ func (r *rpcStream) Recv(msg interface{}) error {
|
|||||||
|
|
||||||
var resp codec.Message
|
var resp codec.Message
|
||||||
|
|
||||||
if err := r.codec.ReadHeader(&resp, codec.Response); err != nil {
|
r.Unlock()
|
||||||
|
err := r.codec.ReadHeader(&resp, codec.Response)
|
||||||
|
r.Lock()
|
||||||
|
if err != nil {
|
||||||
if err == io.EOF && !r.isClosed() {
|
if err == io.EOF && !r.isClosed() {
|
||||||
r.err = io.ErrUnexpectedEOF
|
r.err = io.ErrUnexpectedEOF
|
||||||
return io.ErrUnexpectedEOF
|
return io.ErrUnexpectedEOF
|
||||||
@ -102,11 +105,17 @@ func (r *rpcStream) Recv(msg interface{}) error {
|
|||||||
} else {
|
} else {
|
||||||
r.err = io.EOF
|
r.err = io.EOF
|
||||||
}
|
}
|
||||||
if err := r.codec.ReadBody(nil); err != nil {
|
r.Unlock()
|
||||||
|
err = r.codec.ReadBody(nil)
|
||||||
|
r.Lock()
|
||||||
|
if err != nil {
|
||||||
r.err = err
|
r.err = err
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
if err := r.codec.ReadBody(msg); err != nil {
|
r.Unlock()
|
||||||
|
err = r.codec.ReadBody(msg)
|
||||||
|
r.Lock()
|
||||||
|
if err != nil {
|
||||||
r.err = err
|
r.err = err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -121,11 +130,15 @@ func (r *rpcStream) Error() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *rpcStream) Close() error {
|
func (r *rpcStream) Close() error {
|
||||||
|
r.RLock()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-r.closed:
|
case <-r.closed:
|
||||||
|
r.RUnlock()
|
||||||
return nil
|
return nil
|
||||||
default:
|
default:
|
||||||
close(r.closed)
|
close(r.closed)
|
||||||
|
r.RUnlock()
|
||||||
|
|
||||||
// send the end of stream message
|
// send the end of stream message
|
||||||
if r.sendEOS {
|
if r.sendEOS {
|
||||||
|
@ -48,13 +48,13 @@ func (r *rpcStream) Send(msg interface{}) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *rpcStream) Recv(msg interface{}) error {
|
func (r *rpcStream) Recv(msg interface{}) error {
|
||||||
r.Lock()
|
|
||||||
defer r.Unlock()
|
|
||||||
|
|
||||||
req := new(codec.Message)
|
req := new(codec.Message)
|
||||||
req.Type = codec.Request
|
req.Type = codec.Request
|
||||||
|
|
||||||
if err := r.codec.ReadHeader(req, req.Type); err != nil {
|
err := r.codec.ReadHeader(req, req.Type)
|
||||||
|
r.Lock()
|
||||||
|
defer r.Unlock()
|
||||||
|
if err != nil {
|
||||||
// discard body
|
// discard body
|
||||||
r.codec.ReadBody(nil)
|
r.codec.ReadBody(nil)
|
||||||
r.err = err
|
r.err = err
|
||||||
@ -67,7 +67,9 @@ func (r *rpcStream) Recv(msg interface{}) error {
|
|||||||
switch req.Error {
|
switch req.Error {
|
||||||
case lastStreamResponseError.Error():
|
case lastStreamResponseError.Error():
|
||||||
// discard body
|
// discard body
|
||||||
|
r.Unlock()
|
||||||
r.codec.ReadBody(nil)
|
r.codec.ReadBody(nil)
|
||||||
|
r.Lock()
|
||||||
r.err = io.EOF
|
r.err = io.EOF
|
||||||
return io.EOF
|
return io.EOF
|
||||||
default:
|
default:
|
||||||
@ -77,7 +79,10 @@ func (r *rpcStream) Recv(msg interface{}) error {
|
|||||||
|
|
||||||
// we need to stay up to date with sequence numbers
|
// we need to stay up to date with sequence numbers
|
||||||
r.id = req.Id
|
r.id = req.Id
|
||||||
if err := r.codec.ReadBody(msg); err != nil {
|
r.Unlock()
|
||||||
|
err = r.codec.ReadBody(msg)
|
||||||
|
r.Lock()
|
||||||
|
if err != nil {
|
||||||
r.err = err
|
r.err = err
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
133
server/rpc_stream_test.go
Normal file
133
server/rpc_stream_test.go
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math/rand"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang/protobuf/proto"
|
||||||
|
"github.com/micro/go-micro/codec/json"
|
||||||
|
protoCodec "github.com/micro/go-micro/codec/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// protoStruct implements proto.Message
|
||||||
|
type protoStruct struct {
|
||||||
|
Payload string `protobuf:"bytes,1,opt,name=service,proto3" json:"service,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *protoStruct) Reset() { *m = protoStruct{} }
|
||||||
|
func (m *protoStruct) String() string { return proto.CompactTextString(m) }
|
||||||
|
func (*protoStruct) ProtoMessage() {}
|
||||||
|
|
||||||
|
// safeBuffer throws away everything and wont Read data back
|
||||||
|
type safeBuffer struct {
|
||||||
|
sync.RWMutex
|
||||||
|
buf []byte
|
||||||
|
off int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *safeBuffer) Write(p []byte) (n int, err error) {
|
||||||
|
if len(p) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
// Cannot retain p, so we must copy it:
|
||||||
|
p2 := make([]byte, len(p))
|
||||||
|
copy(p2, p)
|
||||||
|
b.Lock()
|
||||||
|
b.buf = append(b.buf, p2...)
|
||||||
|
b.Unlock()
|
||||||
|
return len(p2), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *safeBuffer) Read(p []byte) (n int, err error) {
|
||||||
|
if len(p) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
b.RLock()
|
||||||
|
n = copy(p, b.buf[b.off:])
|
||||||
|
b.RUnlock()
|
||||||
|
if n == 0 {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
b.off += n
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *safeBuffer) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRPCStream_Sequence(t *testing.T) {
|
||||||
|
buffer := new(bytes.Buffer)
|
||||||
|
rwc := readWriteCloser{
|
||||||
|
rbuf: buffer,
|
||||||
|
wbuf: buffer,
|
||||||
|
}
|
||||||
|
codec := json.NewCodec(&rwc)
|
||||||
|
streamServer := rpcStream{
|
||||||
|
codec: codec,
|
||||||
|
request: &rpcRequest{
|
||||||
|
codec: codec,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if sequence is correct
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
if err := streamServer.Send(fmt.Sprintf(`{"test":"value %d"}`, i)); err != nil {
|
||||||
|
t.Errorf("Unexpected Send error: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
var msg string
|
||||||
|
if err := streamServer.Recv(&msg); err != nil {
|
||||||
|
t.Errorf("Unexpected Recv error: %s", err)
|
||||||
|
}
|
||||||
|
if msg != fmt.Sprintf(`{"test":"value %d"}`, i) {
|
||||||
|
t.Errorf("Unexpected msg: %s", msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRPCStream_Concurrency(t *testing.T) {
|
||||||
|
buffer := new(safeBuffer)
|
||||||
|
codec := protoCodec.NewCodec(buffer)
|
||||||
|
streamServer := rpcStream{
|
||||||
|
codec: codec,
|
||||||
|
request: &rpcRequest{
|
||||||
|
codec: codec,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
// Check if race conditions happen
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
wg.Add(2)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
msg := protoStruct{Payload: "test"}
|
||||||
|
<-time.After(time.Duration(rand.Intn(50)) * time.Millisecond)
|
||||||
|
if err := streamServer.Send(msg); err != nil {
|
||||||
|
t.Errorf("Unexpected Send error: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
<-time.After(time.Duration(rand.Intn(50)) * time.Millisecond)
|
||||||
|
if err := streamServer.Recv(&protoStruct{}); err != nil {
|
||||||
|
t.Errorf("Unexpected Recv error: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user