commit 8a3538eb77bedcccccf797e6e5b7200840dfdc27 Author: Asim Aslam Date: Mon Jun 3 18:44:43 2019 +0100 Further consolidate the libraries diff --git a/README.md b/README.md new file mode 100644 index 0000000..9a36428 --- /dev/null +++ b/README.md @@ -0,0 +1,30 @@ +# GRPC Server + +The grpc server is a [micro.Server](https://godoc.org/github.com/micro/go-micro/server#Server) compatible server. + +## Overview + +The server makes use of the [google.golang.org/grpc](google.golang.org/grpc) framework for the underlying server +but continues to use micro handler signatures and protoc-gen-micro generated code. + +## Usage + +Specify the server to your micro service + +```go +import ( + "github.com/micro/go-micro" + "github.com/micro/go-plugins/server/grpc" +) + +func main() { + service := micro.NewService( + // This needs to be first as it replaces the underlying server + // which causes any configuration set before it + // to be discarded + micro.Server(grpc.NewServer()), + micro.Name("greeter"), + ) +} +``` +**NOTE**: Setting the gRPC server and/or client causes the underlying the server/client to be replaced which causes any previous configuration set on that server/client to be discarded. It is therefore recommended to set gRPC server/client before any other configuration \ No newline at end of file diff --git a/buffer.go b/buffer.go new file mode 100644 index 0000000..c43bb23 --- /dev/null +++ b/buffer.go @@ -0,0 +1,14 @@ +package grpc + +import ( + "bytes" +) + +type buffer struct { + *bytes.Buffer +} + +func (b *buffer) Close() error { + b.Buffer.Reset() + return nil +} diff --git a/codec.go b/codec.go new file mode 100644 index 0000000..50a96ef --- /dev/null +++ b/codec.go @@ -0,0 +1,82 @@ +package grpc + +import ( + "encoding/json" + "fmt" + + "github.com/golang/protobuf/proto" + "github.com/micro/go-micro/codec" + "github.com/micro/go-micro/codec/jsonrpc" + "github.com/micro/go-micro/codec/protorpc" + "google.golang.org/grpc/encoding" +) + +type jsonCodec struct{} +type bytesCodec struct{} +type protoCodec struct{} + +var ( + defaultGRPCCodecs = map[string]encoding.Codec{ + "application/json": jsonCodec{}, + "application/proto": protoCodec{}, + "application/protobuf": protoCodec{}, + "application/octet-stream": protoCodec{}, + "application/grpc": protoCodec{}, + "application/grpc+json": jsonCodec{}, + "application/grpc+proto": protoCodec{}, + "application/grpc+bytes": bytesCodec{}, + } + + defaultRPCCodecs = map[string]codec.NewCodec{ + "application/json": jsonrpc.NewCodec, + "application/json-rpc": jsonrpc.NewCodec, + "application/protobuf": protorpc.NewCodec, + "application/proto-rpc": protorpc.NewCodec, + "application/octet-stream": protorpc.NewCodec, + } +) + +func (protoCodec) Marshal(v interface{}) ([]byte, error) { + return proto.Marshal(v.(proto.Message)) +} + +func (protoCodec) Unmarshal(data []byte, v interface{}) error { + return proto.Unmarshal(data, v.(proto.Message)) +} + +func (protoCodec) Name() string { + return "proto" +} + +func (jsonCodec) Marshal(v interface{}) ([]byte, error) { + return json.Marshal(v) +} + +func (jsonCodec) Unmarshal(data []byte, v interface{}) error { + return json.Unmarshal(data, v) +} + +func (jsonCodec) Name() string { + return "json" +} + +func (bytesCodec) Marshal(v interface{}) ([]byte, error) { + b, ok := v.(*[]byte) + if !ok { + return nil, fmt.Errorf("failed to marshal: %v is not type of *[]byte", v) + } + return *b, nil +} + +func (bytesCodec) Unmarshal(data []byte, v interface{}) error { + b, ok := v.(*[]byte) + if !ok { + return fmt.Errorf("failed to unmarshal: %v is not type of *[]byte", v) + } + *b = data + return nil +} + +func (bytesCodec) Name() string { + return "bytes" +} diff --git a/debug.go b/debug.go new file mode 100644 index 0000000..5f23534 --- /dev/null +++ b/debug.go @@ -0,0 +1,15 @@ +package grpc + +import ( + "github.com/micro/go-micro/server" + "github.com/micro/go-micro/server/debug" +) + +// We use this to wrap any debug handlers so we preserve the signature Debug.{Method} +type Debug struct { + debug.DebugHandler +} + +func registerDebugHandler(s server.Server) { + s.Handle(s.NewHandler(&Debug{s.Options().DebugHandler}, server.InternalHandler(true))) +} diff --git a/error.go b/error.go new file mode 100644 index 0000000..4001622 --- /dev/null +++ b/error.go @@ -0,0 +1,42 @@ +package grpc + +import ( + "net/http" + + "github.com/micro/go-micro/errors" + "google.golang.org/grpc/codes" +) + +func microError(err *errors.Error) codes.Code { + switch err { + case nil: + return codes.OK + } + + switch err.Code { + case http.StatusOK: + return codes.OK + case http.StatusBadRequest: + return codes.InvalidArgument + case http.StatusRequestTimeout: + return codes.DeadlineExceeded + case http.StatusNotFound: + return codes.NotFound + case http.StatusConflict: + return codes.AlreadyExists + case http.StatusForbidden: + return codes.PermissionDenied + case http.StatusUnauthorized: + return codes.Unauthenticated + case http.StatusPreconditionFailed: + return codes.FailedPrecondition + case http.StatusNotImplemented: + return codes.Unimplemented + case http.StatusInternalServerError: + return codes.Internal + case http.StatusServiceUnavailable: + return codes.Unavailable + } + + return codes.Unknown +} diff --git a/extractor.go b/extractor.go new file mode 100644 index 0000000..89c00ce --- /dev/null +++ b/extractor.go @@ -0,0 +1,120 @@ +package grpc + +import ( + "fmt" + "reflect" + "strings" + + "github.com/micro/go-micro/registry" +) + +func extractValue(v reflect.Type, d int) *registry.Value { + if d == 3 { + return nil + } + if v == nil { + return nil + } + + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + + arg := ®istry.Value{ + Name: v.Name(), + Type: v.Name(), + } + + switch v.Kind() { + case reflect.Struct: + for i := 0; i < v.NumField(); i++ { + f := v.Field(i) + val := extractValue(f.Type, d+1) + if val == nil { + continue + } + + // if we can find a json tag use it + if tags := f.Tag.Get("json"); len(tags) > 0 { + parts := strings.Split(tags, ",") + if parts[0] == "-" || parts[0] == "omitempty" { + continue + } + val.Name = parts[0] + } + + // if there's no name default it + if len(val.Name) == 0 { + val.Name = v.Field(i).Name + } + + arg.Values = append(arg.Values, val) + } + case reflect.Slice: + p := v.Elem() + if p.Kind() == reflect.Ptr { + p = p.Elem() + } + arg.Type = "[]" + p.Name() + val := extractValue(v.Elem(), d+1) + if val != nil { + arg.Values = append(arg.Values, val) + } + } + + return arg +} + +func extractEndpoint(method reflect.Method) *registry.Endpoint { + if method.PkgPath != "" { + return nil + } + + var rspType, reqType reflect.Type + var stream bool + mt := method.Type + + switch mt.NumIn() { + case 3: + reqType = mt.In(1) + rspType = mt.In(2) + case 4: + reqType = mt.In(2) + rspType = mt.In(3) + default: + return nil + } + + // are we dealing with a stream? + switch rspType.Kind() { + case reflect.Func, reflect.Interface: + stream = true + } + + request := extractValue(reqType, 0) + response := extractValue(rspType, 0) + + return ®istry.Endpoint{ + Name: method.Name, + Request: request, + Response: response, + Metadata: map[string]string{ + "stream": fmt.Sprintf("%v", stream), + }, + } +} + +func extractSubValue(typ reflect.Type) *registry.Value { + var reqType reflect.Type + switch typ.NumIn() { + case 1: + reqType = typ.In(0) + case 2: + reqType = typ.In(1) + case 3: + reqType = typ.In(2) + default: + return nil + } + return extractValue(reqType, 0) +} diff --git a/extractor_test.go b/extractor_test.go new file mode 100644 index 0000000..ebbd0e2 --- /dev/null +++ b/extractor_test.go @@ -0,0 +1,65 @@ +package grpc + +import ( + "context" + "reflect" + "testing" + + "github.com/micro/go-micro/registry" +) + +type testHandler struct{} + +type testRequest struct{} + +type testResponse struct{} + +func (t *testHandler) Test(ctx context.Context, req *testRequest, rsp *testResponse) error { + return nil +} + +func TestExtractEndpoint(t *testing.T) { + handler := &testHandler{} + typ := reflect.TypeOf(handler) + + var endpoints []*registry.Endpoint + + for m := 0; m < typ.NumMethod(); m++ { + if e := extractEndpoint(typ.Method(m)); e != nil { + endpoints = append(endpoints, e) + } + } + + if i := len(endpoints); i != 1 { + t.Errorf("Expected 1 endpoint, have %d", i) + } + + if endpoints[0].Name != "Test" { + t.Errorf("Expected handler Test, got %s", endpoints[0].Name) + } + + if endpoints[0].Request == nil { + t.Error("Expected non nil request") + } + + if endpoints[0].Response == nil { + t.Error("Expected non nil request") + } + + if endpoints[0].Request.Name != "testRequest" { + t.Errorf("Expected testRequest got %s", endpoints[0].Request.Name) + } + + if endpoints[0].Response.Name != "testResponse" { + t.Errorf("Expected testResponse got %s", endpoints[0].Response.Name) + } + + if endpoints[0].Request.Type != "testRequest" { + t.Errorf("Expected testRequest type got %s", endpoints[0].Request.Type) + } + + if endpoints[0].Response.Type != "testResponse" { + t.Errorf("Expected testResponse type got %s", endpoints[0].Response.Type) + } + +} diff --git a/grpc.go b/grpc.go new file mode 100644 index 0000000..54197fc --- /dev/null +++ b/grpc.go @@ -0,0 +1,731 @@ +// Package grpc provides a grpc server +package grpc + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "reflect" + "sort" + "strconv" + "strings" + "sync" + "time" + + "github.com/micro/go-micro/broker" + "github.com/micro/go-micro/cmd" + "github.com/micro/go-micro/codec" + "github.com/micro/go-micro/errors" + meta "github.com/micro/go-micro/metadata" + "github.com/micro/go-micro/registry" + "github.com/micro/go-micro/server" + "github.com/micro/go-micro/util/addr" + mgrpc "github.com/micro/go-micro/util/grpc" + "github.com/micro/go-micro/util/log" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/encoding" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" +) + +var ( + // DefaultMaxMsgSize define maximum message size that server can send + // or receive. Default value is 4MB. + DefaultMaxMsgSize = 1024 * 1024 * 4 +) + +const ( + defaultContentType = "application/grpc" +) + +type grpcServer struct { + rpc *rServer + srv *grpc.Server + exit chan chan error + wg *sync.WaitGroup + + sync.RWMutex + opts server.Options + handlers map[string]server.Handler + subscribers map[*subscriber][]broker.Subscriber + // used for first registration + registered bool +} + +func init() { + encoding.RegisterCodec(jsonCodec{}) + encoding.RegisterCodec(bytesCodec{}) + + cmd.DefaultServers["grpc"] = NewServer +} + +func newGRPCServer(opts ...server.Option) server.Server { + options := newOptions(opts...) + + // create a grpc server + srv := &grpcServer{ + opts: options, + rpc: &rServer{ + serviceMap: make(map[string]*service), + }, + handlers: make(map[string]server.Handler), + subscribers: make(map[*subscriber][]broker.Subscriber), + exit: make(chan chan error), + wg: wait(options.Context), + } + + // configure the grpc server + srv.configure() + + return srv +} + +func (g *grpcServer) configure(opts ...server.Option) { + // Don't reprocess where there's no config + if len(opts) == 0 && g.srv != nil { + return + } + + for _, o := range opts { + o(&g.opts) + } + + maxMsgSize := g.getMaxMsgSize() + + gopts := []grpc.ServerOption{ + grpc.MaxRecvMsgSize(maxMsgSize), + grpc.MaxSendMsgSize(maxMsgSize), + grpc.UnknownServiceHandler(g.handler), + } + + if creds := g.getCredentials(); creds != nil { + gopts = append(gopts, grpc.Creds(creds)) + } + + if opts := g.getGrpcOptions(); opts != nil { + gopts = append(gopts, opts...) + } + + g.srv = grpc.NewServer(gopts...) +} + +func (g *grpcServer) getMaxMsgSize() int { + if g.opts.Context == nil { + return DefaultMaxMsgSize + } + s, ok := g.opts.Context.Value(maxMsgSizeKey{}).(int) + if !ok { + return DefaultMaxMsgSize + } + return s +} + +func (g *grpcServer) getCredentials() credentials.TransportCredentials { + if g.opts.Context != nil { + if v := g.opts.Context.Value(tlsAuth{}); v != nil { + tls := v.(*tls.Config) + return credentials.NewTLS(tls) + } + } + return nil +} + +func (g *grpcServer) getGrpcOptions() []grpc.ServerOption { + if g.opts.Context == nil { + return nil + } + + v := g.opts.Context.Value(grpcOptions{}) + + if v == nil { + return nil + } + + opts, ok := v.([]grpc.ServerOption) + + if !ok { + return nil + } + + return opts +} + +func (g *grpcServer) handler(srv interface{}, stream grpc.ServerStream) error { + if g.wg != nil { + g.wg.Add(1) + defer g.wg.Done() + } + + fullMethod, ok := grpc.MethodFromServerStream(stream) + if !ok { + return grpc.Errorf(codes.Internal, "method does not exist in context") + } + + serviceName, methodName, err := mgrpc.ServiceMethod(fullMethod) + if err != nil { + return status.New(codes.InvalidArgument, err.Error()).Err() + } + + g.rpc.mu.Lock() + service := g.rpc.serviceMap[serviceName] + g.rpc.mu.Unlock() + + if service == nil { + return status.New(codes.Unimplemented, fmt.Sprintf("unknown service %v", service)).Err() + } + + mtype := service.method[methodName] + if mtype == nil { + return status.New(codes.Unimplemented, fmt.Sprintf("unknown service %v", service)).Err() + } + + // get grpc metadata + gmd, ok := metadata.FromIncomingContext(stream.Context()) + if !ok { + gmd = metadata.MD{} + } + + // copy the metadata to go-micro.metadata + md := meta.Metadata{} + for k, v := range gmd { + md[k] = strings.Join(v, ", ") + } + + // timeout for server deadline + to := md["timeout"] + + // get content type + ct := defaultContentType + if ctype, ok := md["x-content-type"]; ok { + ct = ctype + } + + delete(md, "x-content-type") + delete(md, "timeout") + + // create new context + ctx := meta.NewContext(stream.Context(), md) + + // set the timeout if we have it + if len(to) > 0 { + if n, err := strconv.ParseUint(to, 10, 64); err == nil { + ctx, _ = context.WithTimeout(ctx, time.Duration(n)) + } + } + + // process unary + if !mtype.stream { + return g.processRequest(stream, service, mtype, ct, ctx) + } + + // process stream + return g.processStream(stream, service, mtype, ct, ctx) +} + +func (g *grpcServer) processRequest(stream grpc.ServerStream, service *service, mtype *methodType, ct string, ctx context.Context) error { + for { + var argv, replyv reflect.Value + + // Decode the argument value. + argIsValue := false // if true, need to indirect before calling. + if mtype.ArgType.Kind() == reflect.Ptr { + argv = reflect.New(mtype.ArgType.Elem()) + } else { + argv = reflect.New(mtype.ArgType) + argIsValue = true + } + + // Unmarshal request + if err := stream.RecvMsg(argv.Interface()); err != nil { + return err + } + + if argIsValue { + argv = argv.Elem() + } + + // reply value + replyv = reflect.New(mtype.ReplyType.Elem()) + + function := mtype.method.Func + var returnValues []reflect.Value + + cc, err := g.newGRPCCodec(ct) + if err != nil { + return errors.InternalServerError("go.micro.server", err.Error()) + } + b, err := cc.Marshal(argv.Interface()) + if err != nil { + return err + } + + // create a client.Request + r := &rpcRequest{ + service: g.opts.Name, + contentType: ct, + method: fmt.Sprintf("%s.%s", service.name, mtype.method.Name), + body: b, + payload: argv.Interface(), + } + + // define the handler func + fn := func(ctx context.Context, req server.Request, rsp interface{}) error { + returnValues = function.Call([]reflect.Value{service.rcvr, mtype.prepareContext(ctx), reflect.ValueOf(argv.Interface()), reflect.ValueOf(rsp)}) + + // The return value for the method is an error. + if err := returnValues[0].Interface(); err != nil { + return err.(error) + } + + return nil + } + + // wrap the handler func + for i := len(g.opts.HdlrWrappers); i > 0; i-- { + fn = g.opts.HdlrWrappers[i-1](fn) + } + + statusCode := codes.OK + statusDesc := "" + + // execute the handler + if appErr := fn(ctx, r, replyv.Interface()); appErr != nil { + if err, ok := appErr.(*rpcError); ok { + statusCode = err.code + statusDesc = err.desc + } else if err, ok := appErr.(*errors.Error); ok { + statusCode = microError(err) + statusDesc = appErr.Error() + } else { + statusCode = convertCode(appErr) + statusDesc = appErr.Error() + } + return status.New(statusCode, statusDesc).Err() + } + if err := stream.SendMsg(replyv.Interface()); err != nil { + return err + } + return status.New(statusCode, statusDesc).Err() + } +} + +func (g *grpcServer) processStream(stream grpc.ServerStream, service *service, mtype *methodType, ct string, ctx context.Context) error { + opts := g.opts + + r := &rpcRequest{ + service: opts.Name, + contentType: ct, + method: fmt.Sprintf("%s.%s", service.name, mtype.method.Name), + stream: true, + } + + ss := &rpcStream{ + request: r, + s: stream, + } + + function := mtype.method.Func + var returnValues []reflect.Value + + // Invoke the method, providing a new value for the reply. + fn := func(ctx context.Context, req server.Request, stream interface{}) error { + returnValues = function.Call([]reflect.Value{service.rcvr, mtype.prepareContext(ctx), reflect.ValueOf(stream)}) + if err := returnValues[0].Interface(); err != nil { + return err.(error) + } + + return nil + } + + for i := len(opts.HdlrWrappers); i > 0; i-- { + fn = opts.HdlrWrappers[i-1](fn) + } + + statusCode := codes.OK + statusDesc := "" + + appErr := fn(ctx, r, ss) + if appErr != nil { + if err, ok := appErr.(*rpcError); ok { + statusCode = err.code + statusDesc = err.desc + } else if err, ok := appErr.(*errors.Error); ok { + statusCode = microError(err) + statusDesc = appErr.Error() + } else { + statusCode = convertCode(appErr) + statusDesc = appErr.Error() + } + } + + return status.New(statusCode, statusDesc).Err() +} + +func (g *grpcServer) newGRPCCodec(contentType string) (encoding.Codec, error) { + codecs := make(map[string]encoding.Codec) + if g.opts.Context != nil { + if v := g.opts.Context.Value(codecsKey{}); v != nil { + codecs = v.(map[string]encoding.Codec) + } + } + if c, ok := codecs[contentType]; ok { + return c, nil + } + if c, ok := defaultGRPCCodecs[contentType]; ok { + return c, nil + } + return nil, fmt.Errorf("Unsupported Content-Type: %s", contentType) +} + +func (g *grpcServer) newCodec(contentType string) (codec.NewCodec, error) { + if cf, ok := g.opts.Codecs[contentType]; ok { + return cf, nil + } + if cf, ok := defaultRPCCodecs[contentType]; ok { + return cf, nil + } + return nil, fmt.Errorf("Unsupported Content-Type: %s", contentType) +} + +func (g *grpcServer) Options() server.Options { + opts := g.opts + return opts +} + +func (g *grpcServer) Init(opts ...server.Option) error { + g.configure(opts...) + return nil +} + +func (g *grpcServer) NewHandler(h interface{}, opts ...server.HandlerOption) server.Handler { + return newRpcHandler(h, opts...) +} + +func (g *grpcServer) Handle(h server.Handler) error { + if err := g.rpc.register(h.Handler()); err != nil { + return err + } + + g.handlers[h.Name()] = h + return nil +} + +func (g *grpcServer) NewSubscriber(topic string, sb interface{}, opts ...server.SubscriberOption) server.Subscriber { + return newSubscriber(topic, sb, opts...) +} + +func (g *grpcServer) Subscribe(sb server.Subscriber) error { + sub, ok := sb.(*subscriber) + if !ok { + return fmt.Errorf("invalid subscriber: expected *subscriber") + } + if len(sub.handlers) == 0 { + return fmt.Errorf("invalid subscriber: no handler functions") + } + + if err := validateSubscriber(sb); err != nil { + return err + } + + g.Lock() + + _, ok = g.subscribers[sub] + if ok { + return fmt.Errorf("subscriber %v already exists", sub) + } + g.subscribers[sub] = nil + g.Unlock() + return nil +} + +func (g *grpcServer) Register() error { + // parse address for host, port + config := g.opts + var advt, host string + var port int + + // check the advertise address first + // if it exists then use it, otherwise + // use the address + if len(config.Advertise) > 0 { + advt = config.Advertise + } else { + advt = config.Address + } + + parts := strings.Split(advt, ":") + if len(parts) > 1 { + host = strings.Join(parts[:len(parts)-1], ":") + port, _ = strconv.Atoi(parts[len(parts)-1]) + } else { + host = parts[0] + } + + addr, err := addr.Extract(host) + if err != nil { + return err + } + + // register service + node := ®istry.Node{ + Id: config.Name + "-" + config.Id, + Address: addr, + Port: port, + Metadata: config.Metadata, + } + + node.Metadata["broker"] = config.Broker.String() + node.Metadata["registry"] = config.Registry.String() + node.Metadata["server"] = g.String() + node.Metadata["transport"] = g.String() + // node.Metadata["transport"] = config.Transport.String() + + g.RLock() + // Maps are ordered randomly, sort the keys for consistency + var handlerList []string + for n, e := range g.handlers { + // Only advertise non internal handlers + if !e.Options().Internal { + handlerList = append(handlerList, n) + } + } + sort.Strings(handlerList) + + var subscriberList []*subscriber + for e := range g.subscribers { + // Only advertise non internal subscribers + if !e.Options().Internal { + subscriberList = append(subscriberList, e) + } + } + sort.Slice(subscriberList, func(i, j int) bool { + return subscriberList[i].topic > subscriberList[j].topic + }) + + var endpoints []*registry.Endpoint + for _, n := range handlerList { + endpoints = append(endpoints, g.handlers[n].Endpoints()...) + } + for _, e := range subscriberList { + endpoints = append(endpoints, e.Endpoints()...) + } + g.RUnlock() + + service := ®istry.Service{ + Name: config.Name, + Version: config.Version, + Nodes: []*registry.Node{node}, + Endpoints: endpoints, + } + + g.Lock() + registered := g.registered + g.Unlock() + + if !registered { + log.Logf("Registering node: %s", node.Id) + } + + // create registry options + rOpts := []registry.RegisterOption{registry.RegisterTTL(config.RegisterTTL)} + + if err := config.Registry.Register(service, rOpts...); err != nil { + return err + } + + // already registered? don't need to register subscribers + if registered { + return nil + } + + g.Lock() + defer g.Unlock() + + g.registered = true + + for sb, _ := range g.subscribers { + handler := g.createSubHandler(sb, g.opts) + var opts []broker.SubscribeOption + if queue := sb.Options().Queue; len(queue) > 0 { + opts = append(opts, broker.Queue(queue)) + } + + if !sb.Options().AutoAck { + opts = append(opts, broker.DisableAutoAck()) + } + + sub, err := config.Broker.Subscribe(sb.Topic(), handler, opts...) + if err != nil { + return err + } + g.subscribers[sb] = []broker.Subscriber{sub} + } + + return nil +} + +func (g *grpcServer) Deregister() error { + config := g.opts + var advt, host string + var port int + + // check the advertise address first + // if it exists then use it, otherwise + // use the address + if len(config.Advertise) > 0 { + advt = config.Advertise + } else { + advt = config.Address + } + + parts := strings.Split(advt, ":") + if len(parts) > 1 { + host = strings.Join(parts[:len(parts)-1], ":") + port, _ = strconv.Atoi(parts[len(parts)-1]) + } else { + host = parts[0] + } + + addr, err := addr.Extract(host) + if err != nil { + return err + } + + node := ®istry.Node{ + Id: config.Name + "-" + config.Id, + Address: addr, + Port: port, + } + + service := ®istry.Service{ + Name: config.Name, + Version: config.Version, + Nodes: []*registry.Node{node}, + } + + log.Logf("Deregistering node: %s", node.Id) + if err := config.Registry.Deregister(service); err != nil { + return err + } + + g.Lock() + + if !g.registered { + g.Unlock() + return nil + } + + g.registered = false + + for sb, subs := range g.subscribers { + for _, sub := range subs { + log.Logf("Unsubscribing from topic: %s", sub.Topic()) + sub.Unsubscribe() + } + g.subscribers[sb] = nil + } + + g.Unlock() + return nil +} + +func (g *grpcServer) Start() error { + registerDebugHandler(g) + config := g.opts + + // micro: config.Transport.Listen(config.Address) + ts, err := net.Listen("tcp", config.Address) + if err != nil { + return err + } + + log.Logf("Server [grpc] Listening on %s", ts.Addr().String()) + g.Lock() + g.opts.Address = ts.Addr().String() + g.Unlock() + + // connect to the broker + if err := config.Broker.Connect(); err != nil { + return err + } + + log.Logf("Broker [%s] Listening on %s", config.Broker.String(), config.Broker.Address()) + + // announce self to the world + if err := g.Register(); err != nil { + log.Log("Server register error: ", err) + } + + // micro: go ts.Accept(s.accept) + go func() { + if err := g.srv.Serve(ts); err != nil { + log.Log("gRPC Server start error: ", err) + } + }() + + go func() { + t := new(time.Ticker) + + // only process if it exists + if g.opts.RegisterInterval > time.Duration(0) { + // new ticker + t = time.NewTicker(g.opts.RegisterInterval) + } + + // return error chan + var ch chan error + + Loop: + for { + select { + // register self on interval + case <-t.C: + if err := g.Register(); err != nil { + log.Log("Server register error: ", err) + } + // wait for exit + case ch = <-g.exit: + break Loop + } + } + + // deregister self + if err := g.Deregister(); err != nil { + log.Log("Server deregister error: ", err) + } + + // wait for waitgroup + if g.wg != nil { + g.wg.Wait() + } + + // stop the grpc server + g.srv.GracefulStop() + + // close transport + ch <- nil + + // disconnect broker + config.Broker.Disconnect() + }() + + return nil +} + +func (g *grpcServer) Stop() error { + ch := make(chan error) + g.exit <- ch + return <-ch +} + +func (g *grpcServer) String() string { + return "grpc" +} + +func NewServer(opts ...server.Option) server.Server { + return newGRPCServer(opts...) +} diff --git a/grpc_test.go b/grpc_test.go new file mode 100644 index 0000000..d489e03 --- /dev/null +++ b/grpc_test.go @@ -0,0 +1,66 @@ +package grpc + +import ( + "context" + "testing" + + "github.com/micro/go-micro/registry/memory" + "github.com/micro/go-micro/server" + "google.golang.org/grpc" + + pb "github.com/micro/examples/greeter/srv/proto/hello" +) + +// server is used to implement helloworld.GreeterServer. +type sayServer struct{} + +// SayHello implements helloworld.GreeterServer +func (s *sayServer) Hello(ctx context.Context, req *pb.Request, rsp *pb.Response) error { + rsp.Msg = "Hello " + req.Name + return nil +} + +func TestGRPCServer(t *testing.T) { + r := memory.NewRegistry() + s := NewServer( + server.Name("foo"), + server.Registry(r), + ) + + pb.RegisterSayHandler(s, &sayServer{}) + + if err := s.Start(); err != nil { + t.Fatalf("failed to start: %v", err) + } + + // check registration + services, err := r.GetService("foo") + if err != nil || len(services) == 0 { + t.Fatalf("failed to get service: %v # %d", err, len(services)) + } + + defer func() { + if err := s.Stop(); err != nil { + t.Fatalf("failed to stop: %v", err) + } + }() + + cc, err := grpc.Dial(s.Options().Address, grpc.WithInsecure()) + if err != nil { + t.Fatalf("failed to dial server: %v", err) + } + + testMethods := []string{"/helloworld.Say/Hello", "/greeter.helloworld.Say/Hello"} + + for _, method := range testMethods { + rsp := pb.Response{} + + if err := cc.Invoke(context.Background(), method, &pb.Request{Name: "John"}, &rsp); err != nil { + t.Fatalf("error calling server: %v", err) + } + + if rsp.Msg != "Hello John" { + t.Fatalf("Got unexpected response %v", rsp.Msg) + } + } +} diff --git a/handler.go b/handler.go new file mode 100644 index 0000000..f41a1fe --- /dev/null +++ b/handler.go @@ -0,0 +1,66 @@ +package grpc + +import ( + "reflect" + + "github.com/micro/go-micro/registry" + "github.com/micro/go-micro/server" +) + +type rpcHandler struct { + name string + handler interface{} + endpoints []*registry.Endpoint + opts server.HandlerOptions +} + +func newRpcHandler(handler interface{}, opts ...server.HandlerOption) server.Handler { + options := server.HandlerOptions{ + Metadata: make(map[string]map[string]string), + } + + for _, o := range opts { + o(&options) + } + + typ := reflect.TypeOf(handler) + hdlr := reflect.ValueOf(handler) + name := reflect.Indirect(hdlr).Type().Name() + + var endpoints []*registry.Endpoint + + for m := 0; m < typ.NumMethod(); m++ { + if e := extractEndpoint(typ.Method(m)); e != nil { + e.Name = name + "." + e.Name + + for k, v := range options.Metadata[e.Name] { + e.Metadata[k] = v + } + + endpoints = append(endpoints, e) + } + } + + return &rpcHandler{ + name: name, + handler: handler, + endpoints: endpoints, + opts: options, + } +} + +func (r *rpcHandler) Name() string { + return r.name +} + +func (r *rpcHandler) Handler() interface{} { + return r.handler +} + +func (r *rpcHandler) Endpoints() []*registry.Endpoint { + return r.endpoints +} + +func (r *rpcHandler) Options() server.HandlerOptions { + return r.opts +} diff --git a/options.go b/options.go new file mode 100644 index 0000000..65d82fc --- /dev/null +++ b/options.go @@ -0,0 +1,113 @@ +package grpc + +import ( + "context" + "crypto/tls" + + "github.com/micro/go-micro/broker" + "github.com/micro/go-micro/codec" + "github.com/micro/go-micro/registry" + "github.com/micro/go-micro/server" + "github.com/micro/go-micro/server/debug" + "github.com/micro/go-micro/transport" + "google.golang.org/grpc" + "google.golang.org/grpc/encoding" +) + +type codecsKey struct{} +type tlsAuth struct{} +type maxMsgSizeKey struct{} +type grpcOptions struct{} + +// gRPC Codec to be used to encode/decode requests for a given content type +func Codec(contentType string, c encoding.Codec) server.Option { + return func(o *server.Options) { + codecs := make(map[string]encoding.Codec) + if o.Context == nil { + o.Context = context.Background() + } + if v := o.Context.Value(codecsKey{}); v != nil { + codecs = v.(map[string]encoding.Codec) + } + codecs[contentType] = c + o.Context = context.WithValue(o.Context, codecsKey{}, codecs) + } +} + +// AuthTLS should be used to setup a secure authentication using TLS +func AuthTLS(t *tls.Config) server.Option { + return func(o *server.Options) { + if o.Context == nil { + o.Context = context.Background() + } + o.Context = context.WithValue(o.Context, tlsAuth{}, t) + } +} + +// Options to be used to configure gRPC options +func Options(opts ...grpc.ServerOption) server.Option { + return func(o *server.Options) { + if o.Context == nil { + o.Context = context.Background() + } + o.Context = context.WithValue(o.Context, grpcOptions{}, opts) + } +} + +// +// MaxMsgSize set the maximum message in bytes the server can receive and +// send. Default maximum message size is 4 MB. +// +func MaxMsgSize(s int) server.Option { + return func(o *server.Options) { + if o.Context == nil { + o.Context = context.Background() + } + o.Context = context.WithValue(o.Context, maxMsgSizeKey{}, s) + } +} + +func newOptions(opt ...server.Option) server.Options { + opts := server.Options{ + Codecs: make(map[string]codec.NewCodec), + Metadata: map[string]string{}, + } + + for _, o := range opt { + o(&opts) + } + + if opts.Broker == nil { + opts.Broker = broker.DefaultBroker + } + + if opts.Registry == nil { + opts.Registry = registry.DefaultRegistry + } + + if opts.Transport == nil { + opts.Transport = transport.DefaultTransport + } + + if opts.DebugHandler == nil { + opts.DebugHandler = debug.DefaultDebugHandler + } + + if len(opts.Address) == 0 { + opts.Address = server.DefaultAddress + } + + if len(opts.Name) == 0 { + opts.Name = server.DefaultName + } + + if len(opts.Id) == 0 { + opts.Id = server.DefaultId + } + + if len(opts.Version) == 0 { + opts.Version = server.DefaultVersion + } + + return opts +} diff --git a/request.go b/request.go new file mode 100644 index 0000000..951c1a1 --- /dev/null +++ b/request.go @@ -0,0 +1,70 @@ +package grpc + +import ( + "github.com/micro/go-micro/codec" +) + +type rpcRequest struct { + service string + method string + contentType string + codec codec.Codec + header map[string]string + body []byte + stream bool + payload interface{} +} + +type rpcMessage struct { + topic string + contentType string + payload interface{} +} + +func (r *rpcRequest) ContentType() string { + return r.contentType +} + +func (r *rpcRequest) Service() string { + return r.service +} + +func (r *rpcRequest) Method() string { + return r.method +} + +func (r *rpcRequest) Endpoint() string { + return r.method +} + +func (r *rpcRequest) Codec() codec.Reader { + return r.codec +} + +func (r *rpcRequest) Header() map[string]string { + return r.header +} + +func (r *rpcRequest) Read() ([]byte, error) { + return r.body, nil +} + +func (r *rpcRequest) Stream() bool { + return r.stream +} + +func (r *rpcRequest) Body() interface{} { + return r.payload +} + +func (r *rpcMessage) ContentType() string { + return r.contentType +} + +func (r *rpcMessage) Topic() string { + return r.topic +} + +func (r *rpcMessage) Payload() interface{} { + return r.payload +} diff --git a/server.go b/server.go new file mode 100644 index 0000000..140d7e5 --- /dev/null +++ b/server.go @@ -0,0 +1,180 @@ +package grpc + +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +// Meh, we need to get rid of this shit + +import ( + "context" + "errors" + "reflect" + "sync" + "unicode" + "unicode/utf8" + + "github.com/micro/go-micro/server" + "github.com/micro/go-micro/util/log" +) + +var ( + // Precompute the reflect type for error. Can't use error directly + // because Typeof takes an empty interface value. This is annoying. + typeOfError = reflect.TypeOf((*error)(nil)).Elem() +) + +type methodType struct { + method reflect.Method + ArgType reflect.Type + ReplyType reflect.Type + ContextType reflect.Type + stream bool +} + +type service struct { + name string // name of service + rcvr reflect.Value // receiver of methods for the service + typ reflect.Type // type of the receiver + method map[string]*methodType // registered methods +} + +// server represents an RPC Server. +type rServer struct { + mu sync.Mutex // protects the serviceMap + serviceMap map[string]*service +} + +// Is this an exported - upper case - name? +func isExported(name string) bool { + rune, _ := utf8.DecodeRuneInString(name) + return unicode.IsUpper(rune) +} + +// Is this type exported or a builtin? +func isExportedOrBuiltinType(t reflect.Type) bool { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + // PkgPath will be non-empty even for an exported type, + // so we need to check the type name as well. + return isExported(t.Name()) || t.PkgPath() == "" +} + +// prepareEndpoint() returns a methodType for the provided method or nil +// in case if the method was unsuitable. +func prepareEndpoint(method reflect.Method) *methodType { + mtype := method.Type + mname := method.Name + var replyType, argType, contextType reflect.Type + var stream bool + + // Endpoint() must be exported. + if method.PkgPath != "" { + return nil + } + + switch mtype.NumIn() { + case 3: + // assuming streaming + argType = mtype.In(2) + contextType = mtype.In(1) + stream = true + case 4: + // method that takes a context + argType = mtype.In(2) + replyType = mtype.In(3) + contextType = mtype.In(1) + default: + log.Log("method", mname, "of", mtype, "has wrong number of ins:", mtype.NumIn()) + return nil + } + + if stream { + // check stream type + streamType := reflect.TypeOf((*server.Stream)(nil)).Elem() + if !argType.Implements(streamType) { + log.Log(mname, "argument does not implement Streamer interface:", argType) + return nil + } + } else { + // if not stream check the replyType + + // First arg need not be a pointer. + if !isExportedOrBuiltinType(argType) { + log.Log(mname, "argument type not exported:", argType) + return nil + } + + if replyType.Kind() != reflect.Ptr { + log.Log("method", mname, "reply type not a pointer:", replyType) + return nil + } + + // Reply type must be exported. + if !isExportedOrBuiltinType(replyType) { + log.Log("method", mname, "reply type not exported:", replyType) + return nil + } + } + + // Endpoint() needs one out. + if mtype.NumOut() != 1 { + log.Log("method", mname, "has wrong number of outs:", mtype.NumOut()) + return nil + } + // The return type of the method must be error. + if returnType := mtype.Out(0); returnType != typeOfError { + log.Log("method", mname, "returns", returnType.String(), "not error") + return nil + } + return &methodType{method: method, ArgType: argType, ReplyType: replyType, ContextType: contextType, stream: stream} +} + +func (server *rServer) register(rcvr interface{}) error { + server.mu.Lock() + defer server.mu.Unlock() + if server.serviceMap == nil { + server.serviceMap = make(map[string]*service) + } + s := new(service) + s.typ = reflect.TypeOf(rcvr) + s.rcvr = reflect.ValueOf(rcvr) + sname := reflect.Indirect(s.rcvr).Type().Name() + if sname == "" { + log.Fatal("rpc: no service name for type", s.typ.String()) + } + if !isExported(sname) { + s := "rpc Register: type " + sname + " is not exported" + log.Log(s) + return errors.New(s) + } + if _, present := server.serviceMap[sname]; present { + return errors.New("rpc: service already defined: " + sname) + } + s.name = sname + s.method = make(map[string]*methodType) + + // Install the methods + for m := 0; m < s.typ.NumMethod(); m++ { + method := s.typ.Method(m) + if mt := prepareEndpoint(method); mt != nil { + s.method[method.Name] = mt + } + } + + if len(s.method) == 0 { + s := "rpc Register: type " + sname + " has no exported methods of suitable type" + log.Log(s) + return errors.New(s) + } + server.serviceMap[s.name] = s + return nil +} + +func (m *methodType) prepareContext(ctx context.Context) reflect.Value { + if contextv := reflect.ValueOf(ctx); contextv.IsValid() { + return contextv + } + return reflect.Zero(m.ContextType) +} diff --git a/stream.go b/stream.go new file mode 100644 index 0000000..ae31a0b --- /dev/null +++ b/stream.go @@ -0,0 +1,38 @@ +package grpc + +import ( + "context" + + "github.com/micro/go-micro/server" + "google.golang.org/grpc" +) + +// rpcStream implements a server side Stream. +type rpcStream struct { + s grpc.ServerStream + request server.Request +} + +func (r *rpcStream) Close() error { + return nil +} + +func (r *rpcStream) Error() error { + return nil +} + +func (r *rpcStream) Request() server.Request { + return r.request +} + +func (r *rpcStream) Context() context.Context { + return r.s.Context() +} + +func (r *rpcStream) Send(m interface{}) error { + return r.s.SendMsg(m) +} + +func (r *rpcStream) Recv(m interface{}) error { + return r.s.RecvMsg(m) +} diff --git a/subscriber.go b/subscriber.go new file mode 100644 index 0000000..56cf6db --- /dev/null +++ b/subscriber.go @@ -0,0 +1,262 @@ +package grpc + +import ( + "bytes" + "context" + "fmt" + "reflect" + + "github.com/micro/go-micro/broker" + "github.com/micro/go-micro/codec" + "github.com/micro/go-micro/metadata" + "github.com/micro/go-micro/registry" + "github.com/micro/go-micro/server" +) + +const ( + subSig = "func(context.Context, interface{}) error" +) + +type handler struct { + method reflect.Value + reqType reflect.Type + ctxType reflect.Type +} + +type subscriber struct { + topic string + rcvr reflect.Value + typ reflect.Type + subscriber interface{} + handlers []*handler + endpoints []*registry.Endpoint + opts server.SubscriberOptions +} + +func newSubscriber(topic string, sub interface{}, opts ...server.SubscriberOption) server.Subscriber { + + options := server.SubscriberOptions{ + AutoAck: true, + } + + for _, o := range opts { + o(&options) + } + + var endpoints []*registry.Endpoint + var handlers []*handler + + if typ := reflect.TypeOf(sub); typ.Kind() == reflect.Func { + h := &handler{ + method: reflect.ValueOf(sub), + } + + switch typ.NumIn() { + case 1: + h.reqType = typ.In(0) + case 2: + h.ctxType = typ.In(0) + h.reqType = typ.In(1) + } + + handlers = append(handlers, h) + + endpoints = append(endpoints, ®istry.Endpoint{ + Name: "Func", + Request: extractSubValue(typ), + Metadata: map[string]string{ + "topic": topic, + "subscriber": "true", + }, + }) + } else { + hdlr := reflect.ValueOf(sub) + name := reflect.Indirect(hdlr).Type().Name() + + for m := 0; m < typ.NumMethod(); m++ { + method := typ.Method(m) + h := &handler{ + method: method.Func, + } + + switch method.Type.NumIn() { + case 2: + h.reqType = method.Type.In(1) + case 3: + h.ctxType = method.Type.In(1) + h.reqType = method.Type.In(2) + } + + handlers = append(handlers, h) + + endpoints = append(endpoints, ®istry.Endpoint{ + Name: name + "." + method.Name, + Request: extractSubValue(method.Type), + Metadata: map[string]string{ + "topic": topic, + "subscriber": "true", + }, + }) + } + } + + return &subscriber{ + rcvr: reflect.ValueOf(sub), + typ: reflect.TypeOf(sub), + topic: topic, + subscriber: sub, + handlers: handlers, + endpoints: endpoints, + opts: options, + } +} + +func validateSubscriber(sub server.Subscriber) error { + typ := reflect.TypeOf(sub.Subscriber()) + var argType reflect.Type + + if typ.Kind() == reflect.Func { + name := "Func" + switch typ.NumIn() { + case 2: + argType = typ.In(1) + default: + return fmt.Errorf("subscriber %v takes wrong number of args: %v required signature %s", name, typ.NumIn(), subSig) + } + if !isExportedOrBuiltinType(argType) { + return fmt.Errorf("subscriber %v argument type not exported: %v", name, argType) + } + if typ.NumOut() != 1 { + return fmt.Errorf("subscriber %v has wrong number of outs: %v require signature %s", + name, typ.NumOut(), subSig) + } + if returnType := typ.Out(0); returnType != typeOfError { + return fmt.Errorf("subscriber %v returns %v not error", name, returnType.String()) + } + } else { + hdlr := reflect.ValueOf(sub.Subscriber()) + name := reflect.Indirect(hdlr).Type().Name() + + for m := 0; m < typ.NumMethod(); m++ { + method := typ.Method(m) + + switch method.Type.NumIn() { + case 3: + argType = method.Type.In(2) + default: + return fmt.Errorf("subscriber %v.%v takes wrong number of args: %v required signature %s", + name, method.Name, method.Type.NumIn(), subSig) + } + + if !isExportedOrBuiltinType(argType) { + return fmt.Errorf("%v argument type not exported: %v", name, argType) + } + if method.Type.NumOut() != 1 { + return fmt.Errorf( + "subscriber %v.%v has wrong number of outs: %v require signature %s", + name, method.Name, method.Type.NumOut(), subSig) + } + if returnType := method.Type.Out(0); returnType != typeOfError { + return fmt.Errorf("subscriber %v.%v returns %v not error", name, method.Name, returnType.String()) + } + } + } + + return nil +} + +func (g *grpcServer) createSubHandler(sb *subscriber, opts server.Options) broker.Handler { + return func(p broker.Publication) error { + msg := p.Message() + ct := msg.Header["Content-Type"] + cf, err := g.newCodec(ct) + if err != nil { + return err + } + + hdr := make(map[string]string) + for k, v := range msg.Header { + hdr[k] = v + } + delete(hdr, "Content-Type") + ctx := metadata.NewContext(context.Background(), hdr) + + for i := 0; i < len(sb.handlers); i++ { + handler := sb.handlers[i] + + var isVal bool + var req reflect.Value + + if handler.reqType.Kind() == reflect.Ptr { + req = reflect.New(handler.reqType.Elem()) + } else { + req = reflect.New(handler.reqType) + isVal = true + } + if isVal { + req = req.Elem() + } + + b := &buffer{bytes.NewBuffer(msg.Body)} + co := cf(b) + defer co.Close() + + if err := co.ReadHeader(&codec.Message{}, codec.Publication); err != nil { + return err + } + + if err := co.ReadBody(req.Interface()); err != nil { + return err + } + + fn := func(ctx context.Context, msg server.Message) error { + var vals []reflect.Value + if sb.typ.Kind() != reflect.Func { + vals = append(vals, sb.rcvr) + } + if handler.ctxType != nil { + vals = append(vals, reflect.ValueOf(ctx)) + } + + vals = append(vals, reflect.ValueOf(msg.Payload())) + + returnValues := handler.method.Call(vals) + if err := returnValues[0].Interface(); err != nil { + return err.(error) + } + return nil + } + + for i := len(opts.SubWrappers); i > 0; i-- { + fn = opts.SubWrappers[i-1](fn) + } + + g.wg.Add(1) + go func() { + defer g.wg.Done() + fn(ctx, &rpcMessage{ + topic: sb.topic, + contentType: ct, + payload: req.Interface(), + }) + }() + } + return nil + } +} + +func (s *subscriber) Topic() string { + return s.topic +} + +func (s *subscriber) Subscriber() interface{} { + return s.subscriber +} + +func (s *subscriber) Endpoints() []*registry.Endpoint { + return s.endpoints +} + +func (s *subscriber) Options() server.SubscriberOptions { + return s.opts +} diff --git a/util.go b/util.go new file mode 100644 index 0000000..0583548 --- /dev/null +++ b/util.go @@ -0,0 +1,60 @@ +package grpc + +import ( + "context" + "fmt" + "io" + "os" + "sync" + + "google.golang.org/grpc/codes" +) + +// rpcError defines the status from an RPC. +type rpcError struct { + code codes.Code + desc string +} + +func (e *rpcError) Error() string { + return fmt.Sprintf("rpc error: code = %d desc = %s", e.code, e.desc) +} + +// convertCode converts a standard Go error into its canonical code. Note that +// this is only used to translate the error returned by the server applications. +func convertCode(err error) codes.Code { + switch err { + case nil: + return codes.OK + case io.EOF: + return codes.OutOfRange + case io.ErrClosedPipe, io.ErrNoProgress, io.ErrShortBuffer, io.ErrShortWrite, io.ErrUnexpectedEOF: + return codes.FailedPrecondition + case os.ErrInvalid: + return codes.InvalidArgument + case context.Canceled: + return codes.Canceled + case context.DeadlineExceeded: + return codes.DeadlineExceeded + } + switch { + case os.IsExist(err): + return codes.AlreadyExists + case os.IsNotExist(err): + return codes.NotFound + case os.IsPermission(err): + return codes.PermissionDenied + } + return codes.Unknown +} + +func wait(ctx context.Context) *sync.WaitGroup { + if ctx == nil { + return nil + } + wg, ok := ctx.Value("wait").(*sync.WaitGroup) + if !ok { + return nil + } + return wg +}