diff --git a/grpc.go b/grpc.go index 74a2745..bde2e4c 100644 --- a/grpc.go +++ b/grpc.go @@ -81,6 +81,14 @@ func newGRPCServer(opts ...server.Option) server.Server { return srv } +type grpcRouter struct { + h func(context.Context, server.Request, interface{}) error +} + +func (r grpcRouter) ServeRequest(ctx context.Context, req server.Request, rsp server.Response) error { + return r.h(ctx, req, rsp) +} + func (g *grpcServer) configure(opts ...server.Option) { // Don't reprocess where there's no config if len(opts) == 0 && g.srv != nil { @@ -167,19 +175,6 @@ func (g *grpcServer) handler(srv interface{}, stream grpc.ServerStream) error { return status.New(codes.InvalidArgument, err.Error()).Err() } - g.rpc.mu.Lock() - service := g.rpc.serviceMap[serviceName] - g.rpc.mu.Unlock() - - if service == nil { - return status.New(codes.Unimplemented, fmt.Sprintf("unknown service %v", service)).Err() - } - - mtype := service.method[methodName] - if mtype == nil { - return status.New(codes.Unimplemented, fmt.Sprintf("unknown service %v", service)).Err() - } - // get grpc metadata gmd, ok := metadata.FromIncomingContext(stream.Context()) if !ok { @@ -214,6 +209,51 @@ func (g *grpcServer) handler(srv interface{}, stream grpc.ServerStream) error { } } + // process via router + if g.opts.Router != nil { + // create a client.Request + request := &rpcRequest{ + service: g.opts.Name, + contentType: ct, + method: fmt.Sprintf("%s.%s", serviceName, methodName), + } + + response := &rpcResponse{} + + // create a wrapped function + handler := func(ctx context.Context, req server.Request, rsp interface{}) error { + return g.opts.Router.ServeRequest(ctx, req, rsp.(server.Response)) + } + + // execute the wrapper for it + for i := len(g.opts.HdlrWrappers); i > 0; i-- { + handler = g.opts.HdlrWrappers[i-1](handler) + } + + r := grpcRouter{handler} + + // serve the actual request using the request router + if err := r.ServeRequest(ctx, request, response); err != nil { + return status.Errorf(codes.Internal, err.Error()) + } + + return nil + } + + // process the standard request flow + g.rpc.mu.Lock() + service := g.rpc.serviceMap[serviceName] + g.rpc.mu.Unlock() + + if service == nil { + return status.New(codes.Unimplemented, fmt.Sprintf("unknown service %v", service)).Err() + } + + mtype := service.method[methodName] + if mtype == nil { + return status.New(codes.Unimplemented, fmt.Sprintf("unknown service %v", service)).Err() + } + // process unary if !mtype.stream { return g.processRequest(stream, service, mtype, ct, ctx) diff --git a/response.go b/response.go new file mode 100644 index 0000000..451b1f4 --- /dev/null +++ b/response.go @@ -0,0 +1,35 @@ +package grpc + +import ( + "net/http" + + "github.com/micro/go-micro/codec" + "github.com/micro/go-micro/transport" +) + +type rpcResponse struct { + header map[string]string + socket transport.Socket + codec codec.Codec +} + +func (r *rpcResponse) Codec() codec.Writer { + return r.codec +} + +func (r *rpcResponse) WriteHeader(hdr map[string]string) { + for k, v := range hdr { + r.header[k] = v + } +} + +func (r *rpcResponse) Write(b []byte) error { + if _, ok := r.header["Content-Type"]; !ok { + r.header["Content-Type"] = http.DetectContentType(b) + } + + return r.socket.Send(&transport.Message{ + Header: r.header, + Body: b, + }) +}