From 02839cfba5d74b0fc4ac4d6b251987c6dbf5410b Mon Sep 17 00:00:00 2001 From: Vasiliy Tolstov Date: Thu, 26 Mar 2020 14:29:28 +0300 Subject: [PATCH] api/handler: use http.MaxBytesReader and buffer pool (#1415) * api/handler: use http.MaxBytesReader protect api handlers from OOM cases Signed-off-by: Vasiliy Tolstov --- api/handler/api/api.go | 6 ++++ api/handler/api/util.go | 16 ++++++--- api/handler/broker/broker.go | 22 +++++++++--- api/handler/cloudevents/cloudevents.go | 15 +++++--- api/handler/cloudevents/event.go | 12 +++++-- api/handler/event/event.go | 29 ++++++++++----- api/handler/options.go | 22 ++++++++++-- api/handler/registry/registry.go | 49 ++++++++++++++++---------- api/handler/rpc/rpc.go | 26 ++++++++++---- 9 files changed, 146 insertions(+), 51 deletions(-) diff --git a/api/handler/api/api.go b/api/handler/api/api.go index 7a14fb4f..85c7be75 100644 --- a/api/handler/api/api.go +++ b/api/handler/api/api.go @@ -24,6 +24,12 @@ const ( // API handler is the default handler which takes api.Request and returns api.Response func (a *apiHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + bsize := handler.DefaultMaxRecvSize + if a.opts.MaxRecvSize > 0 { + bsize = a.opts.MaxRecvSize + } + + r.Body = http.MaxBytesReader(w, r.Body, bsize) request, err := requestToProto(r) if err != nil { er := errors.InternalServerError("go.micro.api", err.Error()) diff --git a/api/handler/api/util.go b/api/handler/api/util.go index f66f611b..1caf0ca4 100644 --- a/api/handler/api/util.go +++ b/api/handler/api/util.go @@ -2,7 +2,6 @@ package api import ( "fmt" - "io/ioutil" "mime" "net" "net/http" @@ -11,6 +10,12 @@ import ( api "github.com/micro/go-micro/v2/api/proto" "github.com/micro/go-micro/v2/client/selector" "github.com/micro/go-micro/v2/registry" + "github.com/oxtoacart/bpool" +) + +var ( + // need to calculate later to specify useful defaults + bufferPool = bpool.NewSizedBufferPool(1024, 8) ) func requestToProto(r *http.Request) (*api.Request, error) { @@ -39,9 +44,12 @@ func requestToProto(r *http.Request) (*api.Request, error) { case "application/x-www-form-urlencoded": // expect form vals in Post data default: - - data, _ := ioutil.ReadAll(r.Body) - req.Body = string(data) + buf := bufferPool.Get() + defer bufferPool.Put(buf) + if _, err = buf.ReadFrom(r.Body); err != nil { + return nil, err + } + req.Body = buf.String() } } diff --git a/api/handler/broker/broker.go b/api/handler/broker/broker.go index 8f82446b..7c04c0aa 100644 --- a/api/handler/broker/broker.go +++ b/api/handler/broker/broker.go @@ -3,7 +3,6 @@ package broker import ( "encoding/json" - "io/ioutil" "net/http" "net/url" "strings" @@ -15,6 +14,11 @@ import ( "github.com/micro/go-micro/v2/api/handler" "github.com/micro/go-micro/v2/broker" "github.com/micro/go-micro/v2/logger" + "github.com/oxtoacart/bpool" +) + +var ( + bufferPool = bpool.NewSizedBufferPool(1024, 8) ) const ( @@ -155,6 +159,13 @@ func (c *conn) writeLoop() { } func (b *brokerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + bsize := handler.DefaultMaxRecvSize + if b.opts.MaxRecvSize > 0 { + bsize = b.opts.MaxRecvSize + } + + r.Body = http.MaxBytesReader(w, r.Body, bsize) + br := b.opts.Service.Client().Options().Broker // Setup the broker @@ -191,14 +202,15 @@ func (b *brokerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } // Read body - b, err := ioutil.ReadAll(r.Body) - if err != nil { + buf := bufferPool.Get() + defer bufferPool.Put(buf) + if _, err := buf.ReadFrom(r.Body); err != nil { http.Error(w, err.Error(), 500) return } - // Set body - msg.Body = b + msg.Body = buf.Bytes() + // Set body // Publish br.Publish(topic, msg) diff --git a/api/handler/cloudevents/cloudevents.go b/api/handler/cloudevents/cloudevents.go index 2eec44ef..630412da 100644 --- a/api/handler/cloudevents/cloudevents.go +++ b/api/handler/cloudevents/cloudevents.go @@ -12,7 +12,7 @@ import ( ) type event struct { - options handler.Options + opts handler.Options } var ( @@ -58,10 +58,17 @@ func evRoute(ns, p string) (string, string) { } func (e *event) ServeHTTP(w http.ResponseWriter, r *http.Request) { + bsize := handler.DefaultMaxRecvSize + if e.opts.MaxRecvSize > 0 { + bsize = e.opts.MaxRecvSize + } + + r.Body = http.MaxBytesReader(w, r.Body, bsize) + // request to topic:event // create event // publish to topic - topic, _ := evRoute(e.options.Namespace, r.URL.Path) + topic, _ := evRoute(e.opts.Namespace, r.URL.Path) // create event ev, err := FromRequest(r) @@ -71,7 +78,7 @@ func (e *event) ServeHTTP(w http.ResponseWriter, r *http.Request) { } // get client - c := e.options.Service.Client() + c := e.opts.Service.Client() // create publication p := c.NewMessage(topic, ev) @@ -89,6 +96,6 @@ func (e *event) String() string { func NewHandler(opts ...handler.Option) handler.Handler { return &event{ - options: handler.NewOptions(opts...), + opts: handler.NewOptions(opts...), } } diff --git a/api/handler/cloudevents/event.go b/api/handler/cloudevents/event.go index 4869188b..0463792c 100644 --- a/api/handler/cloudevents/event.go +++ b/api/handler/cloudevents/event.go @@ -24,7 +24,6 @@ import ( "encoding/json" "errors" "fmt" - "io/ioutil" "mime" "net/http" "strings" @@ -32,9 +31,14 @@ import ( "unicode" "github.com/google/uuid" + "github.com/oxtoacart/bpool" validator "gopkg.in/go-playground/validator.v9" ) +var ( + bufferPool = bpool.NewSizedBufferPool(1024, 8) +) + const ( // TransformationVersion is indicative of the revision of how Event Gateway transforms a request into CloudEvents format. TransformationVersion = "0.1" @@ -97,10 +101,12 @@ func FromRequest(r *http.Request) (*Event, error) { // Read request body body := []byte{} if r.Body != nil { - body, err = ioutil.ReadAll(r.Body) - if err != nil { + buf := bufferPool.Get() + defer bufferPool.Put(buf) + if _, err := buf.ReadFrom(r.Body); err != nil { return nil, err } + body = buf.Bytes() } var event *Event diff --git a/api/handler/event/event.go b/api/handler/event/event.go index a819d947..27393165 100644 --- a/api/handler/event/event.go +++ b/api/handler/event/event.go @@ -4,7 +4,6 @@ package event import ( "encoding/json" "fmt" - "io/ioutil" "net/http" "path" "regexp" @@ -15,10 +14,15 @@ import ( "github.com/micro/go-micro/v2/api/handler" proto "github.com/micro/go-micro/v2/api/proto" "github.com/micro/go-micro/v2/util/ctx" + "github.com/oxtoacart/bpool" +) + +var ( + bufferPool = bpool.NewSizedBufferPool(1024, 8) ) type event struct { - options handler.Options + opts handler.Options } var ( @@ -64,11 +68,18 @@ func evRoute(ns, p string) (string, string) { } func (e *event) ServeHTTP(w http.ResponseWriter, r *http.Request) { + bsize := handler.DefaultMaxRecvSize + if e.opts.MaxRecvSize > 0 { + bsize = e.opts.MaxRecvSize + } + + r.Body = http.MaxBytesReader(w, r.Body, bsize) + // request to topic:event // create event // publish to topic - topic, action := evRoute(e.options.Namespace, r.URL.Path) + topic, action := evRoute(e.opts.Namespace, r.URL.Path) // create event ev := &proto.Event{ @@ -96,16 +107,18 @@ func (e *event) ServeHTTP(w http.ResponseWriter, r *http.Request) { bytes, _ := json.Marshal(r.URL.Query()) ev.Data = string(bytes) } else { - b, err := ioutil.ReadAll(r.Body) - if err != nil { + // Read body + buf := bufferPool.Get() + defer bufferPool.Put(buf) + if _, err := buf.ReadFrom(r.Body); err != nil { http.Error(w, err.Error(), 500) return } - ev.Data = string(b) + ev.Data = buf.String() } // get client - c := e.options.Service.Client() + c := e.opts.Service.Client() // create publication p := c.NewMessage(topic, ev) @@ -123,6 +136,6 @@ func (e *event) String() string { func NewHandler(opts ...handler.Option) handler.Handler { return &event{ - options: handler.NewOptions(opts...), + opts: handler.NewOptions(opts...), } } diff --git a/api/handler/options.go b/api/handler/options.go index d7403270..6feeec24 100644 --- a/api/handler/options.go +++ b/api/handler/options.go @@ -5,10 +5,15 @@ import ( "github.com/micro/go-micro/v2/api/router" ) +var ( + DefaultMaxRecvSize int64 = 1024 * 1024 * 10 // 10Mb +) + type Options struct { - Namespace string - Router router.Router - Service micro.Service + MaxRecvSize int64 + Namespace string + Router router.Router + Service micro.Service } type Option func(o *Options) @@ -30,6 +35,10 @@ func NewOptions(opts ...Option) Options { WithNamespace("go.micro.api")(&options) } + if options.MaxRecvSize == 0 { + options.MaxRecvSize = DefaultMaxRecvSize + } + return options } @@ -53,3 +62,10 @@ func WithService(s micro.Service) Option { o.Service = s } } + +// WithmaxRecvSize specifies max body size +func WithMaxRecvSize(size int64) Option { + return func(o *Options) { + o.MaxRecvSize = size + } +} diff --git a/api/handler/registry/registry.go b/api/handler/registry/registry.go index 4aca36da..b0eeaf5a 100644 --- a/api/handler/registry/registry.go +++ b/api/handler/registry/registry.go @@ -3,7 +3,6 @@ package registry import ( "encoding/json" - "io/ioutil" "net/http" "strconv" "time" @@ -11,6 +10,11 @@ import ( "github.com/gorilla/websocket" "github.com/micro/go-micro/v2/api/handler" "github.com/micro/go-micro/v2/registry" + "github.com/oxtoacart/bpool" +) + +var ( + bufferPool = bpool.NewSizedBufferPool(1024, 8) ) const ( @@ -29,12 +33,15 @@ type registryHandler struct { func (rh *registryHandler) add(w http.ResponseWriter, r *http.Request) { r.ParseForm() - b, err := ioutil.ReadAll(r.Body) - if err != nil { + defer r.Body.Close() + + // Read body + buf := bufferPool.Get() + defer bufferPool.Put(buf) + if _, err := buf.ReadFrom(r.Body); err != nil { http.Error(w, err.Error(), 500) return } - defer r.Body.Close() var opts []registry.RegisterOption @@ -47,13 +54,11 @@ func (rh *registryHandler) add(w http.ResponseWriter, r *http.Request) { } var service *registry.Service - err = json.Unmarshal(b, &service) - if err != nil { + if err := json.NewDecoder(buf).Decode(&service); err != nil { http.Error(w, err.Error(), 500) return } - err = rh.reg.Register(service, opts...) - if err != nil { + if err := rh.reg.Register(service, opts...); err != nil { http.Error(w, err.Error(), 500) return } @@ -61,21 +66,22 @@ func (rh *registryHandler) add(w http.ResponseWriter, r *http.Request) { func (rh *registryHandler) del(w http.ResponseWriter, r *http.Request) { r.ParseForm() - b, err := ioutil.ReadAll(r.Body) - if err != nil { - http.Error(w, err.Error(), 500) - return - } defer r.Body.Close() - var service *registry.Service - err = json.Unmarshal(b, &service) - if err != nil { + // Read body + buf := bufferPool.Get() + defer bufferPool.Put(buf) + if _, err := buf.ReadFrom(r.Body); err != nil { http.Error(w, err.Error(), 500) return } - err = rh.reg.Deregister(service) - if err != nil { + + var service *registry.Service + if err := json.NewDecoder(buf).Decode(&service); err != nil { + http.Error(w, err.Error(), 500) + return + } + if err := rh.reg.Deregister(service); err != nil { http.Error(w, err.Error(), 500) return } @@ -187,6 +193,13 @@ func watch(rw registry.Watcher, w http.ResponseWriter, r *http.Request) { } func (rh *registryHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + bsize := handler.DefaultMaxRecvSize + if rh.opts.MaxRecvSize > 0 { + bsize = rh.opts.MaxRecvSize + } + + r.Body = http.MaxBytesReader(w, r.Body, bsize) + switch r.Method { case "GET": rh.get(w, r) diff --git a/api/handler/rpc/rpc.go b/api/handler/rpc/rpc.go index a246d996..3059207d 100644 --- a/api/handler/rpc/rpc.go +++ b/api/handler/rpc/rpc.go @@ -4,7 +4,6 @@ package rpc import ( "encoding/json" "io" - "io/ioutil" "net/http" "strconv" "strings" @@ -22,6 +21,7 @@ import ( "github.com/micro/go-micro/v2/logger" "github.com/micro/go-micro/v2/registry" "github.com/micro/go-micro/v2/util/ctx" + "github.com/oxtoacart/bpool" ) const ( @@ -45,6 +45,8 @@ var ( "application/proto-rpc", "application/octet-stream", } + + bufferPool = bpool.NewSizedBufferPool(1024, 8) ) type rpcHandler struct { @@ -69,6 +71,13 @@ func strategy(services []*registry.Service) selector.Strategy { } func (h *rpcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + bsize := handler.DefaultMaxRecvSize + if h.opts.MaxRecvSize > 0 { + bsize = h.opts.MaxRecvSize + } + + r.Body = http.MaxBytesReader(w, r.Body, bsize) + defer r.Body.Close() var service *api.Service @@ -240,8 +249,8 @@ func requestPayload(r *http.Request) ([]byte, error) { if err := c.ReadBody(&raw); err != nil { return nil, err } - b, _ := raw.Marshal() - return b, nil + b, err := raw.Marshal() + return b, err case strings.Contains(ct, "application/www-x-form-urlencoded"): r.ParseForm() @@ -252,8 +261,8 @@ func requestPayload(r *http.Request) ([]byte, error) { } // marshal - b, _ := json.Marshal(vals) - return b, nil + b, err := json.Marshal(vals) + return b, err // TODO: application/grpc } @@ -265,7 +274,12 @@ func requestPayload(r *http.Request) ([]byte, error) { return qson.ToJSON(r.URL.RawQuery) } case "PATCH", "POST": - return ioutil.ReadAll(r.Body) + buf := bufferPool.Get() + defer bufferPool.Put(buf) + if _, err := buf.ReadFrom(r.Body); err != nil { + return nil, err + } + return buf.Bytes(), nil } return []byte{}, nil