diff --git a/proto.go b/proto.go index 88f6e0b..9d2455c 100644 --- a/proto.go +++ b/proto.go @@ -24,7 +24,7 @@ func (c *protoCodec) Marshal(v interface{}) ([]byte, error) { case proto.Message: if nv, nerr := rutil.StructFieldByTag(v, codec.DefaultTagName, flattenTag); nerr == nil { if nm, ok := nv.(proto.Message); ok { - return proto.Marshal(nm) + m = nm } } return proto.Marshal(m) @@ -44,7 +44,7 @@ func (c *protoCodec) Unmarshal(d []byte, v interface{}) error { case proto.Message: if nv, nerr := rutil.StructFieldByTag(v, codec.DefaultTagName, flattenTag); nerr == nil { if nm, ok := nv.(proto.Message); ok { - return proto.Unmarshal(d, nm) + m = nm } } return proto.Unmarshal(d, m) @@ -78,7 +78,7 @@ func (c *protoCodec) ReadBody(conn io.Reader, v interface{}) error { } if nv, nerr := rutil.StructFieldByTag(v, codec.DefaultTagName, flattenTag); nerr == nil { if nm, ok := nv.(proto.Message); ok { - return proto.Unmarshal(buf, nm) + m = nm } } return proto.Unmarshal(buf, m) @@ -94,17 +94,13 @@ func (c *protoCodec) Write(conn io.Writer, m *codec.Message, v interface{}) erro _, err := conn.Write(m.Data) return err case proto.Message: - var buf []byte - var err error - if nv, nerr := rutil.StructFieldByTag(v, codec.DefaultTagName, flattenTag); nerr == nil { if nm, ok := nv.(proto.Message); ok { - buf, err = proto.Marshal(nm) + m = nm } - } else { - buf, err = proto.Marshal(m) } + buf, err := proto.Marshal(m) if err != nil { return err } else if len(buf) == 0 {