diff --git a/codec.go b/codec.go index 681a104..cfa97f9 100644 --- a/codec.go +++ b/codec.go @@ -3,19 +3,22 @@ package grpc import ( "encoding/json" "fmt" + "strings" "github.com/golang/protobuf/proto" "github.com/micro/go-micro/codec" "github.com/micro/go-micro/codec/bytes" "github.com/micro/go-micro/codec/jsonrpc" "github.com/micro/go-micro/codec/protorpc" + "google.golang.org/grpc" "google.golang.org/grpc/encoding" + "google.golang.org/grpc/metadata" ) type jsonCodec struct{} type bytesCodec struct{} type protoCodec struct{} -type wrapCodec struct { encoding.Codec } +type wrapCodec struct{ encoding.Codec } var ( defaultGRPCCodecs = map[string]encoding.Codec{ @@ -39,24 +42,24 @@ var ( ) func (w wrapCodec) String() string { - return w.Codec.Name() + return w.Codec.Name() } func (w wrapCodec) Marshal(v interface{}) ([]byte, error) { - b, ok := v.(*bytes.Frame) - if ok { - return b.Data, nil - } - return w.Codec.Marshal(v) + b, ok := v.(*bytes.Frame) + if ok { + return b.Data, nil + } + return w.Codec.Marshal(v) } func (w wrapCodec) Unmarshal(data []byte, v interface{}) error { - b, ok := v.(*bytes.Frame) - if ok { - b.Data = data - return nil - } - return w.Codec.Unmarshal(data, v) + b, ok := v.(*bytes.Frame) + if ok { + b.Data = data + return nil + } + return w.Codec.Unmarshal(data, v) } func (protoCodec) Marshal(v interface{}) ([]byte, error) { @@ -103,3 +106,61 @@ func (bytesCodec) Unmarshal(data []byte, v interface{}) error { func (bytesCodec) Name() string { return "bytes" } + +type grpcCodec struct { + // headers + id string + target string + method string + endpoint string + + s grpc.ServerStream + c encoding.Codec +} + +func (g *grpcCodec) ReadHeader(m *codec.Message, mt codec.MessageType) error { + md, _ := metadata.FromIncomingContext(g.s.Context()) + if m == nil { + m = new(codec.Message) + } + if m.Header == nil { + m.Header = make(map[string]string) + } + for k, v := range md { + m.Header[k] = strings.Join(v, ",") + } + m.Id = g.id + m.Target = g.target + m.Method = g.method + m.Endpoint = g.endpoint + return nil +} + +func (g *grpcCodec) ReadBody(v interface{}) error { + // caller has requested a frame + if f, ok := v.(*bytes.Frame); ok { + return g.s.RecvMsg(f) + } + return g.s.RecvMsg(v) +} + +func (g *grpcCodec) Write(m *codec.Message, v interface{}) error { + // if we don't have a body + if v != nil { + b, err := g.c.Marshal(v) + if err != nil { + return err + } + m.Body = b + } + // write the body using the framing codec + return g.s.SendMsg(&bytes.Frame{m.Body}) +} + +func (g *grpcCodec) Close() error { + return nil +} + +func (g *grpcCodec) String() string { + return g.c.Name() +} diff --git a/grpc.go b/grpc.go index 1b33081..2f6e1d4 100644 --- a/grpc.go +++ b/grpc.go @@ -56,6 +56,7 @@ type grpcServer struct { } func init() { + encoding.RegisterCodec(wrapCodec{protoCodec{}}) encoding.RegisterCodec(wrapCodec{jsonCodec{}}) encoding.RegisterCodec(wrapCodec{bytesCodec{}}) } @@ -211,14 +212,30 @@ func (g *grpcServer) handler(srv interface{}, stream grpc.ServerStream) error { // process via router if g.opts.Router != nil { - // create a client.Request - request := &rpcRequest{ - service: g.opts.Name, - contentType: ct, - method: fmt.Sprintf("%s.%s", serviceName, methodName), + cc, err := g.newGRPCCodec(ct) + if err != nil { + return errors.InternalServerError("go.micro.server", err.Error()) + } + codec := &grpcCodec{ + method: fmt.Sprintf("%s.%s", serviceName, methodName), + endpoint: fmt.Sprintf("%s.%s", serviceName, methodName), + target: g.opts.Name, + s: stream, + c: cc, } - response := &rpcResponse{} + // create a client.Request + request := &rpcRequest{ + service: mgrpc.ServiceFromMethod(fullMethod), + contentType: ct, + method: fmt.Sprintf("%s.%s", serviceName, methodName), + codec: codec, + } + + response := &rpcResponse{ + header: make(map[string]string), + codec: codec, + } // create a wrapped function handler := func(ctx context.Context, req server.Request, rsp interface{}) error { diff --git a/request.go b/request.go index 951c1a1..617b9a7 100644 --- a/request.go +++ b/request.go @@ -2,6 +2,7 @@ package grpc import ( "github.com/micro/go-micro/codec" + "github.com/micro/go-micro/codec/bytes" ) type rpcRequest struct { @@ -46,7 +47,11 @@ func (r *rpcRequest) Header() map[string]string { } func (r *rpcRequest) Read() ([]byte, error) { - return r.body, nil + f := &bytes.Frame{} + if err := r.codec.ReadBody(f); err != nil { + return nil, err + } + return f.Data, nil } func (r *rpcRequest) Stream() bool { diff --git a/response.go b/response.go index 451b1f4..f13ad89 100644 --- a/response.go +++ b/response.go @@ -1,15 +1,11 @@ package grpc import ( - "net/http" - "github.com/micro/go-micro/codec" - "github.com/micro/go-micro/transport" ) type rpcResponse struct { header map[string]string - socket transport.Socket codec codec.Codec } @@ -24,12 +20,8 @@ func (r *rpcResponse) WriteHeader(hdr map[string]string) { } func (r *rpcResponse) Write(b []byte) error { - if _, ok := r.header["Content-Type"]; !ok { - r.header["Content-Type"] = http.DetectContentType(b) - } - - return r.socket.Send(&transport.Message{ + return r.codec.Write(&codec.Message{ Header: r.header, Body: b, - }) + }, nil) }