fixes for safe conversation and avoid panics (#1213)

* fixes for safe convertation

Signed-off-by: Vasiliy Tolstov <v.tolstov@unistack.org>

* fix client publish panic

If broker connect returns error we dont check it status and use
it later to publish message, mostly this is unexpected because
broker connection failed and we cant use it.
Also proposed solution have benefit - we flag connection status
only when we have succeseful broker connection

Signed-off-by: Vasiliy Tolstov <v.tolstov@unistack.org>

* api/handler/broker: fix possible broker publish panic

Signed-off-by: Vasiliy Tolstov <v.tolstov@unistack.org>
This commit is contained in:
Василий Толстов 2020-02-19 02:05:38 +03:00
parent d89db33c07
commit 70cc7c93ef
4 changed files with 51 additions and 78 deletions

View File

@ -2,7 +2,6 @@ package grpc
import ( import (
"encoding/json" "encoding/json"
"fmt"
"strings" "strings"
b "bytes" b "bytes"
@ -71,11 +70,19 @@ func (w wrapCodec) Unmarshal(data []byte, v interface{}) error {
} }
func (protoCodec) Marshal(v interface{}) ([]byte, error) { func (protoCodec) Marshal(v interface{}) ([]byte, error) {
return proto.Marshal(v.(proto.Message)) m, ok := v.(proto.Message)
if !ok {
return nil, codec.ErrInvalidMessage
}
return proto.Marshal(m)
} }
func (protoCodec) Unmarshal(data []byte, v interface{}) error { func (protoCodec) Unmarshal(data []byte, v interface{}) error {
return proto.Unmarshal(data, v.(proto.Message)) m, ok := v.(proto.Message)
if !ok {
return codec.ErrInvalidMessage
}
return proto.Unmarshal(data, m)
} }
func (protoCodec) Name() string { func (protoCodec) Name() string {
@ -85,7 +92,6 @@ func (protoCodec) Name() string {
func (jsonCodec) Marshal(v interface{}) ([]byte, error) { func (jsonCodec) Marshal(v interface{}) ([]byte, error) {
if pb, ok := v.(proto.Message); ok { if pb, ok := v.(proto.Message); ok {
s, err := jsonpbMarshaler.MarshalToString(pb) s, err := jsonpbMarshaler.MarshalToString(pb)
return []byte(s), err return []byte(s), err
} }
@ -109,7 +115,7 @@ func (jsonCodec) Name() string {
func (bytesCodec) Marshal(v interface{}) ([]byte, error) { func (bytesCodec) Marshal(v interface{}) ([]byte, error) {
b, ok := v.(*[]byte) b, ok := v.(*[]byte)
if !ok { if !ok {
return nil, fmt.Errorf("failed to marshal: %v is not type of *[]byte", v) return nil, codec.ErrInvalidMessage
} }
return *b, nil return *b, nil
} }
@ -117,7 +123,7 @@ func (bytesCodec) Marshal(v interface{}) ([]byte, error) {
func (bytesCodec) Unmarshal(data []byte, v interface{}) error { func (bytesCodec) Unmarshal(data []byte, v interface{}) error {
b, ok := v.(*[]byte) b, ok := v.(*[]byte)
if !ok { if !ok {
return fmt.Errorf("failed to unmarshal: %v is not type of *[]byte", v) return codec.ErrInvalidMessage
} }
*b = data *b = data
return nil return nil

16
context.go Normal file
View File

@ -0,0 +1,16 @@
package grpc
import (
"context"
"github.com/micro/go-micro/v2/server"
)
func setServerOption(k, v interface{}) server.Option {
return func(o *server.Options) {
if o.Context == nil {
o.Context = context.Background()
}
o.Context = context.WithValue(o.Context, k, v)
}
}

24
grpc.go
View File

@ -143,9 +143,8 @@ func (g *grpcServer) getMaxMsgSize() int {
func (g *grpcServer) getCredentials() credentials.TransportCredentials { func (g *grpcServer) getCredentials() credentials.TransportCredentials {
if g.opts.Context != nil { if g.opts.Context != nil {
if v := g.opts.Context.Value(tlsAuth{}); v != nil { if v, ok := g.opts.Context.Value(tlsAuth{}).(*tls.Config); ok && v != nil {
tls := v.(*tls.Config) return credentials.NewTLS(v)
return credentials.NewTLS(tls)
} }
} }
return nil return nil
@ -156,15 +155,8 @@ func (g *grpcServer) getGrpcOptions() []grpc.ServerOption {
return nil return nil
} }
v := g.opts.Context.Value(grpcOptions{}) opts, ok := g.opts.Context.Value(grpcOptions{}).([]grpc.ServerOption)
if !ok || opts == nil {
if v == nil {
return nil
}
opts, ok := v.([]grpc.ServerOption)
if !ok {
return nil return nil
} }
@ -505,8 +497,8 @@ func (g *grpcServer) processStream(stream grpc.ServerStream, service *service, m
func (g *grpcServer) newGRPCCodec(contentType string) (encoding.Codec, error) { func (g *grpcServer) newGRPCCodec(contentType string) (encoding.Codec, error) {
codecs := make(map[string]encoding.Codec) codecs := make(map[string]encoding.Codec)
if g.opts.Context != nil { if g.opts.Context != nil {
if v := g.opts.Context.Value(codecsKey{}); v != nil { if v, ok := g.opts.Context.Value(codecsKey{}).(map[string]encoding.Codec); ok && v != nil {
codecs = v.(map[string]encoding.Codec) codecs = v
} }
} }
if c, ok := codecs[contentType]; ok { if c, ok := codecs[contentType]; ok {
@ -573,10 +565,10 @@ func (g *grpcServer) Subscribe(sb server.Subscriber) error {
g.Lock() g.Lock()
_, ok = g.subscribers[sub] if _, ok = g.subscribers[sub]; ok {
if ok {
return fmt.Errorf("subscriber %v already exists", sub) return fmt.Errorf("subscriber %v already exists", sub)
} }
g.subscribers[sub] = nil g.subscribers[sub] = nil
g.Unlock() g.Unlock()
return nil return nil

View File

@ -27,8 +27,8 @@ func Codec(contentType string, c encoding.Codec) server.Option {
if o.Context == nil { if o.Context == nil {
o.Context = context.Background() o.Context = context.Background()
} }
if v := o.Context.Value(codecsKey{}); v != nil { if v, ok := o.Context.Value(codecsKey{}).(map[string]encoding.Codec); ok && v != nil {
codecs = v.(map[string]encoding.Codec) codecs = v
} }
codecs[contentType] = c codecs[contentType] = c
o.Context = context.WithValue(o.Context, codecsKey{}, codecs) o.Context = context.WithValue(o.Context, codecsKey{}, codecs)
@ -37,32 +37,17 @@ func Codec(contentType string, c encoding.Codec) server.Option {
// AuthTLS should be used to setup a secure authentication using TLS // AuthTLS should be used to setup a secure authentication using TLS
func AuthTLS(t *tls.Config) server.Option { func AuthTLS(t *tls.Config) server.Option {
return func(o *server.Options) { return setServerOption(tlsAuth{}, t)
if o.Context == nil {
o.Context = context.Background()
}
o.Context = context.WithValue(o.Context, tlsAuth{}, t)
}
} }
// Listener specifies the net.Listener to use instead of the default // Listener specifies the net.Listener to use instead of the default
func Listener(l net.Listener) server.Option { func Listener(l net.Listener) server.Option {
return func(o *server.Options) { return setServerOption(netListener{}, l)
if o.Context == nil {
o.Context = context.Background()
}
o.Context = context.WithValue(o.Context, netListener{}, l)
}
} }
// Options to be used to configure gRPC options // Options to be used to configure gRPC options
func Options(opts ...grpc.ServerOption) server.Option { func Options(opts ...grpc.ServerOption) server.Option {
return func(o *server.Options) { return setServerOption(grpcOptions{}, opts)
if o.Context == nil {
o.Context = context.Background()
}
o.Context = context.WithValue(o.Context, grpcOptions{}, opts)
}
} }
// //
@ -70,51 +55,25 @@ func Options(opts ...grpc.ServerOption) server.Option {
// send. Default maximum message size is 4 MB. // send. Default maximum message size is 4 MB.
// //
func MaxMsgSize(s int) server.Option { func MaxMsgSize(s int) server.Option {
return func(o *server.Options) { return setServerOption(maxMsgSizeKey{}, s)
if o.Context == nil {
o.Context = context.Background()
}
o.Context = context.WithValue(o.Context, maxMsgSizeKey{}, s)
}
} }
func newOptions(opt ...server.Option) server.Options { func newOptions(opt ...server.Option) server.Options {
opts := server.Options{ opts := server.Options{
Codecs: make(map[string]codec.NewCodec), Codecs: make(map[string]codec.NewCodec),
Metadata: map[string]string{}, Metadata: map[string]string{},
Broker: broker.DefaultBroker,
Registry: registry.DefaultRegistry,
Transport: transport.DefaultTransport,
Address: server.DefaultAddress,
Name: server.DefaultName,
Id: server.DefaultId,
Version: server.DefaultVersion,
} }
for _, o := range opt { for _, o := range opt {
o(&opts) o(&opts)
} }
if opts.Broker == nil {
opts.Broker = broker.DefaultBroker
}
if opts.Registry == nil {
opts.Registry = registry.DefaultRegistry
}
if opts.Transport == nil {
opts.Transport = transport.DefaultTransport
}
if len(opts.Address) == 0 {
opts.Address = server.DefaultAddress
}
if len(opts.Name) == 0 {
opts.Name = server.DefaultName
}
if len(opts.Id) == 0 {
opts.Id = server.DefaultId
}
if len(opts.Version) == 0 {
opts.Version = server.DefaultVersion
}
return opts return opts
} }