package rpc import ( "context" "encoding/json" "net/http" "strings" "github.com/gorilla/websocket" "github.com/micro/go-micro/v2/api" "github.com/micro/go-micro/v2/client" "github.com/micro/go-micro/v2/client/selector" ) var upgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, } // serveWebsocket will stream rpc back over websockets assuming json func serveWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request, service *api.Service, c client.Client) { // upgrade the connection conn, err := upgrader.Upgrade(w, r, nil) if err != nil { return } // close on exit defer conn.Close() // wait for the first request so we know _, p, err := conn.ReadMessage() if err != nil { return } // send to backend // default to trying json var request json.RawMessage // if the extracted payload isn't empty lets use it if len(p) > 0 { request = json.RawMessage(p) } // create a request to the backend req := c.NewRequest( service.Name, service.Endpoint.Name, &request, client.WithContentType("application/json"), ) so := selector.WithStrategy(strategy(service.Services)) // create a new stream stream, err := c.Stream(ctx, req, client.WithSelectOption(so)) if err != nil { return } // send the first request for the client // since if err := stream.Send(request); err != nil { return } go writeLoop(conn, stream) resp := stream.Response() // receive from stream and send to client for { // read backend response body body, err := resp.Read() if err != nil { return } // write the response if err := conn.WriteMessage(websocket.TextMessage, body); err != nil { return } } } // writeLoop func writeLoop(conn *websocket.Conn, stream client.Stream) { // close stream when done defer stream.Close() for { _, p, err := conn.ReadMessage() if err != nil { return } // send to backend // default to trying json var request json.RawMessage // if the extracted payload isn't empty lets use it if len(p) > 0 { request = json.RawMessage(p) } if err := stream.Send(request); err != nil { return } } } func isStream(r *http.Request, srv *api.Service) bool { // check if it's a web socket if !isWebSocket(r) { return false } // check if the endpoint supports streaming for _, service := range srv.Services { for _, ep := range service.Endpoints { // skip if it doesn't match the name if ep.Name != srv.Endpoint.Name { continue } // matched if the name if v := ep.Metadata["stream"]; v == "true" { return true } } } return false } func isWebSocket(r *http.Request) bool { contains := func(key, val string) bool { vv := strings.Split(r.Header.Get(key), ",") for _, v := range vv { if val == strings.ToLower(strings.TrimSpace(v)) { return true } } return false } if contains("Connection", "upgrade") && contains("Upgrade", "websocket") { return true } return false }