From 58598d0fe02f3a5bb4b384970d52bfb798d22c24 Mon Sep 17 00:00:00 2001 From: Vasiliy Tolstov Date: Wed, 19 Feb 2020 02:05:38 +0300 Subject: [PATCH] fixes for safe conversation and avoid panics (#1213) * fixes for safe convertation Signed-off-by: Vasiliy Tolstov * fix client publish panic If broker connect returns error we dont check it status and use it later to publish message, mostly this is unexpected because broker connection failed and we cant use it. Also proposed solution have benefit - we flag connection status only when we have succeseful broker connection Signed-off-by: Vasiliy Tolstov * api/handler/broker: fix possible broker publish panic Signed-off-by: Vasiliy Tolstov --- api/handler/broker/broker.go | 20 +++++++--- auth/jwt/jwt.go | 5 ++- client/grpc/grpc.go | 15 +++++--- client/rpc_client.go | 14 ++++--- codec/codec.go | 5 +++ codec/proto/proto.go | 8 +++- codec/protorpc/protorpc.go | 14 +++++-- server/grpc/codec.go | 18 ++++++--- server/grpc/context.go | 16 ++++++++ server/grpc/grpc.go | 24 ++++-------- server/grpc/options.go | 71 ++++++++---------------------------- 11 files changed, 108 insertions(+), 102 deletions(-) create mode 100644 server/grpc/context.go diff --git a/api/handler/broker/broker.go b/api/handler/broker/broker.go index bf4ccf60..29d78a02 100644 --- a/api/handler/broker/broker.go +++ b/api/handler/broker/broker.go @@ -8,6 +8,7 @@ import ( "net/url" "strings" "sync" + "sync/atomic" "time" "github.com/gorilla/websocket" @@ -26,6 +27,7 @@ const ( ) type brokerHandler struct { + once atomic.Value opts handler.Options u websocket.Upgrader } @@ -42,7 +44,6 @@ type conn struct { } var ( - once sync.Once contentType = "text/plain" ) @@ -155,10 +156,15 @@ func (b *brokerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { br := b.opts.Service.Client().Options().Broker // Setup the broker - once.Do(func() { - br.Init() - br.Connect() - }) + if !b.once.Load().(bool) { + if err := br.Init(); err != nil { + http.Error(w, err.Error(), 500) + } + if err := br.Connect(); err != nil { + http.Error(w, err.Error(), 500) + } + b.once.Store(true) + } // Parse r.ParseForm() @@ -235,7 +241,7 @@ func (b *brokerHandler) String() string { } func NewHandler(opts ...handler.Option) handler.Handler { - return &brokerHandler{ + h := &brokerHandler{ u: websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true @@ -245,6 +251,8 @@ func NewHandler(opts ...handler.Option) handler.Handler { }, opts: handler.NewOptions(opts...), } + h.once.Store(true) + return h } func WithCors(cors map[string]bool, opts ...handler.Option) handler.Handler { diff --git a/auth/jwt/jwt.go b/auth/jwt/jwt.go index 540bec43..e0719682 100644 --- a/auth/jwt/jwt.go +++ b/auth/jwt/jwt.go @@ -100,7 +100,10 @@ func (s *svc) Validate(token string) (*auth.Account, error) { return nil, ErrInvalidToken } - claims := res.Claims.(*AuthClaims) + claims, ok := res.Claims.(*AuthClaims) + if !ok { + return nil, ErrInvalidToken + } return &auth.Account{ Id: claims.Id, diff --git a/client/grpc/grpc.go b/client/grpc/grpc.go index 1c7f3512..7c5a0084 100644 --- a/client/grpc/grpc.go +++ b/client/grpc/grpc.go @@ -6,7 +6,7 @@ import ( "crypto/tls" "fmt" "os" - "sync" + "sync/atomic" "time" "github.com/micro/go-micro/v2/broker" @@ -24,9 +24,9 @@ import ( ) type grpcClient struct { - once sync.Once opts client.Options pool *pool + once atomic.Value } func init() { @@ -570,9 +570,12 @@ func (g *grpcClient) Publish(ctx context.Context, p client.Message, opts ...clie body = b } - g.once.Do(func() { - g.opts.Broker.Connect() - }) + if !g.once.Load().(bool) { + if err = g.opts.Broker.Connect(); err != nil { + return errors.InternalServerError("go.micro.client", err.Error()) + } + g.once.Store(true) + } topic := p.Topic() @@ -641,9 +644,9 @@ func newClient(opts ...client.Option) client.Client { } rc := &grpcClient{ - once: sync.Once{}, opts: options, } + rc.once.Store(false) rc.pool = newPool(options.PoolSize, options.PoolTTL, rc.poolMaxIdle(), rc.poolMaxStreams()) diff --git a/client/rpc_client.go b/client/rpc_client.go index b701d1ca..8b4b806b 100644 --- a/client/rpc_client.go +++ b/client/rpc_client.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "os" - "sync" "sync/atomic" "time" @@ -22,7 +21,7 @@ import ( ) type rpcClient struct { - once sync.Once + once atomic.Value opts Options pool pool.Pool seq uint64 @@ -38,11 +37,11 @@ func newRpcClient(opt ...Option) Client { ) rc := &rpcClient{ - once: sync.Once{}, opts: opts, pool: p, seq: 0, } + rc.once.Store(false) c := Client(rc) @@ -645,9 +644,12 @@ func (r *rpcClient) Publish(ctx context.Context, msg Message, opts ...PublishOpt body = b.Bytes() } - r.once.Do(func() { - r.opts.Broker.Connect() - }) + if !r.once.Load().(bool) { + if err = r.opts.Broker.Connect(); err != nil { + return errors.InternalServerError("go.micro.client", err.Error()) + } + r.once.Store(true) + } return r.opts.Broker.Publish(topic, &broker.Message{ Header: md, diff --git a/codec/codec.go b/codec/codec.go index b4feb0a4..107bdb35 100644 --- a/codec/codec.go +++ b/codec/codec.go @@ -2,6 +2,7 @@ package codec import ( + "errors" "io" ) @@ -12,6 +13,10 @@ const ( Event ) +var ( + ErrInvalidMessage = errors.New("invalid message") +) + type MessageType int // Takes in a connection/buffer and returns a new Codec diff --git a/codec/proto/proto.go b/codec/proto/proto.go index c2c4f382..87073619 100644 --- a/codec/proto/proto.go +++ b/codec/proto/proto.go @@ -25,13 +25,17 @@ func (c *Codec) ReadBody(b interface{}) error { if err != nil { return err } - return proto.Unmarshal(buf, b.(proto.Message)) + m, ok := b.(proto.Message) + if !ok { + return codec.ErrInvalidMessage + } + return proto.Unmarshal(buf, m) } func (c *Codec) Write(m *codec.Message, b interface{}) error { p, ok := b.(proto.Message) if !ok { - return nil + return codec.ErrInvalidMessage } buf, err := proto.Marshal(p) if err != nil { diff --git a/codec/protorpc/protorpc.go b/codec/protorpc/protorpc.go index 4e4b2ee8..41f52fff 100644 --- a/codec/protorpc/protorpc.go +++ b/codec/protorpc/protorpc.go @@ -56,8 +56,12 @@ func (c *protoCodec) Write(m *codec.Message, b interface{}) error { if err != nil { return err } - // Of course this is a protobuf! Trust me or detonate the program. - data, err = proto.Marshal(b.(proto.Message)) + // dont trust or incoming message + m, ok := b.(proto.Message) + if !ok { + return codec.ErrInvalidMessage + } + data, err = proto.Marshal(m) if err != nil { return err } @@ -100,7 +104,11 @@ func (c *protoCodec) Write(m *codec.Message, b interface{}) error { } } case codec.Event: - data, err := proto.Marshal(b.(proto.Message)) + m, ok := b.(proto.Message) + if !ok { + return codec.ErrInvalidMessage + } + data, err := proto.Marshal(m) if err != nil { return err } diff --git a/server/grpc/codec.go b/server/grpc/codec.go index cb27ec6e..b92b63b4 100644 --- a/server/grpc/codec.go +++ b/server/grpc/codec.go @@ -2,7 +2,6 @@ package grpc import ( "encoding/json" - "fmt" "strings" b "bytes" @@ -71,11 +70,19 @@ func (w wrapCodec) Unmarshal(data []byte, v interface{}) error { } func (protoCodec) Marshal(v interface{}) ([]byte, error) { - return proto.Marshal(v.(proto.Message)) + m, ok := v.(proto.Message) + if !ok { + return nil, codec.ErrInvalidMessage + } + return proto.Marshal(m) } func (protoCodec) Unmarshal(data []byte, v interface{}) error { - return proto.Unmarshal(data, v.(proto.Message)) + m, ok := v.(proto.Message) + if !ok { + return codec.ErrInvalidMessage + } + return proto.Unmarshal(data, m) } func (protoCodec) Name() string { @@ -85,7 +92,6 @@ func (protoCodec) Name() string { func (jsonCodec) Marshal(v interface{}) ([]byte, error) { if pb, ok := v.(proto.Message); ok { s, err := jsonpbMarshaler.MarshalToString(pb) - return []byte(s), err } @@ -109,7 +115,7 @@ func (jsonCodec) Name() string { func (bytesCodec) Marshal(v interface{}) ([]byte, error) { b, ok := v.(*[]byte) if !ok { - return nil, fmt.Errorf("failed to marshal: %v is not type of *[]byte", v) + return nil, codec.ErrInvalidMessage } return *b, nil } @@ -117,7 +123,7 @@ func (bytesCodec) Marshal(v interface{}) ([]byte, error) { func (bytesCodec) Unmarshal(data []byte, v interface{}) error { b, ok := v.(*[]byte) if !ok { - return fmt.Errorf("failed to unmarshal: %v is not type of *[]byte", v) + return codec.ErrInvalidMessage } *b = data return nil diff --git a/server/grpc/context.go b/server/grpc/context.go new file mode 100644 index 00000000..0b246d7e --- /dev/null +++ b/server/grpc/context.go @@ -0,0 +1,16 @@ +package grpc + +import ( + "context" + + "github.com/micro/go-micro/v2/server" +) + +func setServerOption(k, v interface{}) server.Option { + return func(o *server.Options) { + if o.Context == nil { + o.Context = context.Background() + } + o.Context = context.WithValue(o.Context, k, v) + } +} diff --git a/server/grpc/grpc.go b/server/grpc/grpc.go index b21c58e0..001ea1b5 100644 --- a/server/grpc/grpc.go +++ b/server/grpc/grpc.go @@ -143,9 +143,8 @@ func (g *grpcServer) getMaxMsgSize() int { func (g *grpcServer) getCredentials() credentials.TransportCredentials { if g.opts.Context != nil { - if v := g.opts.Context.Value(tlsAuth{}); v != nil { - tls := v.(*tls.Config) - return credentials.NewTLS(tls) + if v, ok := g.opts.Context.Value(tlsAuth{}).(*tls.Config); ok && v != nil { + return credentials.NewTLS(v) } } return nil @@ -156,15 +155,8 @@ func (g *grpcServer) getGrpcOptions() []grpc.ServerOption { return nil } - v := g.opts.Context.Value(grpcOptions{}) - - if v == nil { - return nil - } - - opts, ok := v.([]grpc.ServerOption) - - if !ok { + opts, ok := g.opts.Context.Value(grpcOptions{}).([]grpc.ServerOption) + if !ok || opts == nil { return nil } @@ -505,8 +497,8 @@ func (g *grpcServer) processStream(stream grpc.ServerStream, service *service, m func (g *grpcServer) newGRPCCodec(contentType string) (encoding.Codec, error) { codecs := make(map[string]encoding.Codec) if g.opts.Context != nil { - if v := g.opts.Context.Value(codecsKey{}); v != nil { - codecs = v.(map[string]encoding.Codec) + if v, ok := g.opts.Context.Value(codecsKey{}).(map[string]encoding.Codec); ok && v != nil { + codecs = v } } if c, ok := codecs[contentType]; ok { @@ -573,10 +565,10 @@ func (g *grpcServer) Subscribe(sb server.Subscriber) error { g.Lock() - _, ok = g.subscribers[sub] - if ok { + if _, ok = g.subscribers[sub]; ok { return fmt.Errorf("subscriber %v already exists", sub) } + g.subscribers[sub] = nil g.Unlock() return nil diff --git a/server/grpc/options.go b/server/grpc/options.go index a46b9147..64a8173d 100644 --- a/server/grpc/options.go +++ b/server/grpc/options.go @@ -27,8 +27,8 @@ func Codec(contentType string, c encoding.Codec) server.Option { if o.Context == nil { o.Context = context.Background() } - if v := o.Context.Value(codecsKey{}); v != nil { - codecs = v.(map[string]encoding.Codec) + if v, ok := o.Context.Value(codecsKey{}).(map[string]encoding.Codec); ok && v != nil { + codecs = v } codecs[contentType] = c o.Context = context.WithValue(o.Context, codecsKey{}, codecs) @@ -37,32 +37,17 @@ func Codec(contentType string, c encoding.Codec) server.Option { // AuthTLS should be used to setup a secure authentication using TLS func AuthTLS(t *tls.Config) server.Option { - return func(o *server.Options) { - if o.Context == nil { - o.Context = context.Background() - } - o.Context = context.WithValue(o.Context, tlsAuth{}, t) - } + return setServerOption(tlsAuth{}, t) } // Listener specifies the net.Listener to use instead of the default func Listener(l net.Listener) server.Option { - return func(o *server.Options) { - if o.Context == nil { - o.Context = context.Background() - } - o.Context = context.WithValue(o.Context, netListener{}, l) - } + return setServerOption(netListener{}, l) } // Options to be used to configure gRPC options func Options(opts ...grpc.ServerOption) server.Option { - return func(o *server.Options) { - if o.Context == nil { - o.Context = context.Background() - } - o.Context = context.WithValue(o.Context, grpcOptions{}, opts) - } + return setServerOption(grpcOptions{}, opts) } // @@ -70,51 +55,25 @@ func Options(opts ...grpc.ServerOption) server.Option { // send. Default maximum message size is 4 MB. // func MaxMsgSize(s int) server.Option { - return func(o *server.Options) { - if o.Context == nil { - o.Context = context.Background() - } - o.Context = context.WithValue(o.Context, maxMsgSizeKey{}, s) - } + return setServerOption(maxMsgSizeKey{}, s) } func newOptions(opt ...server.Option) server.Options { opts := server.Options{ - Codecs: make(map[string]codec.NewCodec), - Metadata: map[string]string{}, + Codecs: make(map[string]codec.NewCodec), + Metadata: map[string]string{}, + Broker: broker.DefaultBroker, + Registry: registry.DefaultRegistry, + Transport: transport.DefaultTransport, + Address: server.DefaultAddress, + Name: server.DefaultName, + Id: server.DefaultId, + Version: server.DefaultVersion, } for _, o := range opt { o(&opts) } - if opts.Broker == nil { - opts.Broker = broker.DefaultBroker - } - - if opts.Registry == nil { - opts.Registry = registry.DefaultRegistry - } - - if opts.Transport == nil { - opts.Transport = transport.DefaultTransport - } - - if len(opts.Address) == 0 { - opts.Address = server.DefaultAddress - } - - if len(opts.Name) == 0 { - opts.Name = server.DefaultName - } - - if len(opts.Id) == 0 { - opts.Id = server.DefaultId - } - - if len(opts.Version) == 0 { - opts.Version = server.DefaultVersion - } - return opts }