Merge branch 'master' into auth-resolver

This commit is contained in:
ben-toogood 2020-04-06 14:43:22 +01:00 committed by GitHub
commit 7f07e1a642
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 78 additions and 55 deletions

View File

@ -135,9 +135,19 @@ func serveWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request,
// receive from stream and send to client // receive from stream and send to client
for { for {
select {
case <-ctx.Done():
return
case <-stream.Context().Done():
return
default:
// read backend response body // read backend response body
buf, err := rsp.Read() buf, err := rsp.Read()
if err != nil { if err != nil {
// wants to avoid import grpc/status.Status
if strings.Contains(err.Error(), "context canceled") {
return
}
if logger.V(logger.ErrorLevel, logger.DefaultLogger) { if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
logger.Error(err) logger.Error(err)
} }
@ -158,6 +168,7 @@ func serveWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request,
return return
} }
} }
}
} }
// writeLoop // writeLoop
@ -166,8 +177,18 @@ func writeLoop(rw io.ReadWriter, stream client.Stream) {
defer stream.Close() defer stream.Close()
for { for {
select {
case <-stream.Context().Done():
return
default:
buf, op, err := wsutil.ReadClientData(rw) buf, op, err := wsutil.ReadClientData(rw)
if err != nil { if err != nil {
if wserr, ok := err.(wsutil.ClosedError); ok {
switch wserr.Code {
case ws.StatusNormalClosure, ws.StatusNoStatusRcvd:
return
}
}
if logger.V(logger.ErrorLevel, logger.DefaultLogger) { if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
logger.Error(err) logger.Error(err)
} }
@ -184,7 +205,6 @@ func writeLoop(rw io.ReadWriter, stream client.Stream) {
// default to trying json // default to trying json
// if the extracted payload isn't empty lets use it // if the extracted payload isn't empty lets use it
request := &raw.Frame{Data: buf} request := &raw.Frame{Data: buf}
if err := stream.Send(request); err != nil { if err := stream.Send(request); err != nil {
if logger.V(logger.ErrorLevel, logger.DefaultLogger) { if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
logger.Error(err) logger.Error(err)
@ -192,6 +212,7 @@ func writeLoop(rw io.ReadWriter, stream client.Stream) {
return return
} }
} }
}
} }
func isStream(r *http.Request, srv *api.Service) bool { func isStream(r *http.Request, srv *api.Service) bool {

View File

@ -133,22 +133,19 @@ func (h authHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
func namespaceFromRequest(req *http.Request) (string, error) { func namespaceFromRequest(req *http.Request) (string, error) {
// determine the host, e.g. dev.micro.mu:8080 // determine the host, e.g. dev.micro.mu:8080
host := req.URL.Host host := req.URL.Hostname()
if len(host) == 0 { if len(host) == 0 {
host = req.Host // fallback to req.Host
host, _, _ = net.SplitHostPort(req.Host)
} }
logger.Infof("Host is %v", host)
// check for an ip address // check for an ip address
if net.ParseIP(host) != nil { if net.ParseIP(host) != nil {
return auth.DefaultNamespace, nil return auth.DefaultNamespace, nil
} }
// split the host to remove the port
host, _, err := net.SplitHostPort(req.Host)
if err != nil {
return "", err
}
// check for dev enviroment // check for dev enviroment
if host == "localhost" || host == "127.0.0.1" { if host == "localhost" || host == "127.0.0.1" {
return auth.DefaultNamespace, nil return auth.DefaultNamespace, nil

View File

@ -175,7 +175,7 @@ func AuthHandler(fn func() auth.Auth) server.HandlerWrapper {
// Get the namespace for the request // Get the namespace for the request
namespace, ok := metadata.Get(ctx, auth.NamespaceKey) namespace, ok := metadata.Get(ctx, auth.NamespaceKey)
if !ok { if !ok {
logger.Errorf("Missing request namespace") logger.Debugf("Missing request namespace")
namespace = auth.DefaultNamespace namespace = auth.DefaultNamespace
} }
@ -188,9 +188,9 @@ func AuthHandler(fn func() auth.Auth) server.HandlerWrapper {
// Check the accounts namespace matches the namespace we're operating // Check the accounts namespace matches the namespace we're operating
// within. If not forbid the request and log the occurance. // within. If not forbid the request and log the occurance.
if account.Namespace != namespace { if account.Namespace != namespace {
logger.Warnf("Cross namespace request forbidden: account %v (%v) requested access to %v %v in the %v namespace", logger.Debugf("Cross namespace request forbidden: account %v (%v) requested access to %v %v in the %v namespace",
account.ID, account.Namespace, req.Service(), req.Endpoint(), namespace) account.ID, account.Namespace, req.Service(), req.Endpoint(), namespace)
return errors.Forbidden(req.Service(), "cross namespace request") // return errors.Forbidden(req.Service(), "cross namespace request")
} }
// construct the resource // construct the resource

View File

@ -11,6 +11,7 @@ import (
"github.com/micro/go-micro/v2/registry" "github.com/micro/go-micro/v2/registry"
) )
//Options for web
type Options struct { type Options struct {
Name string Name string
Version string Version string
@ -75,7 +76,7 @@ func newOptions(opts ...Option) Options {
return opt return opt
} }
// Server name // Name of Web
func Name(n string) Option { func Name(n string) Option {
return func(o *Options) { return func(o *Options) {
o.Name = n o.Name = n
@ -92,7 +93,7 @@ func Icon(ico string) Option {
} }
} }
// Unique server id //Id for Unique server id
func Id(id string) Option { func Id(id string) Option {
return func(o *Options) { return func(o *Options) {
o.Id = id o.Id = id
@ -120,7 +121,7 @@ func Address(a string) Option {
} }
} }
// The address to advertise for discovery - host:port //Advertise The address to advertise for discovery - host:port
func Advertise(a string) Option { func Advertise(a string) Option {
return func(o *Options) { return func(o *Options) {
o.Advertise = a o.Advertise = a
@ -143,26 +144,28 @@ func Registry(r registry.Registry) Option {
} }
} }
// Register the service with a TTL //RegisterTTL Register the service with a TTL
func RegisterTTL(t time.Duration) Option { func RegisterTTL(t time.Duration) Option {
return func(o *Options) { return func(o *Options) {
o.RegisterTTL = t o.RegisterTTL = t
} }
} }
// Register the service with at interval //RegisterInterval Register the service with at interval
func RegisterInterval(t time.Duration) Option { func RegisterInterval(t time.Duration) Option {
return func(o *Options) { return func(o *Options) {
o.RegisterInterval = t o.RegisterInterval = t
} }
} }
//Handler for custom handler
func Handler(h http.Handler) Option { func Handler(h http.Handler) Option {
return func(o *Options) { return func(o *Options) {
o.Handler = h o.Handler = h
} }
} }
//Server for custom Server
func Server(srv *http.Server) Option { func Server(srv *http.Server) Option {
return func(o *Options) { return func(o *Options) {
o.Server = srv o.Server = srv

View File

@ -268,7 +268,7 @@ func (s *service) stop() error {
func (s *service) Client() *http.Client { func (s *service) Client() *http.Client {
rt := mhttp.NewRoundTripper( rt := mhttp.NewRoundTripper(
mhttp.WithRegistry(registry.DefaultRegistry), mhttp.WithRegistry(s.opts.Registry),
) )
return &http.Client{ return &http.Client{
Transport: rt, Transport: rt,

View File

@ -20,8 +20,10 @@ type Service interface {
Run() error Run() error
} }
//Option for web
type Option func(o *Options) type Option func(o *Options)
//Web basic Defaults
var ( var (
// For serving // For serving
DefaultName = "go-web" DefaultName = "go-web"