api/handler/rpc: binary streaming support (#1466)
* api/handler/rpc: binary streaming support Signed-off-by: Vasiliy Tolstov <v.tolstov@unistack.org> * fixup Signed-off-by: Vasiliy Tolstov <v.tolstov@unistack.org> * fix Signed-off-by: Vasiliy Tolstov <v.tolstov@unistack.org> * fix sec webscoekt protol Signed-off-by: Vasiliy Tolstov <v.tolstov@unistack.org>
This commit is contained in:
		| @@ -1,107 +1,194 @@ | ||||
| package rpc | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gorilla/websocket" | ||||
| 	"github.com/gobwas/httphead" | ||||
| 	"github.com/gobwas/ws" | ||||
| 	"github.com/gobwas/ws/wsutil" | ||||
| 	"github.com/micro/go-micro/v2/api" | ||||
| 	"github.com/micro/go-micro/v2/client" | ||||
| 	"github.com/micro/go-micro/v2/client/selector" | ||||
| 	raw "github.com/micro/go-micro/v2/codec/bytes" | ||||
| 	"github.com/micro/go-micro/v2/logger" | ||||
| ) | ||||
|  | ||||
| var upgrader = websocket.Upgrader{ | ||||
| 	ReadBufferSize:  1024, | ||||
| 	WriteBufferSize: 1024, | ||||
| } | ||||
|  | ||||
| // serveWebsocket will stream rpc back over websockets assuming json | ||||
| func serveWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request, service *api.Service, c client.Client) { | ||||
| 	// upgrade the connection | ||||
| 	conn, err := upgrader.Upgrade(w, r, nil) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	// close on exit | ||||
| 	defer conn.Close() | ||||
| 	var op ws.OpCode | ||||
|  | ||||
| 	// wait for the first request so we know | ||||
| 	_, p, err := conn.ReadMessage() | ||||
| 	ct := r.Header.Get("Content-Type") | ||||
| 	// Strip charset from Content-Type (like `application/json; charset=UTF-8`) | ||||
| 	if idx := strings.IndexRune(ct, ';'); idx >= 0 { | ||||
| 		ct = ct[:idx] | ||||
| 	} | ||||
|  | ||||
| 	// check proto from request | ||||
| 	switch ct { | ||||
| 	case "application/json": | ||||
| 		op = ws.OpText | ||||
| 	default: | ||||
| 		op = ws.OpBinary | ||||
| 	} | ||||
|  | ||||
| 	hdr := make(http.Header) | ||||
| 	if proto, ok := r.Header["Sec-WebSocket-Protocol"]; ok { | ||||
| 		for _, p := range proto { | ||||
| 			switch p { | ||||
| 			case "binary": | ||||
| 				hdr["Sec-WebSocket-Protocol"] = []string{"binary"} | ||||
| 				op = ws.OpBinary | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	payload, err := requestPayload(r) | ||||
| 	if err != nil { | ||||
| 		if logger.V(logger.ErrorLevel, logger.DefaultLogger) { | ||||
| 			logger.Error(err) | ||||
| 		} | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// send to backend | ||||
| 	// default to trying json | ||||
| 	var request json.RawMessage | ||||
| 	// if the extracted payload isn't empty lets use it | ||||
| 	if len(p) > 0 { | ||||
| 		request = json.RawMessage(p) | ||||
| 	upgrader := ws.HTTPUpgrader{Timeout: 5 * time.Second, | ||||
| 		Protocol: func(proto string) bool { | ||||
| 			if strings.Contains(proto, "binary") { | ||||
| 				return true | ||||
| 			} | ||||
| 			// fallback to support all protocols now | ||||
| 			return true | ||||
| 		}, | ||||
| 		Extension: func(httphead.Option) bool { | ||||
| 			// disable extensions for compatibility | ||||
| 			return false | ||||
| 		}, | ||||
| 		Header: hdr, | ||||
| 	} | ||||
|  | ||||
| 	// create a request to the backend | ||||
| 	conn, rw, _, err := upgrader.Upgrade(r, w) | ||||
| 	if err != nil { | ||||
| 		if logger.V(logger.ErrorLevel, logger.DefaultLogger) { | ||||
| 			logger.Error(err) | ||||
| 		} | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	defer func() { | ||||
| 		if err := conn.Close(); err != nil { | ||||
| 			if logger.V(logger.ErrorLevel, logger.DefaultLogger) { | ||||
| 				logger.Error(err) | ||||
| 			} | ||||
| 			return | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	var request interface{} | ||||
| 	if !bytes.Equal(payload, []byte(`{}`)) { | ||||
| 		switch ct { | ||||
| 		case "application/json", "": | ||||
| 			m := json.RawMessage(payload) | ||||
| 			request = &m | ||||
| 		default: | ||||
| 			request = &raw.Frame{Data: payload} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// we always need to set content type for message | ||||
| 	if ct == "" { | ||||
| 		ct = "application/json" | ||||
| 	} | ||||
| 	req := c.NewRequest( | ||||
| 		service.Name, | ||||
| 		service.Endpoint.Name, | ||||
| 		&request, | ||||
| 		client.WithContentType("application/json"), | ||||
| 		request, | ||||
| 		client.WithContentType(ct), | ||||
| 		client.StreamingRequest(), | ||||
| 	) | ||||
|  | ||||
| 	so := selector.WithStrategy(strategy(service.Services)) | ||||
|  | ||||
| 	// create a new stream | ||||
| 	stream, err := c.Stream(ctx, req, client.WithSelectOption(so)) | ||||
| 	if err != nil { | ||||
| 		if logger.V(logger.ErrorLevel, logger.DefaultLogger) { | ||||
| 			logger.Error(err) | ||||
| 		} | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// send the first request for the client | ||||
| 	// since | ||||
| 	if err := stream.Send(request); err != nil { | ||||
| 		return | ||||
| 	if request != nil { | ||||
| 		if err = stream.Send(request); err != nil { | ||||
| 			if logger.V(logger.ErrorLevel, logger.DefaultLogger) { | ||||
| 				logger.Error(err) | ||||
| 			} | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	go writeLoop(conn, stream) | ||||
| 	go writeLoop(rw, stream) | ||||
|  | ||||
| 	resp := stream.Response() | ||||
| 	rsp := stream.Response() | ||||
|  | ||||
| 	// receive from stream and send to client | ||||
| 	for { | ||||
| 		// read backend response body | ||||
| 		body, err := resp.Read() | ||||
| 		buf, err := rsp.Read() | ||||
| 		if err != nil { | ||||
| 			if logger.V(logger.ErrorLevel, logger.DefaultLogger) { | ||||
| 				logger.Error(err) | ||||
| 			} | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		// write the response | ||||
| 		if err := conn.WriteMessage(websocket.TextMessage, body); err != nil { | ||||
| 		if err := wsutil.WriteServerMessage(rw, op, buf); err != nil { | ||||
| 			if logger.V(logger.ErrorLevel, logger.DefaultLogger) { | ||||
| 				logger.Error(err) | ||||
| 			} | ||||
| 			return | ||||
| 		} | ||||
| 		if err = rw.Flush(); err != nil { | ||||
| 			if logger.V(logger.ErrorLevel, logger.DefaultLogger) { | ||||
| 				logger.Error(err) | ||||
| 			} | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // writeLoop | ||||
| func writeLoop(conn *websocket.Conn, stream client.Stream) { | ||||
| func writeLoop(rw io.ReadWriter, stream client.Stream) { | ||||
| 	// close stream when done | ||||
| 	defer stream.Close() | ||||
|  | ||||
| 	for { | ||||
| 		_, p, err := conn.ReadMessage() | ||||
| 		buf, op, err := wsutil.ReadClientData(rw) | ||||
| 		if err != nil { | ||||
| 			if logger.V(logger.ErrorLevel, logger.DefaultLogger) { | ||||
| 				logger.Error(err) | ||||
| 			} | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		switch op { | ||||
| 		default: | ||||
| 			// not relevant | ||||
| 			continue | ||||
| 		case ws.OpText, ws.OpBinary: | ||||
| 			break | ||||
| 		} | ||||
| 		// send to backend | ||||
| 		// default to trying json | ||||
| 		var request json.RawMessage | ||||
| 		// if the extracted payload isn't empty lets use it | ||||
| 		if len(p) > 0 { | ||||
| 			request = json.RawMessage(p) | ||||
| 		} | ||||
| 		request := &raw.Frame{Data: buf} | ||||
|  | ||||
| 		if err := stream.Send(request); err != nil { | ||||
| 			if logger.V(logger.ErrorLevel, logger.DefaultLogger) { | ||||
| 				logger.Error(err) | ||||
| 			} | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| @@ -112,7 +199,6 @@ func isStream(r *http.Request, srv *api.Service) bool { | ||||
| 	if !isWebSocket(r) { | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	// check if the endpoint supports streaming | ||||
| 	for _, service := range srv.Services { | ||||
| 		for _, ep := range service.Endpoints { | ||||
| @@ -120,14 +206,12 @@ func isStream(r *http.Request, srv *api.Service) bool { | ||||
| 			if ep.Name != srv.Endpoint.Name { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			// matched if the name | ||||
| 			if v := ep.Metadata["stream"]; v == "true" { | ||||
| 				return true | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return false | ||||
| } | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user