diff --git a/codec.go b/codec.go index cb27ec6..b92b63b 100644 --- a/codec.go +++ b/codec.go @@ -2,7 +2,6 @@ package grpc import ( "encoding/json" - "fmt" "strings" b "bytes" @@ -71,11 +70,19 @@ func (w wrapCodec) Unmarshal(data []byte, v interface{}) error { } func (protoCodec) Marshal(v interface{}) ([]byte, error) { - return proto.Marshal(v.(proto.Message)) + m, ok := v.(proto.Message) + if !ok { + return nil, codec.ErrInvalidMessage + } + return proto.Marshal(m) } func (protoCodec) Unmarshal(data []byte, v interface{}) error { - return proto.Unmarshal(data, v.(proto.Message)) + m, ok := v.(proto.Message) + if !ok { + return codec.ErrInvalidMessage + } + return proto.Unmarshal(data, m) } func (protoCodec) Name() string { @@ -85,7 +92,6 @@ func (protoCodec) Name() string { func (jsonCodec) Marshal(v interface{}) ([]byte, error) { if pb, ok := v.(proto.Message); ok { s, err := jsonpbMarshaler.MarshalToString(pb) - return []byte(s), err } @@ -109,7 +115,7 @@ func (jsonCodec) Name() string { func (bytesCodec) Marshal(v interface{}) ([]byte, error) { b, ok := v.(*[]byte) if !ok { - return nil, fmt.Errorf("failed to marshal: %v is not type of *[]byte", v) + return nil, codec.ErrInvalidMessage } return *b, nil } @@ -117,7 +123,7 @@ func (bytesCodec) Marshal(v interface{}) ([]byte, error) { func (bytesCodec) Unmarshal(data []byte, v interface{}) error { b, ok := v.(*[]byte) if !ok { - return fmt.Errorf("failed to unmarshal: %v is not type of *[]byte", v) + return codec.ErrInvalidMessage } *b = data return nil diff --git a/context.go b/context.go new file mode 100644 index 0000000..0b246d7 --- /dev/null +++ b/context.go @@ -0,0 +1,16 @@ +package grpc + +import ( + "context" + + "github.com/micro/go-micro/v2/server" +) + +func setServerOption(k, v interface{}) server.Option { + return func(o *server.Options) { + if o.Context == nil { + o.Context = context.Background() + } + o.Context = context.WithValue(o.Context, k, v) + } +} diff --git a/grpc.go b/grpc.go index b21c58e..001ea1b 100644 --- a/grpc.go +++ b/grpc.go @@ -143,9 +143,8 @@ func (g *grpcServer) getMaxMsgSize() int { func (g *grpcServer) getCredentials() credentials.TransportCredentials { if g.opts.Context != nil { - if v := g.opts.Context.Value(tlsAuth{}); v != nil { - tls := v.(*tls.Config) - return credentials.NewTLS(tls) + if v, ok := g.opts.Context.Value(tlsAuth{}).(*tls.Config); ok && v != nil { + return credentials.NewTLS(v) } } return nil @@ -156,15 +155,8 @@ func (g *grpcServer) getGrpcOptions() []grpc.ServerOption { return nil } - v := g.opts.Context.Value(grpcOptions{}) - - if v == nil { - return nil - } - - opts, ok := v.([]grpc.ServerOption) - - if !ok { + opts, ok := g.opts.Context.Value(grpcOptions{}).([]grpc.ServerOption) + if !ok || opts == nil { return nil } @@ -505,8 +497,8 @@ func (g *grpcServer) processStream(stream grpc.ServerStream, service *service, m func (g *grpcServer) newGRPCCodec(contentType string) (encoding.Codec, error) { codecs := make(map[string]encoding.Codec) if g.opts.Context != nil { - if v := g.opts.Context.Value(codecsKey{}); v != nil { - codecs = v.(map[string]encoding.Codec) + if v, ok := g.opts.Context.Value(codecsKey{}).(map[string]encoding.Codec); ok && v != nil { + codecs = v } } if c, ok := codecs[contentType]; ok { @@ -573,10 +565,10 @@ func (g *grpcServer) Subscribe(sb server.Subscriber) error { g.Lock() - _, ok = g.subscribers[sub] - if ok { + if _, ok = g.subscribers[sub]; ok { return fmt.Errorf("subscriber %v already exists", sub) } + g.subscribers[sub] = nil g.Unlock() return nil diff --git a/options.go b/options.go index a46b914..64a8173 100644 --- a/options.go +++ b/options.go @@ -27,8 +27,8 @@ func Codec(contentType string, c encoding.Codec) server.Option { if o.Context == nil { o.Context = context.Background() } - if v := o.Context.Value(codecsKey{}); v != nil { - codecs = v.(map[string]encoding.Codec) + if v, ok := o.Context.Value(codecsKey{}).(map[string]encoding.Codec); ok && v != nil { + codecs = v } codecs[contentType] = c o.Context = context.WithValue(o.Context, codecsKey{}, codecs) @@ -37,32 +37,17 @@ func Codec(contentType string, c encoding.Codec) server.Option { // AuthTLS should be used to setup a secure authentication using TLS func AuthTLS(t *tls.Config) server.Option { - return func(o *server.Options) { - if o.Context == nil { - o.Context = context.Background() - } - o.Context = context.WithValue(o.Context, tlsAuth{}, t) - } + return setServerOption(tlsAuth{}, t) } // Listener specifies the net.Listener to use instead of the default func Listener(l net.Listener) server.Option { - return func(o *server.Options) { - if o.Context == nil { - o.Context = context.Background() - } - o.Context = context.WithValue(o.Context, netListener{}, l) - } + return setServerOption(netListener{}, l) } // Options to be used to configure gRPC options func Options(opts ...grpc.ServerOption) server.Option { - return func(o *server.Options) { - if o.Context == nil { - o.Context = context.Background() - } - o.Context = context.WithValue(o.Context, grpcOptions{}, opts) - } + return setServerOption(grpcOptions{}, opts) } // @@ -70,51 +55,25 @@ func Options(opts ...grpc.ServerOption) server.Option { // send. Default maximum message size is 4 MB. // func MaxMsgSize(s int) server.Option { - return func(o *server.Options) { - if o.Context == nil { - o.Context = context.Background() - } - o.Context = context.WithValue(o.Context, maxMsgSizeKey{}, s) - } + return setServerOption(maxMsgSizeKey{}, s) } func newOptions(opt ...server.Option) server.Options { opts := server.Options{ - Codecs: make(map[string]codec.NewCodec), - Metadata: map[string]string{}, + Codecs: make(map[string]codec.NewCodec), + Metadata: map[string]string{}, + Broker: broker.DefaultBroker, + Registry: registry.DefaultRegistry, + Transport: transport.DefaultTransport, + Address: server.DefaultAddress, + Name: server.DefaultName, + Id: server.DefaultId, + Version: server.DefaultVersion, } for _, o := range opt { o(&opts) } - if opts.Broker == nil { - opts.Broker = broker.DefaultBroker - } - - if opts.Registry == nil { - opts.Registry = registry.DefaultRegistry - } - - if opts.Transport == nil { - opts.Transport = transport.DefaultTransport - } - - if len(opts.Address) == 0 { - opts.Address = server.DefaultAddress - } - - if len(opts.Name) == 0 { - opts.Name = server.DefaultName - } - - if len(opts.Id) == 0 { - opts.Id = server.DefaultId - } - - if len(opts.Version) == 0 { - opts.Version = server.DefaultVersion - } - return opts }