diff --git a/api/handler/rpc/stream.go b/api/handler/rpc/stream.go index a4741769..d54f933e 100644 --- a/api/handler/rpc/stream.go +++ b/api/handler/rpc/stream.go @@ -135,27 +135,38 @@ func serveWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request, // receive from stream and send to client for { - // read backend response body - buf, err := rsp.Read() - if err != nil { - if logger.V(logger.ErrorLevel, logger.DefaultLogger) { - logger.Error(err) - } + select { + case <-ctx.Done(): return - } + case <-stream.Context().Done(): + return + default: + // read backend response body + buf, err := rsp.Read() + if err != nil { + // wants to avoid import grpc/status.Status + if strings.Contains(err.Error(), "context canceled") { + return + } + if logger.V(logger.ErrorLevel, logger.DefaultLogger) { + logger.Error(err) + } + return + } - // write the response - if err := wsutil.WriteServerMessage(rw, op, buf); err != nil { - if logger.V(logger.ErrorLevel, logger.DefaultLogger) { - logger.Error(err) + // write the response + if err := wsutil.WriteServerMessage(rw, op, buf); err != nil { + if logger.V(logger.ErrorLevel, logger.DefaultLogger) { + logger.Error(err) + } + return } - return - } - if err = rw.Flush(); err != nil { - if logger.V(logger.ErrorLevel, logger.DefaultLogger) { - logger.Error(err) + if err = rw.Flush(); err != nil { + if logger.V(logger.ErrorLevel, logger.DefaultLogger) { + logger.Error(err) + } + return } - return } } } @@ -166,30 +177,40 @@ func writeLoop(rw io.ReadWriter, stream client.Stream) { defer stream.Close() for { - buf, op, err := wsutil.ReadClientData(rw) - if err != nil { - if logger.V(logger.ErrorLevel, logger.DefaultLogger) { - logger.Error(err) - } + select { + case <-stream.Context().Done(): return - } - switch op { default: - // not relevant - continue - case ws.OpText, ws.OpBinary: - break - } - // send to backend - // default to trying json - // if the extracted payload isn't empty lets use it - request := &raw.Frame{Data: buf} - - if err := stream.Send(request); err != nil { - if logger.V(logger.ErrorLevel, logger.DefaultLogger) { - logger.Error(err) + buf, op, err := wsutil.ReadClientData(rw) + if err != nil { + wserr := err.(wsutil.ClosedError) + switch wserr.Code { + case ws.StatusNormalClosure, ws.StatusNoStatusRcvd: + return + default: + if logger.V(logger.ErrorLevel, logger.DefaultLogger) { + logger.Error(err) + } + return + } + } + switch op { + default: + // not relevant + continue + case ws.OpText, ws.OpBinary: + break + } + // send to backend + // default to trying json + // if the extracted payload isn't empty lets use it + request := &raw.Frame{Data: buf} + if err := stream.Send(request); err != nil { + if logger.V(logger.ErrorLevel, logger.DefaultLogger) { + logger.Error(err) + } + return } - return } } }