api/handler: use http.MaxBytesReader and buffer pool (#1415)

* api/handler: use http.MaxBytesReader

protect api handlers from OOM cases

Signed-off-by: Vasiliy Tolstov <v.tolstov@unistack.org>
This commit is contained in:
Василий Толстов 2020-03-26 14:29:28 +03:00
parent 244c9fdb90
commit bf74b4394e

26
rpc.go
View File

@ -4,7 +4,6 @@ package rpc
import ( import (
"encoding/json" "encoding/json"
"io" "io"
"io/ioutil"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
@ -22,6 +21,7 @@ import (
"github.com/micro/go-micro/v2/logger" "github.com/micro/go-micro/v2/logger"
"github.com/micro/go-micro/v2/registry" "github.com/micro/go-micro/v2/registry"
"github.com/micro/go-micro/v2/util/ctx" "github.com/micro/go-micro/v2/util/ctx"
"github.com/oxtoacart/bpool"
) )
const ( const (
@ -45,6 +45,8 @@ var (
"application/proto-rpc", "application/proto-rpc",
"application/octet-stream", "application/octet-stream",
} }
bufferPool = bpool.NewSizedBufferPool(1024, 8)
) )
type rpcHandler struct { type rpcHandler struct {
@ -69,6 +71,13 @@ func strategy(services []*registry.Service) selector.Strategy {
} }
func (h *rpcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *rpcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
bsize := handler.DefaultMaxRecvSize
if h.opts.MaxRecvSize > 0 {
bsize = h.opts.MaxRecvSize
}
r.Body = http.MaxBytesReader(w, r.Body, bsize)
defer r.Body.Close() defer r.Body.Close()
var service *api.Service var service *api.Service
@ -240,8 +249,8 @@ func requestPayload(r *http.Request) ([]byte, error) {
if err := c.ReadBody(&raw); err != nil { if err := c.ReadBody(&raw); err != nil {
return nil, err return nil, err
} }
b, _ := raw.Marshal() b, err := raw.Marshal()
return b, nil return b, err
case strings.Contains(ct, "application/www-x-form-urlencoded"): case strings.Contains(ct, "application/www-x-form-urlencoded"):
r.ParseForm() r.ParseForm()
@ -252,8 +261,8 @@ func requestPayload(r *http.Request) ([]byte, error) {
} }
// marshal // marshal
b, _ := json.Marshal(vals) b, err := json.Marshal(vals)
return b, nil return b, err
// TODO: application/grpc // TODO: application/grpc
} }
@ -265,7 +274,12 @@ func requestPayload(r *http.Request) ([]byte, error) {
return qson.ToJSON(r.URL.RawQuery) return qson.ToJSON(r.URL.RawQuery)
} }
case "PATCH", "POST": case "PATCH", "POST":
return ioutil.ReadAll(r.Body) buf := bufferPool.Get()
defer bufferPool.Put(buf)
if _, err := buf.ReadFrom(r.Body); err != nil {
return nil, err
}
return buf.Bytes(), nil
} }
return []byte{}, nil return []byte{}, nil