Compare commits

..

No commits in common. "v3" and "v3.10.5" have entirely different histories.
v3 ... v3.10.5

17 changed files with 1825 additions and 1204 deletions

24
.gitignore vendored
View File

@ -1,24 +0,0 @@
# Binaries for programs and plugins
*.exe
*.exe~
*.dll
*.so
*.dylib
bin
# Test binary, built with `go test -c`
*.test
# Output of the go coverage tool, specifically when used with LiteIDE
*.out
# Dependency directories (remove the comment below to include it)
# vendor/
# Go workspace file
go.work
# General
.DS_Store
.idea
.vscode

View File

@ -1,6 +1,8 @@
package grpc package grpc
import ( import (
"io"
"go.unistack.org/micro/v3/codec" "go.unistack.org/micro/v3/codec"
"google.golang.org/grpc/encoding" "google.golang.org/grpc/encoding"
) )
@ -47,3 +49,29 @@ func (w *wrapGrpcCodec) Unmarshal(d []byte, v interface{}, opts ...codec.Option)
} }
return w.Codec.Unmarshal(d, v) return w.Codec.Unmarshal(d, v)
} }
func (w *wrapGrpcCodec) ReadHeader(conn io.Reader, m *codec.Message, mt codec.MessageType) error {
return nil
}
func (w *wrapGrpcCodec) ReadBody(conn io.Reader, v interface{}) error {
if m, ok := v.(*codec.Frame); ok {
_, err := conn.Read(m.Data)
return err
}
return codec.ErrInvalidMessage
}
func (w *wrapGrpcCodec) Write(conn io.Writer, m *codec.Message, v interface{}) error {
// if we don't have a body
if v != nil {
b, err := w.Marshal(v)
if err != nil {
return err
}
m.Body = b
}
// write the body using the framing codec
_, err := conn.Write(m.Body)
return err
}

View File

@ -1,4 +0,0 @@
package grpc
//go:generate go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@latest
//go:generate sh -c "protoc -I./proto -I$(go list -f '{{ .Dir }}' -m go.unistack.org/micro-proto/v3) -I. --go-grpc_out=paths=source_relative:./proto --go_out=paths=source_relative:./proto proto/test.proto"

21
go.mod
View File

@ -1,20 +1,11 @@
module go.unistack.org/micro-server-grpc/v3 module go.unistack.org/micro-server-grpc/v3
go 1.22 go 1.16
toolchain go1.23.1
require ( require (
github.com/golang/protobuf v1.5.4 github.com/golang/protobuf v1.5.2
go.unistack.org/micro/v3 v3.10.97 go.unistack.org/micro/v3 v3.10.14
golang.org/x/net v0.30.0 golang.org/x/net v0.5.0
google.golang.org/grpc v1.67.1 google.golang.org/grpc v1.52.3
google.golang.org/protobuf v1.35.1 google.golang.org/protobuf v1.28.1
)
require (
go.unistack.org/micro-proto/v3 v3.4.1 // indirect
golang.org/x/sys v0.26.0 // indirect
golang.org/x/text v0.19.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20241007155032-5fefd90f89a9 // indirect
) )

1053
go.sum

File diff suppressed because it is too large Load Diff

360
grpc.go
View File

@ -1,5 +1,5 @@
// Package grpc provides a grpc server // Package grpc provides a grpc server
package grpc package grpc // import "go.unistack.org/micro-server-grpc/v3"
import ( import (
"context" "context"
@ -7,30 +7,22 @@ import (
"fmt" "fmt"
"net" "net"
"reflect" "reflect"
"slices" "runtime/debug"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
greflection "google.golang.org/grpc/reflection"
reflectionv1pb "google.golang.org/grpc/reflection/grpc_reflection_v1"
// nolint: staticcheck // nolint: staticcheck
oldproto "github.com/golang/protobuf/proto" oldproto "github.com/golang/protobuf/proto"
"go.unistack.org/micro/v3/broker" "go.unistack.org/micro/v3/broker"
"go.unistack.org/micro/v3/codec" "go.unistack.org/micro/v3/codec"
"go.unistack.org/micro/v3/errors" "go.unistack.org/micro/v3/errors"
"go.unistack.org/micro/v3/logger" "go.unistack.org/micro/v3/logger"
"go.unistack.org/micro/v3/metadata" metadata "go.unistack.org/micro/v3/metadata"
"go.unistack.org/micro/v3/meter"
"go.unistack.org/micro/v3/options"
"go.unistack.org/micro/v3/register" "go.unistack.org/micro/v3/register"
"go.unistack.org/micro/v3/semconv"
"go.unistack.org/micro/v3/server" "go.unistack.org/micro/v3/server"
msync "go.unistack.org/micro/v3/sync"
"go.unistack.org/micro/v3/tracer"
"golang.org/x/net/netutil" "golang.org/x/net/netutil"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
@ -52,28 +44,15 @@ type ServerReflection struct {
} }
*/ */
type streamWrapper struct {
ctx context.Context
grpc.ServerStream
}
func (w *streamWrapper) Context() context.Context {
if w.ctx != nil {
return w.ctx
}
return w.ServerStream.Context()
}
type Server struct { type Server struct {
handlers map[string]server.Handler handlers map[string]server.Handler
srv *grpc.Server srv *grpc.Server
exit chan chan error exit chan chan error
wg *msync.WaitGroup wg *sync.WaitGroup
rsvc *register.Service rsvc *register.Service
subscribers map[*subscriber][]broker.Subscriber subscribers map[*subscriber][]broker.Subscriber
rpc *rServer rpc *rServer
opts server.Options opts server.Options
unknownHandler grpc.StreamHandler
sync.RWMutex sync.RWMutex
started bool started bool
registered bool registered bool
@ -92,8 +71,6 @@ func newServer(opts ...server.Option) *Server {
exit: make(chan chan error), exit: make(chan chan error),
} }
g.opts.Meter = g.opts.Meter.Clone(meter.Labels("type", "grpc"))
return g return g
} }
@ -121,6 +98,27 @@ func (g *Server) configure(opts ...server.Option) error {
o(&g.opts) o(&g.opts)
} }
if err := g.opts.Register.Init(); err != nil {
return err
}
if err := g.opts.Broker.Init(); err != nil {
return err
}
if err := g.opts.Tracer.Init(); err != nil {
return err
}
if err := g.opts.Logger.Init(); err != nil {
return err
}
if err := g.opts.Meter.Init(); err != nil {
return err
}
if err := g.opts.Transport.Init(); err != nil {
return err
}
g.wg = g.opts.Wait
if g.opts.Context != nil { if g.opts.Context != nil {
if codecs, ok := g.opts.Context.Value(codecsKey{}).(map[string]encoding.Codec); ok && codecs != nil { if codecs, ok := g.opts.Context.Value(codecsKey{}).(map[string]encoding.Codec); ok && codecs != nil {
for k, v := range codecs { for k, v := range codecs {
@ -155,19 +153,8 @@ func (g *Server) configure(opts ...server.Option) error {
} }
g.srv = grpc.NewServer(gopts...) g.srv = grpc.NewServer(gopts...)
if v, ok := g.opts.Context.Value(reflectionKey{}).(Reflector); ok { if v, ok := g.opts.Context.Value(reflectionKey{}).(bool); ok {
reflectionv1pb.RegisterServerReflectionServer( g.reflection = v
g.srv,
greflection.NewServerV1(greflection.ServerOptions{
Services: v,
DescriptorResolver: v,
ExtensionResolver: v,
}),
)
}
if h, ok := g.opts.Context.Value(unknownServiceHandlerKey{}).(grpc.StreamHandler); ok {
g.unknownHandler = h
} }
if restart { if restart {
@ -178,9 +165,12 @@ func (g *Server) configure(opts ...server.Option) error {
} }
func (g *Server) getMaxMsgSize() int { func (g *Server) getMaxMsgSize() int {
if g.opts.Context == nil {
return codec.DefaultMaxMsgSize
}
s, ok := g.opts.Context.Value(maxMsgSizeKey{}).(int) s, ok := g.opts.Context.Value(maxMsgSizeKey{}).(int)
if !ok { if !ok {
return 4 * 1024 * 1024 return codec.DefaultMaxMsgSize
} }
return s return s
} }
@ -198,58 +188,52 @@ func (g *Server) getGrpcOptions() []grpc.ServerOption {
return opts return opts
} }
func (g *Server) handler(srv interface{}, stream grpc.ServerStream) error { func (g *Server) handler(srv interface{}, stream grpc.ServerStream) (err error) {
var err error
ctx := stream.Context()
fullMethod, ok := grpc.MethodFromServerStream(stream) fullMethod, ok := grpc.MethodFromServerStream(stream)
if !ok { if !ok {
return status.Errorf(codes.Internal, "method does not exist in context") return status.Errorf(codes.Internal, "method does not exist in context")
} }
serviceName, methodName, err := serviceMethod(fullMethod)
if err != nil {
return status.New(codes.InvalidArgument, err.Error()).Err()
}
defer func() {
if r := recover(); r != nil {
g.RLock()
config := g.opts
g.RUnlock()
if config.Logger.V(logger.ErrorLevel) {
config.Logger.Errorf(config.Context, "panic in %s.%s recovered: %v", serviceName, methodName, r)
config.Logger.Error(config.Context, string(debug.Stack()))
}
err = errors.InternalServerError(g.opts.Name, "panic in %s.%s recovered: %v", serviceName, methodName, r)
} else if err != nil {
g.RLock()
config := g.opts
g.RUnlock()
if config.Logger.V(logger.ErrorLevel) {
config.Logger.Errorf(config.Context, "grpc handler %s.%s got error: %s", serviceName, methodName, err)
}
}
}()
if g.wg != nil {
g.wg.Add(1)
defer g.wg.Done()
}
// get grpc metadata // get grpc metadata
gmd, ok := gmetadata.FromIncomingContext(ctx) gmd, ok := gmetadata.FromIncomingContext(stream.Context())
if !ok { if !ok {
gmd = gmetadata.MD{} gmd = gmetadata.MD{}
} }
var serviceName, methodName string
serviceName, methodName, err = serviceMethod(fullMethod)
if err != nil {
err = status.New(codes.InvalidArgument, err.Error()).Err()
return err
}
endpointName := serviceName + "/" + methodName
ts := time.Now()
var sp tracer.Span
if !slices.Contains(tracer.DefaultSkipEndpoints, endpointName) {
ctx, sp = g.opts.Tracer.Start(ctx, "rpc-server",
tracer.WithSpanKind(tracer.SpanKindServer),
tracer.WithSpanLabels(
"endpoint", endpointName,
"server", "grpc",
),
)
defer func() {
st := status.Convert(err)
if st != nil || st.Code() != codes.OK {
sp.SetStatus(tracer.SpanStatusError, err.Error())
}
sp.Finish()
}()
}
md := metadata.New(len(gmd)) md := metadata.New(len(gmd))
for k, v := range gmd { for k, v := range gmd {
md[k] = strings.Join(v, ", ") md.Set(k, strings.Join(v, ", "))
} }
md.Set("Path", fullMethod)
md.Set("Micro-Server", "grpc")
md.Set(metadata.HeaderEndpoint, methodName)
md.Set(metadata.HeaderService, serviceName)
var td string var td string
// timeout for server deadline // timeout for server deadline
@ -287,42 +271,17 @@ func (g *Server) handler(srv interface{}, stream grpc.ServerStream) error {
} }
// create new context // create new context
ctx = metadata.NewIncomingContext(ctx, md) ctx := metadata.NewIncomingContext(stream.Context(), md)
stream = &streamWrapper{ctx, stream}
if !slices.Contains(meter.DefaultSkipEndpoints, endpointName) {
g.opts.Meter.Counter(semconv.ServerRequestInflight, "endpoint", endpointName, "server", "grpc").Inc()
defer func() {
te := time.Since(ts)
g.opts.Meter.Summary(semconv.ServerRequestLatencyMicroseconds, "endpoint", endpointName, "server", "grpc").Update(te.Seconds())
g.opts.Meter.Histogram(semconv.ServerRequestDurationSeconds, "endpoint", endpointName, "server", "grpc").Update(te.Seconds())
g.opts.Meter.Counter(semconv.ServerRequestInflight, "endpoint", endpointName, "server", "grpc").Dec()
st := status.Convert(err)
if st == nil || st.Code() == codes.OK {
g.opts.Meter.Counter(semconv.ServerRequestTotal, "endpoint", endpointName, "server", "grpc", "status", "success", "code", strconv.Itoa(int(codes.OK))).Inc()
} else {
g.opts.Meter.Counter(semconv.ServerRequestTotal, "endpoint", endpointName, "server", "grpc", "status", "failure", "code", strconv.Itoa(int(st.Code()))).Inc()
}
}()
}
if g.opts.Wait != nil {
g.opts.Wait.Add(1)
defer g.opts.Wait.Done()
}
// get peer from context // get peer from context
if p, ok := peer.FromContext(ctx); ok { if p, ok := peer.FromContext(stream.Context()); ok {
md.Set("Remote", p.Addr.String()) md.Set("Remote", p.Addr.String())
ctx = peer.NewContext(ctx, p) ctx = peer.NewContext(ctx, p)
} }
// set the timeout if we have it // set the timeout if we have it
if len(td) > 0 { if len(td) > 0 {
var n uint64 if n, err := strconv.ParseUint(td, 10, 64); err == nil {
if n, err = strconv.ParseUint(td, 10, 64); err == nil {
var cancel context.CancelFunc var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, time.Duration(n)) ctx, cancel = context.WithTimeout(ctx, time.Duration(n))
defer cancel() defer cancel()
@ -333,34 +292,54 @@ func (g *Server) handler(srv interface{}, stream grpc.ServerStream) error {
svc := g.rpc.serviceMap[serviceName] svc := g.rpc.serviceMap[serviceName]
g.rpc.mu.RUnlock() g.rpc.mu.RUnlock()
if svc == nil { /*
if g.unknownHandler != nil { if svc == nil && g.reflection && methodName == "ServerReflectionInfo" {
err = g.unknownHandler(srv, stream) rfl := &ServerReflection{srv: g.srv, s: &serverReflectionServer{s: g.srv}}
return err svc = &service{}
svc.typ = reflect.TypeOf(rfl)
svc.rcvr = reflect.ValueOf(rfl)
svc.name = reflect.Indirect(svc.rcvr).Type().Name()
svc.method = make(map[string]*methodType)
typ := reflect.TypeOf(rfl)
if me, ok := typ.MethodByName("ServerReflectionInfo"); ok {
g.rpc.mu.Lock()
ep, err := prepareEndpoint(me)
if ep != nil && err != nil {
svc.method["ServerReflectionInfo"] = ep
} else if err != nil {
return status.New(codes.Unimplemented, err.Error()).Err()
} }
err = status.New(codes.Unimplemented, fmt.Sprintf("unknown service %s", serviceName)).Err() g.rpc.mu.Unlock()
return err }
}
*/
if svc == nil {
if g.opts.Context != nil {
if h, ok := g.opts.Context.Value(unknownServiceHandlerKey{}).(grpc.StreamHandler); ok {
return h(srv, stream)
}
}
return status.New(codes.Unimplemented, fmt.Sprintf("unknown service %s", serviceName)).Err()
} }
mtype := svc.method[methodName] mtype := svc.method[methodName]
if mtype == nil { if mtype == nil {
if g.unknownHandler != nil { if g.opts.Context != nil {
err = g.unknownHandler(srv, stream) if h, ok := g.opts.Context.Value(unknownServiceHandlerKey{}).(grpc.StreamHandler); ok {
return err return h(srv, stream)
} }
err = status.New(codes.Unimplemented, fmt.Sprintf("unknown service method %s.%s", serviceName, methodName)).Err() }
return err return status.New(codes.Unimplemented, fmt.Sprintf("unknown service method %s.%s", serviceName, methodName)).Err()
} }
// process unary // process unary
if !mtype.stream { if !mtype.stream {
err = g.processRequest(ctx, stream, svc, mtype, ct) return g.processRequest(ctx, stream, svc, mtype, ct)
} else {
// process stream
err = g.processStream(ctx, stream, svc, mtype, ct)
} }
return err // process stream
return g.processStream(ctx, stream, svc, mtype, ct)
} }
func (g *Server) processRequest(ctx context.Context, stream grpc.ServerStream, service *service, mtype *methodType, ct string) error { func (g *Server) processRequest(ctx context.Context, stream grpc.ServerStream, service *service, mtype *methodType, ct string) error {
@ -412,11 +391,10 @@ func (g *Server) processRequest(ctx context.Context, stream grpc.ServerStream, s
return err return err
} }
g.opts.Hooks.EachNext(func(hook options.Hook) { // wrap the handler func
if h, ok := hook.(server.HookHandler); ok { for i := len(g.opts.HdlrWrappers); i > 0; i-- {
fn = h(fn) fn = g.opts.HdlrWrappers[i-1](fn)
} }
})
statusCode := codes.OK statusCode := codes.OK
statusDesc := "" statusDesc := ""
@ -449,7 +427,7 @@ func (g *Server) processRequest(ctx context.Context, stream grpc.ServerStream, s
config := g.opts config := g.opts
g.RUnlock() g.RUnlock()
if config.Logger.V(logger.ErrorLevel) { if config.Logger.V(logger.ErrorLevel) {
config.Logger.Error(config.Context, "handler error will not be transferred properly, must return *errors.Error or proto.Message") config.Logger.Warn(config.Context, "handler error will not be transferred properly, must return *errors.Error or proto.Message")
} }
// default case user pass own error type that not proto based // default case user pass own error type that not proto based
statusCode = convertCode(verr) statusCode = convertCode(verr)
@ -465,8 +443,53 @@ func (g *Server) processRequest(ctx context.Context, stream grpc.ServerStream, s
} }
return status.New(statusCode, statusDesc).Err() return status.New(statusCode, statusDesc).Err()
// }
} }
/*
type reflectStream struct {
stream server.Stream
}
func (s *reflectStream) Send(rsp *grpcreflect.ServerReflectionResponse) error {
return s.stream.Send(rsp)
}
func (s *reflectStream) Recv() (*grpcreflect.ServerReflectionRequest, error) {
req := &grpcreflect.ServerReflectionRequest{}
err := s.stream.Recv(req)
return req, err
}
func (s *reflectStream) SetHeader(gmetadata.MD) error {
return nil
}
func (s *reflectStream) SendHeader(gmetadata.MD) error {
return nil
}
func (s *reflectStream) SetTrailer(gmetadata.MD) {
}
func (s *reflectStream) Context() context.Context {
return s.stream.Context()
}
func (s *reflectStream) SendMsg(m interface{}) error {
return s.stream.Send(m)
}
func (s *reflectStream) RecvMsg(m interface{}) error {
return s.stream.Recv(m)
}
func (g *ServerReflection) ServerReflectionInfo(ctx context.Context, stream server.Stream) error {
return g.s.ServerReflectionInfo(&reflectStream{stream})
}
*/
func (g *Server) processStream(ctx context.Context, stream grpc.ServerStream, service *service, mtype *methodType, ct string) error { func (g *Server) processStream(ctx context.Context, stream grpc.ServerStream, service *service, mtype *methodType, ct string) error {
opts := g.opts opts := g.opts
@ -496,11 +519,9 @@ func (g *Server) processStream(ctx context.Context, stream grpc.ServerStream, se
return nil return nil
} }
opts.Hooks.EachNext(func(hook options.Hook) { for i := len(opts.HdlrWrappers); i > 0; i-- {
if h, ok := hook.(server.HookHandler); ok { fn = opts.HdlrWrappers[i-1](fn)
fn = h(fn)
} }
})
statusCode := codes.OK statusCode := codes.OK
statusDesc := "" statusDesc := ""
@ -668,7 +689,7 @@ func (g *Server) Register() error {
if !registered { if !registered {
if config.Logger.V(logger.InfoLevel) { if config.Logger.V(logger.InfoLevel) {
config.Logger.Info(config.Context, fmt.Sprintf("Register [%s] Registering node: %s", config.Register.String(), service.Nodes[0].ID)) config.Logger.Infof(config.Context, "Register [%s] Registering node: %s", config.Register.String(), service.Nodes[0].ID)
} }
} }
@ -704,7 +725,7 @@ func (g *Server) Deregister() error {
} }
if config.Logger.V(logger.InfoLevel) { if config.Logger.V(logger.InfoLevel) {
config.Logger.Info(config.Context, "Deregistering node: "+service.Nodes[0].ID) config.Logger.Infof(config.Context, "Deregistering node: %s", service.Nodes[0].ID)
} }
if err := server.DefaultDeregisterFunc(service, config); err != nil { if err := server.DefaultDeregisterFunc(service, config); err != nil {
@ -728,11 +749,11 @@ func (g *Server) Deregister() error {
go func(s broker.Subscriber) { go func(s broker.Subscriber) {
defer wg.Done() defer wg.Done()
if config.Logger.V(logger.InfoLevel) { if config.Logger.V(logger.InfoLevel) {
config.Logger.Info(config.Context, "Unsubscribing from topic: "+s.Topic()) config.Logger.Infof(config.Context, "Unsubscribing from topic: %s", s.Topic())
} }
if err := s.Unsubscribe(g.opts.Context); err != nil { if err := s.Unsubscribe(g.opts.Context); err != nil {
if config.Logger.V(logger.ErrorLevel) { if config.Logger.V(logger.ErrorLevel) {
config.Logger.Error(config.Context, "Unsubscribing from topic: "+s.Topic(), err) config.Logger.Errorf(config.Context, "Unsubscribing from topic: %s err: %v", s.Topic(), err)
} }
} }
}(sub) }(sub)
@ -779,7 +800,7 @@ func (g *Server) Start() error {
} }
if config.Logger.V(logger.InfoLevel) { if config.Logger.V(logger.InfoLevel) {
config.Logger.Info(config.Context, "Server [grpc] Listening on "+ts.Addr().String()) config.Logger.Infof(config.Context, "Server [grpc] Listening on %s", ts.Addr().String())
} }
g.Lock() g.Lock()
g.opts.Address = ts.Addr().String() g.opts.Address = ts.Addr().String()
@ -793,13 +814,13 @@ func (g *Server) Start() error {
// connect to the broker // connect to the broker
if err = config.Broker.Connect(config.Context); err != nil { if err = config.Broker.Connect(config.Context); err != nil {
if config.Logger.V(logger.ErrorLevel) { if config.Logger.V(logger.ErrorLevel) {
config.Logger.Error(config.Context, fmt.Sprintf("broker [%s] connect error", config.Broker.String()), err) config.Logger.Errorf(config.Context, "Broker [%s] connect error: %v", config.Broker.String(), err)
} }
return err return err
} }
if config.Logger.V(logger.InfoLevel) { if config.Logger.V(logger.InfoLevel) {
config.Logger.Info(config.Context, fmt.Sprintf("broker [%s] Connected to %s", config.Broker.String(), config.Broker.Address())) config.Logger.Infof(config.Context, "Broker [%s] Connected to %s", config.Broker.String(), config.Broker.Address())
} }
} }
@ -807,13 +828,13 @@ func (g *Server) Start() error {
// nolint: nestif // nolint: nestif
if err = g.opts.RegisterCheck(config.Context); err != nil { if err = g.opts.RegisterCheck(config.Context); err != nil {
if config.Logger.V(logger.ErrorLevel) { if config.Logger.V(logger.ErrorLevel) {
config.Logger.Error(config.Context, fmt.Sprintf("Server %s-%s register check error", config.Name, config.ID), err) config.Logger.Errorf(config.Context, "Server %s-%s register check error: %s", config.Name, config.ID, err)
} }
} else { } else {
// announce self to the world // announce self to the world
if err = g.Register(); err != nil { if err = g.Register(); err != nil {
if config.Logger.V(logger.ErrorLevel) { if config.Logger.V(logger.ErrorLevel) {
config.Logger.Error(config.Context, "Server register error", err) config.Logger.Errorf(config.Context, "Server register error: %v", err)
} }
} }
} }
@ -826,11 +847,11 @@ func (g *Server) Start() error {
go func() { go func() {
if err = g.srv.Serve(ts); err != nil { if err = g.srv.Serve(ts); err != nil {
if config.Logger.V(logger.ErrorLevel) { if config.Logger.V(logger.ErrorLevel) {
config.Logger.Error(config.Context, "gRPC Server start error", err) config.Logger.Errorf(config.Context, "gRPC Server start error: %v", err)
} }
if err = g.Stop(); err != nil { if err = g.Stop(); err != nil {
if config.Logger.V(logger.ErrorLevel) { if config.Logger.V(logger.ErrorLevel) {
config.Logger.Error(config.Context, "gRPC Server stop error", err) config.Logger.Errorf(config.Context, "gRPC Server stop error: %v", err)
} }
} }
} }
@ -860,23 +881,23 @@ func (g *Server) Start() error {
// nolint: nestif // nolint: nestif
if rerr != nil && registered { if rerr != nil && registered {
if config.Logger.V(logger.ErrorLevel) { if config.Logger.V(logger.ErrorLevel) {
config.Logger.Error(config.Context, fmt.Sprintf("Server %s-%s register check error, deregister it", config.Name, config.ID), rerr) config.Logger.Errorf(config.Context, "Server %s-%s register check error: %s, deregister it", config.Name, config.ID, rerr)
} }
// deregister self in case of error // deregister self in case of error
if err = g.Deregister(); err != nil { if err = g.Deregister(); err != nil {
if config.Logger.V(logger.ErrorLevel) { if config.Logger.V(logger.ErrorLevel) {
config.Logger.Error(config.Context, fmt.Sprintf("Server %s-%s deregister error", config.Name, config.ID), err) config.Logger.Errorf(config.Context, "Server %s-%s deregister error: %s", config.Name, config.ID, err)
} }
} }
} else if rerr != nil && !registered { } else if rerr != nil && !registered {
if config.Logger.V(logger.ErrorLevel) { if config.Logger.V(logger.ErrorLevel) {
config.Logger.Error(config.Context, fmt.Sprintf("Server %s-%s register check error", config.Name, config.ID), rerr) config.Logger.Errorf(config.Context, "Server %s-%s register check error: %s", config.Name, config.ID, rerr)
} }
continue continue
} }
if err = g.Register(); err != nil { if err = g.Register(); err != nil {
if config.Logger.V(logger.ErrorLevel) { if config.Logger.V(logger.ErrorLevel) {
config.Logger.Error(config.Context, fmt.Sprintf("Server %s-%s register error", config.Name, config.ID), err) config.Logger.Errorf(config.Context, "Server %s-%s register error: %s", config.Name, config.ID, err)
} }
} }
// wait for exit // wait for exit
@ -888,13 +909,13 @@ func (g *Server) Start() error {
// deregister self // deregister self
if err = g.Deregister(); err != nil { if err = g.Deregister(); err != nil {
if config.Logger.V(logger.ErrorLevel) { if config.Logger.V(logger.ErrorLevel) {
config.Logger.Error(config.Context, "Server deregister error", err) config.Logger.Errorf(config.Context, "Server deregister error: %v", err)
} }
} }
// wait for waitgroup // wait for waitgroup
if g.opts.Wait != nil { if g.wg != nil {
g.opts.Wait.Wait() g.wg.Wait()
} }
// stop the grpc server // stop the grpc server
@ -907,7 +928,7 @@ func (g *Server) Start() error {
select { select {
case <-exit: case <-exit:
case <-time.After(g.opts.GracefulTimeout): case <-time.After(time.Second):
g.srv.Stop() g.srv.Stop()
} }
@ -915,12 +936,12 @@ func (g *Server) Start() error {
ch <- nil ch <- nil
if config.Logger.V(logger.InfoLevel) { if config.Logger.V(logger.InfoLevel) {
config.Logger.Info(config.Context, fmt.Sprintf("broker [%s] Disconnected from %s", config.Broker.String(), config.Broker.Address())) config.Logger.Infof(config.Context, "Broker [%s] Disconnected from %s", config.Broker.String(), config.Broker.Address())
} }
// disconnect broker // disconnect broker
if err = config.Broker.Disconnect(config.Context); err != nil { if err = config.Broker.Disconnect(config.Context); err != nil {
if config.Logger.V(logger.ErrorLevel) { if config.Logger.V(logger.ErrorLevel) {
config.Logger.Error(config.Context, fmt.Sprintf("broker [%s] disconnect error", config.Broker.String()), err) config.Logger.Errorf(config.Context, "Broker [%s] disconnect error: %v", config.Broker.String(), err)
} }
} }
}() }()
@ -933,6 +954,37 @@ func (g *Server) Start() error {
return nil return nil
} }
func (g *Server) subscribe() error {
config := g.opts
for sb := range g.subscribers {
handler := g.createSubHandler(sb, config)
var opts []broker.SubscribeOption
if queue := sb.Options().Queue; len(queue) > 0 {
opts = append(opts, broker.SubscribeGroup(queue))
}
subCtx := config.Context
if cx := sb.Options().Context; cx != nil {
subCtx = cx
}
opts = append(opts, broker.SubscribeContext(subCtx))
opts = append(opts, broker.SubscribeAutoAck(sb.Options().AutoAck))
opts = append(opts, broker.SubscribeBodyOnly(sb.Options().BodyOnly))
if config.Logger.V(logger.InfoLevel) {
config.Logger.Infof(config.Context, "Subscribing to topic: %s", sb.Topic())
}
sub, err := config.Broker.Subscribe(subCtx, sb.Topic(), handler, opts...)
if err != nil {
return err
}
g.subscribers[sb] = []broker.Subscriber{sub}
}
return nil
}
func (g *Server) Stop() error { func (g *Server) Stop() error {
g.RLock() g.RLock()
if !g.started { if !g.started {

View File

@ -36,6 +36,7 @@ func Options(opts ...grpc.ServerOption) server.Option {
return server.SetOption(grpcOptions{}, opts) return server.SetOption(grpcOptions{}, opts)
} }
//
// MaxMsgSize set the maximum message in bytes the server can receive and // MaxMsgSize set the maximum message in bytes the server can receive and
// send. Default maximum message size is 4 MB. // send. Default maximum message size is 4 MB.
func MaxMsgSize(s int) server.Option { func MaxMsgSize(s int) server.Option {
@ -43,8 +44,8 @@ func MaxMsgSize(s int) server.Option {
} }
// Reflection enables reflection support in grpc server // Reflection enables reflection support in grpc server
func Reflection(r Reflector) server.Option { func Reflection(b bool) server.Option {
return server.SetOption(reflectionKey{}, r) return server.SetOption(reflectionKey{}, b)
} }
// UnknownServiceHandler enables support for all services // UnknownServiceHandler enables support for all services

View File

@ -1,208 +0,0 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.26.0
// protoc v4.25.2
// source: test.proto
package testpb
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type CallReq struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Data string `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"`
}
func (x *CallReq) Reset() {
*x = CallReq{}
if protoimpl.UnsafeEnabled {
mi := &file_test_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *CallReq) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*CallReq) ProtoMessage() {}
func (x *CallReq) ProtoReflect() protoreflect.Message {
mi := &file_test_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use CallReq.ProtoReflect.Descriptor instead.
func (*CallReq) Descriptor() ([]byte, []int) {
return file_test_proto_rawDescGZIP(), []int{0}
}
func (x *CallReq) GetData() string {
if x != nil {
return x.Data
}
return ""
}
type CallRsp struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Data string `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"`
}
func (x *CallRsp) Reset() {
*x = CallRsp{}
if protoimpl.UnsafeEnabled {
mi := &file_test_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *CallRsp) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*CallRsp) ProtoMessage() {}
func (x *CallRsp) ProtoReflect() protoreflect.Message {
mi := &file_test_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use CallRsp.ProtoReflect.Descriptor instead.
func (*CallRsp) Descriptor() ([]byte, []int) {
return file_test_proto_rawDescGZIP(), []int{1}
}
func (x *CallRsp) GetData() string {
if x != nil {
return x.Data
}
return ""
}
var File_test_proto protoreflect.FileDescriptor
var file_test_proto_rawDesc = []byte{
0x0a, 0x0a, 0x74, 0x65, 0x73, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x04, 0x74, 0x65,
0x73, 0x74, 0x22, 0x1d, 0x0a, 0x07, 0x43, 0x61, 0x6c, 0x6c, 0x52, 0x65, 0x71, 0x12, 0x12, 0x0a,
0x04, 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x64, 0x61, 0x74,
0x61, 0x22, 0x1d, 0x0a, 0x07, 0x43, 0x61, 0x6c, 0x6c, 0x52, 0x73, 0x70, 0x12, 0x12, 0x0a, 0x04,
0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x64, 0x61, 0x74, 0x61,
0x32, 0x33, 0x0a, 0x0b, 0x54, 0x65, 0x73, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12,
0x24, 0x0a, 0x04, 0x43, 0x61, 0x6c, 0x6c, 0x12, 0x0d, 0x2e, 0x74, 0x65, 0x73, 0x74, 0x2e, 0x43,
0x61, 0x6c, 0x6c, 0x52, 0x65, 0x71, 0x1a, 0x0d, 0x2e, 0x74, 0x65, 0x73, 0x74, 0x2e, 0x43, 0x61,
0x6c, 0x6c, 0x52, 0x73, 0x70, 0x42, 0x0b, 0x5a, 0x09, 0x2e, 0x2f, 0x3b, 0x74, 0x65, 0x73, 0x74,
0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_test_proto_rawDescOnce sync.Once
file_test_proto_rawDescData = file_test_proto_rawDesc
)
func file_test_proto_rawDescGZIP() []byte {
file_test_proto_rawDescOnce.Do(func() {
file_test_proto_rawDescData = protoimpl.X.CompressGZIP(file_test_proto_rawDescData)
})
return file_test_proto_rawDescData
}
var file_test_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_test_proto_goTypes = []interface{}{
(*CallReq)(nil), // 0: test.CallReq
(*CallRsp)(nil), // 1: test.CallRsp
}
var file_test_proto_depIdxs = []int32{
0, // 0: test.TestService.Call:input_type -> test.CallReq
1, // 1: test.TestService.Call:output_type -> test.CallRsp
1, // [1:2] is the sub-list for method output_type
0, // [0:1] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_test_proto_init() }
func file_test_proto_init() {
if File_test_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_test_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*CallReq); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_test_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*CallRsp); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_test_proto_rawDesc,
NumEnums: 0,
NumMessages: 2,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_test_proto_goTypes,
DependencyIndexes: file_test_proto_depIdxs,
MessageInfos: file_test_proto_msgTypes,
}.Build()
File_test_proto = out.File
file_test_proto_rawDesc = nil
file_test_proto_goTypes = nil
file_test_proto_depIdxs = nil
}

View File

@ -1,18 +0,0 @@
syntax = "proto3";
package test;
option go_package = "./;testpb";
service TestService {
rpc Call(CallReq) returns (CallRsp);
}
message CallReq {
string data = 1;
}
message CallRsp {
string data = 1;
}

View File

@ -1,109 +0,0 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.3.0
// - protoc v4.25.2
// source: test.proto
package testpb
import (
context "context"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
)
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7
const (
TestService_Call_FullMethodName = "/test.TestService/Call"
)
// TestServiceClient is the client API for TestService service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type TestServiceClient interface {
Call(ctx context.Context, in *CallReq, opts ...grpc.CallOption) (*CallRsp, error)
}
type testServiceClient struct {
cc grpc.ClientConnInterface
}
func NewTestServiceClient(cc grpc.ClientConnInterface) TestServiceClient {
return &testServiceClient{cc}
}
func (c *testServiceClient) Call(ctx context.Context, in *CallReq, opts ...grpc.CallOption) (*CallRsp, error) {
out := new(CallRsp)
err := c.cc.Invoke(ctx, TestService_Call_FullMethodName, in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// TestServiceServer is the server API for TestService service.
// All implementations must embed UnimplementedTestServiceServer
// for forward compatibility
type TestServiceServer interface {
Call(context.Context, *CallReq) (*CallRsp, error)
mustEmbedUnimplementedTestServiceServer()
}
// UnimplementedTestServiceServer must be embedded to have forward compatible implementations.
type UnimplementedTestServiceServer struct {
}
func (UnimplementedTestServiceServer) Call(context.Context, *CallReq) (*CallRsp, error) {
return nil, status.Errorf(codes.Unimplemented, "method Call not implemented")
}
func (UnimplementedTestServiceServer) mustEmbedUnimplementedTestServiceServer() {}
// UnsafeTestServiceServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to TestServiceServer will
// result in compilation errors.
type UnsafeTestServiceServer interface {
mustEmbedUnimplementedTestServiceServer()
}
func RegisterTestServiceServer(s grpc.ServiceRegistrar, srv TestServiceServer) {
s.RegisterService(&TestService_ServiceDesc, srv)
}
func _TestService_Call_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(CallReq)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(TestServiceServer).Call(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: TestService_Call_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(TestServiceServer).Call(ctx, req.(*CallReq))
}
return interceptor(ctx, in, info, handler)
}
// TestService_ServiceDesc is the grpc.ServiceDesc for TestService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var TestService_ServiceDesc = grpc.ServiceDesc{
ServiceName: "test.TestService",
HandlerType: (*TestServiceServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "Call",
Handler: _TestService_Call_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "test.proto",
}

View File

@ -1,21 +1,488 @@
// +build ignore
/*
*
* Copyright 2016 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
/*
Package reflection implements server reflection service.
The service implemented is defined in:
https://github.com/grpc/grpc/blob/master/src/proto/grpc/reflection/v1alpha/reflection.proto.
To register server reflection on a gRPC server:
import "google.golang.org/grpc/reflection"
s := grpc.NewServer()
pb.RegisterYourOwnServer(s, &server{})
// Register reflection service on gRPC server.
reflection.Register(s)
s.Serve(lis)
*/
package grpc package grpc
import ( import (
"google.golang.org/grpc/reflection" "bytes"
"google.golang.org/protobuf/reflect/protodesc" "compress/gzip"
"fmt"
"io"
"io/ioutil"
"reflect"
"sort"
"sync"
"github.com/golang/protobuf/proto"
dpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
rpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
"google.golang.org/grpc/status"
) )
type Reflector interface { type serverReflectionServer struct {
protodesc.Resolver rpb.UnimplementedServerReflectionServer
reflection.ServiceInfoProvider s *grpc.Server
reflection.ExtensionResolver
initSymbols sync.Once
serviceNames []string
symbols map[string]*dpb.FileDescriptorProto // map of fully-qualified names to files
} }
const ( // Register registers the server reflection service on the given gRPC server.
// ReflectV1ServiceName is the fully-qualified name of the v1 version of the reflection service. func Register(s *grpc.Server) {
ReflectV1ServiceName = "grpc.reflection.v1.ServerReflection" rpb.RegisterServerReflectionServer(s, &serverReflectionServer{
// ReflectServiceURLPathV1 is the full path for reflection service endpoint s: s,
ReflectServiceURLPathV1 = "/" + ReflectV1ServiceName + "/" })
// ReflectMethodName is the reflection service name }
ReflectMethodName = "ServerReflectionInfo"
) // protoMessage is used for type assertion on proto messages.
// Generated proto message implements function Descriptor(), but Descriptor()
// is not part of interface proto.Message. This interface is needed to
// call Descriptor().
type protoMessage interface {
Descriptor() ([]byte, []int)
}
func (s *serverReflectionServer) getSymbols() (svcNames []string, symbolIndex map[string]*dpb.FileDescriptorProto) {
s.initSymbols.Do(func() {
serviceInfo := s.s.GetServiceInfo()
s.symbols = map[string]*dpb.FileDescriptorProto{}
s.serviceNames = make([]string, 0, len(serviceInfo))
processed := map[string]struct{}{}
for svc, info := range serviceInfo {
s.serviceNames = append(s.serviceNames, svc)
fdenc, ok := parseMetadata(info.Metadata)
if !ok {
continue
}
fd, err := decodeFileDesc(fdenc)
if err != nil {
continue
}
s.processFile(fd, processed)
}
sort.Strings(s.serviceNames)
})
return s.serviceNames, s.symbols
}
func (s *serverReflectionServer) processFile(fd *dpb.FileDescriptorProto, processed map[string]struct{}) {
filename := fd.GetName()
if _, ok := processed[filename]; ok {
return
}
processed[filename] = struct{}{}
prefix := fd.GetPackage()
for _, msg := range fd.MessageType {
s.processMessage(fd, prefix, msg)
}
for _, en := range fd.EnumType {
s.processEnum(fd, prefix, en)
}
for _, ext := range fd.Extension {
s.processField(fd, prefix, ext)
}
for _, svc := range fd.Service {
svcName := fqn(prefix, svc.GetName())
s.symbols[svcName] = fd
for _, meth := range svc.Method {
name := fqn(svcName, meth.GetName())
s.symbols[name] = fd
}
}
for _, dep := range fd.Dependency {
fdenc := proto.FileDescriptor(dep)
fdDep, err := decodeFileDesc(fdenc)
if err != nil {
continue
}
s.processFile(fdDep, processed)
}
}
func (s *serverReflectionServer) processMessage(fd *dpb.FileDescriptorProto, prefix string, msg *dpb.DescriptorProto) {
msgName := fqn(prefix, msg.GetName())
s.symbols[msgName] = fd
for _, nested := range msg.NestedType {
s.processMessage(fd, msgName, nested)
}
for _, en := range msg.EnumType {
s.processEnum(fd, msgName, en)
}
for _, ext := range msg.Extension {
s.processField(fd, msgName, ext)
}
for _, fld := range msg.Field {
s.processField(fd, msgName, fld)
}
for _, oneof := range msg.OneofDecl {
oneofName := fqn(msgName, oneof.GetName())
s.symbols[oneofName] = fd
}
}
func (s *serverReflectionServer) processEnum(fd *dpb.FileDescriptorProto, prefix string, en *dpb.EnumDescriptorProto) {
enName := fqn(prefix, en.GetName())
s.symbols[enName] = fd
for _, val := range en.Value {
valName := fqn(enName, val.GetName())
s.symbols[valName] = fd
}
}
func (s *serverReflectionServer) processField(fd *dpb.FileDescriptorProto, prefix string, fld *dpb.FieldDescriptorProto) {
fldName := fqn(prefix, fld.GetName())
s.symbols[fldName] = fd
}
func fqn(prefix, name string) string {
if prefix == "" {
return name
}
return prefix + "." + name
}
// fileDescForType gets the file descriptor for the given type.
// The given type should be a proto message.
func (s *serverReflectionServer) fileDescForType(st reflect.Type) (*dpb.FileDescriptorProto, error) {
m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(protoMessage)
if !ok {
return nil, fmt.Errorf("failed to create message from type: %v", st)
}
enc, _ := m.Descriptor()
return decodeFileDesc(enc)
}
// decodeFileDesc does decompression and unmarshalling on the given
// file descriptor byte slice.
func decodeFileDesc(enc []byte) (*dpb.FileDescriptorProto, error) {
raw, err := decompress(enc)
if err != nil {
return nil, fmt.Errorf("failed to decompress enc: %v", err)
}
fd := new(dpb.FileDescriptorProto)
if err := proto.Unmarshal(raw, fd); err != nil {
return nil, fmt.Errorf("bad descriptor: %v", err)
}
return fd, nil
}
// decompress does gzip decompression.
func decompress(b []byte) ([]byte, error) {
r, err := gzip.NewReader(bytes.NewReader(b))
if err != nil {
return nil, fmt.Errorf("bad gzipped descriptor: %v", err)
}
out, err := ioutil.ReadAll(r)
if err != nil {
return nil, fmt.Errorf("bad gzipped descriptor: %v", err)
}
return out, nil
}
func typeForName(name string) (reflect.Type, error) {
pt := proto.MessageType(name)
if pt == nil {
return nil, fmt.Errorf("unknown type: %q", name)
}
st := pt.Elem()
return st, nil
}
func fileDescContainingExtension(st reflect.Type, ext int32) (*dpb.FileDescriptorProto, error) {
m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(proto.Message)
if !ok {
return nil, fmt.Errorf("failed to create message from type: %v", st)
}
var extDesc *proto.ExtensionDesc
for id, desc := range proto.RegisteredExtensions(m) {
if id == ext {
extDesc = desc
break
}
}
if extDesc == nil {
return nil, fmt.Errorf("failed to find registered extension for extension number %v", ext)
}
return decodeFileDesc(proto.FileDescriptor(extDesc.Filename))
}
func (s *serverReflectionServer) allExtensionNumbersForType(st reflect.Type) ([]int32, error) {
m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(proto.Message)
if !ok {
return nil, fmt.Errorf("failed to create message from type: %v", st)
}
exts := proto.RegisteredExtensions(m)
out := make([]int32, 0, len(exts))
for id := range exts {
out = append(out, id)
}
return out, nil
}
// fileDescWithDependencies returns a slice of serialized fileDescriptors in
// wire format ([]byte). The fileDescriptors will include fd and all the
// transitive dependencies of fd with names not in sentFileDescriptors.
func fileDescWithDependencies(fd *dpb.FileDescriptorProto, sentFileDescriptors map[string]bool) ([][]byte, error) {
r := [][]byte{}
queue := []*dpb.FileDescriptorProto{fd}
for len(queue) > 0 {
currentfd := queue[0]
queue = queue[1:]
if sent := sentFileDescriptors[currentfd.GetName()]; len(r) == 0 || !sent {
sentFileDescriptors[currentfd.GetName()] = true
currentfdEncoded, err := proto.Marshal(currentfd)
if err != nil {
return nil, err
}
r = append(r, currentfdEncoded)
}
for _, dep := range currentfd.Dependency {
fdenc := proto.FileDescriptor(dep)
fdDep, err := decodeFileDesc(fdenc)
if err != nil {
continue
}
queue = append(queue, fdDep)
}
}
return r, nil
}
// fileDescEncodingByFilename finds the file descriptor for given filename,
// finds all of its previously unsent transitive dependencies, does marshalling
// on them, and returns the marshalled result.
func (s *serverReflectionServer) fileDescEncodingByFilename(name string, sentFileDescriptors map[string]bool) ([][]byte, error) {
enc := proto.FileDescriptor(name)
if enc == nil {
return nil, fmt.Errorf("unknown file: %v", name)
}
fd, err := decodeFileDesc(enc)
if err != nil {
return nil, err
}
return fileDescWithDependencies(fd, sentFileDescriptors)
}
// parseMetadata finds the file descriptor bytes specified meta.
// For SupportPackageIsVersion4, m is the name of the proto file, we
// call proto.FileDescriptor to get the byte slice.
// For SupportPackageIsVersion3, m is a byte slice itself.
func parseMetadata(meta interface{}) ([]byte, bool) {
// Check if meta is the file name.
if fileNameForMeta, ok := meta.(string); ok {
return proto.FileDescriptor(fileNameForMeta), true
}
// Check if meta is the byte slice.
if enc, ok := meta.([]byte); ok {
return enc, true
}
return nil, false
}
// fileDescEncodingContainingSymbol finds the file descriptor containing the
// given symbol, finds all of its previously unsent transitive dependencies,
// does marshalling on them, and returns the marshalled result. The given symbol
// can be a type, a service or a method.
func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string, sentFileDescriptors map[string]bool) ([][]byte, error) {
_, symbols := s.getSymbols()
fd := symbols[name]
if fd == nil {
// Check if it's a type name that was not present in the
// transitive dependencies of the registered services.
if st, err := typeForName(name); err == nil {
fd, err = s.fileDescForType(st)
if err != nil {
return nil, err
}
}
}
if fd == nil {
return nil, fmt.Errorf("unknown symbol: %v", name)
}
return fileDescWithDependencies(fd, sentFileDescriptors)
}
// fileDescEncodingContainingExtension finds the file descriptor containing
// given extension, finds all of its previously unsent transitive dependencies,
// does marshalling on them, and returns the marshalled result.
func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32, sentFileDescriptors map[string]bool) ([][]byte, error) {
st, err := typeForName(typeName)
if err != nil {
return nil, err
}
fd, err := fileDescContainingExtension(st, extNum)
if err != nil {
return nil, err
}
return fileDescWithDependencies(fd, sentFileDescriptors)
}
// allExtensionNumbersForTypeName returns all extension numbers for the given type.
func (s *serverReflectionServer) allExtensionNumbersForTypeName(name string) ([]int32, error) {
st, err := typeForName(name)
if err != nil {
return nil, err
}
extNums, err := s.allExtensionNumbersForType(st)
if err != nil {
return nil, err
}
return extNums, nil
}
// ServerReflectionInfo is the reflection service handler.
func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflection_ServerReflectionInfoServer) error {
sentFileDescriptors := make(map[string]bool)
for {
in, err := stream.Recv()
if err == io.EOF {
return nil
}
if err != nil {
return err
}
out := &rpb.ServerReflectionResponse{
ValidHost: in.Host,
OriginalRequest: in,
}
switch req := in.MessageRequest.(type) {
case *rpb.ServerReflectionRequest_FileByFilename:
b, err := s.fileDescEncodingByFilename(req.FileByFilename, sentFileDescriptors)
if err != nil {
out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &rpb.ErrorResponse{
ErrorCode: int32(codes.NotFound),
ErrorMessage: err.Error(),
},
}
} else {
out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{
FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: b},
}
}
case *rpb.ServerReflectionRequest_FileContainingSymbol:
b, err := s.fileDescEncodingContainingSymbol(req.FileContainingSymbol, sentFileDescriptors)
if err != nil {
out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &rpb.ErrorResponse{
ErrorCode: int32(codes.NotFound),
ErrorMessage: err.Error(),
},
}
} else {
out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{
FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: b},
}
}
case *rpb.ServerReflectionRequest_FileContainingExtension:
typeName := req.FileContainingExtension.ContainingType
extNum := req.FileContainingExtension.ExtensionNumber
b, err := s.fileDescEncodingContainingExtension(typeName, extNum, sentFileDescriptors)
if err != nil {
out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &rpb.ErrorResponse{
ErrorCode: int32(codes.NotFound),
ErrorMessage: err.Error(),
},
}
} else {
out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{
FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: b},
}
}
case *rpb.ServerReflectionRequest_AllExtensionNumbersOfType:
extNums, err := s.allExtensionNumbersForTypeName(req.AllExtensionNumbersOfType)
if err != nil {
out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &rpb.ErrorResponse{
ErrorCode: int32(codes.NotFound),
ErrorMessage: err.Error(),
},
}
} else {
out.MessageResponse = &rpb.ServerReflectionResponse_AllExtensionNumbersResponse{
AllExtensionNumbersResponse: &rpb.ExtensionNumberResponse{
BaseTypeName: req.AllExtensionNumbersOfType,
ExtensionNumber: extNums,
},
}
}
case *rpb.ServerReflectionRequest_ListServices:
svcNames, _ := s.getSymbols()
serviceResponses := make([]*rpb.ServiceResponse, len(svcNames))
for i, n := range svcNames {
serviceResponses[i] = &rpb.ServiceResponse{
Name: n,
}
}
out.MessageResponse = &rpb.ServerReflectionResponse_ListServicesResponse{
ListServicesResponse: &rpb.ListServiceResponse{
Service: serviceResponses,
},
}
default:
return status.Errorf(codes.InvalidArgument, "invalid MessageRequest: %v", in.MessageRequest)
}
if err := stream.Send(out); err != nil {
return err
}
}
}

View File

@ -1,489 +0,0 @@
//go:build ignore
// +build ignore
/*
*
* Copyright 2016 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
/*
Package reflection implements server reflection service.
The service implemented is defined in:
https://github.com/grpc/grpc/blob/master/src/proto/grpc/reflection/v1alpha/reflection.proto.
To register server reflection on a gRPC server:
import "google.golang.org/grpc/reflection"
s := grpc.NewServer()
pb.RegisterYourOwnServer(s, &server{})
// Register reflection service on gRPC server.
reflection.Register(s)
s.Serve(lis)
*/
package grpc
import (
"bytes"
"compress/gzip"
"fmt"
"io"
"io/ioutil"
"reflect"
"sort"
"sync"
"github.com/golang/protobuf/proto"
dpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
rpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
"google.golang.org/grpc/status"
)
type serverReflectionServer struct {
rpb.UnimplementedServerReflectionServer
s *grpc.Server
initSymbols sync.Once
serviceNames []string
symbols map[string]*dpb.FileDescriptorProto // map of fully-qualified names to files
}
// Register registers the server reflection service on the given gRPC server.
func Register(s *grpc.Server) {
rpb.RegisterServerReflectionServer(s, &serverReflectionServer{
s: s,
})
}
// protoMessage is used for type assertion on proto messages.
// Generated proto message implements function Descriptor(), but Descriptor()
// is not part of interface proto.Message. This interface is needed to
// call Descriptor().
type protoMessage interface {
Descriptor() ([]byte, []int)
}
func (s *serverReflectionServer) getSymbols() (svcNames []string, symbolIndex map[string]*dpb.FileDescriptorProto) {
s.initSymbols.Do(func() {
serviceInfo := s.s.GetServiceInfo()
s.symbols = map[string]*dpb.FileDescriptorProto{}
s.serviceNames = make([]string, 0, len(serviceInfo))
processed := map[string]struct{}{}
for svc, info := range serviceInfo {
s.serviceNames = append(s.serviceNames, svc)
fdenc, ok := parseMetadata(info.Metadata)
if !ok {
continue
}
fd, err := decodeFileDesc(fdenc)
if err != nil {
continue
}
s.processFile(fd, processed)
}
sort.Strings(s.serviceNames)
})
return s.serviceNames, s.symbols
}
func (s *serverReflectionServer) processFile(fd *dpb.FileDescriptorProto, processed map[string]struct{}) {
filename := fd.GetName()
if _, ok := processed[filename]; ok {
return
}
processed[filename] = struct{}{}
prefix := fd.GetPackage()
for _, msg := range fd.MessageType {
s.processMessage(fd, prefix, msg)
}
for _, en := range fd.EnumType {
s.processEnum(fd, prefix, en)
}
for _, ext := range fd.Extension {
s.processField(fd, prefix, ext)
}
for _, svc := range fd.Service {
svcName := fqn(prefix, svc.GetName())
s.symbols[svcName] = fd
for _, meth := range svc.Method {
name := fqn(svcName, meth.GetName())
s.symbols[name] = fd
}
}
for _, dep := range fd.Dependency {
fdenc := proto.FileDescriptor(dep)
fdDep, err := decodeFileDesc(fdenc)
if err != nil {
continue
}
s.processFile(fdDep, processed)
}
}
func (s *serverReflectionServer) processMessage(fd *dpb.FileDescriptorProto, prefix string, msg *dpb.DescriptorProto) {
msgName := fqn(prefix, msg.GetName())
s.symbols[msgName] = fd
for _, nested := range msg.NestedType {
s.processMessage(fd, msgName, nested)
}
for _, en := range msg.EnumType {
s.processEnum(fd, msgName, en)
}
for _, ext := range msg.Extension {
s.processField(fd, msgName, ext)
}
for _, fld := range msg.Field {
s.processField(fd, msgName, fld)
}
for _, oneof := range msg.OneofDecl {
oneofName := fqn(msgName, oneof.GetName())
s.symbols[oneofName] = fd
}
}
func (s *serverReflectionServer) processEnum(fd *dpb.FileDescriptorProto, prefix string, en *dpb.EnumDescriptorProto) {
enName := fqn(prefix, en.GetName())
s.symbols[enName] = fd
for _, val := range en.Value {
valName := fqn(enName, val.GetName())
s.symbols[valName] = fd
}
}
func (s *serverReflectionServer) processField(fd *dpb.FileDescriptorProto, prefix string, fld *dpb.FieldDescriptorProto) {
fldName := fqn(prefix, fld.GetName())
s.symbols[fldName] = fd
}
func fqn(prefix, name string) string {
if prefix == "" {
return name
}
return prefix + "." + name
}
// fileDescForType gets the file descriptor for the given type.
// The given type should be a proto message.
func (s *serverReflectionServer) fileDescForType(st reflect.Type) (*dpb.FileDescriptorProto, error) {
m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(protoMessage)
if !ok {
return nil, fmt.Errorf("failed to create message from type: %v", st)
}
enc, _ := m.Descriptor()
return decodeFileDesc(enc)
}
// decodeFileDesc does decompression and unmarshalling on the given
// file descriptor byte slice.
func decodeFileDesc(enc []byte) (*dpb.FileDescriptorProto, error) {
raw, err := decompress(enc)
if err != nil {
return nil, fmt.Errorf("failed to decompress enc: %v", err)
}
fd := new(dpb.FileDescriptorProto)
if err := proto.Unmarshal(raw, fd); err != nil {
return nil, fmt.Errorf("bad descriptor: %v", err)
}
return fd, nil
}
// decompress does gzip decompression.
func decompress(b []byte) ([]byte, error) {
r, err := gzip.NewReader(bytes.NewReader(b))
if err != nil {
return nil, fmt.Errorf("bad gzipped descriptor: %v", err)
}
out, err := ioutil.ReadAll(r)
if err != nil {
return nil, fmt.Errorf("bad gzipped descriptor: %v", err)
}
return out, nil
}
func typeForName(name string) (reflect.Type, error) {
pt := proto.MessageType(name)
if pt == nil {
return nil, fmt.Errorf("unknown type: %q", name)
}
st := pt.Elem()
return st, nil
}
func fileDescContainingExtension(st reflect.Type, ext int32) (*dpb.FileDescriptorProto, error) {
m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(proto.Message)
if !ok {
return nil, fmt.Errorf("failed to create message from type: %v", st)
}
var extDesc *proto.ExtensionDesc
for id, desc := range proto.RegisteredExtensions(m) {
if id == ext {
extDesc = desc
break
}
}
if extDesc == nil {
return nil, fmt.Errorf("failed to find registered extension for extension number %v", ext)
}
return decodeFileDesc(proto.FileDescriptor(extDesc.Filename))
}
func (s *serverReflectionServer) allExtensionNumbersForType(st reflect.Type) ([]int32, error) {
m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(proto.Message)
if !ok {
return nil, fmt.Errorf("failed to create message from type: %v", st)
}
exts := proto.RegisteredExtensions(m)
out := make([]int32, 0, len(exts))
for id := range exts {
out = append(out, id)
}
return out, nil
}
// fileDescWithDependencies returns a slice of serialized fileDescriptors in
// wire format ([]byte). The fileDescriptors will include fd and all the
// transitive dependencies of fd with names not in sentFileDescriptors.
func fileDescWithDependencies(fd *dpb.FileDescriptorProto, sentFileDescriptors map[string]bool) ([][]byte, error) {
r := [][]byte{}
queue := []*dpb.FileDescriptorProto{fd}
for len(queue) > 0 {
currentfd := queue[0]
queue = queue[1:]
if sent := sentFileDescriptors[currentfd.GetName()]; len(r) == 0 || !sent {
sentFileDescriptors[currentfd.GetName()] = true
currentfdEncoded, err := proto.Marshal(currentfd)
if err != nil {
return nil, err
}
r = append(r, currentfdEncoded)
}
for _, dep := range currentfd.Dependency {
fdenc := proto.FileDescriptor(dep)
fdDep, err := decodeFileDesc(fdenc)
if err != nil {
continue
}
queue = append(queue, fdDep)
}
}
return r, nil
}
// fileDescEncodingByFilename finds the file descriptor for given filename,
// finds all of its previously unsent transitive dependencies, does marshalling
// on them, and returns the marshalled result.
func (s *serverReflectionServer) fileDescEncodingByFilename(name string, sentFileDescriptors map[string]bool) ([][]byte, error) {
enc := proto.FileDescriptor(name)
if enc == nil {
return nil, fmt.Errorf("unknown file: %v", name)
}
fd, err := decodeFileDesc(enc)
if err != nil {
return nil, err
}
return fileDescWithDependencies(fd, sentFileDescriptors)
}
// parseMetadata finds the file descriptor bytes specified meta.
// For SupportPackageIsVersion4, m is the name of the proto file, we
// call proto.FileDescriptor to get the byte slice.
// For SupportPackageIsVersion3, m is a byte slice itself.
func parseMetadata(meta interface{}) ([]byte, bool) {
// Check if meta is the file name.
if fileNameForMeta, ok := meta.(string); ok {
return proto.FileDescriptor(fileNameForMeta), true
}
// Check if meta is the byte slice.
if enc, ok := meta.([]byte); ok {
return enc, true
}
return nil, false
}
// fileDescEncodingContainingSymbol finds the file descriptor containing the
// given symbol, finds all of its previously unsent transitive dependencies,
// does marshalling on them, and returns the marshalled result. The given symbol
// can be a type, a service or a method.
func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string, sentFileDescriptors map[string]bool) ([][]byte, error) {
_, symbols := s.getSymbols()
fd := symbols[name]
if fd == nil {
// Check if it's a type name that was not present in the
// transitive dependencies of the registered services.
if st, err := typeForName(name); err == nil {
fd, err = s.fileDescForType(st)
if err != nil {
return nil, err
}
}
}
if fd == nil {
return nil, fmt.Errorf("unknown symbol: %v", name)
}
return fileDescWithDependencies(fd, sentFileDescriptors)
}
// fileDescEncodingContainingExtension finds the file descriptor containing
// given extension, finds all of its previously unsent transitive dependencies,
// does marshalling on them, and returns the marshalled result.
func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32, sentFileDescriptors map[string]bool) ([][]byte, error) {
st, err := typeForName(typeName)
if err != nil {
return nil, err
}
fd, err := fileDescContainingExtension(st, extNum)
if err != nil {
return nil, err
}
return fileDescWithDependencies(fd, sentFileDescriptors)
}
// allExtensionNumbersForTypeName returns all extension numbers for the given type.
func (s *serverReflectionServer) allExtensionNumbersForTypeName(name string) ([]int32, error) {
st, err := typeForName(name)
if err != nil {
return nil, err
}
extNums, err := s.allExtensionNumbersForType(st)
if err != nil {
return nil, err
}
return extNums, nil
}
// ServerReflectionInfo is the reflection service handler.
func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflection_ServerReflectionInfoServer) error {
sentFileDescriptors := make(map[string]bool)
for {
in, err := stream.Recv()
if err == io.EOF {
return nil
}
if err != nil {
return err
}
out := &rpb.ServerReflectionResponse{
ValidHost: in.Host,
OriginalRequest: in,
}
switch req := in.MessageRequest.(type) {
case *rpb.ServerReflectionRequest_FileByFilename:
b, err := s.fileDescEncodingByFilename(req.FileByFilename, sentFileDescriptors)
if err != nil {
out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &rpb.ErrorResponse{
ErrorCode: int32(codes.NotFound),
ErrorMessage: err.Error(),
},
}
} else {
out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{
FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: b},
}
}
case *rpb.ServerReflectionRequest_FileContainingSymbol:
b, err := s.fileDescEncodingContainingSymbol(req.FileContainingSymbol, sentFileDescriptors)
if err != nil {
out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &rpb.ErrorResponse{
ErrorCode: int32(codes.NotFound),
ErrorMessage: err.Error(),
},
}
} else {
out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{
FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: b},
}
}
case *rpb.ServerReflectionRequest_FileContainingExtension:
typeName := req.FileContainingExtension.ContainingType
extNum := req.FileContainingExtension.ExtensionNumber
b, err := s.fileDescEncodingContainingExtension(typeName, extNum, sentFileDescriptors)
if err != nil {
out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &rpb.ErrorResponse{
ErrorCode: int32(codes.NotFound),
ErrorMessage: err.Error(),
},
}
} else {
out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{
FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: b},
}
}
case *rpb.ServerReflectionRequest_AllExtensionNumbersOfType:
extNums, err := s.allExtensionNumbersForTypeName(req.AllExtensionNumbersOfType)
if err != nil {
out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &rpb.ErrorResponse{
ErrorCode: int32(codes.NotFound),
ErrorMessage: err.Error(),
},
}
} else {
out.MessageResponse = &rpb.ServerReflectionResponse_AllExtensionNumbersResponse{
AllExtensionNumbersResponse: &rpb.ExtensionNumberResponse{
BaseTypeName: req.AllExtensionNumbersOfType,
ExtensionNumber: extNums,
},
}
}
case *rpb.ServerReflectionRequest_ListServices:
svcNames, _ := s.getSymbols()
serviceResponses := make([]*rpb.ServiceResponse, len(svcNames))
for i, n := range svcNames {
serviceResponses[i] = &rpb.ServiceResponse{
Name: n,
}
}
out.MessageResponse = &rpb.ServerReflectionResponse_ListServicesResponse{
ListServicesResponse: &rpb.ListServiceResponse{
Service: serviceResponses,
},
}
default:
return status.Errorf(codes.InvalidArgument, "invalid MessageRequest: %v", in.MessageRequest)
}
if err := stream.Send(out); err != nil {
return err
}
}
}

View File

@ -1,62 +0,0 @@
package grpc
import (
"fmt"
"testing"
_ "go.unistack.org/micro-server-grpc/v3/proto"
"go.unistack.org/micro/v3/server"
"google.golang.org/grpc"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
)
type reflector struct{}
func (r *reflector) FindFileByPath(path string) (protoreflect.FileDescriptor, error) {
fd, err := protoregistry.GlobalFiles.FindFileByPath(path)
if err != nil {
fmt.Printf("err: %v\n", err)
return nil, err
}
return fd, nil
}
func (r *reflector) FindDescriptorByName(name protoreflect.FullName) (protoreflect.Descriptor, error) {
fd, err := protoregistry.GlobalFiles.FindDescriptorByName(name)
if err != nil {
return nil, err
}
return fd, nil
}
func (r *reflector) GetServiceInfo() map[string]grpc.ServiceInfo {
fmt.Printf("GetServiceInfo\n")
return nil
}
func (r *reflector) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) {
fmt.Printf("FindExtensionByName field %#+v\n", field)
return nil, nil
}
func (r *reflector) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) {
fmt.Printf("FindExtensionByNumber message %#+v field %#+v\n", message, field)
return nil, nil
}
func (r *reflector) RangeExtensionsByMessage(message protoreflect.FullName, f func(protoreflect.ExtensionType) bool) {
fmt.Printf("RangeExtensionsByMessage\n")
}
func TestReflector(t *testing.T) {
srv := NewServer(Reflection(&reflector{}), server.Address(":12345"))
if err := srv.Init(); err != nil {
t.Fatal(err)
}
if err := srv.Start(); err != nil {
t.Fatal(err)
}
t.Logf("addr %s", srv.Options().Address)
select {}
}

View File

@ -8,7 +8,10 @@ import (
"go.unistack.org/micro/v3/server" "go.unistack.org/micro/v3/server"
) )
var _ server.Request = &rpcRequest{} var (
_ server.Request = &rpcRequest{}
_ server.Message = &rpcMessage{}
)
type rpcRequest struct { type rpcRequest struct {
rw io.ReadWriter rw io.ReadWriter
@ -22,6 +25,14 @@ type rpcRequest struct {
stream bool stream bool
} }
type rpcMessage struct {
payload interface{}
codec codec.Codec
header metadata.Metadata
topic string
contentType string
}
func (r *rpcRequest) ContentType() string { func (r *rpcRequest) ContentType() string {
return r.contentType return r.contentType
} }
@ -47,7 +58,11 @@ func (r *rpcRequest) Header() metadata.Metadata {
} }
func (r *rpcRequest) Read() ([]byte, error) { func (r *rpcRequest) Read() ([]byte, error) {
return nil, nil f := &codec.Frame{}
if err := r.codec.ReadBody(r.rw, f); err != nil {
return nil, err
}
return f.Data, nil
} }
func (r *rpcRequest) Stream() bool { func (r *rpcRequest) Stream() bool {
@ -57,3 +72,23 @@ func (r *rpcRequest) Stream() bool {
func (r *rpcRequest) Body() interface{} { func (r *rpcRequest) Body() interface{} {
return r.payload return r.payload
} }
func (r *rpcMessage) ContentType() string {
return r.contentType
}
func (r *rpcMessage) Topic() string {
return r.topic
}
func (r *rpcMessage) Body() interface{} {
return r.payload
}
func (r *rpcMessage) Header() metadata.Metadata {
return r.header
}
func (r *rpcMessage) Codec() codec.Codec {
return r.codec
}

View File

@ -27,5 +27,8 @@ func (r *rpcResponse) WriteHeader(hdr metadata.Metadata) {
} }
func (r *rpcResponse) Write(b []byte) error { func (r *rpcResponse) Write(b []byte) error {
return nil return r.codec.Write(r.rw, &codec.Message{
Header: r.header,
Body: b,
}, nil)
} }

View File

@ -4,47 +4,17 @@ import (
"context" "context"
"fmt" "fmt"
"reflect" "reflect"
"runtime/debug"
"strings" "strings"
"go.unistack.org/micro/v3/broker" "go.unistack.org/micro/v3/broker"
"go.unistack.org/micro/v3/codec" "go.unistack.org/micro/v3/errors"
"go.unistack.org/micro/v3/logger" "go.unistack.org/micro/v3/logger"
"go.unistack.org/micro/v3/metadata" "go.unistack.org/micro/v3/metadata"
"go.unistack.org/micro/v3/options"
"go.unistack.org/micro/v3/register" "go.unistack.org/micro/v3/register"
"go.unistack.org/micro/v3/server" "go.unistack.org/micro/v3/server"
) )
var _ server.Message = &rpcMessage{}
type rpcMessage struct {
payload interface{}
codec codec.Codec
header metadata.Metadata
topic string
contentType string
}
func (r *rpcMessage) ContentType() string {
return r.contentType
}
func (r *rpcMessage) Topic() string {
return r.topic
}
func (r *rpcMessage) Body() interface{} {
return r.payload
}
func (r *rpcMessage) Header() metadata.Metadata {
return r.header
}
func (r *rpcMessage) Codec() codec.Codec {
return r.codec
}
type handler struct { type handler struct {
reqType reflect.Type reqType reflect.Type
ctxType reflect.Type ctxType reflect.Type
@ -134,6 +104,16 @@ func newSubscriber(topic string, sub interface{}, opts ...server.SubscriberOptio
func (g *Server) createSubHandler(sb *subscriber, opts server.Options) broker.Handler { func (g *Server) createSubHandler(sb *subscriber, opts server.Options) broker.Handler {
return func(p broker.Event) (err error) { return func(p broker.Event) (err error) {
defer func() {
if r := recover(); r != nil {
if g.opts.Logger.V(logger.ErrorLevel) {
g.opts.Logger.Error(g.opts.Context, "panic recovered: ", r)
g.opts.Logger.Error(g.opts.Context, string(debug.Stack()))
}
err = errors.InternalServerError(g.opts.Name+".subscriber", "panic recovered: %v", r)
}
}()
msg := p.Message() msg := p.Message()
// if we don't have headers, create empty map // if we don't have headers, create empty map
if msg.Header == nil { if msg.Header == nil {
@ -152,6 +132,9 @@ func (g *Server) createSubHandler(sb *subscriber, opts server.Options) broker.Ha
hdr := make(map[string]string, len(msg.Header)) hdr := make(map[string]string, len(msg.Header))
for k, v := range msg.Header { for k, v := range msg.Header {
if k == "Content-Type" {
continue
}
hdr[k] = v hdr[k] = v
} }
@ -197,11 +180,9 @@ func (g *Server) createSubHandler(sb *subscriber, opts server.Options) broker.Ha
return nil return nil
} }
opts.Hooks.EachNext(func(hook options.Hook) { for i := len(opts.SubWrappers); i > 0; i-- {
if h, ok := hook.(server.HookSubHandler); ok { fn = opts.SubWrappers[i-1](fn)
fn = h(fn)
} }
})
if g.wg != nil { if g.wg != nil {
g.wg.Add(1) g.wg.Add(1)
@ -248,38 +229,3 @@ func (s *subscriber) Endpoints() []*register.Endpoint {
func (s *subscriber) Options() server.SubscriberOptions { func (s *subscriber) Options() server.SubscriberOptions {
return s.opts return s.opts
} }
func (g *Server) subscribe() error {
config := g.opts
subCtx := config.Context
for sb := range g.subscribers {
if cx := sb.Options().Context; cx != nil {
subCtx = cx
}
opts := []broker.SubscribeOption{
broker.SubscribeContext(subCtx),
broker.SubscribeAutoAck(sb.Options().AutoAck),
broker.SubscribeBodyOnly(sb.Options().BodyOnly),
}
if queue := sb.Options().Queue; len(queue) > 0 {
opts = append(opts, broker.SubscribeGroup(queue))
}
if config.Logger.V(logger.InfoLevel) {
config.Logger.Info(config.Context, "subscribing to topic: "+sb.Topic())
}
sub, err := config.Broker.Subscribe(subCtx, sb.Topic(), g.createSubHandler(sb, config), opts...)
if err != nil {
return err
}
g.subscribers[sb] = []broker.Subscriber{sub}
}
return nil
}

View File

@ -26,7 +26,6 @@ func TestServiceMethod(t *testing.T) {
if err != nil && test.err == true { if err != nil && test.err == true {
continue continue
} }
t.Logf("input %s service %s method %s", test.input, service, method)
// unexpected error // unexpected error
if err != nil && test.err == false { if err != nil && test.err == false {
t.Fatalf("unexpected err %v for %+v", err, test) t.Fatalf("unexpected err %v for %+v", err, test)