diff --git a/api/handler/rpc/rpc.go b/api/handler/rpc/rpc.go index b3ab0db5..83a0d877 100644 --- a/api/handler/rpc/rpc.go +++ b/api/handler/rpc/rpc.go @@ -129,9 +129,11 @@ func (h *rpcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // set merged context to request *r = *r.Clone(cx) - // if stream we currently only support json if isStream(r, service) { + // drop older context as it can have timeouts and create new + // md, _ := metadata.FromContext(cx) + //serveWebsocket(context.TODO(), w, r, service, c) serveWebsocket(cx, w, r, service, c) return } diff --git a/api/handler/rpc/stream.go b/api/handler/rpc/stream.go index 29eee41b..a4741769 100644 --- a/api/handler/rpc/stream.go +++ b/api/handler/rpc/stream.go @@ -1,107 +1,194 @@ package rpc import ( + "bytes" "context" "encoding/json" + "io" "net/http" "strings" + "time" - "github.com/gorilla/websocket" + "github.com/gobwas/httphead" + "github.com/gobwas/ws" + "github.com/gobwas/ws/wsutil" "github.com/micro/go-micro/v2/api" "github.com/micro/go-micro/v2/client" "github.com/micro/go-micro/v2/client/selector" + raw "github.com/micro/go-micro/v2/codec/bytes" + "github.com/micro/go-micro/v2/logger" ) -var upgrader = websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, -} - // serveWebsocket will stream rpc back over websockets assuming json func serveWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request, service *api.Service, c client.Client) { - // upgrade the connection - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return - } - // close on exit - defer conn.Close() + var op ws.OpCode - // wait for the first request so we know - _, p, err := conn.ReadMessage() + ct := r.Header.Get("Content-Type") + // Strip charset from Content-Type (like `application/json; charset=UTF-8`) + if idx := strings.IndexRune(ct, ';'); idx >= 0 { + ct = ct[:idx] + } + + // check proto from request + switch ct { + case "application/json": + op = ws.OpText + default: + op = ws.OpBinary + } + + hdr := make(http.Header) + if proto, ok := r.Header["Sec-WebSocket-Protocol"]; ok { + for _, p := range proto { + switch p { + case "binary": + hdr["Sec-WebSocket-Protocol"] = []string{"binary"} + op = ws.OpBinary + } + } + } + payload, err := requestPayload(r) if err != nil { + if logger.V(logger.ErrorLevel, logger.DefaultLogger) { + logger.Error(err) + } return } - // send to backend - // default to trying json - var request json.RawMessage - // if the extracted payload isn't empty lets use it - if len(p) > 0 { - request = json.RawMessage(p) + upgrader := ws.HTTPUpgrader{Timeout: 5 * time.Second, + Protocol: func(proto string) bool { + if strings.Contains(proto, "binary") { + return true + } + // fallback to support all protocols now + return true + }, + Extension: func(httphead.Option) bool { + // disable extensions for compatibility + return false + }, + Header: hdr, } - // create a request to the backend + conn, rw, _, err := upgrader.Upgrade(r, w) + if err != nil { + if logger.V(logger.ErrorLevel, logger.DefaultLogger) { + logger.Error(err) + } + return + } + + defer func() { + if err := conn.Close(); err != nil { + if logger.V(logger.ErrorLevel, logger.DefaultLogger) { + logger.Error(err) + } + return + } + }() + + var request interface{} + if !bytes.Equal(payload, []byte(`{}`)) { + switch ct { + case "application/json", "": + m := json.RawMessage(payload) + request = &m + default: + request = &raw.Frame{Data: payload} + } + } + + // we always need to set content type for message + if ct == "" { + ct = "application/json" + } req := c.NewRequest( service.Name, service.Endpoint.Name, - &request, - client.WithContentType("application/json"), + request, + client.WithContentType(ct), + client.StreamingRequest(), ) so := selector.WithStrategy(strategy(service.Services)) - // create a new stream stream, err := c.Stream(ctx, req, client.WithSelectOption(so)) if err != nil { + if logger.V(logger.ErrorLevel, logger.DefaultLogger) { + logger.Error(err) + } return } - // send the first request for the client - // since - if err := stream.Send(request); err != nil { - return + if request != nil { + if err = stream.Send(request); err != nil { + if logger.V(logger.ErrorLevel, logger.DefaultLogger) { + logger.Error(err) + } + return + } } - go writeLoop(conn, stream) + go writeLoop(rw, stream) - resp := stream.Response() + rsp := stream.Response() // receive from stream and send to client for { // read backend response body - body, err := resp.Read() + buf, err := rsp.Read() if err != nil { + if logger.V(logger.ErrorLevel, logger.DefaultLogger) { + logger.Error(err) + } return } // write the response - if err := conn.WriteMessage(websocket.TextMessage, body); err != nil { + if err := wsutil.WriteServerMessage(rw, op, buf); err != nil { + if logger.V(logger.ErrorLevel, logger.DefaultLogger) { + logger.Error(err) + } + return + } + if err = rw.Flush(); err != nil { + if logger.V(logger.ErrorLevel, logger.DefaultLogger) { + logger.Error(err) + } return } } } // writeLoop -func writeLoop(conn *websocket.Conn, stream client.Stream) { +func writeLoop(rw io.ReadWriter, stream client.Stream) { // close stream when done defer stream.Close() for { - _, p, err := conn.ReadMessage() + buf, op, err := wsutil.ReadClientData(rw) if err != nil { + if logger.V(logger.ErrorLevel, logger.DefaultLogger) { + logger.Error(err) + } return } - + switch op { + default: + // not relevant + continue + case ws.OpText, ws.OpBinary: + break + } // send to backend // default to trying json - var request json.RawMessage // if the extracted payload isn't empty lets use it - if len(p) > 0 { - request = json.RawMessage(p) - } + request := &raw.Frame{Data: buf} if err := stream.Send(request); err != nil { + if logger.V(logger.ErrorLevel, logger.DefaultLogger) { + logger.Error(err) + } return } } @@ -112,7 +199,6 @@ func isStream(r *http.Request, srv *api.Service) bool { if !isWebSocket(r) { return false } - // check if the endpoint supports streaming for _, service := range srv.Services { for _, ep := range service.Endpoints { @@ -120,14 +206,12 @@ func isStream(r *http.Request, srv *api.Service) bool { if ep.Name != srv.Endpoint.Name { continue } - // matched if the name if v := ep.Metadata["stream"]; v == "true" { return true } } } - return false } diff --git a/api/router/static/static.go b/api/router/static/static.go index 3cf3bebe..3c2c25d4 100644 --- a/api/router/static/static.go +++ b/api/router/static/static.go @@ -15,6 +15,7 @@ import ( "github.com/micro/go-micro/v2/api/router" "github.com/micro/go-micro/v2/logger" "github.com/micro/go-micro/v2/metadata" + "github.com/micro/go-micro/v2/registry" ) type endpoint struct { @@ -163,13 +164,23 @@ func (r *staticRouter) Endpoint(req *http.Request) (*api.Service, error) { // hack for stream endpoint if ep.apiep.Stream { - for _, svc := range services { + svcs := registry.Copy(services) + for _, svc := range svcs { + if len(svc.Endpoints) == 0 { + e := ®istry.Endpoint{} + e.Name = strings.Join(epf[1:], ".") + e.Metadata = make(map[string]string) + e.Metadata["stream"] = "true" + svc.Endpoints = append(svc.Endpoints, e) + } for _, e := range svc.Endpoints { e.Name = strings.Join(epf[1:], ".") e.Metadata = make(map[string]string) e.Metadata["stream"] = "true" } } + + services = svcs } svc := &api.Service{ @@ -180,6 +191,7 @@ func (r *staticRouter) Endpoint(req *http.Request) (*api.Service, error) { Host: ep.apiep.Host, Method: ep.apiep.Method, Path: ep.apiep.Path, + Stream: ep.apiep.Stream, }, Services: services, } diff --git a/go.mod b/go.mod index 7a8615a6..c18607a6 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,9 @@ require ( github.com/ghodss/yaml v1.0.0 github.com/go-acme/lego/v3 v3.3.0 github.com/go-playground/universal-translator v0.17.0 // indirect + github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee + github.com/gobwas/pool v0.2.0 // indirect + github.com/gobwas/ws v1.0.3 github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e // indirect github.com/golang/protobuf v1.3.2 github.com/google/go-cmp v0.4.0 // indirect @@ -57,3 +60,5 @@ require ( gopkg.in/src-d/go-git.v4 v4.13.1 gopkg.in/telegram-bot-api.v4 v4.6.4 ) + +replace github.com/coreos/bbolt => go.etcd.io/bbolt v1.3.4 diff --git a/go.sum b/go.sum index 898b26f4..7d8a1ee1 100644 --- a/go.sum +++ b/go.sum @@ -140,6 +140,12 @@ github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTM github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD876Lmtgy7VtROAbHHXk8no= github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee h1:s+21KNqlpePfkah2I+gwHF8xmJWRjooY+5248k6m4A0= +github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= +github.com/gobwas/pool v0.2.0 h1:QEmUOlnSjWtnpRGHF3SauEiOsy82Cup83Vf2LcMlnc8= +github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= +github.com/gobwas/ws v1.0.3 h1:ZOigqf7iBxkA4jdQ3am7ATzdlOFp9YzA6NmuvEEZc9g= +github.com/gobwas/ws v1.0.3/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e/go.mod h1:bBOAhwG1umN6/6ZUMtDFBMQR8jRg9O75tm9K00oMsK4= github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=