Merge branch 'master' into namespace

This commit is contained in:
ben-toogood 2020-04-08 13:44:46 +01:00 committed by GitHub
commit 9f4286fc4e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 319 additions and 127 deletions

View File

@ -55,7 +55,8 @@ func (m *memoryBroker) Connect() error {
return nil return nil
} }
addr, err := maddr.Extract("::") // use 127.0.0.1 to avoid scan of all network interfaces
addr, err := maddr.Extract("127.0.0.1")
if err != nil { if err != nil {
return err return err
} }

View File

@ -11,6 +11,7 @@ import (
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/micro/go-micro/v2/codec" "github.com/micro/go-micro/v2/codec"
"github.com/micro/go-micro/v2/codec/bytes" "github.com/micro/go-micro/v2/codec/bytes"
"github.com/oxtoacart/bpool"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/encoding" "google.golang.org/grpc/encoding"
) )
@ -23,6 +24,9 @@ type wrapCodec struct{ encoding.Codec }
var jsonpbMarshaler = &jsonpb.Marshaler{} var jsonpbMarshaler = &jsonpb.Marshaler{}
var useNumber bool var useNumber bool
// create buffer pool with 16 instances each preallocated with 256 bytes
var bufferPool = bpool.NewSizedBufferPool(16, 256)
var ( var (
defaultGRPCCodecs = map[string]encoding.Codec{ defaultGRPCCodecs = map[string]encoding.Codec{
"application/json": jsonCodec{}, "application/json": jsonCodec{},
@ -106,14 +110,19 @@ func (bytesCodec) Name() string {
} }
func (jsonCodec) Marshal(v interface{}) ([]byte, error) { func (jsonCodec) Marshal(v interface{}) ([]byte, error) {
if pb, ok := v.(proto.Message); ok {
s, err := jsonpbMarshaler.MarshalToString(pb)
return []byte(s), err
}
if b, ok := v.(*bytes.Frame); ok { if b, ok := v.(*bytes.Frame); ok {
return b.Data, nil return b.Data, nil
} }
if pb, ok := v.(proto.Message); ok {
buf := bufferPool.Get()
defer bufferPool.Put(buf)
if err := jsonpbMarshaler.Marshal(buf, pb); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
return json.Marshal(v) return json.Marshal(v)
} }

View File

@ -1,7 +1,7 @@
package bytes package bytes
import ( import (
"errors" "github.com/micro/go-micro/v2/codec"
) )
type Marshaler struct{} type Marshaler struct{}
@ -20,7 +20,7 @@ func (n Marshaler) Marshal(v interface{}) ([]byte, error) {
case *Message: case *Message:
return ve.Body, nil return ve.Body, nil
} }
return nil, errors.New("invalid message") return nil, codec.ErrInvalidMessage
} }
func (n Marshaler) Unmarshal(d []byte, v interface{}) error { func (n Marshaler) Unmarshal(d []byte, v interface{}) error {
@ -30,7 +30,7 @@ func (n Marshaler) Unmarshal(d []byte, v interface{}) error {
case *Message: case *Message:
ve.Body = d ve.Body = d
} }
return errors.New("invalid message") return codec.ErrInvalidMessage
} }
func (n Marshaler) String() string { func (n Marshaler) String() string {

View File

@ -7,7 +7,7 @@ import (
) )
var ( var (
maxMessageSize = 1024 * 1024 * 4 MaxMessageSize = 1024 * 1024 * 4 // 4Mb
maxInt = int(^uint(0) >> 1) maxInt = int(^uint(0) >> 1)
) )
@ -34,8 +34,8 @@ func decode(r io.Reader) (uint8, []byte, error) {
if int64(length) > int64(maxInt) { if int64(length) > int64(maxInt) {
return cf, nil, fmt.Errorf("grpc: received message larger than max length allowed on current machine (%d vs. %d)", length, maxInt) return cf, nil, fmt.Errorf("grpc: received message larger than max length allowed on current machine (%d vs. %d)", length, maxInt)
} }
if int(length) > maxMessageSize { if int(length) > MaxMessageSize {
return cf, nil, fmt.Errorf("grpc: received message larger than max (%d vs. %d)", length, maxMessageSize) return cf, nil, fmt.Errorf("grpc: received message larger than max (%d vs. %d)", length, MaxMessageSize)
} }
msg := make([]byte, int(length)) msg := make([]byte, int(length))

View File

@ -1,21 +1,36 @@
package json package json
import ( import (
"bytes"
"encoding/json" "encoding/json"
"github.com/golang/protobuf/jsonpb" "github.com/golang/protobuf/jsonpb"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/oxtoacart/bpool"
) )
var jsonpbMarshaler = &jsonpb.Marshaler{}
// create buffer pool with 16 instances each preallocated with 256 bytes
var bufferPool = bpool.NewSizedBufferPool(16, 256)
type Marshaler struct{} type Marshaler struct{}
func (j Marshaler) Marshal(v interface{}) ([]byte, error) { func (j Marshaler) Marshal(v interface{}) ([]byte, error) {
if pb, ok := v.(proto.Message); ok {
buf := bufferPool.Get()
defer bufferPool.Put(buf)
if err := jsonpbMarshaler.Marshal(buf, pb); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
return json.Marshal(v) return json.Marshal(v)
} }
func (j Marshaler) Unmarshal(d []byte, v interface{}) error { func (j Marshaler) Unmarshal(d []byte, v interface{}) error {
if pb, ok := v.(proto.Message); ok { if pb, ok := v.(proto.Message); ok {
return jsonpb.UnmarshalString(string(d), pb) return jsonpb.Unmarshal(bytes.NewReader(d), pb)
} }
return json.Unmarshal(d, v) return json.Unmarshal(d, v)
} }

View File

@ -1,17 +1,45 @@
package proto package proto
import ( import (
"bytes"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/micro/go-micro/v2/codec"
"github.com/oxtoacart/bpool"
) )
// create buffer pool with 16 instances each preallocated with 256 bytes
var bufferPool = bpool.NewSizedBufferPool(16, 256)
type Marshaler struct{} type Marshaler struct{}
func (Marshaler) Marshal(v interface{}) ([]byte, error) { func (Marshaler) Marshal(v interface{}) ([]byte, error) {
return proto.Marshal(v.(proto.Message)) pb, ok := v.(proto.Message)
if !ok {
return nil, codec.ErrInvalidMessage
}
// looks not good, but allows to reuse underlining bytes
buf := bufferPool.Get()
pbuf := proto.NewBuffer(buf.Bytes())
defer func() {
bufferPool.Put(bytes.NewBuffer(pbuf.Bytes()))
}()
if err := pbuf.Marshal(pb); err != nil {
return nil, err
}
return pbuf.Bytes(), nil
} }
func (Marshaler) Unmarshal(data []byte, v interface{}) error { func (Marshaler) Unmarshal(data []byte, v interface{}) error {
return proto.Unmarshal(data, v.(proto.Message)) pb, ok := v.(proto.Message)
if !ok {
return codec.ErrInvalidMessage
}
return proto.Unmarshal(data, pb)
} }
func (Marshaler) String() string { func (Marshaler) String() string {

View File

@ -34,7 +34,7 @@ func (md Metadata) Delete(key string) {
// Copy makes a copy of the metadata // Copy makes a copy of the metadata
func Copy(md Metadata) Metadata { func Copy(md Metadata) Metadata {
cmd := make(Metadata) cmd := make(Metadata, len(md))
for k, v := range md { for k, v := range md {
cmd[k] = v cmd[k] = v
} }
@ -86,7 +86,7 @@ func FromContext(ctx context.Context) (Metadata, bool) {
} }
// capitalise all values // capitalise all values
newMD := make(map[string]string, len(md)) newMD := make(Metadata, len(md))
for k, v := range md { for k, v := range md {
newMD[strings.Title(k)] = v newMD[strings.Title(k)] = v
} }
@ -105,7 +105,7 @@ func MergeContext(ctx context.Context, patchMd Metadata, overwrite bool) context
ctx = context.Background() ctx = context.Background()
} }
md, _ := ctx.Value(MetadataKey{}).(Metadata) md, _ := ctx.Value(MetadataKey{}).(Metadata)
cmd := make(Metadata) cmd := make(Metadata, len(md))
for k, v := range md { for k, v := range md {
cmd[k] = v cmd[k] = v
} }

View File

@ -339,8 +339,8 @@ func (c *cache) run() {
c.setStatus(err) c.setStatus(err)
if a > 3 { if a > 3 {
if logger.V(logger.InfoLevel, logger.DefaultLogger) { if logger.V(logger.DebugLevel, logger.DefaultLogger) {
logger.Info("rcache: ", err, " backing off ", d) logger.Debug("rcache: ", err, " backing off ", d)
} }
a = 0 a = 0
} }
@ -364,8 +364,8 @@ func (c *cache) run() {
c.setStatus(err) c.setStatus(err)
if b > 3 { if b > 3 {
if logger.V(logger.InfoLevel, logger.DefaultLogger) { if logger.V(logger.DebugLevel, logger.DefaultLogger) {
logger.Info("rcache: ", err, " backing off ", d) logger.Debug("rcache: ", err, " backing off ", d)
} }
b = 0 b = 0
} }

View File

@ -59,6 +59,9 @@ type grpcServer struct {
started bool started bool
// used for first registration // used for first registration
registered bool registered bool
// registry service instance
rsvc *registry.Service
} }
func init() { func init() {
@ -102,6 +105,9 @@ func (r grpcRouter) ServeRequest(ctx context.Context, req server.Request, rsp se
} }
func (g *grpcServer) configure(opts ...server.Option) { func (g *grpcServer) configure(opts ...server.Option) {
g.Lock()
defer g.Unlock()
// Don't reprocess where there's no config // Don't reprocess where there's no config
if len(opts) == 0 && g.srv != nil { if len(opts) == 0 && g.srv != nil {
return return
@ -127,6 +133,7 @@ func (g *grpcServer) configure(opts ...server.Option) {
gopts = append(gopts, opts...) gopts = append(gopts, opts...)
} }
g.rsvc = nil
g.srv = grpc.NewServer(gopts...) g.srv = grpc.NewServer(gopts...)
} }
@ -559,11 +566,24 @@ func (g *grpcServer) Subscribe(sb server.Subscriber) error {
} }
func (g *grpcServer) Register() error { func (g *grpcServer) Register() error {
g.RLock()
rsvc := g.rsvc
config := g.opts
g.RUnlock()
// if service already filled, reuse it and return early
if rsvc != nil {
rOpts := []registry.RegisterOption{registry.RegisterTTL(config.RegisterTTL)}
if err := config.Registry.Register(rsvc, rOpts...); err != nil {
return err
}
return nil
}
var err error var err error
var advt, host, port string var advt, host, port string
var cacheService bool
// parse address for host, port
config := g.opts
// check the advertise address first // check the advertise address first
// if it exists then use it, otherwise // if it exists then use it, otherwise
@ -584,16 +604,17 @@ func (g *grpcServer) Register() error {
host = advt host = advt
} }
if ip := net.ParseIP(host); ip != nil {
cacheService = true
}
addr, err := addr.Extract(host) addr, err := addr.Extract(host)
if err != nil { if err != nil {
return err return err
} }
// make copy of metadata // make copy of metadata
md := make(meta.Metadata) md := meta.Copy(config.Metadata)
for k, v := range config.Metadata {
md[k] = v
}
// register service // register service
node := &registry.Node{ node := &registry.Node{
@ -646,13 +667,13 @@ func (g *grpcServer) Register() error {
Endpoints: endpoints, Endpoints: endpoints,
} }
g.Lock() g.RLock()
registered := g.registered registered := g.registered
g.Unlock() g.RUnlock()
if !registered { if !registered {
if logger.V(logger.InfoLevel, logger.DefaultLogger) { if logger.V(logger.DebugLevel, logger.DefaultLogger) {
logger.Infof("Registry [%s] Registering node: %s", config.Registry.String(), node.Id) logger.Debugf("Registry [%s] Registering node: %s", config.Registry.String(), node.Id)
} }
} }
@ -671,6 +692,9 @@ func (g *grpcServer) Register() error {
g.Lock() g.Lock()
defer g.Unlock() defer g.Unlock()
if cacheService {
g.rsvc = service
}
g.registered = true g.registered = true
for sb := range g.subscribers { for sb := range g.subscribers {
@ -688,8 +712,8 @@ func (g *grpcServer) Register() error {
opts = append(opts, broker.DisableAutoAck()) opts = append(opts, broker.DisableAutoAck())
} }
if logger.V(logger.InfoLevel, logger.DefaultLogger) { if logger.V(logger.DebugLevel, logger.DefaultLogger) {
logger.Infof("Subscribing to topic: %s", sb.Topic()) logger.Debug("Subscribing to topic: %s", sb.Topic())
} }
sub, err := config.Broker.Subscribe(sb.Topic(), handler, opts...) sub, err := config.Broker.Subscribe(sb.Topic(), handler, opts...)
if err != nil { if err != nil {
@ -705,7 +729,9 @@ func (g *grpcServer) Deregister() error {
var err error var err error
var advt, host, port string var advt, host, port string
g.RLock()
config := g.opts config := g.opts
g.RUnlock()
// check the advertise address first // check the advertise address first
// if it exists then use it, otherwise // if it exists then use it, otherwise
@ -742,14 +768,15 @@ func (g *grpcServer) Deregister() error {
Nodes: []*registry.Node{node}, Nodes: []*registry.Node{node},
} }
if logger.V(logger.InfoLevel, logger.DefaultLogger) { if logger.V(logger.DebugLevel, logger.DefaultLogger) {
logger.Infof("Deregistering node: %s", node.Id) logger.Debugf("Deregistering node: %s", node.Id)
} }
if err := config.Registry.Deregister(service); err != nil { if err := config.Registry.Deregister(service); err != nil {
return err return err
} }
g.Lock() g.Lock()
g.rsvc = nil
if !g.registered { if !g.registered {
g.Unlock() g.Unlock()
@ -760,8 +787,8 @@ func (g *grpcServer) Deregister() error {
for sb, subs := range g.subscribers { for sb, subs := range g.subscribers {
for _, sub := range subs { for _, sub := range subs {
if logger.V(logger.InfoLevel, logger.DefaultLogger) { if logger.V(logger.DebugLevel, logger.DefaultLogger) {
logger.Infof("Unsubscribing from topic: %s", sub.Topic()) logger.Debugf("Unsubscribing from topic: %s", sub.Topic())
} }
sub.Unsubscribe() sub.Unsubscribe()
} }
@ -819,11 +846,14 @@ func (g *grpcServer) Start() error {
if len(g.subscribers) > 0 { if len(g.subscribers) > 0 {
// connect to the broker // connect to the broker
if err := config.Broker.Connect(); err != nil { if err := config.Broker.Connect(); err != nil {
if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
logger.Errorf("Broker [%s] connect error: %v", config.Broker.String(), err)
}
return err return err
} }
if logger.V(logger.InfoLevel, logger.DefaultLogger) { if logger.V(logger.DebugLevel, logger.DefaultLogger) {
logger.Infof("Broker [%s] Connected to %s", config.Broker.String(), config.Broker.Address()) logger.Debugf("Broker [%s] Connected to %s", config.Broker.String(), config.Broker.Address())
} }
} }
@ -900,11 +930,15 @@ func (g *grpcServer) Start() error {
// close transport // close transport
ch <- nil ch <- nil
if logger.V(logger.InfoLevel, logger.DefaultLogger) { if logger.V(logger.DebugLevel, logger.DefaultLogger) {
logger.Infof("Broker [%s] Disconnected from %s", config.Broker.String(), config.Broker.Address()) logger.Debugf("Broker [%s] Disconnected from %s", config.Broker.String(), config.Broker.Address())
} }
// disconnect broker // disconnect broker
config.Broker.Disconnect() if err := config.Broker.Disconnect(); err != nil {
if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
logger.Errorf("Broker [%s] disconnect error: %v", config.Broker.String(), err)
}
}
}() }()
// mark the server as started // mark the server as started
@ -930,6 +964,7 @@ func (g *grpcServer) Stop() error {
select { select {
case err = <-ch: case err = <-ch:
g.Lock() g.Lock()
g.rsvc = nil
g.started = false g.started = false
g.Unlock() g.Unlock()
} }

View File

@ -40,6 +40,8 @@ type rpcServer struct {
subscriber broker.Subscriber subscriber broker.Subscriber
// graceful exit // graceful exit
wg *sync.WaitGroup wg *sync.WaitGroup
rsvc *registry.Service
} }
func newRpcServer(opts ...Option) Server { func newRpcServer(opts ...Option) Server {
@ -459,10 +461,11 @@ func (s *rpcServer) Options() Options {
func (s *rpcServer) Init(opts ...Option) error { func (s *rpcServer) Init(opts ...Option) error {
s.Lock() s.Lock()
defer s.Unlock()
for _, opt := range opts { for _, opt := range opts {
opt(&s.opts) opt(&s.opts)
} }
// update router if its the default // update router if its the default
if s.opts.Router == nil { if s.opts.Router == nil {
r := newRpcRouter() r := newRpcRouter()
@ -472,7 +475,8 @@ func (s *rpcServer) Init(opts ...Option) error {
s.router = r s.router = r
} }
s.Unlock() s.rsvc = nil
return nil return nil
} }
@ -510,11 +514,24 @@ func (s *rpcServer) Subscribe(sb Subscriber) error {
} }
func (s *rpcServer) Register() error { func (s *rpcServer) Register() error {
s.RLock()
rsvc := s.rsvc
config := s.Options()
s.RUnlock()
if rsvc != nil {
rOpts := []registry.RegisterOption{registry.RegisterTTL(config.RegisterTTL)}
if err := config.Registry.Register(rsvc, rOpts...); err != nil {
return err
}
return nil
}
var err error var err error
var advt, host, port string var advt, host, port string
var cacheService bool
// parse address for host, port
config := s.Options()
// check the advertise address first // check the advertise address first
// if it exists then use it, otherwise // if it exists then use it, otherwise
@ -535,16 +552,17 @@ func (s *rpcServer) Register() error {
host = advt host = advt
} }
if ip := net.ParseIP(host); ip != nil {
cacheService = true
}
addr, err := addr.Extract(host) addr, err := addr.Extract(host)
if err != nil { if err != nil {
return err return err
} }
// make copy of metadata // make copy of metadata
md := make(metadata.Metadata) md := metadata.Copy(config.Metadata)
for k, v := range config.Metadata {
md[k] = v
}
// mq-rpc(eg. nats) doesn't need the port. its addr is queue name. // mq-rpc(eg. nats) doesn't need the port. its addr is queue name.
if port != "" { if port != "" {
@ -612,7 +630,9 @@ func (s *rpcServer) Register() error {
s.RUnlock() s.RUnlock()
if !registered { if !registered {
log.Infof("Registry [%s] Registering node: %s", config.Registry.String(), node.Id) if logger.V(logger.DebugLevel, logger.DefaultLogger) {
log.Debugf("Registry [%s] Registering node: %s", config.Registry.String(), node.Id)
}
} }
// create registry options // create registry options
@ -630,6 +650,9 @@ func (s *rpcServer) Register() error {
s.Lock() s.Lock()
defer s.Unlock() defer s.Unlock()
if cacheService {
s.rsvc = service
}
s.registered = true s.registered = true
// set what we're advertising // set what we're advertising
s.opts.Advertise = addr s.opts.Advertise = addr
@ -665,8 +688,9 @@ func (s *rpcServer) Register() error {
if err != nil { if err != nil {
return err return err
} }
log.Infof("Subscribing to topic: %s", sub.Topic()) if logger.V(logger.DebugLevel, logger.DefaultLogger) {
log.Debugf("Subscribing to topic: %s", sub.Topic())
}
s.subscribers[sb] = []broker.Subscriber{sub} s.subscribers[sb] = []broker.Subscriber{sub}
} }
@ -677,7 +701,9 @@ func (s *rpcServer) Deregister() error {
var err error var err error
var advt, host, port string var advt, host, port string
s.RLock()
config := s.Options() config := s.Options()
s.RUnlock()
// check the advertise address first // check the advertise address first
// if it exists then use it, otherwise // if it exists then use it, otherwise
@ -719,12 +745,15 @@ func (s *rpcServer) Deregister() error {
Nodes: []*registry.Node{node}, Nodes: []*registry.Node{node},
} }
log.Infof("Registry [%s] Deregistering node: %s", config.Registry.String(), node.Id) if logger.V(logger.DebugLevel, logger.DefaultLogger) {
log.Debugf("Registry [%s] Deregistering node: %s", config.Registry.String(), node.Id)
}
if err := config.Registry.Deregister(service); err != nil { if err := config.Registry.Deregister(service); err != nil {
return err return err
} }
s.Lock() s.Lock()
s.rsvc = nil
if !s.registered { if !s.registered {
s.Unlock() s.Unlock()
@ -741,7 +770,9 @@ func (s *rpcServer) Deregister() error {
for sb, subs := range s.subscribers { for sb, subs := range s.subscribers {
for _, sub := range subs { for _, sub := range subs {
log.Infof("Unsubscribing %s from topic: %s", node.Id, sub.Topic()) if logger.V(logger.DebugLevel, logger.DefaultLogger) {
log.Debugf("Unsubscribing %s from topic: %s", node.Id, sub.Topic())
}
sub.Unsubscribe() sub.Unsubscribe()
} }
s.subscribers[sb] = nil s.subscribers[sb] = nil
@ -767,7 +798,9 @@ func (s *rpcServer) Start() error {
return err return err
} }
log.Infof("Transport [%s] Listening on %s", config.Transport.String(), ts.Addr()) if logger.V(logger.DebugLevel, logger.DefaultLogger) {
log.Debugf("Transport [%s] Listening on %s", config.Transport.String(), ts.Addr())
}
// swap address // swap address
s.Lock() s.Lock()
@ -775,24 +808,33 @@ func (s *rpcServer) Start() error {
s.opts.Address = ts.Addr() s.opts.Address = ts.Addr()
s.Unlock() s.Unlock()
bname := config.Broker.String()
// connect to the broker // connect to the broker
if err := config.Broker.Connect(); err != nil { if err := config.Broker.Connect(); err != nil {
if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
log.Errorf("Broker [%s] connect error: %v", bname, err)
}
return err return err
} }
bname := config.Broker.String() if logger.V(logger.DebugLevel, logger.DefaultLogger) {
log.Debugf("Broker [%s] Connected to %s", bname, config.Broker.Address())
log.Infof("Broker [%s] Connected to %s", bname, config.Broker.Address()) }
// use RegisterCheck func before register // use RegisterCheck func before register
if err = s.opts.RegisterCheck(s.opts.Context); err != nil { if err = s.opts.RegisterCheck(s.opts.Context); err != nil {
if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
log.Errorf("Server %s-%s register check error: %s", config.Name, config.Id, err) log.Errorf("Server %s-%s register check error: %s", config.Name, config.Id, err)
}
} else { } else {
// announce self to the world // announce self to the world
if err = s.Register(); err != nil { if err = s.Register(); err != nil {
if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
log.Errorf("Server %s-%s register error: %s", config.Name, config.Id, err) log.Errorf("Server %s-%s register error: %s", config.Name, config.Id, err)
} }
} }
}
exit := make(chan bool) exit := make(chan bool)
@ -811,7 +853,9 @@ func (s *rpcServer) Start() error {
// check the error and backoff // check the error and backoff
default: default:
if err != nil { if err != nil {
if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
log.Errorf("Accept error: %v", err) log.Errorf("Accept error: %v", err)
}
time.Sleep(time.Second) time.Sleep(time.Second)
continue continue
} }
@ -844,18 +888,26 @@ func (s *rpcServer) Start() error {
s.RUnlock() s.RUnlock()
rerr := s.opts.RegisterCheck(s.opts.Context) rerr := s.opts.RegisterCheck(s.opts.Context)
if rerr != nil && registered { if rerr != nil && registered {
if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
log.Errorf("Server %s-%s register check error: %s, deregister it", config.Name, config.Id, err) log.Errorf("Server %s-%s register check error: %s, deregister it", config.Name, config.Id, err)
}
// deregister self in case of error // deregister self in case of error
if err := s.Deregister(); err != nil { if err := s.Deregister(); err != nil {
if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
log.Errorf("Server %s-%s deregister error: %s", config.Name, config.Id, err) log.Errorf("Server %s-%s deregister error: %s", config.Name, config.Id, err)
} }
}
} else if rerr != nil && !registered { } else if rerr != nil && !registered {
if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
log.Errorf("Server %s-%s register check error: %s", config.Name, config.Id, err) log.Errorf("Server %s-%s register check error: %s", config.Name, config.Id, err)
}
continue continue
} }
if err := s.Register(); err != nil { if err := s.Register(); err != nil {
if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
log.Errorf("Server %s-%s register error: %s", config.Name, config.Id, err) log.Errorf("Server %s-%s register error: %s", config.Name, config.Id, err)
} }
}
// wait for exit // wait for exit
case ch = <-s.exit: case ch = <-s.exit:
t.Stop() t.Stop()
@ -870,9 +922,11 @@ func (s *rpcServer) Start() error {
if registered { if registered {
// deregister self // deregister self
if err := s.Deregister(); err != nil { if err := s.Deregister(); err != nil {
if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
log.Errorf("Server %s-%s deregister error: %s", config.Name, config.Id, err) log.Errorf("Server %s-%s deregister error: %s", config.Name, config.Id, err)
} }
} }
}
s.Lock() s.Lock()
swg := s.wg swg := s.wg
@ -886,9 +940,15 @@ func (s *rpcServer) Start() error {
// close transport listener // close transport listener
ch <- ts.Close() ch <- ts.Close()
log.Infof("Broker [%s] Disconnected from %s", bname, config.Broker.Address()) if logger.V(logger.DebugLevel, logger.DefaultLogger) {
log.Debugf("Broker [%s] Disconnected from %s", bname, config.Broker.Address())
}
// disconnect the broker // disconnect the broker
config.Broker.Disconnect() if err := config.Broker.Disconnect(); err != nil {
if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
log.Errorf("Broker [%s] Disconnect error: %v", bname, err)
}
}
// swap back address // swap back address
s.Lock() s.Lock()

View File

@ -23,6 +23,10 @@ func NewCache(stores ...store.Store) store.Store {
return c return c
} }
func (c *cache) Close() error {
return nil
}
func (c *cache) Init(...store.Option) error { func (c *cache) Init(...store.Option) error {
if len(c.stores) < 2 { if len(c.stores) < 2 {
return errors.New("cache requires at least 2 stores") return errors.New("cache requires at least 2 stores")

View File

@ -101,6 +101,10 @@ func validateOptions(account, token, namespace string) {
} }
} }
func (w *workersKV) Close() error {
return nil
}
func (w *workersKV) Init(opts ...store.Option) error { func (w *workersKV) Init(opts ...store.Option) error {
for _, o := range opts { for _, o := range opts {
o(&w.options) o(&w.options)
@ -108,6 +112,9 @@ func (w *workersKV) Init(opts ...store.Option) error {
if len(w.options.Database) > 0 { if len(w.options.Database) > 0 {
w.namespace = w.options.Database w.namespace = w.options.Database
} }
if w.options.Context == nil {
w.options.Context = context.TODO()
}
ttl := w.options.Context.Value("STORE_CACHE_TTL") ttl := w.options.Context.Value("STORE_CACHE_TTL")
if ttl != nil { if ttl != nil {
ttlduration, ok := ttl.(time.Duration) ttlduration, ok := ttl.(time.Duration)

View File

@ -18,6 +18,7 @@ import (
// will use if no namespace is provided. // will use if no namespace is provided.
var ( var (
DefaultDatabase = "micro" DefaultDatabase = "micro"
DefaultTable = "micro"
) )
type sqlStore struct { type sqlStore struct {
@ -36,6 +37,19 @@ type sqlStore struct {
options store.Options options store.Options
} }
func (s *sqlStore) Close() error {
closeStmt(s.delete)
closeStmt(s.list)
closeStmt(s.readMany)
closeStmt(s.readOffset)
closeStmt(s.readOne)
closeStmt(s.write)
if s.db != nil {
return s.db.Close()
}
return nil
}
func (s *sqlStore) Init(opts ...store.Option) error { func (s *sqlStore) Init(opts ...store.Option) error {
for _, o := range opts { for _, o := range opts {
o(&s.options) o(&s.options)
@ -241,33 +255,25 @@ func (s *sqlStore) initDB() error {
if err != nil { if err != nil {
return errors.Wrap(err, "List statement couldn't be prepared") return errors.Wrap(err, "List statement couldn't be prepared")
} }
if s.list != nil { closeStmt(s.list)
s.list.Close()
}
s.list = list s.list = list
readOne, err := s.db.Prepare(fmt.Sprintf("SELECT key, value, expiry FROM %s.%s WHERE key = $1;", s.database, s.table)) readOne, err := s.db.Prepare(fmt.Sprintf("SELECT key, value, expiry FROM %s.%s WHERE key = $1;", s.database, s.table))
if err != nil { if err != nil {
return errors.Wrap(err, "ReadOne statement couldn't be prepared") return errors.Wrap(err, "ReadOne statement couldn't be prepared")
} }
if s.readOne != nil { closeStmt(s.readOne)
s.readOne.Close()
}
s.readOne = readOne s.readOne = readOne
readMany, err := s.db.Prepare(fmt.Sprintf("SELECT key, value, expiry FROM %s.%s WHERE key LIKE $1;", s.database, s.table)) readMany, err := s.db.Prepare(fmt.Sprintf("SELECT key, value, expiry FROM %s.%s WHERE key LIKE $1;", s.database, s.table))
if err != nil { if err != nil {
return errors.Wrap(err, "ReadMany statement couldn't be prepared") return errors.Wrap(err, "ReadMany statement couldn't be prepared")
} }
if s.readMany != nil { closeStmt(s.readMany)
s.readMany.Close()
}
s.readMany = readMany s.readMany = readMany
readOffset, err := s.db.Prepare(fmt.Sprintf("SELECT key, value, expiry FROM %s.%s WHERE key LIKE $1 ORDER BY key DESC LIMIT $2 OFFSET $3;", s.database, s.table)) readOffset, err := s.db.Prepare(fmt.Sprintf("SELECT key, value, expiry FROM %s.%s WHERE key LIKE $1 ORDER BY key DESC LIMIT $2 OFFSET $3;", s.database, s.table))
if err != nil { if err != nil {
return errors.Wrap(err, "ReadOffset statement couldn't be prepared") return errors.Wrap(err, "ReadOffset statement couldn't be prepared")
} }
if s.readOffset != nil { closeStmt(s.readOffset)
s.readOffset.Close()
}
s.readOffset = readOffset s.readOffset = readOffset
write, err := s.db.Prepare(fmt.Sprintf(`INSERT INTO %s.%s(key, value, expiry) write, err := s.db.Prepare(fmt.Sprintf(`INSERT INTO %s.%s(key, value, expiry)
VALUES ($1, $2::bytea, $3) VALUES ($1, $2::bytea, $3)
@ -277,17 +283,13 @@ func (s *sqlStore) initDB() error {
if err != nil { if err != nil {
return errors.Wrap(err, "Write statement couldn't be prepared") return errors.Wrap(err, "Write statement couldn't be prepared")
} }
if s.write != nil { closeStmt(s.write)
s.write.Close()
}
s.write = write s.write = write
delete, err := s.db.Prepare(fmt.Sprintf("DELETE FROM %s.%s WHERE key = $1;", s.database, s.table)) delete, err := s.db.Prepare(fmt.Sprintf("DELETE FROM %s.%s WHERE key = $1;", s.database, s.table))
if err != nil { if err != nil {
return errors.Wrap(err, "Delete statement couldn't be prepared") return errors.Wrap(err, "Delete statement couldn't be prepared")
} }
if s.delete != nil { closeStmt(s.delete)
s.delete.Close()
}
s.delete = delete s.delete = delete
return nil return nil
@ -295,7 +297,7 @@ func (s *sqlStore) initDB() error {
func (s *sqlStore) configure() error { func (s *sqlStore) configure() error {
if len(s.options.Nodes) == 0 { if len(s.options.Nodes) == 0 {
s.options.Nodes = []string{"postgresql://root@localhost:26257"} s.options.Nodes = []string{"postgresql://root@localhost:26257?sslmode=disable"}
} }
database := s.options.Database database := s.options.Database
@ -303,10 +305,10 @@ func (s *sqlStore) configure() error {
database = DefaultDatabase database = DefaultDatabase
} }
if len(s.options.Table) == 0 {
return errors.New("no table set")
}
table := s.options.Table table := s.options.Table
if len(table) == 0 {
table = DefaultTable
}
// store.namespace must only contain letters, numbers and underscores // store.namespace must only contain letters, numbers and underscores
reg, err := regexp.Compile("[^a-zA-Z0-9]+") reg, err := regexp.Compile("[^a-zA-Z0-9]+")
@ -375,3 +377,9 @@ func NewStore(opts ...store.Option) store.Store {
// return store // return store
return s return s
} }
func closeStmt(s *sql.Stmt) {
if s != nil {
s.Close()
}
}

View File

@ -32,6 +32,10 @@ func NewStore(opts ...store.Option) store.Store {
return e return e
} }
func (e *etcdStore) Close() error {
return e.client.Close()
}
func (e *etcdStore) Init(opts ...store.Option) error { func (e *etcdStore) Init(opts ...store.Option) error {
for _, o := range opts { for _, o := range opts {
o(&e.options) o(&e.options)

View File

@ -22,6 +22,9 @@ var (
DefaultTable = "micro" DefaultTable = "micro"
// DefaultDir is the default directory for bbolt files // DefaultDir is the default directory for bbolt files
DefaultDir = os.TempDir() DefaultDir = os.TempDir()
// bucket used for data storage
dataBucket = "data"
) )
// NewStore returns a memory store // NewStore returns a memory store
@ -49,7 +52,7 @@ type record struct {
func (m *fileStore) delete(key string) error { func (m *fileStore) delete(key string) error {
return m.db.Update(func(tx *bolt.Tx) error { return m.db.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte(m.options.Table)) b := tx.Bucket([]byte(dataBucket))
if b == nil { if b == nil {
return nil return nil
} }
@ -72,13 +75,13 @@ func (m *fileStore) init(opts ...store.Option) error {
} }
// create a directory /tmp/micro // create a directory /tmp/micro
dir := filepath.Join(DefaultDir, "micro") dir := filepath.Join(DefaultDir, m.options.Database)
// create the database handle // create the database handle
fname := m.options.Database + ".db" fname := m.options.Table + ".db"
// Ignoring this as the folder might exist. // Ignoring this as the folder might exist.
// Reads/Writes updates will return with sensible error messages // Reads/Writes updates will return with sensible error messages
// about the dir not existing in case this cannot create the path anyway // about the dir not existing in case this cannot create the path anyway
_ = os.Mkdir(dir, 0700) os.MkdirAll(dir, 0700)
m.dir = dir m.dir = dir
m.fileName = fname m.fileName = fname
@ -100,7 +103,7 @@ func (m *fileStore) init(opts ...store.Option) error {
// create the table // create the table
return db.Update(func(tx *bolt.Tx) error { return db.Update(func(tx *bolt.Tx) error {
_, err := tx.CreateBucketIfNotExists([]byte(m.options.Table)) _, err := tx.CreateBucketIfNotExists([]byte(dataBucket))
return err return err
}) })
} }
@ -109,7 +112,7 @@ func (m *fileStore) list(limit, offset uint) []string {
var allItems []string var allItems []string
m.db.View(func(tx *bolt.Tx) error { m.db.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte(m.options.Table)) b := tx.Bucket([]byte(dataBucket))
// nothing to read // nothing to read
if b == nil { if b == nil {
return nil return nil
@ -164,7 +167,7 @@ func (m *fileStore) get(k string) (*store.Record, error) {
m.db.View(func(tx *bolt.Tx) error { m.db.View(func(tx *bolt.Tx) error {
// @todo this is still very experimental... // @todo this is still very experimental...
b := tx.Bucket([]byte(m.options.Table)) b := tx.Bucket([]byte(dataBucket))
if b == nil { if b == nil {
return nil return nil
} }
@ -211,10 +214,10 @@ func (m *fileStore) set(r *store.Record) error {
data, _ := json.Marshal(item) data, _ := json.Marshal(item)
return m.db.Update(func(tx *bolt.Tx) error { return m.db.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte(m.options.Table)) b := tx.Bucket([]byte(dataBucket))
if b == nil { if b == nil {
var err error var err error
b, err = tx.CreateBucketIfNotExists([]byte(m.options.Table)) b, err = tx.CreateBucketIfNotExists([]byte(dataBucket))
if err != nil { if err != nil {
return err return err
} }
@ -223,6 +226,13 @@ func (m *fileStore) set(r *store.Record) error {
}) })
} }
func (m *fileStore) Close() error {
if m.db != nil {
return m.db.Close()
}
return nil
}
func (m *fileStore) Init(opts ...store.Option) error { func (m *fileStore) Init(opts ...store.Option) error {
return m.init(opts...) return m.init(opts...)
} }

View File

@ -13,14 +13,15 @@ import (
"github.com/micro/go-micro/v2/store" "github.com/micro/go-micro/v2/store"
) )
func cleanup() { func cleanup(db string, s store.Store) {
dir := filepath.Join(DefaultDir, "micro/") s.Close()
dir := filepath.Join(DefaultDir, db+"/")
os.RemoveAll(dir) os.RemoveAll(dir)
} }
func TestFileStoreReInit(t *testing.T) { func TestFileStoreReInit(t *testing.T) {
defer cleanup()
s := NewStore(store.Table("aaa")) s := NewStore(store.Table("aaa"))
defer cleanup(DefaultDatabase, s)
s.Init(store.Table("bbb")) s.Init(store.Table("bbb"))
if s.Options().Table != "bbb" { if s.Options().Table != "bbb" {
t.Error("Init didn't reinitialise the store") t.Error("Init didn't reinitialise the store")
@ -28,26 +29,26 @@ func TestFileStoreReInit(t *testing.T) {
} }
func TestFileStoreBasic(t *testing.T) { func TestFileStoreBasic(t *testing.T) {
defer cleanup()
s := NewStore() s := NewStore()
defer cleanup(DefaultDatabase, s)
fileTest(s, t) fileTest(s, t)
} }
func TestFileStoreTable(t *testing.T) { func TestFileStoreTable(t *testing.T) {
defer cleanup()
s := NewStore(store.Table("testTable")) s := NewStore(store.Table("testTable"))
defer cleanup(DefaultDatabase, s)
fileTest(s, t) fileTest(s, t)
} }
func TestFileStoreDatabase(t *testing.T) { func TestFileStoreDatabase(t *testing.T) {
defer cleanup()
s := NewStore(store.Database("testdb")) s := NewStore(store.Database("testdb"))
defer cleanup("testdb", s)
fileTest(s, t) fileTest(s, t)
} }
func TestFileStoreDatabaseTable(t *testing.T) { func TestFileStoreDatabaseTable(t *testing.T) {
defer cleanup()
s := NewStore(store.Table("testTable"), store.Database("testdb")) s := NewStore(store.Table("testTable"), store.Database("testdb"))
defer cleanup("testdb", s)
fileTest(s, t) fileTest(s, t)
} }
@ -248,19 +249,19 @@ func fileTest(s store.Store, t *testing.T) {
t.Error(err) t.Error(err)
} else { } else {
if len(results) != 5 { if len(results) != 5 {
t.Error("Expected 5 results, got ", len(results)) t.Fatal("Expected 5 results, got ", len(results))
} }
if !strings.HasPrefix(results[0].Key, "a") { if !strings.HasPrefix(results[0].Key, "a") {
t.Errorf("Expected a prefix, got %s", results[0].Key) t.Fatalf("Expected a prefix, got %s", results[0].Key)
} }
} }
// read the rest back // read the rest back
if results, err := s.Read("a", store.ReadLimit(30), store.ReadOffset(5), store.ReadPrefix()); err != nil { if results, err := s.Read("a", store.ReadLimit(30), store.ReadOffset(5), store.ReadPrefix()); err != nil {
t.Error(err) t.Fatal(err)
} else { } else {
if len(results) != 5 { if len(results) != 5 {
t.Error("Expected 5 results, got ", len(results)) t.Fatal("Expected 5 results, got ", len(results))
} }
} }
} }

View File

@ -30,6 +30,11 @@ type memoryStore struct {
store *cache.Cache store *cache.Cache
} }
func (m *memoryStore) Close() error {
m.store.Flush()
return nil
}
func (m *memoryStore) Init(opts ...store.Option) error { func (m *memoryStore) Init(opts ...store.Option) error {
m.store.Flush() m.store.Flush()
for _, o := range opts { for _, o := range opts {

View File

@ -29,3 +29,7 @@ func (n *noopStore) Delete(key string, opts ...DeleteOption) error {
func (n *noopStore) List(opts ...ListOption) ([]string, error) { func (n *noopStore) List(opts ...ListOption) ([]string, error) {
return []string{}, nil return []string{}, nil
} }
func (n *noopStore) Close() error {
return nil
}

View File

@ -32,6 +32,10 @@ type serviceStore struct {
Client pb.StoreService Client pb.StoreService
} }
func (s *serviceStore) Close() error {
return nil
}
func (s *serviceStore) Init(opts ...store.Option) error { func (s *serviceStore) Init(opts ...store.Option) error {
for _, o := range opts { for _, o := range opts {
o(&s.options) o(&s.options)

View File

@ -28,6 +28,8 @@ type Store interface {
Delete(key string, opts ...DeleteOption) error Delete(key string, opts ...DeleteOption) error
// List returns any keys that match, or an empty list with no error if none matched. // List returns any keys that match, or an empty list with no error if none matched.
List(opts ...ListOption) ([]string, error) List(opts ...ListOption) ([]string, error)
// Close the store
Close() error
// String returns the name of the implementation. // String returns the name of the implementation.
String() string String() string
} }

View File

@ -41,6 +41,10 @@ func NewCache(opts ...Option) Cache {
return c return c
} }
func (c *cache) Close() error {
return nil
}
// Init initialises the storeOptions // Init initialises the storeOptions
func (c *cache) Init(opts ...store.Option) error { func (c *cache) Init(opts ...store.Option) error {
for _, o := range opts { for _, o := range opts {

View File

@ -34,18 +34,8 @@ var (
) )
func (c *clientWrapper) setHeaders(ctx context.Context) context.Context { func (c *clientWrapper) setHeaders(ctx context.Context) context.Context {
// copy metadata // don't overwrite keys
mda, _ := metadata.FromContext(ctx) return metadata.MergeContext(ctx, c.headers, false)
md := metadata.Copy(mda)
// set headers
for k, v := range c.headers {
if _, ok := md[k]; !ok {
md[k] = v
}
}
return metadata.NewContext(ctx, md)
} }
func (c *clientWrapper) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { func (c *clientWrapper) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error {

View File

@ -115,8 +115,9 @@ func TestService(t *testing.T) {
ch := make(chan os.Signal, 1) ch := make(chan os.Signal, 1)
signal.Notify(ch, syscall.SIGTERM) signal.Notify(ch, syscall.SIGTERM)
p, _ := os.FindProcess(os.Getpid())
p.Signal(syscall.SIGTERM)
syscall.Kill(syscall.Getpid(), syscall.SIGTERM)
<-ch <-ch
select { select {