From 6d803d9e4514ee3320fa1f0231939ee0e5907493 Mon Sep 17 00:00:00 2001 From: ben-toogood Date: Wed, 4 Mar 2020 11:40:53 +0000 Subject: [PATCH] Implement api/server/cors (#1294) --- api/server/cors/cors.go | 43 +++++++++++++++++++++++++++++++++++++++++ api/server/http/http.go | 9 ++++++++- api/server/options.go | 7 +++++++ 3 files changed, 58 insertions(+), 1 deletion(-) create mode 100644 api/server/cors/cors.go diff --git a/api/server/cors/cors.go b/api/server/cors/cors.go new file mode 100644 index 00000000..090d6632 --- /dev/null +++ b/api/server/cors/cors.go @@ -0,0 +1,43 @@ +package cors + +import ( + "net/http" +) + +// CombinedCORSHandler wraps a server and provides CORS headers +func CombinedCORSHandler(h http.Handler) http.Handler { + return corsHandler{h} +} + +type corsHandler struct { + handler http.Handler +} + +func (c corsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + SetHeaders(w, r) + + if r.Method == "OPTIONS" { + return + } + + c.handler.ServeHTTP(w, r) +} + +// SetHeaders sets the CORS headers +func SetHeaders(w http.ResponseWriter, r *http.Request) { + set := func(w http.ResponseWriter, k, v string) { + if v := w.Header().Get(k); len(v) > 0 { + return + } + w.Header().Set(k, v) + } + + if origin := r.Header.Get("Origin"); len(origin) > 0 { + set(w, "Access-Control-Allow-Origin", origin) + } else { + set(w, "Access-Control-Allow-Origin", "*") + } + + set(w, "Access-Control-Allow-Methods", "POST, PATCH, GET, OPTIONS, PUT, DELETE") + set(w, "Access-Control-Allow-Headers", "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization") +} diff --git a/api/server/http/http.go b/api/server/http/http.go index f5fe5d24..214f203e 100644 --- a/api/server/http/http.go +++ b/api/server/http/http.go @@ -10,6 +10,7 @@ import ( "github.com/gorilla/handlers" "github.com/micro/go-micro/v2/api/server" + "github.com/micro/go-micro/v2/api/server/cors" log "github.com/micro/go-micro/v2/logger" ) @@ -45,7 +46,13 @@ func (s *httpServer) Init(opts ...server.Option) error { } func (s *httpServer) Handle(path string, handler http.Handler) { - s.mux.Handle(path, handlers.CombinedLoggingHandler(os.Stdout, handler)) + h := handlers.CombinedLoggingHandler(os.Stdout, handler) + + if s.opts.EnableCORS { + h = cors.CombinedCORSHandler(h) + } + + s.mux.Handle(path, h) } func (s *httpServer) Start() error { diff --git a/api/server/options.go b/api/server/options.go index 687e0926..99be1a03 100644 --- a/api/server/options.go +++ b/api/server/options.go @@ -10,12 +10,19 @@ type Option func(o *Options) type Options struct { EnableACME bool + EnableCORS bool ACMEProvider acme.Provider EnableTLS bool ACMEHosts []string TLSConfig *tls.Config } +func EnableCORS(b bool) Option { + return func(o *Options) { + o.EnableCORS = b + } +} + func EnableACME(b bool) Option { return func(o *Options) { o.EnableACME = b