diff --git a/server/grpc/grpc.go b/server/grpc/grpc.go index 71d456ba..972f4ba8 100644 --- a/server/grpc/grpc.go +++ b/server/grpc/grpc.go @@ -183,7 +183,17 @@ func (g *grpcServer) getListener() net.Listener { return nil } -func (g *grpcServer) handler(srv interface{}, stream grpc.ServerStream) error { +func (g *grpcServer) handler(srv interface{}, stream grpc.ServerStream) (err error) { + defer func() { + if r := recover(); r != nil { + if logger.V(logger.ErrorLevel, logger.DefaultLogger) { + logger.Error("panic recovered: ", r) + logger.Error(string(debug.Stack())) + } + err = errors.InternalServerError("go.micro.server", "panic recovered: %v", r) + } + }() + if g.wg != nil { g.wg.Add(1) defer g.wg.Done() @@ -367,15 +377,6 @@ func (g *grpcServer) processRequest(stream grpc.ServerStream, service *service, // define the handler func fn := func(ctx context.Context, req server.Request, rsp interface{}) (err error) { - defer func() { - if r := recover(); r != nil { - if logger.V(logger.ErrorLevel, logger.DefaultLogger) { - logger.Error("panic recovered: ", r) - logger.Error(string(debug.Stack())) - } - err = errors.InternalServerError("go.micro.server", "panic recovered: %v", r) - } - }() returnValues = function.Call([]reflect.Value{service.rcvr, mtype.prepareContext(ctx), reflect.ValueOf(argv.Interface()), reflect.ValueOf(rsp)}) // The return value for the method is an error. diff --git a/server/grpc/grpc_test.go b/server/grpc/grpc_test.go index 76fd6658..2a74f97a 100644 --- a/server/grpc/grpc_test.go +++ b/server/grpc/grpc_test.go @@ -60,6 +60,11 @@ func (s *testServer) Call(ctx context.Context, req *pb.Request, rsp *pb.Response return &errors.Error{Id: "1", Code: 99, Detail: "detail"} } + if req.Name == "Panic" { + // make it panic + panic("handler panic") + } + rsp.Msg = "Hello " + req.Name return nil } @@ -205,3 +210,101 @@ func TestGRPCServer(t *testing.T) { } } } + +// TestGRPCServerWithPanicWrapper test grpc server with panic wrapper +// gRPC server should not crash when wrapper crashed +func TestGRPCServerWithPanicWrapper(t *testing.T) { + r := rmemory.NewRegistry() + b := bmemory.NewBroker() + tr := tgrpc.NewTransport() + s := gsrv.NewServer( + server.Broker(b), + server.Name("foo"), + server.Registry(r), + server.Transport(tr), + server.WrapHandler(func(hf server.HandlerFunc) server.HandlerFunc { + return func(ctx context.Context, req server.Request, rsp interface{}) error { + // make it panic + panic("wrapper panic") + } + }), + ) + + h := &testServer{} + pb.RegisterTestHandler(s, h) + + if err := s.Start(); err != nil { + t.Fatalf("failed to start: %v", err) + } + + // check registration + services, err := r.GetService("foo") + if err != nil || len(services) == 0 { + t.Fatalf("failed to get service: %v # %d", err, len(services)) + } + + defer func() { + if err := s.Stop(); err != nil { + t.Fatalf("failed to stop: %v", err) + } + }() + + cc, err := grpc.Dial(s.Options().Address, grpc.WithInsecure()) + if err != nil { + t.Fatalf("failed to dial server: %v", err) + } + + rsp := pb.Response{} + if err := cc.Invoke(context.Background(), "/test.Test/Call", &pb.Request{Name: "John"}, &rsp); err == nil { + t.Fatal("this must return error, as wrapper should be panic") + } + + // both wrapper and handler should panic + rsp = pb.Response{} + if err := cc.Invoke(context.Background(), "/test.Test/Call", &pb.Request{Name: "Panic"}, &rsp); err == nil { + t.Fatal("this must return error, as wrapper and handler should be panic") + } +} + +// TestGRPCServerWithPanicWrapper test grpc server with panic handler +// gRPC server should not crash when handler crashed +func TestGRPCServerWithPanicHandler(t *testing.T) { + r := rmemory.NewRegistry() + b := bmemory.NewBroker() + tr := tgrpc.NewTransport() + s := gsrv.NewServer( + server.Broker(b), + server.Name("foo"), + server.Registry(r), + server.Transport(tr), + ) + + h := &testServer{} + pb.RegisterTestHandler(s, h) + + if err := s.Start(); err != nil { + t.Fatalf("failed to start: %v", err) + } + + // check registration + services, err := r.GetService("foo") + if err != nil || len(services) == 0 { + t.Fatalf("failed to get service: %v # %d", err, len(services)) + } + + defer func() { + if err := s.Stop(); err != nil { + t.Fatalf("failed to stop: %v", err) + } + }() + + cc, err := grpc.Dial(s.Options().Address, grpc.WithInsecure()) + if err != nil { + t.Fatalf("failed to dial server: %v", err) + } + + rsp := pb.Response{} + if err := cc.Invoke(context.Background(), "/test.Test/Call", &pb.Request{Name: "Panic"}, &rsp); err == nil { + t.Fatal("this must return error, as handler should be panic") + } +}