From 032a93615039f7ed771e8e0d09ec0bb92d40e3b5 Mon Sep 17 00:00:00 2001 From: Vasiliy Tolstov Date: Thu, 2 Apr 2020 12:13:04 +0300 Subject: [PATCH] api/handler/rpc: binary streaming support (#1466) * api/handler/rpc: binary streaming support Signed-off-by: Vasiliy Tolstov * fixup Signed-off-by: Vasiliy Tolstov * fix Signed-off-by: Vasiliy Tolstov * fix sec webscoekt protol Signed-off-by: Vasiliy Tolstov --- rpc.go | 4 +- stream.go | 170 ++++++++++++++++++++++++++++++++++++++++-------------- 2 files changed, 130 insertions(+), 44 deletions(-) diff --git a/rpc.go b/rpc.go index b3ab0db..83a0d87 100644 --- a/rpc.go +++ b/rpc.go @@ -129,9 +129,11 @@ func (h *rpcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // set merged context to request *r = *r.Clone(cx) - // if stream we currently only support json if isStream(r, service) { + // drop older context as it can have timeouts and create new + // md, _ := metadata.FromContext(cx) + //serveWebsocket(context.TODO(), w, r, service, c) serveWebsocket(cx, w, r, service, c) return } diff --git a/stream.go b/stream.go index 29eee41..a474176 100644 --- a/stream.go +++ b/stream.go @@ -1,107 +1,194 @@ package rpc import ( + "bytes" "context" "encoding/json" + "io" "net/http" "strings" + "time" - "github.com/gorilla/websocket" + "github.com/gobwas/httphead" + "github.com/gobwas/ws" + "github.com/gobwas/ws/wsutil" "github.com/micro/go-micro/v2/api" "github.com/micro/go-micro/v2/client" "github.com/micro/go-micro/v2/client/selector" + raw "github.com/micro/go-micro/v2/codec/bytes" + "github.com/micro/go-micro/v2/logger" ) -var upgrader = websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, -} - // serveWebsocket will stream rpc back over websockets assuming json func serveWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request, service *api.Service, c client.Client) { - // upgrade the connection - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return - } - // close on exit - defer conn.Close() + var op ws.OpCode - // wait for the first request so we know - _, p, err := conn.ReadMessage() + ct := r.Header.Get("Content-Type") + // Strip charset from Content-Type (like `application/json; charset=UTF-8`) + if idx := strings.IndexRune(ct, ';'); idx >= 0 { + ct = ct[:idx] + } + + // check proto from request + switch ct { + case "application/json": + op = ws.OpText + default: + op = ws.OpBinary + } + + hdr := make(http.Header) + if proto, ok := r.Header["Sec-WebSocket-Protocol"]; ok { + for _, p := range proto { + switch p { + case "binary": + hdr["Sec-WebSocket-Protocol"] = []string{"binary"} + op = ws.OpBinary + } + } + } + payload, err := requestPayload(r) if err != nil { + if logger.V(logger.ErrorLevel, logger.DefaultLogger) { + logger.Error(err) + } return } - // send to backend - // default to trying json - var request json.RawMessage - // if the extracted payload isn't empty lets use it - if len(p) > 0 { - request = json.RawMessage(p) + upgrader := ws.HTTPUpgrader{Timeout: 5 * time.Second, + Protocol: func(proto string) bool { + if strings.Contains(proto, "binary") { + return true + } + // fallback to support all protocols now + return true + }, + Extension: func(httphead.Option) bool { + // disable extensions for compatibility + return false + }, + Header: hdr, } - // create a request to the backend + conn, rw, _, err := upgrader.Upgrade(r, w) + if err != nil { + if logger.V(logger.ErrorLevel, logger.DefaultLogger) { + logger.Error(err) + } + return + } + + defer func() { + if err := conn.Close(); err != nil { + if logger.V(logger.ErrorLevel, logger.DefaultLogger) { + logger.Error(err) + } + return + } + }() + + var request interface{} + if !bytes.Equal(payload, []byte(`{}`)) { + switch ct { + case "application/json", "": + m := json.RawMessage(payload) + request = &m + default: + request = &raw.Frame{Data: payload} + } + } + + // we always need to set content type for message + if ct == "" { + ct = "application/json" + } req := c.NewRequest( service.Name, service.Endpoint.Name, - &request, - client.WithContentType("application/json"), + request, + client.WithContentType(ct), + client.StreamingRequest(), ) so := selector.WithStrategy(strategy(service.Services)) - // create a new stream stream, err := c.Stream(ctx, req, client.WithSelectOption(so)) if err != nil { + if logger.V(logger.ErrorLevel, logger.DefaultLogger) { + logger.Error(err) + } return } - // send the first request for the client - // since - if err := stream.Send(request); err != nil { - return + if request != nil { + if err = stream.Send(request); err != nil { + if logger.V(logger.ErrorLevel, logger.DefaultLogger) { + logger.Error(err) + } + return + } } - go writeLoop(conn, stream) + go writeLoop(rw, stream) - resp := stream.Response() + rsp := stream.Response() // receive from stream and send to client for { // read backend response body - body, err := resp.Read() + buf, err := rsp.Read() if err != nil { + if logger.V(logger.ErrorLevel, logger.DefaultLogger) { + logger.Error(err) + } return } // write the response - if err := conn.WriteMessage(websocket.TextMessage, body); err != nil { + if err := wsutil.WriteServerMessage(rw, op, buf); err != nil { + if logger.V(logger.ErrorLevel, logger.DefaultLogger) { + logger.Error(err) + } + return + } + if err = rw.Flush(); err != nil { + if logger.V(logger.ErrorLevel, logger.DefaultLogger) { + logger.Error(err) + } return } } } // writeLoop -func writeLoop(conn *websocket.Conn, stream client.Stream) { +func writeLoop(rw io.ReadWriter, stream client.Stream) { // close stream when done defer stream.Close() for { - _, p, err := conn.ReadMessage() + buf, op, err := wsutil.ReadClientData(rw) if err != nil { + if logger.V(logger.ErrorLevel, logger.DefaultLogger) { + logger.Error(err) + } return } - + switch op { + default: + // not relevant + continue + case ws.OpText, ws.OpBinary: + break + } // send to backend // default to trying json - var request json.RawMessage // if the extracted payload isn't empty lets use it - if len(p) > 0 { - request = json.RawMessage(p) - } + request := &raw.Frame{Data: buf} if err := stream.Send(request); err != nil { + if logger.V(logger.ErrorLevel, logger.DefaultLogger) { + logger.Error(err) + } return } } @@ -112,7 +199,6 @@ func isStream(r *http.Request, srv *api.Service) bool { if !isWebSocket(r) { return false } - // check if the endpoint supports streaming for _, service := range srv.Services { for _, ep := range service.Endpoints { @@ -120,14 +206,12 @@ func isStream(r *http.Request, srv *api.Service) bool { if ep.Name != srv.Endpoint.Name { continue } - // matched if the name if v := ep.Metadata["stream"]; v == "true" { return true } } } - return false }