Channel rather than mutex to check is closed

This commit is contained in:
Asim 2015-12-28 19:11:10 +00:00
parent 10d2ad0de9
commit bffd55f500
2 changed files with 23 additions and 8 deletions

View File

@ -136,9 +136,12 @@ func (r *rpcClient) stream(ctx context.Context, address string, req Request) (St
return nil, errors.InternalServerError("go.micro.client", fmt.Sprintf("Error sending request: %v", err)) return nil, errors.InternalServerError("go.micro.client", fmt.Sprintf("Error sending request: %v", err))
} }
var once sync.Once
stream := &rpcStream{ stream := &rpcStream{
context: ctx, context: ctx,
request: req, request: req,
once: once,
closed: make(chan bool),
codec: newRpcPlusCodec(msg, c, cf), codec: newRpcPlusCodec(msg, c, cf),
} }

View File

@ -13,13 +13,25 @@ import (
type rpcStream struct { type rpcStream struct {
sync.RWMutex sync.RWMutex
seq uint64 seq uint64
closed bool once sync.Once
closed chan bool
err error err error
request Request request Request
codec clientCodec codec clientCodec
context context.Context context context.Context
} }
func (r *rpcStream) isClosed() bool {
select {
case _, ok := <-r.closed:
if !ok {
return true
}
default:
}
return false
}
func (r *rpcStream) Context() context.Context { func (r *rpcStream) Context() context.Context {
return r.context return r.context
} }
@ -32,7 +44,7 @@ func (r *rpcStream) Send(msg interface{}) error {
r.Lock() r.Lock()
defer r.Unlock() defer r.Unlock()
if r.closed { if r.isClosed() {
r.err = errShutdown r.err = errShutdown
return errShutdown return errShutdown
} }
@ -57,14 +69,14 @@ func (r *rpcStream) Recv(msg interface{}) error {
r.Lock() r.Lock()
defer r.Unlock() defer r.Unlock()
if r.closed { if r.isClosed() {
r.err = errShutdown r.err = errShutdown
return errShutdown return errShutdown
} }
var resp response var resp response
if err := r.codec.ReadResponseHeader(&resp); err != nil { if err := r.codec.ReadResponseHeader(&resp); err != nil {
if err == io.EOF && !r.closed { if err == io.EOF && !r.isClosed() {
r.err = io.ErrUnexpectedEOF r.err = io.ErrUnexpectedEOF
return io.ErrUnexpectedEOF return io.ErrUnexpectedEOF
} }
@ -91,7 +103,7 @@ func (r *rpcStream) Recv(msg interface{}) error {
} }
} }
if r.err != nil && r.err != io.EOF && !r.closed { if r.err != nil && r.err != io.EOF && !r.isClosed() {
log.Println("rpc: client protocol error:", r.err) log.Println("rpc: client protocol error:", r.err)
} }
@ -105,8 +117,8 @@ func (r *rpcStream) Error() error {
} }
func (r *rpcStream) Close() error { func (r *rpcStream) Close() error {
r.Lock() r.once.Do(func() {
defer r.Unlock() close(r.closed)
r.closed = true })
return r.codec.Close() return r.codec.Close()
} }