diff --git a/rpc.go b/rpc.go index a246d99..3059207 100644 --- a/rpc.go +++ b/rpc.go @@ -4,7 +4,6 @@ package rpc import ( "encoding/json" "io" - "io/ioutil" "net/http" "strconv" "strings" @@ -22,6 +21,7 @@ import ( "github.com/micro/go-micro/v2/logger" "github.com/micro/go-micro/v2/registry" "github.com/micro/go-micro/v2/util/ctx" + "github.com/oxtoacart/bpool" ) const ( @@ -45,6 +45,8 @@ var ( "application/proto-rpc", "application/octet-stream", } + + bufferPool = bpool.NewSizedBufferPool(1024, 8) ) type rpcHandler struct { @@ -69,6 +71,13 @@ func strategy(services []*registry.Service) selector.Strategy { } func (h *rpcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + bsize := handler.DefaultMaxRecvSize + if h.opts.MaxRecvSize > 0 { + bsize = h.opts.MaxRecvSize + } + + r.Body = http.MaxBytesReader(w, r.Body, bsize) + defer r.Body.Close() var service *api.Service @@ -240,8 +249,8 @@ func requestPayload(r *http.Request) ([]byte, error) { if err := c.ReadBody(&raw); err != nil { return nil, err } - b, _ := raw.Marshal() - return b, nil + b, err := raw.Marshal() + return b, err case strings.Contains(ct, "application/www-x-form-urlencoded"): r.ParseForm() @@ -252,8 +261,8 @@ func requestPayload(r *http.Request) ([]byte, error) { } // marshal - b, _ := json.Marshal(vals) - return b, nil + b, err := json.Marshal(vals) + return b, err // TODO: application/grpc } @@ -265,7 +274,12 @@ func requestPayload(r *http.Request) ([]byte, error) { return qson.ToJSON(r.URL.RawQuery) } case "PATCH", "POST": - return ioutil.ReadAll(r.Body) + buf := bufferPool.Get() + defer bufferPool.Put(buf) + if _, err := buf.ReadFrom(r.Body); err != nil { + return nil, err + } + return buf.Bytes(), nil } return []byte{}, nil