diff --git a/.golangci.yml b/.golangci.yml index 524fc7f8..6c81c4a9 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,6 +1,5 @@ run: concurrency: 4 - deadline: 5m issues-exit-code: 1 tests: true @@ -13,15 +12,13 @@ linters-settings: linters: enable: - govet - - deadcode - errcheck - govet - ineffassign - staticcheck - - structcheck - typecheck - unused - - varcheck + - spancheck - bodyclose - gci - goconst @@ -41,4 +38,5 @@ linters: - prealloc - unconvert - unparam + - unused disable-all: false diff --git a/broker/options.go b/broker/options.go index baa72ee9..ddc12e32 100644 --- a/broker/options.go +++ b/broker/options.go @@ -58,14 +58,14 @@ func NewOptions(opts ...options.Option) Options { type PublishOptions struct { // Context holds external options Context context.Context - // BodyOnly flag says the message contains raw body bytes - BodyOnly bool // Message metadata usually passed as message headers Metadata metadata.Metadata // Content-Type of message for marshal ContentType string // Topic destination Topic string + // BodyOnly flag says the message contains raw body bytes + BodyOnly bool } // NewPublishOptions creates PublishOptions struct diff --git a/broker/subscriber.go b/broker/subscriber.go index c330c022..e30c6cce 100644 --- a/broker/subscriber.go +++ b/broker/subscriber.go @@ -19,8 +19,8 @@ var typeOfError = reflect.TypeOf((*error)(nil)).Elem() // Is this an exported - upper case - name? func isExported(name string) bool { - rune, _ := utf8.DecodeRuneInString(name) - return unicode.IsUpper(rune) + r, _ := utf8.DecodeRuneInString(name) + return unicode.IsUpper(r) } // Is this type exported or a builtin? diff --git a/database/dsn.go b/database/dsn.go index 8e6627a8..d4b158fc 100644 --- a/database/dsn.go +++ b/database/dsn.go @@ -75,78 +75,80 @@ func ParseDSN(dsn string) (*Config, error) { // Find last '/' that goes before dbname foundSlash := false for i := len(dsn) - 1; i >= 0; i-- { - if dsn[i] == '/' { - foundSlash = true - var j, k int + if dsn[i] != '/' { + continue + } - // left part is empty if i <= 0 - if i > 0 { - // Find the first ':' in dsn - for j = i; j >= 0; j-- { - if dsn[j] == ':' { - cfg.Scheme = dsn[0:j] - } + foundSlash = true + var j, k int + + // left part is empty if i <= 0 + if i > 0 { + // Find the first ':' in dsn + for j = i; j >= 0; j-- { + if dsn[j] == ':' { + cfg.Scheme = dsn[0:j] } - - // [username[:password]@][host] - // Find the last '@' in dsn[:i] - for j = i; j >= 0; j-- { - if dsn[j] == '@' { - // username[:password] - // Find the second ':' in dsn[:j] - for k = 0; k < j; k++ { - if dsn[k] == ':' { - if cfg.Scheme == dsn[:k] { - continue - } - var err error - cfg.Password, err = url.PathUnescape(dsn[k+1 : j]) - if err != nil { - return nil, err - } - break - } - } - cfg.Username = dsn[len(cfg.Scheme)+3 : k] - break - } - } - - for k = j + 1; k < i; k++ { - if dsn[k] == ':' { - cfg.Host = dsn[j+1 : k] - cfg.Port = dsn[k+1 : i] - break - } - } - } - // dbname[?param1=value1&...¶mN=valueN] - // Find the first '?' in dsn[i+1:] - for j = i + 1; j < len(dsn); j++ { - if dsn[j] == '?' { - parts := strings.Split(dsn[j+1:], "&") - cfg.Params = make([]string, 0, len(parts)*2) - for _, p := range parts { - k, v, found := strings.Cut(p, "=") - if !found { - continue + // [username[:password]@][host] + // Find the last '@' in dsn[:i] + for j = i; j >= 0; j-- { + if dsn[j] == '@' { + // username[:password] + // Find the second ':' in dsn[:j] + for k = 0; k < j; k++ { + if dsn[k] == ':' { + if cfg.Scheme == dsn[:k] { + continue + } + var err error + cfg.Password, err = url.PathUnescape(dsn[k+1 : j]) + if err != nil { + return nil, err + } + break } - cfg.Params = append(cfg.Params, k, v) } - + cfg.Username = dsn[len(cfg.Scheme)+3 : k] break } } - var err error - dbname := dsn[i+1 : j] - if cfg.Database, err = url.PathUnescape(dbname); err != nil { - return nil, fmt.Errorf("invalid dbname %q: %w", dbname, err) + + for k = j + 1; k < i; k++ { + if dsn[k] == ':' { + cfg.Host = dsn[j+1 : k] + cfg.Port = dsn[k+1 : i] + break + } } - break } + + // dbname[?param1=value1&...¶mN=valueN] + // Find the first '?' in dsn[i+1:] + for j = i + 1; j < len(dsn); j++ { + if dsn[j] == '?' { + parts := strings.Split(dsn[j+1:], "&") + cfg.Params = make([]string, 0, len(parts)*2) + for _, p := range parts { + k, v, found := strings.Cut(p, "=") + if !found { + continue + } + cfg.Params = append(cfg.Params, k, v) + } + + break + } + } + var err error + dbname := dsn[i+1 : j] + if cfg.Database, err = url.PathUnescape(dbname); err != nil { + return nil, fmt.Errorf("invalid dbname %q: %w", dbname, err) + } + + break } if !foundSlash && len(dsn) > 0 { diff --git a/go.mod b/go.mod index 80403212..5a3afc27 100644 --- a/go.mod +++ b/go.mod @@ -5,17 +5,16 @@ go 1.20 require ( dario.cat/mergo v1.0.0 github.com/DATA-DOG/go-sqlmock v1.5.0 - github.com/google/uuid v1.3.1 + github.com/google/uuid v1.6.0 github.com/patrickmn/go-cache v2.1.0+incompatible github.com/silas/dag v0.0.0-20220518035006-a7e85ada93c5 - golang.org/x/sync v0.3.0 - golang.org/x/sys v0.12.0 - google.golang.org/grpc v1.58.2 - google.golang.org/protobuf v1.31.0 + golang.org/x/sync v0.6.0 + golang.org/x/sys v0.16.0 + google.golang.org/grpc v1.62.1 + google.golang.org/protobuf v1.32.0 ) require ( github.com/golang/protobuf v1.5.3 // indirect - golang.org/x/net v0.15.0 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240123012728-ef4313101c80 // indirect ) diff --git a/go.sum b/go.sum index 1bc873a5..cef2316a 100644 --- a/go.sum +++ b/go.sum @@ -6,29 +6,28 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= -github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4= -github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/silas/dag v0.0.0-20220518035006-a7e85ada93c5 h1:G/FZtUu7a6NTWl3KUHMV9jkLAh/Rvtf03NWMHaEDl+E= github.com/silas/dag v0.0.0-20220518035006-a7e85ada93c5/go.mod h1:7RTUFBdIRC9nZ7/3RyRNH1bdqIShrDejd1YbLwgPS+I= -golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8= -golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= -golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= -golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= -golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= -golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= +golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo= +golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= +golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= +golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 h1:bVf09lpb+OJbByTj913DRJioFFAjf/ZGxEz7MajTp2U= -google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98/go.mod h1:TUfxEVdsvPg18p6AslUXFoLdpED4oBnGwyqk3dV1XzM= -google.golang.org/grpc v1.58.2 h1:SXUpjxeVF3FKrTYQI4f4KvbGD5u2xccdYdurwowix5I= -google.golang.org/grpc v1.58.2/go.mod h1:tgX3ZQDlNJGU96V6yHh1T/JeoBQ2TXdr43YbYSsCJk0= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240123012728-ef4313101c80 h1:AjyfHzEPEFp/NpvfN5g+KDla3EMojjhRVZc1i7cj+oM= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240123012728-ef4313101c80/go.mod h1:PAREbraiVEVGVdTZsVWjSbbTtSyGbAgIIvni8a8CD5s= +google.golang.org/grpc v1.62.1 h1:B4n+nfKzOICUXMgyrNd19h/I9oH0L1pizfk1d4zSgTk= +google.golang.org/grpc v1.62.1/go.mod h1:IWTG0VlJLCh1SkC58F7np9ka9mx/WNkjl4PGJaiq+QE= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= -google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.32.0 h1:pPC6BG5ex8PDFnkbrGU3EixyhKcQ2aDuBS36lqK/C7I= +google.golang.org/protobuf v1.32.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/logger/options.go b/logger/options.go index 4424c612..8755c9b7 100644 --- a/logger/options.go +++ b/logger/options.go @@ -18,18 +18,12 @@ type Options struct { Out io.Writer // Context holds exernal options Context context.Context - // Attrs holds additional attributes - Attrs []interface{} - // Name holds the logger name - Name string - // The logging level the logger should log - Level Level - // CallerSkipCount number of frmaes to skip - CallerSkipCount int - // ContextAttrFuncs contains funcs that executed before log func on context - ContextAttrFuncs []ContextAttrFunc + // TimeFunc used to obtain current time + TimeFunc func() time.Time // TimeKey is the key used for the time of the log call TimeKey string + // Name holds the logger name + Name string // LevelKey is the key used for the level of the log call LevelKey string // MessageKey is the key used for the message of the log call @@ -40,12 +34,18 @@ type Options struct { SourceKey string // StacktraceKey is the key used for the stacktrace StacktraceKey string + // Attrs holds additional attributes + Attrs []interface{} + // ContextAttrFuncs contains funcs that executed before log func on context + ContextAttrFuncs []ContextAttrFunc + // CallerSkipCount number of frmaes to skip + CallerSkipCount int + // The logging level the logger should log + Level Level // AddStacktrace controls writing of stacktaces on error AddStacktrace bool // AddSource enabled writing source file and position in log AddSource bool - // TimeFunc used to obtain current time - TimeFunc func() time.Time } // NewOptions creates new options struct diff --git a/logger/unwrap/unwrap.go b/logger/unwrap/unwrap.go index e13844fb..3c456f19 100644 --- a/logger/unwrap/unwrap.go +++ b/logger/unwrap/unwrap.go @@ -56,9 +56,9 @@ type Wrapper struct { s fmt.State pointers map[uintptr]int opts *Options + takeMap map[int]bool depth int ignoreNextType bool - takeMap map[int]bool protoWrapperType bool sqlWrapperType bool } diff --git a/metadata/metadata_test.go b/metadata/metadata_test.go index 6155e535..809f5274 100644 --- a/metadata/metadata_test.go +++ b/metadata/metadata_test.go @@ -54,7 +54,24 @@ func testOutgoingCtx(ctx context.Context) { } } -func TestPassing(t *testing.T) { +func TestIncoming(t *testing.T) { + ctx := context.TODO() + md1 := New(2) + md1.Set("Key1", "Val1") + md1.Set("Key2", "Val2") + + ctx = NewIncomingContext(ctx, md1) + testIncomingCtx(ctx) + md, ok := FromIncomingContext(ctx) + if !ok { + t.Fatalf("missing metadata from incoming context") + } + if v, ok := md.Get("Key1"); !ok || v != "Val1_new" { + t.Fatalf("invalid metadata value %#+v", md) + } +} + +func TestOutgoing(t *testing.T) { ctx := context.TODO() md1 := New(2) md1.Set("Key1", "Val1") diff --git a/micro.go b/micro.go index c86955c8..6163b2e1 100644 --- a/micro.go +++ b/micro.go @@ -65,6 +65,8 @@ func As(b any, target any) bool { break case targetType.Implements(routerType): break + case targetType.Implements(tracerType): + break default: return false } @@ -76,19 +78,21 @@ func As(b any, target any) bool { return false } -var brokerType = reflect.TypeOf((*broker.Broker)(nil)).Elem() -var loggerType = reflect.TypeOf((*logger.Logger)(nil)).Elem() -var clientType = reflect.TypeOf((*client.Client)(nil)).Elem() -var serverType = reflect.TypeOf((*server.Server)(nil)).Elem() -var codecType = reflect.TypeOf((*codec.Codec)(nil)).Elem() -var flowType = reflect.TypeOf((*flow.Flow)(nil)).Elem() -var fsmType = reflect.TypeOf((*fsm.FSM)(nil)).Elem() -var meterType = reflect.TypeOf((*meter.Meter)(nil)).Elem() -var registerType = reflect.TypeOf((*register.Register)(nil)).Elem() -var resolverType = reflect.TypeOf((*resolver.Resolver)(nil)).Elem() -var routerType = reflect.TypeOf((*router.Router)(nil)).Elem() -var selectorType = reflect.TypeOf((*selector.Selector)(nil)).Elem() -var storeType = reflect.TypeOf((*store.Store)(nil)).Elem() -var syncType = reflect.TypeOf((*sync.Sync)(nil)).Elem() -var tracerType = reflect.TypeOf((*tracer.Tracer)(nil)).Elem() -var serviceType = reflect.TypeOf((*Service)(nil)).Elem() +var ( + brokerType = reflect.TypeOf((*broker.Broker)(nil)).Elem() + loggerType = reflect.TypeOf((*logger.Logger)(nil)).Elem() + clientType = reflect.TypeOf((*client.Client)(nil)).Elem() + serverType = reflect.TypeOf((*server.Server)(nil)).Elem() + codecType = reflect.TypeOf((*codec.Codec)(nil)).Elem() + flowType = reflect.TypeOf((*flow.Flow)(nil)).Elem() + fsmType = reflect.TypeOf((*fsm.FSM)(nil)).Elem() + meterType = reflect.TypeOf((*meter.Meter)(nil)).Elem() + registerType = reflect.TypeOf((*register.Register)(nil)).Elem() + resolverType = reflect.TypeOf((*resolver.Resolver)(nil)).Elem() + routerType = reflect.TypeOf((*router.Router)(nil)).Elem() + selectorType = reflect.TypeOf((*selector.Selector)(nil)).Elem() + storeType = reflect.TypeOf((*store.Store)(nil)).Elem() + syncType = reflect.TypeOf((*sync.Sync)(nil)).Elem() + tracerType = reflect.TypeOf((*tracer.Tracer)(nil)).Elem() + serviceType = reflect.TypeOf((*Service)(nil)).Elem() +) diff --git a/options/options_test.go b/options/options_test.go index acb64651..9a8946a2 100644 --- a/options/options_test.go +++ b/options/options_test.go @@ -1,7 +1,6 @@ package options_test import ( - "fmt" "testing" "go.unistack.org/micro/v4/codec" @@ -132,7 +131,6 @@ func TestMetadataAny(t *testing.T) { var opts []options.Option switch valData := tt.Data.(type) { case []any: - fmt.Printf("%s any %#+v\n", tt.Name, valData) opts = append(opts, options.Metadata(valData...)) case map[string]string, map[string][]string, metadata.Metadata: opts = append(opts, options.Metadata(valData)) diff --git a/register/memory/memory.go b/register/memory/memory.go index e4385597..35a12562 100644 --- a/register/memory/memory.go +++ b/register/memory/memory.go @@ -32,10 +32,10 @@ type record struct { } type memory struct { - sync.RWMutex records map[string]services watchers map[string]*watcher opts register.Options + sync.RWMutex } // services is a KV map with service name as the key and a map of records as the value @@ -102,10 +102,20 @@ func (m *memory) sendEvent(r *register.Result) { } func (m *memory) Connect(ctx context.Context) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } return nil } func (m *memory) Disconnect(ctx context.Context) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } return nil } @@ -126,6 +136,11 @@ func (m *memory) Options() register.Options { } func (m *memory) Register(ctx context.Context, s *register.Service, opts ...register.RegisterOption) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } m.Lock() defer m.Unlock() @@ -467,9 +482,7 @@ func serviceToRecord(s *register.Service, ttl time.Duration) *record { } endpoints := make([]*register.Endpoint, len(s.Endpoints)) - for i, e := range s.Endpoints { // TODO: vtolstov use copy - endpoints[i] = e - } + copy(endpoints, s.Endpoints) return &record{ Name: s.Name, diff --git a/register/memory/memory_test.go b/register/memory/memory_test.go index c4dd7151..c60b9adf 100644 --- a/register/memory/memory_test.go +++ b/register/memory/memory_test.go @@ -3,12 +3,13 @@ package memory import ( "context" "fmt" - "go.unistack.org/micro/v4" - "go.unistack.org/micro/v4/register" "reflect" "sync" "testing" "time" + + "go.unistack.org/micro/v4" + "go.unistack.org/micro/v4/register" ) var testData = map[string][]*register.Service{ @@ -209,9 +210,9 @@ func TestMemoryRegistryTTLConcurrent(t *testing.T) { } } - //if len(os.Getenv("IN_TRAVIS_CI")) == 0 { + // if len(os.Getenv("IN_TRAVIS_CI")) == 0 { // t.Logf("test will wait %v, then check TTL timeouts", waitTime) - //} + // } errChan := make(chan error, concurrency) syncChan := make(chan struct{}) @@ -252,6 +253,13 @@ func TestMemoryWildcard(t *testing.T) { m := NewRegister() ctx := context.TODO() + if err := m.Init(); err != nil { + t.Fatal(err) + } + + if err := m.Connect(ctx); err != nil { + t.Fatal(err) + } testSrv := ®ister.Service{Name: "foo", Version: "1.0.0"} if err := m.Register(ctx, testSrv, register.RegisterDomain("one")); err != nil { @@ -291,8 +299,12 @@ func TestWatcher(t *testing.T) { ctx := context.TODO() m := NewRegister() - m.Init() - m.Connect(ctx) + if err := m.Init(); err != nil { + t.Fatal(err) + } + if err := m.Connect(ctx); err != nil { + t.Fatal(err) + } wc, err := m.Watch(ctx) if err != nil { t.Fatalf("cant watch: %v", err) diff --git a/sync/waitgroup.go b/sync/waitgroup.go new file mode 100644 index 00000000..3124d948 --- /dev/null +++ b/sync/waitgroup.go @@ -0,0 +1,69 @@ +package sync + +import ( + "context" + "sync" +) + +type WaitGroup struct { + wg *sync.WaitGroup + c int + mu sync.Mutex +} + +func WrapWaitGroup(wg *sync.WaitGroup) *WaitGroup { + g := &WaitGroup{ + wg: wg, + } + return g +} + +func NewWaitGroup() *WaitGroup { + var wg sync.WaitGroup + return WrapWaitGroup(&wg) +} + +func (g *WaitGroup) Add(n int) { + g.mu.Lock() + g.c += n + g.wg.Add(n) + g.mu.Unlock() +} + +func (g *WaitGroup) Done() { + g.mu.Lock() + g.c += -1 + g.wg.Add(-1) + g.mu.Unlock() +} + +func (g *WaitGroup) Wait() { + g.wg.Wait() +} + +func (g *WaitGroup) WaitContext(ctx context.Context) { + done := make(chan struct{}) + go func() { + g.wg.Wait() + close(done) + }() + + select { + case <-ctx.Done(): + g.mu.Lock() + g.wg.Add(-g.c) + <-done + g.wg.Add(g.c) + g.mu.Unlock() + return + case <-done: + return + } +} + +func (g *WaitGroup) Waiters() int { + g.mu.Lock() + c := g.c + g.mu.Unlock() + return c +} diff --git a/sync/waitgroup_test.go b/sync/waitgroup_test.go new file mode 100644 index 00000000..c3f6f1b7 --- /dev/null +++ b/sync/waitgroup_test.go @@ -0,0 +1,37 @@ +package sync + +import ( + "context" + "testing" + "time" +) + +func TestWaitGroupContext(t *testing.T) { + wg := NewWaitGroup() + _ = t + wg.Add(1) + ctx, cancel := context.WithTimeout(context.TODO(), 1*time.Second) + defer cancel() + wg.WaitContext(ctx) +} + +func TestWaitGroupReuse(t *testing.T) { + wg := NewWaitGroup() + defer func() { + if wg.Waiters() != 0 { + t.Fatal("lost goroutines") + } + }() + + wg.Add(1) + defer wg.Done() + ctx, cancel := context.WithTimeout(context.TODO(), 1*time.Second) + defer cancel() + wg.WaitContext(ctx) + + wg.Add(1) + defer wg.Done() + ctx, cancel = context.WithTimeout(context.TODO(), 1*time.Second) + defer cancel() + wg.WaitContext(ctx) +} diff --git a/util/register/util.go b/util/register/util.go index c235920b..a73e7f19 100644 --- a/util/register/util.go +++ b/util/register/util.go @@ -109,12 +109,11 @@ func Merge(olist []*register.Service, nlist []*register.Service) []*register.Ser seen = true srv = append(srv, sp) break - } else { - sp := ®ister.Service{} - // make copy - *sp = *o - srv = append(srv, sp) } + sp := ®ister.Service{} + // make copy + *sp = *o + srv = append(srv, sp) } if !seen { srv = append(srv, Copy([]*register.Service{n})...) diff --git a/util/structfs/metadata_ec2.go b/util/structfs/metadata_ec2.go index 6c0f63aa..07be6e42 100644 --- a/util/structfs/metadata_ec2.go +++ b/util/structfs/metadata_ec2.go @@ -12,7 +12,7 @@ type EC2Metadata struct { InstanceType string `json:"instance-type"` LocalHostname string `json:"local-hostname"` LocalIPv4 string `json:"local-ipv4"` - kernelID int `json:"kernel-id"` + KernelID int `json:"kernel-id"` Placement string `json:"placement"` AvailabilityZone string `json:"availability-zone"` ProductCodes string `json:"product-codes"` diff --git a/util/structfs/structfs.go b/util/structfs/structfs.go index 755c691c..740b3014 100644 --- a/util/structfs/structfs.go +++ b/util/structfs/structfs.go @@ -67,9 +67,9 @@ func (fi *fileInfo) Name() string { func (fi *fileInfo) Mode() os.FileMode { if strings.HasSuffix(fi.name, "/") { - return os.FileMode(0755) | os.ModeDir + return os.FileMode(0o755) | os.ModeDir } - return os.FileMode(0644) + return os.FileMode(0o644) } func (fi *fileInfo) IsDir() bool { @@ -112,15 +112,14 @@ func (f *file) Readdir(count int) ([]os.FileInfo, error) { func (f *file) Seek(offset int64, whence int) (int64, error) { // log.Printf("seek %d %d %s\n", offset, whence, f.name) switch whence { - case os.SEEK_SET: + case io.SeekStart: f.offset = offset - case os.SEEK_CUR: + case io.SeekCurrent: f.offset += offset - case os.SEEK_END: + case io.SeekEnd: f.offset = int64(len(f.data)) + offset } return f.offset, nil - } func (f *file) Stat() (os.FileInfo, error) { diff --git a/util/structfs/structfs_test.go b/util/structfs/structfs_test.go index 7abf8edb..5c198fda 100644 --- a/util/structfs/structfs_test.go +++ b/util/structfs/structfs_test.go @@ -2,7 +2,7 @@ package structfs import ( "encoding/json" - "io/ioutil" + "io" "net/http" "reflect" "testing" @@ -82,7 +82,7 @@ func get(path string) ([]byte, error) { return nil, err } defer res.Body.Close() - return ioutil.ReadAll(res.Body) + return io.ReadAll(res.Body) } func TestAll(t *testing.T) {