api/handler: use http.MaxBytesReader and buffer pool (#1415)

* api/handler: use http.MaxBytesReader

protect api handlers from OOM cases

Signed-off-by: Vasiliy Tolstov <v.tolstov@unistack.org>
This commit is contained in:
Василий Толстов 2020-03-26 14:29:28 +03:00 committed by GitHub
parent 776a7d6cd6
commit 02839cfba5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 146 additions and 51 deletions

View File

@ -24,6 +24,12 @@ const (
// API handler is the default handler which takes api.Request and returns api.Response // API handler is the default handler which takes api.Request and returns api.Response
func (a *apiHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (a *apiHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
bsize := handler.DefaultMaxRecvSize
if a.opts.MaxRecvSize > 0 {
bsize = a.opts.MaxRecvSize
}
r.Body = http.MaxBytesReader(w, r.Body, bsize)
request, err := requestToProto(r) request, err := requestToProto(r)
if err != nil { if err != nil {
er := errors.InternalServerError("go.micro.api", err.Error()) er := errors.InternalServerError("go.micro.api", err.Error())

View File

@ -2,7 +2,6 @@ package api
import ( import (
"fmt" "fmt"
"io/ioutil"
"mime" "mime"
"net" "net"
"net/http" "net/http"
@ -11,6 +10,12 @@ import (
api "github.com/micro/go-micro/v2/api/proto" api "github.com/micro/go-micro/v2/api/proto"
"github.com/micro/go-micro/v2/client/selector" "github.com/micro/go-micro/v2/client/selector"
"github.com/micro/go-micro/v2/registry" "github.com/micro/go-micro/v2/registry"
"github.com/oxtoacart/bpool"
)
var (
// need to calculate later to specify useful defaults
bufferPool = bpool.NewSizedBufferPool(1024, 8)
) )
func requestToProto(r *http.Request) (*api.Request, error) { func requestToProto(r *http.Request) (*api.Request, error) {
@ -39,9 +44,12 @@ func requestToProto(r *http.Request) (*api.Request, error) {
case "application/x-www-form-urlencoded": case "application/x-www-form-urlencoded":
// expect form vals in Post data // expect form vals in Post data
default: default:
buf := bufferPool.Get()
data, _ := ioutil.ReadAll(r.Body) defer bufferPool.Put(buf)
req.Body = string(data) if _, err = buf.ReadFrom(r.Body); err != nil {
return nil, err
}
req.Body = buf.String()
} }
} }

View File

@ -3,7 +3,6 @@ package broker
import ( import (
"encoding/json" "encoding/json"
"io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
@ -15,6 +14,11 @@ import (
"github.com/micro/go-micro/v2/api/handler" "github.com/micro/go-micro/v2/api/handler"
"github.com/micro/go-micro/v2/broker" "github.com/micro/go-micro/v2/broker"
"github.com/micro/go-micro/v2/logger" "github.com/micro/go-micro/v2/logger"
"github.com/oxtoacart/bpool"
)
var (
bufferPool = bpool.NewSizedBufferPool(1024, 8)
) )
const ( const (
@ -155,6 +159,13 @@ func (c *conn) writeLoop() {
} }
func (b *brokerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (b *brokerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
bsize := handler.DefaultMaxRecvSize
if b.opts.MaxRecvSize > 0 {
bsize = b.opts.MaxRecvSize
}
r.Body = http.MaxBytesReader(w, r.Body, bsize)
br := b.opts.Service.Client().Options().Broker br := b.opts.Service.Client().Options().Broker
// Setup the broker // Setup the broker
@ -191,14 +202,15 @@ func (b *brokerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
// Read body // Read body
b, err := ioutil.ReadAll(r.Body) buf := bufferPool.Get()
if err != nil { defer bufferPool.Put(buf)
if _, err := buf.ReadFrom(r.Body); err != nil {
http.Error(w, err.Error(), 500) http.Error(w, err.Error(), 500)
return return
} }
// Set body // Set body
msg.Body = b msg.Body = buf.Bytes()
// Set body
// Publish // Publish
br.Publish(topic, msg) br.Publish(topic, msg)

View File

@ -12,7 +12,7 @@ import (
) )
type event struct { type event struct {
options handler.Options opts handler.Options
} }
var ( var (
@ -58,10 +58,17 @@ func evRoute(ns, p string) (string, string) {
} }
func (e *event) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (e *event) ServeHTTP(w http.ResponseWriter, r *http.Request) {
bsize := handler.DefaultMaxRecvSize
if e.opts.MaxRecvSize > 0 {
bsize = e.opts.MaxRecvSize
}
r.Body = http.MaxBytesReader(w, r.Body, bsize)
// request to topic:event // request to topic:event
// create event // create event
// publish to topic // publish to topic
topic, _ := evRoute(e.options.Namespace, r.URL.Path) topic, _ := evRoute(e.opts.Namespace, r.URL.Path)
// create event // create event
ev, err := FromRequest(r) ev, err := FromRequest(r)
@ -71,7 +78,7 @@ func (e *event) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
// get client // get client
c := e.options.Service.Client() c := e.opts.Service.Client()
// create publication // create publication
p := c.NewMessage(topic, ev) p := c.NewMessage(topic, ev)
@ -89,6 +96,6 @@ func (e *event) String() string {
func NewHandler(opts ...handler.Option) handler.Handler { func NewHandler(opts ...handler.Option) handler.Handler {
return &event{ return &event{
options: handler.NewOptions(opts...), opts: handler.NewOptions(opts...),
} }
} }

View File

@ -24,7 +24,6 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io/ioutil"
"mime" "mime"
"net/http" "net/http"
"strings" "strings"
@ -32,9 +31,14 @@ import (
"unicode" "unicode"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/oxtoacart/bpool"
validator "gopkg.in/go-playground/validator.v9" validator "gopkg.in/go-playground/validator.v9"
) )
var (
bufferPool = bpool.NewSizedBufferPool(1024, 8)
)
const ( const (
// TransformationVersion is indicative of the revision of how Event Gateway transforms a request into CloudEvents format. // TransformationVersion is indicative of the revision of how Event Gateway transforms a request into CloudEvents format.
TransformationVersion = "0.1" TransformationVersion = "0.1"
@ -97,10 +101,12 @@ func FromRequest(r *http.Request) (*Event, error) {
// Read request body // Read request body
body := []byte{} body := []byte{}
if r.Body != nil { if r.Body != nil {
body, err = ioutil.ReadAll(r.Body) buf := bufferPool.Get()
if err != nil { defer bufferPool.Put(buf)
if _, err := buf.ReadFrom(r.Body); err != nil {
return nil, err return nil, err
} }
body = buf.Bytes()
} }
var event *Event var event *Event

View File

@ -4,7 +4,6 @@ package event
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil"
"net/http" "net/http"
"path" "path"
"regexp" "regexp"
@ -15,10 +14,15 @@ import (
"github.com/micro/go-micro/v2/api/handler" "github.com/micro/go-micro/v2/api/handler"
proto "github.com/micro/go-micro/v2/api/proto" proto "github.com/micro/go-micro/v2/api/proto"
"github.com/micro/go-micro/v2/util/ctx" "github.com/micro/go-micro/v2/util/ctx"
"github.com/oxtoacart/bpool"
)
var (
bufferPool = bpool.NewSizedBufferPool(1024, 8)
) )
type event struct { type event struct {
options handler.Options opts handler.Options
} }
var ( var (
@ -64,11 +68,18 @@ func evRoute(ns, p string) (string, string) {
} }
func (e *event) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (e *event) ServeHTTP(w http.ResponseWriter, r *http.Request) {
bsize := handler.DefaultMaxRecvSize
if e.opts.MaxRecvSize > 0 {
bsize = e.opts.MaxRecvSize
}
r.Body = http.MaxBytesReader(w, r.Body, bsize)
// request to topic:event // request to topic:event
// create event // create event
// publish to topic // publish to topic
topic, action := evRoute(e.options.Namespace, r.URL.Path) topic, action := evRoute(e.opts.Namespace, r.URL.Path)
// create event // create event
ev := &proto.Event{ ev := &proto.Event{
@ -96,16 +107,18 @@ func (e *event) ServeHTTP(w http.ResponseWriter, r *http.Request) {
bytes, _ := json.Marshal(r.URL.Query()) bytes, _ := json.Marshal(r.URL.Query())
ev.Data = string(bytes) ev.Data = string(bytes)
} else { } else {
b, err := ioutil.ReadAll(r.Body) // Read body
if err != nil { buf := bufferPool.Get()
defer bufferPool.Put(buf)
if _, err := buf.ReadFrom(r.Body); err != nil {
http.Error(w, err.Error(), 500) http.Error(w, err.Error(), 500)
return return
} }
ev.Data = string(b) ev.Data = buf.String()
} }
// get client // get client
c := e.options.Service.Client() c := e.opts.Service.Client()
// create publication // create publication
p := c.NewMessage(topic, ev) p := c.NewMessage(topic, ev)
@ -123,6 +136,6 @@ func (e *event) String() string {
func NewHandler(opts ...handler.Option) handler.Handler { func NewHandler(opts ...handler.Option) handler.Handler {
return &event{ return &event{
options: handler.NewOptions(opts...), opts: handler.NewOptions(opts...),
} }
} }

View File

@ -5,7 +5,12 @@ import (
"github.com/micro/go-micro/v2/api/router" "github.com/micro/go-micro/v2/api/router"
) )
var (
DefaultMaxRecvSize int64 = 1024 * 1024 * 10 // 10Mb
)
type Options struct { type Options struct {
MaxRecvSize int64
Namespace string Namespace string
Router router.Router Router router.Router
Service micro.Service Service micro.Service
@ -30,6 +35,10 @@ func NewOptions(opts ...Option) Options {
WithNamespace("go.micro.api")(&options) WithNamespace("go.micro.api")(&options)
} }
if options.MaxRecvSize == 0 {
options.MaxRecvSize = DefaultMaxRecvSize
}
return options return options
} }
@ -53,3 +62,10 @@ func WithService(s micro.Service) Option {
o.Service = s o.Service = s
} }
} }
// WithmaxRecvSize specifies max body size
func WithMaxRecvSize(size int64) Option {
return func(o *Options) {
o.MaxRecvSize = size
}
}

View File

@ -3,7 +3,6 @@ package registry
import ( import (
"encoding/json" "encoding/json"
"io/ioutil"
"net/http" "net/http"
"strconv" "strconv"
"time" "time"
@ -11,6 +10,11 @@ import (
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/micro/go-micro/v2/api/handler" "github.com/micro/go-micro/v2/api/handler"
"github.com/micro/go-micro/v2/registry" "github.com/micro/go-micro/v2/registry"
"github.com/oxtoacart/bpool"
)
var (
bufferPool = bpool.NewSizedBufferPool(1024, 8)
) )
const ( const (
@ -29,12 +33,15 @@ type registryHandler struct {
func (rh *registryHandler) add(w http.ResponseWriter, r *http.Request) { func (rh *registryHandler) add(w http.ResponseWriter, r *http.Request) {
r.ParseForm() r.ParseForm()
b, err := ioutil.ReadAll(r.Body) defer r.Body.Close()
if err != nil {
// Read body
buf := bufferPool.Get()
defer bufferPool.Put(buf)
if _, err := buf.ReadFrom(r.Body); err != nil {
http.Error(w, err.Error(), 500) http.Error(w, err.Error(), 500)
return return
} }
defer r.Body.Close()
var opts []registry.RegisterOption var opts []registry.RegisterOption
@ -47,13 +54,11 @@ func (rh *registryHandler) add(w http.ResponseWriter, r *http.Request) {
} }
var service *registry.Service var service *registry.Service
err = json.Unmarshal(b, &service) if err := json.NewDecoder(buf).Decode(&service); err != nil {
if err != nil {
http.Error(w, err.Error(), 500) http.Error(w, err.Error(), 500)
return return
} }
err = rh.reg.Register(service, opts...) if err := rh.reg.Register(service, opts...); err != nil {
if err != nil {
http.Error(w, err.Error(), 500) http.Error(w, err.Error(), 500)
return return
} }
@ -61,21 +66,22 @@ func (rh *registryHandler) add(w http.ResponseWriter, r *http.Request) {
func (rh *registryHandler) del(w http.ResponseWriter, r *http.Request) { func (rh *registryHandler) del(w http.ResponseWriter, r *http.Request) {
r.ParseForm() r.ParseForm()
b, err := ioutil.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), 500)
return
}
defer r.Body.Close() defer r.Body.Close()
var service *registry.Service // Read body
err = json.Unmarshal(b, &service) buf := bufferPool.Get()
if err != nil { defer bufferPool.Put(buf)
if _, err := buf.ReadFrom(r.Body); err != nil {
http.Error(w, err.Error(), 500) http.Error(w, err.Error(), 500)
return return
} }
err = rh.reg.Deregister(service)
if err != nil { var service *registry.Service
if err := json.NewDecoder(buf).Decode(&service); err != nil {
http.Error(w, err.Error(), 500)
return
}
if err := rh.reg.Deregister(service); err != nil {
http.Error(w, err.Error(), 500) http.Error(w, err.Error(), 500)
return return
} }
@ -187,6 +193,13 @@ func watch(rw registry.Watcher, w http.ResponseWriter, r *http.Request) {
} }
func (rh *registryHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (rh *registryHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
bsize := handler.DefaultMaxRecvSize
if rh.opts.MaxRecvSize > 0 {
bsize = rh.opts.MaxRecvSize
}
r.Body = http.MaxBytesReader(w, r.Body, bsize)
switch r.Method { switch r.Method {
case "GET": case "GET":
rh.get(w, r) rh.get(w, r)

View File

@ -4,7 +4,6 @@ package rpc
import ( import (
"encoding/json" "encoding/json"
"io" "io"
"io/ioutil"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
@ -22,6 +21,7 @@ import (
"github.com/micro/go-micro/v2/logger" "github.com/micro/go-micro/v2/logger"
"github.com/micro/go-micro/v2/registry" "github.com/micro/go-micro/v2/registry"
"github.com/micro/go-micro/v2/util/ctx" "github.com/micro/go-micro/v2/util/ctx"
"github.com/oxtoacart/bpool"
) )
const ( const (
@ -45,6 +45,8 @@ var (
"application/proto-rpc", "application/proto-rpc",
"application/octet-stream", "application/octet-stream",
} }
bufferPool = bpool.NewSizedBufferPool(1024, 8)
) )
type rpcHandler struct { type rpcHandler struct {
@ -69,6 +71,13 @@ func strategy(services []*registry.Service) selector.Strategy {
} }
func (h *rpcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *rpcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
bsize := handler.DefaultMaxRecvSize
if h.opts.MaxRecvSize > 0 {
bsize = h.opts.MaxRecvSize
}
r.Body = http.MaxBytesReader(w, r.Body, bsize)
defer r.Body.Close() defer r.Body.Close()
var service *api.Service var service *api.Service
@ -240,8 +249,8 @@ func requestPayload(r *http.Request) ([]byte, error) {
if err := c.ReadBody(&raw); err != nil { if err := c.ReadBody(&raw); err != nil {
return nil, err return nil, err
} }
b, _ := raw.Marshal() b, err := raw.Marshal()
return b, nil return b, err
case strings.Contains(ct, "application/www-x-form-urlencoded"): case strings.Contains(ct, "application/www-x-form-urlencoded"):
r.ParseForm() r.ParseForm()
@ -252,8 +261,8 @@ func requestPayload(r *http.Request) ([]byte, error) {
} }
// marshal // marshal
b, _ := json.Marshal(vals) b, err := json.Marshal(vals)
return b, nil return b, err
// TODO: application/grpc // TODO: application/grpc
} }
@ -265,7 +274,12 @@ func requestPayload(r *http.Request) ([]byte, error) {
return qson.ToJSON(r.URL.RawQuery) return qson.ToJSON(r.URL.RawQuery)
} }
case "PATCH", "POST": case "PATCH", "POST":
return ioutil.ReadAll(r.Body) buf := bufferPool.Get()
defer bufferPool.Put(buf)
if _, err := buf.ReadFrom(r.Body); err != nil {
return nil, err
}
return buf.Bytes(), nil
} }
return []byte{}, nil return []byte{}, nil