diff --git a/api/handler/broker/broker.go b/api/handler/broker/broker.go index bf4ccf60..29d78a02 100644 --- a/api/handler/broker/broker.go +++ b/api/handler/broker/broker.go @@ -8,6 +8,7 @@ import ( "net/url" "strings" "sync" + "sync/atomic" "time" "github.com/gorilla/websocket" @@ -26,6 +27,7 @@ const ( ) type brokerHandler struct { + once atomic.Value opts handler.Options u websocket.Upgrader } @@ -42,7 +44,6 @@ type conn struct { } var ( - once sync.Once contentType = "text/plain" ) @@ -155,10 +156,15 @@ func (b *brokerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { br := b.opts.Service.Client().Options().Broker // Setup the broker - once.Do(func() { - br.Init() - br.Connect() - }) + if !b.once.Load().(bool) { + if err := br.Init(); err != nil { + http.Error(w, err.Error(), 500) + } + if err := br.Connect(); err != nil { + http.Error(w, err.Error(), 500) + } + b.once.Store(true) + } // Parse r.ParseForm() @@ -235,7 +241,7 @@ func (b *brokerHandler) String() string { } func NewHandler(opts ...handler.Option) handler.Handler { - return &brokerHandler{ + h := &brokerHandler{ u: websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true @@ -245,6 +251,8 @@ func NewHandler(opts ...handler.Option) handler.Handler { }, opts: handler.NewOptions(opts...), } + h.once.Store(true) + return h } func WithCors(cors map[string]bool, opts ...handler.Option) handler.Handler { diff --git a/auth/jwt/jwt.go b/auth/jwt/jwt.go index 540bec43..e0719682 100644 --- a/auth/jwt/jwt.go +++ b/auth/jwt/jwt.go @@ -100,7 +100,10 @@ func (s *svc) Validate(token string) (*auth.Account, error) { return nil, ErrInvalidToken } - claims := res.Claims.(*AuthClaims) + claims, ok := res.Claims.(*AuthClaims) + if !ok { + return nil, ErrInvalidToken + } return &auth.Account{ Id: claims.Id, diff --git a/client/grpc/grpc.go b/client/grpc/grpc.go index 1c7f3512..7c5a0084 100644 --- a/client/grpc/grpc.go +++ b/client/grpc/grpc.go @@ -6,7 +6,7 @@ import ( "crypto/tls" "fmt" "os" - "sync" + "sync/atomic" "time" "github.com/micro/go-micro/v2/broker" @@ -24,9 +24,9 @@ import ( ) type grpcClient struct { - once sync.Once opts client.Options pool *pool + once atomic.Value } func init() { @@ -570,9 +570,12 @@ func (g *grpcClient) Publish(ctx context.Context, p client.Message, opts ...clie body = b } - g.once.Do(func() { - g.opts.Broker.Connect() - }) + if !g.once.Load().(bool) { + if err = g.opts.Broker.Connect(); err != nil { + return errors.InternalServerError("go.micro.client", err.Error()) + } + g.once.Store(true) + } topic := p.Topic() @@ -641,9 +644,9 @@ func newClient(opts ...client.Option) client.Client { } rc := &grpcClient{ - once: sync.Once{}, opts: options, } + rc.once.Store(false) rc.pool = newPool(options.PoolSize, options.PoolTTL, rc.poolMaxIdle(), rc.poolMaxStreams()) diff --git a/client/rpc_client.go b/client/rpc_client.go index b701d1ca..8b4b806b 100644 --- a/client/rpc_client.go +++ b/client/rpc_client.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "os" - "sync" "sync/atomic" "time" @@ -22,7 +21,7 @@ import ( ) type rpcClient struct { - once sync.Once + once atomic.Value opts Options pool pool.Pool seq uint64 @@ -38,11 +37,11 @@ func newRpcClient(opt ...Option) Client { ) rc := &rpcClient{ - once: sync.Once{}, opts: opts, pool: p, seq: 0, } + rc.once.Store(false) c := Client(rc) @@ -645,9 +644,12 @@ func (r *rpcClient) Publish(ctx context.Context, msg Message, opts ...PublishOpt body = b.Bytes() } - r.once.Do(func() { - r.opts.Broker.Connect() - }) + if !r.once.Load().(bool) { + if err = r.opts.Broker.Connect(); err != nil { + return errors.InternalServerError("go.micro.client", err.Error()) + } + r.once.Store(true) + } return r.opts.Broker.Publish(topic, &broker.Message{ Header: md, diff --git a/codec/codec.go b/codec/codec.go index b4feb0a4..107bdb35 100644 --- a/codec/codec.go +++ b/codec/codec.go @@ -2,6 +2,7 @@ package codec import ( + "errors" "io" ) @@ -12,6 +13,10 @@ const ( Event ) +var ( + ErrInvalidMessage = errors.New("invalid message") +) + type MessageType int // Takes in a connection/buffer and returns a new Codec diff --git a/codec/proto/proto.go b/codec/proto/proto.go index c2c4f382..87073619 100644 --- a/codec/proto/proto.go +++ b/codec/proto/proto.go @@ -25,13 +25,17 @@ func (c *Codec) ReadBody(b interface{}) error { if err != nil { return err } - return proto.Unmarshal(buf, b.(proto.Message)) + m, ok := b.(proto.Message) + if !ok { + return codec.ErrInvalidMessage + } + return proto.Unmarshal(buf, m) } func (c *Codec) Write(m *codec.Message, b interface{}) error { p, ok := b.(proto.Message) if !ok { - return nil + return codec.ErrInvalidMessage } buf, err := proto.Marshal(p) if err != nil { diff --git a/codec/protorpc/protorpc.go b/codec/protorpc/protorpc.go index 4e4b2ee8..41f52fff 100644 --- a/codec/protorpc/protorpc.go +++ b/codec/protorpc/protorpc.go @@ -56,8 +56,12 @@ func (c *protoCodec) Write(m *codec.Message, b interface{}) error { if err != nil { return err } - // Of course this is a protobuf! Trust me or detonate the program. - data, err = proto.Marshal(b.(proto.Message)) + // dont trust or incoming message + m, ok := b.(proto.Message) + if !ok { + return codec.ErrInvalidMessage + } + data, err = proto.Marshal(m) if err != nil { return err } @@ -100,7 +104,11 @@ func (c *protoCodec) Write(m *codec.Message, b interface{}) error { } } case codec.Event: - data, err := proto.Marshal(b.(proto.Message)) + m, ok := b.(proto.Message) + if !ok { + return codec.ErrInvalidMessage + } + data, err := proto.Marshal(m) if err != nil { return err } diff --git a/server/grpc/codec.go b/server/grpc/codec.go index cb27ec6e..b92b63b4 100644 --- a/server/grpc/codec.go +++ b/server/grpc/codec.go @@ -2,7 +2,6 @@ package grpc import ( "encoding/json" - "fmt" "strings" b "bytes" @@ -71,11 +70,19 @@ func (w wrapCodec) Unmarshal(data []byte, v interface{}) error { } func (protoCodec) Marshal(v interface{}) ([]byte, error) { - return proto.Marshal(v.(proto.Message)) + m, ok := v.(proto.Message) + if !ok { + return nil, codec.ErrInvalidMessage + } + return proto.Marshal(m) } func (protoCodec) Unmarshal(data []byte, v interface{}) error { - return proto.Unmarshal(data, v.(proto.Message)) + m, ok := v.(proto.Message) + if !ok { + return codec.ErrInvalidMessage + } + return proto.Unmarshal(data, m) } func (protoCodec) Name() string { @@ -85,7 +92,6 @@ func (protoCodec) Name() string { func (jsonCodec) Marshal(v interface{}) ([]byte, error) { if pb, ok := v.(proto.Message); ok { s, err := jsonpbMarshaler.MarshalToString(pb) - return []byte(s), err } @@ -109,7 +115,7 @@ func (jsonCodec) Name() string { func (bytesCodec) Marshal(v interface{}) ([]byte, error) { b, ok := v.(*[]byte) if !ok { - return nil, fmt.Errorf("failed to marshal: %v is not type of *[]byte", v) + return nil, codec.ErrInvalidMessage } return *b, nil } @@ -117,7 +123,7 @@ func (bytesCodec) Marshal(v interface{}) ([]byte, error) { func (bytesCodec) Unmarshal(data []byte, v interface{}) error { b, ok := v.(*[]byte) if !ok { - return fmt.Errorf("failed to unmarshal: %v is not type of *[]byte", v) + return codec.ErrInvalidMessage } *b = data return nil diff --git a/server/grpc/context.go b/server/grpc/context.go new file mode 100644 index 00000000..0b246d7e --- /dev/null +++ b/server/grpc/context.go @@ -0,0 +1,16 @@ +package grpc + +import ( + "context" + + "github.com/micro/go-micro/v2/server" +) + +func setServerOption(k, v interface{}) server.Option { + return func(o *server.Options) { + if o.Context == nil { + o.Context = context.Background() + } + o.Context = context.WithValue(o.Context, k, v) + } +} diff --git a/server/grpc/grpc.go b/server/grpc/grpc.go index b21c58e0..001ea1b5 100644 --- a/server/grpc/grpc.go +++ b/server/grpc/grpc.go @@ -143,9 +143,8 @@ func (g *grpcServer) getMaxMsgSize() int { func (g *grpcServer) getCredentials() credentials.TransportCredentials { if g.opts.Context != nil { - if v := g.opts.Context.Value(tlsAuth{}); v != nil { - tls := v.(*tls.Config) - return credentials.NewTLS(tls) + if v, ok := g.opts.Context.Value(tlsAuth{}).(*tls.Config); ok && v != nil { + return credentials.NewTLS(v) } } return nil @@ -156,15 +155,8 @@ func (g *grpcServer) getGrpcOptions() []grpc.ServerOption { return nil } - v := g.opts.Context.Value(grpcOptions{}) - - if v == nil { - return nil - } - - opts, ok := v.([]grpc.ServerOption) - - if !ok { + opts, ok := g.opts.Context.Value(grpcOptions{}).([]grpc.ServerOption) + if !ok || opts == nil { return nil } @@ -505,8 +497,8 @@ func (g *grpcServer) processStream(stream grpc.ServerStream, service *service, m func (g *grpcServer) newGRPCCodec(contentType string) (encoding.Codec, error) { codecs := make(map[string]encoding.Codec) if g.opts.Context != nil { - if v := g.opts.Context.Value(codecsKey{}); v != nil { - codecs = v.(map[string]encoding.Codec) + if v, ok := g.opts.Context.Value(codecsKey{}).(map[string]encoding.Codec); ok && v != nil { + codecs = v } } if c, ok := codecs[contentType]; ok { @@ -573,10 +565,10 @@ func (g *grpcServer) Subscribe(sb server.Subscriber) error { g.Lock() - _, ok = g.subscribers[sub] - if ok { + if _, ok = g.subscribers[sub]; ok { return fmt.Errorf("subscriber %v already exists", sub) } + g.subscribers[sub] = nil g.Unlock() return nil diff --git a/server/grpc/options.go b/server/grpc/options.go index a46b9147..64a8173d 100644 --- a/server/grpc/options.go +++ b/server/grpc/options.go @@ -27,8 +27,8 @@ func Codec(contentType string, c encoding.Codec) server.Option { if o.Context == nil { o.Context = context.Background() } - if v := o.Context.Value(codecsKey{}); v != nil { - codecs = v.(map[string]encoding.Codec) + if v, ok := o.Context.Value(codecsKey{}).(map[string]encoding.Codec); ok && v != nil { + codecs = v } codecs[contentType] = c o.Context = context.WithValue(o.Context, codecsKey{}, codecs) @@ -37,32 +37,17 @@ func Codec(contentType string, c encoding.Codec) server.Option { // AuthTLS should be used to setup a secure authentication using TLS func AuthTLS(t *tls.Config) server.Option { - return func(o *server.Options) { - if o.Context == nil { - o.Context = context.Background() - } - o.Context = context.WithValue(o.Context, tlsAuth{}, t) - } + return setServerOption(tlsAuth{}, t) } // Listener specifies the net.Listener to use instead of the default func Listener(l net.Listener) server.Option { - return func(o *server.Options) { - if o.Context == nil { - o.Context = context.Background() - } - o.Context = context.WithValue(o.Context, netListener{}, l) - } + return setServerOption(netListener{}, l) } // Options to be used to configure gRPC options func Options(opts ...grpc.ServerOption) server.Option { - return func(o *server.Options) { - if o.Context == nil { - o.Context = context.Background() - } - o.Context = context.WithValue(o.Context, grpcOptions{}, opts) - } + return setServerOption(grpcOptions{}, opts) } // @@ -70,51 +55,25 @@ func Options(opts ...grpc.ServerOption) server.Option { // send. Default maximum message size is 4 MB. // func MaxMsgSize(s int) server.Option { - return func(o *server.Options) { - if o.Context == nil { - o.Context = context.Background() - } - o.Context = context.WithValue(o.Context, maxMsgSizeKey{}, s) - } + return setServerOption(maxMsgSizeKey{}, s) } func newOptions(opt ...server.Option) server.Options { opts := server.Options{ - Codecs: make(map[string]codec.NewCodec), - Metadata: map[string]string{}, + Codecs: make(map[string]codec.NewCodec), + Metadata: map[string]string{}, + Broker: broker.DefaultBroker, + Registry: registry.DefaultRegistry, + Transport: transport.DefaultTransport, + Address: server.DefaultAddress, + Name: server.DefaultName, + Id: server.DefaultId, + Version: server.DefaultVersion, } for _, o := range opt { o(&opts) } - if opts.Broker == nil { - opts.Broker = broker.DefaultBroker - } - - if opts.Registry == nil { - opts.Registry = registry.DefaultRegistry - } - - if opts.Transport == nil { - opts.Transport = transport.DefaultTransport - } - - if len(opts.Address) == 0 { - opts.Address = server.DefaultAddress - } - - if len(opts.Name) == 0 { - opts.Name = server.DefaultName - } - - if len(opts.Id) == 0 { - opts.Id = server.DefaultId - } - - if len(opts.Version) == 0 { - opts.Version = server.DefaultVersion - } - return opts }