Compare commits
	
		
			60 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 44ec3b663b | |||
| 4bb73514e9 | |||
| 3e86864ce7 | |||
|  | a68d3b24b8 | ||
| 9c22ae5384 | |||
|  | 16bad9a0cd | ||
|  | 3c779b248f | ||
| bbc7512054 | |||
|  | dd810e4ae0 | ||
| 236ed47ab1 | |||
|  | 909bcf51a4 | ||
| 9c24001f52 | |||
| 680cd6f708 | |||
|  | 3fcf3bef6d | ||
|  | 0ecd6199d4 | ||
| 14a30fb6a7 | |||
| 7c613072df | |||
|  | c55e212270 | ||
| 100bc006bb | |||
| 5dbfe8a7a6 | |||
| 6be077dbe8 | |||
| b4878211ee | |||
| ec9178c6d4 | |||
| ae63d44866 | |||
| 883e79216a | |||
| fa636ef6a9 | |||
| cdb81a9ba3 | |||
| 413c6cc2f0 | |||
|  | f56bd70136 | ||
| b51b4107a8 | |||
| 2067c9de6b | |||
| 3f82cb3ba4 | |||
|  | 306b7a3962 | ||
| a8eda9d58d | |||
| 7e4477dcb4 | |||
|  | d846044fc6 | ||
| 29d956e74e | |||
| fcc4faff8a | |||
| 5df8f83f45 | |||
|  | 27fa6e9173 | ||
| bd55a35dc3 | |||
| 653bd386cc | |||
|  | 558c6f4d7c | ||
| d7dd6fbeb2 | |||
| a00cf2c8d9 | |||
|  | a3e8ab2492 | ||
| 06da500ef4 | |||
| 277f04ba19 | |||
|  | 470263ff5f | ||
| b8232e02be | |||
|  | f8c5e10c1d | ||
| 397e71f815 | |||
| 74e31d99f6 | |||
| f39de15d93 | |||
| d291102877 | |||
| 37ffbb18d8 | |||
| 9a85dead86 | |||
| a489aab1c3 | |||
| d160664ef1 | |||
| fa868edcaa | 
| @@ -3,14 +3,16 @@ name: coverage | ||||
| on: | ||||
|   push: | ||||
|     branches: [ main, v3, v4 ] | ||||
|     paths-ignore: | ||||
|       - '.github/**' | ||||
|       - '.gitea/**' | ||||
|   pull_request: | ||||
|     branches: [ main, v3, v4 ] | ||||
|   # Allows you to run this workflow manually from the Actions tab | ||||
|   workflow_dispatch: | ||||
| 
 | ||||
| jobs: | ||||
| 
 | ||||
|   build: | ||||
|     if: github.server_url != 'https://github.com' | ||||
|     runs-on: ubuntu-latest | ||||
|     steps: | ||||
|     - name: checkout code | ||||
| @@ -26,24 +28,24 @@ jobs: | ||||
| 
 | ||||
|     - name: test coverage | ||||
|       run: | | ||||
|         go test -v -cover ./... -coverprofile coverage.out -coverpkg ./... | ||||
|         go test -v -cover ./... -covermode=count -coverprofile coverage.out -coverpkg ./... | ||||
|         go tool cover -func coverage.out -o coverage.out | ||||
| 
 | ||||
|     - name: coverage badge | ||||
|       uses: tj-actions/coverage-badge-go@v1 | ||||
|       uses: tj-actions/coverage-badge-go@v2 | ||||
|       with: | ||||
|         green: 80 | ||||
|         filename: coverage.out | ||||
| 
 | ||||
|     - uses: stefanzweifel/git-auto-commit-action@v4 | ||||
|       id: auto-commit-action | ||||
|       name: autocommit | ||||
|       with: | ||||
|         commit_message: Apply Code Coverage Badge | ||||
|         skip_fetch: true | ||||
|         skip_checkout: true | ||||
|         skip_fetch: false | ||||
|         skip_checkout: false | ||||
|         file_pattern: ./README.md | ||||
| 
 | ||||
|     - name: Push Changes | ||||
|     - name: push | ||||
|       if: steps.auto-commit-action.outputs.changes_detected == 'true' | ||||
|       uses: ad-m/github-push-action@master | ||||
|       with: | ||||
| @@ -3,10 +3,10 @@ name: lint | ||||
| on: | ||||
|   pull_request: | ||||
|     types: [opened, reopened, synchronize] | ||||
|     branches: | ||||
|     - master | ||||
|     - v3 | ||||
|     - v4 | ||||
|     branches: [ master, v3, v4 ] | ||||
|     paths-ignore: | ||||
|       - '.github/**' | ||||
|       - '.gitea/**' | ||||
| 
 | ||||
| jobs: | ||||
|   lint: | ||||
| @@ -24,6 +24,6 @@ jobs: | ||||
|     - name: setup deps | ||||
|       run: go get -v ./... | ||||
|     - name: run lint | ||||
|       uses: https://github.com/golangci/golangci-lint-action@v6 | ||||
|       uses: golangci/golangci-lint-action@v6 | ||||
|       with: | ||||
|         version: 'latest' | ||||
| @@ -3,15 +3,12 @@ name: test | ||||
| on: | ||||
|   pull_request: | ||||
|     types: [opened, reopened, synchronize] | ||||
|     branches: | ||||
|     - master | ||||
|     - v3 | ||||
|     - v4 | ||||
|     branches: [ master, v3, v4 ] | ||||
|   push: | ||||
|     branches: | ||||
|     - master | ||||
|     - v3 | ||||
|     - v4 | ||||
|     branches: [ master, v3, v4 ] | ||||
|     paths-ignore: | ||||
|       - '.github/**' | ||||
|       - '.gitea/**' | ||||
| 
 | ||||
| jobs: | ||||
|   test: | ||||
| @@ -3,15 +3,12 @@ name: test | ||||
| on: | ||||
|   pull_request: | ||||
|     types: [opened, reopened, synchronize] | ||||
|     branches: | ||||
|     - master | ||||
|     - v3 | ||||
|     - v4 | ||||
|     branches: [ master, v3, v4 ] | ||||
|   push: | ||||
|     branches: | ||||
|     - master | ||||
|     - v3 | ||||
|     - v4 | ||||
|     branches: [ master, v3, v4 ] | ||||
|     paths-ignore: | ||||
|       - '.github/**' | ||||
|       - '.gitea/**' | ||||
| 
 | ||||
| jobs: | ||||
|   test: | ||||
| @@ -35,19 +32,19 @@ jobs: | ||||
|         go-version: 'stable' | ||||
|     - name: setup go work | ||||
|       env: | ||||
|         GOWORK: /workspace/${{ github.repository_owner }}/go.work | ||||
|         GOWORK: ${{ github.workspace }}/go.work | ||||
|       run: | | ||||
|         go work init | ||||
|         go work use . | ||||
|         go work use micro-tests | ||||
|     - name: setup deps | ||||
|       env: | ||||
|         GOWORK: /workspace/${{ github.repository_owner }}/go.work | ||||
|         GOWORK: ${{ github.workspace }}/go.work | ||||
|       run: go get -v ./... | ||||
|     - name: run tests | ||||
|       env: | ||||
|         INTEGRATION_TESTS: yes | ||||
|         GOWORK: /workspace/${{ github.repository_owner }}/go.work | ||||
|         GOWORK: ${{ github.workspace }}/go.work | ||||
|       run: | | ||||
|         cd micro-tests | ||||
|         go test -mod readonly -v ./... || true | ||||
| @@ -1,5 +1,5 @@ | ||||
| run: | ||||
|   concurrency: 8 | ||||
|   deadline: 5m | ||||
|   timeout: 5m | ||||
|   issues-exit-code: 1 | ||||
|   tests: true | ||||
|   | ||||
							
								
								
									
										32
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										32
									
								
								README.md
									
									
									
									
									
								
							| @@ -1,40 +1,34 @@ | ||||
| # Micro | ||||
|  | ||||
|  | ||||
| [](https://opensource.org/licenses/Apache-2.0) | ||||
| [](https://pkg.go.dev/go.unistack.org/micro/v3?tab=overview) | ||||
| [](https://git.unistack.org/unistack-org/micro/actions?query=workflow%3Abuild+branch%3Av3+event%3Apush) | ||||
| [](https://goreportcard.com/report/go.unistack.org/micro/v3) | ||||
|  | ||||
| Micro is a standard library for microservices. | ||||
|  | ||||
| ## Overview | ||||
|  | ||||
| Micro provides the core requirements for distributed systems development including RPC and Event driven communication.  | ||||
| Micro provides the core requirements for distributed systems development including SYNC and ASYNC communication.  | ||||
|  | ||||
| ## Features | ||||
|  | ||||
| Micro abstracts away the details of distributed systems. Here are the main features. | ||||
|  | ||||
| - **Authentication** - Auth is built in as a first class citizen. Authentication and authorization enable secure  | ||||
| zero trust networking by providing every service an identity and certificates. This additionally includes rule  | ||||
| based access control. | ||||
| Micro abstracts away the details of distributed systems. Main features: | ||||
|  | ||||
| - **Dynamic Config** - Load and hot reload dynamic config from anywhere. The config interface provides a way to load application  | ||||
| level config from any source such as env vars, file, etcd. You can merge the sources and even define fallbacks. | ||||
| level config from any source such as env vars, cmdline, file, consul, vault, etc... You can merge the sources and even define fallbacks. | ||||
|  | ||||
| - **Data Storage** - A simple data store interface to read, write and delete records. It includes support for memory, file and  | ||||
| CockroachDB by default. State and persistence becomes a core requirement beyond prototyping and Micro looks to build that into the framework. | ||||
| s3. State and persistence becomes a core requirement beyond prototyping and Micro looks to build that into the framework. | ||||
|  | ||||
| - **Service Discovery** - Automatic service registration and name resolution. Service discovery is at the core of micro service  | ||||
| development. When service A needs to speak to service B it needs the location of that service. | ||||
|  | ||||
| - **Load Balancing** - Client side load balancing built on service discovery. Once we have the addresses of any number of instances  | ||||
| of a service we now need a way to decide which node to route to. We use random hashed load balancing to provide even distribution  | ||||
| across the services and retry a different node if there's a problem.  | ||||
| development. | ||||
|  | ||||
| - **Message Encoding** - Dynamic message encoding based on content-type. The client and server will use codecs along with content-type  | ||||
| to seamlessly encode and decode Go types for you. Any variety of messages could be encoded and sent from different clients. The client  | ||||
| and server handle this by default. | ||||
|  | ||||
| - **Transport** - gRPC or http based request/response with support for bidirectional streaming. We provide an abstraction for synchronous communication. A request made to a service will be automatically resolved, load balanced, dialled and streamed. | ||||
|  | ||||
| - **Async Messaging** - PubSub is built in as a first class citizen for asynchronous communication and event driven architectures. | ||||
| - **Async Messaging** - Pub/Sub is built in as a first class citizen for asynchronous communication and event driven architectures. | ||||
| Event notifications are a core pattern in micro service development. | ||||
|  | ||||
| - **Synchronization** - Distributed systems are often built in an eventually consistent manner. Support for distributed locking and  | ||||
| @@ -43,10 +37,6 @@ leadership are built in as a Sync interface. When using an eventually consistent | ||||
| - **Pluggable Interfaces** - Micro makes use of Go interfaces for each system abstraction. Because of this these interfaces  | ||||
| are pluggable and allows Micro to be runtime agnostic. | ||||
|  | ||||
| ## Getting Started | ||||
|  | ||||
| To be created. | ||||
|  | ||||
| ## License | ||||
|  | ||||
| Micro is Apache 2.0 licensed. | ||||
|   | ||||
							
								
								
									
										15
									
								
								SECURITY.md
									
									
									
									
									
								
							
							
						
						
									
										15
									
								
								SECURITY.md
									
									
									
									
									
								
							| @@ -1,15 +0,0 @@ | ||||
| # Security Policy | ||||
|  | ||||
| ## Supported Versions | ||||
|  | ||||
| Use this section to tell people about which versions of your project are | ||||
| currently being supported with security updates. | ||||
|  | ||||
| | Version | Supported          | | ||||
| | ------- | ------------------ | | ||||
| | 3.7.x   | :white_check_mark: | | ||||
| | < 3.7.0 | :x:                | | ||||
|  | ||||
| ## Reporting a Vulnerability | ||||
|  | ||||
| If you find any issue, please create github issue in this repo | ||||
| @@ -17,11 +17,6 @@ import ( | ||||
| 	"go.unistack.org/micro/v3/tracer" | ||||
| ) | ||||
|  | ||||
| // DefaultCodecs will be used to encode/decode data | ||||
| var DefaultCodecs = map[string]codec.Codec{ | ||||
| 	"application/octet-stream": codec.NewCodec(), | ||||
| } | ||||
|  | ||||
| type noopClient struct { | ||||
| 	funcPublish      FuncPublish | ||||
| 	funcBatchPublish FuncBatchPublish | ||||
| @@ -178,9 +173,6 @@ func (n *noopClient) newCodec(contentType string) (codec.Codec, error) { | ||||
| 	if cf, ok := n.opts.Codecs[contentType]; ok { | ||||
| 		return cf, nil | ||||
| 	} | ||||
| 	if cf, ok := DefaultCodecs[contentType]; ok { | ||||
| 		return cf, nil | ||||
| 	} | ||||
| 	return nil, codec.ErrUnknownContentType | ||||
| } | ||||
|  | ||||
| @@ -588,7 +580,6 @@ func (n *noopClient) publish(ctx context.Context, ps []Message, opts ...PublishO | ||||
|  | ||||
| 	for _, p := range ps { | ||||
| 		md := metadata.Copy(omd) | ||||
| 		md[metadata.HeaderContentType] = p.ContentType() | ||||
| 		topic := p.Topic() | ||||
| 		if len(exchange) > 0 { | ||||
| 			topic = exchange | ||||
| @@ -600,6 +591,8 @@ func (n *noopClient) publish(ctx context.Context, ps []Message, opts ...PublishO | ||||
| 			md.Set(k, v) | ||||
| 		} | ||||
|  | ||||
| 		md[metadata.HeaderContentType] = p.ContentType() | ||||
|  | ||||
| 		var body []byte | ||||
|  | ||||
| 		// passed in raw data | ||||
|   | ||||
| @@ -3,6 +3,8 @@ package client | ||||
| import ( | ||||
| 	"context" | ||||
| 	"testing" | ||||
|  | ||||
| 	"go.unistack.org/micro/v3/codec" | ||||
| ) | ||||
|  | ||||
| type testHook struct { | ||||
| @@ -19,7 +21,7 @@ func (t *testHook) Publish(fn FuncPublish) FuncPublish { | ||||
| func TestNoopHook(t *testing.T) { | ||||
| 	h := &testHook{} | ||||
|  | ||||
| 	c := NewClient(Hooks(HookPublish(h.Publish))) | ||||
| 	c := NewClient(Codec("application/octet-stream", codec.NewCodec()), Hooks(HookPublish(h.Publish))) | ||||
|  | ||||
| 	if err := c.Init(); err != nil { | ||||
| 		t.Fatal(err) | ||||
|   | ||||
| @@ -198,7 +198,7 @@ func NewOptions(opts ...Option) Options { | ||||
| 	options := Options{ | ||||
| 		Context:     context.Background(), | ||||
| 		ContentType: DefaultContentType, | ||||
| 		Codecs:      DefaultCodecs, | ||||
| 		Codecs:      make(map[string]codec.Codec), | ||||
| 		CallOptions: CallOptions{ | ||||
| 			Context:        context.Background(), | ||||
| 			Backoff:        DefaultBackoff, | ||||
|   | ||||
							
								
								
									
										235
									
								
								cluster/hasql/cluster.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										235
									
								
								cluster/hasql/cluster.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,235 @@ | ||||
| package sql | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| 	"reflect" | ||||
| 	"unsafe" | ||||
|  | ||||
| 	"golang.yandex/hasql/v2" | ||||
| ) | ||||
|  | ||||
| func newSQLRowError() *sql.Row { | ||||
| 	row := &sql.Row{} | ||||
| 	t := reflect.TypeOf(row).Elem() | ||||
| 	field, _ := t.FieldByName("err") | ||||
| 	rowPtr := unsafe.Pointer(row) | ||||
| 	errFieldPtr := unsafe.Pointer(uintptr(rowPtr) + field.Offset) | ||||
| 	errPtr := (*error)(errFieldPtr) | ||||
| 	*errPtr = ErrorNoAliveNodes | ||||
| 	return row | ||||
| } | ||||
|  | ||||
| type ClusterQuerier interface { | ||||
| 	Querier | ||||
| 	WaitForNodes(ctx context.Context, criterion ...hasql.NodeStateCriterion) error | ||||
| } | ||||
|  | ||||
| type Cluster struct { | ||||
| 	hasql   *hasql.Cluster[Querier] | ||||
| 	options ClusterOptions | ||||
| } | ||||
|  | ||||
| // NewCluster returns [Querier] that provides cluster of nodes | ||||
| func NewCluster[T Querier](opts ...ClusterOption) (ClusterQuerier, error) { | ||||
| 	options := ClusterOptions{Context: context.Background()} | ||||
| 	for _, opt := range opts { | ||||
| 		opt(&options) | ||||
| 	} | ||||
| 	if options.NodeChecker == nil { | ||||
| 		return nil, ErrClusterChecker | ||||
| 	} | ||||
| 	if options.NodeDiscoverer == nil { | ||||
| 		return nil, ErrClusterDiscoverer | ||||
| 	} | ||||
| 	if options.NodePicker == nil { | ||||
| 		return nil, ErrClusterPicker | ||||
| 	} | ||||
|  | ||||
| 	if options.Retries < 1 { | ||||
| 		options.Retries = 1 | ||||
| 	} | ||||
|  | ||||
| 	if options.NodeStateCriterion == 0 { | ||||
| 		options.NodeStateCriterion = hasql.Primary | ||||
| 	} | ||||
|  | ||||
| 	options.Options = append(options.Options, hasql.WithNodePicker(options.NodePicker)) | ||||
| 	if p, ok := options.NodePicker.(*CustomPicker[Querier]); ok { | ||||
| 		p.opts.Priority = options.NodePriority | ||||
| 	} | ||||
|  | ||||
| 	c, err := hasql.NewCluster( | ||||
| 		options.NodeDiscoverer, | ||||
| 		options.NodeChecker, | ||||
| 		options.Options..., | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return &Cluster{hasql: c, options: options}, nil | ||||
| } | ||||
|  | ||||
| func (c *Cluster) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { | ||||
| 	var tx *sql.Tx | ||||
| 	var err error | ||||
|  | ||||
| 	retries := 0 | ||||
| 	c.hasql.NodesIter(c.getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { | ||||
| 		for ; retries < c.options.Retries; retries++ { | ||||
| 			if tx, err = n.DB().BeginTx(ctx, opts); err != nil && retries >= c.options.Retries { | ||||
| 				return true | ||||
| 			} | ||||
| 		} | ||||
| 		return false | ||||
| 	}) | ||||
|  | ||||
| 	if tx == nil && err == nil { | ||||
| 		err = ErrorNoAliveNodes | ||||
| 	} | ||||
|  | ||||
| 	return tx, err | ||||
| } | ||||
|  | ||||
| func (c *Cluster) Close() error { | ||||
| 	return c.hasql.Close() | ||||
| } | ||||
|  | ||||
| func (c *Cluster) Conn(ctx context.Context) (*sql.Conn, error) { | ||||
| 	var conn *sql.Conn | ||||
| 	var err error | ||||
|  | ||||
| 	retries := 0 | ||||
| 	c.hasql.NodesIter(c.getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { | ||||
| 		for ; retries < c.options.Retries; retries++ { | ||||
| 			if conn, err = n.DB().Conn(ctx); err != nil && retries >= c.options.Retries { | ||||
| 				return true | ||||
| 			} | ||||
| 		} | ||||
| 		return false | ||||
| 	}) | ||||
|  | ||||
| 	if conn == nil && err == nil { | ||||
| 		err = ErrorNoAliveNodes | ||||
| 	} | ||||
|  | ||||
| 	return conn, err | ||||
| } | ||||
|  | ||||
| func (c *Cluster) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { | ||||
| 	var res sql.Result | ||||
| 	var err error | ||||
|  | ||||
| 	retries := 0 | ||||
| 	c.hasql.NodesIter(c.getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { | ||||
| 		for ; retries < c.options.Retries; retries++ { | ||||
| 			if res, err = n.DB().ExecContext(ctx, query, args...); err != nil && retries >= c.options.Retries { | ||||
| 				return true | ||||
| 			} | ||||
| 		} | ||||
| 		return false | ||||
| 	}) | ||||
|  | ||||
| 	if res == nil && err == nil { | ||||
| 		err = ErrorNoAliveNodes | ||||
| 	} | ||||
|  | ||||
| 	return res, err | ||||
| } | ||||
|  | ||||
| func (c *Cluster) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { | ||||
| 	var res *sql.Stmt | ||||
| 	var err error | ||||
|  | ||||
| 	retries := 0 | ||||
| 	c.hasql.NodesIter(c.getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { | ||||
| 		for ; retries < c.options.Retries; retries++ { | ||||
| 			if res, err = n.DB().PrepareContext(ctx, query); err != nil && retries >= c.options.Retries { | ||||
| 				return true | ||||
| 			} | ||||
| 		} | ||||
| 		return false | ||||
| 	}) | ||||
|  | ||||
| 	if res == nil && err == nil { | ||||
| 		err = ErrorNoAliveNodes | ||||
| 	} | ||||
|  | ||||
| 	return res, err | ||||
| } | ||||
|  | ||||
| func (c *Cluster) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { | ||||
| 	var res *sql.Rows | ||||
| 	var err error | ||||
|  | ||||
| 	retries := 0 | ||||
| 	c.hasql.NodesIter(c.getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { | ||||
| 		for ; retries < c.options.Retries; retries++ { | ||||
| 			if res, err = n.DB().QueryContext(ctx, query); err != nil && err != sql.ErrNoRows && retries >= c.options.Retries { | ||||
| 				return true | ||||
| 			} | ||||
| 		} | ||||
| 		return false | ||||
| 	}) | ||||
|  | ||||
| 	if res == nil && err == nil { | ||||
| 		err = ErrorNoAliveNodes | ||||
| 	} | ||||
|  | ||||
| 	return res, err | ||||
| } | ||||
|  | ||||
| func (c *Cluster) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { | ||||
| 	var res *sql.Row | ||||
|  | ||||
| 	retries := 0 | ||||
| 	c.hasql.NodesIter(c.getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { | ||||
| 		for ; retries < c.options.Retries; retries++ { | ||||
| 			res = n.DB().QueryRowContext(ctx, query, args...) | ||||
| 			if res.Err() == nil { | ||||
| 				return false | ||||
| 			} else if res.Err() != nil && retries >= c.options.Retries { | ||||
| 				return false | ||||
| 			} | ||||
| 		} | ||||
| 		return true | ||||
| 	}) | ||||
|  | ||||
| 	if res == nil { | ||||
| 		res = newSQLRowError() | ||||
| 	} | ||||
|  | ||||
| 	return res | ||||
| } | ||||
|  | ||||
| func (c *Cluster) PingContext(ctx context.Context) error { | ||||
| 	var err error | ||||
| 	var ok bool | ||||
|  | ||||
| 	retries := 0 | ||||
| 	c.hasql.NodesIter(c.getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { | ||||
| 		ok = true | ||||
| 		for ; retries < c.options.Retries; retries++ { | ||||
| 			if err = n.DB().PingContext(ctx); err != nil && retries >= c.options.Retries { | ||||
| 				return true | ||||
| 			} | ||||
| 		} | ||||
| 		return false | ||||
| 	}) | ||||
|  | ||||
| 	if !ok { | ||||
| 		err = ErrorNoAliveNodes | ||||
| 	} | ||||
|  | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (c *Cluster) WaitForNodes(ctx context.Context, criterions ...hasql.NodeStateCriterion) error { | ||||
| 	for _, criterion := range criterions { | ||||
| 		if _, err := c.hasql.WaitForNode(ctx, criterion); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
							
								
								
									
										171
									
								
								cluster/hasql/cluster_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										171
									
								
								cluster/hasql/cluster_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,171 @@ | ||||
| package sql | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/DATA-DOG/go-sqlmock" | ||||
| 	"golang.yandex/hasql/v2" | ||||
| ) | ||||
|  | ||||
| func TestNewCluster(t *testing.T) { | ||||
| 	dbMaster, dbMasterMock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	defer dbMaster.Close() | ||||
| 	dbMasterMock.MatchExpectationsInOrder(false) | ||||
|  | ||||
| 	dbMasterMock.ExpectQuery(`.*pg_is_in_recovery.*`).WillReturnRows( | ||||
| 		sqlmock.NewRowsWithColumnDefinition( | ||||
| 			sqlmock.NewColumn("role").OfType("int8", 0), | ||||
| 			sqlmock.NewColumn("replication_lag").OfType("int8", 0)). | ||||
| 			AddRow(1, 0)). | ||||
| 		RowsWillBeClosed(). | ||||
| 		WithoutArgs() | ||||
|  | ||||
| 	dbMasterMock.ExpectQuery(`SELECT node_name as name`).WillReturnRows( | ||||
| 		sqlmock.NewRows([]string{"name"}). | ||||
| 			AddRow("master-dc1")) | ||||
|  | ||||
| 	dbDRMaster, dbDRMasterMock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	defer dbDRMaster.Close() | ||||
| 	dbDRMasterMock.MatchExpectationsInOrder(false) | ||||
|  | ||||
| 	dbDRMasterMock.ExpectQuery(`.*pg_is_in_recovery.*`).WillReturnRows( | ||||
| 		sqlmock.NewRowsWithColumnDefinition( | ||||
| 			sqlmock.NewColumn("role").OfType("int8", 0), | ||||
| 			sqlmock.NewColumn("replication_lag").OfType("int8", 0)). | ||||
| 			AddRow(2, 40)). | ||||
| 		RowsWillBeClosed(). | ||||
| 		WithoutArgs() | ||||
|  | ||||
| 	dbDRMasterMock.ExpectQuery(`SELECT node_name as name`).WillReturnRows( | ||||
| 		sqlmock.NewRows([]string{"name"}). | ||||
| 			AddRow("drmaster1-dc2")) | ||||
|  | ||||
| 	dbDRMasterMock.ExpectQuery(`SELECT node_name as name`).WillReturnRows( | ||||
| 		sqlmock.NewRows([]string{"name"}). | ||||
| 			AddRow("drmaster")) | ||||
|  | ||||
| 	dbSlaveDC1, dbSlaveDC1Mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	defer dbSlaveDC1.Close() | ||||
| 	dbSlaveDC1Mock.MatchExpectationsInOrder(false) | ||||
|  | ||||
| 	dbSlaveDC1Mock.ExpectQuery(`.*pg_is_in_recovery.*`).WillReturnRows( | ||||
| 		sqlmock.NewRowsWithColumnDefinition( | ||||
| 			sqlmock.NewColumn("role").OfType("int8", 0), | ||||
| 			sqlmock.NewColumn("replication_lag").OfType("int8", 0)). | ||||
| 			AddRow(2, 50)). | ||||
| 		RowsWillBeClosed(). | ||||
| 		WithoutArgs() | ||||
|  | ||||
| 	dbSlaveDC1Mock.ExpectQuery(`SELECT node_name as name`).WillReturnRows( | ||||
| 		sqlmock.NewRows([]string{"name"}). | ||||
| 			AddRow("slave-dc1")) | ||||
|  | ||||
| 	dbSlaveDC2, dbSlaveDC2Mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	defer dbSlaveDC2.Close() | ||||
| 	dbSlaveDC1Mock.MatchExpectationsInOrder(false) | ||||
|  | ||||
| 	dbSlaveDC2Mock.ExpectQuery(`.*pg_is_in_recovery.*`).WillReturnRows( | ||||
| 		sqlmock.NewRowsWithColumnDefinition( | ||||
| 			sqlmock.NewColumn("role").OfType("int8", 0), | ||||
| 			sqlmock.NewColumn("replication_lag").OfType("int8", 0)). | ||||
| 			AddRow(2, 50)). | ||||
| 		RowsWillBeClosed(). | ||||
| 		WithoutArgs() | ||||
|  | ||||
| 	dbSlaveDC2Mock.ExpectQuery(`SELECT node_name as name`).WillReturnRows( | ||||
| 		sqlmock.NewRows([]string{"name"}). | ||||
| 			AddRow("slave-dc1")) | ||||
|  | ||||
| 	tctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) | ||||
| 	defer cancel() | ||||
|  | ||||
| 	c, err := NewCluster[Querier]( | ||||
| 		WithClusterContext(tctx), | ||||
| 		WithClusterNodeChecker(hasql.PostgreSQLChecker), | ||||
| 		WithClusterNodePicker(NewCustomPicker[Querier]( | ||||
| 			CustomPickerMaxLag(100), | ||||
| 		)), | ||||
| 		WithClusterNodes( | ||||
| 			ClusterNode{"slave-dc1", dbSlaveDC1, 1}, | ||||
| 			ClusterNode{"master-dc1", dbMaster, 1}, | ||||
| 			ClusterNode{"slave-dc2", dbSlaveDC2, 2}, | ||||
| 			ClusterNode{"drmaster1-dc2", dbDRMaster, 0}, | ||||
| 		), | ||||
| 		WithClusterOptions( | ||||
| 			hasql.WithUpdateInterval[Querier](2*time.Second), | ||||
| 			hasql.WithUpdateTimeout[Querier](1*time.Second), | ||||
| 		), | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	defer c.Close() | ||||
|  | ||||
| 	if err = c.WaitForNodes(tctx, hasql.Primary, hasql.Standby); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	time.Sleep(500 * time.Millisecond) | ||||
|  | ||||
| 	node1Name := "" | ||||
| 	fmt.Printf("check for Standby\n") | ||||
| 	if row := c.QueryRowContext(NodeStateCriterion(tctx, hasql.Standby), "SELECT node_name as name"); row.Err() != nil { | ||||
| 		t.Fatal(row.Err()) | ||||
| 	} else if err = row.Scan(&node1Name); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} else if "slave-dc1" != node1Name { | ||||
| 		t.Fatalf("invalid node name %s != %s", "slave-dc1", node1Name) | ||||
| 	} | ||||
|  | ||||
| 	dbSlaveDC1Mock.ExpectQuery(`SELECT node_name as name`).WillReturnRows( | ||||
| 		sqlmock.NewRows([]string{"name"}). | ||||
| 			AddRow("slave-dc1")) | ||||
|  | ||||
| 	node2Name := "" | ||||
| 	fmt.Printf("check for PreferStandby\n") | ||||
| 	if row := c.QueryRowContext(NodeStateCriterion(tctx, hasql.PreferStandby), "SELECT node_name as name"); row.Err() != nil { | ||||
| 		t.Fatal(row.Err()) | ||||
| 	} else if err = row.Scan(&node2Name); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} else if "slave-dc1" != node2Name { | ||||
| 		t.Fatalf("invalid node name %s != %s", "slave-dc1", node2Name) | ||||
| 	} | ||||
|  | ||||
| 	node3Name := "" | ||||
| 	fmt.Printf("check for PreferPrimary\n") | ||||
| 	if row := c.QueryRowContext(NodeStateCriterion(tctx, hasql.PreferPrimary), "SELECT node_name as name"); row.Err() != nil { | ||||
| 		t.Fatal(row.Err()) | ||||
| 	} else if err = row.Scan(&node3Name); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} else if "master-dc1" != node3Name { | ||||
| 		t.Fatalf("invalid node name %s != %s", "master-dc1", node3Name) | ||||
| 	} | ||||
|  | ||||
| 	dbSlaveDC1Mock.ExpectQuery(`.*`).WillReturnRows(sqlmock.NewRows([]string{"role"}).RowError(1, fmt.Errorf("row error"))) | ||||
|  | ||||
| 	time.Sleep(2 * time.Second) | ||||
|  | ||||
| 	fmt.Printf("check for PreferStandby\n") | ||||
| 	if row := c.QueryRowContext(NodeStateCriterion(tctx, hasql.PreferStandby), "SELECT node_name as name"); row.Err() == nil { | ||||
| 		t.Fatal("must return error") | ||||
| 	} | ||||
|  | ||||
| 	if dbMasterErr := dbMasterMock.ExpectationsWereMet(); dbMasterErr != nil { | ||||
| 		t.Error(dbMasterErr) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										25
									
								
								cluster/hasql/db.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								cluster/hasql/db.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,25 @@ | ||||
| package sql | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| ) | ||||
|  | ||||
| type Querier interface { | ||||
| 	// Basic connection methods | ||||
| 	PingContext(ctx context.Context) error | ||||
| 	Close() error | ||||
|  | ||||
| 	// Query methods with context | ||||
| 	ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) | ||||
| 	QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) | ||||
| 	QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row | ||||
|  | ||||
| 	// Prepared statements with context | ||||
| 	PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) | ||||
|  | ||||
| 	// Transaction management with context | ||||
| 	BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) | ||||
|  | ||||
| 	Conn(ctx context.Context) (*sql.Conn, error) | ||||
| } | ||||
							
								
								
									
										295
									
								
								cluster/hasql/driver.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										295
									
								
								cluster/hasql/driver.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,295 @@ | ||||
| package sql | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| 	"database/sql/driver" | ||||
| 	"io" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| // OpenDBWithCluster creates a [*sql.DB] that uses the [ClusterQuerier] | ||||
| func OpenDBWithCluster(db ClusterQuerier) (*sql.DB, error) { | ||||
| 	driver := NewClusterDriver(db) | ||||
| 	connector, err := driver.OpenConnector("") | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return sql.OpenDB(connector), nil | ||||
| } | ||||
|  | ||||
| // ClusterDriver implements [driver.Driver] and driver.Connector for an existing [Querier] | ||||
| type ClusterDriver struct { | ||||
| 	db ClusterQuerier | ||||
| } | ||||
|  | ||||
| // NewClusterDriver creates a new [driver.Driver] that uses an existing [ClusterQuerier] | ||||
| func NewClusterDriver(db ClusterQuerier) *ClusterDriver { | ||||
| 	return &ClusterDriver{db: db} | ||||
| } | ||||
|  | ||||
| // Open implements [driver.Driver.Open] | ||||
| func (d *ClusterDriver) Open(name string) (driver.Conn, error) { | ||||
| 	return d.Connect(context.Background()) | ||||
| } | ||||
|  | ||||
| // OpenConnector implements [driver.DriverContext.OpenConnector] | ||||
| func (d *ClusterDriver) OpenConnector(name string) (driver.Connector, error) { | ||||
| 	return d, nil | ||||
| } | ||||
|  | ||||
| // Connect implements [driver.Connector.Connect] | ||||
| func (d *ClusterDriver) Connect(ctx context.Context) (driver.Conn, error) { | ||||
| 	conn, err := d.db.Conn(ctx) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &dbConn{conn: conn}, nil | ||||
| } | ||||
|  | ||||
| // Driver implements [driver.Connector.Driver] | ||||
| func (d *ClusterDriver) Driver() driver.Driver { | ||||
| 	return d | ||||
| } | ||||
|  | ||||
| // dbConn implements driver.Conn with both context and legacy methods | ||||
| type dbConn struct { | ||||
| 	conn *sql.Conn | ||||
| 	mu   sync.Mutex | ||||
| } | ||||
|  | ||||
| // Prepare implements [driver.Conn.Prepare] (legacy method) | ||||
| func (c *dbConn) Prepare(query string) (driver.Stmt, error) { | ||||
| 	return c.PrepareContext(context.Background(), query) | ||||
| } | ||||
|  | ||||
| // PrepareContext implements [driver.ConnPrepareContext.PrepareContext] | ||||
| func (c *dbConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { | ||||
| 	c.mu.Lock() | ||||
| 	defer c.mu.Unlock() | ||||
|  | ||||
| 	stmt, err := c.conn.PrepareContext(ctx, query) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return &dbStmt{stmt: stmt}, nil | ||||
| } | ||||
|  | ||||
| // Exec implements [driver.Execer.Exec] (legacy method) | ||||
| func (c *dbConn) Exec(query string, args []driver.Value) (driver.Result, error) { | ||||
| 	namedArgs := make([]driver.NamedValue, len(args)) | ||||
| 	for i, value := range args { | ||||
| 		namedArgs[i] = driver.NamedValue{Value: value} | ||||
| 	} | ||||
| 	return c.ExecContext(context.Background(), query, namedArgs) | ||||
| } | ||||
|  | ||||
| // ExecContext implements [driver.ExecerContext.ExecContext] | ||||
| func (c *dbConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { | ||||
| 	c.mu.Lock() | ||||
| 	defer c.mu.Unlock() | ||||
|  | ||||
| 	// Convert driver.NamedValue to any | ||||
| 	interfaceArgs := make([]any, len(args)) | ||||
| 	for i, arg := range args { | ||||
| 		interfaceArgs[i] = arg.Value | ||||
| 	} | ||||
|  | ||||
| 	return c.conn.ExecContext(ctx, query, interfaceArgs...) | ||||
| } | ||||
|  | ||||
| // Query implements [driver.Queryer.Query] (legacy method) | ||||
| func (c *dbConn) Query(query string, args []driver.Value) (driver.Rows, error) { | ||||
| 	namedArgs := make([]driver.NamedValue, len(args)) | ||||
| 	for i, value := range args { | ||||
| 		namedArgs[i] = driver.NamedValue{Value: value} | ||||
| 	} | ||||
| 	return c.QueryContext(context.Background(), query, namedArgs) | ||||
| } | ||||
|  | ||||
| // QueryContext implements [driver.QueryerContext.QueryContext] | ||||
| func (c *dbConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { | ||||
| 	c.mu.Lock() | ||||
| 	defer c.mu.Unlock() | ||||
|  | ||||
| 	// Convert driver.NamedValue to any | ||||
| 	interfaceArgs := make([]any, len(args)) | ||||
| 	for i, arg := range args { | ||||
| 		interfaceArgs[i] = arg.Value | ||||
| 	} | ||||
|  | ||||
| 	rows, err := c.conn.QueryContext(ctx, query, interfaceArgs...) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return &dbRows{rows: rows}, nil | ||||
| } | ||||
|  | ||||
| // Begin implements [driver.Conn.Begin] (legacy method) | ||||
| func (c *dbConn) Begin() (driver.Tx, error) { | ||||
| 	return c.BeginTx(context.Background(), driver.TxOptions{}) | ||||
| } | ||||
|  | ||||
| // BeginTx implements [driver.ConnBeginTx.BeginTx] | ||||
| func (c *dbConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { | ||||
| 	c.mu.Lock() | ||||
| 	defer c.mu.Unlock() | ||||
|  | ||||
| 	sqlOpts := &sql.TxOptions{ | ||||
| 		Isolation: sql.IsolationLevel(opts.Isolation), | ||||
| 		ReadOnly:  opts.ReadOnly, | ||||
| 	} | ||||
|  | ||||
| 	tx, err := c.conn.BeginTx(ctx, sqlOpts) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return &dbTx{tx: tx}, nil | ||||
| } | ||||
|  | ||||
| // Ping implements [driver.Pinger.Ping] | ||||
| func (c *dbConn) Ping(ctx context.Context) error { | ||||
| 	return c.conn.PingContext(ctx) | ||||
| } | ||||
|  | ||||
| // Close implements [driver.Conn.Close] | ||||
| func (c *dbConn) Close() error { | ||||
| 	return c.conn.Close() | ||||
| } | ||||
|  | ||||
| // IsValid implements [driver.Validator.IsValid] | ||||
| func (c *dbConn) IsValid() bool { | ||||
| 	// Ping with a short timeout to check if the connection is still valid | ||||
| 	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) | ||||
| 	defer cancel() | ||||
|  | ||||
| 	return c.conn.PingContext(ctx) == nil | ||||
| } | ||||
|  | ||||
| // dbStmt implements [driver.Stmt] with both context and legacy methods | ||||
| type dbStmt struct { | ||||
| 	stmt *sql.Stmt | ||||
| 	mu   sync.Mutex | ||||
| } | ||||
|  | ||||
| // Close implements [driver.Stmt.Close] | ||||
| func (s *dbStmt) Close() error { | ||||
| 	s.mu.Lock() | ||||
| 	defer s.mu.Unlock() | ||||
| 	return s.stmt.Close() | ||||
| } | ||||
|  | ||||
| // Close implements [driver.Stmt.NumInput] | ||||
| func (s *dbStmt) NumInput() int { | ||||
| 	return -1 // Number of parameters is unknown | ||||
| } | ||||
|  | ||||
| // Exec implements [driver.Stmt.Exec] (legacy method) | ||||
| func (s *dbStmt) Exec(args []driver.Value) (driver.Result, error) { | ||||
| 	namedArgs := make([]driver.NamedValue, len(args)) | ||||
| 	for i, value := range args { | ||||
| 		namedArgs[i] = driver.NamedValue{Value: value} | ||||
| 	} | ||||
| 	return s.ExecContext(context.Background(), namedArgs) | ||||
| } | ||||
|  | ||||
| // ExecContext implements [driver.StmtExecContext.ExecContext] | ||||
| func (s *dbStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { | ||||
| 	s.mu.Lock() | ||||
| 	defer s.mu.Unlock() | ||||
|  | ||||
| 	interfaceArgs := make([]any, len(args)) | ||||
| 	for i, arg := range args { | ||||
| 		interfaceArgs[i] = arg.Value | ||||
| 	} | ||||
| 	return s.stmt.ExecContext(ctx, interfaceArgs...) | ||||
| } | ||||
|  | ||||
| // Query implements [driver.Stmt.Query] (legacy method) | ||||
| func (s *dbStmt) Query(args []driver.Value) (driver.Rows, error) { | ||||
| 	namedArgs := make([]driver.NamedValue, len(args)) | ||||
| 	for i, value := range args { | ||||
| 		namedArgs[i] = driver.NamedValue{Value: value} | ||||
| 	} | ||||
| 	return s.QueryContext(context.Background(), namedArgs) | ||||
| } | ||||
|  | ||||
| // QueryContext implements [driver.StmtQueryContext.QueryContext] | ||||
| func (s *dbStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { | ||||
| 	s.mu.Lock() | ||||
| 	defer s.mu.Unlock() | ||||
|  | ||||
| 	interfaceArgs := make([]any, len(args)) | ||||
| 	for i, arg := range args { | ||||
| 		interfaceArgs[i] = arg.Value | ||||
| 	} | ||||
|  | ||||
| 	rows, err := s.stmt.QueryContext(ctx, interfaceArgs...) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return &dbRows{rows: rows}, nil | ||||
| } | ||||
|  | ||||
| // dbRows implements [driver.Rows] | ||||
| type dbRows struct { | ||||
| 	rows *sql.Rows | ||||
| } | ||||
|  | ||||
| // Columns implements [driver.Rows.Columns] | ||||
| func (r *dbRows) Columns() []string { | ||||
| 	cols, err := r.rows.Columns() | ||||
| 	if err != nil { | ||||
| 		// This shouldn't happen if the query was successful | ||||
| 		return []string{} | ||||
| 	} | ||||
| 	return cols | ||||
| } | ||||
|  | ||||
| // Close implements [driver.Rows.Close] | ||||
| func (r *dbRows) Close() error { | ||||
| 	return r.rows.Close() | ||||
| } | ||||
|  | ||||
| // Next implements [driver.Rows.Next] | ||||
| func (r *dbRows) Next(dest []driver.Value) error { | ||||
| 	if !r.rows.Next() { | ||||
| 		if err := r.rows.Err(); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		return io.EOF | ||||
| 	} | ||||
|  | ||||
| 	// Create a slice of interfaces to scan into | ||||
| 	scanArgs := make([]any, len(dest)) | ||||
| 	for i := range scanArgs { | ||||
| 		scanArgs[i] = &dest[i] | ||||
| 	} | ||||
|  | ||||
| 	return r.rows.Scan(scanArgs...) | ||||
| } | ||||
|  | ||||
| // dbTx implements [driver.Tx] | ||||
| type dbTx struct { | ||||
| 	tx *sql.Tx | ||||
| 	mu sync.Mutex | ||||
| } | ||||
|  | ||||
| // Commit implements [driver.Tx.Commit] | ||||
| func (t *dbTx) Commit() error { | ||||
| 	t.mu.Lock() | ||||
| 	defer t.mu.Unlock() | ||||
| 	return t.tx.Commit() | ||||
| } | ||||
|  | ||||
| // Rollback implements [driver.Tx.Rollback] | ||||
| func (t *dbTx) Rollback() error { | ||||
| 	t.mu.Lock() | ||||
| 	defer t.mu.Unlock() | ||||
| 	return t.tx.Rollback() | ||||
| } | ||||
							
								
								
									
										141
									
								
								cluster/hasql/driver_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										141
									
								
								cluster/hasql/driver_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,141 @@ | ||||
| package sql | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/DATA-DOG/go-sqlmock" | ||||
| 	"golang.yandex/hasql/v2" | ||||
| ) | ||||
|  | ||||
| func TestDriver(t *testing.T) { | ||||
| 	dbMaster, dbMasterMock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	defer dbMaster.Close() | ||||
| 	dbMasterMock.MatchExpectationsInOrder(false) | ||||
|  | ||||
| 	dbMasterMock.ExpectQuery(`.*pg_is_in_recovery.*`).WillReturnRows( | ||||
| 		sqlmock.NewRowsWithColumnDefinition( | ||||
| 			sqlmock.NewColumn("role").OfType("int8", 0), | ||||
| 			sqlmock.NewColumn("replication_lag").OfType("int8", 0)). | ||||
| 			AddRow(1, 0)). | ||||
| 		RowsWillBeClosed(). | ||||
| 		WithoutArgs() | ||||
|  | ||||
| 	dbMasterMock.ExpectQuery(`SELECT node_name as name`).WillReturnRows( | ||||
| 		sqlmock.NewRows([]string{"name"}). | ||||
| 			AddRow("master-dc1")) | ||||
|  | ||||
| 	dbDRMaster, dbDRMasterMock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	defer dbDRMaster.Close() | ||||
| 	dbDRMasterMock.MatchExpectationsInOrder(false) | ||||
|  | ||||
| 	dbDRMasterMock.ExpectQuery(`.*pg_is_in_recovery.*`).WillReturnRows( | ||||
| 		sqlmock.NewRowsWithColumnDefinition( | ||||
| 			sqlmock.NewColumn("role").OfType("int8", 0), | ||||
| 			sqlmock.NewColumn("replication_lag").OfType("int8", 0)). | ||||
| 			AddRow(2, 40)). | ||||
| 		RowsWillBeClosed(). | ||||
| 		WithoutArgs() | ||||
|  | ||||
| 	dbDRMasterMock.ExpectQuery(`SELECT node_name as name`).WillReturnRows( | ||||
| 		sqlmock.NewRows([]string{"name"}). | ||||
| 			AddRow("drmaster1-dc2")) | ||||
|  | ||||
| 	dbDRMasterMock.ExpectQuery(`SELECT node_name as name`).WillReturnRows( | ||||
| 		sqlmock.NewRows([]string{"name"}). | ||||
| 			AddRow("drmaster")) | ||||
|  | ||||
| 	dbSlaveDC1, dbSlaveDC1Mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	defer dbSlaveDC1.Close() | ||||
| 	dbSlaveDC1Mock.MatchExpectationsInOrder(false) | ||||
|  | ||||
| 	dbSlaveDC1Mock.ExpectQuery(`.*pg_is_in_recovery.*`).WillReturnRows( | ||||
| 		sqlmock.NewRowsWithColumnDefinition( | ||||
| 			sqlmock.NewColumn("role").OfType("int8", 0), | ||||
| 			sqlmock.NewColumn("replication_lag").OfType("int8", 0)). | ||||
| 			AddRow(2, 50)). | ||||
| 		RowsWillBeClosed(). | ||||
| 		WithoutArgs() | ||||
|  | ||||
| 	dbSlaveDC1Mock.ExpectQuery(`SELECT node_name as name`).WillReturnRows( | ||||
| 		sqlmock.NewRows([]string{"name"}). | ||||
| 			AddRow("slave-dc1")) | ||||
|  | ||||
| 	dbSlaveDC2, dbSlaveDC2Mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	defer dbSlaveDC2.Close() | ||||
| 	dbSlaveDC1Mock.MatchExpectationsInOrder(false) | ||||
|  | ||||
| 	dbSlaveDC2Mock.ExpectQuery(`.*pg_is_in_recovery.*`).WillReturnRows( | ||||
| 		sqlmock.NewRowsWithColumnDefinition( | ||||
| 			sqlmock.NewColumn("role").OfType("int8", 0), | ||||
| 			sqlmock.NewColumn("replication_lag").OfType("int8", 0)). | ||||
| 			AddRow(2, 50)). | ||||
| 		RowsWillBeClosed(). | ||||
| 		WithoutArgs() | ||||
|  | ||||
| 	dbSlaveDC2Mock.ExpectQuery(`SELECT node_name as name`).WillReturnRows( | ||||
| 		sqlmock.NewRows([]string{"name"}). | ||||
| 			AddRow("slave-dc1")) | ||||
|  | ||||
| 	tctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) | ||||
| 	defer cancel() | ||||
|  | ||||
| 	c, err := NewCluster[Querier]( | ||||
| 		WithClusterContext(tctx), | ||||
| 		WithClusterNodeChecker(hasql.PostgreSQLChecker), | ||||
| 		WithClusterNodePicker(NewCustomPicker[Querier]( | ||||
| 			CustomPickerMaxLag(100), | ||||
| 		)), | ||||
| 		WithClusterNodes( | ||||
| 			ClusterNode{"slave-dc1", dbSlaveDC1, 1}, | ||||
| 			ClusterNode{"master-dc1", dbMaster, 1}, | ||||
| 			ClusterNode{"slave-dc2", dbSlaveDC2, 2}, | ||||
| 			ClusterNode{"drmaster1-dc2", dbDRMaster, 0}, | ||||
| 		), | ||||
| 		WithClusterOptions( | ||||
| 			hasql.WithUpdateInterval[Querier](2*time.Second), | ||||
| 			hasql.WithUpdateTimeout[Querier](1*time.Second), | ||||
| 		), | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	defer c.Close() | ||||
|  | ||||
| 	if err = c.WaitForNodes(tctx, hasql.Primary, hasql.Standby); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	db, err := OpenDBWithCluster(c) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	// Use context methods | ||||
| 	row := db.QueryRowContext(NodeStateCriterion(t.Context(), hasql.Primary), "SELECT node_name as name") | ||||
| 	if err = row.Err(); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	nodeName := "" | ||||
| 	if err = row.Scan(&nodeName); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	if nodeName != "master-dc1" { | ||||
| 		t.Fatalf("invalid node_name %s != %s", "master-dc1", nodeName) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										10
									
								
								cluster/hasql/error.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								cluster/hasql/error.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,10 @@ | ||||
| package sql | ||||
|  | ||||
| import "errors" | ||||
|  | ||||
| var ( | ||||
| 	ErrClusterChecker    = errors.New("cluster node checker required") | ||||
| 	ErrClusterDiscoverer = errors.New("cluster node discoverer required") | ||||
| 	ErrClusterPicker     = errors.New("cluster node picker required") | ||||
| 	ErrorNoAliveNodes    = errors.New("cluster no alive nodes") | ||||
| ) | ||||
							
								
								
									
										110
									
								
								cluster/hasql/options.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										110
									
								
								cluster/hasql/options.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,110 @@ | ||||
| package sql | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"math" | ||||
|  | ||||
| 	"golang.yandex/hasql/v2" | ||||
| ) | ||||
|  | ||||
| // ClusterOptions contains cluster specific options | ||||
| type ClusterOptions struct { | ||||
| 	NodeChecker        hasql.NodeChecker | ||||
| 	NodePicker         hasql.NodePicker[Querier] | ||||
| 	NodeDiscoverer     hasql.NodeDiscoverer[Querier] | ||||
| 	Options            []hasql.ClusterOpt[Querier] | ||||
| 	Context            context.Context | ||||
| 	Retries            int | ||||
| 	NodePriority       map[string]int32 | ||||
| 	NodeStateCriterion hasql.NodeStateCriterion | ||||
| } | ||||
|  | ||||
| // ClusterOption apply cluster options to ClusterOptions | ||||
| type ClusterOption func(*ClusterOptions) | ||||
|  | ||||
| // WithClusterNodeChecker pass hasql.NodeChecker to cluster options | ||||
| func WithClusterNodeChecker(c hasql.NodeChecker) ClusterOption { | ||||
| 	return func(o *ClusterOptions) { | ||||
| 		o.NodeChecker = c | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // WithClusterNodePicker pass hasql.NodePicker to cluster options | ||||
| func WithClusterNodePicker(p hasql.NodePicker[Querier]) ClusterOption { | ||||
| 	return func(o *ClusterOptions) { | ||||
| 		o.NodePicker = p | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // WithClusterNodeDiscoverer pass hasql.NodeDiscoverer to cluster options | ||||
| func WithClusterNodeDiscoverer(d hasql.NodeDiscoverer[Querier]) ClusterOption { | ||||
| 	return func(o *ClusterOptions) { | ||||
| 		o.NodeDiscoverer = d | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // WithRetries retry count on other nodes in case of error | ||||
| func WithRetries(n int) ClusterOption { | ||||
| 	return func(o *ClusterOptions) { | ||||
| 		o.Retries = n | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // WithClusterContext pass context.Context to cluster options and used for checks | ||||
| func WithClusterContext(ctx context.Context) ClusterOption { | ||||
| 	return func(o *ClusterOptions) { | ||||
| 		o.Context = ctx | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // WithClusterOptions pass hasql.ClusterOpt | ||||
| func WithClusterOptions(opts ...hasql.ClusterOpt[Querier]) ClusterOption { | ||||
| 	return func(o *ClusterOptions) { | ||||
| 		o.Options = append(o.Options, opts...) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // WithClusterNodeStateCriterion pass default hasql.NodeStateCriterion | ||||
| func WithClusterNodeStateCriterion(c hasql.NodeStateCriterion) ClusterOption { | ||||
| 	return func(o *ClusterOptions) { | ||||
| 		o.NodeStateCriterion = c | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type ClusterNode struct { | ||||
| 	Name     string | ||||
| 	DB       Querier | ||||
| 	Priority int32 | ||||
| } | ||||
|  | ||||
| // WithClusterNodes create cluster with static NodeDiscoverer | ||||
| func WithClusterNodes(cns ...ClusterNode) ClusterOption { | ||||
| 	return func(o *ClusterOptions) { | ||||
| 		nodes := make([]*hasql.Node[Querier], 0, len(cns)) | ||||
| 		if o.NodePriority == nil { | ||||
| 			o.NodePriority = make(map[string]int32, len(cns)) | ||||
| 		} | ||||
| 		for _, cn := range cns { | ||||
| 			nodes = append(nodes, hasql.NewNode(cn.Name, cn.DB)) | ||||
| 			if cn.Priority == 0 { | ||||
| 				cn.Priority = math.MaxInt32 | ||||
| 			} | ||||
| 			o.NodePriority[cn.Name] = cn.Priority | ||||
| 		} | ||||
| 		o.NodeDiscoverer = hasql.NewStaticNodeDiscoverer(nodes...) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type nodeStateCriterionKey struct{} | ||||
|  | ||||
| // NodeStateCriterion inject hasql.NodeStateCriterion to context | ||||
| func NodeStateCriterion(ctx context.Context, c hasql.NodeStateCriterion) context.Context { | ||||
| 	return context.WithValue(ctx, nodeStateCriterionKey{}, c) | ||||
| } | ||||
|  | ||||
| func (c *Cluster) getNodeStateCriterion(ctx context.Context) hasql.NodeStateCriterion { | ||||
| 	if v, ok := ctx.Value(nodeStateCriterionKey{}).(hasql.NodeStateCriterion); ok { | ||||
| 		return v | ||||
| 	} | ||||
| 	return c.options.NodeStateCriterion | ||||
| } | ||||
							
								
								
									
										113
									
								
								cluster/hasql/picker.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										113
									
								
								cluster/hasql/picker.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,113 @@ | ||||
| package sql | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"math" | ||||
| 	"time" | ||||
|  | ||||
| 	"golang.yandex/hasql/v2" | ||||
| ) | ||||
|  | ||||
| // compile time guard | ||||
| var _ hasql.NodePicker[Querier] = (*CustomPicker[Querier])(nil) | ||||
|  | ||||
| // CustomPickerOptions holds options to pick nodes | ||||
| type CustomPickerOptions struct { | ||||
| 	MaxLag   int | ||||
| 	Priority map[string]int32 | ||||
| 	Retries  int | ||||
| } | ||||
|  | ||||
| // CustomPickerOption func apply option to CustomPickerOptions | ||||
| type CustomPickerOption func(*CustomPickerOptions) | ||||
|  | ||||
| // CustomPickerMaxLag specifies max lag for which node can be used | ||||
| func CustomPickerMaxLag(n int) CustomPickerOption { | ||||
| 	return func(o *CustomPickerOptions) { | ||||
| 		o.MaxLag = n | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // NewCustomPicker creates new node picker | ||||
| func NewCustomPicker[T Querier](opts ...CustomPickerOption) *CustomPicker[Querier] { | ||||
| 	options := CustomPickerOptions{} | ||||
| 	for _, o := range opts { | ||||
| 		o(&options) | ||||
| 	} | ||||
| 	return &CustomPicker[Querier]{opts: options} | ||||
| } | ||||
|  | ||||
| // CustomPicker holds node picker options | ||||
| type CustomPicker[T Querier] struct { | ||||
| 	opts CustomPickerOptions | ||||
| } | ||||
|  | ||||
| // PickNode used to return specific node | ||||
| func (p *CustomPicker[T]) PickNode(cnodes []hasql.CheckedNode[T]) hasql.CheckedNode[T] { | ||||
| 	for _, n := range cnodes { | ||||
| 		fmt.Printf("node %s\n", n.Node.String()) | ||||
| 	} | ||||
| 	return cnodes[0] | ||||
| } | ||||
|  | ||||
| func (p *CustomPicker[T]) getPriority(nodeName string) int32 { | ||||
| 	if prio, ok := p.opts.Priority[nodeName]; ok { | ||||
| 		return prio | ||||
| 	} | ||||
| 	return math.MaxInt32 // Default to lowest priority | ||||
| } | ||||
|  | ||||
| // CompareNodes used to sort nodes | ||||
| func (p *CustomPicker[T]) CompareNodes(a, b hasql.CheckedNode[T]) int { | ||||
| 	// Get replication lag values | ||||
| 	aLag := a.Info.(interface{ ReplicationLag() int }).ReplicationLag() | ||||
| 	bLag := b.Info.(interface{ ReplicationLag() int }).ReplicationLag() | ||||
|  | ||||
| 	// First check that lag lower then MaxLag | ||||
| 	if aLag > p.opts.MaxLag && bLag > p.opts.MaxLag { | ||||
| 		return 0 // both are equal | ||||
| 	} | ||||
|  | ||||
| 	// If one node exceeds MaxLag and the other doesn't, prefer the one that doesn't | ||||
| 	if aLag > p.opts.MaxLag { | ||||
| 		return 1 // b is better | ||||
| 	} | ||||
| 	if bLag > p.opts.MaxLag { | ||||
| 		return -1 // a is better | ||||
| 	} | ||||
|  | ||||
| 	// Get node priorities | ||||
| 	aPrio := p.getPriority(a.Node.String()) | ||||
| 	bPrio := p.getPriority(b.Node.String()) | ||||
|  | ||||
| 	// if both priority equals | ||||
| 	if aPrio == bPrio { | ||||
| 		// First compare by replication lag | ||||
| 		if aLag < bLag { | ||||
| 			return -1 | ||||
| 		} | ||||
| 		if aLag > bLag { | ||||
| 			return 1 | ||||
| 		} | ||||
| 		// If replication lag is equal, compare by latency | ||||
| 		aLatency := a.Info.(interface{ Latency() time.Duration }).Latency() | ||||
| 		bLatency := b.Info.(interface{ Latency() time.Duration }).Latency() | ||||
|  | ||||
| 		if aLatency < bLatency { | ||||
| 			return -1 | ||||
| 		} | ||||
| 		if aLatency > bLatency { | ||||
| 			return 1 | ||||
| 		} | ||||
|  | ||||
| 		// If lag and latency is equal | ||||
| 		return 0 | ||||
| 	} | ||||
|  | ||||
| 	// If priorities are different, prefer the node with lower priority value | ||||
| 	if aPrio < bPrio { | ||||
| 		return -1 | ||||
| 	} | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
							
								
								
									
										531
									
								
								cluster/sql/cluster.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										531
									
								
								cluster/sql/cluster.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,531 @@ | ||||
| package sql | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"math" | ||||
| 	"reflect" | ||||
| 	"time" | ||||
| 	"unsafe" | ||||
|  | ||||
| 	"golang.yandex/hasql/v2" | ||||
| ) | ||||
|  | ||||
| var errNoAliveNodes = errors.New("no alive nodes") | ||||
|  | ||||
| func newSQLRowError() *sql.Row { | ||||
| 	row := &sql.Row{} | ||||
| 	t := reflect.TypeOf(row).Elem() | ||||
| 	field, _ := t.FieldByName("err") | ||||
| 	rowPtr := unsafe.Pointer(row) | ||||
| 	errFieldPtr := unsafe.Pointer(uintptr(rowPtr) + field.Offset) | ||||
| 	errPtr := (*error)(errFieldPtr) | ||||
| 	*errPtr = errNoAliveNodes | ||||
| 	return row | ||||
| } | ||||
|  | ||||
| type ClusterQuerier interface { | ||||
| 	Querier | ||||
| 	WaitForNodes(ctx context.Context, criterion ...hasql.NodeStateCriterion) error | ||||
| } | ||||
|  | ||||
| type Querier interface { | ||||
| 	// Basic connection methods | ||||
| 	PingContext(ctx context.Context) error | ||||
| 	Close() error | ||||
|  | ||||
| 	// Query methods with context | ||||
| 	ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) | ||||
| 	QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) | ||||
| 	QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row | ||||
|  | ||||
| 	// Prepared statements with context | ||||
| 	PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) | ||||
|  | ||||
| 	// Transaction management with context | ||||
| 	BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) | ||||
|  | ||||
| 	// Connection pool management | ||||
| 	SetConnMaxLifetime(d time.Duration) | ||||
| 	SetConnMaxIdleTime(d time.Duration) | ||||
| 	SetMaxOpenConns(n int) | ||||
| 	SetMaxIdleConns(n int) | ||||
| 	Stats() sql.DBStats | ||||
|  | ||||
| 	Conn(ctx context.Context) (*sql.Conn, error) | ||||
| } | ||||
|  | ||||
| var ( | ||||
| 	ErrClusterChecker    = errors.New("cluster node checker required") | ||||
| 	ErrClusterDiscoverer = errors.New("cluster node discoverer required") | ||||
| 	ErrClusterPicker     = errors.New("cluster node picker required") | ||||
| ) | ||||
|  | ||||
| type Cluster struct { | ||||
| 	hasql   *hasql.Cluster[Querier] | ||||
| 	options ClusterOptions | ||||
| } | ||||
|  | ||||
| // NewCluster returns Querier that provides cluster of nodes | ||||
| func NewCluster[T Querier](opts ...ClusterOption) (ClusterQuerier, error) { | ||||
| 	options := ClusterOptions{Context: context.Background()} | ||||
| 	for _, opt := range opts { | ||||
| 		opt(&options) | ||||
| 	} | ||||
| 	if options.NodeChecker == nil { | ||||
| 		return nil, ErrClusterChecker | ||||
| 	} | ||||
| 	if options.NodeDiscoverer == nil { | ||||
| 		return nil, ErrClusterDiscoverer | ||||
| 	} | ||||
| 	if options.NodePicker == nil { | ||||
| 		return nil, ErrClusterPicker | ||||
| 	} | ||||
|  | ||||
| 	if options.Retries < 1 { | ||||
| 		options.Retries = 1 | ||||
| 	} | ||||
|  | ||||
| 	if options.NodeStateCriterion == 0 { | ||||
| 		options.NodeStateCriterion = hasql.Primary | ||||
| 	} | ||||
|  | ||||
| 	options.Options = append(options.Options, hasql.WithNodePicker(options.NodePicker)) | ||||
| 	if p, ok := options.NodePicker.(*CustomPicker[Querier]); ok { | ||||
| 		p.opts.Priority = options.NodePriority | ||||
| 	} | ||||
|  | ||||
| 	c, err := hasql.NewCluster( | ||||
| 		options.NodeDiscoverer, | ||||
| 		options.NodeChecker, | ||||
| 		options.Options..., | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return &Cluster{hasql: c, options: options}, nil | ||||
| } | ||||
|  | ||||
| // compile time guard | ||||
| var _ hasql.NodePicker[Querier] = (*CustomPicker[Querier])(nil) | ||||
|  | ||||
| type nodeStateCriterionKey struct{} | ||||
|  | ||||
| // NodeStateCriterion inject hasql.NodeStateCriterion to context | ||||
| func NodeStateCriterion(ctx context.Context, c hasql.NodeStateCriterion) context.Context { | ||||
| 	return context.WithValue(ctx, nodeStateCriterionKey{}, c) | ||||
| } | ||||
|  | ||||
| // CustomPickerOptions holds options to pick nodes | ||||
| type CustomPickerOptions struct { | ||||
| 	MaxLag   int | ||||
| 	Priority map[string]int32 | ||||
| 	Retries  int | ||||
| } | ||||
|  | ||||
| // CustomPickerOption func apply option to CustomPickerOptions | ||||
| type CustomPickerOption func(*CustomPickerOptions) | ||||
|  | ||||
| // CustomPickerMaxLag specifies max lag for which node can be used | ||||
| func CustomPickerMaxLag(n int) CustomPickerOption { | ||||
| 	return func(o *CustomPickerOptions) { | ||||
| 		o.MaxLag = n | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // NewCustomPicker creates new node picker | ||||
| func NewCustomPicker[T Querier](opts ...CustomPickerOption) *CustomPicker[Querier] { | ||||
| 	options := CustomPickerOptions{} | ||||
| 	for _, o := range opts { | ||||
| 		o(&options) | ||||
| 	} | ||||
| 	return &CustomPicker[Querier]{opts: options} | ||||
| } | ||||
|  | ||||
| // CustomPicker holds node picker options | ||||
| type CustomPicker[T Querier] struct { | ||||
| 	opts CustomPickerOptions | ||||
| } | ||||
|  | ||||
| // PickNode used to return specific node | ||||
| func (p *CustomPicker[T]) PickNode(cnodes []hasql.CheckedNode[T]) hasql.CheckedNode[T] { | ||||
| 	for _, n := range cnodes { | ||||
| 		fmt.Printf("node %s\n", n.Node.String()) | ||||
| 	} | ||||
| 	return cnodes[0] | ||||
| } | ||||
|  | ||||
| func (p *CustomPicker[T]) getPriority(nodeName string) int32 { | ||||
| 	if prio, ok := p.opts.Priority[nodeName]; ok { | ||||
| 		return prio | ||||
| 	} | ||||
| 	return math.MaxInt32 // Default to lowest priority | ||||
| } | ||||
|  | ||||
| // CompareNodes used to sort nodes | ||||
| func (p *CustomPicker[T]) CompareNodes(a, b hasql.CheckedNode[T]) int { | ||||
| 	fmt.Printf("CompareNodes %s %s\n", a.Node.String(), b.Node.String()) | ||||
| 	// Get replication lag values | ||||
| 	aLag := a.Info.(interface{ ReplicationLag() int }).ReplicationLag() | ||||
| 	bLag := b.Info.(interface{ ReplicationLag() int }).ReplicationLag() | ||||
|  | ||||
| 	// First check that lag lower then MaxLag | ||||
| 	if aLag > p.opts.MaxLag && bLag > p.opts.MaxLag { | ||||
| 		fmt.Printf("CompareNodes aLag > p.opts.MaxLag && bLag > p.opts.MaxLag\n") | ||||
| 		return 0 // both are equal | ||||
| 	} | ||||
|  | ||||
| 	// If one node exceeds MaxLag and the other doesn't, prefer the one that doesn't | ||||
| 	if aLag > p.opts.MaxLag { | ||||
| 		fmt.Printf("CompareNodes aLag > p.opts.MaxLag\n") | ||||
| 		return 1 // b is better | ||||
| 	} | ||||
| 	if bLag > p.opts.MaxLag { | ||||
| 		fmt.Printf("CompareNodes bLag > p.opts.MaxLag\n") | ||||
| 		return -1 // a is better | ||||
| 	} | ||||
|  | ||||
| 	// Get node priorities | ||||
| 	aPrio := p.getPriority(a.Node.String()) | ||||
| 	bPrio := p.getPriority(b.Node.String()) | ||||
|  | ||||
| 	// if both priority equals | ||||
| 	if aPrio == bPrio { | ||||
| 		fmt.Printf("CompareNodes aPrio == bPrio\n") | ||||
| 		// First compare by replication lag | ||||
| 		if aLag < bLag { | ||||
| 			fmt.Printf("CompareNodes aLag < bLag\n") | ||||
| 			return -1 | ||||
| 		} | ||||
| 		if aLag > bLag { | ||||
| 			fmt.Printf("CompareNodes aLag > bLag\n") | ||||
| 			return 1 | ||||
| 		} | ||||
| 		// If replication lag is equal, compare by latency | ||||
| 		aLatency := a.Info.(interface{ Latency() time.Duration }).Latency() | ||||
| 		bLatency := b.Info.(interface{ Latency() time.Duration }).Latency() | ||||
|  | ||||
| 		if aLatency < bLatency { | ||||
| 			return -1 | ||||
| 		} | ||||
| 		if aLatency > bLatency { | ||||
| 			return 1 | ||||
| 		} | ||||
|  | ||||
| 		// If lag and latency is equal | ||||
| 		return 0 | ||||
| 	} | ||||
|  | ||||
| 	// If priorities are different, prefer the node with lower priority value | ||||
| 	if aPrio < bPrio { | ||||
| 		return -1 | ||||
| 	} | ||||
|  | ||||
| 	return 1 | ||||
| } | ||||
|  | ||||
| // ClusterOptions contains cluster specific options | ||||
| type ClusterOptions struct { | ||||
| 	NodeChecker        hasql.NodeChecker | ||||
| 	NodePicker         hasql.NodePicker[Querier] | ||||
| 	NodeDiscoverer     hasql.NodeDiscoverer[Querier] | ||||
| 	Options            []hasql.ClusterOpt[Querier] | ||||
| 	Context            context.Context | ||||
| 	Retries            int | ||||
| 	NodePriority       map[string]int32 | ||||
| 	NodeStateCriterion hasql.NodeStateCriterion | ||||
| } | ||||
|  | ||||
| // ClusterOption apply cluster options to ClusterOptions | ||||
| type ClusterOption func(*ClusterOptions) | ||||
|  | ||||
| // WithClusterNodeChecker pass hasql.NodeChecker to cluster options | ||||
| func WithClusterNodeChecker(c hasql.NodeChecker) ClusterOption { | ||||
| 	return func(o *ClusterOptions) { | ||||
| 		o.NodeChecker = c | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // WithClusterNodePicker pass hasql.NodePicker to cluster options | ||||
| func WithClusterNodePicker(p hasql.NodePicker[Querier]) ClusterOption { | ||||
| 	return func(o *ClusterOptions) { | ||||
| 		o.NodePicker = p | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // WithClusterNodeDiscoverer pass hasql.NodeDiscoverer to cluster options | ||||
| func WithClusterNodeDiscoverer(d hasql.NodeDiscoverer[Querier]) ClusterOption { | ||||
| 	return func(o *ClusterOptions) { | ||||
| 		o.NodeDiscoverer = d | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // WithRetries retry count on other nodes in case of error | ||||
| func WithRetries(n int) ClusterOption { | ||||
| 	return func(o *ClusterOptions) { | ||||
| 		o.Retries = n | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // WithClusterContext pass context.Context to cluster options and used for checks | ||||
| func WithClusterContext(ctx context.Context) ClusterOption { | ||||
| 	return func(o *ClusterOptions) { | ||||
| 		o.Context = ctx | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // WithClusterOptions pass hasql.ClusterOpt | ||||
| func WithClusterOptions(opts ...hasql.ClusterOpt[Querier]) ClusterOption { | ||||
| 	return func(o *ClusterOptions) { | ||||
| 		o.Options = append(o.Options, opts...) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // WithClusterNodeStateCriterion pass default hasql.NodeStateCriterion | ||||
| func WithClusterNodeStateCriterion(c hasql.NodeStateCriterion) ClusterOption { | ||||
| 	return func(o *ClusterOptions) { | ||||
| 		o.NodeStateCriterion = c | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type ClusterNode struct { | ||||
| 	Name     string | ||||
| 	DB       Querier | ||||
| 	Priority int32 | ||||
| } | ||||
|  | ||||
| // WithClusterNodes create cluster with static NodeDiscoverer | ||||
| func WithClusterNodes(cns ...ClusterNode) ClusterOption { | ||||
| 	return func(o *ClusterOptions) { | ||||
| 		nodes := make([]*hasql.Node[Querier], 0, len(cns)) | ||||
| 		if o.NodePriority == nil { | ||||
| 			o.NodePriority = make(map[string]int32, len(cns)) | ||||
| 		} | ||||
| 		for _, cn := range cns { | ||||
| 			nodes = append(nodes, hasql.NewNode(cn.Name, cn.DB)) | ||||
| 			if cn.Priority == 0 { | ||||
| 				cn.Priority = math.MaxInt32 | ||||
| 			} | ||||
| 			o.NodePriority[cn.Name] = cn.Priority | ||||
| 		} | ||||
| 		o.NodeDiscoverer = hasql.NewStaticNodeDiscoverer(nodes...) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (c *Cluster) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { | ||||
| 	var tx *sql.Tx | ||||
| 	var err error | ||||
|  | ||||
| 	retries := 0 | ||||
| 	c.hasql.NodesIter(c.getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { | ||||
| 		for ; retries < c.options.Retries; retries++ { | ||||
| 			if tx, err = n.DB().BeginTx(ctx, opts); err != nil && retries >= c.options.Retries { | ||||
| 				return true | ||||
| 			} | ||||
| 		} | ||||
| 		return false | ||||
| 	}) | ||||
|  | ||||
| 	if tx == nil && err == nil { | ||||
| 		err = errNoAliveNodes | ||||
| 	} | ||||
|  | ||||
| 	return tx, err | ||||
| } | ||||
|  | ||||
| func (c *Cluster) Close() error { | ||||
| 	return c.hasql.Close() | ||||
| } | ||||
|  | ||||
| func (c *Cluster) Conn(ctx context.Context) (*sql.Conn, error) { | ||||
| 	var conn *sql.Conn | ||||
| 	var err error | ||||
|  | ||||
| 	retries := 0 | ||||
| 	c.hasql.NodesIter(c.getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { | ||||
| 		for ; retries < c.options.Retries; retries++ { | ||||
| 			if conn, err = n.DB().Conn(ctx); err != nil && retries >= c.options.Retries { | ||||
| 				return true | ||||
| 			} | ||||
| 		} | ||||
| 		return false | ||||
| 	}) | ||||
|  | ||||
| 	if conn == nil && err == nil { | ||||
| 		err = errNoAliveNodes | ||||
| 	} | ||||
|  | ||||
| 	return conn, err | ||||
| } | ||||
|  | ||||
| func (c *Cluster) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { | ||||
| 	var res sql.Result | ||||
| 	var err error | ||||
|  | ||||
| 	retries := 0 | ||||
| 	c.hasql.NodesIter(c.getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { | ||||
| 		for ; retries < c.options.Retries; retries++ { | ||||
| 			if res, err = n.DB().ExecContext(ctx, query, args...); err != nil && retries >= c.options.Retries { | ||||
| 				return true | ||||
| 			} | ||||
| 		} | ||||
| 		return false | ||||
| 	}) | ||||
|  | ||||
| 	if res == nil && err == nil { | ||||
| 		err = errNoAliveNodes | ||||
| 	} | ||||
|  | ||||
| 	return res, err | ||||
| } | ||||
|  | ||||
| func (c *Cluster) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { | ||||
| 	var res *sql.Stmt | ||||
| 	var err error | ||||
|  | ||||
| 	retries := 0 | ||||
| 	c.hasql.NodesIter(c.getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { | ||||
| 		for ; retries < c.options.Retries; retries++ { | ||||
| 			if res, err = n.DB().PrepareContext(ctx, query); err != nil && retries >= c.options.Retries { | ||||
| 				return true | ||||
| 			} | ||||
| 		} | ||||
| 		return false | ||||
| 	}) | ||||
|  | ||||
| 	if res == nil && err == nil { | ||||
| 		err = errNoAliveNodes | ||||
| 	} | ||||
|  | ||||
| 	return res, err | ||||
| } | ||||
|  | ||||
| func (c *Cluster) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { | ||||
| 	var res *sql.Rows | ||||
| 	var err error | ||||
|  | ||||
| 	retries := 0 | ||||
| 	c.hasql.NodesIter(c.getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { | ||||
| 		for ; retries < c.options.Retries; retries++ { | ||||
| 			if res, err = n.DB().QueryContext(ctx, query); err != nil && err != sql.ErrNoRows && retries >= c.options.Retries { | ||||
| 				return true | ||||
| 			} | ||||
| 		} | ||||
| 		return false | ||||
| 	}) | ||||
|  | ||||
| 	if res == nil && err == nil { | ||||
| 		err = errNoAliveNodes | ||||
| 	} | ||||
|  | ||||
| 	return res, err | ||||
| } | ||||
|  | ||||
| func (c *Cluster) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { | ||||
| 	var res *sql.Row | ||||
|  | ||||
| 	retries := 0 | ||||
| 	c.hasql.NodesIter(c.getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { | ||||
| 		for ; retries < c.options.Retries; retries++ { | ||||
| 			res = n.DB().QueryRowContext(ctx, query, args...) | ||||
| 			if res.Err() == nil { | ||||
| 				return false | ||||
| 			} else if res.Err() != nil && retries >= c.options.Retries { | ||||
| 				return false | ||||
| 			} | ||||
| 		} | ||||
| 		return true | ||||
| 	}) | ||||
|  | ||||
| 	if res == nil { | ||||
| 		res = newSQLRowError() | ||||
| 	} | ||||
|  | ||||
| 	return res | ||||
| } | ||||
|  | ||||
| func (c *Cluster) PingContext(ctx context.Context) error { | ||||
| 	var err error | ||||
| 	var ok bool | ||||
|  | ||||
| 	retries := 0 | ||||
| 	c.hasql.NodesIter(c.getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { | ||||
| 		ok = true | ||||
| 		for ; retries < c.options.Retries; retries++ { | ||||
| 			if err = n.DB().PingContext(ctx); err != nil && retries >= c.options.Retries { | ||||
| 				return true | ||||
| 			} | ||||
| 		} | ||||
| 		return false | ||||
| 	}) | ||||
|  | ||||
| 	if !ok { | ||||
| 		err = errNoAliveNodes | ||||
| 	} | ||||
|  | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (c *Cluster) WaitForNodes(ctx context.Context, criterions ...hasql.NodeStateCriterion) error { | ||||
| 	for _, criterion := range criterions { | ||||
| 		if _, err := c.hasql.WaitForNode(ctx, criterion); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (c *Cluster) SetConnMaxLifetime(td time.Duration) { | ||||
| 	c.hasql.NodesIter(hasql.NodeStateCriterion(hasql.Alive))(func(n *hasql.Node[Querier]) bool { | ||||
| 		n.DB().SetConnMaxIdleTime(td) | ||||
| 		return false | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func (c *Cluster) SetConnMaxIdleTime(td time.Duration) { | ||||
| 	c.hasql.NodesIter(hasql.NodeStateCriterion(hasql.Alive))(func(n *hasql.Node[Querier]) bool { | ||||
| 		n.DB().SetConnMaxIdleTime(td) | ||||
| 		return false | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func (c *Cluster) SetMaxOpenConns(nc int) { | ||||
| 	c.hasql.NodesIter(hasql.NodeStateCriterion(hasql.Alive))(func(n *hasql.Node[Querier]) bool { | ||||
| 		n.DB().SetMaxOpenConns(nc) | ||||
| 		return false | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func (c *Cluster) SetMaxIdleConns(nc int) { | ||||
| 	c.hasql.NodesIter(hasql.NodeStateCriterion(hasql.Alive))(func(n *hasql.Node[Querier]) bool { | ||||
| 		n.DB().SetMaxIdleConns(nc) | ||||
| 		return false | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func (c *Cluster) Stats() sql.DBStats { | ||||
| 	s := sql.DBStats{} | ||||
| 	c.hasql.NodesIter(hasql.NodeStateCriterion(hasql.Alive))(func(n *hasql.Node[Querier]) bool { | ||||
| 		st := n.DB().Stats() | ||||
| 		s.Idle += st.Idle | ||||
| 		s.InUse += st.InUse | ||||
| 		s.MaxIdleClosed += st.MaxIdleClosed | ||||
| 		s.MaxIdleTimeClosed += st.MaxIdleTimeClosed | ||||
| 		s.MaxOpenConnections += st.MaxOpenConnections | ||||
| 		s.OpenConnections += st.OpenConnections | ||||
| 		s.WaitCount += st.WaitCount | ||||
| 		s.WaitDuration += st.WaitDuration | ||||
| 		return false | ||||
| 	}) | ||||
| 	return s | ||||
| } | ||||
|  | ||||
| func (c *Cluster) getNodeStateCriterion(ctx context.Context) hasql.NodeStateCriterion { | ||||
| 	if v, ok := ctx.Value(nodeStateCriterionKey{}).(hasql.NodeStateCriterion); ok { | ||||
| 		return v | ||||
| 	} | ||||
| 	return c.options.NodeStateCriterion | ||||
| } | ||||
							
								
								
									
										171
									
								
								cluster/sql/cluster_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										171
									
								
								cluster/sql/cluster_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,171 @@ | ||||
| package sql | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/DATA-DOG/go-sqlmock" | ||||
| 	"golang.yandex/hasql/v2" | ||||
| ) | ||||
|  | ||||
| func TestNewCluster(t *testing.T) { | ||||
| 	dbMaster, dbMasterMock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	defer dbMaster.Close() | ||||
| 	dbMasterMock.MatchExpectationsInOrder(false) | ||||
|  | ||||
| 	dbMasterMock.ExpectQuery(`.*pg_is_in_recovery.*`).WillReturnRows( | ||||
| 		sqlmock.NewRowsWithColumnDefinition( | ||||
| 			sqlmock.NewColumn("role").OfType("int8", 0), | ||||
| 			sqlmock.NewColumn("replication_lag").OfType("int8", 0)). | ||||
| 			AddRow(1, 0)). | ||||
| 		RowsWillBeClosed(). | ||||
| 		WithoutArgs() | ||||
|  | ||||
| 	dbMasterMock.ExpectQuery(`SELECT node_name as name`).WillReturnRows( | ||||
| 		sqlmock.NewRows([]string{"name"}). | ||||
| 			AddRow("master-dc1")) | ||||
|  | ||||
| 	dbDRMaster, dbDRMasterMock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	defer dbDRMaster.Close() | ||||
| 	dbDRMasterMock.MatchExpectationsInOrder(false) | ||||
|  | ||||
| 	dbDRMasterMock.ExpectQuery(`.*pg_is_in_recovery.*`).WillReturnRows( | ||||
| 		sqlmock.NewRowsWithColumnDefinition( | ||||
| 			sqlmock.NewColumn("role").OfType("int8", 0), | ||||
| 			sqlmock.NewColumn("replication_lag").OfType("int8", 0)). | ||||
| 			AddRow(2, 40)). | ||||
| 		RowsWillBeClosed(). | ||||
| 		WithoutArgs() | ||||
|  | ||||
| 	dbDRMasterMock.ExpectQuery(`SELECT node_name as name`).WillReturnRows( | ||||
| 		sqlmock.NewRows([]string{"name"}). | ||||
| 			AddRow("drmaster1-dc2")) | ||||
|  | ||||
| 	dbDRMasterMock.ExpectQuery(`SELECT node_name as name`).WillReturnRows( | ||||
| 		sqlmock.NewRows([]string{"name"}). | ||||
| 			AddRow("drmaster")) | ||||
|  | ||||
| 	dbSlaveDC1, dbSlaveDC1Mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	defer dbSlaveDC1.Close() | ||||
| 	dbSlaveDC1Mock.MatchExpectationsInOrder(false) | ||||
|  | ||||
| 	dbSlaveDC1Mock.ExpectQuery(`.*pg_is_in_recovery.*`).WillReturnRows( | ||||
| 		sqlmock.NewRowsWithColumnDefinition( | ||||
| 			sqlmock.NewColumn("role").OfType("int8", 0), | ||||
| 			sqlmock.NewColumn("replication_lag").OfType("int8", 0)). | ||||
| 			AddRow(2, 50)). | ||||
| 		RowsWillBeClosed(). | ||||
| 		WithoutArgs() | ||||
|  | ||||
| 	dbSlaveDC1Mock.ExpectQuery(`SELECT node_name as name`).WillReturnRows( | ||||
| 		sqlmock.NewRows([]string{"name"}). | ||||
| 			AddRow("slave-dc1")) | ||||
|  | ||||
| 	dbSlaveDC2, dbSlaveDC2Mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	defer dbSlaveDC2.Close() | ||||
| 	dbSlaveDC1Mock.MatchExpectationsInOrder(false) | ||||
|  | ||||
| 	dbSlaveDC2Mock.ExpectQuery(`.*pg_is_in_recovery.*`).WillReturnRows( | ||||
| 		sqlmock.NewRowsWithColumnDefinition( | ||||
| 			sqlmock.NewColumn("role").OfType("int8", 0), | ||||
| 			sqlmock.NewColumn("replication_lag").OfType("int8", 0)). | ||||
| 			AddRow(2, 50)). | ||||
| 		RowsWillBeClosed(). | ||||
| 		WithoutArgs() | ||||
|  | ||||
| 	dbSlaveDC2Mock.ExpectQuery(`SELECT node_name as name`).WillReturnRows( | ||||
| 		sqlmock.NewRows([]string{"name"}). | ||||
| 			AddRow("slave-dc1")) | ||||
|  | ||||
| 	tctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) | ||||
| 	defer cancel() | ||||
|  | ||||
| 	c, err := NewCluster[Querier]( | ||||
| 		WithClusterContext(tctx), | ||||
| 		WithClusterNodeChecker(hasql.PostgreSQLChecker), | ||||
| 		WithClusterNodePicker(NewCustomPicker[Querier]( | ||||
| 			CustomPickerMaxLag(100), | ||||
| 		)), | ||||
| 		WithClusterNodes( | ||||
| 			ClusterNode{"slave-dc1", dbSlaveDC1, 1}, | ||||
| 			ClusterNode{"master-dc1", dbMaster, 1}, | ||||
| 			ClusterNode{"slave-dc2", dbSlaveDC2, 2}, | ||||
| 			ClusterNode{"drmaster1-dc2", dbDRMaster, 0}, | ||||
| 		), | ||||
| 		WithClusterOptions( | ||||
| 			hasql.WithUpdateInterval[Querier](2*time.Second), | ||||
| 			hasql.WithUpdateTimeout[Querier](1*time.Second), | ||||
| 		), | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	defer c.Close() | ||||
|  | ||||
| 	if err = c.WaitForNodes(tctx, hasql.Primary, hasql.Standby); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	time.Sleep(500 * time.Millisecond) | ||||
|  | ||||
| 	node1Name := "" | ||||
| 	fmt.Printf("check for Standby\n") | ||||
| 	if row := c.QueryRowContext(NodeStateCriterion(tctx, hasql.Standby), "SELECT node_name as name"); row.Err() != nil { | ||||
| 		t.Fatal(row.Err()) | ||||
| 	} else if err = row.Scan(&node1Name); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} else if "slave-dc1" != node1Name { | ||||
| 		t.Fatalf("invalid node name %s != %s", "slave-dc1", node1Name) | ||||
| 	} | ||||
|  | ||||
| 	dbSlaveDC1Mock.ExpectQuery(`SELECT node_name as name`).WillReturnRows( | ||||
| 		sqlmock.NewRows([]string{"name"}). | ||||
| 			AddRow("slave-dc1")) | ||||
|  | ||||
| 	node2Name := "" | ||||
| 	fmt.Printf("check for PreferStandby\n") | ||||
| 	if row := c.QueryRowContext(NodeStateCriterion(tctx, hasql.PreferStandby), "SELECT node_name as name"); row.Err() != nil { | ||||
| 		t.Fatal(row.Err()) | ||||
| 	} else if err = row.Scan(&node2Name); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} else if "slave-dc1" != node2Name { | ||||
| 		t.Fatalf("invalid node name %s != %s", "slave-dc1", node2Name) | ||||
| 	} | ||||
|  | ||||
| 	node3Name := "" | ||||
| 	fmt.Printf("check for PreferPrimary\n") | ||||
| 	if row := c.QueryRowContext(NodeStateCriterion(tctx, hasql.PreferPrimary), "SELECT node_name as name"); row.Err() != nil { | ||||
| 		t.Fatal(row.Err()) | ||||
| 	} else if err = row.Scan(&node3Name); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} else if "master-dc1" != node3Name { | ||||
| 		t.Fatalf("invalid node name %s != %s", "master-dc1", node3Name) | ||||
| 	} | ||||
|  | ||||
| 	dbSlaveDC1Mock.ExpectQuery(`.*`).WillReturnRows(sqlmock.NewRows([]string{"role"}).RowError(1, fmt.Errorf("row error"))) | ||||
|  | ||||
| 	time.Sleep(2 * time.Second) | ||||
|  | ||||
| 	fmt.Printf("check for PreferStandby\n") | ||||
| 	if row := c.QueryRowContext(NodeStateCriterion(tctx, hasql.PreferStandby), "SELECT node_name as name"); row.Err() == nil { | ||||
| 		t.Fatal("must return error") | ||||
| 	} | ||||
|  | ||||
| 	if dbMasterErr := dbMasterMock.ExpectationsWereMet(); dbMasterErr != nil { | ||||
| 		t.Error(dbMasterErr) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										17
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										17
									
								
								go.mod
									
									
									
									
									
								
							| @@ -1,21 +1,23 @@ | ||||
| module go.unistack.org/micro/v3 | ||||
|  | ||||
| go 1.23.4 | ||||
| go 1.24.0 | ||||
|  | ||||
| require ( | ||||
| 	dario.cat/mergo v1.0.1 | ||||
| 	github.com/DATA-DOG/go-sqlmock v1.5.0 | ||||
| 	github.com/DATA-DOG/go-sqlmock v1.5.2 | ||||
| 	github.com/KimMachineGun/automemlimit v0.6.1 | ||||
| 	github.com/ash3in/uuidv8 v1.0.1 | ||||
| 	github.com/ash3in/uuidv8 v1.2.0 | ||||
| 	github.com/google/uuid v1.6.0 | ||||
| 	github.com/matoous/go-nanoid v1.5.1 | ||||
| 	github.com/patrickmn/go-cache v2.1.0+incompatible | ||||
| 	github.com/silas/dag v0.0.0-20220518035006-a7e85ada93c5 | ||||
| 	github.com/stretchr/testify v1.10.0 | ||||
| 	go.uber.org/automaxprocs v1.6.0 | ||||
| 	go.unistack.org/micro-proto/v3 v3.4.1 | ||||
| 	golang.org/x/sync v0.10.0 | ||||
| 	google.golang.org/grpc v1.68.1 | ||||
| 	google.golang.org/protobuf v1.35.2 | ||||
| 	golang.yandex/hasql/v2 v2.1.0 | ||||
| 	google.golang.org/grpc v1.69.2 | ||||
| 	google.golang.org/protobuf v1.36.1 | ||||
| 	gopkg.in/yaml.v3 v3.0.1 | ||||
| ) | ||||
|  | ||||
| @@ -33,11 +35,10 @@ require ( | ||||
| 	github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect | ||||
| 	github.com/rogpeppe/go-internal v1.13.1 // indirect | ||||
| 	github.com/sirupsen/logrus v1.9.3 // indirect | ||||
| 	github.com/stretchr/testify v1.10.0 // indirect | ||||
| 	go.uber.org/goleak v1.3.0 // indirect | ||||
| 	golang.org/x/exp v0.0.0-20241210194714-1829a127f884 // indirect | ||||
| 	golang.org/x/net v0.32.0 // indirect | ||||
| 	golang.org/x/net v0.33.0 // indirect | ||||
| 	golang.org/x/sys v0.28.0 // indirect | ||||
| 	google.golang.org/genproto/googleapis/rpc v0.0.0-20241209162323-e6fa225c2576 // indirect | ||||
| 	google.golang.org/genproto/googleapis/rpc v0.0.0-20241216192217-9240e9c98484 // indirect | ||||
| 	gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect | ||||
| ) | ||||
|   | ||||
							
								
								
									
										27
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										27
									
								
								go.sum
									
									
									
									
									
								
							| @@ -1,11 +1,11 @@ | ||||
| dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s= | ||||
| dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= | ||||
| github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= | ||||
| github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= | ||||
| github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= | ||||
| github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= | ||||
| github.com/KimMachineGun/automemlimit v0.6.1 h1:ILa9j1onAAMadBsyyUJv5cack8Y1WT26yLj/V+ulKp8= | ||||
| github.com/KimMachineGun/automemlimit v0.6.1/go.mod h1:T7xYht7B8r6AG/AqFcUdc7fzd2bIdBKmepfP2S1svPY= | ||||
| github.com/ash3in/uuidv8 v1.0.1 h1:dIq1XRkWT8lGA7N5s7WRTB4V3k49WTBLvILz7aCLp80= | ||||
| github.com/ash3in/uuidv8 v1.0.1/go.mod h1:EoyUgCtxNBnrnpc9efw5rVN1cQ+LFGCoJiFuD6maOMw= | ||||
| github.com/ash3in/uuidv8 v1.2.0 h1:2oogGdtCPwaVtyvPPGin4TfZLtOGE5F+W++E880G6SI= | ||||
| github.com/ash3in/uuidv8 v1.2.0/go.mod h1:BnU0wJBxnzdEKmVg4xckBkD+VZuecTFTUP3M0dWgyY4= | ||||
| github.com/cilium/ebpf v0.16.0 h1:+BiEnHL6Z7lXnlGUsXQPPAE7+kenAd4ES8MQ5min0Ok= | ||||
| github.com/cilium/ebpf v0.16.0/go.mod h1:L7u2Blt2jMM/vLAVgjxluxtBKlz3/GWjB0dMOEngfwE= | ||||
| github.com/containerd/cgroups/v3 v3.0.4 h1:2fs7l3P0Qxb1nKWuJNFiwhp2CqiKzho71DQkDrHJIo4= | ||||
| @@ -35,6 +35,7 @@ github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtL | ||||
| github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= | ||||
| github.com/jsimonetti/rtnetlink/v2 v2.0.1 h1:xda7qaHDSVOsADNouv7ukSuicKZO7GgVUCXxpaIEIlM= | ||||
| github.com/jsimonetti/rtnetlink/v2 v2.0.1/go.mod h1:7MoNYNbb3UaDHtF8udiJo/RH6VsTKP1pqKLUTVCvToE= | ||||
| github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= | ||||
| github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= | ||||
| github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= | ||||
| github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= | ||||
| @@ -79,8 +80,8 @@ go.unistack.org/micro-proto/v3 v3.4.1 h1:UTjLSRz2YZuaHk9iSlVqqsA50JQNAEK2ZFboGqt | ||||
| go.unistack.org/micro-proto/v3 v3.4.1/go.mod h1:okx/cnOhzuCX0ggl/vToatbCupi0O44diiiLLsZ93Zo= | ||||
| golang.org/x/exp v0.0.0-20241210194714-1829a127f884 h1:Y/Mj/94zIQQGHVSv1tTtQBDaQaJe62U9bkDZKKyhPCU= | ||||
| golang.org/x/exp v0.0.0-20241210194714-1829a127f884/go.mod h1:qj5a5QZpwLU2NLQudwIN5koi3beDhSAlJwa67PuM98c= | ||||
| golang.org/x/net v0.32.0 h1:ZqPmj8Kzc+Y6e0+skZsuACbx+wzMgo5MQsJh9Qd6aYI= | ||||
| golang.org/x/net v0.32.0/go.mod h1:CwU0IoeOlnQQWJ6ioyFrfRuomB8GKF6KbYXZVyeXNfs= | ||||
| golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= | ||||
| golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= | ||||
| golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= | ||||
| golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= | ||||
| golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| @@ -88,12 +89,14 @@ golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= | ||||
| golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= | ||||
| golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= | ||||
| golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= | ||||
| google.golang.org/genproto/googleapis/rpc v0.0.0-20241209162323-e6fa225c2576 h1:8ZmaLZE4XWrtU3MyClkYqqtl6Oegr3235h7jxsDyqCY= | ||||
| google.golang.org/genproto/googleapis/rpc v0.0.0-20241209162323-e6fa225c2576/go.mod h1:5uTbfoYQed2U9p3KIj2/Zzm02PYhndfdmML0qC3q3FU= | ||||
| google.golang.org/grpc v1.68.1 h1:oI5oTa11+ng8r8XMMN7jAOmWfPZWbYpCFaMUTACxkM0= | ||||
| google.golang.org/grpc v1.68.1/go.mod h1:+q1XYFJjShcqn0QZHvCyeR4CXPA+llXIeUIfIe00waw= | ||||
| google.golang.org/protobuf v1.35.2 h1:8Ar7bF+apOIoThw1EdZl0p1oWvMqTHmpA2fRTyZO8io= | ||||
| google.golang.org/protobuf v1.35.2/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= | ||||
| golang.yandex/hasql/v2 v2.1.0 h1:7CaFFWeHoK5TvA+QvZzlKHlIN5sqNpqM8NSrXskZD/k= | ||||
| golang.yandex/hasql/v2 v2.1.0/go.mod h1:3Au1AxuJDCTXmS117BpbI6e+70kGWeyLR1qJAH6HdtA= | ||||
| google.golang.org/genproto/googleapis/rpc v0.0.0-20241216192217-9240e9c98484 h1:Z7FRVJPSMaHQxD0uXU8WdgFh8PseLM8Q8NzhnpMrBhQ= | ||||
| google.golang.org/genproto/googleapis/rpc v0.0.0-20241216192217-9240e9c98484/go.mod h1:lcTa1sDdWEIHMWlITnIczmw5w60CF9ffkb8Z+DVmmjA= | ||||
| google.golang.org/grpc v1.69.2 h1:U3S9QEtbXC0bYNvRtcoklF3xGtLViumSYxWykJS+7AU= | ||||
| google.golang.org/grpc v1.69.2/go.mod h1:vyjdE6jLBI76dgpDojsFGNaHlxdjXN9ghpnd2o7JGZ4= | ||||
| google.golang.org/protobuf v1.36.1 h1:yBPeRvTftaleIgM3PZ/WBIZ7XM/eEYAaEyCwvyjq/gk= | ||||
| google.golang.org/protobuf v1.36.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= | ||||
| gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= | ||||
| gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= | ||||
| gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= | ||||
|   | ||||
							
								
								
									
										76
									
								
								hooks/metadata/metadata.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										76
									
								
								hooks/metadata/metadata.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,76 @@ | ||||
| package metadata | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
|  | ||||
| 	"go.unistack.org/micro/v3/client" | ||||
| 	"go.unistack.org/micro/v3/metadata" | ||||
| 	"go.unistack.org/micro/v3/server" | ||||
| ) | ||||
|  | ||||
| var DefaultMetadataKeys = []string{"x-request-id"} | ||||
|  | ||||
| type hook struct { | ||||
| 	keys []string | ||||
| } | ||||
|  | ||||
| func NewHook(keys ...string) *hook { | ||||
| 	return &hook{keys: keys} | ||||
| } | ||||
|  | ||||
| func metadataCopy(ctx context.Context, keys []string) context.Context { | ||||
| 	if keys == nil { | ||||
| 		return ctx | ||||
| 	} | ||||
| 	if imd, iok := metadata.FromIncomingContext(ctx); iok && imd != nil { | ||||
| 		omd, ook := metadata.FromOutgoingContext(ctx) | ||||
| 		if !ook || omd == nil { | ||||
| 			omd = metadata.New(len(keys)) | ||||
| 		} | ||||
| 		for _, k := range keys { | ||||
| 			if v, ok := imd.Get(k); ok && v != "" { | ||||
| 				omd.Set(k, v) | ||||
| 			} | ||||
| 		} | ||||
| 		if !ook { | ||||
| 			ctx = metadata.NewOutgoingContext(ctx, omd) | ||||
| 		} | ||||
| 	} | ||||
| 	return ctx | ||||
| } | ||||
|  | ||||
| func (w *hook) ClientCall(next client.FuncCall) client.FuncCall { | ||||
| 	return func(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { | ||||
| 		return next(metadataCopy(ctx, w.keys), req, rsp, opts...) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (w *hook) ClientStream(next client.FuncStream) client.FuncStream { | ||||
| 	return func(ctx context.Context, req client.Request, opts ...client.CallOption) (client.Stream, error) { | ||||
| 		return next(metadataCopy(ctx, w.keys), req, opts...) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (w *hook) ClientPublish(next client.FuncPublish) client.FuncPublish { | ||||
| 	return func(ctx context.Context, msg client.Message, opts ...client.PublishOption) error { | ||||
| 		return next(metadataCopy(ctx, w.keys), msg, opts...) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (w *hook) ClientBatchPublish(next client.FuncBatchPublish) client.FuncBatchPublish { | ||||
| 	return func(ctx context.Context, msgs []client.Message, opts ...client.PublishOption) error { | ||||
| 		return next(metadataCopy(ctx, w.keys), msgs, opts...) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (w *hook) ServerHandler(next server.FuncHandler) server.FuncHandler { | ||||
| 	return func(ctx context.Context, req server.Request, rsp interface{}) error { | ||||
| 		return next(metadataCopy(ctx, w.keys), req, rsp) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (w *hook) ServerSubscriber(next server.FuncSubHandler) server.FuncSubHandler { | ||||
| 	return func(ctx context.Context, msg server.Message) error { | ||||
| 		return next(metadataCopy(ctx, w.keys), msg) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										94
									
								
								hooks/recovery/recovery.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										94
									
								
								hooks/recovery/recovery.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,94 @@ | ||||
| package recovery | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
|  | ||||
| 	"go.unistack.org/micro/v3/errors" | ||||
| 	"go.unistack.org/micro/v3/server" | ||||
| ) | ||||
|  | ||||
| func NewOptions(opts ...Option) Options { | ||||
| 	options := Options{ | ||||
| 		ServerHandlerFn:    DefaultServerHandlerFn, | ||||
| 		ServerSubscriberFn: DefaultServerSubscriberFn, | ||||
| 	} | ||||
| 	for _, o := range opts { | ||||
| 		o(&options) | ||||
| 	} | ||||
| 	return options | ||||
| } | ||||
|  | ||||
| type Options struct { | ||||
| 	ServerHandlerFn    func(context.Context, server.Request, interface{}, error) error | ||||
| 	ServerSubscriberFn func(context.Context, server.Message, error) error | ||||
| } | ||||
|  | ||||
| type Option func(*Options) | ||||
|  | ||||
| func ServerHandlerFunc(fn func(context.Context, server.Request, interface{}, error) error) Option { | ||||
| 	return func(o *Options) { | ||||
| 		o.ServerHandlerFn = fn | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func ServerSubscriberFunc(fn func(context.Context, server.Message, error) error) Option { | ||||
| 	return func(o *Options) { | ||||
| 		o.ServerSubscriberFn = fn | ||||
| 	} | ||||
| } | ||||
|  | ||||
| var ( | ||||
| 	DefaultServerHandlerFn = func(ctx context.Context, req server.Request, rsp interface{}, err error) error { | ||||
| 		return errors.BadRequest("", "%v", err) | ||||
| 	} | ||||
| 	DefaultServerSubscriberFn = func(ctx context.Context, req server.Message, err error) error { | ||||
| 		return errors.BadRequest("", "%v", err) | ||||
| 	} | ||||
| ) | ||||
|  | ||||
| var Hook = NewHook() | ||||
|  | ||||
| type hook struct { | ||||
| 	opts Options | ||||
| } | ||||
|  | ||||
| func NewHook(opts ...Option) *hook { | ||||
| 	return &hook{opts: NewOptions(opts...)} | ||||
| } | ||||
|  | ||||
| func (w *hook) ServerHandler(next server.FuncHandler) server.FuncHandler { | ||||
| 	return func(ctx context.Context, req server.Request, rsp interface{}) (err error) { | ||||
| 		defer func() { | ||||
| 			r := recover() | ||||
| 			switch verr := r.(type) { | ||||
| 			case nil: | ||||
| 				return | ||||
| 			case error: | ||||
| 				err = w.opts.ServerHandlerFn(ctx, req, rsp, verr) | ||||
| 			default: | ||||
| 				err = w.opts.ServerHandlerFn(ctx, req, rsp, fmt.Errorf("%v", r)) | ||||
| 			} | ||||
| 		}() | ||||
| 		err = next(ctx, req, rsp) | ||||
| 		return err | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (w *hook) ServerSubscriber(next server.FuncSubHandler) server.FuncSubHandler { | ||||
| 	return func(ctx context.Context, msg server.Message) (err error) { | ||||
| 		defer func() { | ||||
| 			r := recover() | ||||
| 			switch verr := r.(type) { | ||||
| 			case nil: | ||||
| 				return | ||||
| 			case error: | ||||
| 				err = w.opts.ServerSubscriberFn(ctx, msg, verr) | ||||
| 			default: | ||||
| 				err = w.opts.ServerSubscriberFn(ctx, msg, fmt.Errorf("%v", r)) | ||||
| 			} | ||||
| 		}() | ||||
| 		err = next(ctx, msg) | ||||
| 		return err | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										139
									
								
								hooks/requestid/requestid.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										139
									
								
								hooks/requestid/requestid.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,139 @@ | ||||
| package requestid | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"net/textproto" | ||||
| 	 | ||||
| 	"go.unistack.org/micro/v3/client" | ||||
| 	"go.unistack.org/micro/v3/metadata" | ||||
| 	"go.unistack.org/micro/v3/server" | ||||
| 	"go.unistack.org/micro/v3/util/id" | ||||
| ) | ||||
|  | ||||
| type XRequestIDKey struct{} | ||||
|  | ||||
| // DefaultMetadataKey contains metadata key | ||||
| var DefaultMetadataKey = textproto.CanonicalMIMEHeaderKey("x-request-id") | ||||
|  | ||||
| // DefaultMetadataFunc wil be used if user not provide own func to fill metadata | ||||
| var DefaultMetadataFunc = func(ctx context.Context) (context.Context, error) { | ||||
| 	var xid string | ||||
|  | ||||
| 	cid, cok := ctx.Value(XRequestIDKey{}).(string) | ||||
| 	if cok && cid != "" { | ||||
| 		xid = cid | ||||
| 	} | ||||
|  | ||||
| 	imd, iok := metadata.FromIncomingContext(ctx) | ||||
| 	if !iok || imd == nil { | ||||
| 		imd = metadata.New(1) | ||||
| 		ctx = metadata.NewIncomingContext(ctx, imd) | ||||
| 	} | ||||
|  | ||||
| 	omd, ook := metadata.FromOutgoingContext(ctx) | ||||
| 	if !ook || omd == nil { | ||||
| 		omd = metadata.New(1) | ||||
| 		ctx = metadata.NewOutgoingContext(ctx, omd) | ||||
| 	} | ||||
|  | ||||
| 	if xid == "" { | ||||
| 		var id string | ||||
| 		if id, iok = imd.Get(DefaultMetadataKey); iok && id != "" { | ||||
| 			xid = id | ||||
| 		} | ||||
| 		if id, ook = omd.Get(DefaultMetadataKey); ook && id != "" { | ||||
| 			xid = id | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if xid == "" { | ||||
| 		var err error | ||||
| 		xid, err = id.New() | ||||
| 		if err != nil { | ||||
| 			return ctx, err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if !cok { | ||||
| 		ctx = context.WithValue(ctx, XRequestIDKey{}, xid) | ||||
| 	} | ||||
|  | ||||
| 	if !iok { | ||||
| 		imd.Set(DefaultMetadataKey, xid) | ||||
| 	} | ||||
|  | ||||
| 	if !ook { | ||||
| 		omd.Set(DefaultMetadataKey, xid) | ||||
| 	} | ||||
|  | ||||
| 	return ctx, nil | ||||
| } | ||||
|  | ||||
| type hook struct{} | ||||
|  | ||||
| func NewHook() *hook { | ||||
| 	return &hook{} | ||||
| } | ||||
|  | ||||
| func (w *hook) ServerSubscriber(next server.FuncSubHandler) server.FuncSubHandler { | ||||
| 	return func(ctx context.Context, msg server.Message) error { | ||||
| 		var err error | ||||
| 		if xid, ok := msg.Header()[DefaultMetadataKey]; ok { | ||||
| 			ctx = context.WithValue(ctx, XRequestIDKey{}, xid) | ||||
| 		} | ||||
| 		if ctx, err = DefaultMetadataFunc(ctx); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		return next(ctx, msg) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (w *hook) ServerHandler(next server.FuncHandler) server.FuncHandler { | ||||
| 	return func(ctx context.Context, req server.Request, rsp interface{}) error { | ||||
| 		var err error | ||||
| 		if ctx, err = DefaultMetadataFunc(ctx); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		return next(ctx, req, rsp) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (w *hook) ClientBatchPublish(next client.FuncBatchPublish) client.FuncBatchPublish { | ||||
| 	return func(ctx context.Context, msgs []client.Message, opts ...client.PublishOption) error { | ||||
| 		var err error | ||||
| 		if ctx, err = DefaultMetadataFunc(ctx); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		return next(ctx, msgs, opts...) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (w *hook) ClientPublish(next client.FuncPublish) client.FuncPublish { | ||||
| 	return func(ctx context.Context, msg client.Message, opts ...client.PublishOption) error { | ||||
| 		var err error | ||||
| 		if ctx, err = DefaultMetadataFunc(ctx); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		return next(ctx, msg, opts...) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (w *hook) ClientCall(next client.FuncCall) client.FuncCall { | ||||
| 	return func(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { | ||||
| 		var err error | ||||
| 		if ctx, err = DefaultMetadataFunc(ctx); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		return next(ctx, req, rsp, opts...) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (w *hook) ClientStream(next client.FuncStream) client.FuncStream { | ||||
| 	return func(ctx context.Context, req client.Request, opts ...client.CallOption) (client.Stream, error) { | ||||
| 		var err error | ||||
| 		if ctx, err = DefaultMetadataFunc(ctx); err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		return next(ctx, req, opts...) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										33
									
								
								hooks/requestid/requestid_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								hooks/requestid/requestid_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,33 @@ | ||||
| package requestid | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"testing" | ||||
|  | ||||
| 	"go.unistack.org/micro/v3/metadata" | ||||
| ) | ||||
|  | ||||
| func TestDefaultMetadataFunc(t *testing.T) { | ||||
| 	ctx := context.TODO() | ||||
|  | ||||
| 	nctx, err := DefaultMetadataFunc(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("%v", err) | ||||
| 	} | ||||
|  | ||||
| 	imd, ok := metadata.FromIncomingContext(nctx) | ||||
| 	if !ok { | ||||
| 		t.Fatalf("md missing in incoming context") | ||||
| 	} | ||||
| 	omd, ok := metadata.FromOutgoingContext(nctx) | ||||
| 	if !ok { | ||||
| 		t.Fatalf("md missing in outgoing context") | ||||
| 	} | ||||
|  | ||||
| 	_, iok := imd.Get(DefaultMetadataKey) | ||||
| 	_, ook := omd.Get(DefaultMetadataKey) | ||||
|  | ||||
| 	if !iok || !ook { | ||||
| 		t.Fatalf("missing metadata key value") | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										51
									
								
								hooks/sql/common.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										51
									
								
								hooks/sql/common.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,51 @@ | ||||
| package sql | ||||
|  | ||||
| import ( | ||||
| 	"database/sql/driver" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"runtime" | ||||
| ) | ||||
|  | ||||
| //go:generate sh -c "go run gen.go > wrap_gen.go" | ||||
|  | ||||
| // namedValueToValue converts driver arguments of NamedValue format to Value format. Implemented in the same way as in | ||||
| // database/sql ctxutil.go. | ||||
| func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { | ||||
| 	dargs := make([]driver.Value, len(named)) | ||||
| 	for n, param := range named { | ||||
| 		if len(param.Name) > 0 { | ||||
| 			return nil, errors.New("sql: driver does not support the use of Named Parameters") | ||||
| 		} | ||||
| 		dargs[n] = param.Value | ||||
| 	} | ||||
| 	return dargs, nil | ||||
| } | ||||
|  | ||||
| // namedValueToLabels convert driver arguments to interface{} slice | ||||
| func namedValueToLabels(named []driver.NamedValue) []interface{} { | ||||
| 	largs := make([]interface{}, 0, len(named)*2) | ||||
| 	var name string | ||||
| 	for _, param := range named { | ||||
| 		if param.Name != "" { | ||||
| 			name = param.Name | ||||
| 		} else { | ||||
| 			name = fmt.Sprintf("$%d", param.Ordinal) | ||||
| 		} | ||||
| 		largs = append(largs, fmt.Sprintf("%s=%v", name, param.Value)) | ||||
| 	} | ||||
| 	return largs | ||||
| } | ||||
|  | ||||
| // getCallerName get the name of the function A where A() -> B() -> GetFunctionCallerName() | ||||
| func getCallerName() string { | ||||
| 	pc, _, _, ok := runtime.Caller(3) | ||||
| 	details := runtime.FuncForPC(pc) | ||||
| 	var callerName string | ||||
| 	if ok && details != nil { | ||||
| 		callerName = details.Name() | ||||
| 	} else { | ||||
| 		callerName = labelUnknown | ||||
| 	} | ||||
| 	return callerName | ||||
| } | ||||
							
								
								
									
										467
									
								
								hooks/sql/conn.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										467
									
								
								hooks/sql/conn.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,467 @@ | ||||
| package sql | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql/driver" | ||||
| 	"fmt" | ||||
| 	"time" | ||||
|  | ||||
| 	"go.unistack.org/micro/v3/hooks/requestid" | ||||
| 	"go.unistack.org/micro/v3/tracer" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	_ driver.Conn               = (*wrapperConn)(nil) | ||||
| 	_ driver.ConnBeginTx        = (*wrapperConn)(nil) | ||||
| 	_ driver.ConnPrepareContext = (*wrapperConn)(nil) | ||||
| 	_ driver.Pinger             = (*wrapperConn)(nil) | ||||
| 	_ driver.Validator          = (*wrapperConn)(nil) | ||||
| 	_ driver.Queryer            = (*wrapperConn)(nil) // nolint:staticcheck | ||||
| 	_ driver.QueryerContext     = (*wrapperConn)(nil) | ||||
| 	_ driver.Execer             = (*wrapperConn)(nil) // nolint:staticcheck | ||||
| 	_ driver.ExecerContext      = (*wrapperConn)(nil) | ||||
| 	//	_ driver.Connector | ||||
| 	//	_ driver.Driver | ||||
| 	//	_ driver.DriverContext | ||||
| ) | ||||
|  | ||||
| // wrapperConn defines a wrapper for driver.Conn | ||||
| type wrapperConn struct { | ||||
| 	d     *wrapperDriver | ||||
| 	dname string | ||||
| 	conn  driver.Conn | ||||
| 	opts  Options | ||||
| 	ctx   context.Context | ||||
| 	//span  tracer.Span | ||||
| } | ||||
|  | ||||
| // Close implements driver.Conn Close | ||||
| func (w *wrapperConn) Close() error { | ||||
| 	var ctx context.Context | ||||
| 	if w.ctx != nil { | ||||
| 		ctx = w.ctx | ||||
| 	} else { | ||||
| 		ctx = context.Background() | ||||
| 	} | ||||
| 	_ = ctx | ||||
| 	labels := []string{labelMethod, "Close"} | ||||
| 	ts := time.Now() | ||||
| 	err := w.conn.Close() | ||||
| 	td := time.Since(ts) | ||||
| 	te := td.Seconds() | ||||
| 	if err != nil { | ||||
| 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelFailure)...).Inc() | ||||
| 	} else { | ||||
| 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelSuccess)...).Inc() | ||||
| 	} | ||||
| 	w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||
| 	w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||
| 	/* | ||||
| 		if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||
| 			w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "Close", getCallerName(), td, err)...) | ||||
| 		} | ||||
| 	*/ | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| // Begin implements driver.Conn Begin | ||||
| func (w *wrapperConn) Begin() (driver.Tx, error) { | ||||
| 	var ctx context.Context | ||||
| 	if w.ctx != nil { | ||||
| 		ctx = w.ctx | ||||
| 	} else { | ||||
| 		ctx = context.Background() | ||||
| 	} | ||||
|  | ||||
| 	labels := []string{labelMethod, "Begin"} | ||||
| 	ts := time.Now() | ||||
| 	tx, err := w.conn.Begin() // nolint:staticcheck | ||||
| 	td := time.Since(ts) | ||||
| 	te := td.Seconds() | ||||
| 	if err != nil { | ||||
| 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelFailure)...).Inc() | ||||
| 		w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||
| 		w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||
| 		/* | ||||
| 			if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||
| 				w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "Begin", getCallerName(), td, err)...) | ||||
| 			} | ||||
| 		*/ | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelSuccess)...).Inc() | ||||
| 	w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||
| 	w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||
| 	/* | ||||
| 		if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||
| 			w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "Begin", getCallerName(), td, err)...) | ||||
| 		} | ||||
| 	*/ | ||||
| 	return &wrapperTx{tx: tx, opts: w.opts, ctx: ctx}, nil | ||||
| } | ||||
|  | ||||
| // BeginTx implements driver.ConnBeginTx BeginTx | ||||
| func (w *wrapperConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { | ||||
| 	name := getQueryName(ctx) | ||||
| 	nctx, span := w.opts.Tracer.Start(ctx, "sdk.database", tracer.WithSpanKind(tracer.SpanKindClient)) | ||||
| 	span.AddLabels("db.method", "BeginTx") | ||||
| 	span.AddLabels("db.statement", name) | ||||
| 	if id, ok := ctx.Value(requestid.XRequestIDKey{}).(string); ok { | ||||
| 		span.AddLabels("x-request-id", id) | ||||
| 	} | ||||
| 	labels := []string{labelMethod, "BeginTx", labelQuery, name} | ||||
|  | ||||
| 	connBeginTx, ok := w.conn.(driver.ConnBeginTx) | ||||
| 	if !ok { | ||||
| 		return w.Begin() | ||||
| 	} | ||||
|  | ||||
| 	ts := time.Now() | ||||
| 	tx, err := connBeginTx.BeginTx(nctx, opts) | ||||
| 	td := time.Since(ts) | ||||
| 	te := td.Seconds() | ||||
| 	if err != nil { | ||||
| 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelFailure)...).Inc() | ||||
| 		w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||
| 		w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||
| 		span.SetStatus(tracer.SpanStatusError, err.Error()) | ||||
| 		/* | ||||
| 			if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||
| 				w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "BeginTx", getCallerName(), td, err)...) | ||||
| 			} | ||||
| 		*/ | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelSuccess)...).Inc() | ||||
| 	/* | ||||
| 		if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||
| 			w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "BeginTx", getCallerName(), td, err)...) | ||||
| 		} | ||||
| 	*/ | ||||
| 	return &wrapperTx{tx: tx, opts: w.opts, ctx: ctx, span: span}, nil | ||||
| } | ||||
|  | ||||
| // Prepare implements driver.Conn Prepare | ||||
| func (w *wrapperConn) Prepare(query string) (driver.Stmt, error) { | ||||
| 	var ctx context.Context | ||||
| 	if w.ctx != nil { | ||||
| 		ctx = w.ctx | ||||
| 	} else { | ||||
| 		ctx = context.Background() | ||||
| 	} | ||||
| 	_ = ctx | ||||
| 	labels := []string{labelMethod, "Prepare", labelQuery, getCallerName()} | ||||
| 	ts := time.Now() | ||||
| 	stmt, err := w.conn.Prepare(query) | ||||
| 	td := time.Since(ts) | ||||
| 	te := td.Seconds() | ||||
| 	if err != nil { | ||||
| 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelFailure)...).Inc() | ||||
| 		w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||
| 		w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||
| 		/* | ||||
| 			if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||
| 				w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "Prepare", getCallerName(), td, err)...) | ||||
| 			} | ||||
| 		*/ | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelSuccess)...).Inc() | ||||
| 	w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||
| 	w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||
| 	/* | ||||
| 		if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||
| 			w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "Prepare", getCallerName(), td, err)...) | ||||
| 		} | ||||
| 	*/ | ||||
| 	return wrapStmt(stmt, query, w.opts), nil | ||||
| } | ||||
|  | ||||
| // PrepareContext implements driver.ConnPrepareContext PrepareContext | ||||
| func (w *wrapperConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { | ||||
| 	var nctx context.Context | ||||
| 	var span tracer.Span | ||||
|  | ||||
| 	name := getQueryName(ctx) | ||||
| 	if w.ctx != nil { | ||||
| 		nctx, span = w.opts.Tracer.Start(w.ctx, "sdk.database", tracer.WithSpanKind(tracer.SpanKindClient)) | ||||
| 	} else { | ||||
| 		nctx, span = w.opts.Tracer.Start(ctx, "sdk.database", tracer.WithSpanKind(tracer.SpanKindClient)) | ||||
| 	} | ||||
| 	span.AddLabels("db.method", "PrepareContext") | ||||
| 	span.AddLabels("db.statement", name) | ||||
| 	if id, ok := ctx.Value(requestid.XRequestIDKey{}).(string); ok { | ||||
| 		span.AddLabels("x-request-id", id) | ||||
| 	} | ||||
| 	labels := []string{labelMethod, "PrepareContext", labelQuery, name} | ||||
| 	conn, ok := w.conn.(driver.ConnPrepareContext) | ||||
| 	if !ok { | ||||
| 		return w.Prepare(query) | ||||
| 	} | ||||
|  | ||||
| 	ts := time.Now() | ||||
| 	stmt, err := conn.PrepareContext(nctx, query) | ||||
| 	td := time.Since(ts) | ||||
| 	te := td.Seconds() | ||||
| 	if err != nil { | ||||
| 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelFailure)...).Inc() | ||||
| 		w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||
| 		w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||
| 		span.SetStatus(tracer.SpanStatusError, err.Error()) | ||||
| 		/* | ||||
| 			if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||
| 				w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "PrepareContext", getCallerName(), td, err)...) | ||||
| 			} | ||||
| 		*/ | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelSuccess)...).Inc() | ||||
| 	w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||
| 	w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||
| 	/* | ||||
| 		if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||
| 			w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "PrepareContext", getCallerName(), td, err)...) | ||||
| 		} | ||||
| 	*/ | ||||
| 	return wrapStmt(stmt, query, w.opts), nil | ||||
| } | ||||
|  | ||||
| // Exec implements driver.Execer Exec | ||||
| func (w *wrapperConn) Exec(query string, args []driver.Value) (driver.Result, error) { | ||||
| 	var ctx context.Context | ||||
| 	if w.ctx != nil { | ||||
| 		ctx = w.ctx | ||||
| 	} else { | ||||
| 		ctx = context.Background() | ||||
| 	} | ||||
| 	_ = ctx | ||||
| 	labels := []string{labelMethod, "Exec", labelQuery, getCallerName()} | ||||
|  | ||||
| 	// nolint:staticcheck | ||||
| 	conn, ok := w.conn.(driver.Execer) | ||||
| 	if !ok { | ||||
| 		return nil, driver.ErrSkip | ||||
| 	} | ||||
|  | ||||
| 	ts := time.Now() | ||||
| 	res, err := conn.Exec(query, args) | ||||
| 	td := time.Since(ts) | ||||
| 	te := td.Seconds() | ||||
| 	if err != nil { | ||||
| 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelFailure)...).Inc() | ||||
| 	} else { | ||||
| 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelSuccess)...).Inc() | ||||
| 	} | ||||
| 	w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||
| 	w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||
| 	/* | ||||
| 		if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||
| 			w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "Exec", getCallerName(), td, err)...) | ||||
| 		} | ||||
| 	*/ | ||||
| 	return res, err | ||||
| } | ||||
|  | ||||
| // Exec implements driver.StmtExecContext ExecContext | ||||
| func (w *wrapperConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { | ||||
| 	var nctx context.Context | ||||
| 	var span tracer.Span | ||||
|  | ||||
| 	name := getQueryName(ctx) | ||||
| 	if w.ctx != nil { | ||||
| 		nctx, span = w.opts.Tracer.Start(w.ctx, "sdk.database", tracer.WithSpanKind(tracer.SpanKindClient)) | ||||
| 	} else { | ||||
| 		nctx, span = w.opts.Tracer.Start(ctx, "sdk.database", tracer.WithSpanKind(tracer.SpanKindClient)) | ||||
| 	} | ||||
| 	span.AddLabels("db.method", "ExecContext") | ||||
| 	span.AddLabels("db.statement", name) | ||||
| 	if id, ok := ctx.Value(requestid.XRequestIDKey{}).(string); ok { | ||||
| 		span.AddLabels("x-request-id", id) | ||||
| 	} | ||||
| 	defer span.Finish() | ||||
| 	if len(args) > 0 { | ||||
| 		span.AddLabels("db.args", fmt.Sprintf("%v", namedValueToLabels(args))) | ||||
| 	} | ||||
| 	labels := []string{labelMethod, "ExecContext", labelQuery, name} | ||||
|  | ||||
| 	conn, ok := w.conn.(driver.ExecerContext) | ||||
| 	if !ok { | ||||
| 		// nolint:staticcheck | ||||
| 		return nil, driver.ErrSkip | ||||
| 	} | ||||
|  | ||||
| 	ts := time.Now() | ||||
| 	res, err := conn.ExecContext(nctx, query, args) | ||||
| 	td := time.Since(ts) | ||||
| 	te := td.Seconds() | ||||
| 	if err != nil { | ||||
| 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelFailure)...).Inc() | ||||
| 		span.SetStatus(tracer.SpanStatusError, err.Error()) | ||||
| 	} else { | ||||
| 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelSuccess)...).Inc() | ||||
| 	} | ||||
| 	w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelSuccess)...).Inc() | ||||
| 	w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||
| 	w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||
| 	/* | ||||
| 		if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||
| 			w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "ExecContext", getCallerName(), td, err)...) | ||||
| 		} | ||||
| 	*/ | ||||
| 	return res, err | ||||
| } | ||||
|  | ||||
| // Ping implements driver.Pinger Ping | ||||
| func (w *wrapperConn) Ping(ctx context.Context) error { | ||||
| 	conn, ok := w.conn.(driver.Pinger) | ||||
|  | ||||
| 	if !ok { | ||||
| 		// fallback path to check db alive | ||||
| 		pc, err := w.d.Open(w.dname) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		return pc.Close() | ||||
| 	} | ||||
|  | ||||
| 	var nctx context.Context //nolint: gosimple | ||||
| 	nctx = ctx | ||||
| 	/* | ||||
| 		var span tracer.Span | ||||
| 		if w.ctx != nil { | ||||
| 			nctx, span = w.opts.Tracer.Start(w.ctx, "sdk.database", tracer.WithSpanKind(tracer.SpanKindClient)) | ||||
| 		} else { | ||||
| 			nctx, span = w.opts.Tracer.Start(ctx, "sdk.database", tracer.WithSpanKind(tracer.SpanKindClient)) | ||||
| 		} | ||||
| 		span.AddLabels("db.method", "Ping") | ||||
| 		defer span.Finish() | ||||
| 	*/ | ||||
| 	labels := []string{labelMethod, "Ping"} | ||||
| 	ts := time.Now() | ||||
| 	err := conn.Ping(nctx) | ||||
| 	td := time.Since(ts) | ||||
| 	te := td.Seconds() | ||||
| 	if err != nil { | ||||
| 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelFailure)...).Inc() | ||||
| 		// span.SetStatus(tracer.SpanStatusError, err.Error()) | ||||
| 		/* | ||||
| 			if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||
| 				w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "Ping", getCallerName(), td, err)...) | ||||
| 			} | ||||
| 		*/ | ||||
| 		return err | ||||
| 	} else { | ||||
| 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelSuccess)...).Inc() | ||||
| 	} | ||||
| 	w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||
| 	w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // Query implements driver.Queryer Query | ||||
| func (w *wrapperConn) Query(query string, args []driver.Value) (driver.Rows, error) { | ||||
| 	var ctx context.Context | ||||
| 	if w.ctx != nil { | ||||
| 		ctx = w.ctx | ||||
| 	} else { | ||||
| 		ctx = context.Background() | ||||
| 	} | ||||
| 	_ = ctx | ||||
| 	// nolint:staticcheck | ||||
| 	conn, ok := w.conn.(driver.Queryer) | ||||
| 	if !ok { | ||||
| 		return nil, driver.ErrSkip | ||||
| 	} | ||||
|  | ||||
| 	labels := []string{labelMethod, "Query", labelQuery, getCallerName()} | ||||
| 	ts := time.Now() | ||||
| 	rows, err := conn.Query(query, args) | ||||
| 	td := time.Since(ts) | ||||
| 	te := td.Seconds() | ||||
| 	if err != nil { | ||||
| 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelFailure)...).Inc() | ||||
| 	} else { | ||||
| 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelSuccess)...).Inc() | ||||
| 	} | ||||
| 	w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||
| 	w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||
| 	/* | ||||
| 		if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||
| 			w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "Query", getCallerName(), td, err)...) | ||||
| 		} | ||||
| 	*/ | ||||
| 	return rows, err | ||||
| } | ||||
|  | ||||
| // QueryContext implements Driver.QueryerContext QueryContext | ||||
| func (w *wrapperConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { | ||||
| 	var nctx context.Context | ||||
| 	var span tracer.Span | ||||
|  | ||||
| 	name := getQueryName(ctx) | ||||
| 	if w.ctx != nil { | ||||
| 		nctx, span = w.opts.Tracer.Start(w.ctx, "sdk.database", tracer.WithSpanKind(tracer.SpanKindClient)) | ||||
| 	} else { | ||||
| 		nctx, span = w.opts.Tracer.Start(ctx, "sdk.database", tracer.WithSpanKind(tracer.SpanKindClient)) | ||||
| 	} | ||||
| 	span.AddLabels("db.method", "QueryContext") | ||||
| 	span.AddLabels("db.statement", name) | ||||
| 	if id, ok := ctx.Value(requestid.XRequestIDKey{}).(string); ok { | ||||
| 		span.AddLabels("x-request-id", id) | ||||
| 	} | ||||
| 	defer span.Finish() | ||||
| 	if len(args) > 0 { | ||||
| 		span.AddLabels("db.args", fmt.Sprintf("%v", namedValueToLabels(args))) | ||||
| 	} | ||||
| 	labels := []string{labelMethod, "QueryContext", labelQuery, name} | ||||
| 	conn, ok := w.conn.(driver.QueryerContext) | ||||
| 	if !ok { | ||||
| 		return nil, driver.ErrSkip | ||||
| 	} | ||||
|  | ||||
| 	ts := time.Now() | ||||
| 	rows, err := conn.QueryContext(nctx, query, args) | ||||
| 	td := time.Since(ts) | ||||
| 	te := td.Seconds() | ||||
| 	if err != nil { | ||||
| 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelFailure)...).Inc() | ||||
| 		span.SetStatus(tracer.SpanStatusError, err.Error()) | ||||
| 	} else { | ||||
| 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelSuccess)...).Inc() | ||||
| 	} | ||||
| 	w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||
| 	w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||
| 	/* | ||||
| 		if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||
| 			w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "QueryContext", getCallerName(), td, err)...) | ||||
| 		} | ||||
| 	*/ | ||||
| 	return rows, err | ||||
| } | ||||
|  | ||||
| // CheckNamedValue implements driver.NamedValueChecker | ||||
| func (w *wrapperConn) CheckNamedValue(v *driver.NamedValue) error { | ||||
| 	s, ok := w.conn.(driver.NamedValueChecker) | ||||
| 	if !ok { | ||||
| 		return driver.ErrSkip | ||||
| 	} | ||||
| 	return s.CheckNamedValue(v) | ||||
| } | ||||
|  | ||||
| // IsValid implements driver.Validator | ||||
| func (w *wrapperConn) IsValid() bool { | ||||
| 	v, ok := w.conn.(driver.Validator) | ||||
| 	if !ok { | ||||
| 		return w.conn != nil | ||||
| 	} | ||||
| 	return v.IsValid() | ||||
| } | ||||
|  | ||||
| func (w *wrapperConn) ResetSession(ctx context.Context) error { | ||||
| 	s, ok := w.conn.(driver.SessionResetter) | ||||
| 	if !ok { | ||||
| 		return driver.ErrSkip | ||||
| 	} | ||||
| 	return s.ResetSession(ctx) | ||||
| } | ||||
							
								
								
									
										94
									
								
								hooks/sql/driver.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										94
									
								
								hooks/sql/driver.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,94 @@ | ||||
| package sql | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql/driver" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| // _ driver.DriverContext = (*wrapperDriver)(nil) | ||||
| // _ driver.Connector     = (*wrapperDriver)(nil) | ||||
| ) | ||||
|  | ||||
| /* | ||||
| type conn interface { | ||||
| 	driver.Pinger | ||||
| 	driver.Execer | ||||
| 	driver.ExecerContext | ||||
| 	driver.Queryer | ||||
| 	driver.QueryerContext | ||||
| 	driver.Conn | ||||
| 	driver.ConnPrepareContext | ||||
| 	driver.ConnBeginTx | ||||
| } | ||||
| */ | ||||
|  | ||||
| // wrapperDriver defines a wrapper for driver.Driver | ||||
| type wrapperDriver struct { | ||||
| 	driver driver.Driver | ||||
| 	opts   Options | ||||
| 	ctx    context.Context | ||||
| } | ||||
|  | ||||
| // NewWrapper creates and returns a new SQL driver with passed capabilities | ||||
| func NewWrapper(d driver.Driver, opts ...Option) driver.Driver { | ||||
| 	return &wrapperDriver{driver: d, opts: NewOptions(opts...), ctx: context.Background()} | ||||
| } | ||||
|  | ||||
| type wrappedConnector struct { | ||||
| 	connector driver.Connector | ||||
| //	name      string | ||||
| 	opts      Options | ||||
| 	ctx       context.Context | ||||
| } | ||||
|  | ||||
| func NewWrapperConnector(c driver.Connector, opts ...Option) driver.Connector { | ||||
| 	return &wrappedConnector{connector: c, opts: NewOptions(opts...), ctx: context.Background()} | ||||
| } | ||||
|  | ||||
| // Connect implements driver.Driver Connect | ||||
| func (w *wrappedConnector) Connect(ctx context.Context) (driver.Conn, error) { | ||||
| 	return w.connector.Connect(ctx) | ||||
| } | ||||
|  | ||||
| // Driver implements driver.Driver Driver | ||||
| func (w *wrappedConnector) Driver() driver.Driver { | ||||
| 	return w.connector.Driver() | ||||
| } | ||||
|  | ||||
| /* | ||||
| // Connect implements driver.Driver OpenConnector | ||||
| func (w *wrapperDriver) OpenConnector(name string) (driver.Conn, error) { | ||||
| 	return &wrapperConnector{driver: w.driver, name: name, opts: w.opts}, nil | ||||
| } | ||||
| */ | ||||
|  | ||||
| // Open implements driver.Driver Open | ||||
| func (w *wrapperDriver) Open(name string) (driver.Conn, error) { | ||||
| 	// ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) // Ensure eventual timeout | ||||
| 	// defer cancel() | ||||
|  | ||||
| 	/* | ||||
| 		connector, err := w.OpenConnector(name) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		return connector.Connect(ctx) | ||||
| 	*/ | ||||
|  | ||||
| 	ts := time.Now() | ||||
| 	c, err := w.driver.Open(name) | ||||
| 	td := time.Since(ts) | ||||
| 	/* | ||||
| 		if w.opts.LoggerEnabled { | ||||
| 			w.opts.Logger.Log(w.ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(w.ctx, "Open", getCallerName(), td, err)...) | ||||
| 		} | ||||
| 	*/ | ||||
| 	_ = td | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return wrapConn(c, w.opts), nil | ||||
| } | ||||
							
								
								
									
										165
									
								
								hooks/sql/gen.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										165
									
								
								hooks/sql/gen.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,165 @@ | ||||
| //go:build ignore | ||||
|  | ||||
| package main | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"crypto/md5" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"sort" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| var connIfaces = []string{ | ||||
| 	"driver.ConnBeginTx", | ||||
| 	"driver.ConnPrepareContext", | ||||
| 	"driver.Execer", | ||||
| 	"driver.ExecerContext", | ||||
| 	"driver.NamedValueChecker", | ||||
| 	"driver.Pinger", | ||||
| 	"driver.Queryer", | ||||
| 	"driver.QueryerContext", | ||||
| 	"driver.SessionResetter", | ||||
| 	"driver.Validator", | ||||
| } | ||||
|  | ||||
| var stmtIfaces = []string{ | ||||
| 	"driver.StmtExecContext", | ||||
| 	"driver.StmtQueryContext", | ||||
| 	"driver.ColumnConverter", | ||||
| 	"driver.NamedValueChecker", | ||||
| } | ||||
|  | ||||
| func getHash(s []string) string { | ||||
| 	h := md5.New() | ||||
| 	io.WriteString(h, strings.Join(s, "|")) | ||||
| 	return fmt.Sprintf("%x", h.Sum(nil)) | ||||
| } | ||||
|  | ||||
| func main() { | ||||
| 	comboConn := all(connIfaces) | ||||
|  | ||||
| 	sort.Slice(comboConn, func(i, j int) bool { | ||||
| 		return len(comboConn[i]) < len(comboConn[j]) | ||||
| 	}) | ||||
|  | ||||
| 	comboStmt := all(stmtIfaces) | ||||
|  | ||||
| 	sort.Slice(comboStmt, func(i, j int) bool { | ||||
| 		return len(comboStmt[i]) < len(comboStmt[j]) | ||||
| 	}) | ||||
|  | ||||
| 	b := bytes.NewBuffer(nil) | ||||
| 	b.WriteString("// Code generated. DO NOT EDIT.\n\n") | ||||
| 	b.WriteString("package sql\n\n") | ||||
| 	b.WriteString(`import "database/sql/driver"`) | ||||
| 	b.WriteString("\n\n") | ||||
|  | ||||
| 	b.WriteString("func wrapConn(dc driver.Conn, opts Options) driver.Conn {\n") | ||||
| 	b.WriteString("\tc := &wrapperConn{conn: dc, opts: opts}\n") | ||||
|  | ||||
| 	for idx := len(comboConn) - 1; idx >= 0; idx-- { | ||||
| 		ifaces := comboConn[idx] | ||||
| 		n := len(ifaces) | ||||
| 		if n == 0 { | ||||
| 			continue | ||||
| 		} | ||||
| 		h := getHash(ifaces) | ||||
| 		b.WriteString(fmt.Sprintf("\tif _, ok := dc.(wrapConn%04d_%s); ok {\n", n, h)) | ||||
| 		b.WriteString("\t\treturn struct {\n") | ||||
| 		b.WriteString(strings.Join(append([]string{"\t\t\tdriver.Conn"}, ifaces...), "\n\t\t\t")) | ||||
| 		b.WriteString("\n\t\t}{") | ||||
| 		for idx := range ifaces { | ||||
| 			if idx > 0 { | ||||
| 				b.WriteString(", ") | ||||
| 				b.WriteString("c") | ||||
| 			} else if idx == 0 { | ||||
| 				b.WriteString("c") | ||||
| 			} else { | ||||
| 				b.WriteString("c") | ||||
| 			} | ||||
| 		} | ||||
| 		b.WriteString(", c}\n") | ||||
| 		b.WriteString("\t}\n\n") | ||||
| 	} | ||||
| 	b.WriteString("\treturn c\n") | ||||
| 	b.WriteString("}\n\n") | ||||
|  | ||||
| 	for idx := len(comboConn) - 1; idx >= 0; idx-- { | ||||
| 		ifaces := comboConn[idx] | ||||
| 		n := len(ifaces) | ||||
| 		if n == 0 { | ||||
| 			continue | ||||
| 		} | ||||
| 		h := getHash(ifaces) | ||||
| 		b.WriteString(fmt.Sprintf("// %s\n", strings.Join(ifaces, "|"))) | ||||
| 		b.WriteString(fmt.Sprintf("type wrapConn%04d_%s interface {\n", n, h)) | ||||
| 		for _, iface := range ifaces { | ||||
| 			b.WriteString(fmt.Sprintf("\t%s\n", iface)) | ||||
| 		} | ||||
| 		b.WriteString("}\n\n") | ||||
| 	} | ||||
|  | ||||
| 	b.WriteString("func wrapStmt(stmt driver.Stmt, query string, opts Options) driver.Stmt {\n") | ||||
| 	b.WriteString("\tc := &wrapperStmt{stmt: stmt, query: query, opts: opts}\n") | ||||
|  | ||||
| 	for idx := len(comboStmt) - 1; idx >= 0; idx-- { | ||||
| 		ifaces := comboStmt[idx] | ||||
| 		n := len(ifaces) | ||||
| 		if n == 0 { | ||||
| 			continue | ||||
| 		} | ||||
| 		h := getHash(ifaces) | ||||
| 		b.WriteString(fmt.Sprintf("\tif _, ok := stmt.(wrapStmt%04d_%s); ok {\n", n, h)) | ||||
| 		b.WriteString("\t\treturn struct {\n") | ||||
| 		b.WriteString(strings.Join(append([]string{"\t\t\tdriver.Stmt"}, ifaces...), "\n\t\t\t")) | ||||
| 		b.WriteString("\n\t\t}{") | ||||
| 		for idx := range ifaces { | ||||
| 			if idx > 0 { | ||||
| 				b.WriteString(", ") | ||||
| 				b.WriteString("c") | ||||
| 			} else if idx == 0 { | ||||
| 				b.WriteString("c") | ||||
| 			} else { | ||||
| 				b.WriteString("c") | ||||
| 			} | ||||
| 		} | ||||
| 		b.WriteString(", c}\n") | ||||
| 		b.WriteString("\t}\n\n") | ||||
| 	} | ||||
| 	b.WriteString("\treturn c\n") | ||||
| 	b.WriteString("}\n") | ||||
|  | ||||
| 	for idx := len(comboStmt) - 1; idx >= 0; idx-- { | ||||
| 		ifaces := comboStmt[idx] | ||||
| 		n := len(ifaces) | ||||
| 		if n == 0 { | ||||
| 			continue | ||||
| 		} | ||||
| 		h := getHash(ifaces) | ||||
| 		b.WriteString(fmt.Sprintf("\n// %s\n", strings.Join(ifaces, "|"))) | ||||
| 		b.WriteString(fmt.Sprintf("type wrapStmt%04d_%s interface {\n", n, h)) | ||||
| 		for _, iface := range ifaces { | ||||
| 			b.WriteString(fmt.Sprintf("\t%s\n", iface)) | ||||
| 		} | ||||
| 		b.WriteString("}\n") | ||||
| 	} | ||||
|  | ||||
| 	fmt.Printf("%s", b.String()) | ||||
| } | ||||
|  | ||||
| // all returns all combinations for a given string array. | ||||
| func all[T any](set []T) (subsets [][]T) { | ||||
| 	length := uint(len(set)) | ||||
| 	for subsetBits := 1; subsetBits < (1 << length); subsetBits++ { | ||||
| 		var subset []T | ||||
| 		for object := uint(0); object < length; object++ { | ||||
| 			if (subsetBits>>object)&1 == 1 { | ||||
| 				subset = append(subset, set[object]) | ||||
| 			} | ||||
| 		} | ||||
| 		subsets = append(subsets, subset) | ||||
| 	} | ||||
| 	return subsets | ||||
| } | ||||
							
								
								
									
										172
									
								
								hooks/sql/options.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										172
									
								
								hooks/sql/options.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,172 @@ | ||||
| package sql | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"time" | ||||
|  | ||||
| 	"go.unistack.org/micro/v3/logger" | ||||
| 	"go.unistack.org/micro/v3/meter" | ||||
| 	"go.unistack.org/micro/v3/tracer" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	// DefaultMeterStatsInterval holds default stats interval | ||||
| 	DefaultMeterStatsInterval = 5 * time.Second | ||||
| 	// DefaultLoggerObserver used to prepare labels for logger | ||||
| 	DefaultLoggerObserver = func(ctx context.Context, method string, query string, td time.Duration, err error) []interface{} { | ||||
| 		labels := []interface{}{"db.method", method, "took", fmt.Sprintf("%v", td)} | ||||
| 		if err != nil { | ||||
| 			labels = append(labels, "error", err.Error()) | ||||
| 		} | ||||
| 		if query != labelUnknown { | ||||
| 			labels = append(labels, "query", query) | ||||
| 		} | ||||
| 		return labels | ||||
| 	} | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	MaxOpenConnections = "micro_sql_max_open_conn" | ||||
| 	OpenConnections    = "micro_sql_open_conn" | ||||
| 	InuseConnections   = "micro_sql_inuse_conn" | ||||
| 	IdleConnections    = "micro_sql_idle_conn" | ||||
| 	WaitConnections    = "micro_sql_waited_conn" | ||||
| 	BlockedSeconds     = "micro_sql_blocked_seconds" | ||||
| 	MaxIdleClosed      = "micro_sql_max_idle_closed" | ||||
| 	MaxIdletimeClosed  = "micro_sql_closed_max_idle" | ||||
| 	MaxLifetimeClosed  = "micro_sql_closed_max_lifetime" | ||||
|  | ||||
| 	meterRequestTotal               = "micro_sql_request_total" | ||||
| 	meterRequestLatencyMicroseconds = "micro_sql_latency_microseconds" | ||||
| 	meterRequestDurationSeconds     = "micro_sql_request_duration_seconds" | ||||
|  | ||||
| 	labelUnknown  = "unknown" | ||||
| 	labelQuery    = "db_statement" | ||||
| 	labelMethod   = "db_method" | ||||
| 	labelStatus   = "status" | ||||
| 	labelSuccess  = "success" | ||||
| 	labelFailure  = "failure" | ||||
| 	labelHost     = "db_host" | ||||
| 	labelDatabase = "db_name" | ||||
| ) | ||||
|  | ||||
| // Options struct holds wrapper options | ||||
| type Options struct { | ||||
| 	Logger             logger.Logger | ||||
| 	Meter              meter.Meter | ||||
| 	Tracer             tracer.Tracer | ||||
| 	DatabaseHost       string | ||||
| 	DatabaseName       string | ||||
| 	MeterStatsInterval time.Duration | ||||
| 	LoggerLevel        logger.Level | ||||
| 	LoggerEnabled      bool | ||||
| 	LoggerObserver     func(ctx context.Context, method string, name string, td time.Duration, err error) []interface{} | ||||
| } | ||||
|  | ||||
| // Option func signature | ||||
| type Option func(*Options) | ||||
|  | ||||
| // NewOptions create new Options struct from provided option slice | ||||
| func NewOptions(opts ...Option) Options { | ||||
| 	options := Options{ | ||||
| 		Logger:             logger.DefaultLogger, | ||||
| 		Meter:              meter.DefaultMeter, | ||||
| 		Tracer:             tracer.DefaultTracer, | ||||
| 		MeterStatsInterval: DefaultMeterStatsInterval, | ||||
| 		LoggerLevel:        logger.ErrorLevel, | ||||
| 		LoggerObserver:     DefaultLoggerObserver, | ||||
| 	} | ||||
| 	for _, o := range opts { | ||||
| 		o(&options) | ||||
| 	} | ||||
|  | ||||
| 	options.Meter = options.Meter.Clone( | ||||
| 		meter.Labels( | ||||
| 			labelHost, options.DatabaseHost, | ||||
| 			labelDatabase, options.DatabaseName, | ||||
| 		), | ||||
| 	) | ||||
|  | ||||
| 	options.Logger = options.Logger.Clone(logger.WithAddCallerSkipCount(1)) | ||||
|  | ||||
| 	return options | ||||
| } | ||||
|  | ||||
| // MetricInterval specifies stats interval for *sql.DB | ||||
| func MetricInterval(td time.Duration) Option { | ||||
| 	return func(o *Options) { | ||||
| 		o.MeterStatsInterval = td | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func DatabaseHost(host string) Option { | ||||
| 	return func(o *Options) { | ||||
| 		o.DatabaseHost = host | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func DatabaseName(name string) Option { | ||||
| 	return func(o *Options) { | ||||
| 		o.DatabaseName = name | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Meter passes meter.Meter to wrapper | ||||
| func Meter(m meter.Meter) Option { | ||||
| 	return func(o *Options) { | ||||
| 		o.Meter = m | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Logger passes logger.Logger to wrapper | ||||
| func Logger(l logger.Logger) Option { | ||||
| 	return func(o *Options) { | ||||
| 		o.Logger = l | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // LoggerEnabled enable sql logging | ||||
| func LoggerEnabled(b bool) Option { | ||||
| 	return func(o *Options) { | ||||
| 		o.LoggerEnabled = b | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // LoggerLevel passes logger.Level option | ||||
| func LoggerLevel(lvl logger.Level) Option { | ||||
| 	return func(o *Options) { | ||||
| 		o.LoggerLevel = lvl | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // LoggerObserver passes observer to fill logger fields | ||||
| func LoggerObserver(obs func(context.Context, string, string, time.Duration, error) []interface{}) Option { | ||||
| 	return func(o *Options) { | ||||
| 		o.LoggerObserver = obs | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Tracer passes tracer.Tracer to wrapper | ||||
| func Tracer(t tracer.Tracer) Option { | ||||
| 	return func(o *Options) { | ||||
| 		o.Tracer = t | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type queryNameKey struct{} | ||||
|  | ||||
| // QueryName passes query name to wrapper func | ||||
| func QueryName(ctx context.Context, name string) context.Context { | ||||
| 	if ctx == nil { | ||||
| 		ctx = context.Background() | ||||
| 	} | ||||
| 	return context.WithValue(ctx, queryNameKey{}, name) | ||||
| } | ||||
|  | ||||
| func getQueryName(ctx context.Context) string { | ||||
| 	if v, ok := ctx.Value(queryNameKey{}).(string); ok && v != labelUnknown { | ||||
| 		return v | ||||
| 	} | ||||
| 	return getCallerName() | ||||
| } | ||||
							
								
								
									
										41
									
								
								hooks/sql/stats.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								hooks/sql/stats.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,41 @@ | ||||
| package sql | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| type Statser interface { | ||||
| 	Stats() sql.DBStats | ||||
| } | ||||
|  | ||||
| func NewStatsMeter(ctx context.Context, db Statser, opts ...Option) { | ||||
| 	options := NewOptions(opts...) | ||||
|  | ||||
| 	go func() { | ||||
| 		ticker := time.NewTicker(options.MeterStatsInterval) | ||||
| 		defer ticker.Stop() | ||||
|  | ||||
| 		for { | ||||
| 			select { | ||||
| 			case <-ctx.Done(): | ||||
| 				return | ||||
| 			case <-ticker.C: | ||||
| 				if db == nil { | ||||
| 					return | ||||
| 				} | ||||
| 				stats := db.Stats() | ||||
| 				options.Meter.Counter(MaxOpenConnections).Set(uint64(stats.MaxOpenConnections)) | ||||
| 				options.Meter.Counter(OpenConnections).Set(uint64(stats.OpenConnections)) | ||||
| 				options.Meter.Counter(InuseConnections).Set(uint64(stats.InUse)) | ||||
| 				options.Meter.Counter(IdleConnections).Set(uint64(stats.Idle)) | ||||
| 				options.Meter.Counter(WaitConnections).Set(uint64(stats.WaitCount)) | ||||
| 				options.Meter.FloatCounter(BlockedSeconds).Set(stats.WaitDuration.Seconds()) | ||||
| 				options.Meter.Counter(MaxIdleClosed).Set(uint64(stats.MaxIdleClosed)) | ||||
| 				options.Meter.Counter(MaxIdletimeClosed).Set(uint64(stats.MaxIdleTimeClosed)) | ||||
| 				options.Meter.Counter(MaxLifetimeClosed).Set(uint64(stats.MaxLifetimeClosed)) | ||||
| 			} | ||||
| 		} | ||||
| 	}() | ||||
| } | ||||
							
								
								
									
										287
									
								
								hooks/sql/stmt.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										287
									
								
								hooks/sql/stmt.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,287 @@ | ||||
| package sql | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql/driver" | ||||
| 	"fmt" | ||||
| 	"time" | ||||
|  | ||||
| 	"go.unistack.org/micro/v3/hooks/requestid" | ||||
| 	"go.unistack.org/micro/v3/tracer" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	_ driver.Stmt              = (*wrapperStmt)(nil) | ||||
| 	_ driver.StmtQueryContext  = (*wrapperStmt)(nil) | ||||
| 	_ driver.StmtExecContext   = (*wrapperStmt)(nil) | ||||
| 	_ driver.NamedValueChecker = (*wrapperStmt)(nil) | ||||
| ) | ||||
|  | ||||
| // wrapperStmt defines a wrapper for driver.Stmt | ||||
| type wrapperStmt struct { | ||||
| 	stmt  driver.Stmt | ||||
| 	opts  Options | ||||
| 	query string | ||||
| 	ctx   context.Context | ||||
| } | ||||
|  | ||||
| // Close implements driver.Stmt Close | ||||
| func (w *wrapperStmt) Close() error { | ||||
| 	var ctx context.Context | ||||
| 	if w.ctx != nil { | ||||
| 		ctx = w.ctx | ||||
| 	} else { | ||||
| 		ctx = context.Background() | ||||
| 	} | ||||
| 	_ = ctx | ||||
| 	labels := []string{labelMethod, "Close"} | ||||
| 	ts := time.Now() | ||||
| 	err := w.stmt.Close() | ||||
| 	td := time.Since(ts) | ||||
| 	te := td.Seconds() | ||||
| 	if err != nil { | ||||
| 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelFailure)...).Inc() | ||||
| 	} else { | ||||
| 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelSuccess)...).Inc() | ||||
| 	} | ||||
| 	w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||
| 	w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||
| 	/* | ||||
| 		if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||
| 			w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "Close", getCallerName(), td, err)...) | ||||
| 		} | ||||
| 	*/ | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| // NumInput implements driver.Stmt NumInput | ||||
| func (w *wrapperStmt) NumInput() int { | ||||
| 	return w.stmt.NumInput() | ||||
| } | ||||
|  | ||||
| // CheckNamedValue implements driver.NamedValueChecker | ||||
| func (w *wrapperStmt) CheckNamedValue(v *driver.NamedValue) error { | ||||
| 	s, ok := w.stmt.(driver.NamedValueChecker) | ||||
| 	if !ok { | ||||
| 		return driver.ErrSkip | ||||
| 	} | ||||
| 	return s.CheckNamedValue(v) | ||||
| } | ||||
|  | ||||
| // Exec implements driver.Stmt Exec | ||||
| func (w *wrapperStmt) Exec(args []driver.Value) (driver.Result, error) { | ||||
| 	var ctx context.Context | ||||
| 	if w.ctx != nil { | ||||
| 		ctx = w.ctx | ||||
| 	} else { | ||||
| 		ctx = context.Background() | ||||
| 	} | ||||
| 	_ = ctx | ||||
| 	labels := []string{labelMethod, "Exec"} | ||||
| 	ts := time.Now() | ||||
| 	res, err := w.stmt.Exec(args) // nolint:staticcheck | ||||
| 	td := time.Since(ts) | ||||
| 	te := td.Seconds() | ||||
| 	if err != nil { | ||||
| 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelFailure)...).Inc() | ||||
| 	} else { | ||||
| 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelSuccess)...).Inc() | ||||
| 	} | ||||
| 	w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||
| 	w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||
| 	/* | ||||
| 		if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||
| 			w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "Exec", getCallerName(), td, err)...) | ||||
| 		} | ||||
| 	*/ | ||||
| 	return res, err | ||||
| } | ||||
|  | ||||
| // Query implements driver.Stmt Query | ||||
| func (w *wrapperStmt) Query(args []driver.Value) (driver.Rows, error) { | ||||
| 	var ctx context.Context | ||||
| 	if w.ctx != nil { | ||||
| 		ctx = w.ctx | ||||
| 	} else { | ||||
| 		ctx = context.Background() | ||||
| 	} | ||||
| 	_ = ctx | ||||
| 	labels := []string{labelMethod, "Query"} | ||||
| 	ts := time.Now() | ||||
| 	rows, err := w.stmt.Query(args) // nolint:staticcheck | ||||
| 	td := time.Since(ts) | ||||
| 	te := td.Seconds() | ||||
| 	if err != nil { | ||||
| 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelFailure)...).Inc() | ||||
| 	} else { | ||||
| 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelSuccess)...).Inc() | ||||
| 	} | ||||
| 	w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||
| 	w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||
| 	/* | ||||
| 		if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||
| 			w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "Query", getCallerName(), td, err)...) | ||||
| 		} | ||||
| 	*/ | ||||
| 	return rows, err | ||||
| } | ||||
|  | ||||
| // ColumnConverter implements driver.ColumnConverter | ||||
| func (w *wrapperStmt) ColumnConverter(idx int) driver.ValueConverter { | ||||
| 	s, ok := w.stmt.(driver.ColumnConverter) // nolint:staticcheck | ||||
| 	if !ok { | ||||
| 		return nil | ||||
| 	} | ||||
| 	return s.ColumnConverter(idx) | ||||
| } | ||||
|  | ||||
| // ExecContext implements driver.StmtExecContext ExecContext | ||||
| func (w *wrapperStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { | ||||
| 	var nctx context.Context | ||||
| 	var span tracer.Span | ||||
|  | ||||
| 	name := getQueryName(ctx) | ||||
| 	if w.ctx != nil { | ||||
| 		nctx, span = w.opts.Tracer.Start(w.ctx, "sdk.database", tracer.WithSpanKind(tracer.SpanKindClient)) | ||||
| 	} else { | ||||
| 		nctx, span = w.opts.Tracer.Start(ctx, "sdk.database", tracer.WithSpanKind(tracer.SpanKindClient)) | ||||
| 	} | ||||
| 	span.AddLabels("db.method", "ExecContext") | ||||
| 	span.AddLabels("db.statement", name) | ||||
| 	defer span.Finish() | ||||
| 	if len(args) > 0 { | ||||
| 		span.AddLabels("db.args", fmt.Sprintf("%v", namedValueToLabels(args))) | ||||
| 	} | ||||
| 	if id, ok := ctx.Value(requestid.XRequestIDKey{}).(string); ok { | ||||
| 		span.AddLabels("x-request-id", id) | ||||
| 	} | ||||
| 	labels := []string{labelMethod, "ExecContext", labelQuery, name} | ||||
|  | ||||
| 	if conn, ok := w.stmt.(driver.StmtExecContext); ok { | ||||
| 		ts := time.Now() | ||||
| 		res, err := conn.ExecContext(nctx, args) | ||||
| 		td := time.Since(ts) | ||||
| 		te := td.Seconds() | ||||
| 		if err != nil { | ||||
| 			w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelFailure)...).Inc() | ||||
| 			span.SetStatus(tracer.SpanStatusError, err.Error()) | ||||
| 		} else { | ||||
| 			w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelSuccess)...).Inc() | ||||
| 		} | ||||
| 		w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||
| 		w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||
| 		/* | ||||
| 			if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||
| 				w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "ExecContext", name, td, err)...) | ||||
| 			} | ||||
| 		*/ | ||||
| 		return res, err | ||||
| 	} | ||||
|  | ||||
| 	values, err := namedValueToValue(args) | ||||
| 	if err != nil { | ||||
| 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelFailure)...).Inc() | ||||
| 		span.SetStatus(tracer.SpanStatusError, err.Error()) | ||||
| 		/* | ||||
| 			if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||
| 				w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "ExecContext", name, 0, err)...) | ||||
| 			} | ||||
| 		*/ | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	ts := time.Now() | ||||
| 	res, err := w.Exec(values) // nolint:staticcheck | ||||
| 	td := time.Since(ts) | ||||
| 	te := td.Seconds() | ||||
| 	if err != nil { | ||||
| 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelFailure)...).Inc() | ||||
| 		span.SetStatus(tracer.SpanStatusError, err.Error()) | ||||
| 	} else { | ||||
| 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelSuccess)...).Inc() | ||||
| 	} | ||||
|  | ||||
| 	w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||
| 	w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||
| 	/* | ||||
| 		if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||
| 			w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "ExecContext", name, td, err)...) | ||||
| 		} | ||||
| 	*/ | ||||
| 	return res, err | ||||
| } | ||||
|  | ||||
| // QueryContext implements driver.StmtQueryContext StmtQueryContext | ||||
| func (w *wrapperStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { | ||||
| 	var nctx context.Context | ||||
| 	var span tracer.Span | ||||
|  | ||||
| 	name := getQueryName(ctx) | ||||
| 	if w.ctx != nil { | ||||
| 		nctx, span = w.opts.Tracer.Start(w.ctx, "sdk.database", tracer.WithSpanKind(tracer.SpanKindClient)) | ||||
| 	} else { | ||||
| 		nctx, span = w.opts.Tracer.Start(ctx, "sdk.database", tracer.WithSpanKind(tracer.SpanKindClient)) | ||||
| 	} | ||||
| 	span.AddLabels("db.method", "QueryContext") | ||||
| 	span.AddLabels("db.statement", name) | ||||
| 	defer span.Finish() | ||||
| 	if len(args) > 0 { | ||||
| 		span.AddLabels("db.args", fmt.Sprintf("%v", namedValueToLabels(args))) | ||||
| 	} | ||||
| 	if id, ok := ctx.Value(requestid.XRequestIDKey{}).(string); ok { | ||||
| 		span.AddLabels("x-request-id", id) | ||||
| 	} | ||||
| 	labels := []string{labelMethod, "QueryContext", labelQuery, name} | ||||
| 	if conn, ok := w.stmt.(driver.StmtQueryContext); ok { | ||||
| 		ts := time.Now() | ||||
| 		rows, err := conn.QueryContext(nctx, args) | ||||
| 		td := time.Since(ts) | ||||
| 		te := td.Seconds() | ||||
| 		if err != nil { | ||||
| 			w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelFailure)...).Inc() | ||||
| 			span.SetStatus(tracer.SpanStatusError, err.Error()) | ||||
| 		} else { | ||||
| 			w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelSuccess)...).Inc() | ||||
| 		} | ||||
|  | ||||
| 		w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||
| 		w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||
| 		/* | ||||
| 			if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||
| 				w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "QueryContext", name, td, err)...) | ||||
| 			} | ||||
| 		*/ | ||||
| 		return rows, err | ||||
| 	} | ||||
|  | ||||
| 	values, err := namedValueToValue(args) | ||||
| 	if err != nil { | ||||
| 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelFailure)...).Inc() | ||||
|  | ||||
| 		span.SetStatus(tracer.SpanStatusError, err.Error()) | ||||
| 		/* | ||||
| 			if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||
| 				w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "QueryContext", name, 0, err)...) | ||||
| 			} | ||||
| 		*/ | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	ts := time.Now() | ||||
| 	rows, err := w.Query(values) // nolint:staticcheck | ||||
| 	td := time.Since(ts) | ||||
| 	te := td.Seconds() | ||||
| 	if err != nil { | ||||
| 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelFailure)...).Inc() | ||||
| 		span.SetStatus(tracer.SpanStatusError, err.Error()) | ||||
| 	} else { | ||||
| 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelSuccess)...).Inc() | ||||
| 	} | ||||
|  | ||||
| 	w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||
| 	w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||
| 	/* | ||||
| 		if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||
| 			w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "QueryContext", name, td, err)...) | ||||
| 		} | ||||
| 	*/ | ||||
| 	return rows, err | ||||
| } | ||||
							
								
								
									
										63
									
								
								hooks/sql/tx.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										63
									
								
								hooks/sql/tx.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,63 @@ | ||||
| package sql | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql/driver" | ||||
| 	"time" | ||||
|  | ||||
| 	"go.unistack.org/micro/v3/tracer" | ||||
| ) | ||||
|  | ||||
| var _ driver.Tx = (*wrapperTx)(nil) | ||||
|  | ||||
| // wrapperTx defines a wrapper for driver.Tx | ||||
| type wrapperTx struct { | ||||
| 	tx   driver.Tx | ||||
| 	span tracer.Span | ||||
| 	opts Options | ||||
| 	ctx  context.Context | ||||
| } | ||||
|  | ||||
| // Commit implements driver.Tx Commit | ||||
| func (w *wrapperTx) Commit() error { | ||||
| 	ts := time.Now() | ||||
| 	err := w.tx.Commit() | ||||
| 	td := time.Since(ts) | ||||
| 	_ = td | ||||
| 	if w.span != nil { | ||||
| 		if err != nil { | ||||
| 			w.span.SetStatus(tracer.SpanStatusError, err.Error()) | ||||
| 		} | ||||
| 		w.span.Finish() | ||||
| 	} | ||||
| 	/* | ||||
| 		if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||
| 			w.opts.Logger.Log(w.ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(w.ctx, "Commit", getCallerName(), td, err)...) | ||||
| 		} | ||||
| 	*/ | ||||
| 	w.ctx = nil | ||||
|  | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| // Rollback implements driver.Tx Rollback | ||||
| func (w *wrapperTx) Rollback() error { | ||||
| 	ts := time.Now() | ||||
| 	err := w.tx.Rollback() | ||||
| 	td := time.Since(ts) | ||||
| 	_ = td | ||||
| 	if w.span != nil { | ||||
| 		if err != nil { | ||||
| 			w.span.SetStatus(tracer.SpanStatusError, err.Error()) | ||||
| 		} | ||||
| 		w.span.Finish() | ||||
| 	} | ||||
| 	/* | ||||
| 		if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||
| 			w.opts.Logger.Log(w.ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(w.ctx, "Rollback", getCallerName(), td, err)...) | ||||
| 		} | ||||
| 	*/ | ||||
| 	w.ctx = nil | ||||
|  | ||||
| 	return err | ||||
| } | ||||
							
								
								
									
										19
									
								
								hooks/sql/wrap.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								hooks/sql/wrap.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,19 @@ | ||||
| package sql | ||||
|  | ||||
| import ( | ||||
| 	"database/sql/driver" | ||||
| ) | ||||
|  | ||||
| /* | ||||
| func wrapDriver(d driver.Driver, opts Options) driver.Driver { | ||||
| 	if _, ok := d.(driver.DriverContext); ok { | ||||
| 		return &wrapperDriver{driver: d, opts: opts} | ||||
| 	} | ||||
| 	return struct{ driver.Driver }{&wrapperDriver{driver: d, opts: opts}} | ||||
| } | ||||
| */ | ||||
|  | ||||
| // WrapConn allows an existing driver.Conn to be wrapped. | ||||
| func WrapConn(c driver.Conn, opts ...Option) driver.Conn { | ||||
| 	return wrapConn(c, NewOptions(opts...)) | ||||
| } | ||||
							
								
								
									
										20699
									
								
								hooks/sql/wrap_gen.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20699
									
								
								hooks/sql/wrap_gen.go
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										194
									
								
								hooks/validator/validator.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										194
									
								
								hooks/validator/validator.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,194 @@ | ||||
| package validator | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
|  | ||||
| 	"go.unistack.org/micro/v3/client" | ||||
| 	"go.unistack.org/micro/v3/errors" | ||||
| 	"go.unistack.org/micro/v3/server" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	DefaultClientErrorFunc = func(req client.Request, rsp interface{}, err error) error { | ||||
| 		if rsp != nil { | ||||
| 			return errors.BadGateway(req.Service(), "%v", err) | ||||
| 		} | ||||
| 		return errors.BadRequest(req.Service(), "%v", err) | ||||
| 	} | ||||
|  | ||||
| 	DefaultServerErrorFunc = func(req server.Request, rsp interface{}, err error) error { | ||||
| 		if rsp != nil { | ||||
| 			return errors.BadGateway(req.Service(), "%v", err) | ||||
| 		} | ||||
| 		return errors.BadRequest(req.Service(), "%v", err) | ||||
| 	} | ||||
|  | ||||
| 	DefaultPublishErrorFunc = func(msg client.Message, err error) error { | ||||
| 		return errors.BadRequest(msg.Topic(), "%v", err) | ||||
| 	} | ||||
|  | ||||
| 	DefaultSubscribeErrorFunc = func(msg server.Message, err error) error { | ||||
| 		return errors.BadRequest(msg.Topic(), "%v", err) | ||||
| 	} | ||||
| ) | ||||
|  | ||||
| type ( | ||||
| 	ClientErrorFunc    func(client.Request, interface{}, error) error | ||||
| 	ServerErrorFunc    func(server.Request, interface{}, error) error | ||||
| 	PublishErrorFunc   func(client.Message, error) error | ||||
| 	SubscribeErrorFunc func(server.Message, error) error | ||||
| ) | ||||
|  | ||||
| // Options struct holds wrapper options | ||||
| type Options struct { | ||||
| 	ClientErrorFn          ClientErrorFunc | ||||
| 	ServerErrorFn          ServerErrorFunc | ||||
| 	PublishErrorFn         PublishErrorFunc | ||||
| 	SubscribeErrorFn       SubscribeErrorFunc | ||||
| 	ClientValidateResponse bool | ||||
| 	ServerValidateResponse bool | ||||
| } | ||||
|  | ||||
| // Option func signature | ||||
| type Option func(*Options) | ||||
|  | ||||
| func ClientValidateResponse(b bool) Option { | ||||
| 	return func(o *Options) { | ||||
| 		o.ClientValidateResponse = b | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func ServerValidateResponse(b bool) Option { | ||||
| 	return func(o *Options) { | ||||
| 		o.ClientValidateResponse = b | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func ClientReqErrorFn(fn ClientErrorFunc) Option { | ||||
| 	return func(o *Options) { | ||||
| 		o.ClientErrorFn = fn | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func ServerErrorFn(fn ServerErrorFunc) Option { | ||||
| 	return func(o *Options) { | ||||
| 		o.ServerErrorFn = fn | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func PublishErrorFn(fn PublishErrorFunc) Option { | ||||
| 	return func(o *Options) { | ||||
| 		o.PublishErrorFn = fn | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func SubscribeErrorFn(fn SubscribeErrorFunc) Option { | ||||
| 	return func(o *Options) { | ||||
| 		o.SubscribeErrorFn = fn | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func NewOptions(opts ...Option) Options { | ||||
| 	options := Options{ | ||||
| 		ClientErrorFn:    DefaultClientErrorFunc, | ||||
| 		ServerErrorFn:    DefaultServerErrorFunc, | ||||
| 		PublishErrorFn:   DefaultPublishErrorFunc, | ||||
| 		SubscribeErrorFn: DefaultSubscribeErrorFunc, | ||||
| 	} | ||||
| 	for _, o := range opts { | ||||
| 		o(&options) | ||||
| 	} | ||||
| 	return options | ||||
| } | ||||
|  | ||||
| func NewHook(opts ...Option) *hook { | ||||
| 	return &hook{opts: NewOptions(opts...)} | ||||
| } | ||||
|  | ||||
| type validator interface { | ||||
| 	Validate() error | ||||
| } | ||||
|  | ||||
| type hook struct { | ||||
| 	opts Options | ||||
| } | ||||
|  | ||||
| func (w *hook) ClientCall(next client.FuncCall) client.FuncCall { | ||||
| 	return func(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { | ||||
| 		if v, ok := req.Body().(validator); ok { | ||||
| 			if err := v.Validate(); err != nil { | ||||
| 				return w.opts.ClientErrorFn(req, nil, err) | ||||
| 			} | ||||
| 		} | ||||
| 		err := next(ctx, req, rsp, opts...) | ||||
| 		if v, ok := rsp.(validator); ok && w.opts.ClientValidateResponse { | ||||
| 			if verr := v.Validate(); verr != nil { | ||||
| 				return w.opts.ClientErrorFn(req, rsp, verr) | ||||
| 			} | ||||
| 		} | ||||
| 		return err | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (w *hook) ClientStream(next client.FuncStream) client.FuncStream { | ||||
| 	return func(ctx context.Context, req client.Request, opts ...client.CallOption) (client.Stream, error) { | ||||
| 		if v, ok := req.Body().(validator); ok { | ||||
| 			if err := v.Validate(); err != nil { | ||||
| 				return nil, w.opts.ClientErrorFn(req, nil, err) | ||||
| 			} | ||||
| 		} | ||||
| 		return next(ctx, req, opts...) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (w *hook) ClientPublish(next client.FuncPublish) client.FuncPublish { | ||||
| 	return func(ctx context.Context, msg client.Message, opts ...client.PublishOption) error { | ||||
| 		if v, ok := msg.Payload().(validator); ok { | ||||
| 			if err := v.Validate(); err != nil { | ||||
| 				return w.opts.PublishErrorFn(msg, err) | ||||
| 			} | ||||
| 		} | ||||
| 		return next(ctx, msg, opts...) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (w *hook) ClientBatchPublish(next client.FuncBatchPublish) client.FuncBatchPublish { | ||||
| 	return func(ctx context.Context, msgs []client.Message, opts ...client.PublishOption) error { | ||||
| 		for _, msg := range msgs { | ||||
| 			if v, ok := msg.Payload().(validator); ok { | ||||
| 				if err := v.Validate(); err != nil { | ||||
| 					return w.opts.PublishErrorFn(msg, err) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 		return next(ctx, msgs, opts...) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (w *hook) ServerHandler(next server.FuncHandler) server.FuncHandler { | ||||
| 	return func(ctx context.Context, req server.Request, rsp interface{}) error { | ||||
| 		if v, ok := req.Body().(validator); ok { | ||||
| 			if err := v.Validate(); err != nil { | ||||
| 				return w.opts.ServerErrorFn(req, nil, err) | ||||
| 			} | ||||
| 		} | ||||
| 		err := next(ctx, req, rsp) | ||||
| 		if v, ok := rsp.(validator); ok && w.opts.ServerValidateResponse { | ||||
| 			if verr := v.Validate(); verr != nil { | ||||
| 				return w.opts.ServerErrorFn(req, rsp, verr) | ||||
| 			} | ||||
| 		} | ||||
| 		return err | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (w *hook) ServerSubscriber(next server.FuncSubHandler) server.FuncSubHandler { | ||||
| 	return func(ctx context.Context, msg server.Message) error { | ||||
| 		if v, ok := msg.Body().(validator); ok { | ||||
| 			if err := v.Validate(); err != nil { | ||||
| 				return w.opts.SubscribeErrorFn(msg, err) | ||||
| 			} | ||||
| 		} | ||||
| 		return next(ctx, msg) | ||||
| 	} | ||||
| } | ||||
| @@ -4,18 +4,20 @@ package logger | ||||
| type Level int8 | ||||
|  | ||||
| const ( | ||||
| 	// TraceLevel level usually used to find bugs, very verbose | ||||
| 	// TraceLevel usually used to find bugs, very verbose | ||||
| 	TraceLevel Level = iota - 2 | ||||
| 	// DebugLevel level used only when enabled debugging | ||||
| 	// DebugLevel used only when enabled debugging | ||||
| 	DebugLevel | ||||
| 	// InfoLevel level used for general info about what's going on inside the application | ||||
| 	// InfoLevel used for general info about what's going on inside the application | ||||
| 	InfoLevel | ||||
| 	// WarnLevel level used for non-critical entries | ||||
| 	// WarnLevel used for non-critical entries | ||||
| 	WarnLevel | ||||
| 	// ErrorLevel level used for errors that should definitely be noted | ||||
| 	// ErrorLevel used for errors that should definitely be noted | ||||
| 	ErrorLevel | ||||
| 	// FatalLevel level used for critical errors and then calls `os.Exit(1)` | ||||
| 	// FatalLevel used for critical errors and then calls `os.Exit(1)` | ||||
| 	FatalLevel | ||||
| 	// NoneLevel used to disable logging | ||||
| 	NoneLevel | ||||
| ) | ||||
|  | ||||
| // String returns logger level string representation | ||||
| @@ -33,6 +35,8 @@ func (l Level) String() string { | ||||
| 		return "error" | ||||
| 	case FatalLevel: | ||||
| 		return "fatal" | ||||
| 	case NoneLevel: | ||||
| 		return "none" | ||||
| 	} | ||||
| 	return "info" | ||||
| } | ||||
| @@ -58,6 +62,8 @@ func ParseLevel(lvl string) Level { | ||||
| 		return ErrorLevel | ||||
| 	case FatalLevel.String(): | ||||
| 		return FatalLevel | ||||
| 	case NoneLevel.String(): | ||||
| 		return NoneLevel | ||||
| 	} | ||||
| 	return InfoLevel | ||||
| } | ||||
|   | ||||
| @@ -52,6 +52,12 @@ type Options struct { | ||||
| 	AddStacktrace bool | ||||
| 	// DedupKeys deduplicate keys in log output | ||||
| 	DedupKeys bool | ||||
| 	// FatalFinalizers runs in order in [logger.Fatal] method | ||||
| 	FatalFinalizers []func(context.Context) | ||||
| } | ||||
|  | ||||
| var DefaultFatalFinalizer = func(ctx context.Context) { | ||||
| 	os.Exit(1) | ||||
| } | ||||
|  | ||||
| // NewOptions creates new options struct | ||||
| @@ -65,6 +71,7 @@ func NewOptions(opts ...Option) Options { | ||||
| 		AddSource:        true, | ||||
| 		TimeFunc:         time.Now, | ||||
| 		Meter:            meter.DefaultMeter, | ||||
| 		FatalFinalizers:  []func(context.Context){DefaultFatalFinalizer}, | ||||
| 	} | ||||
|  | ||||
| 	WithMicroKeys()(&options) | ||||
| @@ -76,6 +83,13 @@ func NewOptions(opts ...Option) Options { | ||||
| 	return options | ||||
| } | ||||
|  | ||||
| // WithFatalFinalizers set logger.Fatal finalizers | ||||
| func WithFatalFinalizers(fncs ...func(context.Context)) Option { | ||||
| 	return func(o *Options) { | ||||
| 		o.FatalFinalizers = fncs | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // WithContextAttrFuncs appends default funcs for the context attrs filler | ||||
| func WithContextAttrFuncs(fncs ...ContextAttrFunc) Option { | ||||
| 	return func(o *Options) { | ||||
| @@ -99,6 +113,7 @@ func WithAddFields(fields ...interface{}) Option { | ||||
| 					iv, iok := o.Fields[i].(string) | ||||
| 					jv, jok := fields[j].(string) | ||||
| 					if iok && jok && iv == jv { | ||||
| 						o.Fields[i+1] = fields[j+1] | ||||
| 						fields = slices.Delete(fields, j, j+2) | ||||
| 					} | ||||
| 				} | ||||
|   | ||||
| @@ -4,14 +4,12 @@ import ( | ||||
| 	"context" | ||||
| 	"io" | ||||
| 	"log/slog" | ||||
| 	"os" | ||||
| 	"reflect" | ||||
| 	"regexp" | ||||
| 	"runtime" | ||||
| 	"strconv" | ||||
| 	"sync" | ||||
| 	"sync/atomic" | ||||
| 	"time" | ||||
|  | ||||
| 	"go.unistack.org/micro/v3/logger" | ||||
| 	"go.unistack.org/micro/v3/semconv" | ||||
| @@ -22,6 +20,7 @@ const ( | ||||
| 	badKey = "!BADKEY" | ||||
| 	// defaultCallerSkipCount used by logger | ||||
| 	defaultCallerSkipCount = 3 | ||||
| 	timeFormat             = "2006-01-02T15:04:05.000000000Z07:00" | ||||
| ) | ||||
|  | ||||
| var reTrace = regexp.MustCompile(`.*/slog/logger\.go.*\n`) | ||||
| @@ -33,6 +32,7 @@ var ( | ||||
| 	warnValue  = slog.StringValue("warn") | ||||
| 	errorValue = slog.StringValue("error") | ||||
| 	fatalValue = slog.StringValue("fatal") | ||||
| 	noneValue  = slog.StringValue("none") | ||||
| ) | ||||
|  | ||||
| type wrapper struct { | ||||
| @@ -64,6 +64,7 @@ func (s *slogLogger) renameAttr(_ []string, a slog.Attr) slog.Attr { | ||||
| 		a.Key = s.opts.SourceKey | ||||
| 	case slog.TimeKey: | ||||
| 		a.Key = s.opts.TimeKey | ||||
| 		a.Value = slog.StringValue(a.Value.Time().Format(timeFormat)) | ||||
| 	case slog.MessageKey: | ||||
| 		a.Key = s.opts.MessageKey | ||||
| 	case slog.LevelKey: | ||||
| @@ -83,6 +84,8 @@ func (s *slogLogger) renameAttr(_ []string, a slog.Attr) slog.Attr { | ||||
| 			a.Value = errorValue | ||||
| 		case lvl >= logger.FatalLevel: | ||||
| 			a.Value = fatalValue | ||||
| 		case lvl >= logger.NoneLevel: | ||||
| 			a.Value = noneValue | ||||
| 		default: | ||||
| 			a.Value = infoValue | ||||
| 		} | ||||
| @@ -226,11 +229,12 @@ func (s *slogLogger) Error(ctx context.Context, msg string, attrs ...interface{} | ||||
|  | ||||
| func (s *slogLogger) Fatal(ctx context.Context, msg string, attrs ...interface{}) { | ||||
| 	s.printLog(ctx, logger.FatalLevel, msg, attrs...) | ||||
| 	for _, fn := range s.opts.FatalFinalizers { | ||||
| 		fn(ctx) | ||||
| 	} | ||||
| 	if closer, ok := s.opts.Out.(io.Closer); ok { | ||||
| 		closer.Close() | ||||
| 	} | ||||
| 	time.Sleep(1 * time.Second) | ||||
| 	os.Exit(1) | ||||
| } | ||||
|  | ||||
| func (s *slogLogger) Warn(ctx context.Context, msg string, attrs ...interface{}) { | ||||
| @@ -276,7 +280,7 @@ func (s *slogLogger) printLog(ctx context.Context, lvl logger.Level, msg string, | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if (s.opts.AddStacktrace || lvl == logger.FatalLevel) || (s.opts.AddStacktrace && lvl == logger.ErrorLevel) { | ||||
| 	if s.opts.AddStacktrace && (lvl == logger.FatalLevel || lvl == logger.ErrorLevel) { | ||||
| 		stackInfo := make([]byte, 1024*1024) | ||||
| 		if stackSize := runtime.Stack(stackInfo, false); stackSize > 0 { | ||||
| 			traceLines := reTrace.Split(string(stackInfo[:stackSize]), -1) | ||||
| @@ -314,6 +318,8 @@ func loggerToSlogLevel(level logger.Level) slog.Level { | ||||
| 		return slog.LevelDebug - 1 | ||||
| 	case logger.FatalLevel: | ||||
| 		return slog.LevelError + 1 | ||||
| 	case logger.NoneLevel: | ||||
| 		return slog.LevelError + 2 | ||||
| 	default: | ||||
| 		return slog.LevelInfo | ||||
| 	} | ||||
| @@ -331,6 +337,8 @@ func slogToLoggerLevel(level slog.Level) logger.Level { | ||||
| 		return logger.TraceLevel | ||||
| 	case slog.LevelError + 1: | ||||
| 		return logger.FatalLevel | ||||
| 	case slog.LevelError + 2: | ||||
| 		return logger.NoneLevel | ||||
| 	default: | ||||
| 		return logger.InfoLevel | ||||
| 	} | ||||
|   | ||||
| @@ -9,16 +9,19 @@ import ( | ||||
| 	"log/slog" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/google/uuid" | ||||
| 	"go.unistack.org/micro/v3/logger" | ||||
| 	"go.unistack.org/micro/v3/metadata" | ||||
| 	"go.unistack.org/micro/v3/util/buffer" | ||||
| ) | ||||
|  | ||||
| // always first to have proper check | ||||
| func TestStacktrace(t *testing.T) { | ||||
| 	ctx := context.TODO() | ||||
| 	buf := bytes.NewBuffer(nil) | ||||
| 	l := NewLogger(logger.WithLevel(logger.ErrorLevel), logger.WithOutput(buf), | ||||
| 	l := NewLogger(logger.WithLevel(logger.DebugLevel), logger.WithOutput(buf), | ||||
| 		WithHandlerFunc(slog.NewTextHandler), | ||||
| 		logger.WithAddStacktrace(true), | ||||
| 	) | ||||
| @@ -28,7 +31,65 @@ func TestStacktrace(t *testing.T) { | ||||
|  | ||||
| 	l.Error(ctx, "msg1", errors.New("err")) | ||||
|  | ||||
| 	if !bytes.Contains(buf.Bytes(), []byte(`slog_test.go:29`)) { | ||||
| 	if !bytes.Contains(buf.Bytes(), []byte(`slog_test.go:32`)) { | ||||
| 		t.Fatalf("logger error not works, buf contains: %s", buf.Bytes()) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestNoneLevel(t *testing.T) { | ||||
| 	ctx := context.TODO() | ||||
| 	buf := bytes.NewBuffer(nil) | ||||
| 	l := NewLogger(logger.WithLevel(logger.NoneLevel), logger.WithOutput(buf), | ||||
| 		WithHandlerFunc(slog.NewTextHandler), | ||||
| 		logger.WithAddStacktrace(true), | ||||
| 	) | ||||
| 	if err := l.Init(logger.WithFields("key1", "val1")); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	l.Error(ctx, "msg1", errors.New("err")) | ||||
|  | ||||
| 	if buf.Len() != 0 { | ||||
| 		t.Fatalf("logger none level not works, buf contains: %s", buf.Bytes()) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestDelayedBuffer(t *testing.T) { | ||||
| 	ctx := context.TODO() | ||||
| 	buf := bytes.NewBuffer(nil) | ||||
| 	dbuf := buffer.NewDelayedBuffer(100, 100*time.Millisecond, buf) | ||||
| 	l := NewLogger(logger.WithLevel(logger.ErrorLevel), logger.WithOutput(dbuf), | ||||
| 		WithHandlerFunc(slog.NewTextHandler), | ||||
| 		logger.WithAddStacktrace(true), | ||||
| 	) | ||||
| 	if err := l.Init(logger.WithFields("key1", "val1")); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	l.Error(ctx, "msg1", errors.New("err")) | ||||
| 	time.Sleep(120 * time.Millisecond) | ||||
| 	if !bytes.Contains(buf.Bytes(), []byte(`key1=val1`)) { | ||||
| 		t.Fatalf("logger delayed buffer not works, buf contains: %s", buf.Bytes()) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestTime(t *testing.T) { | ||||
| 	ctx := context.TODO() | ||||
| 	buf := bytes.NewBuffer(nil) | ||||
| 	l := NewLogger(logger.WithLevel(logger.ErrorLevel), logger.WithOutput(buf), | ||||
| 		WithHandlerFunc(slog.NewTextHandler), | ||||
| 		logger.WithAddStacktrace(true), | ||||
| 		logger.WithTimeFunc(func() time.Time { | ||||
| 			return time.Unix(0, 0).UTC() | ||||
| 		}), | ||||
| 	) | ||||
| 	if err := l.Init(logger.WithFields("key1", "val1")); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	l.Error(ctx, "msg1", errors.New("err")) | ||||
|  | ||||
| 	if !bytes.Contains(buf.Bytes(), []byte(`timestamp=1970-01-01T00:00:00.000000000Z`)) { | ||||
| 		t.Fatalf("logger error not works, buf contains: %s", buf.Bytes()) | ||||
| 	} | ||||
| } | ||||
| @@ -80,7 +141,7 @@ func TestWithDedupKeysWithAddFields(t *testing.T) { | ||||
|  | ||||
| 	l.Info(ctx, "msg3") | ||||
|  | ||||
| 	if !bytes.Contains(buf.Bytes(), []byte(`msg=msg3 key1=val1 key2=val2`)) { | ||||
| 	if !bytes.Contains(buf.Bytes(), []byte(`msg=msg3 key1=val4 key2=val3`)) { | ||||
| 		t.Fatalf("logger error not works, buf contains: %s", buf.Bytes()) | ||||
| 	} | ||||
| } | ||||
| @@ -362,15 +423,16 @@ func TestLogger(t *testing.T) { | ||||
| func Test_WithContextAttrFunc(t *testing.T) { | ||||
| 	loggerContextAttrFuncs := []logger.ContextAttrFunc{ | ||||
| 		func(ctx context.Context) []interface{} { | ||||
| 			md, ok := metadata.FromIncomingContext(ctx) | ||||
| 			md, ok := metadata.FromOutgoingContext(ctx) | ||||
| 			if !ok { | ||||
| 				return nil | ||||
| 			} | ||||
| 			attrs := make([]interface{}, 0, 10) | ||||
| 			for k, v := range md { | ||||
| 				switch k { | ||||
| 				case "X-Request-Id", "Phone", "External-Id", "Source-Service", "X-App-Install-Id", "Client-Id", "Client-Ip": | ||||
| 					attrs = append(attrs, strings.ToLower(k), v) | ||||
| 				key := strings.ToLower(k) | ||||
| 				switch key { | ||||
| 				case "x-request-id", "phone", "external-Id", "source-service", "x-app-install-id", "client-id", "client-ip": | ||||
| 					attrs = append(attrs, key, v) | ||||
| 				} | ||||
| 			} | ||||
| 			return attrs | ||||
| @@ -380,7 +442,7 @@ func Test_WithContextAttrFunc(t *testing.T) { | ||||
| 	logger.DefaultContextAttrFuncs = append(logger.DefaultContextAttrFuncs, loggerContextAttrFuncs...) | ||||
|  | ||||
| 	ctx := context.TODO() | ||||
| 	ctx = metadata.AppendIncomingContext(ctx, "X-Request-Id", uuid.New().String(), | ||||
| 	ctx = metadata.AppendOutgoingContext(ctx, "X-Request-Id", uuid.New().String(), | ||||
| 		"Source-Service", "Test-System") | ||||
|  | ||||
| 	buf := bytes.NewBuffer(nil) | ||||
| @@ -393,17 +455,39 @@ func Test_WithContextAttrFunc(t *testing.T) { | ||||
| 	if !(bytes.Contains(buf.Bytes(), []byte(`"level":"info"`)) && bytes.Contains(buf.Bytes(), []byte(`"msg":"test message"`))) { | ||||
| 		t.Fatalf("logger info, buf %s", buf.Bytes()) | ||||
| 	} | ||||
| 	if !(bytes.Contains(buf.Bytes(), []byte(`"x-request-id":"`))) { | ||||
| 	if !(bytes.Contains(buf.Bytes(), []byte(`"x-request-id":`))) { | ||||
| 		t.Fatalf("logger info, buf %s", buf.Bytes()) | ||||
| 	} | ||||
| 	if !(bytes.Contains(buf.Bytes(), []byte(`"source-service":"Test-System"`))) { | ||||
| 		t.Fatalf("logger info, buf %s", buf.Bytes()) | ||||
| 	} | ||||
| 	buf.Reset() | ||||
| 	imd, _ := metadata.FromIncomingContext(ctx) | ||||
| 	omd, _ := metadata.FromOutgoingContext(ctx) | ||||
| 	l.Info(ctx, "test message1") | ||||
| 	imd.Set("Source-Service", "Test-System2") | ||||
| 	omd.Set("Source-Service", "Test-System2") | ||||
| 	l.Info(ctx, "test message2") | ||||
|  | ||||
| 	// t.Logf("xxx %s", buf.Bytes()) | ||||
| } | ||||
|  | ||||
| func TestFatalFinalizers(t *testing.T) { | ||||
| 	ctx := context.TODO() | ||||
| 	buf := bytes.NewBuffer(nil) | ||||
| 	l := NewLogger( | ||||
| 		logger.WithLevel(logger.TraceLevel), | ||||
| 		logger.WithOutput(buf), | ||||
| 	) | ||||
| 	if err := l.Init( | ||||
| 		logger.WithFatalFinalizers(func(ctx context.Context) { | ||||
| 			l.Info(ctx, "fatal finalizer") | ||||
| 		})); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	l.Fatal(ctx, "info_msg1") | ||||
| 	if !bytes.Contains(buf.Bytes(), []byte("fatal finalizer")) { | ||||
| 		t.Fatalf("logger dont have fatal message, buf %s", buf.Bytes()) | ||||
| 	} | ||||
| 	if !bytes.Contains(buf.Bytes(), []byte("info_msg1")) { | ||||
| 		t.Fatalf("logger dont have info_msg1 message, buf %s", buf.Bytes()) | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -4,8 +4,8 @@ package meter | ||||
| import ( | ||||
| 	"io" | ||||
| 	"sort" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| @@ -49,9 +49,11 @@ type Meter interface { | ||||
| 	Set(opts ...Option) Meter | ||||
| 	// Histogram get or create histogram | ||||
| 	Histogram(name string, labels ...string) Histogram | ||||
| 	// HistogramExt get or create histogram with specified quantiles | ||||
| 	HistogramExt(name string, quantiles []float64, labels ...string) Histogram | ||||
| 	// Summary get or create summary | ||||
| 	Summary(name string, labels ...string) Summary | ||||
| 	// SummaryExt get or create summary with spcified quantiles and window time | ||||
| 	// SummaryExt get or create summary with specified quantiles and window time | ||||
| 	SummaryExt(name string, window time.Duration, quantiles []float64, labels ...string) Summary | ||||
| 	// Write writes metrics to io.Writer | ||||
| 	Write(w io.Writer, opts ...Option) error | ||||
| @@ -59,6 +61,8 @@ type Meter interface { | ||||
| 	Options() Options | ||||
| 	// String return meter type | ||||
| 	String() string | ||||
| 	// Unregister metric name and drop all data | ||||
| 	Unregister(name string, labels ...string) bool | ||||
| } | ||||
|  | ||||
| // Counter is a counter | ||||
| @@ -80,7 +84,11 @@ type FloatCounter interface { | ||||
|  | ||||
| // Gauge is a float64 gauge | ||||
| type Gauge interface { | ||||
| 	Add(float64) | ||||
| 	Get() float64 | ||||
| 	Set(float64) | ||||
| 	Dec() | ||||
| 	Inc() | ||||
| } | ||||
|  | ||||
| // Histogram is a histogram for non-negative values with automatically created buckets | ||||
| @@ -117,6 +125,39 @@ func BuildLabels(labels ...string) []string { | ||||
| 	return labels | ||||
| } | ||||
|  | ||||
| var spool = newStringsPool(500) | ||||
|  | ||||
| type stringsPool struct { | ||||
| 	p *sync.Pool | ||||
| 	c int | ||||
| } | ||||
|  | ||||
| func newStringsPool(size int) *stringsPool { | ||||
| 	p := &stringsPool{c: size} | ||||
| 	p.p = &sync.Pool{ | ||||
| 		New: func() interface{} { | ||||
| 			return &strings.Builder{} | ||||
| 		}, | ||||
| 	} | ||||
| 	return p | ||||
| } | ||||
|  | ||||
| func (p *stringsPool) Cap() int { | ||||
| 	return p.c | ||||
| } | ||||
|  | ||||
| func (p *stringsPool) Get() *strings.Builder { | ||||
| 	return p.p.Get().(*strings.Builder) | ||||
| } | ||||
|  | ||||
| func (p *stringsPool) Put(b *strings.Builder) { | ||||
| 	if b.Cap() > p.c { | ||||
| 		return | ||||
| 	} | ||||
| 	b.Reset() | ||||
| 	p.p.Put(b) | ||||
| } | ||||
|  | ||||
| // BuildName used to combine metric with labels. | ||||
| // If labels count is odd, drop last element | ||||
| func BuildName(name string, labels ...string) string { | ||||
| @@ -125,8 +166,6 @@ func BuildName(name string, labels ...string) string { | ||||
| 	} | ||||
|  | ||||
| 	if len(labels) > 2 { | ||||
| 		sort.Sort(byKey(labels)) | ||||
|  | ||||
| 		idx := 0 | ||||
| 		for { | ||||
| 			if labels[idx] == labels[idx+2] { | ||||
| @@ -141,7 +180,9 @@ func BuildName(name string, labels ...string) string { | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	var b strings.Builder | ||||
| 	b := spool.Get() | ||||
| 	defer spool.Put(b) | ||||
|  | ||||
| 	_, _ = b.WriteString(name) | ||||
| 	_, _ = b.WriteRune('{') | ||||
| 	for idx := 0; idx < len(labels); idx += 2 { | ||||
| @@ -149,8 +190,9 @@ func BuildName(name string, labels ...string) string { | ||||
| 			_, _ = b.WriteRune(',') | ||||
| 		} | ||||
| 		_, _ = b.WriteString(labels[idx]) | ||||
| 		_, _ = b.WriteString(`=`) | ||||
| 		_, _ = b.WriteString(strconv.Quote(labels[idx+1])) | ||||
| 		_, _ = b.WriteString(`="`) | ||||
| 		_, _ = b.WriteString(labels[idx+1]) | ||||
| 		_, _ = b.WriteRune('"') | ||||
| 	} | ||||
| 	_, _ = b.WriteRune('}') | ||||
|  | ||||
|   | ||||
| @@ -50,11 +50,12 @@ func TestBuildName(t *testing.T) { | ||||
| 	data := map[string][]string{ | ||||
| 		`my_metric{firstlabel="value2",zerolabel="value3"}`: { | ||||
| 			"my_metric", | ||||
| 			"zerolabel", "value3", "firstlabel", "value2", | ||||
| 			"firstlabel", "value2", | ||||
| 			"zerolabel", "value3", | ||||
| 		}, | ||||
| 		`my_metric{broker="broker2",register="mdns",server="tcp"}`: { | ||||
| 			"my_metric", | ||||
| 			"broker", "broker1", "broker", "broker2", "server", "http", "server", "tcp", "register", "mdns", | ||||
| 			"broker", "broker1", "broker", "broker2", "register", "mdns", "server", "http", "server", "tcp", | ||||
| 		}, | ||||
| 		`my_metric{aaa="aaa"}`: { | ||||
| 			"my_metric", | ||||
|   | ||||
| @@ -28,6 +28,10 @@ func (r *noopMeter) Name() string { | ||||
| 	return r.opts.Name | ||||
| } | ||||
|  | ||||
| func (r *noopMeter) Unregister(name string, labels ...string) bool { | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| // Init initialize options | ||||
| func (r *noopMeter) Init(opts ...Option) error { | ||||
| 	for _, o := range opts { | ||||
| @@ -66,6 +70,11 @@ func (r *noopMeter) Histogram(_ string, labels ...string) Histogram { | ||||
| 	return &noopHistogram{labels: labels} | ||||
| } | ||||
|  | ||||
| // HistogramExt implements the Meter interface | ||||
| func (r *noopMeter) HistogramExt(_ string, quantiles []float64, labels ...string) Histogram { | ||||
| 	return &noopHistogram{labels: labels} | ||||
| } | ||||
|  | ||||
| // Set implements the Meter interface | ||||
| func (r *noopMeter) Set(opts ...Option) Meter { | ||||
| 	m := &noopMeter{opts: r.opts} | ||||
| @@ -132,6 +141,18 @@ type noopGauge struct { | ||||
| 	labels []string | ||||
| } | ||||
|  | ||||
| func (r *noopGauge) Add(float64) { | ||||
| } | ||||
|  | ||||
| func (r *noopGauge) Set(float64) { | ||||
| } | ||||
|  | ||||
| func (r *noopGauge) Inc() { | ||||
| } | ||||
|  | ||||
| func (r *noopGauge) Dec() { | ||||
| } | ||||
|  | ||||
| func (r *noopGauge) Get() float64 { | ||||
| 	return 0 | ||||
| } | ||||
|   | ||||
| @@ -4,6 +4,8 @@ import ( | ||||
| 	"context" | ||||
| ) | ||||
|  | ||||
| var DefaultQuantiles = []float64{.005, .01, .025, .05, .1, .25, .5, 1, 2.5, 5, 10} | ||||
|  | ||||
| // Option powers the configuration for metrics implementations: | ||||
| type Option func(*Options) | ||||
|  | ||||
| @@ -23,6 +25,8 @@ type Options struct { | ||||
| 	WriteProcessMetrics bool | ||||
| 	// WriteFDMetrics flag to write fd metrics | ||||
| 	WriteFDMetrics bool | ||||
| 	// Quantiles specifies buckets for histogram | ||||
| 	Quantiles []float64 | ||||
| } | ||||
|  | ||||
| // NewOptions prepares a set of options: | ||||
| @@ -61,14 +65,12 @@ func Address(value string) Option { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| /* | ||||
| // TimingObjectives defines the desired spread of statistics for histogram / timing metrics: | ||||
| func TimingObjectives(value map[float64]float64) Option { | ||||
| // Quantiles defines the desired spread of statistics for histogram metrics: | ||||
| func Quantiles(quantiles []float64) Option { | ||||
| 	return func(o *Options) { | ||||
| 		o.TimingObjectives = value | ||||
| 		o.Quantiles = quantiles | ||||
| 	} | ||||
| } | ||||
| */ | ||||
|  | ||||
| // Labels add the meter labels | ||||
| func Labels(ls ...string) Option { | ||||
|   | ||||
| @@ -1,12 +1,9 @@ | ||||
| package register | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"unicode" | ||||
| 	"unicode/utf8" | ||||
|  | ||||
| 	"go.unistack.org/micro/v3/metadata" | ||||
| ) | ||||
|  | ||||
| // ExtractValue from reflect.Type from specified depth | ||||
| @@ -38,53 +35,6 @@ func ExtractValue(v reflect.Type, d int) string { | ||||
| 	return v.Name() | ||||
| } | ||||
|  | ||||
| // ExtractEndpoint extract *Endpoint from reflect.Method | ||||
| func ExtractEndpoint(method reflect.Method) *Endpoint { | ||||
| 	if method.PkgPath != "" { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	var rspType, reqType reflect.Type | ||||
| 	var stream bool | ||||
| 	mt := method.Type | ||||
|  | ||||
| 	switch mt.NumIn() { | ||||
| 	case 3: | ||||
| 		reqType = mt.In(1) | ||||
| 		rspType = mt.In(2) | ||||
| 	case 4: | ||||
| 		reqType = mt.In(2) | ||||
| 		rspType = mt.In(3) | ||||
| 	default: | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	// are we dealing with a stream? | ||||
| 	switch rspType.Kind() { | ||||
| 	case reflect.Func, reflect.Interface: | ||||
| 		stream = true | ||||
| 	} | ||||
|  | ||||
| 	request := ExtractValue(reqType, 0) | ||||
| 	response := ExtractValue(rspType, 0) | ||||
| 	if request == "" || response == "" { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	ep := &Endpoint{ | ||||
| 		Name:     method.Name, | ||||
| 		Request:  request, | ||||
| 		Response: response, | ||||
| 		Metadata: metadata.New(0), | ||||
| 	} | ||||
|  | ||||
| 	if stream { | ||||
| 		ep.Metadata.Set("stream", fmt.Sprintf("%v", stream)) | ||||
| 	} | ||||
|  | ||||
| 	return ep | ||||
| } | ||||
|  | ||||
| // ExtractSubValue exctact *Value from reflect.Type | ||||
| func ExtractSubValue(typ reflect.Type) string { | ||||
| 	var reqType reflect.Type | ||||
|   | ||||
| @@ -2,8 +2,6 @@ package register | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"reflect" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| type TestHandler struct{} | ||||
| @@ -15,40 +13,3 @@ type TestResponse struct{} | ||||
| func (t *TestHandler) Test(ctx context.Context, req *TestRequest, rsp *TestResponse) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func TestExtractEndpoint(t *testing.T) { | ||||
| 	handler := &TestHandler{} | ||||
| 	typ := reflect.TypeOf(handler) | ||||
|  | ||||
| 	var endpoints []*Endpoint | ||||
|  | ||||
| 	for m := 0; m < typ.NumMethod(); m++ { | ||||
| 		if e := ExtractEndpoint(typ.Method(m)); e != nil { | ||||
| 			endpoints = append(endpoints, e) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if i := len(endpoints); i != 1 { | ||||
| 		t.Fatalf("Expected 1 endpoint, have %d", i) | ||||
| 	} | ||||
|  | ||||
| 	if endpoints[0].Name != "Test" { | ||||
| 		t.Fatalf("Expected handler Test, got %s", endpoints[0].Name) | ||||
| 	} | ||||
|  | ||||
| 	if endpoints[0].Request == "" { | ||||
| 		t.Fatal("Expected non nil Request") | ||||
| 	} | ||||
|  | ||||
| 	if endpoints[0].Response == "" { | ||||
| 		t.Fatal("Expected non nil Request") | ||||
| 	} | ||||
|  | ||||
| 	if endpoints[0].Request != "TestRequest" { | ||||
| 		t.Fatalf("Expected TestRequest got %s", endpoints[0].Request) | ||||
| 	} | ||||
|  | ||||
| 	if endpoints[0].Response != "TestResponse" { | ||||
| 		t.Fatalf("Expected TestResponse got %s", endpoints[0].Response) | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -27,7 +27,6 @@ type record struct { | ||||
| 	Version  string | ||||
| 	Metadata map[string]string | ||||
| 	Nodes    map[string]*node | ||||
| 	Endpoints []*register.Endpoint | ||||
| } | ||||
|  | ||||
| type memory struct { | ||||
| @@ -59,7 +58,7 @@ func (m *memory) ttlPrune() { | ||||
|  | ||||
| 	for range prune.C { | ||||
| 		m.Lock() | ||||
| 		for domain, services := range m.records { | ||||
| 		for namespace, services := range m.records { | ||||
| 			for service, versions := range services { | ||||
| 				for version, record := range versions { | ||||
| 					for id, n := range record.Nodes { | ||||
| @@ -67,7 +66,7 @@ func (m *memory) ttlPrune() { | ||||
| 							if m.opts.Logger.V(logger.DebugLevel) { | ||||
| 								m.opts.Logger.Debug(m.opts.Context, fmt.Sprintf("Register TTL expired for node %s of service %s", n.ID, service)) | ||||
| 							} | ||||
| 							delete(m.records[domain][service][version].Nodes, id) | ||||
| 							delete(m.records[namespace][service][version].Nodes, id) | ||||
| 						} | ||||
| 					} | ||||
| 				} | ||||
| @@ -131,17 +130,12 @@ func (m *memory) Register(_ context.Context, s *register.Service, opts ...regist | ||||
| 	options := register.NewRegisterOptions(opts...) | ||||
|  | ||||
| 	// get the services for this domain from the register | ||||
| 	srvs, ok := m.records[options.Domain] | ||||
| 	srvs, ok := m.records[options.Namespace] | ||||
| 	if !ok { | ||||
| 		srvs = make(services) | ||||
| 	} | ||||
|  | ||||
| 	// domain is set in metadata so it can be passed to watchers | ||||
| 	if s.Metadata == nil { | ||||
| 		s.Metadata = map[string]string{"domain": options.Domain} | ||||
| 	} else { | ||||
| 		s.Metadata["domain"] = options.Domain | ||||
| 	} | ||||
| 	s.Namespace = options.Namespace | ||||
|  | ||||
| 	// ensure the service name exists | ||||
| 	r := serviceToRecord(s, options.TTL) | ||||
| @@ -154,8 +148,8 @@ func (m *memory) Register(_ context.Context, s *register.Service, opts ...regist | ||||
| 		if m.opts.Logger.V(logger.DebugLevel) { | ||||
| 			m.opts.Logger.Debug(m.opts.Context, fmt.Sprintf("Register added new service: %s, version: %s", s.Name, s.Version)) | ||||
| 		} | ||||
| 		m.records[options.Domain] = srvs | ||||
| 		go m.sendEvent(®ister.Result{Action: "create", Service: s}) | ||||
| 		m.records[options.Namespace] = srvs | ||||
| 		go m.sendEvent(®ister.Result{Action: register.EventCreate, Service: s}) | ||||
| 	} | ||||
|  | ||||
| 	var addedNodes bool | ||||
| @@ -173,9 +167,6 @@ func (m *memory) Register(_ context.Context, s *register.Service, opts ...regist | ||||
| 			metadata[k] = v | ||||
| 		} | ||||
|  | ||||
| 		// set the domain | ||||
| 		metadata["domain"] = options.Domain | ||||
|  | ||||
| 		// add the node | ||||
| 		srvs[s.Name][s.Version].Nodes[n.ID] = &node{ | ||||
| 			Node: ®ister.Node{ | ||||
| @@ -194,7 +185,7 @@ func (m *memory) Register(_ context.Context, s *register.Service, opts ...regist | ||||
| 		if m.opts.Logger.V(logger.DebugLevel) { | ||||
| 			m.opts.Logger.Debug(m.opts.Context, fmt.Sprintf("Register added new node to service: %s, version: %s", s.Name, s.Version)) | ||||
| 		} | ||||
| 		go m.sendEvent(®ister.Result{Action: "update", Service: s}) | ||||
| 		go m.sendEvent(®ister.Result{Action: register.EventUpdate, Service: s}) | ||||
| 	} else { | ||||
| 		// refresh TTL and timestamp | ||||
| 		for _, n := range s.Nodes { | ||||
| @@ -206,7 +197,7 @@ func (m *memory) Register(_ context.Context, s *register.Service, opts ...regist | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	m.records[options.Domain] = srvs | ||||
| 	m.records[options.Namespace] = srvs | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| @@ -216,15 +207,8 @@ func (m *memory) Deregister(ctx context.Context, s *register.Service, opts ...re | ||||
|  | ||||
| 	options := register.NewDeregisterOptions(opts...) | ||||
|  | ||||
| 	// domain is set in metadata so it can be passed to watchers | ||||
| 	if s.Metadata == nil { | ||||
| 		s.Metadata = map[string]string{"domain": options.Domain} | ||||
| 	} else { | ||||
| 		s.Metadata["domain"] = options.Domain | ||||
| 	} | ||||
|  | ||||
| 	// if the domain doesn't exist, there is nothing to deregister | ||||
| 	services, ok := m.records[options.Domain] | ||||
| 	services, ok := m.records[options.Namespace] | ||||
| 	if !ok { | ||||
| 		return nil | ||||
| 	} | ||||
| @@ -253,16 +237,16 @@ func (m *memory) Deregister(ctx context.Context, s *register.Service, opts ...re | ||||
| 	// if the nodes not empty, we replace the version in the store and exist, the rest of the logic | ||||
| 	// is cleanup | ||||
| 	if len(version.Nodes) > 0 { | ||||
| 		m.records[options.Domain][s.Name][s.Version] = version | ||||
| 		go m.sendEvent(®ister.Result{Action: "update", Service: s}) | ||||
| 		m.records[options.Namespace][s.Name][s.Version] = version | ||||
| 		go m.sendEvent(®ister.Result{Action: register.EventUpdate, Service: s}) | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	// if this version was the only version of the service, we can remove the whole service from the | ||||
| 	// register and exit | ||||
| 	if len(versions) == 1 { | ||||
| 		delete(m.records[options.Domain], s.Name) | ||||
| 		go m.sendEvent(®ister.Result{Action: "delete", Service: s}) | ||||
| 		delete(m.records[options.Namespace], s.Name) | ||||
| 		go m.sendEvent(®ister.Result{Action: register.EventDelete, Service: s}) | ||||
|  | ||||
| 		if m.opts.Logger.V(logger.DebugLevel) { | ||||
| 			m.opts.Logger.Debug(m.opts.Context, fmt.Sprintf("Register removed service: %s", s.Name)) | ||||
| @@ -271,8 +255,8 @@ func (m *memory) Deregister(ctx context.Context, s *register.Service, opts ...re | ||||
| 	} | ||||
|  | ||||
| 	// there are other versions of the service running, so only remove this version of it | ||||
| 	delete(m.records[options.Domain][s.Name], s.Version) | ||||
| 	go m.sendEvent(®ister.Result{Action: "delete", Service: s}) | ||||
| 	delete(m.records[options.Namespace][s.Name], s.Version) | ||||
| 	go m.sendEvent(®ister.Result{Action: register.EventDelete, Service: s}) | ||||
| 	if m.opts.Logger.V(logger.DebugLevel) { | ||||
| 		m.opts.Logger.Debug(m.opts.Context, fmt.Sprintf("Register removed service: %s, version: %s", s.Name, s.Version)) | ||||
| 	} | ||||
| @@ -284,15 +268,15 @@ func (m *memory) LookupService(ctx context.Context, name string, opts ...registe | ||||
| 	options := register.NewLookupOptions(opts...) | ||||
|  | ||||
| 	// if it's a wildcard domain, return from all domains | ||||
| 	if options.Domain == register.WildcardDomain { | ||||
| 	if options.Namespace == register.WildcardNamespace { | ||||
| 		m.RLock() | ||||
| 		recs := m.records | ||||
| 		m.RUnlock() | ||||
|  | ||||
| 		var services []*register.Service | ||||
|  | ||||
| 		for domain := range recs { | ||||
| 			srvs, err := m.LookupService(ctx, name, append(opts, register.LookupDomain(domain))...) | ||||
| 		for namespace := range recs { | ||||
| 			srvs, err := m.LookupService(ctx, name, append(opts, register.LookupNamespace(namespace))...) | ||||
| 			if err == register.ErrNotFound { | ||||
| 				continue | ||||
| 			} else if err != nil { | ||||
| @@ -311,7 +295,7 @@ func (m *memory) LookupService(ctx context.Context, name string, opts ...registe | ||||
| 	defer m.RUnlock() | ||||
|  | ||||
| 	// check the domain exists | ||||
| 	services, ok := m.records[options.Domain] | ||||
| 	services, ok := m.records[options.Namespace] | ||||
| 	if !ok { | ||||
| 		return nil, register.ErrNotFound | ||||
| 	} | ||||
| @@ -328,7 +312,7 @@ func (m *memory) LookupService(ctx context.Context, name string, opts ...registe | ||||
| 	var i int | ||||
|  | ||||
| 	for _, r := range versions { | ||||
| 		result[i] = recordToService(r, options.Domain) | ||||
| 		result[i] = recordToService(r, options.Namespace) | ||||
| 		i++ | ||||
| 	} | ||||
|  | ||||
| @@ -339,15 +323,15 @@ func (m *memory) ListServices(ctx context.Context, opts ...register.ListOption) | ||||
| 	options := register.NewListOptions(opts...) | ||||
|  | ||||
| 	// if it's a wildcard domain, list from all domains | ||||
| 	if options.Domain == register.WildcardDomain { | ||||
| 	if options.Namespace == register.WildcardNamespace { | ||||
| 		m.RLock() | ||||
| 		recs := m.records | ||||
| 		m.RUnlock() | ||||
|  | ||||
| 		var services []*register.Service | ||||
|  | ||||
| 		for domain := range recs { | ||||
| 			srvs, err := m.ListServices(ctx, append(opts, register.ListDomain(domain))...) | ||||
| 		for namespace := range recs { | ||||
| 			srvs, err := m.ListServices(ctx, append(opts, register.ListNamespace(namespace))...) | ||||
| 			if err != nil { | ||||
| 				return nil, err | ||||
| 			} | ||||
| @@ -361,7 +345,7 @@ func (m *memory) ListServices(ctx context.Context, opts ...register.ListOption) | ||||
| 	defer m.RUnlock() | ||||
|  | ||||
| 	// ensure the domain exists | ||||
| 	services, ok := m.records[options.Domain] | ||||
| 	services, ok := m.records[options.Namespace] | ||||
| 	if !ok { | ||||
| 		return make([]*register.Service, 0), nil | ||||
| 	} | ||||
| @@ -371,7 +355,7 @@ func (m *memory) ListServices(ctx context.Context, opts ...register.ListOption) | ||||
|  | ||||
| 	for _, service := range services { | ||||
| 		for _, version := range service { | ||||
| 			result = append(result, recordToService(version, options.Domain)) | ||||
| 			result = append(result, recordToService(version, options.Namespace)) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| @@ -426,16 +410,13 @@ func (m *watcher) Next() (*register.Result, error) { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			// extract domain from service metadata | ||||
| 			var domain string | ||||
| 			if r.Service.Metadata != nil && len(r.Service.Metadata["domain"]) > 0 { | ||||
| 				domain = r.Service.Metadata["domain"] | ||||
| 			} else { | ||||
| 				domain = register.DefaultDomain | ||||
| 			namespace := register.DefaultNamespace | ||||
| 			if r.Service.Namespace != "" { | ||||
| 				namespace = r.Service.Namespace | ||||
| 			} | ||||
|  | ||||
| 			// only send the event if watching the wildcard or this specific domain | ||||
| 			if m.wo.Domain == register.WildcardDomain || m.wo.Domain == domain { | ||||
| 			if m.wo.Namespace == register.WildcardNamespace || m.wo.Namespace == namespace { | ||||
| 				return r, nil | ||||
| 			} | ||||
| 		case <-m.exit: | ||||
| @@ -454,11 +435,6 @@ func (m *watcher) Stop() { | ||||
| } | ||||
|  | ||||
| func serviceToRecord(s *register.Service, ttl time.Duration) *record { | ||||
| 	metadata := make(map[string]string, len(s.Metadata)) | ||||
| 	for k, v := range s.Metadata { | ||||
| 		metadata[k] = v | ||||
| 	} | ||||
|  | ||||
| 	nodes := make(map[string]*node, len(s.Nodes)) | ||||
| 	for _, n := range s.Nodes { | ||||
| 		nodes[n.ID] = &node{ | ||||
| @@ -468,42 +444,19 @@ func serviceToRecord(s *register.Service, ttl time.Duration) *record { | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	endpoints := make([]*register.Endpoint, len(s.Endpoints)) | ||||
| 	copy(endpoints, s.Endpoints) | ||||
|  | ||||
| 	return &record{ | ||||
| 		Name:    s.Name, | ||||
| 		Version: s.Version, | ||||
| 		Metadata:  metadata, | ||||
| 		Nodes:   nodes, | ||||
| 		Endpoints: endpoints, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func recordToService(r *record, domain string) *register.Service { | ||||
| func recordToService(r *record, namespace string) *register.Service { | ||||
| 	metadata := make(map[string]string, len(r.Metadata)) | ||||
| 	for k, v := range r.Metadata { | ||||
| 		metadata[k] = v | ||||
| 	} | ||||
|  | ||||
| 	// set the domain in metadata so it can be determined when a wildcard query is performed | ||||
| 	metadata["domain"] = domain | ||||
|  | ||||
| 	endpoints := make([]*register.Endpoint, len(r.Endpoints)) | ||||
| 	for i, e := range r.Endpoints { | ||||
| 		md := make(map[string]string, len(e.Metadata)) | ||||
| 		for k, v := range e.Metadata { | ||||
| 			md[k] = v | ||||
| 		} | ||||
|  | ||||
| 		endpoints[i] = ®ister.Endpoint{ | ||||
| 			Name:     e.Name, | ||||
| 			Request:  e.Request, | ||||
| 			Response: e.Response, | ||||
| 			Metadata: md, | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	nodes := make([]*register.Node, len(r.Nodes)) | ||||
| 	i := 0 | ||||
| 	for _, n := range r.Nodes { | ||||
| @@ -523,8 +476,7 @@ func recordToService(r *record, domain string) *register.Service { | ||||
| 	return ®ister.Service{ | ||||
| 		Name:      r.Name, | ||||
| 		Version:   r.Version, | ||||
| 		Metadata:  metadata, | ||||
| 		Endpoints: endpoints, | ||||
| 		Nodes:     nodes, | ||||
| 		Namespace: namespace, | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -253,32 +253,32 @@ func TestMemoryWildcard(t *testing.T) { | ||||
|  | ||||
| 	testSrv := ®ister.Service{Name: "foo", Version: "1.0.0"} | ||||
|  | ||||
| 	if err := m.Register(ctx, testSrv, register.RegisterDomain("one")); err != nil { | ||||
| 	if err := m.Register(ctx, testSrv, register.RegisterNamespace("one")); err != nil { | ||||
| 		t.Fatalf("Register err: %v", err) | ||||
| 	} | ||||
| 	if err := m.Register(ctx, testSrv, register.RegisterDomain("two")); err != nil { | ||||
| 	if err := m.Register(ctx, testSrv, register.RegisterNamespace("two")); err != nil { | ||||
| 		t.Fatalf("Register err: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	if recs, err := m.ListServices(ctx, register.ListDomain("one")); err != nil { | ||||
| 	if recs, err := m.ListServices(ctx, register.ListNamespace("one")); err != nil { | ||||
| 		t.Errorf("List err: %v", err) | ||||
| 	} else if len(recs) != 1 { | ||||
| 		t.Errorf("Expected 1 record, got %v", len(recs)) | ||||
| 	} | ||||
|  | ||||
| 	if recs, err := m.ListServices(ctx, register.ListDomain("*")); err != nil { | ||||
| 	if recs, err := m.ListServices(ctx, register.ListNamespace("*")); err != nil { | ||||
| 		t.Errorf("List err: %v", err) | ||||
| 	} else if len(recs) != 2 { | ||||
| 		t.Errorf("Expected 2 records, got %v", len(recs)) | ||||
| 	} | ||||
|  | ||||
| 	if recs, err := m.LookupService(ctx, testSrv.Name, register.LookupDomain("one")); err != nil { | ||||
| 	if recs, err := m.LookupService(ctx, testSrv.Name, register.LookupNamespace("one")); err != nil { | ||||
| 		t.Errorf("Lookup err: %v", err) | ||||
| 	} else if len(recs) != 1 { | ||||
| 		t.Errorf("Expected 1 record, got %v", len(recs)) | ||||
| 	} | ||||
|  | ||||
| 	if recs, err := m.LookupService(ctx, testSrv.Name, register.LookupDomain("*")); err != nil { | ||||
| 	if recs, err := m.LookupService(ctx, testSrv.Name, register.LookupNamespace("*")); err != nil { | ||||
| 		t.Errorf("Lookup err: %v", err) | ||||
| 	} else if len(recs) != 2 { | ||||
| 		t.Errorf("Expected 2 records, got %v", len(recs)) | ||||
|   | ||||
| @@ -5,6 +5,7 @@ import ( | ||||
| 	"crypto/tls" | ||||
| 	"time" | ||||
|  | ||||
| 	"go.unistack.org/micro/v3/codec" | ||||
| 	"go.unistack.org/micro/v3/logger" | ||||
| 	"go.unistack.org/micro/v3/meter" | ||||
| 	"go.unistack.org/micro/v3/tracer" | ||||
| @@ -26,6 +27,8 @@ type Options struct { | ||||
| 	Name string | ||||
| 	// Addrs specifies register addrs | ||||
| 	Addrs []string | ||||
| 	// Codec used to marshal/unmarshal data in register | ||||
| 	Codec codec.Codec | ||||
| 	// Timeout specifies timeout | ||||
| 	Timeout time.Duration | ||||
| } | ||||
| @@ -37,6 +40,7 @@ func NewOptions(opts ...Option) Options { | ||||
| 		Meter:   meter.DefaultMeter, | ||||
| 		Tracer:  tracer.DefaultTracer, | ||||
| 		Context: context.Background(), | ||||
| 		Codec:   codec.NewCodec(), | ||||
| 	} | ||||
| 	for _, o := range opts { | ||||
| 		o(&options) | ||||
| @@ -47,7 +51,7 @@ func NewOptions(opts ...Option) Options { | ||||
| // RegisterOptions holds options for register method | ||||
| type RegisterOptions struct { // nolint: golint,revive | ||||
| 	Context   context.Context | ||||
| 	Domain   string | ||||
| 	Namespace string | ||||
| 	TTL       time.Duration | ||||
| 	Attempts  int | ||||
| } | ||||
| @@ -55,7 +59,7 @@ type RegisterOptions struct { // nolint: golint,revive | ||||
| // NewRegisterOptions returns register options struct filled by opts | ||||
| func NewRegisterOptions(opts ...RegisterOption) RegisterOptions { | ||||
| 	options := RegisterOptions{ | ||||
| 		Domain:  DefaultDomain, | ||||
| 		Namespace: DefaultNamespace, | ||||
| 		Context:   context.Background(), | ||||
| 	} | ||||
| 	for _, o := range opts { | ||||
| @@ -72,14 +76,14 @@ type WatchOptions struct { | ||||
| 	// Other options for implementations of the interface | ||||
| 	// can be stored in a context | ||||
| 	Context context.Context | ||||
| 	// Domain to watch | ||||
| 	Domain string | ||||
| 	// Namespace to watch | ||||
| 	Namespace string | ||||
| } | ||||
|  | ||||
| // NewWatchOptions returns watch options filled by opts | ||||
| func NewWatchOptions(opts ...WatchOption) WatchOptions { | ||||
| 	options := WatchOptions{ | ||||
| 		Domain:  DefaultDomain, | ||||
| 		Namespace: DefaultNamespace, | ||||
| 		Context:   context.Background(), | ||||
| 	} | ||||
| 	for _, o := range opts { | ||||
| @@ -91,8 +95,8 @@ func NewWatchOptions(opts ...WatchOption) WatchOptions { | ||||
| // DeregisterOptions holds options for deregister method | ||||
| type DeregisterOptions struct { | ||||
| 	Context context.Context | ||||
| 	// Domain the service was registered in | ||||
| 	Domain string | ||||
| 	// Namespace the service was registered in | ||||
| 	Namespace string | ||||
| 	// Atempts specify max attempts for deregister | ||||
| 	Attempts int | ||||
| } | ||||
| @@ -100,7 +104,7 @@ type DeregisterOptions struct { | ||||
| // NewDeregisterOptions returns options for deregister filled by opts | ||||
| func NewDeregisterOptions(opts ...DeregisterOption) DeregisterOptions { | ||||
| 	options := DeregisterOptions{ | ||||
| 		Domain:  DefaultDomain, | ||||
| 		Namespace: DefaultNamespace, | ||||
| 		Context:   context.Background(), | ||||
| 	} | ||||
| 	for _, o := range opts { | ||||
| @@ -112,14 +116,14 @@ func NewDeregisterOptions(opts ...DeregisterOption) DeregisterOptions { | ||||
| // LookupOptions holds lookup options | ||||
| type LookupOptions struct { | ||||
| 	Context context.Context | ||||
| 	// Domain to scope the request to | ||||
| 	Domain string | ||||
| 	// Namespace to scope the request to | ||||
| 	Namespace string | ||||
| } | ||||
|  | ||||
| // NewLookupOptions returns lookup options filled by opts | ||||
| func NewLookupOptions(opts ...LookupOption) LookupOptions { | ||||
| 	options := LookupOptions{ | ||||
| 		Domain:  DefaultDomain, | ||||
| 		Namespace: DefaultNamespace, | ||||
| 		Context:   context.Background(), | ||||
| 	} | ||||
| 	for _, o := range opts { | ||||
| @@ -130,15 +134,16 @@ func NewLookupOptions(opts ...LookupOption) LookupOptions { | ||||
|  | ||||
| // ListOptions holds the list options for list method | ||||
| type ListOptions struct { | ||||
| 	// Context used to store additional options | ||||
| 	Context context.Context | ||||
| 	// Domain to scope the request to | ||||
| 	Domain string | ||||
| 	// Namespace to scope the request to | ||||
| 	Namespace string | ||||
| } | ||||
|  | ||||
| // NewListOptions returns list options filled by opts | ||||
| func NewListOptions(opts ...ListOption) ListOptions { | ||||
| 	options := ListOptions{ | ||||
| 		Domain:  DefaultDomain, | ||||
| 		Namespace: DefaultNamespace, | ||||
| 		Context:   context.Background(), | ||||
| 	} | ||||
| 	for _, o := range opts { | ||||
| @@ -217,10 +222,10 @@ func RegisterContext(ctx context.Context) RegisterOption { // nolint: golint,rev | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // RegisterDomain secifies register domain | ||||
| func RegisterDomain(d string) RegisterOption { // nolint: golint,revive | ||||
| // RegisterNamespace secifies register Namespace | ||||
| func RegisterNamespace(d string) RegisterOption { // nolint: golint,revive | ||||
| 	return func(o *RegisterOptions) { | ||||
| 		o.Domain = d | ||||
| 		o.Namespace = d | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -238,10 +243,10 @@ func WatchContext(ctx context.Context) WatchOption { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // WatchDomain sets the domain for watch | ||||
| func WatchDomain(d string) WatchOption { | ||||
| // WatchNamespace sets the Namespace for watch | ||||
| func WatchNamespace(d string) WatchOption { | ||||
| 	return func(o *WatchOptions) { | ||||
| 		o.Domain = d | ||||
| 		o.Namespace = d | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -259,10 +264,10 @@ func DeregisterContext(ctx context.Context) DeregisterOption { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // DeregisterDomain specifies deregister domain | ||||
| func DeregisterDomain(d string) DeregisterOption { | ||||
| // DeregisterNamespace specifies deregister Namespace | ||||
| func DeregisterNamespace(d string) DeregisterOption { | ||||
| 	return func(o *DeregisterOptions) { | ||||
| 		o.Domain = d | ||||
| 		o.Namespace = d | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -273,10 +278,10 @@ func LookupContext(ctx context.Context) LookupOption { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // LookupDomain sets the domain for lookup | ||||
| func LookupDomain(d string) LookupOption { | ||||
| // LookupNamespace sets the Namespace for lookup | ||||
| func LookupNamespace(d string) LookupOption { | ||||
| 	return func(o *LookupOptions) { | ||||
| 		o.Domain = d | ||||
| 		o.Namespace = d | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -287,10 +292,10 @@ func ListContext(ctx context.Context) ListOption { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ListDomain sets the domain for list method | ||||
| func ListDomain(d string) ListOption { | ||||
| // ListNamespace sets the Namespace for list method | ||||
| func ListNamespace(d string) ListOption { | ||||
| 	return func(o *ListOptions) { | ||||
| 		o.Domain = d | ||||
| 		o.Namespace = d | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -300,3 +305,9 @@ func Name(n string) Option { | ||||
| 		o.Name = n | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func Codec(c codec.Codec) Option { | ||||
| 	return func(o *Options) { | ||||
| 		o.Codec = c | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -9,12 +9,12 @@ import ( | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	// WildcardDomain indicates any domain | ||||
| 	WildcardDomain = "*" | ||||
| 	// WildcardNamespace indicates any Namespace | ||||
| 	WildcardNamespace = "*" | ||||
| ) | ||||
|  | ||||
| // DefaultDomain to use if none was provided in options | ||||
| var DefaultDomain = "micro" | ||||
| // DefaultNamespace to use if none was provided in options | ||||
| var DefaultNamespace = "micro" | ||||
|  | ||||
| var ( | ||||
| 	// DefaultRegister is the global default register | ||||
| @@ -59,26 +59,17 @@ type Register interface { | ||||
|  | ||||
| // Service holds service register info | ||||
| type Service struct { | ||||
| 	Name      string            `json:"name"` | ||||
| 	Version   string            `json:"version"` | ||||
| 	Metadata  metadata.Metadata `json:"metadata"` | ||||
| 	Endpoints []*Endpoint       `json:"endpoints"` | ||||
| 	Nodes     []*Node           `json:"nodes"` | ||||
| 	Name      string  `json:"name,omitempty"` | ||||
| 	Version   string  `json:"version,omitempty"` | ||||
| 	Nodes     []*Node `json:"nodes,omitempty"` | ||||
| 	Namespace string  `json:"namespace,omitempty"` | ||||
| } | ||||
|  | ||||
| // Node holds node register info | ||||
| type Node struct { | ||||
| 	Metadata metadata.Metadata `json:"metadata"` | ||||
| 	ID       string            `json:"id"` | ||||
| 	Address  string            `json:"address"` | ||||
| } | ||||
|  | ||||
| // Endpoint holds endpoint register info | ||||
| type Endpoint struct { | ||||
| 	Request  string            `json:"request"` | ||||
| 	Response string            `json:"response"` | ||||
| 	Metadata metadata.Metadata `json:"metadata"` | ||||
| 	Name     string            `json:"name"` | ||||
| 	Metadata metadata.Metadata `json:"metadata,omitempty"` | ||||
| 	ID       string            `json:"id,omitempty"` | ||||
| 	Address  string            `json:"address,omitempty"` | ||||
| } | ||||
|  | ||||
| // Option func signature | ||||
|   | ||||
| @@ -15,31 +15,31 @@ type Watcher interface { | ||||
| // the watcher. Actions can be create, update, delete | ||||
| type Result struct { | ||||
| 	// Service holds register service | ||||
| 	Service *Service | ||||
| 	Service *Service `json:"service,omitempty"` | ||||
| 	// Action holds the action | ||||
| 	Action string | ||||
| 	Action EventType `json:"action,omitempty"` | ||||
| } | ||||
|  | ||||
| // EventType defines register event type | ||||
| type EventType int | ||||
|  | ||||
| const ( | ||||
| 	// Create is emitted when a new service is registered | ||||
| 	Create EventType = iota | ||||
| 	// Delete is emitted when an existing service is deregistered | ||||
| 	Delete | ||||
| 	// Update is emitted when an existing service is updated | ||||
| 	Update | ||||
| 	// EventCreate is emitted when a new service is registered | ||||
| 	EventCreate EventType = iota | ||||
| 	// EventDelete is emitted when an existing service is deregistered | ||||
| 	EventDelete | ||||
| 	// EventUpdate is emitted when an existing service is updated | ||||
| 	EventUpdate | ||||
| ) | ||||
|  | ||||
| // String returns human readable event type | ||||
| func (t EventType) String() string { | ||||
| 	switch t { | ||||
| 	case Create: | ||||
| 	case EventCreate: | ||||
| 		return "create" | ||||
| 	case Delete: | ||||
| 	case EventDelete: | ||||
| 		return "delete" | ||||
| 	case Update: | ||||
| 	case EventUpdate: | ||||
| 		return "update" | ||||
| 	default: | ||||
| 		return "unknown" | ||||
| @@ -49,11 +49,11 @@ func (t EventType) String() string { | ||||
| // Event is register event | ||||
| type Event struct { | ||||
| 	// Timestamp is event timestamp | ||||
| 	Timestamp time.Time | ||||
| 	Timestamp time.Time `json:"timestamp,omitempty"` | ||||
| 	// Service is register service | ||||
| 	Service *Service | ||||
| 	Service *Service `json:"service,omitempty"` | ||||
| 	// ID is register id | ||||
| 	ID string | ||||
| 	ID string `json:"id,omitempty"` | ||||
| 	// Type defines type of event | ||||
| 	Type EventType | ||||
| 	Type EventType `json:"type,omitempty"` | ||||
| } | ||||
|   | ||||
| @@ -5,7 +5,6 @@ import ( | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"runtime/debug" | ||||
| 	"sort" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| @@ -22,11 +21,6 @@ import ( | ||||
| 	"go.unistack.org/micro/v3/util/rand" | ||||
| ) | ||||
|  | ||||
| // DefaultCodecs will be used to encode/decode | ||||
| var DefaultCodecs = map[string]codec.Codec{ | ||||
| 	"application/octet-stream": codec.NewCodec(), | ||||
| } | ||||
|  | ||||
| const ( | ||||
| 	defaultContentType = "application/json" | ||||
| ) | ||||
| @@ -35,34 +29,17 @@ type rpcHandler struct { | ||||
| 	opts    HandlerOptions | ||||
| 	handler interface{} | ||||
| 	name    string | ||||
| 	endpoints []*register.Endpoint | ||||
| } | ||||
|  | ||||
| func newRPCHandler(handler interface{}, opts ...HandlerOption) Handler { | ||||
| 	options := NewHandlerOptions(opts...) | ||||
|  | ||||
| 	typ := reflect.TypeOf(handler) | ||||
| 	hdlr := reflect.ValueOf(handler) | ||||
| 	name := reflect.Indirect(hdlr).Type().Name() | ||||
|  | ||||
| 	var endpoints []*register.Endpoint | ||||
|  | ||||
| 	for m := 0; m < typ.NumMethod(); m++ { | ||||
| 		if e := register.ExtractEndpoint(typ.Method(m)); e != nil { | ||||
| 			e.Name = name + "." + e.Name | ||||
|  | ||||
| 			for k, v := range options.Metadata[e.Name] { | ||||
| 				e.Metadata[k] = v | ||||
| 			} | ||||
|  | ||||
| 			endpoints = append(endpoints, e) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return &rpcHandler{ | ||||
| 		name:    name, | ||||
| 		handler: handler, | ||||
| 		endpoints: endpoints, | ||||
| 		opts:    options, | ||||
| 	} | ||||
| } | ||||
| @@ -75,10 +52,6 @@ func (r *rpcHandler) Handler() interface{} { | ||||
| 	return r.handler | ||||
| } | ||||
|  | ||||
| func (r *rpcHandler) Endpoints() []*register.Endpoint { | ||||
| 	return r.endpoints | ||||
| } | ||||
|  | ||||
| func (r *rpcHandler) Options() HandlerOptions { | ||||
| 	return r.opts | ||||
| } | ||||
| @@ -115,9 +88,6 @@ func (n *noopServer) newCodec(contentType string) (codec.Codec, error) { | ||||
| 	if cf, ok := n.opts.Codecs[contentType]; ok { | ||||
| 		return cf, nil | ||||
| 	} | ||||
| 	if cf, ok := DefaultCodecs[contentType]; ok { | ||||
| 		return cf, nil | ||||
| 	} | ||||
| 	return nil, codec.ErrUnknownContentType | ||||
| } | ||||
|  | ||||
| @@ -249,35 +219,6 @@ func (n *noopServer) Register() error { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	n.RLock() | ||||
| 	handlerList := make([]string, 0, len(n.handlers)) | ||||
| 	for n := range n.handlers { | ||||
| 		handlerList = append(handlerList, n) | ||||
| 	} | ||||
|  | ||||
| 	sort.Strings(handlerList) | ||||
|  | ||||
| 	subscriberList := make([]*subscriber, 0, len(n.subscribers)) | ||||
| 	for e := range n.subscribers { | ||||
| 		subscriberList = append(subscriberList, e) | ||||
| 	} | ||||
| 	sort.Slice(subscriberList, func(i, j int) bool { | ||||
| 		return subscriberList[i].topic > subscriberList[j].topic | ||||
| 	}) | ||||
|  | ||||
| 	endpoints := make([]*register.Endpoint, 0, len(handlerList)+len(subscriberList)) | ||||
| 	for _, h := range handlerList { | ||||
| 		endpoints = append(endpoints, n.handlers[h].Endpoints()...) | ||||
| 	} | ||||
| 	for _, e := range subscriberList { | ||||
| 		endpoints = append(endpoints, e.Endpoints()...) | ||||
| 	} | ||||
| 	n.RUnlock() | ||||
|  | ||||
| 	service.Nodes[0].Metadata["protocol"] = "noop" | ||||
| 	service.Nodes[0].Metadata["transport"] = service.Nodes[0].Metadata["protocol"] | ||||
| 	service.Endpoints = endpoints | ||||
|  | ||||
| 	n.RLock() | ||||
| 	registered := n.registered | ||||
| 	n.RUnlock() | ||||
| @@ -576,7 +517,6 @@ func (n *noopServer) Stop() error { | ||||
| } | ||||
|  | ||||
| func newSubscriber(topic string, sub interface{}, opts ...SubscriberOption) Subscriber { | ||||
| 	var endpoints []*register.Endpoint | ||||
| 	var handlers []*handler | ||||
|  | ||||
| 	options := NewSubscriberOptions(opts...) | ||||
| @@ -595,18 +535,7 @@ func newSubscriber(topic string, sub interface{}, opts ...SubscriberOption) Subs | ||||
| 		} | ||||
|  | ||||
| 		handlers = append(handlers, h) | ||||
| 		ep := ®ister.Endpoint{ | ||||
| 			Name:     "Func", | ||||
| 			Request:  register.ExtractSubValue(typ), | ||||
| 			Metadata: metadata.New(2), | ||||
| 		} | ||||
| 		ep.Metadata.Set("topic", topic) | ||||
| 		ep.Metadata.Set("subscriber", "true") | ||||
| 		endpoints = append(endpoints, ep) | ||||
| 	} else { | ||||
| 		hdlr := reflect.ValueOf(sub) | ||||
| 		name := reflect.Indirect(hdlr).Type().Name() | ||||
|  | ||||
| 		for m := 0; m < typ.NumMethod(); m++ { | ||||
| 			method := typ.Method(m) | ||||
| 			h := &handler{ | ||||
| @@ -622,14 +551,6 @@ func newSubscriber(topic string, sub interface{}, opts ...SubscriberOption) Subs | ||||
| 			} | ||||
|  | ||||
| 			handlers = append(handlers, h) | ||||
| 			ep := ®ister.Endpoint{ | ||||
| 				Name:     name + "." + method.Name, | ||||
| 				Request:  register.ExtractSubValue(method.Type), | ||||
| 				Metadata: metadata.New(2), | ||||
| 			} | ||||
| 			ep.Metadata.Set("topic", topic) | ||||
| 			ep.Metadata.Set("subscriber", "true") | ||||
| 			endpoints = append(endpoints, ep) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| @@ -639,7 +560,6 @@ func newSubscriber(topic string, sub interface{}, opts ...SubscriberOption) Subs | ||||
| 		topic:      topic, | ||||
| 		subscriber: sub, | ||||
| 		handlers:   handlers, | ||||
| 		endpoints:  endpoints, | ||||
| 		opts:       options, | ||||
| 	} | ||||
| } | ||||
| @@ -766,10 +686,6 @@ func (s *subscriber) Subscriber() interface{} { | ||||
| 	return s.subscriber | ||||
| } | ||||
|  | ||||
| func (s *subscriber) Endpoints() []*register.Endpoint { | ||||
| 	return s.endpoints | ||||
| } | ||||
|  | ||||
| func (s *subscriber) Options() SubscriberOptions { | ||||
| 	return s.opts | ||||
| } | ||||
| @@ -780,7 +696,6 @@ type subscriber struct { | ||||
| 	typ        reflect.Type | ||||
| 	subscriber interface{} | ||||
|  | ||||
| 	endpoints []*register.Endpoint | ||||
| 	handlers []*handler | ||||
|  | ||||
| 	rcvr reflect.Value | ||||
|   | ||||
| @@ -17,7 +17,7 @@ var ( | ||||
|  | ||||
| 		opts := []register.RegisterOption{ | ||||
| 			register.RegisterTTL(config.RegisterTTL), | ||||
| 			register.RegisterDomain(config.Namespace), | ||||
| 			register.RegisterNamespace(config.Namespace), | ||||
| 		} | ||||
|  | ||||
| 		for i := 0; i <= config.RegisterAttempts; i++ { | ||||
| @@ -36,7 +36,7 @@ var ( | ||||
| 		var err error | ||||
|  | ||||
| 		opts := []register.DeregisterOption{ | ||||
| 			register.DeregisterDomain(config.Namespace), | ||||
| 			register.DeregisterNamespace(config.Namespace), | ||||
| 		} | ||||
|  | ||||
| 		for i := 0; i <= config.DeregisterAttempts; i++ { | ||||
| @@ -85,6 +85,5 @@ func NewRegisterService(s Server) (*register.Service, error) { | ||||
| 		Name:    opts.Name, | ||||
| 		Version: opts.Version, | ||||
| 		Nodes:   []*register.Node{node}, | ||||
| 		Metadata: metadata.New(0), | ||||
| 	}, nil | ||||
| } | ||||
|   | ||||
| @@ -7,7 +7,6 @@ import ( | ||||
|  | ||||
| 	"go.unistack.org/micro/v3/codec" | ||||
| 	"go.unistack.org/micro/v3/metadata" | ||||
| 	"go.unistack.org/micro/v3/register" | ||||
| ) | ||||
|  | ||||
| // DefaultServer default server | ||||
| @@ -170,7 +169,6 @@ type Stream interface { | ||||
| type Handler interface { | ||||
| 	Name() string | ||||
| 	Handler() interface{} | ||||
| 	Endpoints() []*register.Endpoint | ||||
| 	Options() HandlerOptions | ||||
| } | ||||
|  | ||||
| @@ -180,6 +178,5 @@ type Handler interface { | ||||
| type Subscriber interface { | ||||
| 	Topic() string | ||||
| 	Subscriber() interface{} | ||||
| 	Endpoints() []*register.Endpoint | ||||
| 	Options() SubscriberOptions | ||||
| } | ||||
|   | ||||
							
								
								
									
										20
									
								
								service.go
									
									
									
									
									
								
							
							
						
						
									
										20
									
								
								service.go
									
									
									
									
									
								
							| @@ -104,6 +104,7 @@ type service struct { | ||||
| 	done chan struct{} | ||||
| 	opts Options | ||||
| 	sync.RWMutex | ||||
| 	stopped bool | ||||
| } | ||||
|  | ||||
| // NewService creates and returns a new Service based on the packages within. | ||||
| @@ -429,7 +430,7 @@ func (s *service) Stop() error { | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	close(s.done) | ||||
| 	s.notifyShutdown() | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
| @@ -453,10 +454,23 @@ func (s *service) Run() error { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	// wait on context cancel | ||||
| 	<-s.done | ||||
|  | ||||
| 	return s.Stop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // notifyShutdown marks the service as stopped and closes the done channel. | ||||
| // It ensures the channel is closed only once, preventing multiple closures. | ||||
| func (s *service) notifyShutdown() { | ||||
| 	s.Lock() | ||||
| 	if s.stopped { | ||||
| 		s.Unlock() | ||||
| 		return | ||||
| 	} | ||||
| 	s.stopped = true | ||||
| 	s.Unlock() | ||||
|  | ||||
| 	close(s.done) | ||||
| } | ||||
|  | ||||
| type Namer interface { | ||||
|   | ||||
| @@ -4,7 +4,9 @@ import ( | ||||
| 	"context" | ||||
| 	"reflect" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/stretchr/testify/require" | ||||
| 	"go.unistack.org/micro/v3/broker" | ||||
| 	"go.unistack.org/micro/v3/client" | ||||
| 	"go.unistack.org/micro/v3/config" | ||||
| @@ -773,3 +775,41 @@ func Test_getNameIndex(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
| */ | ||||
|  | ||||
| func TestServiceShutdown(t *testing.T) { | ||||
| 	defer func() { | ||||
| 		if r := recover(); r != nil { | ||||
| 			t.Fatalf("service shutdown failed: %v", r) | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	s, ok := NewService().(*service) | ||||
| 	require.NotNil(t, s) | ||||
| 	require.True(t, ok) | ||||
|  | ||||
| 	require.NoError(t, s.Start()) | ||||
| 	require.False(t, s.stopped) | ||||
|  | ||||
| 	require.NoError(t, s.Stop()) | ||||
| 	require.True(t, s.stopped) | ||||
| } | ||||
|  | ||||
| func TestServiceMultipleShutdowns(t *testing.T) { | ||||
| 	defer func() { | ||||
| 		if r := recover(); r != nil { | ||||
| 			t.Fatalf("service shutdown failed: %v", r) | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	s := NewService() | ||||
|  | ||||
| 	go func() { | ||||
| 		time.Sleep(10 * time.Millisecond) | ||||
| 		// first call | ||||
| 		require.NoError(t, s.Stop()) | ||||
| 		// duplicate call | ||||
| 		require.NoError(t, s.Stop()) | ||||
| 	}() | ||||
|  | ||||
| 	require.NoError(t, s.Run()) | ||||
| } | ||||
|   | ||||
| @@ -46,6 +46,10 @@ func (s memoryStringer) String() string { | ||||
| 	return s.s | ||||
| } | ||||
|  | ||||
| func (t *Tracer) Enabled() bool { | ||||
| 	return t.opts.Enabled | ||||
| } | ||||
|  | ||||
| func (t *Tracer) Flush(_ context.Context) error { | ||||
| 	return nil | ||||
| } | ||||
| @@ -89,6 +93,10 @@ func (s *Span) Tracer() tracer.Tracer { | ||||
| 	return s.tracer | ||||
| } | ||||
|  | ||||
| func (s *Span) IsRecording() bool { | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| type Event struct { | ||||
| 	name   string | ||||
| 	labels []interface{} | ||||
|   | ||||
| @@ -18,6 +18,10 @@ func (t *noopTracer) Spans() []Span { | ||||
| 	return t.spans | ||||
| } | ||||
|  | ||||
| func (t *noopTracer) Enabled() bool { | ||||
| 	return t.opts.Enabled | ||||
| } | ||||
|  | ||||
| func (t *noopTracer) Start(ctx context.Context, name string, opts ...SpanOption) (context.Context, Span) { | ||||
| 	options := NewSpanOptions(opts...) | ||||
| 	span := &noopSpan{ | ||||
| @@ -120,6 +124,10 @@ func (s *noopSpan) SetStatus(st SpanStatus, msg string) { | ||||
| 	s.statusMsg = msg | ||||
| } | ||||
|  | ||||
| func (s *noopSpan) IsRecording() bool { | ||||
| 	return false | ||||
| } | ||||
|  | ||||
| // NewTracer returns new memory tracer | ||||
| func NewTracer(opts ...Option) Tracer { | ||||
| 	return &noopTracer{ | ||||
|   | ||||
| @@ -142,6 +142,8 @@ type Options struct { | ||||
| 	Name string | ||||
| 	// ContextAttrFuncs contains funcs that provides tracing | ||||
| 	ContextAttrFuncs []ContextAttrFunc | ||||
| 	// Enabled specify trace status | ||||
| 	Enabled bool | ||||
| } | ||||
|  | ||||
| // Option func signature | ||||
| @@ -181,6 +183,7 @@ func NewOptions(opts ...Option) Options { | ||||
| 		Logger:           logger.DefaultLogger, | ||||
| 		Context:          context.Background(), | ||||
| 		ContextAttrFuncs: DefaultContextAttrFuncs, | ||||
| 		Enabled:          true, | ||||
| 	} | ||||
| 	for _, o := range opts { | ||||
| 		o(&options) | ||||
| @@ -194,3 +197,10 @@ func Name(n string) Option { | ||||
| 		o.Name = n | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Disabled disable tracer | ||||
| func Disabled(b bool) Option { | ||||
| 	return func(o *Options) { | ||||
| 		o.Enabled = !b | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -51,6 +51,8 @@ type Tracer interface { | ||||
| 	// Extract(ctx context.Context) | ||||
| 	// Flush flushes spans | ||||
| 	Flush(ctx context.Context) error | ||||
| 	// Enabled returns tracer status | ||||
| 	Enabled() bool | ||||
| } | ||||
|  | ||||
| type Span interface { | ||||
| @@ -78,4 +80,6 @@ type Span interface { | ||||
| 	TraceID() string | ||||
| 	// SpanID returns span id | ||||
| 	SpanID() string | ||||
| 	// IsRecording returns the recording state of the Span. | ||||
| 	IsRecording() bool | ||||
| } | ||||
|   | ||||
| @@ -1,27 +0,0 @@ | ||||
| package buf | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"io" | ||||
| ) | ||||
|  | ||||
| var _ io.Closer = &Buffer{} | ||||
|  | ||||
| // Buffer bytes.Buffer wrapper to satisfie io.Closer interface | ||||
| type Buffer struct { | ||||
| 	*bytes.Buffer | ||||
| } | ||||
|  | ||||
| // Close reset buffer contents | ||||
| func (b *Buffer) Close() error { | ||||
| 	b.Buffer.Reset() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // New creates new buffer that satisfies Closer interface | ||||
| func New(b *bytes.Buffer) *Buffer { | ||||
| 	if b == nil { | ||||
| 		b = bytes.NewBuffer(nil) | ||||
| 	} | ||||
| 	return &Buffer{b} | ||||
| } | ||||
							
								
								
									
										85
									
								
								util/buffer/delayed_buffer.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										85
									
								
								util/buffer/delayed_buffer.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,85 @@ | ||||
| package buffer | ||||
|  | ||||
| import ( | ||||
| 	"io" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| var _ io.WriteCloser = (*DelayedBuffer)(nil) | ||||
|  | ||||
| // DelayedBuffer is the buffer that holds items until either the buffer filled or a specified time limit is reached | ||||
| type DelayedBuffer struct { | ||||
| 	mu        sync.Mutex | ||||
| 	maxWait   time.Duration | ||||
| 	flushTime time.Time | ||||
| 	buffer    chan []byte | ||||
| 	ticker    *time.Ticker | ||||
| 	w         io.Writer | ||||
| 	err       error | ||||
| } | ||||
|  | ||||
| func NewDelayedBuffer(size int, maxWait time.Duration, w io.Writer) *DelayedBuffer { | ||||
| 	b := &DelayedBuffer{ | ||||
| 		buffer:    make(chan []byte, size), | ||||
| 		ticker:    time.NewTicker(maxWait), | ||||
| 		w:         w, | ||||
| 		flushTime: time.Now(), | ||||
| 		maxWait:   maxWait, | ||||
| 	} | ||||
| 	b.loop() | ||||
| 	return b | ||||
| } | ||||
|  | ||||
| func (b *DelayedBuffer) loop() { | ||||
| 	go func() { | ||||
| 		for range b.ticker.C { | ||||
| 			b.mu.Lock() | ||||
| 			if time.Since(b.flushTime) > b.maxWait { | ||||
| 				b.flush() | ||||
| 			} | ||||
| 			b.mu.Unlock() | ||||
| 		} | ||||
| 	}() | ||||
| } | ||||
|  | ||||
| func (b *DelayedBuffer) flush() { | ||||
| 	bufLen := len(b.buffer) | ||||
| 	if bufLen > 0 { | ||||
| 		tmp := make([][]byte, bufLen) | ||||
| 		for i := 0; i < bufLen; i++ { | ||||
| 			tmp[i] = <-b.buffer | ||||
| 		} | ||||
| 		for _, t := range tmp { | ||||
| 			_, b.err = b.w.Write(t) | ||||
| 		} | ||||
| 		b.flushTime = time.Now() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (b *DelayedBuffer) Put(items ...[]byte) { | ||||
| 	b.mu.Lock() | ||||
| 	for _, item := range items { | ||||
| 		select { | ||||
| 		case b.buffer <- item: | ||||
| 		default: | ||||
| 			b.flush() | ||||
| 			b.buffer <- item | ||||
| 		} | ||||
| 	} | ||||
| 	b.mu.Unlock() | ||||
| } | ||||
|  | ||||
| func (b *DelayedBuffer) Close() error { | ||||
| 	b.mu.Lock() | ||||
| 	b.flush() | ||||
| 	close(b.buffer) | ||||
| 	b.ticker.Stop() | ||||
| 	b.mu.Unlock() | ||||
| 	return b.err | ||||
| } | ||||
|  | ||||
| func (b *DelayedBuffer) Write(data []byte) (int, error) { | ||||
| 	b.Put(data) | ||||
| 	return len(data), b.err | ||||
| } | ||||
							
								
								
									
										22
									
								
								util/buffer/delayed_buffer_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								util/buffer/delayed_buffer_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,22 @@ | ||||
| package buffer | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| func TestTimedBuffer(t *testing.T) { | ||||
| 	buf := bytes.NewBuffer(nil) | ||||
| 	b := NewDelayedBuffer(100, 300*time.Millisecond, buf) | ||||
| 	for i := 0; i < 100; i++ { | ||||
| 		_, _ = b.Write([]byte(`test`)) | ||||
| 	} | ||||
| 	if buf.Len() != 0 { | ||||
| 		t.Fatal("delayed write not worked") | ||||
| 	} | ||||
| 	time.Sleep(400 * time.Millisecond) | ||||
| 	if buf.Len() == 0 { | ||||
| 		t.Fatal("delayed write not worked") | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										84
									
								
								util/buffer/seeker_buffer.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										84
									
								
								util/buffer/seeker_buffer.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,84 @@ | ||||
| package buffer | ||||
|  | ||||
| import "io" | ||||
|  | ||||
| var _ interface { | ||||
| 	io.ReadCloser | ||||
| 	io.ReadSeeker | ||||
| } = (*SeekerBuffer)(nil) | ||||
|  | ||||
| // Buffer is a ReadWriteCloser that supports seeking. It's intended to | ||||
| // replicate the functionality of bytes.Buffer that I use in my projects. | ||||
| // | ||||
| // Note that the seeking is limited to the read marker; all writes are | ||||
| // append-only. | ||||
| type SeekerBuffer struct { | ||||
| 	data []byte | ||||
| 	pos  int64 | ||||
| } | ||||
|  | ||||
| func NewSeekerBuffer(data []byte) *SeekerBuffer { | ||||
| 	return &SeekerBuffer{ | ||||
| 		data: data, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (b *SeekerBuffer) Read(p []byte) (int, error) { | ||||
| 	if b.pos >= int64(len(b.data)) { | ||||
| 		return 0, io.EOF | ||||
| 	} | ||||
|  | ||||
| 	n := copy(p, b.data[b.pos:]) | ||||
| 	b.pos += int64(n) | ||||
| 	return n, nil | ||||
| } | ||||
|  | ||||
| func (b *SeekerBuffer) Write(p []byte) (int, error) { | ||||
| 	b.data = append(b.data, p...) | ||||
| 	return len(p), nil | ||||
| } | ||||
|  | ||||
| // Seek sets the read pointer to pos. | ||||
| func (b *SeekerBuffer) Seek(offset int64, whence int) (int64, error) { | ||||
| 	switch whence { | ||||
| 	case io.SeekStart: | ||||
| 		b.pos = offset | ||||
| 	case io.SeekEnd: | ||||
| 		b.pos = int64(len(b.data)) + offset | ||||
| 	case io.SeekCurrent: | ||||
| 		b.pos += offset | ||||
| 	} | ||||
|  | ||||
| 	return b.pos, nil | ||||
| } | ||||
|  | ||||
| // Rewind resets the read pointer to 0. | ||||
| func (b *SeekerBuffer) Rewind() error { | ||||
| 	if _, err := b.Seek(0, io.SeekStart); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // Close clears all the data out of the buffer and sets the read position to 0. | ||||
| func (b *SeekerBuffer) Close() error { | ||||
| 	b.data = nil | ||||
| 	b.pos = 0 | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // Reset clears all the data out of the buffer and sets the read position to 0. | ||||
| func (b *SeekerBuffer) Reset() { | ||||
| 	b.data = nil | ||||
| 	b.pos = 0 | ||||
| } | ||||
|  | ||||
| // Len returns the length of data remaining to be read. | ||||
| func (b *SeekerBuffer) Len() int { | ||||
| 	return len(b.data[b.pos:]) | ||||
| } | ||||
|  | ||||
| // Bytes returns the underlying bytes from the current position. | ||||
| func (b *SeekerBuffer) Bytes() []byte { | ||||
| 	return b.data[b.pos:] | ||||
| } | ||||
							
								
								
									
										55
									
								
								util/buffer/seeker_buffer_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								util/buffer/seeker_buffer_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,55 @@ | ||||
| package buffer | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| func noErrorT(t *testing.T, err error) { | ||||
| 	if nil != err { | ||||
| 		t.Fatalf("%s", err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func boolT(t *testing.T, cond bool, s ...string) { | ||||
| 	if !cond { | ||||
| 		what := strings.Join(s, ", ") | ||||
| 		if len(what) > 0 { | ||||
| 			what = ": " + what | ||||
| 		} | ||||
| 		t.Fatalf("assert.Bool failed%s", what) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestSeeking(t *testing.T) { | ||||
| 	partA := []byte("hello, ") | ||||
| 	partB := []byte("world!") | ||||
|  | ||||
| 	buf := NewSeekerBuffer(partA) | ||||
|  | ||||
| 	boolT(t, buf.Len() == len(partA), fmt.Sprintf("on init: have length %d, want length %d", buf.Len(), len(partA))) | ||||
|  | ||||
| 	b := make([]byte, 32) | ||||
|  | ||||
| 	n, err := buf.Read(b) | ||||
| 	noErrorT(t, err) | ||||
| 	boolT(t, buf.Len() == 0, fmt.Sprintf("after reading 1: have length %d, want length 0", buf.Len())) | ||||
| 	boolT(t, n == len(partA), fmt.Sprintf("after reading 2: have length %d, want length %d", n, len(partA))) | ||||
|  | ||||
| 	n, err = buf.Write(partB) | ||||
| 	noErrorT(t, err) | ||||
| 	boolT(t, n == len(partB), fmt.Sprintf("after writing: have length %d, want length %d", n, len(partB))) | ||||
|  | ||||
| 	n, err = buf.Read(b) | ||||
| 	noErrorT(t, err) | ||||
| 	boolT(t, buf.Len() == 0, fmt.Sprintf("after rereading 1: have length %d, want length 0", buf.Len())) | ||||
| 	boolT(t, n == len(partB), fmt.Sprintf("after rereading 2: have length %d, want length %d", n, len(partB))) | ||||
|  | ||||
| 	partsLen := len(partA) + len(partB) | ||||
| 	_ = buf.Rewind() | ||||
| 	boolT(t, buf.Len() == partsLen, fmt.Sprintf("after rewinding: have length %d, want length %d", buf.Len(), partsLen)) | ||||
|  | ||||
| 	buf.Close() | ||||
| 	boolT(t, buf.Len() == 0, fmt.Sprintf("after closing, have length %d, want length 0", buf.Len())) | ||||
| } | ||||
| @@ -489,35 +489,74 @@ func URLMap(query string) (map[string]interface{}, error) { | ||||
| 	return mp.(map[string]interface{}), nil | ||||
| } | ||||
|  | ||||
| // FlattenMap expand key.subkey to nested map | ||||
| func FlattenMap(a map[string]interface{}) map[string]interface{} { | ||||
| 	// preprocess map | ||||
| 	nb := make(map[string]interface{}, len(a)) | ||||
| 	for k, v := range a { | ||||
| 		ps := strings.Split(k, ".") | ||||
| 		if len(ps) == 1 { | ||||
| 			nb[k] = v | ||||
| // FlattenMap flattens a nested map into a single-level map using dot notation for nested keys. | ||||
| // In case of key conflicts, all nested levels will be discarded in favor of the first-level key. | ||||
| // | ||||
| // Example #1: | ||||
| // | ||||
| //	Input: | ||||
| //	  { | ||||
| //	    "user.name": "alex", | ||||
| //	    "user.document.id": "document_id" | ||||
| //	    "user.document.number": "document_number" | ||||
| //	  } | ||||
| //	Output: | ||||
| //	  { | ||||
| //	    "user": { | ||||
| //	      "name": "alex", | ||||
| //	      "document": { | ||||
| //	        "id": "document_id" | ||||
| //	        "number": "document_number" | ||||
| //	      } | ||||
| //	    } | ||||
| //	  } | ||||
| // | ||||
| // Example #2 (with conflicts): | ||||
| // | ||||
| //	Input: | ||||
| //	  { | ||||
| //	    "user": "alex", | ||||
| //	    "user.document.id": "document_id" | ||||
| //	    "user.document.number": "document_number" | ||||
| //	  } | ||||
| //	Output: | ||||
| //	  { | ||||
| //	    "user": "alex" | ||||
| //	  } | ||||
| func FlattenMap(input map[string]interface{}) map[string]interface{} { | ||||
| 	result := make(map[string]interface{}) | ||||
|  | ||||
| 	for k, v := range input { | ||||
| 		parts := strings.Split(k, ".") | ||||
|  | ||||
| 		if len(parts) == 1 { | ||||
| 			result[k] = v | ||||
| 			continue | ||||
| 		} | ||||
| 		em := make(map[string]interface{}) | ||||
| 		em[ps[len(ps)-1]] = v | ||||
| 		for i := len(ps) - 2; i > 0; i-- { | ||||
| 			nm := make(map[string]interface{}) | ||||
| 			nm[ps[i]] = em | ||||
| 			em = nm | ||||
|  | ||||
| 		current := result | ||||
|  | ||||
| 		for i, part := range parts { | ||||
| 			// last element in the path | ||||
| 			if i == len(parts)-1 { | ||||
| 				current[part] = v | ||||
| 				break | ||||
| 			} | ||||
| 		if vm, ok := nb[ps[0]]; ok { | ||||
| 			// nested map | ||||
| 			nm := vm.(map[string]interface{}) | ||||
| 			for vk, vv := range em { | ||||
| 				nm[vk] = vv | ||||
|  | ||||
| 			// initialize map for current level if not exist | ||||
| 			if _, ok := current[part]; !ok { | ||||
| 				current[part] = make(map[string]interface{}) | ||||
| 			} | ||||
| 			nb[ps[0]] = nm | ||||
|  | ||||
| 			if nested, ok := current[part].(map[string]interface{}); ok { | ||||
| 				current = nested // continue to the nested map | ||||
| 			} else { | ||||
| 			nb[ps[0]] = em | ||||
| 				break // if current element is not a map, ignore it | ||||
| 			} | ||||
| 		} | ||||
| 	return nb | ||||
| 	} | ||||
|  | ||||
| 	return result | ||||
| } | ||||
|  | ||||
| /* | ||||
|   | ||||
| @@ -6,6 +6,7 @@ import ( | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/stretchr/testify/require" | ||||
| 	rutil "go.unistack.org/micro/v3/util/reflect" | ||||
| ) | ||||
|  | ||||
| @@ -319,3 +320,140 @@ func TestIsZero(t *testing.T) { | ||||
|  | ||||
| 	// t.Logf("XX %#+v\n", ok) | ||||
| } | ||||
|  | ||||
| func TestFlattenMap(t *testing.T) { | ||||
| 	tests := []struct { | ||||
| 		name     string | ||||
| 		input    map[string]interface{} | ||||
| 		expected map[string]interface{} | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name:     "empty map", | ||||
| 			input:    map[string]interface{}{}, | ||||
| 			expected: map[string]interface{}{}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:     "nil map", | ||||
| 			input:    nil, | ||||
| 			expected: map[string]interface{}{}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "single level", | ||||
| 			input: map[string]interface{}{ | ||||
| 				"username": "username", | ||||
| 				"password": "password", | ||||
| 			}, | ||||
| 			expected: map[string]interface{}{ | ||||
| 				"username": "username", | ||||
| 				"password": "password", | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "two level", | ||||
| 			input: map[string]interface{}{ | ||||
| 				"order_id":      "order_id", | ||||
| 				"user.name":     "username", | ||||
| 				"user.password": "password", | ||||
| 			}, | ||||
| 			expected: map[string]interface{}{ | ||||
| 				"order_id": "order_id", | ||||
| 				"user": map[string]interface{}{ | ||||
| 					"name":     "username", | ||||
| 					"password": "password", | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "three level", | ||||
| 			input: map[string]interface{}{ | ||||
| 				"order_id":             "order_id", | ||||
| 				"user.name":            "username", | ||||
| 				"user.password":        "password", | ||||
| 				"user.document.id":     "document_id", | ||||
| 				"user.document.number": "document_number", | ||||
| 			}, | ||||
| 			expected: map[string]interface{}{ | ||||
| 				"order_id": "order_id", | ||||
| 				"user": map[string]interface{}{ | ||||
| 					"name":     "username", | ||||
| 					"password": "password", | ||||
| 					"document": map[string]interface{}{ | ||||
| 						"id":     "document_id", | ||||
| 						"number": "document_number", | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "four level", | ||||
| 			input: map[string]interface{}{ | ||||
| 				"order_id":                    "order_id", | ||||
| 				"user.name":                   "username", | ||||
| 				"user.password":               "password", | ||||
| 				"user.document.id":            "document_id", | ||||
| 				"user.document.number":        "document_number", | ||||
| 				"user.info.permissions.read":  "available", | ||||
| 				"user.info.permissions.write": "available", | ||||
| 			}, | ||||
| 			expected: map[string]interface{}{ | ||||
| 				"order_id": "order_id", | ||||
| 				"user": map[string]interface{}{ | ||||
| 					"name":     "username", | ||||
| 					"password": "password", | ||||
| 					"document": map[string]interface{}{ | ||||
| 						"id":     "document_id", | ||||
| 						"number": "document_number", | ||||
| 					}, | ||||
| 					"info": map[string]interface{}{ | ||||
| 						"permissions": map[string]interface{}{ | ||||
| 							"read":  "available", | ||||
| 							"write": "available", | ||||
| 						}, | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "key conflicts", | ||||
| 			input: map[string]interface{}{ | ||||
| 				"user":          "user", | ||||
| 				"user.name":     "username", | ||||
| 				"user.password": "password", | ||||
| 			}, | ||||
| 			expected: map[string]interface{}{ | ||||
| 				"user": "user", | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "overwriting conflicts", | ||||
| 			input: map[string]interface{}{ | ||||
| 				"order_id":             "order_id", | ||||
| 				"user.document.id":     "document_id", | ||||
| 				"user.document.number": "document_number", | ||||
| 				"user.info.address":    "address", | ||||
| 				"user.info.phone":      "phone", | ||||
| 			}, | ||||
| 			expected: map[string]interface{}{ | ||||
| 				"order_id": "order_id", | ||||
| 				"user": map[string]interface{}{ | ||||
| 					"document": map[string]interface{}{ | ||||
| 						"id":     "document_id", | ||||
| 						"number": "document_number", | ||||
| 					}, | ||||
| 					"info": map[string]interface{}{ | ||||
| 						"address": "address", | ||||
| 						"phone":   "phone", | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tt := range tests { | ||||
| 		t.Run(tt.name, func(t *testing.T) { | ||||
| 			for range 100 { // need to exclude the impact of key order in the map on the test. | ||||
| 				require.Equal(t, tt.expected, rutil.FlattenMap(tt.input)) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -71,14 +71,6 @@ func CopyService(service *register.Service) *register.Service { | ||||
| 	} | ||||
| 	s.Nodes = nodes | ||||
|  | ||||
| 	// copy endpoints | ||||
| 	eps := make([]*register.Endpoint, len(service.Endpoints)) | ||||
| 	for j, ep := range service.Endpoints { | ||||
| 		e := ®ister.Endpoint{} | ||||
| 		*e = *ep | ||||
| 		eps[j] = e | ||||
| 	} | ||||
| 	s.Endpoints = eps | ||||
| 	return s | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -35,11 +35,11 @@ func TestUnmarshalYAML(t *testing.T) { | ||||
| 		t.Fatalf("invalid duration %v != 10000000", v.TTL) | ||||
| 	} | ||||
|  | ||||
| 	err = yaml.Unmarshal([]byte(`{"ttl":"1y"}`), v) | ||||
| 	err = yaml.Unmarshal([]byte(`{"ttl":"1d"}`), v) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} else if *(v.TTL) != 31622400000000000 { | ||||
| 		t.Fatalf("invalid duration %v != 31622400000000000", v.TTL) | ||||
| 	} else if *(v.TTL) != 86400000000000 { | ||||
| 		t.Fatalf("invalid duration %v != 86400000000000", *v.TTL) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -68,11 +68,11 @@ func TestUnmarshalJSON(t *testing.T) { | ||||
| 		t.Fatalf("invalid duration %v != 10000000", v.TTL) | ||||
| 	} | ||||
|  | ||||
| 	err = json.Unmarshal([]byte(`{"ttl":"1y"}`), v) | ||||
| 	err = json.Unmarshal([]byte(`{"ttl":"1d"}`), v) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} else if v.TTL != 31622400000000000 { | ||||
| 		t.Fatalf("invalid duration %v != 31622400000000000", v.TTL) | ||||
| 	} else if v.TTL != 86400000000000 { | ||||
| 		t.Fatalf("invalid duration %v != 86400000000000", v.TTL) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -87,11 +87,11 @@ func TestParseDuration(t *testing.T) { | ||||
| 	if td.String() != "340h0m0s" { | ||||
| 		t.Fatalf("ParseDuration 14d != 340h0m0s : %s", td.String()) | ||||
| 	} | ||||
| 	td, err = ParseDuration("1y") | ||||
| 	td, err = ParseDuration("1d") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("ParseDuration error: %v", err) | ||||
| 	} | ||||
| 	if td.String() != "8784h0m0s" { | ||||
| 		t.Fatalf("ParseDuration 1y != 8784h0m0s : %s", td.String()) | ||||
| 	if td.String() != "24h0m0s" { | ||||
| 		t.Fatalf("ParseDuration 1d != 24h0m0s : %s", td.String()) | ||||
| 	} | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user