From 70cc7c93ef29138f8d97abd3114f6e8a9b6bf87e Mon Sep 17 00:00:00 2001 From: Vasiliy Tolstov Date: Wed, 19 Feb 2020 02:05:38 +0300 Subject: [PATCH] fixes for safe conversation and avoid panics (#1213) * fixes for safe convertation Signed-off-by: Vasiliy Tolstov * fix client publish panic If broker connect returns error we dont check it status and use it later to publish message, mostly this is unexpected because broker connection failed and we cant use it. Also proposed solution have benefit - we flag connection status only when we have succeseful broker connection Signed-off-by: Vasiliy Tolstov * api/handler/broker: fix possible broker publish panic Signed-off-by: Vasiliy Tolstov --- codec.go | 18 +++++++++----- context.go | 16 ++++++++++++ grpc.go | 24 ++++++------------ options.go | 71 ++++++++++++------------------------------------------ 4 files changed, 51 insertions(+), 78 deletions(-) create mode 100644 context.go diff --git a/codec.go b/codec.go index cb27ec6..b92b63b 100644 --- a/codec.go +++ b/codec.go @@ -2,7 +2,6 @@ package grpc import ( "encoding/json" - "fmt" "strings" b "bytes" @@ -71,11 +70,19 @@ func (w wrapCodec) Unmarshal(data []byte, v interface{}) error { } func (protoCodec) Marshal(v interface{}) ([]byte, error) { - return proto.Marshal(v.(proto.Message)) + m, ok := v.(proto.Message) + if !ok { + return nil, codec.ErrInvalidMessage + } + return proto.Marshal(m) } func (protoCodec) Unmarshal(data []byte, v interface{}) error { - return proto.Unmarshal(data, v.(proto.Message)) + m, ok := v.(proto.Message) + if !ok { + return codec.ErrInvalidMessage + } + return proto.Unmarshal(data, m) } func (protoCodec) Name() string { @@ -85,7 +92,6 @@ func (protoCodec) Name() string { func (jsonCodec) Marshal(v interface{}) ([]byte, error) { if pb, ok := v.(proto.Message); ok { s, err := jsonpbMarshaler.MarshalToString(pb) - return []byte(s), err } @@ -109,7 +115,7 @@ func (jsonCodec) Name() string { func (bytesCodec) Marshal(v interface{}) ([]byte, error) { b, ok := v.(*[]byte) if !ok { - return nil, fmt.Errorf("failed to marshal: %v is not type of *[]byte", v) + return nil, codec.ErrInvalidMessage } return *b, nil } @@ -117,7 +123,7 @@ func (bytesCodec) Marshal(v interface{}) ([]byte, error) { func (bytesCodec) Unmarshal(data []byte, v interface{}) error { b, ok := v.(*[]byte) if !ok { - return fmt.Errorf("failed to unmarshal: %v is not type of *[]byte", v) + return codec.ErrInvalidMessage } *b = data return nil diff --git a/context.go b/context.go new file mode 100644 index 0000000..0b246d7 --- /dev/null +++ b/context.go @@ -0,0 +1,16 @@ +package grpc + +import ( + "context" + + "github.com/micro/go-micro/v2/server" +) + +func setServerOption(k, v interface{}) server.Option { + return func(o *server.Options) { + if o.Context == nil { + o.Context = context.Background() + } + o.Context = context.WithValue(o.Context, k, v) + } +} diff --git a/grpc.go b/grpc.go index b21c58e..001ea1b 100644 --- a/grpc.go +++ b/grpc.go @@ -143,9 +143,8 @@ func (g *grpcServer) getMaxMsgSize() int { func (g *grpcServer) getCredentials() credentials.TransportCredentials { if g.opts.Context != nil { - if v := g.opts.Context.Value(tlsAuth{}); v != nil { - tls := v.(*tls.Config) - return credentials.NewTLS(tls) + if v, ok := g.opts.Context.Value(tlsAuth{}).(*tls.Config); ok && v != nil { + return credentials.NewTLS(v) } } return nil @@ -156,15 +155,8 @@ func (g *grpcServer) getGrpcOptions() []grpc.ServerOption { return nil } - v := g.opts.Context.Value(grpcOptions{}) - - if v == nil { - return nil - } - - opts, ok := v.([]grpc.ServerOption) - - if !ok { + opts, ok := g.opts.Context.Value(grpcOptions{}).([]grpc.ServerOption) + if !ok || opts == nil { return nil } @@ -505,8 +497,8 @@ func (g *grpcServer) processStream(stream grpc.ServerStream, service *service, m func (g *grpcServer) newGRPCCodec(contentType string) (encoding.Codec, error) { codecs := make(map[string]encoding.Codec) if g.opts.Context != nil { - if v := g.opts.Context.Value(codecsKey{}); v != nil { - codecs = v.(map[string]encoding.Codec) + if v, ok := g.opts.Context.Value(codecsKey{}).(map[string]encoding.Codec); ok && v != nil { + codecs = v } } if c, ok := codecs[contentType]; ok { @@ -573,10 +565,10 @@ func (g *grpcServer) Subscribe(sb server.Subscriber) error { g.Lock() - _, ok = g.subscribers[sub] - if ok { + if _, ok = g.subscribers[sub]; ok { return fmt.Errorf("subscriber %v already exists", sub) } + g.subscribers[sub] = nil g.Unlock() return nil diff --git a/options.go b/options.go index a46b914..64a8173 100644 --- a/options.go +++ b/options.go @@ -27,8 +27,8 @@ func Codec(contentType string, c encoding.Codec) server.Option { if o.Context == nil { o.Context = context.Background() } - if v := o.Context.Value(codecsKey{}); v != nil { - codecs = v.(map[string]encoding.Codec) + if v, ok := o.Context.Value(codecsKey{}).(map[string]encoding.Codec); ok && v != nil { + codecs = v } codecs[contentType] = c o.Context = context.WithValue(o.Context, codecsKey{}, codecs) @@ -37,32 +37,17 @@ func Codec(contentType string, c encoding.Codec) server.Option { // AuthTLS should be used to setup a secure authentication using TLS func AuthTLS(t *tls.Config) server.Option { - return func(o *server.Options) { - if o.Context == nil { - o.Context = context.Background() - } - o.Context = context.WithValue(o.Context, tlsAuth{}, t) - } + return setServerOption(tlsAuth{}, t) } // Listener specifies the net.Listener to use instead of the default func Listener(l net.Listener) server.Option { - return func(o *server.Options) { - if o.Context == nil { - o.Context = context.Background() - } - o.Context = context.WithValue(o.Context, netListener{}, l) - } + return setServerOption(netListener{}, l) } // Options to be used to configure gRPC options func Options(opts ...grpc.ServerOption) server.Option { - return func(o *server.Options) { - if o.Context == nil { - o.Context = context.Background() - } - o.Context = context.WithValue(o.Context, grpcOptions{}, opts) - } + return setServerOption(grpcOptions{}, opts) } // @@ -70,51 +55,25 @@ func Options(opts ...grpc.ServerOption) server.Option { // send. Default maximum message size is 4 MB. // func MaxMsgSize(s int) server.Option { - return func(o *server.Options) { - if o.Context == nil { - o.Context = context.Background() - } - o.Context = context.WithValue(o.Context, maxMsgSizeKey{}, s) - } + return setServerOption(maxMsgSizeKey{}, s) } func newOptions(opt ...server.Option) server.Options { opts := server.Options{ - Codecs: make(map[string]codec.NewCodec), - Metadata: map[string]string{}, + Codecs: make(map[string]codec.NewCodec), + Metadata: map[string]string{}, + Broker: broker.DefaultBroker, + Registry: registry.DefaultRegistry, + Transport: transport.DefaultTransport, + Address: server.DefaultAddress, + Name: server.DefaultName, + Id: server.DefaultId, + Version: server.DefaultVersion, } for _, o := range opt { o(&opts) } - if opts.Broker == nil { - opts.Broker = broker.DefaultBroker - } - - if opts.Registry == nil { - opts.Registry = registry.DefaultRegistry - } - - if opts.Transport == nil { - opts.Transport = transport.DefaultTransport - } - - if len(opts.Address) == 0 { - opts.Address = server.DefaultAddress - } - - if len(opts.Name) == 0 { - opts.Name = server.DefaultName - } - - if len(opts.Id) == 0 { - opts.Id = server.DefaultId - } - - if len(opts.Version) == 0 { - opts.Version = server.DefaultVersion - } - return opts }