diff --git a/proto.go b/proto.go index 9d2455c..9aeb470 100644 --- a/proto.go +++ b/proto.go @@ -57,59 +57,32 @@ func (c *protoCodec) ReadHeader(conn io.Reader, m *codec.Message, t codec.Messag } func (c *protoCodec) ReadBody(conn io.Reader, v interface{}) error { - switch m := v.(type) { - case nil: + if v == nil { return nil - case *codec.Frame: - buf, err := io.ReadAll(conn) - if err != nil { - return err - } else if len(buf) == 0 { - return nil - } - m.Data = buf - return nil - case proto.Message: - buf, err := io.ReadAll(conn) - if err != nil { - return err - } else if len(buf) == 0 { - return nil - } - if nv, nerr := rutil.StructFieldByTag(v, codec.DefaultTagName, flattenTag); nerr == nil { - if nm, ok := nv.(proto.Message); ok { - m = nm - } - } - return proto.Unmarshal(buf, m) } - return codec.ErrInvalidMessage + + buf, err := io.ReadAll(conn) + if err != nil { + return err + } else if len(buf) == 0 { + return nil + } + return c.Unmarshal(buf, v) } func (c *protoCodec) Write(conn io.Writer, m *codec.Message, v interface{}) error { - switch m := v.(type) { - case nil: + if v == nil { return nil - case *codec.Frame: - _, err := conn.Write(m.Data) - return err - case proto.Message: - if nv, nerr := rutil.StructFieldByTag(v, codec.DefaultTagName, flattenTag); nerr == nil { - if nm, ok := nv.(proto.Message); ok { - m = nm - } - } - - buf, err := proto.Marshal(m) - if err != nil { - return err - } else if len(buf) == 0 { - return codec.ErrInvalidMessage - } - _, err = conn.Write(buf) - return err } - return codec.ErrInvalidMessage + + buf, err := c.Marshal(v) + if err != nil { + return err + } else if len(buf) == 0 { + return codec.ErrInvalidMessage + } + _, err = conn.Write(buf) + return err } func (c *protoCodec) String() string {