diff --git a/proto/proto.go b/proto/proto.go index d912616..2699399 100644 --- a/proto/proto.go +++ b/proto/proto.go @@ -38,14 +38,14 @@ func (c *protoCodec) Marshal(v interface{}, opts ...codec.Option) ([]byte, error return m.Data, nil } - switch v.(type) { - case proto.Message, newproto.Message: - break - default: - return nil, codec.ErrInvalidMessage + switch m := v.(type) { + case proto.Message: + return proto.Marshal(m) + case newproto.Message: + return proto.Marshal(m) } - return proto.Marshal(v.(proto.Message)) + return nil, codec.ErrInvalidMessage } func (c *protoCodec) Unmarshal(d []byte, v interface{}, opts ...codec.Option) error { @@ -67,14 +67,15 @@ func (c *protoCodec) Unmarshal(d []byte, v interface{}, opts ...codec.Option) er return nil } - switch v.(type) { - case proto.Message, newproto.Message: - break - default: - return codec.ErrInvalidMessage + switch m := v.(type) { + case proto.Message: + return proto.Unmarshal(d, m) + case newproto.Message: + return proto.Unmarshal(d, m) + } - return proto.Unmarshal(d, v.(proto.Message)) + return codec.ErrInvalidMessage } func (c *protoCodec) ReadHeader(conn io.Reader, m *codec.Message, t codec.MessageType) error {