micro/server/grpc/grpc.go
Vasiliy Tolstov 58598d0fe0
fixes for safe conversation and avoid panics (#1213)
* fixes for safe convertation

Signed-off-by: Vasiliy Tolstov <v.tolstov@unistack.org>

* fix client publish panic

If broker connect returns error we dont check it status and use
it later to publish message, mostly this is unexpected because
broker connection failed and we cant use it.
Also proposed solution have benefit - we flag connection status
only when we have succeseful broker connection

Signed-off-by: Vasiliy Tolstov <v.tolstov@unistack.org>

* api/handler/broker: fix possible broker publish panic

Signed-off-by: Vasiliy Tolstov <v.tolstov@unistack.org>
2020-02-18 23:05:38 +00:00

935 lines
21 KiB
Go

// Package grpc provides a grpc server
package grpc
import (
"context"
"crypto/tls"
"fmt"
"net"
"reflect"
"runtime/debug"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/golang/protobuf/proto"
"github.com/micro/go-micro/v2/broker"
"github.com/micro/go-micro/v2/codec"
"github.com/micro/go-micro/v2/errors"
meta "github.com/micro/go-micro/v2/metadata"
"github.com/micro/go-micro/v2/registry"
"github.com/micro/go-micro/v2/server"
"github.com/micro/go-micro/v2/util/addr"
mgrpc "github.com/micro/go-micro/v2/util/grpc"
"github.com/micro/go-micro/v2/util/log"
mnet "github.com/micro/go-micro/v2/util/net"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/encoding"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
)
var (
// DefaultMaxMsgSize define maximum message size that server can send
// or receive. Default value is 4MB.
DefaultMaxMsgSize = 1024 * 1024 * 4
)
const (
defaultContentType = "application/grpc"
)
type grpcServer struct {
rpc *rServer
srv *grpc.Server
exit chan chan error
wg *sync.WaitGroup
sync.RWMutex
opts server.Options
handlers map[string]server.Handler
subscribers map[*subscriber][]broker.Subscriber
// marks the serve as started
started bool
// used for first registration
registered bool
}
func init() {
encoding.RegisterCodec(wrapCodec{jsonCodec{}})
encoding.RegisterCodec(wrapCodec{protoCodec{}})
encoding.RegisterCodec(wrapCodec{bytesCodec{}})
}
func newGRPCServer(opts ...server.Option) server.Server {
options := newOptions(opts...)
// create a grpc server
srv := &grpcServer{
opts: options,
rpc: &rServer{
serviceMap: make(map[string]*service),
},
handlers: make(map[string]server.Handler),
subscribers: make(map[*subscriber][]broker.Subscriber),
exit: make(chan chan error),
wg: wait(options.Context),
}
// configure the grpc server
srv.configure()
return srv
}
type grpcRouter struct {
h func(context.Context, server.Request, interface{}) error
m func(context.Context, server.Message) error
}
func (r grpcRouter) ProcessMessage(ctx context.Context, msg server.Message) error {
return r.m(ctx, msg)
}
func (r grpcRouter) ServeRequest(ctx context.Context, req server.Request, rsp server.Response) error {
return r.h(ctx, req, rsp)
}
func (g *grpcServer) configure(opts ...server.Option) {
// Don't reprocess where there's no config
if len(opts) == 0 && g.srv != nil {
return
}
for _, o := range opts {
o(&g.opts)
}
maxMsgSize := g.getMaxMsgSize()
gopts := []grpc.ServerOption{
grpc.MaxRecvMsgSize(maxMsgSize),
grpc.MaxSendMsgSize(maxMsgSize),
grpc.UnknownServiceHandler(g.handler),
}
if creds := g.getCredentials(); creds != nil {
gopts = append(gopts, grpc.Creds(creds))
}
if opts := g.getGrpcOptions(); opts != nil {
gopts = append(gopts, opts...)
}
g.srv = grpc.NewServer(gopts...)
}
func (g *grpcServer) getMaxMsgSize() int {
if g.opts.Context == nil {
return DefaultMaxMsgSize
}
s, ok := g.opts.Context.Value(maxMsgSizeKey{}).(int)
if !ok {
return DefaultMaxMsgSize
}
return s
}
func (g *grpcServer) getCredentials() credentials.TransportCredentials {
if g.opts.Context != nil {
if v, ok := g.opts.Context.Value(tlsAuth{}).(*tls.Config); ok && v != nil {
return credentials.NewTLS(v)
}
}
return nil
}
func (g *grpcServer) getGrpcOptions() []grpc.ServerOption {
if g.opts.Context == nil {
return nil
}
opts, ok := g.opts.Context.Value(grpcOptions{}).([]grpc.ServerOption)
if !ok || opts == nil {
return nil
}
return opts
}
func (g *grpcServer) getListener() net.Listener {
if g.opts.Context != nil {
if v := g.opts.Context.Value(netListener{}); v != nil {
if l, ok := v.(net.Listener); ok {
return l
}
}
}
return nil
}
func (g *grpcServer) handler(srv interface{}, stream grpc.ServerStream) error {
if g.wg != nil {
g.wg.Add(1)
defer g.wg.Done()
}
fullMethod, ok := grpc.MethodFromServerStream(stream)
if !ok {
return status.Errorf(codes.Internal, "method does not exist in context")
}
serviceName, methodName, err := mgrpc.ServiceMethod(fullMethod)
if err != nil {
return status.New(codes.InvalidArgument, err.Error()).Err()
}
// get grpc metadata
gmd, ok := metadata.FromIncomingContext(stream.Context())
if !ok {
gmd = metadata.MD{}
}
// copy the metadata to go-micro.metadata
md := meta.Metadata{}
for k, v := range gmd {
md[k] = strings.Join(v, ", ")
}
// timeout for server deadline
to := md["timeout"]
// get content type
ct := defaultContentType
if ctype, ok := md["x-content-type"]; ok {
ct = ctype
}
if ctype, ok := md["content-type"]; ok {
ct = ctype
}
delete(md, "x-content-type")
delete(md, "timeout")
// create new context
ctx := meta.NewContext(stream.Context(), md)
// get peer from context
if p, ok := peer.FromContext(stream.Context()); ok {
md["Remote"] = p.Addr.String()
ctx = peer.NewContext(ctx, p)
}
// set the timeout if we have it
if len(to) > 0 {
if n, err := strconv.ParseUint(to, 10, 64); err == nil {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, time.Duration(n))
defer cancel()
}
}
// process via router
if g.opts.Router != nil {
cc, err := g.newGRPCCodec(ct)
if err != nil {
return errors.InternalServerError("go.micro.server", err.Error())
}
codec := &grpcCodec{
method: fmt.Sprintf("%s.%s", serviceName, methodName),
endpoint: fmt.Sprintf("%s.%s", serviceName, methodName),
target: g.opts.Name,
s: stream,
c: cc,
}
// create a client.Request
request := &rpcRequest{
service: mgrpc.ServiceFromMethod(fullMethod),
contentType: ct,
method: fmt.Sprintf("%s.%s", serviceName, methodName),
codec: codec,
stream: true,
}
response := &rpcResponse{
header: make(map[string]string),
codec: codec,
}
// create a wrapped function
handler := func(ctx context.Context, req server.Request, rsp interface{}) error {
return g.opts.Router.ServeRequest(ctx, req, rsp.(server.Response))
}
// execute the wrapper for it
for i := len(g.opts.HdlrWrappers); i > 0; i-- {
handler = g.opts.HdlrWrappers[i-1](handler)
}
r := grpcRouter{h: handler}
// serve the actual request using the request router
if err := r.ServeRequest(ctx, request, response); err != nil {
return status.Errorf(codes.Internal, err.Error())
}
return nil
}
// process the standard request flow
g.rpc.mu.Lock()
service := g.rpc.serviceMap[serviceName]
g.rpc.mu.Unlock()
if service == nil {
return status.New(codes.Unimplemented, fmt.Sprintf("unknown service %s", serviceName)).Err()
}
mtype := service.method[methodName]
if mtype == nil {
return status.New(codes.Unimplemented, fmt.Sprintf("unknown service %s.%s", serviceName, methodName)).Err()
}
// process unary
if !mtype.stream {
return g.processRequest(stream, service, mtype, ct, ctx)
}
// process stream
return g.processStream(stream, service, mtype, ct, ctx)
}
func (g *grpcServer) processRequest(stream grpc.ServerStream, service *service, mtype *methodType, ct string, ctx context.Context) error {
for {
var argv, replyv reflect.Value
// Decode the argument value.
argIsValue := false // if true, need to indirect before calling.
if mtype.ArgType.Kind() == reflect.Ptr {
argv = reflect.New(mtype.ArgType.Elem())
} else {
argv = reflect.New(mtype.ArgType)
argIsValue = true
}
// Unmarshal request
if err := stream.RecvMsg(argv.Interface()); err != nil {
return err
}
if argIsValue {
argv = argv.Elem()
}
// reply value
replyv = reflect.New(mtype.ReplyType.Elem())
function := mtype.method.Func
var returnValues []reflect.Value
cc, err := g.newGRPCCodec(ct)
if err != nil {
return errors.InternalServerError("go.micro.server", err.Error())
}
b, err := cc.Marshal(argv.Interface())
if err != nil {
return err
}
// create a client.Request
r := &rpcRequest{
service: g.opts.Name,
contentType: ct,
method: fmt.Sprintf("%s.%s", service.name, mtype.method.Name),
body: b,
payload: argv.Interface(),
}
// define the handler func
fn := func(ctx context.Context, req server.Request, rsp interface{}) (err error) {
defer func() {
if r := recover(); r != nil {
log.Log("panic recovered: ", r)
log.Logf(string(debug.Stack()))
err = errors.InternalServerError("go.micro.server", "panic recovered: %v", r)
}
}()
returnValues = function.Call([]reflect.Value{service.rcvr, mtype.prepareContext(ctx), reflect.ValueOf(argv.Interface()), reflect.ValueOf(rsp)})
// The return value for the method is an error.
if rerr := returnValues[0].Interface(); rerr != nil {
err = rerr.(error)
}
return err
}
// wrap the handler func
for i := len(g.opts.HdlrWrappers); i > 0; i-- {
fn = g.opts.HdlrWrappers[i-1](fn)
}
statusCode := codes.OK
statusDesc := ""
// execute the handler
if appErr := fn(ctx, r, replyv.Interface()); appErr != nil {
var errStatus *status.Status
switch verr := appErr.(type) {
case *errors.Error:
// micro.Error now proto based and we can attach it to grpc status
statusCode = microError(verr)
statusDesc = verr.Error()
errStatus, err = status.New(statusCode, statusDesc).WithDetails(verr)
if err != nil {
return err
}
case proto.Message:
// user defined error that proto based we can attach it to grpc status
statusCode = convertCode(appErr)
statusDesc = appErr.Error()
errStatus, err = status.New(statusCode, statusDesc).WithDetails(verr)
if err != nil {
return err
}
case *rpcError:
// rpcError handling may be we have ability to attach it to details?
statusCode = verr.code
statusDesc = verr.desc
errStatus = status.New(statusCode, statusDesc)
default:
// default case user pass own error type that not proto based
statusCode = convertCode(verr)
statusDesc = verr.Error()
errStatus = status.New(statusCode, statusDesc)
}
return errStatus.Err()
}
if err := stream.SendMsg(replyv.Interface()); err != nil {
return err
}
return status.New(statusCode, statusDesc).Err()
}
}
func (g *grpcServer) processStream(stream grpc.ServerStream, service *service, mtype *methodType, ct string, ctx context.Context) error {
opts := g.opts
r := &rpcRequest{
service: opts.Name,
contentType: ct,
method: fmt.Sprintf("%s.%s", service.name, mtype.method.Name),
stream: true,
}
ss := &rpcStream{
request: r,
s: stream,
}
function := mtype.method.Func
var returnValues []reflect.Value
// Invoke the method, providing a new value for the reply.
fn := func(ctx context.Context, req server.Request, stream interface{}) error {
returnValues = function.Call([]reflect.Value{service.rcvr, mtype.prepareContext(ctx), reflect.ValueOf(stream)})
if err := returnValues[0].Interface(); err != nil {
return err.(error)
}
return nil
}
for i := len(opts.HdlrWrappers); i > 0; i-- {
fn = opts.HdlrWrappers[i-1](fn)
}
statusCode := codes.OK
statusDesc := ""
appErr := fn(ctx, r, ss)
if appErr != nil {
var err error
var errStatus *status.Status
switch verr := appErr.(type) {
case *errors.Error:
// micro.Error now proto based and we can attach it to grpc status
statusCode = microError(verr)
statusDesc = verr.Error()
errStatus, err = status.New(statusCode, statusDesc).WithDetails(verr)
if err != nil {
return err
}
case proto.Message:
// user defined error that proto based we can attach it to grpc status
statusCode = convertCode(appErr)
statusDesc = appErr.Error()
errStatus, err = status.New(statusCode, statusDesc).WithDetails(verr)
if err != nil {
return err
}
case *rpcError:
// rpcError handling may be we have ability to attach it to details?
statusCode = verr.code
statusDesc = verr.desc
errStatus = status.New(statusCode, statusDesc)
default:
// default case user pass own error type that not proto based
statusCode = convertCode(verr)
statusDesc = verr.Error()
errStatus = status.New(statusCode, statusDesc)
}
return errStatus.Err()
}
return status.New(statusCode, statusDesc).Err()
}
func (g *grpcServer) newGRPCCodec(contentType string) (encoding.Codec, error) {
codecs := make(map[string]encoding.Codec)
if g.opts.Context != nil {
if v, ok := g.opts.Context.Value(codecsKey{}).(map[string]encoding.Codec); ok && v != nil {
codecs = v
}
}
if c, ok := codecs[contentType]; ok {
return c, nil
}
if c, ok := defaultGRPCCodecs[contentType]; ok {
return c, nil
}
return nil, fmt.Errorf("Unsupported Content-Type: %s", contentType)
}
func (g *grpcServer) newCodec(contentType string) (codec.NewCodec, error) {
if cf, ok := g.opts.Codecs[contentType]; ok {
return cf, nil
}
if cf, ok := defaultRPCCodecs[contentType]; ok {
return cf, nil
}
return nil, fmt.Errorf("Unsupported Content-Type: %s", contentType)
}
func (g *grpcServer) Options() server.Options {
g.RLock()
opts := g.opts
g.RUnlock()
return opts
}
func (g *grpcServer) Init(opts ...server.Option) error {
g.configure(opts...)
return nil
}
func (g *grpcServer) NewHandler(h interface{}, opts ...server.HandlerOption) server.Handler {
return newRpcHandler(h, opts...)
}
func (g *grpcServer) Handle(h server.Handler) error {
if err := g.rpc.register(h.Handler()); err != nil {
return err
}
g.handlers[h.Name()] = h
return nil
}
func (g *grpcServer) NewSubscriber(topic string, sb interface{}, opts ...server.SubscriberOption) server.Subscriber {
return newSubscriber(topic, sb, opts...)
}
func (g *grpcServer) Subscribe(sb server.Subscriber) error {
sub, ok := sb.(*subscriber)
if !ok {
return fmt.Errorf("invalid subscriber: expected *subscriber")
}
if len(sub.handlers) == 0 {
return fmt.Errorf("invalid subscriber: no handler functions")
}
if err := validateSubscriber(sb); err != nil {
return err
}
g.Lock()
if _, ok = g.subscribers[sub]; ok {
return fmt.Errorf("subscriber %v already exists", sub)
}
g.subscribers[sub] = nil
g.Unlock()
return nil
}
func (g *grpcServer) Register() error {
var err error
var advt, host, port string
// parse address for host, port
config := g.opts
// check the advertise address first
// if it exists then use it, otherwise
// use the address
if len(config.Advertise) > 0 {
advt = config.Advertise
} else {
advt = config.Address
}
if cnt := strings.Count(advt, ":"); cnt >= 1 {
// ipv6 address in format [host]:port or ipv4 host:port
host, port, err = net.SplitHostPort(advt)
if err != nil {
return err
}
} else {
host = advt
}
addr, err := addr.Extract(host)
if err != nil {
return err
}
// make copy of metadata
md := make(meta.Metadata)
for k, v := range config.Metadata {
md[k] = v
}
// register service
node := &registry.Node{
Id: config.Name + "-" + config.Id,
Address: mnet.HostPort(addr, port),
Metadata: md,
}
node.Metadata["broker"] = config.Broker.String()
node.Metadata["registry"] = config.Registry.String()
node.Metadata["server"] = g.String()
node.Metadata["transport"] = g.String()
node.Metadata["protocol"] = "grpc"
g.RLock()
// Maps are ordered randomly, sort the keys for consistency
var handlerList []string
for n, e := range g.handlers {
// Only advertise non internal handlers
if !e.Options().Internal {
handlerList = append(handlerList, n)
}
}
sort.Strings(handlerList)
var subscriberList []*subscriber
for e := range g.subscribers {
// Only advertise non internal subscribers
if !e.Options().Internal {
subscriberList = append(subscriberList, e)
}
}
sort.Slice(subscriberList, func(i, j int) bool {
return subscriberList[i].topic > subscriberList[j].topic
})
endpoints := make([]*registry.Endpoint, 0, len(handlerList)+len(subscriberList))
for _, n := range handlerList {
endpoints = append(endpoints, g.handlers[n].Endpoints()...)
}
for _, e := range subscriberList {
endpoints = append(endpoints, e.Endpoints()...)
}
g.RUnlock()
service := &registry.Service{
Name: config.Name,
Version: config.Version,
Nodes: []*registry.Node{node},
Endpoints: endpoints,
}
g.Lock()
registered := g.registered
g.Unlock()
if !registered {
log.Logf("Registry [%s] Registering node: %s", config.Registry.String(), node.Id)
}
// create registry options
rOpts := []registry.RegisterOption{registry.RegisterTTL(config.RegisterTTL)}
if err := config.Registry.Register(service, rOpts...); err != nil {
return err
}
// already registered? don't need to register subscribers
if registered {
return nil
}
g.Lock()
defer g.Unlock()
g.registered = true
for sb := range g.subscribers {
handler := g.createSubHandler(sb, g.opts)
var opts []broker.SubscribeOption
if queue := sb.Options().Queue; len(queue) > 0 {
opts = append(opts, broker.Queue(queue))
}
if cx := sb.Options().Context; cx != nil {
opts = append(opts, broker.SubscribeContext(cx))
}
if !sb.Options().AutoAck {
opts = append(opts, broker.DisableAutoAck())
}
log.Logf("Subscribing to topic: %s", sb.Topic())
sub, err := config.Broker.Subscribe(sb.Topic(), handler, opts...)
if err != nil {
return err
}
g.subscribers[sb] = []broker.Subscriber{sub}
}
return nil
}
func (g *grpcServer) Deregister() error {
var err error
var advt, host, port string
config := g.opts
// check the advertise address first
// if it exists then use it, otherwise
// use the address
if len(config.Advertise) > 0 {
advt = config.Advertise
} else {
advt = config.Address
}
if cnt := strings.Count(advt, ":"); cnt >= 1 {
// ipv6 address in format [host]:port or ipv4 host:port
host, port, err = net.SplitHostPort(advt)
if err != nil {
return err
}
} else {
host = advt
}
addr, err := addr.Extract(host)
if err != nil {
return err
}
node := &registry.Node{
Id: config.Name + "-" + config.Id,
Address: mnet.HostPort(addr, port),
}
service := &registry.Service{
Name: config.Name,
Version: config.Version,
Nodes: []*registry.Node{node},
}
log.Logf("Deregistering node: %s", node.Id)
if err := config.Registry.Deregister(service); err != nil {
return err
}
g.Lock()
if !g.registered {
g.Unlock()
return nil
}
g.registered = false
for sb, subs := range g.subscribers {
for _, sub := range subs {
log.Logf("Unsubscribing from topic: %s", sub.Topic())
sub.Unsubscribe()
}
g.subscribers[sb] = nil
}
g.Unlock()
return nil
}
func (g *grpcServer) Start() error {
g.RLock()
if g.started {
g.RUnlock()
return nil
}
g.RUnlock()
config := g.Options()
// micro: config.Transport.Listen(config.Address)
var ts net.Listener
if l := g.getListener(); l != nil {
ts = l
} else {
var err error
// check the tls config for secure connect
if tc := config.TLSConfig; tc != nil {
ts, err = tls.Listen("tcp", config.Address, tc)
// otherwise just plain tcp listener
} else {
ts, err = net.Listen("tcp", config.Address)
}
if err != nil {
return err
}
}
log.Logf("Server [grpc] Listening on %s", ts.Addr().String())
g.Lock()
g.opts.Address = ts.Addr().String()
g.Unlock()
// only connect if we're subscribed
if len(g.subscribers) > 0 {
// connect to the broker
if err := config.Broker.Connect(); err != nil {
return err
}
log.Logf("Broker [%s] Connected to %s", config.Broker.String(), config.Broker.Address())
}
// announce self to the world
if err := g.Register(); err != nil {
log.Log("Server register error: ", err)
}
// micro: go ts.Accept(s.accept)
go func() {
if err := g.srv.Serve(ts); err != nil {
log.Log("gRPC Server start error: ", err)
}
}()
go func() {
t := new(time.Ticker)
// only process if it exists
if g.opts.RegisterInterval > time.Duration(0) {
// new ticker
t = time.NewTicker(g.opts.RegisterInterval)
}
// return error chan
var ch chan error
Loop:
for {
select {
// register self on interval
case <-t.C:
if err := g.Register(); err != nil {
log.Log("Server register error: ", err)
}
// wait for exit
case ch = <-g.exit:
break Loop
}
}
// deregister self
if err := g.Deregister(); err != nil {
log.Log("Server deregister error: ", err)
}
// wait for waitgroup
if g.wg != nil {
g.wg.Wait()
}
// stop the grpc server
exit := make(chan bool)
go func() {
g.srv.GracefulStop()
close(exit)
}()
select {
case <-exit:
case <-time.After(time.Second):
g.srv.Stop()
}
// close transport
ch <- nil
log.Logf("Broker [%s] Disconnected from %s", config.Broker.String(), config.Broker.Address())
// disconnect broker
config.Broker.Disconnect()
}()
// mark the server as started
g.Lock()
g.started = true
g.Unlock()
return nil
}
func (g *grpcServer) Stop() error {
g.RLock()
if !g.started {
g.RUnlock()
return nil
}
g.RUnlock()
ch := make(chan error)
g.exit <- ch
var err error
select {
case err = <-ch:
g.Lock()
g.started = false
g.Unlock()
}
return err
}
func (g *grpcServer) String() string {
return "grpc"
}
func NewServer(opts ...server.Option) server.Server {
return newGRPCServer(opts...)
}