api/handler: use http.MaxBytesReader and buffer pool (#1415)

* api/handler: use http.MaxBytesReader

protect api handlers from OOM cases

Signed-off-by: Vasiliy Tolstov <v.tolstov@unistack.org>
This commit is contained in:
2020-03-26 14:29:28 +03:00
committed by GitHub
parent 776a7d6cd6
commit 02839cfba5
9 changed files with 146 additions and 51 deletions

View File

@@ -3,7 +3,6 @@ package registry
import (
"encoding/json"
"io/ioutil"
"net/http"
"strconv"
"time"
@@ -11,6 +10,11 @@ import (
"github.com/gorilla/websocket"
"github.com/micro/go-micro/v2/api/handler"
"github.com/micro/go-micro/v2/registry"
"github.com/oxtoacart/bpool"
)
var (
bufferPool = bpool.NewSizedBufferPool(1024, 8)
)
const (
@@ -29,12 +33,15 @@ type registryHandler struct {
func (rh *registryHandler) add(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
b, err := ioutil.ReadAll(r.Body)
if err != nil {
defer r.Body.Close()
// Read body
buf := bufferPool.Get()
defer bufferPool.Put(buf)
if _, err := buf.ReadFrom(r.Body); err != nil {
http.Error(w, err.Error(), 500)
return
}
defer r.Body.Close()
var opts []registry.RegisterOption
@@ -47,13 +54,11 @@ func (rh *registryHandler) add(w http.ResponseWriter, r *http.Request) {
}
var service *registry.Service
err = json.Unmarshal(b, &service)
if err != nil {
if err := json.NewDecoder(buf).Decode(&service); err != nil {
http.Error(w, err.Error(), 500)
return
}
err = rh.reg.Register(service, opts...)
if err != nil {
if err := rh.reg.Register(service, opts...); err != nil {
http.Error(w, err.Error(), 500)
return
}
@@ -61,21 +66,22 @@ func (rh *registryHandler) add(w http.ResponseWriter, r *http.Request) {
func (rh *registryHandler) del(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
b, err := ioutil.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), 500)
return
}
defer r.Body.Close()
var service *registry.Service
err = json.Unmarshal(b, &service)
if err != nil {
// Read body
buf := bufferPool.Get()
defer bufferPool.Put(buf)
if _, err := buf.ReadFrom(r.Body); err != nil {
http.Error(w, err.Error(), 500)
return
}
err = rh.reg.Deregister(service)
if err != nil {
var service *registry.Service
if err := json.NewDecoder(buf).Decode(&service); err != nil {
http.Error(w, err.Error(), 500)
return
}
if err := rh.reg.Deregister(service); err != nil {
http.Error(w, err.Error(), 500)
return
}
@@ -187,6 +193,13 @@ func watch(rw registry.Watcher, w http.ResponseWriter, r *http.Request) {
}
func (rh *registryHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
bsize := handler.DefaultMaxRecvSize
if rh.opts.MaxRecvSize > 0 {
bsize = rh.opts.MaxRecvSize
}
r.Body = http.MaxBytesReader(w, r.Body, bsize)
switch r.Method {
case "GET":
rh.get(w, r)