diff --git a/codec_test.go b/codec_test.go index 9fa7392..1ddfe84 100644 --- a/codec_test.go +++ b/codec_test.go @@ -3,8 +3,40 @@ package proto import ( "bytes" "testing" + + "github.com/unistack-org/micro/v3/codec" ) +func TestFrame(t *testing.T) { + s := &codec.Frame{Data: []byte("test")} + + buf, err := NewCodec().Marshal(s) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(buf, []byte(`test`)) { + t.Fatalf("bytes not equal %s != %s", buf, `test`) + } +} + +func TestFrameFlatten(t *testing.T) { + s := &struct { + One string + Name *codec.Frame `json:"name" codec:"flatten"` + }{ + One: "xx", + Name: &codec.Frame{Data: []byte("test")}, + } + + buf, err := NewCodec().Marshal(s) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(buf, []byte(`test`)) { + t.Fatalf("bytes not equal %s != %s", buf, `test`) + } +} + func TestReadBody(t *testing.T) { t.Skip("skip as no proto") s := &struct { diff --git a/proto.go b/proto.go index 271a5eb..1fae9b2 100644 --- a/proto.go +++ b/proto.go @@ -13,55 +13,61 @@ type protoCodec struct { opts codec.Options } +var _ codec.Codec = &protoCodec{} + const ( flattenTag = "flatten" ) func (c *protoCodec) Marshal(v interface{}, opts ...codec.Option) ([]byte, error) { - switch m := v.(type) { - case nil: + if v == nil { return nil, nil - case *codec.Frame: - return m.Data, nil - case proto.Message: - options := c.opts - for _, o := range opts { - o(&options) - } - - if nv, nerr := rutil.StructFieldByTag(v, options.TagName, flattenTag); nerr == nil { - if nm, ok := nv.(proto.Message); ok { - m = nm - } - } - return proto.Marshal(m) } - return nil, codec.ErrInvalidMessage + + options := c.opts + for _, o := range opts { + o(&options) + } + + if nv, nerr := rutil.StructFieldByTag(v, options.TagName, flattenTag); nerr == nil { + v = nv + } + + if m, ok := v.(*codec.Frame); ok { + return m.Data, nil + } + + if _, ok := v.(proto.Message); !ok { + return nil, codec.ErrInvalidMessage + } + + return proto.Marshal(v.(proto.Message)) } func (c *protoCodec) Unmarshal(d []byte, v interface{}, opts ...codec.Option) error { - if len(d) == 0 { + if v == nil || len(d) == 0 { return nil } - switch m := v.(type) { - case nil: - return nil - case *codec.Frame: - m.Data = d - case proto.Message: - options := c.opts - for _, o := range opts { - o(&options) - } - if nv, nerr := rutil.StructFieldByTag(v, options.TagName, flattenTag); nerr == nil { - if nm, ok := nv.(proto.Message); ok { - m = nm - } - } - return proto.Unmarshal(d, m) + options := c.opts + for _, o := range opts { + o(&options) } - return codec.ErrInvalidMessage + + if nv, nerr := rutil.StructFieldByTag(v, options.TagName, flattenTag); nerr == nil { + v = nv + } + + if m, ok := v.(*codec.Frame); ok { + m.Data = d + return nil + } + + if _, ok := v.(proto.Message); !ok { + return codec.ErrInvalidMessage + } + + return proto.Unmarshal(d, v.(proto.Message)) } func (c *protoCodec) ReadHeader(conn io.Reader, m *codec.Message, t codec.MessageType) error {