micro/codec/protorpc/protorpc.go

187 lines
3.6 KiB
Go
Raw Normal View History

2018-11-20 13:06:13 +03:00
// Protorpc provides a net/rpc proto-rpc codec. See envelope.proto for the format.
2015-11-28 21:54:38 +03:00
package protorpc
import (
"bytes"
"fmt"
"io"
2019-01-08 18:38:25 +03:00
"strconv"
"sync"
"github.com/golang/protobuf/proto"
"github.com/micro/go-micro/v2/codec"
)
type flusher interface {
Flush() error
}
type protoCodec struct {
sync.Mutex
rwc io.ReadWriteCloser
mt codec.MessageType
buf *bytes.Buffer
}
func (c *protoCodec) Close() error {
c.buf.Reset()
return c.rwc.Close()
}
func (c *protoCodec) String() string {
2015-11-28 21:54:38 +03:00
return "proto-rpc"
}
2019-01-08 18:38:25 +03:00
func id(id string) *uint64 {
p, err := strconv.ParseInt(id, 10, 64)
if err != nil {
p = 0
}
i := uint64(p)
return &i
}
func (c *protoCodec) Write(m *codec.Message, b interface{}) error {
switch m.Type {
case codec.Request:
c.Lock()
defer c.Unlock()
// This is protobuf, of course we copy it.
2019-01-18 13:12:57 +03:00
pbr := &Request{ServiceMethod: &m.Method, Seq: id(m.Id)}
data, err := proto.Marshal(pbr)
if err != nil {
return err
}
_, err = WriteNetString(c.rwc, data)
if err != nil {
return err
}
// dont trust or incoming message
m, ok := b.(proto.Message)
if !ok {
return codec.ErrInvalidMessage
}
data, err = proto.Marshal(m)
if err != nil {
return err
}
_, err = WriteNetString(c.rwc, data)
if err != nil {
return err
}
if flusher, ok := c.rwc.(flusher); ok {
2018-11-13 11:56:21 +03:00
if err = flusher.Flush(); err != nil {
return err
}
}
2019-01-13 15:15:35 +03:00
case codec.Response, codec.Error:
c.Lock()
defer c.Unlock()
2019-01-18 13:12:57 +03:00
rtmp := &Response{ServiceMethod: &m.Method, Seq: id(m.Id), Error: &m.Error}
data, err := proto.Marshal(rtmp)
if err != nil {
return err
}
_, err = WriteNetString(c.rwc, data)
if err != nil {
return err
}
if pb, ok := b.(proto.Message); ok {
data, err = proto.Marshal(pb)
if err != nil {
return err
}
} else {
data = nil
}
_, err = WriteNetString(c.rwc, data)
if err != nil {
return err
}
if flusher, ok := c.rwc.(flusher); ok {
2018-11-13 11:56:21 +03:00
if err = flusher.Flush(); err != nil {
return err
}
}
2019-07-07 14:44:09 +03:00
case codec.Event:
m, ok := b.(proto.Message)
if !ok {
return codec.ErrInvalidMessage
}
data, err := proto.Marshal(m)
if err != nil {
return err
}
c.rwc.Write(data)
default:
return fmt.Errorf("Unrecognised message type: %v", m.Type)
}
return nil
}
func (c *protoCodec) ReadHeader(m *codec.Message, mt codec.MessageType) error {
c.buf.Reset()
c.mt = mt
switch mt {
case codec.Request:
data, err := ReadNetString(c.rwc)
if err != nil {
return err
}
2015-11-28 19:39:25 +03:00
rtmp := new(Request)
err = proto.Unmarshal(data, rtmp)
if err != nil {
return err
}
2019-01-18 13:12:57 +03:00
m.Method = rtmp.GetServiceMethod()
2019-01-08 18:38:25 +03:00
m.Id = fmt.Sprintf("%d", rtmp.GetSeq())
case codec.Response:
data, err := ReadNetString(c.rwc)
if err != nil {
return err
}
2015-11-28 19:39:25 +03:00
rtmp := new(Response)
err = proto.Unmarshal(data, rtmp)
if err != nil {
return err
}
2019-01-18 13:12:57 +03:00
m.Method = rtmp.GetServiceMethod()
2019-01-08 18:38:25 +03:00
m.Id = fmt.Sprintf("%d", rtmp.GetSeq())
2016-04-01 11:28:52 +03:00
m.Error = rtmp.GetError()
2019-07-07 14:44:09 +03:00
case codec.Event:
2018-11-13 11:56:21 +03:00
_, err := io.Copy(c.buf, c.rwc)
return err
default:
return fmt.Errorf("Unrecognised message type: %v", mt)
}
return nil
}
func (c *protoCodec) ReadBody(b interface{}) error {
var data []byte
switch c.mt {
case codec.Request, codec.Response:
var err error
data, err = ReadNetString(c.rwc)
if err != nil {
return err
}
2019-07-07 14:44:09 +03:00
case codec.Event:
data = c.buf.Bytes()
default:
return fmt.Errorf("Unrecognised message type: %v", c.mt)
}
if b != nil {
return proto.Unmarshal(data, b.(proto.Message))
}
return nil
}
func NewCodec(rwc io.ReadWriteCloser) codec.Codec {
return &protoCodec{
buf: bytes.NewBuffer(nil),
rwc: rwc,
}
}