From 10f1bd592f21b9fd5f918ad76591f96c34841208 Mon Sep 17 00:00:00 2001 From: Asim Aslam Date: Mon, 3 Jun 2019 18:44:43 +0100 Subject: [PATCH] Further consolidate the libraries --- rpc.go | 307 ++++++++++++++++++++++++++++++++++++++++++++++++++++ rpc_test.go | 95 ++++++++++++++++ 2 files changed, 402 insertions(+) create mode 100644 rpc.go create mode 100644 rpc_test.go diff --git a/rpc.go b/rpc.go new file mode 100644 index 0000000..28c6724 --- /dev/null +++ b/rpc.go @@ -0,0 +1,307 @@ +// Package rpc is a go-micro rpc handler. +package rpc + +import ( + "encoding/json" + "io" + "io/ioutil" + "net/http" + "strconv" + "strings" + + "github.com/joncalhoun/qson" + "github.com/micro/go-micro/api" + "github.com/micro/go-micro/api/handler" + proto "github.com/micro/go-micro/api/internal/proto" + "github.com/micro/go-micro/client" + "github.com/micro/go-micro/codec" + "github.com/micro/go-micro/codec/jsonrpc" + "github.com/micro/go-micro/codec/protorpc" + "github.com/micro/go-micro/errors" + "github.com/micro/go-micro/registry" + "github.com/micro/go-micro/selector" + "github.com/micro/go-micro/util/ctx" +) + +const ( + Handler = "rpc" +) + +var ( + // supported json codecs + jsonCodecs = []string{ + "application/grpc+json", + "application/json", + "application/json-rpc", + } + + // support proto codecs + protoCodecs = []string{ + "application/grpc", + "application/grpc+proto", + "application/proto", + "application/protobuf", + "application/proto-rpc", + "application/octet-stream", + } +) + +type rpcHandler struct { + opts handler.Options + s *api.Service +} + +type buffer struct { + io.ReadCloser +} + +func (b *buffer) Write(_ []byte) (int, error) { + return 0, nil +} + +// strategy is a hack for selection +func strategy(services []*registry.Service) selector.Strategy { + return func(_ []*registry.Service) selector.Next { + // ignore input to this function, use services above + return selector.Random(services) + } +} + +func (h *rpcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + var service *api.Service + + if h.s != nil { + // we were given the service + service = h.s + } else if h.opts.Router != nil { + // try get service from router + s, err := h.opts.Router.Route(r) + if err != nil { + writeError(w, r, errors.InternalServerError("go.micro.api", err.Error())) + return + } + service = s + } else { + // we have no way of routing the request + writeError(w, r, errors.InternalServerError("go.micro.api", "no route found")) + return + } + + // only allow post when we have the router + if r.Method != "GET" && (h.opts.Router != nil && r.Method != "POST") { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + ct := r.Header.Get("Content-Type") + + // Strip charset from Content-Type (like `application/json; charset=UTF-8`) + if idx := strings.IndexRune(ct, ';'); idx >= 0 { + ct = ct[:idx] + } + + // micro client + c := h.opts.Service.Client() + + // create strategy + so := selector.WithStrategy(strategy(service.Services)) + + // get payload + br, err := requestPayload(r) + if err != nil { + writeError(w, r, err) + return + } + + // create context + cx := ctx.FromRequest(r) + + var rsp []byte + + switch { + // json codecs + case hasCodec(ct, jsonCodecs): + var request json.RawMessage + // if the extracted payload isn't empty lets use it + if len(br) > 0 { + request = json.RawMessage(br) + } + + // create request/response + var response json.RawMessage + + req := c.NewRequest( + service.Name, + service.Endpoint.Name, + &request, + client.WithContentType(ct), + ) + + // make the call + if err := c.Call(cx, req, &response, client.WithSelectOption(so)); err != nil { + writeError(w, r, err) + return + } + + // marshall response + rsp, _ = response.MarshalJSON() + // proto codecs + case hasCodec(ct, protoCodecs): + request := &proto.Message{} + // if the extracted payload isn't empty lets use it + if len(br) > 0 { + request = proto.NewMessage(br) + } + + // create request/response + response := &proto.Message{} + + req := c.NewRequest( + service.Name, + service.Endpoint.Name, + request, + client.WithContentType(ct), + ) + + // make the call + if err := c.Call(cx, req, response, client.WithSelectOption(so)); err != nil { + writeError(w, r, err) + return + } + + // marshall response + rsp, _ = response.Marshal() + default: + http.Error(w, "Unsupported Content-Type", 400) + return + } + + // write the response + writeResponse(w, r, rsp) +} + +func (rh *rpcHandler) String() string { + return "rpc" +} + +func hasCodec(ct string, codecs []string) bool { + for _, codec := range codecs { + if ct == codec { + return true + } + } + return false +} + +// requestPayload takes a *http.Request. +// If the request is a GET the query string parameters are extracted and marshaled to JSON and the raw bytes are returned. +// If the request method is a POST the request body is read and returned +func requestPayload(r *http.Request) ([]byte, error) { + // we have to decode json-rpc and proto-rpc because we suck + // well actually because there's no proxy codec right now + switch r.Header.Get("Content-Type") { + case "application/json-rpc": + msg := codec.Message{ + Type: codec.Request, + Header: make(map[string]string), + } + c := jsonrpc.NewCodec(&buffer{r.Body}) + if err := c.ReadHeader(&msg, codec.Request); err != nil { + return nil, err + } + var raw json.RawMessage + if err := c.ReadBody(&raw); err != nil { + return nil, err + } + return ([]byte)(raw), nil + case "application/proto-rpc", "application/octet-stream": + msg := codec.Message{ + Type: codec.Request, + Header: make(map[string]string), + } + c := protorpc.NewCodec(&buffer{r.Body}) + if err := c.ReadHeader(&msg, codec.Request); err != nil { + return nil, err + } + var raw proto.Message + if err := c.ReadBody(&raw); err != nil { + return nil, err + } + b, _ := raw.Marshal() + return b, nil + } + + // otherwise as per usual + + switch r.Method { + case "GET": + if len(r.URL.RawQuery) > 0 { + return qson.ToJSON(r.URL.RawQuery) + } + case "PATCH", "POST": + return ioutil.ReadAll(r.Body) + } + + return []byte{}, nil +} + +func writeError(w http.ResponseWriter, r *http.Request, err error) { + ce := errors.Parse(err.Error()) + + switch ce.Code { + case 0: + // assuming it's totally screwed + ce.Code = 500 + ce.Id = "go.micro.api" + ce.Status = http.StatusText(500) + ce.Detail = "error during request: " + ce.Detail + w.WriteHeader(500) + default: + w.WriteHeader(int(ce.Code)) + } + + // response content type + w.Header().Set("Content-Type", "application/json") + + // Set trailers + if strings.Contains(r.Header.Get("Content-Type"), "application/grpc") { + w.Header().Set("Trailer", "grpc-status") + w.Header().Set("Trailer", "grpc-message") + w.Header().Set("grpc-status", "13") + w.Header().Set("grpc-message", ce.Detail) + } + + w.Write([]byte(ce.Error())) +} + +func writeResponse(w http.ResponseWriter, r *http.Request, rsp []byte) { + w.Header().Set("Content-Type", r.Header.Get("Content-Type")) + w.Header().Set("Content-Length", strconv.Itoa(len(rsp))) + + // Set trailers + if strings.Contains(r.Header.Get("Content-Type"), "application/grpc") { + w.Header().Set("Trailer", "grpc-status") + w.Header().Set("Trailer", "grpc-message") + w.Header().Set("grpc-status", "0") + w.Header().Set("grpc-message", "") + } + + // write response + w.Write(rsp) +} + +func NewHandler(opts ...handler.Option) handler.Handler { + options := handler.NewOptions(opts...) + return &rpcHandler{ + opts: options, + } +} + +func WithService(s *api.Service, opts ...handler.Option) handler.Handler { + options := handler.NewOptions(opts...) + return &rpcHandler{ + opts: options, + s: s, + } +} diff --git a/rpc_test.go b/rpc_test.go new file mode 100644 index 0000000..2804a84 --- /dev/null +++ b/rpc_test.go @@ -0,0 +1,95 @@ +package rpc + +import ( + "bytes" + "encoding/json" + "net/http" + "testing" + + "github.com/golang/protobuf/proto" + "github.com/micro/go-micro/api/proto" +) + +func TestRequestPayloadFromRequest(t *testing.T) { + + // our test event so that we can validate serialising / deserializing of true protos works + protoEvent := go_api.Event{ + Name: "Test", + } + + protoBytes, err := proto.Marshal(&protoEvent) + if err != nil { + t.Fatal("Failed to marshal proto", err) + } + + jsonBytes, err := json.Marshal(protoEvent) + if err != nil { + t.Fatal("Failed to marshal proto to JSON ", err) + } + + t.Run("extracting a proto from a POST request", func(t *testing.T) { + r, err := http.NewRequest("POST", "http://localhost/my/path", bytes.NewReader(protoBytes)) + if err != nil { + t.Fatalf("Failed to created http.Request: %v", err) + } + + extByte, err := requestPayload(r) + if err != nil { + t.Fatalf("Failed to extract payload from request: %v", err) + } + if string(extByte) != string(protoBytes) { + t.Fatalf("Expected %v and %v to match", string(extByte), string(protoBytes)) + } + }) + + t.Run("extracting JSON from a POST request", func(t *testing.T) { + r, err := http.NewRequest("POST", "http://localhost/my/path", bytes.NewReader(jsonBytes)) + if err != nil { + t.Fatalf("Failed to created http.Request: %v", err) + } + + extByte, err := requestPayload(r) + if err != nil { + t.Fatalf("Failed to extract payload from request: %v", err) + } + if string(extByte) != string(jsonBytes) { + t.Fatalf("Expected %v and %v to match", string(extByte), string(jsonBytes)) + } + }) + + t.Run("extracting params from a GET request", func(t *testing.T) { + + r, err := http.NewRequest("GET", "http://localhost/my/path", nil) + if err != nil { + t.Fatalf("Failed to created http.Request: %v", err) + } + + q := r.URL.Query() + q.Add("name", "Test") + r.URL.RawQuery = q.Encode() + + extByte, err := requestPayload(r) + if err != nil { + t.Fatalf("Failed to extract payload from request: %v", err) + } + if string(extByte) != string(jsonBytes) { + t.Fatalf("Expected %v and %v to match", string(extByte), string(jsonBytes)) + } + }) + + t.Run("GET request with no params", func(t *testing.T) { + + r, err := http.NewRequest("GET", "http://localhost/my/path", nil) + if err != nil { + t.Fatalf("Failed to created http.Request: %v", err) + } + + extByte, err := requestPayload(r) + if err != nil { + t.Fatalf("Failed to extract payload from request: %v", err) + } + if string(extByte) != "" { + t.Fatalf("Expected %v and %v to match", string(extByte), "") + } + }) +}