Compare commits
	
		
			29 Commits
		
	
	
		
			b6d2d459c5
			...
			v3.11.46
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 3e86864ce7 | |||
|  | a68d3b24b8 | ||
| 9c22ae5384 | |||
|  | 16bad9a0cd | ||
|  | 3c779b248f | ||
| bbc7512054 | |||
|  | dd810e4ae0 | ||
| 236ed47ab1 | |||
|  | 909bcf51a4 | ||
| 9c24001f52 | |||
| 680cd6f708 | |||
|  | 3fcf3bef6d | ||
|  | 0ecd6199d4 | ||
| 14a30fb6a7 | |||
| 7c613072df | |||
|  | c55e212270 | ||
| 100bc006bb | |||
| 5dbfe8a7a6 | |||
| 6be077dbe8 | |||
| b4878211ee | |||
| ec9178c6d4 | |||
| ae63d44866 | |||
| 883e79216a | |||
| fa636ef6a9 | |||
| cdb81a9ba3 | |||
| 413c6cc2f0 | |||
|  | f56bd70136 | ||
| b51b4107a8 | |||
| 2067c9de6b | 
| @@ -3,14 +3,16 @@ name: coverage | |||||||
| on: | on: | ||||||
|   push: |   push: | ||||||
|     branches: [ main, v3, v4 ] |     branches: [ main, v3, v4 ] | ||||||
|  |     paths-ignore: | ||||||
|  |       - '.github/**' | ||||||
|  |       - '.gitea/**' | ||||||
|   pull_request: |   pull_request: | ||||||
|     branches: [ main, v3, v4 ] |     branches: [ main, v3, v4 ] | ||||||
|   # Allows you to run this workflow manually from the Actions tab |  | ||||||
|   workflow_dispatch: |  | ||||||
| 
 | 
 | ||||||
| jobs: | jobs: | ||||||
| 
 | 
 | ||||||
|   build: |   build: | ||||||
|  |     if: github.server_url != 'https://github.com' | ||||||
|     runs-on: ubuntu-latest |     runs-on: ubuntu-latest | ||||||
|     steps: |     steps: | ||||||
|     - name: checkout code |     - name: checkout code | ||||||
| @@ -39,8 +41,8 @@ jobs: | |||||||
|       name: autocommit |       name: autocommit | ||||||
|       with: |       with: | ||||||
|         commit_message: Apply Code Coverage Badge |         commit_message: Apply Code Coverage Badge | ||||||
|         skip_fetch: true |         skip_fetch: false | ||||||
|         skip_checkout: true |         skip_checkout: false | ||||||
|         file_pattern: ./README.md |         file_pattern: ./README.md | ||||||
| 
 | 
 | ||||||
|     - name: push |     - name: push | ||||||
| @@ -3,10 +3,10 @@ name: lint | |||||||
| on: | on: | ||||||
|   pull_request: |   pull_request: | ||||||
|     types: [opened, reopened, synchronize] |     types: [opened, reopened, synchronize] | ||||||
|     branches: |     branches: [ master, v3, v4 ] | ||||||
|     - master |     paths-ignore: | ||||||
|     - v3 |       - '.github/**' | ||||||
|     - v4 |       - '.gitea/**' | ||||||
| 
 | 
 | ||||||
| jobs: | jobs: | ||||||
|   lint: |   lint: | ||||||
| @@ -24,6 +24,6 @@ jobs: | |||||||
|     - name: setup deps |     - name: setup deps | ||||||
|       run: go get -v ./... |       run: go get -v ./... | ||||||
|     - name: run lint |     - name: run lint | ||||||
|       uses: https://github.com/golangci/golangci-lint-action@v6 |       uses: golangci/golangci-lint-action@v6 | ||||||
|       with: |       with: | ||||||
|         version: 'latest' |         version: 'latest' | ||||||
| @@ -3,15 +3,12 @@ name: test | |||||||
| on: | on: | ||||||
|   pull_request: |   pull_request: | ||||||
|     types: [opened, reopened, synchronize] |     types: [opened, reopened, synchronize] | ||||||
|     branches: |     branches: [ master, v3, v4 ] | ||||||
|     - master |  | ||||||
|     - v3 |  | ||||||
|     - v4 |  | ||||||
|   push: |   push: | ||||||
|     branches: |     branches: [ master, v3, v4 ] | ||||||
|     - master |     paths-ignore: | ||||||
|     - v3 |       - '.github/**' | ||||||
|     - v4 |       - '.gitea/**' | ||||||
| 
 | 
 | ||||||
| jobs: | jobs: | ||||||
|   test: |   test: | ||||||
| @@ -3,15 +3,12 @@ name: test | |||||||
| on: | on: | ||||||
|   pull_request: |   pull_request: | ||||||
|     types: [opened, reopened, synchronize] |     types: [opened, reopened, synchronize] | ||||||
|     branches: |     branches: [ master, v3, v4 ] | ||||||
|     - master |  | ||||||
|     - v3 |  | ||||||
|     - v4 |  | ||||||
|   push: |   push: | ||||||
|     branches: |     branches: [ master, v3, v4 ] | ||||||
|     - master |     paths-ignore: | ||||||
|     - v3 |       - '.github/**' | ||||||
|     - v4 |       - '.gitea/**' | ||||||
| 
 | 
 | ||||||
| jobs: | jobs: | ||||||
|   test: |   test: | ||||||
| @@ -35,19 +32,19 @@ jobs: | |||||||
|         go-version: 'stable' |         go-version: 'stable' | ||||||
|     - name: setup go work |     - name: setup go work | ||||||
|       env: |       env: | ||||||
|         GOWORK: /workspace/${{ github.repository_owner }}/go.work |         GOWORK: ${{ github.workspace }}/go.work | ||||||
|       run: | |       run: | | ||||||
|         go work init |         go work init | ||||||
|         go work use . |         go work use . | ||||||
|         go work use micro-tests |         go work use micro-tests | ||||||
|     - name: setup deps |     - name: setup deps | ||||||
|       env: |       env: | ||||||
|         GOWORK: /workspace/${{ github.repository_owner }}/go.work |         GOWORK: ${{ github.workspace }}/go.work | ||||||
|       run: go get -v ./... |       run: go get -v ./... | ||||||
|     - name: run tests |     - name: run tests | ||||||
|       env: |       env: | ||||||
|         INTEGRATION_TESTS: yes |         INTEGRATION_TESTS: yes | ||||||
|         GOWORK: /workspace/${{ github.repository_owner }}/go.work |         GOWORK: ${{ github.workspace }}/go.work | ||||||
|       run: | |       run: | | ||||||
|         cd micro-tests |         cd micro-tests | ||||||
|         go test -mod readonly -v ./... || true |         go test -mod readonly -v ./... || true | ||||||
| @@ -1,5 +1,5 @@ | |||||||
| run: | run: | ||||||
|   concurrency: 8 |   concurrency: 8 | ||||||
|   deadline: 5m |   timeout: 5m | ||||||
|   issues-exit-code: 1 |   issues-exit-code: 1 | ||||||
|   tests: true |   tests: true | ||||||
|   | |||||||
							
								
								
									
										10
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								README.md
									
									
									
									
									
								
							| @@ -1,5 +1,5 @@ | |||||||
| # Micro | # Micro | ||||||
|  |  | ||||||
| [](https://opensource.org/licenses/Apache-2.0) | [](https://opensource.org/licenses/Apache-2.0) | ||||||
| [](https://pkg.go.dev/go.unistack.org/micro/v3?tab=overview) | [](https://pkg.go.dev/go.unistack.org/micro/v3?tab=overview) | ||||||
| [](https://git.unistack.org/unistack-org/micro/actions?query=workflow%3Abuild+branch%3Av3+event%3Apush) | [](https://git.unistack.org/unistack-org/micro/actions?query=workflow%3Abuild+branch%3Av3+event%3Apush) | ||||||
| @@ -9,20 +9,20 @@ Micro is a standard library for microservices. | |||||||
|  |  | ||||||
| ## Overview | ## Overview | ||||||
|  |  | ||||||
| Micro provides the core requirements for distributed systems development including RPC and Event driven communication.  | Micro provides the core requirements for distributed systems development including SYNC and ASYNC communication.  | ||||||
|  |  | ||||||
| ## Features | ## Features | ||||||
|  |  | ||||||
| Micro abstracts away the details of distributed systems. Here are the main features. | Micro abstracts away the details of distributed systems. Main features: | ||||||
|  |  | ||||||
| - **Dynamic Config** - Load and hot reload dynamic config from anywhere. The config interface provides a way to load application  | - **Dynamic Config** - Load and hot reload dynamic config from anywhere. The config interface provides a way to load application  | ||||||
| level config from any source such as env vars, cmdline, file, consul, vault... You can merge the sources and even define fallbacks. | level config from any source such as env vars, cmdline, file, consul, vault, etc... You can merge the sources and even define fallbacks. | ||||||
|  |  | ||||||
| - **Data Storage** - A simple data store interface to read, write and delete records. It includes support for memory, file and  | - **Data Storage** - A simple data store interface to read, write and delete records. It includes support for memory, file and  | ||||||
| s3. State and persistence becomes a core requirement beyond prototyping and Micro looks to build that into the framework. | s3. State and persistence becomes a core requirement beyond prototyping and Micro looks to build that into the framework. | ||||||
|  |  | ||||||
| - **Service Discovery** - Automatic service registration and name resolution. Service discovery is at the core of micro service  | - **Service Discovery** - Automatic service registration and name resolution. Service discovery is at the core of micro service  | ||||||
| development. When service A needs to speak to service B it needs the location of that service. | development. | ||||||
|  |  | ||||||
| - **Message Encoding** - Dynamic message encoding based on content-type. The client and server will use codecs along with content-type  | - **Message Encoding** - Dynamic message encoding based on content-type. The client and server will use codecs along with content-type  | ||||||
| to seamlessly encode and decode Go types for you. Any variety of messages could be encoded and sent from different clients. The client  | to seamlessly encode and decode Go types for you. Any variety of messages could be encoded and sent from different clients. The client  | ||||||
|   | |||||||
							
								
								
									
										15
									
								
								SECURITY.md
									
									
									
									
									
								
							
							
						
						
									
										15
									
								
								SECURITY.md
									
									
									
									
									
								
							| @@ -1,15 +0,0 @@ | |||||||
| # Security Policy |  | ||||||
|  |  | ||||||
| ## Supported Versions |  | ||||||
|  |  | ||||||
| Use this section to tell people about which versions of your project are |  | ||||||
| currently being supported with security updates. |  | ||||||
|  |  | ||||||
| | Version | Supported          | |  | ||||||
| | ------- | ------------------ | |  | ||||||
| | 3.7.x   | :white_check_mark: | |  | ||||||
| | < 3.7.0 | :x:                | |  | ||||||
|  |  | ||||||
| ## Reporting a Vulnerability |  | ||||||
|  |  | ||||||
| If you find any issue, please create github issue in this repo |  | ||||||
| @@ -17,11 +17,6 @@ import ( | |||||||
| 	"go.unistack.org/micro/v3/tracer" | 	"go.unistack.org/micro/v3/tracer" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // DefaultCodecs will be used to encode/decode data |  | ||||||
| var DefaultCodecs = map[string]codec.Codec{ |  | ||||||
| 	"application/octet-stream": codec.NewCodec(), |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type noopClient struct { | type noopClient struct { | ||||||
| 	funcPublish      FuncPublish | 	funcPublish      FuncPublish | ||||||
| 	funcBatchPublish FuncBatchPublish | 	funcBatchPublish FuncBatchPublish | ||||||
| @@ -178,9 +173,6 @@ func (n *noopClient) newCodec(contentType string) (codec.Codec, error) { | |||||||
| 	if cf, ok := n.opts.Codecs[contentType]; ok { | 	if cf, ok := n.opts.Codecs[contentType]; ok { | ||||||
| 		return cf, nil | 		return cf, nil | ||||||
| 	} | 	} | ||||||
| 	if cf, ok := DefaultCodecs[contentType]; ok { |  | ||||||
| 		return cf, nil |  | ||||||
| 	} |  | ||||||
| 	return nil, codec.ErrUnknownContentType | 	return nil, codec.ErrUnknownContentType | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -3,6 +3,8 @@ package client | |||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"testing" | 	"testing" | ||||||
|  |  | ||||||
|  | 	"go.unistack.org/micro/v3/codec" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type testHook struct { | type testHook struct { | ||||||
| @@ -19,7 +21,7 @@ func (t *testHook) Publish(fn FuncPublish) FuncPublish { | |||||||
| func TestNoopHook(t *testing.T) { | func TestNoopHook(t *testing.T) { | ||||||
| 	h := &testHook{} | 	h := &testHook{} | ||||||
|  |  | ||||||
| 	c := NewClient(Hooks(HookPublish(h.Publish))) | 	c := NewClient(Codec("application/octet-stream", codec.NewCodec()), Hooks(HookPublish(h.Publish))) | ||||||
|  |  | ||||||
| 	if err := c.Init(); err != nil { | 	if err := c.Init(); err != nil { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
|   | |||||||
| @@ -198,7 +198,7 @@ func NewOptions(opts ...Option) Options { | |||||||
| 	options := Options{ | 	options := Options{ | ||||||
| 		Context:     context.Background(), | 		Context:     context.Background(), | ||||||
| 		ContentType: DefaultContentType, | 		ContentType: DefaultContentType, | ||||||
| 		Codecs:      DefaultCodecs, | 		Codecs:      make(map[string]codec.Codec), | ||||||
| 		CallOptions: CallOptions{ | 		CallOptions: CallOptions{ | ||||||
| 			Context:        context.Background(), | 			Context:        context.Background(), | ||||||
| 			Backoff:        DefaultBackoff, | 			Backoff:        DefaultBackoff, | ||||||
|   | |||||||
							
								
								
									
										235
									
								
								cluster/hasql/cluster.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										235
									
								
								cluster/hasql/cluster.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,235 @@ | |||||||
|  | package sql | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"database/sql" | ||||||
|  | 	"reflect" | ||||||
|  | 	"unsafe" | ||||||
|  |  | ||||||
|  | 	"golang.yandex/hasql/v2" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func newSQLRowError() *sql.Row { | ||||||
|  | 	row := &sql.Row{} | ||||||
|  | 	t := reflect.TypeOf(row).Elem() | ||||||
|  | 	field, _ := t.FieldByName("err") | ||||||
|  | 	rowPtr := unsafe.Pointer(row) | ||||||
|  | 	errFieldPtr := unsafe.Pointer(uintptr(rowPtr) + field.Offset) | ||||||
|  | 	errPtr := (*error)(errFieldPtr) | ||||||
|  | 	*errPtr = ErrorNoAliveNodes | ||||||
|  | 	return row | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type ClusterQuerier interface { | ||||||
|  | 	Querier | ||||||
|  | 	WaitForNodes(ctx context.Context, criterion ...hasql.NodeStateCriterion) error | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Cluster struct { | ||||||
|  | 	hasql   *hasql.Cluster[Querier] | ||||||
|  | 	options ClusterOptions | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // NewCluster returns [Querier] that provides cluster of nodes | ||||||
|  | func NewCluster[T Querier](opts ...ClusterOption) (ClusterQuerier, error) { | ||||||
|  | 	options := ClusterOptions{Context: context.Background()} | ||||||
|  | 	for _, opt := range opts { | ||||||
|  | 		opt(&options) | ||||||
|  | 	} | ||||||
|  | 	if options.NodeChecker == nil { | ||||||
|  | 		return nil, ErrClusterChecker | ||||||
|  | 	} | ||||||
|  | 	if options.NodeDiscoverer == nil { | ||||||
|  | 		return nil, ErrClusterDiscoverer | ||||||
|  | 	} | ||||||
|  | 	if options.NodePicker == nil { | ||||||
|  | 		return nil, ErrClusterPicker | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if options.Retries < 1 { | ||||||
|  | 		options.Retries = 1 | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if options.NodeStateCriterion == 0 { | ||||||
|  | 		options.NodeStateCriterion = hasql.Primary | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	options.Options = append(options.Options, hasql.WithNodePicker(options.NodePicker)) | ||||||
|  | 	if p, ok := options.NodePicker.(*CustomPicker[Querier]); ok { | ||||||
|  | 		p.opts.Priority = options.NodePriority | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	c, err := hasql.NewCluster( | ||||||
|  | 		options.NodeDiscoverer, | ||||||
|  | 		options.NodeChecker, | ||||||
|  | 		options.Options..., | ||||||
|  | 	) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return &Cluster{hasql: c, options: options}, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Cluster) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { | ||||||
|  | 	var tx *sql.Tx | ||||||
|  | 	var err error | ||||||
|  |  | ||||||
|  | 	retries := 0 | ||||||
|  | 	c.hasql.NodesIter(c.getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { | ||||||
|  | 		for ; retries < c.options.Retries; retries++ { | ||||||
|  | 			if tx, err = n.DB().BeginTx(ctx, opts); err != nil && retries >= c.options.Retries { | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		return false | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	if tx == nil && err == nil { | ||||||
|  | 		err = ErrorNoAliveNodes | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return tx, err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Cluster) Close() error { | ||||||
|  | 	return c.hasql.Close() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Cluster) Conn(ctx context.Context) (*sql.Conn, error) { | ||||||
|  | 	var conn *sql.Conn | ||||||
|  | 	var err error | ||||||
|  |  | ||||||
|  | 	retries := 0 | ||||||
|  | 	c.hasql.NodesIter(c.getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { | ||||||
|  | 		for ; retries < c.options.Retries; retries++ { | ||||||
|  | 			if conn, err = n.DB().Conn(ctx); err != nil && retries >= c.options.Retries { | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		return false | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	if conn == nil && err == nil { | ||||||
|  | 		err = ErrorNoAliveNodes | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return conn, err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Cluster) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { | ||||||
|  | 	var res sql.Result | ||||||
|  | 	var err error | ||||||
|  |  | ||||||
|  | 	retries := 0 | ||||||
|  | 	c.hasql.NodesIter(c.getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { | ||||||
|  | 		for ; retries < c.options.Retries; retries++ { | ||||||
|  | 			if res, err = n.DB().ExecContext(ctx, query, args...); err != nil && retries >= c.options.Retries { | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		return false | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	if res == nil && err == nil { | ||||||
|  | 		err = ErrorNoAliveNodes | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return res, err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Cluster) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { | ||||||
|  | 	var res *sql.Stmt | ||||||
|  | 	var err error | ||||||
|  |  | ||||||
|  | 	retries := 0 | ||||||
|  | 	c.hasql.NodesIter(c.getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { | ||||||
|  | 		for ; retries < c.options.Retries; retries++ { | ||||||
|  | 			if res, err = n.DB().PrepareContext(ctx, query); err != nil && retries >= c.options.Retries { | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		return false | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	if res == nil && err == nil { | ||||||
|  | 		err = ErrorNoAliveNodes | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return res, err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Cluster) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { | ||||||
|  | 	var res *sql.Rows | ||||||
|  | 	var err error | ||||||
|  |  | ||||||
|  | 	retries := 0 | ||||||
|  | 	c.hasql.NodesIter(c.getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { | ||||||
|  | 		for ; retries < c.options.Retries; retries++ { | ||||||
|  | 			if res, err = n.DB().QueryContext(ctx, query); err != nil && err != sql.ErrNoRows && retries >= c.options.Retries { | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		return false | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	if res == nil && err == nil { | ||||||
|  | 		err = ErrorNoAliveNodes | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return res, err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Cluster) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { | ||||||
|  | 	var res *sql.Row | ||||||
|  |  | ||||||
|  | 	retries := 0 | ||||||
|  | 	c.hasql.NodesIter(c.getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { | ||||||
|  | 		for ; retries < c.options.Retries; retries++ { | ||||||
|  | 			res = n.DB().QueryRowContext(ctx, query, args...) | ||||||
|  | 			if res.Err() == nil { | ||||||
|  | 				return false | ||||||
|  | 			} else if res.Err() != nil && retries >= c.options.Retries { | ||||||
|  | 				return false | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		return true | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	if res == nil { | ||||||
|  | 		res = newSQLRowError() | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return res | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Cluster) PingContext(ctx context.Context) error { | ||||||
|  | 	var err error | ||||||
|  | 	var ok bool | ||||||
|  |  | ||||||
|  | 	retries := 0 | ||||||
|  | 	c.hasql.NodesIter(c.getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { | ||||||
|  | 		ok = true | ||||||
|  | 		for ; retries < c.options.Retries; retries++ { | ||||||
|  | 			if err = n.DB().PingContext(ctx); err != nil && retries >= c.options.Retries { | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		return false | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	if !ok { | ||||||
|  | 		err = ErrorNoAliveNodes | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Cluster) WaitForNodes(ctx context.Context, criterions ...hasql.NodeStateCriterion) error { | ||||||
|  | 	for _, criterion := range criterions { | ||||||
|  | 		if _, err := c.hasql.WaitForNode(ctx, criterion); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
							
								
								
									
										171
									
								
								cluster/hasql/cluster_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										171
									
								
								cluster/hasql/cluster_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,171 @@ | |||||||
|  | package sql | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"fmt" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/DATA-DOG/go-sqlmock" | ||||||
|  | 	"golang.yandex/hasql/v2" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestNewCluster(t *testing.T) { | ||||||
|  | 	dbMaster, dbMasterMock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	defer dbMaster.Close() | ||||||
|  | 	dbMasterMock.MatchExpectationsInOrder(false) | ||||||
|  |  | ||||||
|  | 	dbMasterMock.ExpectQuery(`.*pg_is_in_recovery.*`).WillReturnRows( | ||||||
|  | 		sqlmock.NewRowsWithColumnDefinition( | ||||||
|  | 			sqlmock.NewColumn("role").OfType("int8", 0), | ||||||
|  | 			sqlmock.NewColumn("replication_lag").OfType("int8", 0)). | ||||||
|  | 			AddRow(1, 0)). | ||||||
|  | 		RowsWillBeClosed(). | ||||||
|  | 		WithoutArgs() | ||||||
|  |  | ||||||
|  | 	dbMasterMock.ExpectQuery(`SELECT node_name as name`).WillReturnRows( | ||||||
|  | 		sqlmock.NewRows([]string{"name"}). | ||||||
|  | 			AddRow("master-dc1")) | ||||||
|  |  | ||||||
|  | 	dbDRMaster, dbDRMasterMock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	defer dbDRMaster.Close() | ||||||
|  | 	dbDRMasterMock.MatchExpectationsInOrder(false) | ||||||
|  |  | ||||||
|  | 	dbDRMasterMock.ExpectQuery(`.*pg_is_in_recovery.*`).WillReturnRows( | ||||||
|  | 		sqlmock.NewRowsWithColumnDefinition( | ||||||
|  | 			sqlmock.NewColumn("role").OfType("int8", 0), | ||||||
|  | 			sqlmock.NewColumn("replication_lag").OfType("int8", 0)). | ||||||
|  | 			AddRow(2, 40)). | ||||||
|  | 		RowsWillBeClosed(). | ||||||
|  | 		WithoutArgs() | ||||||
|  |  | ||||||
|  | 	dbDRMasterMock.ExpectQuery(`SELECT node_name as name`).WillReturnRows( | ||||||
|  | 		sqlmock.NewRows([]string{"name"}). | ||||||
|  | 			AddRow("drmaster1-dc2")) | ||||||
|  |  | ||||||
|  | 	dbDRMasterMock.ExpectQuery(`SELECT node_name as name`).WillReturnRows( | ||||||
|  | 		sqlmock.NewRows([]string{"name"}). | ||||||
|  | 			AddRow("drmaster")) | ||||||
|  |  | ||||||
|  | 	dbSlaveDC1, dbSlaveDC1Mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	defer dbSlaveDC1.Close() | ||||||
|  | 	dbSlaveDC1Mock.MatchExpectationsInOrder(false) | ||||||
|  |  | ||||||
|  | 	dbSlaveDC1Mock.ExpectQuery(`.*pg_is_in_recovery.*`).WillReturnRows( | ||||||
|  | 		sqlmock.NewRowsWithColumnDefinition( | ||||||
|  | 			sqlmock.NewColumn("role").OfType("int8", 0), | ||||||
|  | 			sqlmock.NewColumn("replication_lag").OfType("int8", 0)). | ||||||
|  | 			AddRow(2, 50)). | ||||||
|  | 		RowsWillBeClosed(). | ||||||
|  | 		WithoutArgs() | ||||||
|  |  | ||||||
|  | 	dbSlaveDC1Mock.ExpectQuery(`SELECT node_name as name`).WillReturnRows( | ||||||
|  | 		sqlmock.NewRows([]string{"name"}). | ||||||
|  | 			AddRow("slave-dc1")) | ||||||
|  |  | ||||||
|  | 	dbSlaveDC2, dbSlaveDC2Mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	defer dbSlaveDC2.Close() | ||||||
|  | 	dbSlaveDC1Mock.MatchExpectationsInOrder(false) | ||||||
|  |  | ||||||
|  | 	dbSlaveDC2Mock.ExpectQuery(`.*pg_is_in_recovery.*`).WillReturnRows( | ||||||
|  | 		sqlmock.NewRowsWithColumnDefinition( | ||||||
|  | 			sqlmock.NewColumn("role").OfType("int8", 0), | ||||||
|  | 			sqlmock.NewColumn("replication_lag").OfType("int8", 0)). | ||||||
|  | 			AddRow(2, 50)). | ||||||
|  | 		RowsWillBeClosed(). | ||||||
|  | 		WithoutArgs() | ||||||
|  |  | ||||||
|  | 	dbSlaveDC2Mock.ExpectQuery(`SELECT node_name as name`).WillReturnRows( | ||||||
|  | 		sqlmock.NewRows([]string{"name"}). | ||||||
|  | 			AddRow("slave-dc1")) | ||||||
|  |  | ||||||
|  | 	tctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) | ||||||
|  | 	defer cancel() | ||||||
|  |  | ||||||
|  | 	c, err := NewCluster[Querier]( | ||||||
|  | 		WithClusterContext(tctx), | ||||||
|  | 		WithClusterNodeChecker(hasql.PostgreSQLChecker), | ||||||
|  | 		WithClusterNodePicker(NewCustomPicker[Querier]( | ||||||
|  | 			CustomPickerMaxLag(100), | ||||||
|  | 		)), | ||||||
|  | 		WithClusterNodes( | ||||||
|  | 			ClusterNode{"slave-dc1", dbSlaveDC1, 1}, | ||||||
|  | 			ClusterNode{"master-dc1", dbMaster, 1}, | ||||||
|  | 			ClusterNode{"slave-dc2", dbSlaveDC2, 2}, | ||||||
|  | 			ClusterNode{"drmaster1-dc2", dbDRMaster, 0}, | ||||||
|  | 		), | ||||||
|  | 		WithClusterOptions( | ||||||
|  | 			hasql.WithUpdateInterval[Querier](2*time.Second), | ||||||
|  | 			hasql.WithUpdateTimeout[Querier](1*time.Second), | ||||||
|  | 		), | ||||||
|  | 	) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	defer c.Close() | ||||||
|  |  | ||||||
|  | 	if err = c.WaitForNodes(tctx, hasql.Primary, hasql.Standby); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	time.Sleep(500 * time.Millisecond) | ||||||
|  |  | ||||||
|  | 	node1Name := "" | ||||||
|  | 	fmt.Printf("check for Standby\n") | ||||||
|  | 	if row := c.QueryRowContext(NodeStateCriterion(tctx, hasql.Standby), "SELECT node_name as name"); row.Err() != nil { | ||||||
|  | 		t.Fatal(row.Err()) | ||||||
|  | 	} else if err = row.Scan(&node1Name); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} else if "slave-dc1" != node1Name { | ||||||
|  | 		t.Fatalf("invalid node name %s != %s", "slave-dc1", node1Name) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	dbSlaveDC1Mock.ExpectQuery(`SELECT node_name as name`).WillReturnRows( | ||||||
|  | 		sqlmock.NewRows([]string{"name"}). | ||||||
|  | 			AddRow("slave-dc1")) | ||||||
|  |  | ||||||
|  | 	node2Name := "" | ||||||
|  | 	fmt.Printf("check for PreferStandby\n") | ||||||
|  | 	if row := c.QueryRowContext(NodeStateCriterion(tctx, hasql.PreferStandby), "SELECT node_name as name"); row.Err() != nil { | ||||||
|  | 		t.Fatal(row.Err()) | ||||||
|  | 	} else if err = row.Scan(&node2Name); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} else if "slave-dc1" != node2Name { | ||||||
|  | 		t.Fatalf("invalid node name %s != %s", "slave-dc1", node2Name) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	node3Name := "" | ||||||
|  | 	fmt.Printf("check for PreferPrimary\n") | ||||||
|  | 	if row := c.QueryRowContext(NodeStateCriterion(tctx, hasql.PreferPrimary), "SELECT node_name as name"); row.Err() != nil { | ||||||
|  | 		t.Fatal(row.Err()) | ||||||
|  | 	} else if err = row.Scan(&node3Name); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} else if "master-dc1" != node3Name { | ||||||
|  | 		t.Fatalf("invalid node name %s != %s", "master-dc1", node3Name) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	dbSlaveDC1Mock.ExpectQuery(`.*`).WillReturnRows(sqlmock.NewRows([]string{"role"}).RowError(1, fmt.Errorf("row error"))) | ||||||
|  |  | ||||||
|  | 	time.Sleep(2 * time.Second) | ||||||
|  |  | ||||||
|  | 	fmt.Printf("check for PreferStandby\n") | ||||||
|  | 	if row := c.QueryRowContext(NodeStateCriterion(tctx, hasql.PreferStandby), "SELECT node_name as name"); row.Err() == nil { | ||||||
|  | 		t.Fatal("must return error") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if dbMasterErr := dbMasterMock.ExpectationsWereMet(); dbMasterErr != nil { | ||||||
|  | 		t.Error(dbMasterErr) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										25
									
								
								cluster/hasql/db.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								cluster/hasql/db.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,25 @@ | |||||||
|  | package sql | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"database/sql" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type Querier interface { | ||||||
|  | 	// Basic connection methods | ||||||
|  | 	PingContext(ctx context.Context) error | ||||||
|  | 	Close() error | ||||||
|  |  | ||||||
|  | 	// Query methods with context | ||||||
|  | 	ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) | ||||||
|  | 	QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) | ||||||
|  | 	QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row | ||||||
|  |  | ||||||
|  | 	// Prepared statements with context | ||||||
|  | 	PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) | ||||||
|  |  | ||||||
|  | 	// Transaction management with context | ||||||
|  | 	BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) | ||||||
|  |  | ||||||
|  | 	Conn(ctx context.Context) (*sql.Conn, error) | ||||||
|  | } | ||||||
							
								
								
									
										295
									
								
								cluster/hasql/driver.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										295
									
								
								cluster/hasql/driver.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,295 @@ | |||||||
|  | package sql | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"database/sql" | ||||||
|  | 	"database/sql/driver" | ||||||
|  | 	"io" | ||||||
|  | 	"sync" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // OpenDBWithCluster creates a [*sql.DB] that uses the [ClusterQuerier] | ||||||
|  | func OpenDBWithCluster(db ClusterQuerier) (*sql.DB, error) { | ||||||
|  | 	driver := NewClusterDriver(db) | ||||||
|  | 	connector, err := driver.OpenConnector("") | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	return sql.OpenDB(connector), nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ClusterDriver implements [driver.Driver] and driver.Connector for an existing [Querier] | ||||||
|  | type ClusterDriver struct { | ||||||
|  | 	db ClusterQuerier | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // NewClusterDriver creates a new [driver.Driver] that uses an existing [ClusterQuerier] | ||||||
|  | func NewClusterDriver(db ClusterQuerier) *ClusterDriver { | ||||||
|  | 	return &ClusterDriver{db: db} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Open implements [driver.Driver.Open] | ||||||
|  | func (d *ClusterDriver) Open(name string) (driver.Conn, error) { | ||||||
|  | 	return d.Connect(context.Background()) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // OpenConnector implements [driver.DriverContext.OpenConnector] | ||||||
|  | func (d *ClusterDriver) OpenConnector(name string) (driver.Connector, error) { | ||||||
|  | 	return d, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Connect implements [driver.Connector.Connect] | ||||||
|  | func (d *ClusterDriver) Connect(ctx context.Context) (driver.Conn, error) { | ||||||
|  | 	conn, err := d.db.Conn(ctx) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	return &dbConn{conn: conn}, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Driver implements [driver.Connector.Driver] | ||||||
|  | func (d *ClusterDriver) Driver() driver.Driver { | ||||||
|  | 	return d | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // dbConn implements driver.Conn with both context and legacy methods | ||||||
|  | type dbConn struct { | ||||||
|  | 	conn *sql.Conn | ||||||
|  | 	mu   sync.Mutex | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Prepare implements [driver.Conn.Prepare] (legacy method) | ||||||
|  | func (c *dbConn) Prepare(query string) (driver.Stmt, error) { | ||||||
|  | 	return c.PrepareContext(context.Background(), query) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // PrepareContext implements [driver.ConnPrepareContext.PrepareContext] | ||||||
|  | func (c *dbConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { | ||||||
|  | 	c.mu.Lock() | ||||||
|  | 	defer c.mu.Unlock() | ||||||
|  |  | ||||||
|  | 	stmt, err := c.conn.PrepareContext(ctx, query) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return &dbStmt{stmt: stmt}, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Exec implements [driver.Execer.Exec] (legacy method) | ||||||
|  | func (c *dbConn) Exec(query string, args []driver.Value) (driver.Result, error) { | ||||||
|  | 	namedArgs := make([]driver.NamedValue, len(args)) | ||||||
|  | 	for i, value := range args { | ||||||
|  | 		namedArgs[i] = driver.NamedValue{Value: value} | ||||||
|  | 	} | ||||||
|  | 	return c.ExecContext(context.Background(), query, namedArgs) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ExecContext implements [driver.ExecerContext.ExecContext] | ||||||
|  | func (c *dbConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { | ||||||
|  | 	c.mu.Lock() | ||||||
|  | 	defer c.mu.Unlock() | ||||||
|  |  | ||||||
|  | 	// Convert driver.NamedValue to any | ||||||
|  | 	interfaceArgs := make([]any, len(args)) | ||||||
|  | 	for i, arg := range args { | ||||||
|  | 		interfaceArgs[i] = arg.Value | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return c.conn.ExecContext(ctx, query, interfaceArgs...) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Query implements [driver.Queryer.Query] (legacy method) | ||||||
|  | func (c *dbConn) Query(query string, args []driver.Value) (driver.Rows, error) { | ||||||
|  | 	namedArgs := make([]driver.NamedValue, len(args)) | ||||||
|  | 	for i, value := range args { | ||||||
|  | 		namedArgs[i] = driver.NamedValue{Value: value} | ||||||
|  | 	} | ||||||
|  | 	return c.QueryContext(context.Background(), query, namedArgs) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // QueryContext implements [driver.QueryerContext.QueryContext] | ||||||
|  | func (c *dbConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { | ||||||
|  | 	c.mu.Lock() | ||||||
|  | 	defer c.mu.Unlock() | ||||||
|  |  | ||||||
|  | 	// Convert driver.NamedValue to any | ||||||
|  | 	interfaceArgs := make([]any, len(args)) | ||||||
|  | 	for i, arg := range args { | ||||||
|  | 		interfaceArgs[i] = arg.Value | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	rows, err := c.conn.QueryContext(ctx, query, interfaceArgs...) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return &dbRows{rows: rows}, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Begin implements [driver.Conn.Begin] (legacy method) | ||||||
|  | func (c *dbConn) Begin() (driver.Tx, error) { | ||||||
|  | 	return c.BeginTx(context.Background(), driver.TxOptions{}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // BeginTx implements [driver.ConnBeginTx.BeginTx] | ||||||
|  | func (c *dbConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { | ||||||
|  | 	c.mu.Lock() | ||||||
|  | 	defer c.mu.Unlock() | ||||||
|  |  | ||||||
|  | 	sqlOpts := &sql.TxOptions{ | ||||||
|  | 		Isolation: sql.IsolationLevel(opts.Isolation), | ||||||
|  | 		ReadOnly:  opts.ReadOnly, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	tx, err := c.conn.BeginTx(ctx, sqlOpts) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return &dbTx{tx: tx}, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Ping implements [driver.Pinger.Ping] | ||||||
|  | func (c *dbConn) Ping(ctx context.Context) error { | ||||||
|  | 	return c.conn.PingContext(ctx) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Close implements [driver.Conn.Close] | ||||||
|  | func (c *dbConn) Close() error { | ||||||
|  | 	return c.conn.Close() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // IsValid implements [driver.Validator.IsValid] | ||||||
|  | func (c *dbConn) IsValid() bool { | ||||||
|  | 	// Ping with a short timeout to check if the connection is still valid | ||||||
|  | 	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) | ||||||
|  | 	defer cancel() | ||||||
|  |  | ||||||
|  | 	return c.conn.PingContext(ctx) == nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // dbStmt implements [driver.Stmt] with both context and legacy methods | ||||||
|  | type dbStmt struct { | ||||||
|  | 	stmt *sql.Stmt | ||||||
|  | 	mu   sync.Mutex | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Close implements [driver.Stmt.Close] | ||||||
|  | func (s *dbStmt) Close() error { | ||||||
|  | 	s.mu.Lock() | ||||||
|  | 	defer s.mu.Unlock() | ||||||
|  | 	return s.stmt.Close() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Close implements [driver.Stmt.NumInput] | ||||||
|  | func (s *dbStmt) NumInput() int { | ||||||
|  | 	return -1 // Number of parameters is unknown | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Exec implements [driver.Stmt.Exec] (legacy method) | ||||||
|  | func (s *dbStmt) Exec(args []driver.Value) (driver.Result, error) { | ||||||
|  | 	namedArgs := make([]driver.NamedValue, len(args)) | ||||||
|  | 	for i, value := range args { | ||||||
|  | 		namedArgs[i] = driver.NamedValue{Value: value} | ||||||
|  | 	} | ||||||
|  | 	return s.ExecContext(context.Background(), namedArgs) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ExecContext implements [driver.StmtExecContext.ExecContext] | ||||||
|  | func (s *dbStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { | ||||||
|  | 	s.mu.Lock() | ||||||
|  | 	defer s.mu.Unlock() | ||||||
|  |  | ||||||
|  | 	interfaceArgs := make([]any, len(args)) | ||||||
|  | 	for i, arg := range args { | ||||||
|  | 		interfaceArgs[i] = arg.Value | ||||||
|  | 	} | ||||||
|  | 	return s.stmt.ExecContext(ctx, interfaceArgs...) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Query implements [driver.Stmt.Query] (legacy method) | ||||||
|  | func (s *dbStmt) Query(args []driver.Value) (driver.Rows, error) { | ||||||
|  | 	namedArgs := make([]driver.NamedValue, len(args)) | ||||||
|  | 	for i, value := range args { | ||||||
|  | 		namedArgs[i] = driver.NamedValue{Value: value} | ||||||
|  | 	} | ||||||
|  | 	return s.QueryContext(context.Background(), namedArgs) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // QueryContext implements [driver.StmtQueryContext.QueryContext] | ||||||
|  | func (s *dbStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { | ||||||
|  | 	s.mu.Lock() | ||||||
|  | 	defer s.mu.Unlock() | ||||||
|  |  | ||||||
|  | 	interfaceArgs := make([]any, len(args)) | ||||||
|  | 	for i, arg := range args { | ||||||
|  | 		interfaceArgs[i] = arg.Value | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	rows, err := s.stmt.QueryContext(ctx, interfaceArgs...) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return &dbRows{rows: rows}, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // dbRows implements [driver.Rows] | ||||||
|  | type dbRows struct { | ||||||
|  | 	rows *sql.Rows | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Columns implements [driver.Rows.Columns] | ||||||
|  | func (r *dbRows) Columns() []string { | ||||||
|  | 	cols, err := r.rows.Columns() | ||||||
|  | 	if err != nil { | ||||||
|  | 		// This shouldn't happen if the query was successful | ||||||
|  | 		return []string{} | ||||||
|  | 	} | ||||||
|  | 	return cols | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Close implements [driver.Rows.Close] | ||||||
|  | func (r *dbRows) Close() error { | ||||||
|  | 	return r.rows.Close() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Next implements [driver.Rows.Next] | ||||||
|  | func (r *dbRows) Next(dest []driver.Value) error { | ||||||
|  | 	if !r.rows.Next() { | ||||||
|  | 		if err := r.rows.Err(); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		return io.EOF | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Create a slice of interfaces to scan into | ||||||
|  | 	scanArgs := make([]any, len(dest)) | ||||||
|  | 	for i := range scanArgs { | ||||||
|  | 		scanArgs[i] = &dest[i] | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return r.rows.Scan(scanArgs...) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // dbTx implements [driver.Tx] | ||||||
|  | type dbTx struct { | ||||||
|  | 	tx *sql.Tx | ||||||
|  | 	mu sync.Mutex | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Commit implements [driver.Tx.Commit] | ||||||
|  | func (t *dbTx) Commit() error { | ||||||
|  | 	t.mu.Lock() | ||||||
|  | 	defer t.mu.Unlock() | ||||||
|  | 	return t.tx.Commit() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Rollback implements [driver.Tx.Rollback] | ||||||
|  | func (t *dbTx) Rollback() error { | ||||||
|  | 	t.mu.Lock() | ||||||
|  | 	defer t.mu.Unlock() | ||||||
|  | 	return t.tx.Rollback() | ||||||
|  | } | ||||||
							
								
								
									
										141
									
								
								cluster/hasql/driver_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										141
									
								
								cluster/hasql/driver_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,141 @@ | |||||||
|  | package sql | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/DATA-DOG/go-sqlmock" | ||||||
|  | 	"golang.yandex/hasql/v2" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestDriver(t *testing.T) { | ||||||
|  | 	dbMaster, dbMasterMock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	defer dbMaster.Close() | ||||||
|  | 	dbMasterMock.MatchExpectationsInOrder(false) | ||||||
|  |  | ||||||
|  | 	dbMasterMock.ExpectQuery(`.*pg_is_in_recovery.*`).WillReturnRows( | ||||||
|  | 		sqlmock.NewRowsWithColumnDefinition( | ||||||
|  | 			sqlmock.NewColumn("role").OfType("int8", 0), | ||||||
|  | 			sqlmock.NewColumn("replication_lag").OfType("int8", 0)). | ||||||
|  | 			AddRow(1, 0)). | ||||||
|  | 		RowsWillBeClosed(). | ||||||
|  | 		WithoutArgs() | ||||||
|  |  | ||||||
|  | 	dbMasterMock.ExpectQuery(`SELECT node_name as name`).WillReturnRows( | ||||||
|  | 		sqlmock.NewRows([]string{"name"}). | ||||||
|  | 			AddRow("master-dc1")) | ||||||
|  |  | ||||||
|  | 	dbDRMaster, dbDRMasterMock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	defer dbDRMaster.Close() | ||||||
|  | 	dbDRMasterMock.MatchExpectationsInOrder(false) | ||||||
|  |  | ||||||
|  | 	dbDRMasterMock.ExpectQuery(`.*pg_is_in_recovery.*`).WillReturnRows( | ||||||
|  | 		sqlmock.NewRowsWithColumnDefinition( | ||||||
|  | 			sqlmock.NewColumn("role").OfType("int8", 0), | ||||||
|  | 			sqlmock.NewColumn("replication_lag").OfType("int8", 0)). | ||||||
|  | 			AddRow(2, 40)). | ||||||
|  | 		RowsWillBeClosed(). | ||||||
|  | 		WithoutArgs() | ||||||
|  |  | ||||||
|  | 	dbDRMasterMock.ExpectQuery(`SELECT node_name as name`).WillReturnRows( | ||||||
|  | 		sqlmock.NewRows([]string{"name"}). | ||||||
|  | 			AddRow("drmaster1-dc2")) | ||||||
|  |  | ||||||
|  | 	dbDRMasterMock.ExpectQuery(`SELECT node_name as name`).WillReturnRows( | ||||||
|  | 		sqlmock.NewRows([]string{"name"}). | ||||||
|  | 			AddRow("drmaster")) | ||||||
|  |  | ||||||
|  | 	dbSlaveDC1, dbSlaveDC1Mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	defer dbSlaveDC1.Close() | ||||||
|  | 	dbSlaveDC1Mock.MatchExpectationsInOrder(false) | ||||||
|  |  | ||||||
|  | 	dbSlaveDC1Mock.ExpectQuery(`.*pg_is_in_recovery.*`).WillReturnRows( | ||||||
|  | 		sqlmock.NewRowsWithColumnDefinition( | ||||||
|  | 			sqlmock.NewColumn("role").OfType("int8", 0), | ||||||
|  | 			sqlmock.NewColumn("replication_lag").OfType("int8", 0)). | ||||||
|  | 			AddRow(2, 50)). | ||||||
|  | 		RowsWillBeClosed(). | ||||||
|  | 		WithoutArgs() | ||||||
|  |  | ||||||
|  | 	dbSlaveDC1Mock.ExpectQuery(`SELECT node_name as name`).WillReturnRows( | ||||||
|  | 		sqlmock.NewRows([]string{"name"}). | ||||||
|  | 			AddRow("slave-dc1")) | ||||||
|  |  | ||||||
|  | 	dbSlaveDC2, dbSlaveDC2Mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	defer dbSlaveDC2.Close() | ||||||
|  | 	dbSlaveDC1Mock.MatchExpectationsInOrder(false) | ||||||
|  |  | ||||||
|  | 	dbSlaveDC2Mock.ExpectQuery(`.*pg_is_in_recovery.*`).WillReturnRows( | ||||||
|  | 		sqlmock.NewRowsWithColumnDefinition( | ||||||
|  | 			sqlmock.NewColumn("role").OfType("int8", 0), | ||||||
|  | 			sqlmock.NewColumn("replication_lag").OfType("int8", 0)). | ||||||
|  | 			AddRow(2, 50)). | ||||||
|  | 		RowsWillBeClosed(). | ||||||
|  | 		WithoutArgs() | ||||||
|  |  | ||||||
|  | 	dbSlaveDC2Mock.ExpectQuery(`SELECT node_name as name`).WillReturnRows( | ||||||
|  | 		sqlmock.NewRows([]string{"name"}). | ||||||
|  | 			AddRow("slave-dc1")) | ||||||
|  |  | ||||||
|  | 	tctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) | ||||||
|  | 	defer cancel() | ||||||
|  |  | ||||||
|  | 	c, err := NewCluster[Querier]( | ||||||
|  | 		WithClusterContext(tctx), | ||||||
|  | 		WithClusterNodeChecker(hasql.PostgreSQLChecker), | ||||||
|  | 		WithClusterNodePicker(NewCustomPicker[Querier]( | ||||||
|  | 			CustomPickerMaxLag(100), | ||||||
|  | 		)), | ||||||
|  | 		WithClusterNodes( | ||||||
|  | 			ClusterNode{"slave-dc1", dbSlaveDC1, 1}, | ||||||
|  | 			ClusterNode{"master-dc1", dbMaster, 1}, | ||||||
|  | 			ClusterNode{"slave-dc2", dbSlaveDC2, 2}, | ||||||
|  | 			ClusterNode{"drmaster1-dc2", dbDRMaster, 0}, | ||||||
|  | 		), | ||||||
|  | 		WithClusterOptions( | ||||||
|  | 			hasql.WithUpdateInterval[Querier](2*time.Second), | ||||||
|  | 			hasql.WithUpdateTimeout[Querier](1*time.Second), | ||||||
|  | 		), | ||||||
|  | 	) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	defer c.Close() | ||||||
|  |  | ||||||
|  | 	if err = c.WaitForNodes(tctx, hasql.Primary, hasql.Standby); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	db, err := OpenDBWithCluster(c) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Use context methods | ||||||
|  | 	row := db.QueryRowContext(NodeStateCriterion(t.Context(), hasql.Primary), "SELECT node_name as name") | ||||||
|  | 	if err = row.Err(); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	nodeName := "" | ||||||
|  | 	if err = row.Scan(&nodeName); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if nodeName != "master-dc1" { | ||||||
|  | 		t.Fatalf("invalid node_name %s != %s", "master-dc1", nodeName) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										10
									
								
								cluster/hasql/error.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								cluster/hasql/error.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,10 @@ | |||||||
|  | package sql | ||||||
|  |  | ||||||
|  | import "errors" | ||||||
|  |  | ||||||
|  | var ( | ||||||
|  | 	ErrClusterChecker    = errors.New("cluster node checker required") | ||||||
|  | 	ErrClusterDiscoverer = errors.New("cluster node discoverer required") | ||||||
|  | 	ErrClusterPicker     = errors.New("cluster node picker required") | ||||||
|  | 	ErrorNoAliveNodes    = errors.New("cluster no alive nodes") | ||||||
|  | ) | ||||||
							
								
								
									
										110
									
								
								cluster/hasql/options.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										110
									
								
								cluster/hasql/options.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,110 @@ | |||||||
|  | package sql | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"math" | ||||||
|  |  | ||||||
|  | 	"golang.yandex/hasql/v2" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // ClusterOptions contains cluster specific options | ||||||
|  | type ClusterOptions struct { | ||||||
|  | 	NodeChecker        hasql.NodeChecker | ||||||
|  | 	NodePicker         hasql.NodePicker[Querier] | ||||||
|  | 	NodeDiscoverer     hasql.NodeDiscoverer[Querier] | ||||||
|  | 	Options            []hasql.ClusterOpt[Querier] | ||||||
|  | 	Context            context.Context | ||||||
|  | 	Retries            int | ||||||
|  | 	NodePriority       map[string]int32 | ||||||
|  | 	NodeStateCriterion hasql.NodeStateCriterion | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ClusterOption apply cluster options to ClusterOptions | ||||||
|  | type ClusterOption func(*ClusterOptions) | ||||||
|  |  | ||||||
|  | // WithClusterNodeChecker pass hasql.NodeChecker to cluster options | ||||||
|  | func WithClusterNodeChecker(c hasql.NodeChecker) ClusterOption { | ||||||
|  | 	return func(o *ClusterOptions) { | ||||||
|  | 		o.NodeChecker = c | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // WithClusterNodePicker pass hasql.NodePicker to cluster options | ||||||
|  | func WithClusterNodePicker(p hasql.NodePicker[Querier]) ClusterOption { | ||||||
|  | 	return func(o *ClusterOptions) { | ||||||
|  | 		o.NodePicker = p | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // WithClusterNodeDiscoverer pass hasql.NodeDiscoverer to cluster options | ||||||
|  | func WithClusterNodeDiscoverer(d hasql.NodeDiscoverer[Querier]) ClusterOption { | ||||||
|  | 	return func(o *ClusterOptions) { | ||||||
|  | 		o.NodeDiscoverer = d | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // WithRetries retry count on other nodes in case of error | ||||||
|  | func WithRetries(n int) ClusterOption { | ||||||
|  | 	return func(o *ClusterOptions) { | ||||||
|  | 		o.Retries = n | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // WithClusterContext pass context.Context to cluster options and used for checks | ||||||
|  | func WithClusterContext(ctx context.Context) ClusterOption { | ||||||
|  | 	return func(o *ClusterOptions) { | ||||||
|  | 		o.Context = ctx | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // WithClusterOptions pass hasql.ClusterOpt | ||||||
|  | func WithClusterOptions(opts ...hasql.ClusterOpt[Querier]) ClusterOption { | ||||||
|  | 	return func(o *ClusterOptions) { | ||||||
|  | 		o.Options = append(o.Options, opts...) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // WithClusterNodeStateCriterion pass default hasql.NodeStateCriterion | ||||||
|  | func WithClusterNodeStateCriterion(c hasql.NodeStateCriterion) ClusterOption { | ||||||
|  | 	return func(o *ClusterOptions) { | ||||||
|  | 		o.NodeStateCriterion = c | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type ClusterNode struct { | ||||||
|  | 	Name     string | ||||||
|  | 	DB       Querier | ||||||
|  | 	Priority int32 | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // WithClusterNodes create cluster with static NodeDiscoverer | ||||||
|  | func WithClusterNodes(cns ...ClusterNode) ClusterOption { | ||||||
|  | 	return func(o *ClusterOptions) { | ||||||
|  | 		nodes := make([]*hasql.Node[Querier], 0, len(cns)) | ||||||
|  | 		if o.NodePriority == nil { | ||||||
|  | 			o.NodePriority = make(map[string]int32, len(cns)) | ||||||
|  | 		} | ||||||
|  | 		for _, cn := range cns { | ||||||
|  | 			nodes = append(nodes, hasql.NewNode(cn.Name, cn.DB)) | ||||||
|  | 			if cn.Priority == 0 { | ||||||
|  | 				cn.Priority = math.MaxInt32 | ||||||
|  | 			} | ||||||
|  | 			o.NodePriority[cn.Name] = cn.Priority | ||||||
|  | 		} | ||||||
|  | 		o.NodeDiscoverer = hasql.NewStaticNodeDiscoverer(nodes...) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type nodeStateCriterionKey struct{} | ||||||
|  |  | ||||||
|  | // NodeStateCriterion inject hasql.NodeStateCriterion to context | ||||||
|  | func NodeStateCriterion(ctx context.Context, c hasql.NodeStateCriterion) context.Context { | ||||||
|  | 	return context.WithValue(ctx, nodeStateCriterionKey{}, c) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Cluster) getNodeStateCriterion(ctx context.Context) hasql.NodeStateCriterion { | ||||||
|  | 	if v, ok := ctx.Value(nodeStateCriterionKey{}).(hasql.NodeStateCriterion); ok { | ||||||
|  | 		return v | ||||||
|  | 	} | ||||||
|  | 	return c.options.NodeStateCriterion | ||||||
|  | } | ||||||
							
								
								
									
										113
									
								
								cluster/hasql/picker.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										113
									
								
								cluster/hasql/picker.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,113 @@ | |||||||
|  | package sql | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"math" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	"golang.yandex/hasql/v2" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // compile time guard | ||||||
|  | var _ hasql.NodePicker[Querier] = (*CustomPicker[Querier])(nil) | ||||||
|  |  | ||||||
|  | // CustomPickerOptions holds options to pick nodes | ||||||
|  | type CustomPickerOptions struct { | ||||||
|  | 	MaxLag   int | ||||||
|  | 	Priority map[string]int32 | ||||||
|  | 	Retries  int | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // CustomPickerOption func apply option to CustomPickerOptions | ||||||
|  | type CustomPickerOption func(*CustomPickerOptions) | ||||||
|  |  | ||||||
|  | // CustomPickerMaxLag specifies max lag for which node can be used | ||||||
|  | func CustomPickerMaxLag(n int) CustomPickerOption { | ||||||
|  | 	return func(o *CustomPickerOptions) { | ||||||
|  | 		o.MaxLag = n | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // NewCustomPicker creates new node picker | ||||||
|  | func NewCustomPicker[T Querier](opts ...CustomPickerOption) *CustomPicker[Querier] { | ||||||
|  | 	options := CustomPickerOptions{} | ||||||
|  | 	for _, o := range opts { | ||||||
|  | 		o(&options) | ||||||
|  | 	} | ||||||
|  | 	return &CustomPicker[Querier]{opts: options} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // CustomPicker holds node picker options | ||||||
|  | type CustomPicker[T Querier] struct { | ||||||
|  | 	opts CustomPickerOptions | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // PickNode used to return specific node | ||||||
|  | func (p *CustomPicker[T]) PickNode(cnodes []hasql.CheckedNode[T]) hasql.CheckedNode[T] { | ||||||
|  | 	for _, n := range cnodes { | ||||||
|  | 		fmt.Printf("node %s\n", n.Node.String()) | ||||||
|  | 	} | ||||||
|  | 	return cnodes[0] | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (p *CustomPicker[T]) getPriority(nodeName string) int32 { | ||||||
|  | 	if prio, ok := p.opts.Priority[nodeName]; ok { | ||||||
|  | 		return prio | ||||||
|  | 	} | ||||||
|  | 	return math.MaxInt32 // Default to lowest priority | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // CompareNodes used to sort nodes | ||||||
|  | func (p *CustomPicker[T]) CompareNodes(a, b hasql.CheckedNode[T]) int { | ||||||
|  | 	// Get replication lag values | ||||||
|  | 	aLag := a.Info.(interface{ ReplicationLag() int }).ReplicationLag() | ||||||
|  | 	bLag := b.Info.(interface{ ReplicationLag() int }).ReplicationLag() | ||||||
|  |  | ||||||
|  | 	// First check that lag lower then MaxLag | ||||||
|  | 	if aLag > p.opts.MaxLag && bLag > p.opts.MaxLag { | ||||||
|  | 		return 0 // both are equal | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// If one node exceeds MaxLag and the other doesn't, prefer the one that doesn't | ||||||
|  | 	if aLag > p.opts.MaxLag { | ||||||
|  | 		return 1 // b is better | ||||||
|  | 	} | ||||||
|  | 	if bLag > p.opts.MaxLag { | ||||||
|  | 		return -1 // a is better | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Get node priorities | ||||||
|  | 	aPrio := p.getPriority(a.Node.String()) | ||||||
|  | 	bPrio := p.getPriority(b.Node.String()) | ||||||
|  |  | ||||||
|  | 	// if both priority equals | ||||||
|  | 	if aPrio == bPrio { | ||||||
|  | 		// First compare by replication lag | ||||||
|  | 		if aLag < bLag { | ||||||
|  | 			return -1 | ||||||
|  | 		} | ||||||
|  | 		if aLag > bLag { | ||||||
|  | 			return 1 | ||||||
|  | 		} | ||||||
|  | 		// If replication lag is equal, compare by latency | ||||||
|  | 		aLatency := a.Info.(interface{ Latency() time.Duration }).Latency() | ||||||
|  | 		bLatency := b.Info.(interface{ Latency() time.Duration }).Latency() | ||||||
|  |  | ||||||
|  | 		if aLatency < bLatency { | ||||||
|  | 			return -1 | ||||||
|  | 		} | ||||||
|  | 		if aLatency > bLatency { | ||||||
|  | 			return 1 | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// If lag and latency is equal | ||||||
|  | 		return 0 | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// If priorities are different, prefer the node with lower priority value | ||||||
|  | 	if aPrio < bPrio { | ||||||
|  | 		return -1 | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return 1 | ||||||
|  | } | ||||||
							
								
								
									
										531
									
								
								cluster/sql/cluster.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										531
									
								
								cluster/sql/cluster.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,531 @@ | |||||||
|  | package sql | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"database/sql" | ||||||
|  | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  | 	"math" | ||||||
|  | 	"reflect" | ||||||
|  | 	"time" | ||||||
|  | 	"unsafe" | ||||||
|  |  | ||||||
|  | 	"golang.yandex/hasql/v2" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var errNoAliveNodes = errors.New("no alive nodes") | ||||||
|  |  | ||||||
|  | func newSQLRowError() *sql.Row { | ||||||
|  | 	row := &sql.Row{} | ||||||
|  | 	t := reflect.TypeOf(row).Elem() | ||||||
|  | 	field, _ := t.FieldByName("err") | ||||||
|  | 	rowPtr := unsafe.Pointer(row) | ||||||
|  | 	errFieldPtr := unsafe.Pointer(uintptr(rowPtr) + field.Offset) | ||||||
|  | 	errPtr := (*error)(errFieldPtr) | ||||||
|  | 	*errPtr = errNoAliveNodes | ||||||
|  | 	return row | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type ClusterQuerier interface { | ||||||
|  | 	Querier | ||||||
|  | 	WaitForNodes(ctx context.Context, criterion ...hasql.NodeStateCriterion) error | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Querier interface { | ||||||
|  | 	// Basic connection methods | ||||||
|  | 	PingContext(ctx context.Context) error | ||||||
|  | 	Close() error | ||||||
|  |  | ||||||
|  | 	// Query methods with context | ||||||
|  | 	ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) | ||||||
|  | 	QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) | ||||||
|  | 	QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row | ||||||
|  |  | ||||||
|  | 	// Prepared statements with context | ||||||
|  | 	PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) | ||||||
|  |  | ||||||
|  | 	// Transaction management with context | ||||||
|  | 	BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) | ||||||
|  |  | ||||||
|  | 	// Connection pool management | ||||||
|  | 	SetConnMaxLifetime(d time.Duration) | ||||||
|  | 	SetConnMaxIdleTime(d time.Duration) | ||||||
|  | 	SetMaxOpenConns(n int) | ||||||
|  | 	SetMaxIdleConns(n int) | ||||||
|  | 	Stats() sql.DBStats | ||||||
|  |  | ||||||
|  | 	Conn(ctx context.Context) (*sql.Conn, error) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | var ( | ||||||
|  | 	ErrClusterChecker    = errors.New("cluster node checker required") | ||||||
|  | 	ErrClusterDiscoverer = errors.New("cluster node discoverer required") | ||||||
|  | 	ErrClusterPicker     = errors.New("cluster node picker required") | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type Cluster struct { | ||||||
|  | 	hasql   *hasql.Cluster[Querier] | ||||||
|  | 	options ClusterOptions | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // NewCluster returns Querier that provides cluster of nodes | ||||||
|  | func NewCluster[T Querier](opts ...ClusterOption) (ClusterQuerier, error) { | ||||||
|  | 	options := ClusterOptions{Context: context.Background()} | ||||||
|  | 	for _, opt := range opts { | ||||||
|  | 		opt(&options) | ||||||
|  | 	} | ||||||
|  | 	if options.NodeChecker == nil { | ||||||
|  | 		return nil, ErrClusterChecker | ||||||
|  | 	} | ||||||
|  | 	if options.NodeDiscoverer == nil { | ||||||
|  | 		return nil, ErrClusterDiscoverer | ||||||
|  | 	} | ||||||
|  | 	if options.NodePicker == nil { | ||||||
|  | 		return nil, ErrClusterPicker | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if options.Retries < 1 { | ||||||
|  | 		options.Retries = 1 | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if options.NodeStateCriterion == 0 { | ||||||
|  | 		options.NodeStateCriterion = hasql.Primary | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	options.Options = append(options.Options, hasql.WithNodePicker(options.NodePicker)) | ||||||
|  | 	if p, ok := options.NodePicker.(*CustomPicker[Querier]); ok { | ||||||
|  | 		p.opts.Priority = options.NodePriority | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	c, err := hasql.NewCluster( | ||||||
|  | 		options.NodeDiscoverer, | ||||||
|  | 		options.NodeChecker, | ||||||
|  | 		options.Options..., | ||||||
|  | 	) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return &Cluster{hasql: c, options: options}, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // compile time guard | ||||||
|  | var _ hasql.NodePicker[Querier] = (*CustomPicker[Querier])(nil) | ||||||
|  |  | ||||||
|  | type nodeStateCriterionKey struct{} | ||||||
|  |  | ||||||
|  | // NodeStateCriterion inject hasql.NodeStateCriterion to context | ||||||
|  | func NodeStateCriterion(ctx context.Context, c hasql.NodeStateCriterion) context.Context { | ||||||
|  | 	return context.WithValue(ctx, nodeStateCriterionKey{}, c) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // CustomPickerOptions holds options to pick nodes | ||||||
|  | type CustomPickerOptions struct { | ||||||
|  | 	MaxLag   int | ||||||
|  | 	Priority map[string]int32 | ||||||
|  | 	Retries  int | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // CustomPickerOption func apply option to CustomPickerOptions | ||||||
|  | type CustomPickerOption func(*CustomPickerOptions) | ||||||
|  |  | ||||||
|  | // CustomPickerMaxLag specifies max lag for which node can be used | ||||||
|  | func CustomPickerMaxLag(n int) CustomPickerOption { | ||||||
|  | 	return func(o *CustomPickerOptions) { | ||||||
|  | 		o.MaxLag = n | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // NewCustomPicker creates new node picker | ||||||
|  | func NewCustomPicker[T Querier](opts ...CustomPickerOption) *CustomPicker[Querier] { | ||||||
|  | 	options := CustomPickerOptions{} | ||||||
|  | 	for _, o := range opts { | ||||||
|  | 		o(&options) | ||||||
|  | 	} | ||||||
|  | 	return &CustomPicker[Querier]{opts: options} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // CustomPicker holds node picker options | ||||||
|  | type CustomPicker[T Querier] struct { | ||||||
|  | 	opts CustomPickerOptions | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // PickNode used to return specific node | ||||||
|  | func (p *CustomPicker[T]) PickNode(cnodes []hasql.CheckedNode[T]) hasql.CheckedNode[T] { | ||||||
|  | 	for _, n := range cnodes { | ||||||
|  | 		fmt.Printf("node %s\n", n.Node.String()) | ||||||
|  | 	} | ||||||
|  | 	return cnodes[0] | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (p *CustomPicker[T]) getPriority(nodeName string) int32 { | ||||||
|  | 	if prio, ok := p.opts.Priority[nodeName]; ok { | ||||||
|  | 		return prio | ||||||
|  | 	} | ||||||
|  | 	return math.MaxInt32 // Default to lowest priority | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // CompareNodes used to sort nodes | ||||||
|  | func (p *CustomPicker[T]) CompareNodes(a, b hasql.CheckedNode[T]) int { | ||||||
|  | 	fmt.Printf("CompareNodes %s %s\n", a.Node.String(), b.Node.String()) | ||||||
|  | 	// Get replication lag values | ||||||
|  | 	aLag := a.Info.(interface{ ReplicationLag() int }).ReplicationLag() | ||||||
|  | 	bLag := b.Info.(interface{ ReplicationLag() int }).ReplicationLag() | ||||||
|  |  | ||||||
|  | 	// First check that lag lower then MaxLag | ||||||
|  | 	if aLag > p.opts.MaxLag && bLag > p.opts.MaxLag { | ||||||
|  | 		fmt.Printf("CompareNodes aLag > p.opts.MaxLag && bLag > p.opts.MaxLag\n") | ||||||
|  | 		return 0 // both are equal | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// If one node exceeds MaxLag and the other doesn't, prefer the one that doesn't | ||||||
|  | 	if aLag > p.opts.MaxLag { | ||||||
|  | 		fmt.Printf("CompareNodes aLag > p.opts.MaxLag\n") | ||||||
|  | 		return 1 // b is better | ||||||
|  | 	} | ||||||
|  | 	if bLag > p.opts.MaxLag { | ||||||
|  | 		fmt.Printf("CompareNodes bLag > p.opts.MaxLag\n") | ||||||
|  | 		return -1 // a is better | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Get node priorities | ||||||
|  | 	aPrio := p.getPriority(a.Node.String()) | ||||||
|  | 	bPrio := p.getPriority(b.Node.String()) | ||||||
|  |  | ||||||
|  | 	// if both priority equals | ||||||
|  | 	if aPrio == bPrio { | ||||||
|  | 		fmt.Printf("CompareNodes aPrio == bPrio\n") | ||||||
|  | 		// First compare by replication lag | ||||||
|  | 		if aLag < bLag { | ||||||
|  | 			fmt.Printf("CompareNodes aLag < bLag\n") | ||||||
|  | 			return -1 | ||||||
|  | 		} | ||||||
|  | 		if aLag > bLag { | ||||||
|  | 			fmt.Printf("CompareNodes aLag > bLag\n") | ||||||
|  | 			return 1 | ||||||
|  | 		} | ||||||
|  | 		// If replication lag is equal, compare by latency | ||||||
|  | 		aLatency := a.Info.(interface{ Latency() time.Duration }).Latency() | ||||||
|  | 		bLatency := b.Info.(interface{ Latency() time.Duration }).Latency() | ||||||
|  |  | ||||||
|  | 		if aLatency < bLatency { | ||||||
|  | 			return -1 | ||||||
|  | 		} | ||||||
|  | 		if aLatency > bLatency { | ||||||
|  | 			return 1 | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// If lag and latency is equal | ||||||
|  | 		return 0 | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// If priorities are different, prefer the node with lower priority value | ||||||
|  | 	if aPrio < bPrio { | ||||||
|  | 		return -1 | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return 1 | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ClusterOptions contains cluster specific options | ||||||
|  | type ClusterOptions struct { | ||||||
|  | 	NodeChecker        hasql.NodeChecker | ||||||
|  | 	NodePicker         hasql.NodePicker[Querier] | ||||||
|  | 	NodeDiscoverer     hasql.NodeDiscoverer[Querier] | ||||||
|  | 	Options            []hasql.ClusterOpt[Querier] | ||||||
|  | 	Context            context.Context | ||||||
|  | 	Retries            int | ||||||
|  | 	NodePriority       map[string]int32 | ||||||
|  | 	NodeStateCriterion hasql.NodeStateCriterion | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ClusterOption apply cluster options to ClusterOptions | ||||||
|  | type ClusterOption func(*ClusterOptions) | ||||||
|  |  | ||||||
|  | // WithClusterNodeChecker pass hasql.NodeChecker to cluster options | ||||||
|  | func WithClusterNodeChecker(c hasql.NodeChecker) ClusterOption { | ||||||
|  | 	return func(o *ClusterOptions) { | ||||||
|  | 		o.NodeChecker = c | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // WithClusterNodePicker pass hasql.NodePicker to cluster options | ||||||
|  | func WithClusterNodePicker(p hasql.NodePicker[Querier]) ClusterOption { | ||||||
|  | 	return func(o *ClusterOptions) { | ||||||
|  | 		o.NodePicker = p | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // WithClusterNodeDiscoverer pass hasql.NodeDiscoverer to cluster options | ||||||
|  | func WithClusterNodeDiscoverer(d hasql.NodeDiscoverer[Querier]) ClusterOption { | ||||||
|  | 	return func(o *ClusterOptions) { | ||||||
|  | 		o.NodeDiscoverer = d | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // WithRetries retry count on other nodes in case of error | ||||||
|  | func WithRetries(n int) ClusterOption { | ||||||
|  | 	return func(o *ClusterOptions) { | ||||||
|  | 		o.Retries = n | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // WithClusterContext pass context.Context to cluster options and used for checks | ||||||
|  | func WithClusterContext(ctx context.Context) ClusterOption { | ||||||
|  | 	return func(o *ClusterOptions) { | ||||||
|  | 		o.Context = ctx | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // WithClusterOptions pass hasql.ClusterOpt | ||||||
|  | func WithClusterOptions(opts ...hasql.ClusterOpt[Querier]) ClusterOption { | ||||||
|  | 	return func(o *ClusterOptions) { | ||||||
|  | 		o.Options = append(o.Options, opts...) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // WithClusterNodeStateCriterion pass default hasql.NodeStateCriterion | ||||||
|  | func WithClusterNodeStateCriterion(c hasql.NodeStateCriterion) ClusterOption { | ||||||
|  | 	return func(o *ClusterOptions) { | ||||||
|  | 		o.NodeStateCriterion = c | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type ClusterNode struct { | ||||||
|  | 	Name     string | ||||||
|  | 	DB       Querier | ||||||
|  | 	Priority int32 | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // WithClusterNodes create cluster with static NodeDiscoverer | ||||||
|  | func WithClusterNodes(cns ...ClusterNode) ClusterOption { | ||||||
|  | 	return func(o *ClusterOptions) { | ||||||
|  | 		nodes := make([]*hasql.Node[Querier], 0, len(cns)) | ||||||
|  | 		if o.NodePriority == nil { | ||||||
|  | 			o.NodePriority = make(map[string]int32, len(cns)) | ||||||
|  | 		} | ||||||
|  | 		for _, cn := range cns { | ||||||
|  | 			nodes = append(nodes, hasql.NewNode(cn.Name, cn.DB)) | ||||||
|  | 			if cn.Priority == 0 { | ||||||
|  | 				cn.Priority = math.MaxInt32 | ||||||
|  | 			} | ||||||
|  | 			o.NodePriority[cn.Name] = cn.Priority | ||||||
|  | 		} | ||||||
|  | 		o.NodeDiscoverer = hasql.NewStaticNodeDiscoverer(nodes...) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Cluster) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { | ||||||
|  | 	var tx *sql.Tx | ||||||
|  | 	var err error | ||||||
|  |  | ||||||
|  | 	retries := 0 | ||||||
|  | 	c.hasql.NodesIter(c.getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { | ||||||
|  | 		for ; retries < c.options.Retries; retries++ { | ||||||
|  | 			if tx, err = n.DB().BeginTx(ctx, opts); err != nil && retries >= c.options.Retries { | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		return false | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	if tx == nil && err == nil { | ||||||
|  | 		err = errNoAliveNodes | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return tx, err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Cluster) Close() error { | ||||||
|  | 	return c.hasql.Close() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Cluster) Conn(ctx context.Context) (*sql.Conn, error) { | ||||||
|  | 	var conn *sql.Conn | ||||||
|  | 	var err error | ||||||
|  |  | ||||||
|  | 	retries := 0 | ||||||
|  | 	c.hasql.NodesIter(c.getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { | ||||||
|  | 		for ; retries < c.options.Retries; retries++ { | ||||||
|  | 			if conn, err = n.DB().Conn(ctx); err != nil && retries >= c.options.Retries { | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		return false | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	if conn == nil && err == nil { | ||||||
|  | 		err = errNoAliveNodes | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return conn, err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Cluster) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { | ||||||
|  | 	var res sql.Result | ||||||
|  | 	var err error | ||||||
|  |  | ||||||
|  | 	retries := 0 | ||||||
|  | 	c.hasql.NodesIter(c.getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { | ||||||
|  | 		for ; retries < c.options.Retries; retries++ { | ||||||
|  | 			if res, err = n.DB().ExecContext(ctx, query, args...); err != nil && retries >= c.options.Retries { | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		return false | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	if res == nil && err == nil { | ||||||
|  | 		err = errNoAliveNodes | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return res, err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Cluster) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { | ||||||
|  | 	var res *sql.Stmt | ||||||
|  | 	var err error | ||||||
|  |  | ||||||
|  | 	retries := 0 | ||||||
|  | 	c.hasql.NodesIter(c.getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { | ||||||
|  | 		for ; retries < c.options.Retries; retries++ { | ||||||
|  | 			if res, err = n.DB().PrepareContext(ctx, query); err != nil && retries >= c.options.Retries { | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		return false | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	if res == nil && err == nil { | ||||||
|  | 		err = errNoAliveNodes | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return res, err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Cluster) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { | ||||||
|  | 	var res *sql.Rows | ||||||
|  | 	var err error | ||||||
|  |  | ||||||
|  | 	retries := 0 | ||||||
|  | 	c.hasql.NodesIter(c.getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { | ||||||
|  | 		for ; retries < c.options.Retries; retries++ { | ||||||
|  | 			if res, err = n.DB().QueryContext(ctx, query); err != nil && err != sql.ErrNoRows && retries >= c.options.Retries { | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		return false | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	if res == nil && err == nil { | ||||||
|  | 		err = errNoAliveNodes | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return res, err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Cluster) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { | ||||||
|  | 	var res *sql.Row | ||||||
|  |  | ||||||
|  | 	retries := 0 | ||||||
|  | 	c.hasql.NodesIter(c.getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { | ||||||
|  | 		for ; retries < c.options.Retries; retries++ { | ||||||
|  | 			res = n.DB().QueryRowContext(ctx, query, args...) | ||||||
|  | 			if res.Err() == nil { | ||||||
|  | 				return false | ||||||
|  | 			} else if res.Err() != nil && retries >= c.options.Retries { | ||||||
|  | 				return false | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		return true | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	if res == nil { | ||||||
|  | 		res = newSQLRowError() | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return res | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Cluster) PingContext(ctx context.Context) error { | ||||||
|  | 	var err error | ||||||
|  | 	var ok bool | ||||||
|  |  | ||||||
|  | 	retries := 0 | ||||||
|  | 	c.hasql.NodesIter(c.getNodeStateCriterion(ctx))(func(n *hasql.Node[Querier]) bool { | ||||||
|  | 		ok = true | ||||||
|  | 		for ; retries < c.options.Retries; retries++ { | ||||||
|  | 			if err = n.DB().PingContext(ctx); err != nil && retries >= c.options.Retries { | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		return false | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	if !ok { | ||||||
|  | 		err = errNoAliveNodes | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Cluster) WaitForNodes(ctx context.Context, criterions ...hasql.NodeStateCriterion) error { | ||||||
|  | 	for _, criterion := range criterions { | ||||||
|  | 		if _, err := c.hasql.WaitForNode(ctx, criterion); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Cluster) SetConnMaxLifetime(td time.Duration) { | ||||||
|  | 	c.hasql.NodesIter(hasql.NodeStateCriterion(hasql.Alive))(func(n *hasql.Node[Querier]) bool { | ||||||
|  | 		n.DB().SetConnMaxIdleTime(td) | ||||||
|  | 		return false | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Cluster) SetConnMaxIdleTime(td time.Duration) { | ||||||
|  | 	c.hasql.NodesIter(hasql.NodeStateCriterion(hasql.Alive))(func(n *hasql.Node[Querier]) bool { | ||||||
|  | 		n.DB().SetConnMaxIdleTime(td) | ||||||
|  | 		return false | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Cluster) SetMaxOpenConns(nc int) { | ||||||
|  | 	c.hasql.NodesIter(hasql.NodeStateCriterion(hasql.Alive))(func(n *hasql.Node[Querier]) bool { | ||||||
|  | 		n.DB().SetMaxOpenConns(nc) | ||||||
|  | 		return false | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Cluster) SetMaxIdleConns(nc int) { | ||||||
|  | 	c.hasql.NodesIter(hasql.NodeStateCriterion(hasql.Alive))(func(n *hasql.Node[Querier]) bool { | ||||||
|  | 		n.DB().SetMaxIdleConns(nc) | ||||||
|  | 		return false | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Cluster) Stats() sql.DBStats { | ||||||
|  | 	s := sql.DBStats{} | ||||||
|  | 	c.hasql.NodesIter(hasql.NodeStateCriterion(hasql.Alive))(func(n *hasql.Node[Querier]) bool { | ||||||
|  | 		st := n.DB().Stats() | ||||||
|  | 		s.Idle += st.Idle | ||||||
|  | 		s.InUse += st.InUse | ||||||
|  | 		s.MaxIdleClosed += st.MaxIdleClosed | ||||||
|  | 		s.MaxIdleTimeClosed += st.MaxIdleTimeClosed | ||||||
|  | 		s.MaxOpenConnections += st.MaxOpenConnections | ||||||
|  | 		s.OpenConnections += st.OpenConnections | ||||||
|  | 		s.WaitCount += st.WaitCount | ||||||
|  | 		s.WaitDuration += st.WaitDuration | ||||||
|  | 		return false | ||||||
|  | 	}) | ||||||
|  | 	return s | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Cluster) getNodeStateCriterion(ctx context.Context) hasql.NodeStateCriterion { | ||||||
|  | 	if v, ok := ctx.Value(nodeStateCriterionKey{}).(hasql.NodeStateCriterion); ok { | ||||||
|  | 		return v | ||||||
|  | 	} | ||||||
|  | 	return c.options.NodeStateCriterion | ||||||
|  | } | ||||||
							
								
								
									
										171
									
								
								cluster/sql/cluster_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										171
									
								
								cluster/sql/cluster_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,171 @@ | |||||||
|  | package sql | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"fmt" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/DATA-DOG/go-sqlmock" | ||||||
|  | 	"golang.yandex/hasql/v2" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestNewCluster(t *testing.T) { | ||||||
|  | 	dbMaster, dbMasterMock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	defer dbMaster.Close() | ||||||
|  | 	dbMasterMock.MatchExpectationsInOrder(false) | ||||||
|  |  | ||||||
|  | 	dbMasterMock.ExpectQuery(`.*pg_is_in_recovery.*`).WillReturnRows( | ||||||
|  | 		sqlmock.NewRowsWithColumnDefinition( | ||||||
|  | 			sqlmock.NewColumn("role").OfType("int8", 0), | ||||||
|  | 			sqlmock.NewColumn("replication_lag").OfType("int8", 0)). | ||||||
|  | 			AddRow(1, 0)). | ||||||
|  | 		RowsWillBeClosed(). | ||||||
|  | 		WithoutArgs() | ||||||
|  |  | ||||||
|  | 	dbMasterMock.ExpectQuery(`SELECT node_name as name`).WillReturnRows( | ||||||
|  | 		sqlmock.NewRows([]string{"name"}). | ||||||
|  | 			AddRow("master-dc1")) | ||||||
|  |  | ||||||
|  | 	dbDRMaster, dbDRMasterMock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	defer dbDRMaster.Close() | ||||||
|  | 	dbDRMasterMock.MatchExpectationsInOrder(false) | ||||||
|  |  | ||||||
|  | 	dbDRMasterMock.ExpectQuery(`.*pg_is_in_recovery.*`).WillReturnRows( | ||||||
|  | 		sqlmock.NewRowsWithColumnDefinition( | ||||||
|  | 			sqlmock.NewColumn("role").OfType("int8", 0), | ||||||
|  | 			sqlmock.NewColumn("replication_lag").OfType("int8", 0)). | ||||||
|  | 			AddRow(2, 40)). | ||||||
|  | 		RowsWillBeClosed(). | ||||||
|  | 		WithoutArgs() | ||||||
|  |  | ||||||
|  | 	dbDRMasterMock.ExpectQuery(`SELECT node_name as name`).WillReturnRows( | ||||||
|  | 		sqlmock.NewRows([]string{"name"}). | ||||||
|  | 			AddRow("drmaster1-dc2")) | ||||||
|  |  | ||||||
|  | 	dbDRMasterMock.ExpectQuery(`SELECT node_name as name`).WillReturnRows( | ||||||
|  | 		sqlmock.NewRows([]string{"name"}). | ||||||
|  | 			AddRow("drmaster")) | ||||||
|  |  | ||||||
|  | 	dbSlaveDC1, dbSlaveDC1Mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	defer dbSlaveDC1.Close() | ||||||
|  | 	dbSlaveDC1Mock.MatchExpectationsInOrder(false) | ||||||
|  |  | ||||||
|  | 	dbSlaveDC1Mock.ExpectQuery(`.*pg_is_in_recovery.*`).WillReturnRows( | ||||||
|  | 		sqlmock.NewRowsWithColumnDefinition( | ||||||
|  | 			sqlmock.NewColumn("role").OfType("int8", 0), | ||||||
|  | 			sqlmock.NewColumn("replication_lag").OfType("int8", 0)). | ||||||
|  | 			AddRow(2, 50)). | ||||||
|  | 		RowsWillBeClosed(). | ||||||
|  | 		WithoutArgs() | ||||||
|  |  | ||||||
|  | 	dbSlaveDC1Mock.ExpectQuery(`SELECT node_name as name`).WillReturnRows( | ||||||
|  | 		sqlmock.NewRows([]string{"name"}). | ||||||
|  | 			AddRow("slave-dc1")) | ||||||
|  |  | ||||||
|  | 	dbSlaveDC2, dbSlaveDC2Mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	defer dbSlaveDC2.Close() | ||||||
|  | 	dbSlaveDC1Mock.MatchExpectationsInOrder(false) | ||||||
|  |  | ||||||
|  | 	dbSlaveDC2Mock.ExpectQuery(`.*pg_is_in_recovery.*`).WillReturnRows( | ||||||
|  | 		sqlmock.NewRowsWithColumnDefinition( | ||||||
|  | 			sqlmock.NewColumn("role").OfType("int8", 0), | ||||||
|  | 			sqlmock.NewColumn("replication_lag").OfType("int8", 0)). | ||||||
|  | 			AddRow(2, 50)). | ||||||
|  | 		RowsWillBeClosed(). | ||||||
|  | 		WithoutArgs() | ||||||
|  |  | ||||||
|  | 	dbSlaveDC2Mock.ExpectQuery(`SELECT node_name as name`).WillReturnRows( | ||||||
|  | 		sqlmock.NewRows([]string{"name"}). | ||||||
|  | 			AddRow("slave-dc1")) | ||||||
|  |  | ||||||
|  | 	tctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) | ||||||
|  | 	defer cancel() | ||||||
|  |  | ||||||
|  | 	c, err := NewCluster[Querier]( | ||||||
|  | 		WithClusterContext(tctx), | ||||||
|  | 		WithClusterNodeChecker(hasql.PostgreSQLChecker), | ||||||
|  | 		WithClusterNodePicker(NewCustomPicker[Querier]( | ||||||
|  | 			CustomPickerMaxLag(100), | ||||||
|  | 		)), | ||||||
|  | 		WithClusterNodes( | ||||||
|  | 			ClusterNode{"slave-dc1", dbSlaveDC1, 1}, | ||||||
|  | 			ClusterNode{"master-dc1", dbMaster, 1}, | ||||||
|  | 			ClusterNode{"slave-dc2", dbSlaveDC2, 2}, | ||||||
|  | 			ClusterNode{"drmaster1-dc2", dbDRMaster, 0}, | ||||||
|  | 		), | ||||||
|  | 		WithClusterOptions( | ||||||
|  | 			hasql.WithUpdateInterval[Querier](2*time.Second), | ||||||
|  | 			hasql.WithUpdateTimeout[Querier](1*time.Second), | ||||||
|  | 		), | ||||||
|  | 	) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	defer c.Close() | ||||||
|  |  | ||||||
|  | 	if err = c.WaitForNodes(tctx, hasql.Primary, hasql.Standby); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	time.Sleep(500 * time.Millisecond) | ||||||
|  |  | ||||||
|  | 	node1Name := "" | ||||||
|  | 	fmt.Printf("check for Standby\n") | ||||||
|  | 	if row := c.QueryRowContext(NodeStateCriterion(tctx, hasql.Standby), "SELECT node_name as name"); row.Err() != nil { | ||||||
|  | 		t.Fatal(row.Err()) | ||||||
|  | 	} else if err = row.Scan(&node1Name); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} else if "slave-dc1" != node1Name { | ||||||
|  | 		t.Fatalf("invalid node name %s != %s", "slave-dc1", node1Name) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	dbSlaveDC1Mock.ExpectQuery(`SELECT node_name as name`).WillReturnRows( | ||||||
|  | 		sqlmock.NewRows([]string{"name"}). | ||||||
|  | 			AddRow("slave-dc1")) | ||||||
|  |  | ||||||
|  | 	node2Name := "" | ||||||
|  | 	fmt.Printf("check for PreferStandby\n") | ||||||
|  | 	if row := c.QueryRowContext(NodeStateCriterion(tctx, hasql.PreferStandby), "SELECT node_name as name"); row.Err() != nil { | ||||||
|  | 		t.Fatal(row.Err()) | ||||||
|  | 	} else if err = row.Scan(&node2Name); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} else if "slave-dc1" != node2Name { | ||||||
|  | 		t.Fatalf("invalid node name %s != %s", "slave-dc1", node2Name) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	node3Name := "" | ||||||
|  | 	fmt.Printf("check for PreferPrimary\n") | ||||||
|  | 	if row := c.QueryRowContext(NodeStateCriterion(tctx, hasql.PreferPrimary), "SELECT node_name as name"); row.Err() != nil { | ||||||
|  | 		t.Fatal(row.Err()) | ||||||
|  | 	} else if err = row.Scan(&node3Name); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} else if "master-dc1" != node3Name { | ||||||
|  | 		t.Fatalf("invalid node name %s != %s", "master-dc1", node3Name) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	dbSlaveDC1Mock.ExpectQuery(`.*`).WillReturnRows(sqlmock.NewRows([]string{"role"}).RowError(1, fmt.Errorf("row error"))) | ||||||
|  |  | ||||||
|  | 	time.Sleep(2 * time.Second) | ||||||
|  |  | ||||||
|  | 	fmt.Printf("check for PreferStandby\n") | ||||||
|  | 	if row := c.QueryRowContext(NodeStateCriterion(tctx, hasql.PreferStandby), "SELECT node_name as name"); row.Err() == nil { | ||||||
|  | 		t.Fatal("must return error") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if dbMasterErr := dbMasterMock.ExpectationsWereMet(); dbMasterErr != nil { | ||||||
|  | 		t.Error(dbMasterErr) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										5
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										5
									
								
								go.mod
									
									
									
									
									
								
							| @@ -1,6 +1,6 @@ | |||||||
| module go.unistack.org/micro/v3 | module go.unistack.org/micro/v3 | ||||||
|  |  | ||||||
| go 1.22.0 | go 1.24.0 | ||||||
|  |  | ||||||
| require ( | require ( | ||||||
| 	dario.cat/mergo v1.0.1 | 	dario.cat/mergo v1.0.1 | ||||||
| @@ -11,9 +11,11 @@ require ( | |||||||
| 	github.com/matoous/go-nanoid v1.5.1 | 	github.com/matoous/go-nanoid v1.5.1 | ||||||
| 	github.com/patrickmn/go-cache v2.1.0+incompatible | 	github.com/patrickmn/go-cache v2.1.0+incompatible | ||||||
| 	github.com/silas/dag v0.0.0-20220518035006-a7e85ada93c5 | 	github.com/silas/dag v0.0.0-20220518035006-a7e85ada93c5 | ||||||
|  | 	github.com/stretchr/testify v1.10.0 | ||||||
| 	go.uber.org/automaxprocs v1.6.0 | 	go.uber.org/automaxprocs v1.6.0 | ||||||
| 	go.unistack.org/micro-proto/v3 v3.4.1 | 	go.unistack.org/micro-proto/v3 v3.4.1 | ||||||
| 	golang.org/x/sync v0.10.0 | 	golang.org/x/sync v0.10.0 | ||||||
|  | 	golang.yandex/hasql/v2 v2.1.0 | ||||||
| 	google.golang.org/grpc v1.69.2 | 	google.golang.org/grpc v1.69.2 | ||||||
| 	google.golang.org/protobuf v1.36.1 | 	google.golang.org/protobuf v1.36.1 | ||||||
| 	gopkg.in/yaml.v3 v3.0.1 | 	gopkg.in/yaml.v3 v3.0.1 | ||||||
| @@ -33,7 +35,6 @@ require ( | |||||||
| 	github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect | 	github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect | ||||||
| 	github.com/rogpeppe/go-internal v1.13.1 // indirect | 	github.com/rogpeppe/go-internal v1.13.1 // indirect | ||||||
| 	github.com/sirupsen/logrus v1.9.3 // indirect | 	github.com/sirupsen/logrus v1.9.3 // indirect | ||||||
| 	github.com/stretchr/testify v1.10.0 // indirect |  | ||||||
| 	go.uber.org/goleak v1.3.0 // indirect | 	go.uber.org/goleak v1.3.0 // indirect | ||||||
| 	golang.org/x/exp v0.0.0-20241210194714-1829a127f884 // indirect | 	golang.org/x/exp v0.0.0-20241210194714-1829a127f884 // indirect | ||||||
| 	golang.org/x/net v0.33.0 // indirect | 	golang.org/x/net v0.33.0 // indirect | ||||||
|   | |||||||
							
								
								
									
										2
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								go.sum
									
									
									
									
									
								
							| @@ -89,6 +89,8 @@ golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= | |||||||
| golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= | golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= | ||||||
| golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= | golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= | ||||||
| golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= | golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= | ||||||
|  | golang.yandex/hasql/v2 v2.1.0 h1:7CaFFWeHoK5TvA+QvZzlKHlIN5sqNpqM8NSrXskZD/k= | ||||||
|  | golang.yandex/hasql/v2 v2.1.0/go.mod h1:3Au1AxuJDCTXmS117BpbI6e+70kGWeyLR1qJAH6HdtA= | ||||||
| google.golang.org/genproto/googleapis/rpc v0.0.0-20241216192217-9240e9c98484 h1:Z7FRVJPSMaHQxD0uXU8WdgFh8PseLM8Q8NzhnpMrBhQ= | google.golang.org/genproto/googleapis/rpc v0.0.0-20241216192217-9240e9c98484 h1:Z7FRVJPSMaHQxD0uXU8WdgFh8PseLM8Q8NzhnpMrBhQ= | ||||||
| google.golang.org/genproto/googleapis/rpc v0.0.0-20241216192217-9240e9c98484/go.mod h1:lcTa1sDdWEIHMWlITnIczmw5w60CF9ffkb8Z+DVmmjA= | google.golang.org/genproto/googleapis/rpc v0.0.0-20241216192217-9240e9c98484/go.mod h1:lcTa1sDdWEIHMWlITnIczmw5w60CF9ffkb8Z+DVmmjA= | ||||||
| google.golang.org/grpc v1.69.2 h1:U3S9QEtbXC0bYNvRtcoklF3xGtLViumSYxWykJS+7AU= | google.golang.org/grpc v1.69.2 h1:U3S9QEtbXC0bYNvRtcoklF3xGtLViumSYxWykJS+7AU= | ||||||
|   | |||||||
							
								
								
									
										76
									
								
								hooks/metadata/metadata.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										76
									
								
								hooks/metadata/metadata.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,76 @@ | |||||||
|  | package metadata | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  |  | ||||||
|  | 	"go.unistack.org/micro/v3/client" | ||||||
|  | 	"go.unistack.org/micro/v3/metadata" | ||||||
|  | 	"go.unistack.org/micro/v3/server" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var DefaultMetadataKeys = []string{"x-request-id"} | ||||||
|  |  | ||||||
|  | type hook struct { | ||||||
|  | 	keys []string | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func NewHook(keys ...string) *hook { | ||||||
|  | 	return &hook{keys: keys} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func metadataCopy(ctx context.Context, keys []string) context.Context { | ||||||
|  | 	if keys == nil { | ||||||
|  | 		return ctx | ||||||
|  | 	} | ||||||
|  | 	if imd, iok := metadata.FromIncomingContext(ctx); iok && imd != nil { | ||||||
|  | 		omd, ook := metadata.FromOutgoingContext(ctx) | ||||||
|  | 		if !ook || omd == nil { | ||||||
|  | 			omd = metadata.New(len(keys)) | ||||||
|  | 		} | ||||||
|  | 		for _, k := range keys { | ||||||
|  | 			if v, ok := imd.Get(k); ok && v != "" { | ||||||
|  | 				omd.Set(k, v) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		if !ook { | ||||||
|  | 			ctx = metadata.NewOutgoingContext(ctx, omd) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return ctx | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *hook) ClientCall(next client.FuncCall) client.FuncCall { | ||||||
|  | 	return func(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { | ||||||
|  | 		return next(metadataCopy(ctx, w.keys), req, rsp, opts...) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *hook) ClientStream(next client.FuncStream) client.FuncStream { | ||||||
|  | 	return func(ctx context.Context, req client.Request, opts ...client.CallOption) (client.Stream, error) { | ||||||
|  | 		return next(metadataCopy(ctx, w.keys), req, opts...) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *hook) ClientPublish(next client.FuncPublish) client.FuncPublish { | ||||||
|  | 	return func(ctx context.Context, msg client.Message, opts ...client.PublishOption) error { | ||||||
|  | 		return next(metadataCopy(ctx, w.keys), msg, opts...) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *hook) ClientBatchPublish(next client.FuncBatchPublish) client.FuncBatchPublish { | ||||||
|  | 	return func(ctx context.Context, msgs []client.Message, opts ...client.PublishOption) error { | ||||||
|  | 		return next(metadataCopy(ctx, w.keys), msgs, opts...) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *hook) ServerHandler(next server.FuncHandler) server.FuncHandler { | ||||||
|  | 	return func(ctx context.Context, req server.Request, rsp interface{}) error { | ||||||
|  | 		return next(metadataCopy(ctx, w.keys), req, rsp) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *hook) ServerSubscriber(next server.FuncSubHandler) server.FuncSubHandler { | ||||||
|  | 	return func(ctx context.Context, msg server.Message) error { | ||||||
|  | 		return next(metadataCopy(ctx, w.keys), msg) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										94
									
								
								hooks/recovery/recovery.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										94
									
								
								hooks/recovery/recovery.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,94 @@ | |||||||
|  | package recovery | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"fmt" | ||||||
|  |  | ||||||
|  | 	"go.unistack.org/micro/v3/errors" | ||||||
|  | 	"go.unistack.org/micro/v3/server" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func NewOptions(opts ...Option) Options { | ||||||
|  | 	options := Options{ | ||||||
|  | 		ServerHandlerFn:    DefaultServerHandlerFn, | ||||||
|  | 		ServerSubscriberFn: DefaultServerSubscriberFn, | ||||||
|  | 	} | ||||||
|  | 	for _, o := range opts { | ||||||
|  | 		o(&options) | ||||||
|  | 	} | ||||||
|  | 	return options | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Options struct { | ||||||
|  | 	ServerHandlerFn    func(context.Context, server.Request, interface{}, error) error | ||||||
|  | 	ServerSubscriberFn func(context.Context, server.Message, error) error | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Option func(*Options) | ||||||
|  |  | ||||||
|  | func ServerHandlerFunc(fn func(context.Context, server.Request, interface{}, error) error) Option { | ||||||
|  | 	return func(o *Options) { | ||||||
|  | 		o.ServerHandlerFn = fn | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func ServerSubscriberFunc(fn func(context.Context, server.Message, error) error) Option { | ||||||
|  | 	return func(o *Options) { | ||||||
|  | 		o.ServerSubscriberFn = fn | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | var ( | ||||||
|  | 	DefaultServerHandlerFn = func(ctx context.Context, req server.Request, rsp interface{}, err error) error { | ||||||
|  | 		return errors.BadRequest("", "%v", err) | ||||||
|  | 	} | ||||||
|  | 	DefaultServerSubscriberFn = func(ctx context.Context, req server.Message, err error) error { | ||||||
|  | 		return errors.BadRequest("", "%v", err) | ||||||
|  | 	} | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var Hook = NewHook() | ||||||
|  |  | ||||||
|  | type hook struct { | ||||||
|  | 	opts Options | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func NewHook(opts ...Option) *hook { | ||||||
|  | 	return &hook{opts: NewOptions(opts...)} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *hook) ServerHandler(next server.FuncHandler) server.FuncHandler { | ||||||
|  | 	return func(ctx context.Context, req server.Request, rsp interface{}) (err error) { | ||||||
|  | 		defer func() { | ||||||
|  | 			r := recover() | ||||||
|  | 			switch verr := r.(type) { | ||||||
|  | 			case nil: | ||||||
|  | 				return | ||||||
|  | 			case error: | ||||||
|  | 				err = w.opts.ServerHandlerFn(ctx, req, rsp, verr) | ||||||
|  | 			default: | ||||||
|  | 				err = w.opts.ServerHandlerFn(ctx, req, rsp, fmt.Errorf("%v", r)) | ||||||
|  | 			} | ||||||
|  | 		}() | ||||||
|  | 		err = next(ctx, req, rsp) | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *hook) ServerSubscriber(next server.FuncSubHandler) server.FuncSubHandler { | ||||||
|  | 	return func(ctx context.Context, msg server.Message) (err error) { | ||||||
|  | 		defer func() { | ||||||
|  | 			r := recover() | ||||||
|  | 			switch verr := r.(type) { | ||||||
|  | 			case nil: | ||||||
|  | 				return | ||||||
|  | 			case error: | ||||||
|  | 				err = w.opts.ServerSubscriberFn(ctx, msg, verr) | ||||||
|  | 			default: | ||||||
|  | 				err = w.opts.ServerSubscriberFn(ctx, msg, fmt.Errorf("%v", r)) | ||||||
|  | 			} | ||||||
|  | 		}() | ||||||
|  | 		err = next(ctx, msg) | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										139
									
								
								hooks/requestid/requestid.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										139
									
								
								hooks/requestid/requestid.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,139 @@ | |||||||
|  | package requestid | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"net/textproto" | ||||||
|  | 	 | ||||||
|  | 	"go.unistack.org/micro/v3/client" | ||||||
|  | 	"go.unistack.org/micro/v3/metadata" | ||||||
|  | 	"go.unistack.org/micro/v3/server" | ||||||
|  | 	"go.unistack.org/micro/v3/util/id" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type XRequestIDKey struct{} | ||||||
|  |  | ||||||
|  | // DefaultMetadataKey contains metadata key | ||||||
|  | var DefaultMetadataKey = textproto.CanonicalMIMEHeaderKey("x-request-id") | ||||||
|  |  | ||||||
|  | // DefaultMetadataFunc wil be used if user not provide own func to fill metadata | ||||||
|  | var DefaultMetadataFunc = func(ctx context.Context) (context.Context, error) { | ||||||
|  | 	var xid string | ||||||
|  |  | ||||||
|  | 	cid, cok := ctx.Value(XRequestIDKey{}).(string) | ||||||
|  | 	if cok && cid != "" { | ||||||
|  | 		xid = cid | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	imd, iok := metadata.FromIncomingContext(ctx) | ||||||
|  | 	if !iok || imd == nil { | ||||||
|  | 		imd = metadata.New(1) | ||||||
|  | 		ctx = metadata.NewIncomingContext(ctx, imd) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	omd, ook := metadata.FromOutgoingContext(ctx) | ||||||
|  | 	if !ook || omd == nil { | ||||||
|  | 		omd = metadata.New(1) | ||||||
|  | 		ctx = metadata.NewOutgoingContext(ctx, omd) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if xid == "" { | ||||||
|  | 		var id string | ||||||
|  | 		if id, iok = imd.Get(DefaultMetadataKey); iok && id != "" { | ||||||
|  | 			xid = id | ||||||
|  | 		} | ||||||
|  | 		if id, ook = omd.Get(DefaultMetadataKey); ook && id != "" { | ||||||
|  | 			xid = id | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if xid == "" { | ||||||
|  | 		var err error | ||||||
|  | 		xid, err = id.New() | ||||||
|  | 		if err != nil { | ||||||
|  | 			return ctx, err | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if !cok { | ||||||
|  | 		ctx = context.WithValue(ctx, XRequestIDKey{}, xid) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if !iok { | ||||||
|  | 		imd.Set(DefaultMetadataKey, xid) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if !ook { | ||||||
|  | 		omd.Set(DefaultMetadataKey, xid) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return ctx, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type hook struct{} | ||||||
|  |  | ||||||
|  | func NewHook() *hook { | ||||||
|  | 	return &hook{} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *hook) ServerSubscriber(next server.FuncSubHandler) server.FuncSubHandler { | ||||||
|  | 	return func(ctx context.Context, msg server.Message) error { | ||||||
|  | 		var err error | ||||||
|  | 		if xid, ok := msg.Header()[DefaultMetadataKey]; ok { | ||||||
|  | 			ctx = context.WithValue(ctx, XRequestIDKey{}, xid) | ||||||
|  | 		} | ||||||
|  | 		if ctx, err = DefaultMetadataFunc(ctx); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		return next(ctx, msg) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *hook) ServerHandler(next server.FuncHandler) server.FuncHandler { | ||||||
|  | 	return func(ctx context.Context, req server.Request, rsp interface{}) error { | ||||||
|  | 		var err error | ||||||
|  | 		if ctx, err = DefaultMetadataFunc(ctx); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		return next(ctx, req, rsp) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *hook) ClientBatchPublish(next client.FuncBatchPublish) client.FuncBatchPublish { | ||||||
|  | 	return func(ctx context.Context, msgs []client.Message, opts ...client.PublishOption) error { | ||||||
|  | 		var err error | ||||||
|  | 		if ctx, err = DefaultMetadataFunc(ctx); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		return next(ctx, msgs, opts...) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *hook) ClientPublish(next client.FuncPublish) client.FuncPublish { | ||||||
|  | 	return func(ctx context.Context, msg client.Message, opts ...client.PublishOption) error { | ||||||
|  | 		var err error | ||||||
|  | 		if ctx, err = DefaultMetadataFunc(ctx); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		return next(ctx, msg, opts...) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *hook) ClientCall(next client.FuncCall) client.FuncCall { | ||||||
|  | 	return func(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { | ||||||
|  | 		var err error | ||||||
|  | 		if ctx, err = DefaultMetadataFunc(ctx); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		return next(ctx, req, rsp, opts...) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *hook) ClientStream(next client.FuncStream) client.FuncStream { | ||||||
|  | 	return func(ctx context.Context, req client.Request, opts ...client.CallOption) (client.Stream, error) { | ||||||
|  | 		var err error | ||||||
|  | 		if ctx, err = DefaultMetadataFunc(ctx); err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 		return next(ctx, req, opts...) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										33
									
								
								hooks/requestid/requestid_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								hooks/requestid/requestid_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,33 @@ | |||||||
|  | package requestid | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"testing" | ||||||
|  |  | ||||||
|  | 	"go.unistack.org/micro/v3/metadata" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestDefaultMetadataFunc(t *testing.T) { | ||||||
|  | 	ctx := context.TODO() | ||||||
|  |  | ||||||
|  | 	nctx, err := DefaultMetadataFunc(ctx) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("%v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	imd, ok := metadata.FromIncomingContext(nctx) | ||||||
|  | 	if !ok { | ||||||
|  | 		t.Fatalf("md missing in incoming context") | ||||||
|  | 	} | ||||||
|  | 	omd, ok := metadata.FromOutgoingContext(nctx) | ||||||
|  | 	if !ok { | ||||||
|  | 		t.Fatalf("md missing in outgoing context") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	_, iok := imd.Get(DefaultMetadataKey) | ||||||
|  | 	_, ook := omd.Get(DefaultMetadataKey) | ||||||
|  |  | ||||||
|  | 	if !iok || !ook { | ||||||
|  | 		t.Fatalf("missing metadata key value") | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										51
									
								
								hooks/sql/common.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										51
									
								
								hooks/sql/common.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,51 @@ | |||||||
|  | package sql | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"database/sql/driver" | ||||||
|  | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  | 	"runtime" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | //go:generate sh -c "go run gen.go > wrap_gen.go" | ||||||
|  |  | ||||||
|  | // namedValueToValue converts driver arguments of NamedValue format to Value format. Implemented in the same way as in | ||||||
|  | // database/sql ctxutil.go. | ||||||
|  | func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { | ||||||
|  | 	dargs := make([]driver.Value, len(named)) | ||||||
|  | 	for n, param := range named { | ||||||
|  | 		if len(param.Name) > 0 { | ||||||
|  | 			return nil, errors.New("sql: driver does not support the use of Named Parameters") | ||||||
|  | 		} | ||||||
|  | 		dargs[n] = param.Value | ||||||
|  | 	} | ||||||
|  | 	return dargs, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // namedValueToLabels convert driver arguments to interface{} slice | ||||||
|  | func namedValueToLabels(named []driver.NamedValue) []interface{} { | ||||||
|  | 	largs := make([]interface{}, 0, len(named)*2) | ||||||
|  | 	var name string | ||||||
|  | 	for _, param := range named { | ||||||
|  | 		if param.Name != "" { | ||||||
|  | 			name = param.Name | ||||||
|  | 		} else { | ||||||
|  | 			name = fmt.Sprintf("$%d", param.Ordinal) | ||||||
|  | 		} | ||||||
|  | 		largs = append(largs, fmt.Sprintf("%s=%v", name, param.Value)) | ||||||
|  | 	} | ||||||
|  | 	return largs | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // getCallerName get the name of the function A where A() -> B() -> GetFunctionCallerName() | ||||||
|  | func getCallerName() string { | ||||||
|  | 	pc, _, _, ok := runtime.Caller(3) | ||||||
|  | 	details := runtime.FuncForPC(pc) | ||||||
|  | 	var callerName string | ||||||
|  | 	if ok && details != nil { | ||||||
|  | 		callerName = details.Name() | ||||||
|  | 	} else { | ||||||
|  | 		callerName = labelUnknown | ||||||
|  | 	} | ||||||
|  | 	return callerName | ||||||
|  | } | ||||||
							
								
								
									
										467
									
								
								hooks/sql/conn.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										467
									
								
								hooks/sql/conn.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,467 @@ | |||||||
|  | package sql | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"database/sql/driver" | ||||||
|  | 	"fmt" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	"go.unistack.org/micro/v3/hooks/requestid" | ||||||
|  | 	"go.unistack.org/micro/v3/tracer" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var ( | ||||||
|  | 	_ driver.Conn               = (*wrapperConn)(nil) | ||||||
|  | 	_ driver.ConnBeginTx        = (*wrapperConn)(nil) | ||||||
|  | 	_ driver.ConnPrepareContext = (*wrapperConn)(nil) | ||||||
|  | 	_ driver.Pinger             = (*wrapperConn)(nil) | ||||||
|  | 	_ driver.Validator          = (*wrapperConn)(nil) | ||||||
|  | 	_ driver.Queryer            = (*wrapperConn)(nil) // nolint:staticcheck | ||||||
|  | 	_ driver.QueryerContext     = (*wrapperConn)(nil) | ||||||
|  | 	_ driver.Execer             = (*wrapperConn)(nil) // nolint:staticcheck | ||||||
|  | 	_ driver.ExecerContext      = (*wrapperConn)(nil) | ||||||
|  | 	//	_ driver.Connector | ||||||
|  | 	//	_ driver.Driver | ||||||
|  | 	//	_ driver.DriverContext | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // wrapperConn defines a wrapper for driver.Conn | ||||||
|  | type wrapperConn struct { | ||||||
|  | 	d     *wrapperDriver | ||||||
|  | 	dname string | ||||||
|  | 	conn  driver.Conn | ||||||
|  | 	opts  Options | ||||||
|  | 	ctx   context.Context | ||||||
|  | 	//span  tracer.Span | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Close implements driver.Conn Close | ||||||
|  | func (w *wrapperConn) Close() error { | ||||||
|  | 	var ctx context.Context | ||||||
|  | 	if w.ctx != nil { | ||||||
|  | 		ctx = w.ctx | ||||||
|  | 	} else { | ||||||
|  | 		ctx = context.Background() | ||||||
|  | 	} | ||||||
|  | 	_ = ctx | ||||||
|  | 	labels := []string{labelMethod, "Close"} | ||||||
|  | 	ts := time.Now() | ||||||
|  | 	err := w.conn.Close() | ||||||
|  | 	td := time.Since(ts) | ||||||
|  | 	te := td.Seconds() | ||||||
|  | 	if err != nil { | ||||||
|  | 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelFailure)...).Inc() | ||||||
|  | 	} else { | ||||||
|  | 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelSuccess)...).Inc() | ||||||
|  | 	} | ||||||
|  | 	w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||||
|  | 	w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||||
|  | 	/* | ||||||
|  | 		if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||||
|  | 			w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "Close", getCallerName(), td, err)...) | ||||||
|  | 		} | ||||||
|  | 	*/ | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Begin implements driver.Conn Begin | ||||||
|  | func (w *wrapperConn) Begin() (driver.Tx, error) { | ||||||
|  | 	var ctx context.Context | ||||||
|  | 	if w.ctx != nil { | ||||||
|  | 		ctx = w.ctx | ||||||
|  | 	} else { | ||||||
|  | 		ctx = context.Background() | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	labels := []string{labelMethod, "Begin"} | ||||||
|  | 	ts := time.Now() | ||||||
|  | 	tx, err := w.conn.Begin() // nolint:staticcheck | ||||||
|  | 	td := time.Since(ts) | ||||||
|  | 	te := td.Seconds() | ||||||
|  | 	if err != nil { | ||||||
|  | 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelFailure)...).Inc() | ||||||
|  | 		w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||||
|  | 		w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||||
|  | 		/* | ||||||
|  | 			if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||||
|  | 				w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "Begin", getCallerName(), td, err)...) | ||||||
|  | 			} | ||||||
|  | 		*/ | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelSuccess)...).Inc() | ||||||
|  | 	w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||||
|  | 	w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||||
|  | 	/* | ||||||
|  | 		if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||||
|  | 			w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "Begin", getCallerName(), td, err)...) | ||||||
|  | 		} | ||||||
|  | 	*/ | ||||||
|  | 	return &wrapperTx{tx: tx, opts: w.opts, ctx: ctx}, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // BeginTx implements driver.ConnBeginTx BeginTx | ||||||
|  | func (w *wrapperConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { | ||||||
|  | 	name := getQueryName(ctx) | ||||||
|  | 	nctx, span := w.opts.Tracer.Start(ctx, "sdk.database", tracer.WithSpanKind(tracer.SpanKindClient)) | ||||||
|  | 	span.AddLabels("db.method", "BeginTx") | ||||||
|  | 	span.AddLabels("db.statement", name) | ||||||
|  | 	if id, ok := ctx.Value(requestid.XRequestIDKey{}).(string); ok { | ||||||
|  | 		span.AddLabels("x-request-id", id) | ||||||
|  | 	} | ||||||
|  | 	labels := []string{labelMethod, "BeginTx", labelQuery, name} | ||||||
|  |  | ||||||
|  | 	connBeginTx, ok := w.conn.(driver.ConnBeginTx) | ||||||
|  | 	if !ok { | ||||||
|  | 		return w.Begin() | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	ts := time.Now() | ||||||
|  | 	tx, err := connBeginTx.BeginTx(nctx, opts) | ||||||
|  | 	td := time.Since(ts) | ||||||
|  | 	te := td.Seconds() | ||||||
|  | 	if err != nil { | ||||||
|  | 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelFailure)...).Inc() | ||||||
|  | 		w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||||
|  | 		w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||||
|  | 		span.SetStatus(tracer.SpanStatusError, err.Error()) | ||||||
|  | 		/* | ||||||
|  | 			if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||||
|  | 				w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "BeginTx", getCallerName(), td, err)...) | ||||||
|  | 			} | ||||||
|  | 		*/ | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelSuccess)...).Inc() | ||||||
|  | 	/* | ||||||
|  | 		if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||||
|  | 			w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "BeginTx", getCallerName(), td, err)...) | ||||||
|  | 		} | ||||||
|  | 	*/ | ||||||
|  | 	return &wrapperTx{tx: tx, opts: w.opts, ctx: ctx, span: span}, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Prepare implements driver.Conn Prepare | ||||||
|  | func (w *wrapperConn) Prepare(query string) (driver.Stmt, error) { | ||||||
|  | 	var ctx context.Context | ||||||
|  | 	if w.ctx != nil { | ||||||
|  | 		ctx = w.ctx | ||||||
|  | 	} else { | ||||||
|  | 		ctx = context.Background() | ||||||
|  | 	} | ||||||
|  | 	_ = ctx | ||||||
|  | 	labels := []string{labelMethod, "Prepare", labelQuery, getCallerName()} | ||||||
|  | 	ts := time.Now() | ||||||
|  | 	stmt, err := w.conn.Prepare(query) | ||||||
|  | 	td := time.Since(ts) | ||||||
|  | 	te := td.Seconds() | ||||||
|  | 	if err != nil { | ||||||
|  | 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelFailure)...).Inc() | ||||||
|  | 		w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||||
|  | 		w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||||
|  | 		/* | ||||||
|  | 			if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||||
|  | 				w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "Prepare", getCallerName(), td, err)...) | ||||||
|  | 			} | ||||||
|  | 		*/ | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelSuccess)...).Inc() | ||||||
|  | 	w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||||
|  | 	w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||||
|  | 	/* | ||||||
|  | 		if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||||
|  | 			w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "Prepare", getCallerName(), td, err)...) | ||||||
|  | 		} | ||||||
|  | 	*/ | ||||||
|  | 	return wrapStmt(stmt, query, w.opts), nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // PrepareContext implements driver.ConnPrepareContext PrepareContext | ||||||
|  | func (w *wrapperConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { | ||||||
|  | 	var nctx context.Context | ||||||
|  | 	var span tracer.Span | ||||||
|  |  | ||||||
|  | 	name := getQueryName(ctx) | ||||||
|  | 	if w.ctx != nil { | ||||||
|  | 		nctx, span = w.opts.Tracer.Start(w.ctx, "sdk.database", tracer.WithSpanKind(tracer.SpanKindClient)) | ||||||
|  | 	} else { | ||||||
|  | 		nctx, span = w.opts.Tracer.Start(ctx, "sdk.database", tracer.WithSpanKind(tracer.SpanKindClient)) | ||||||
|  | 	} | ||||||
|  | 	span.AddLabels("db.method", "PrepareContext") | ||||||
|  | 	span.AddLabels("db.statement", name) | ||||||
|  | 	if id, ok := ctx.Value(requestid.XRequestIDKey{}).(string); ok { | ||||||
|  | 		span.AddLabels("x-request-id", id) | ||||||
|  | 	} | ||||||
|  | 	labels := []string{labelMethod, "PrepareContext", labelQuery, name} | ||||||
|  | 	conn, ok := w.conn.(driver.ConnPrepareContext) | ||||||
|  | 	if !ok { | ||||||
|  | 		return w.Prepare(query) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	ts := time.Now() | ||||||
|  | 	stmt, err := conn.PrepareContext(nctx, query) | ||||||
|  | 	td := time.Since(ts) | ||||||
|  | 	te := td.Seconds() | ||||||
|  | 	if err != nil { | ||||||
|  | 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelFailure)...).Inc() | ||||||
|  | 		w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||||
|  | 		w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||||
|  | 		span.SetStatus(tracer.SpanStatusError, err.Error()) | ||||||
|  | 		/* | ||||||
|  | 			if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||||
|  | 				w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "PrepareContext", getCallerName(), td, err)...) | ||||||
|  | 			} | ||||||
|  | 		*/ | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelSuccess)...).Inc() | ||||||
|  | 	w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||||
|  | 	w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||||
|  | 	/* | ||||||
|  | 		if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||||
|  | 			w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "PrepareContext", getCallerName(), td, err)...) | ||||||
|  | 		} | ||||||
|  | 	*/ | ||||||
|  | 	return wrapStmt(stmt, query, w.opts), nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Exec implements driver.Execer Exec | ||||||
|  | func (w *wrapperConn) Exec(query string, args []driver.Value) (driver.Result, error) { | ||||||
|  | 	var ctx context.Context | ||||||
|  | 	if w.ctx != nil { | ||||||
|  | 		ctx = w.ctx | ||||||
|  | 	} else { | ||||||
|  | 		ctx = context.Background() | ||||||
|  | 	} | ||||||
|  | 	_ = ctx | ||||||
|  | 	labels := []string{labelMethod, "Exec", labelQuery, getCallerName()} | ||||||
|  |  | ||||||
|  | 	// nolint:staticcheck | ||||||
|  | 	conn, ok := w.conn.(driver.Execer) | ||||||
|  | 	if !ok { | ||||||
|  | 		return nil, driver.ErrSkip | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	ts := time.Now() | ||||||
|  | 	res, err := conn.Exec(query, args) | ||||||
|  | 	td := time.Since(ts) | ||||||
|  | 	te := td.Seconds() | ||||||
|  | 	if err != nil { | ||||||
|  | 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelFailure)...).Inc() | ||||||
|  | 	} else { | ||||||
|  | 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelSuccess)...).Inc() | ||||||
|  | 	} | ||||||
|  | 	w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||||
|  | 	w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||||
|  | 	/* | ||||||
|  | 		if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||||
|  | 			w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "Exec", getCallerName(), td, err)...) | ||||||
|  | 		} | ||||||
|  | 	*/ | ||||||
|  | 	return res, err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Exec implements driver.StmtExecContext ExecContext | ||||||
|  | func (w *wrapperConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { | ||||||
|  | 	var nctx context.Context | ||||||
|  | 	var span tracer.Span | ||||||
|  |  | ||||||
|  | 	name := getQueryName(ctx) | ||||||
|  | 	if w.ctx != nil { | ||||||
|  | 		nctx, span = w.opts.Tracer.Start(w.ctx, "sdk.database", tracer.WithSpanKind(tracer.SpanKindClient)) | ||||||
|  | 	} else { | ||||||
|  | 		nctx, span = w.opts.Tracer.Start(ctx, "sdk.database", tracer.WithSpanKind(tracer.SpanKindClient)) | ||||||
|  | 	} | ||||||
|  | 	span.AddLabels("db.method", "ExecContext") | ||||||
|  | 	span.AddLabels("db.statement", name) | ||||||
|  | 	if id, ok := ctx.Value(requestid.XRequestIDKey{}).(string); ok { | ||||||
|  | 		span.AddLabels("x-request-id", id) | ||||||
|  | 	} | ||||||
|  | 	defer span.Finish() | ||||||
|  | 	if len(args) > 0 { | ||||||
|  | 		span.AddLabels("db.args", fmt.Sprintf("%v", namedValueToLabels(args))) | ||||||
|  | 	} | ||||||
|  | 	labels := []string{labelMethod, "ExecContext", labelQuery, name} | ||||||
|  |  | ||||||
|  | 	conn, ok := w.conn.(driver.ExecerContext) | ||||||
|  | 	if !ok { | ||||||
|  | 		// nolint:staticcheck | ||||||
|  | 		return nil, driver.ErrSkip | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	ts := time.Now() | ||||||
|  | 	res, err := conn.ExecContext(nctx, query, args) | ||||||
|  | 	td := time.Since(ts) | ||||||
|  | 	te := td.Seconds() | ||||||
|  | 	if err != nil { | ||||||
|  | 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelFailure)...).Inc() | ||||||
|  | 		span.SetStatus(tracer.SpanStatusError, err.Error()) | ||||||
|  | 	} else { | ||||||
|  | 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelSuccess)...).Inc() | ||||||
|  | 	} | ||||||
|  | 	w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelSuccess)...).Inc() | ||||||
|  | 	w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||||
|  | 	w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||||
|  | 	/* | ||||||
|  | 		if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||||
|  | 			w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "ExecContext", getCallerName(), td, err)...) | ||||||
|  | 		} | ||||||
|  | 	*/ | ||||||
|  | 	return res, err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Ping implements driver.Pinger Ping | ||||||
|  | func (w *wrapperConn) Ping(ctx context.Context) error { | ||||||
|  | 	conn, ok := w.conn.(driver.Pinger) | ||||||
|  |  | ||||||
|  | 	if !ok { | ||||||
|  | 		// fallback path to check db alive | ||||||
|  | 		pc, err := w.d.Open(w.dname) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		return pc.Close() | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	var nctx context.Context //nolint: gosimple | ||||||
|  | 	nctx = ctx | ||||||
|  | 	/* | ||||||
|  | 		var span tracer.Span | ||||||
|  | 		if w.ctx != nil { | ||||||
|  | 			nctx, span = w.opts.Tracer.Start(w.ctx, "sdk.database", tracer.WithSpanKind(tracer.SpanKindClient)) | ||||||
|  | 		} else { | ||||||
|  | 			nctx, span = w.opts.Tracer.Start(ctx, "sdk.database", tracer.WithSpanKind(tracer.SpanKindClient)) | ||||||
|  | 		} | ||||||
|  | 		span.AddLabels("db.method", "Ping") | ||||||
|  | 		defer span.Finish() | ||||||
|  | 	*/ | ||||||
|  | 	labels := []string{labelMethod, "Ping"} | ||||||
|  | 	ts := time.Now() | ||||||
|  | 	err := conn.Ping(nctx) | ||||||
|  | 	td := time.Since(ts) | ||||||
|  | 	te := td.Seconds() | ||||||
|  | 	if err != nil { | ||||||
|  | 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelFailure)...).Inc() | ||||||
|  | 		// span.SetStatus(tracer.SpanStatusError, err.Error()) | ||||||
|  | 		/* | ||||||
|  | 			if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||||
|  | 				w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "Ping", getCallerName(), td, err)...) | ||||||
|  | 			} | ||||||
|  | 		*/ | ||||||
|  | 		return err | ||||||
|  | 	} else { | ||||||
|  | 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelSuccess)...).Inc() | ||||||
|  | 	} | ||||||
|  | 	w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||||
|  | 	w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||||
|  |  | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Query implements driver.Queryer Query | ||||||
|  | func (w *wrapperConn) Query(query string, args []driver.Value) (driver.Rows, error) { | ||||||
|  | 	var ctx context.Context | ||||||
|  | 	if w.ctx != nil { | ||||||
|  | 		ctx = w.ctx | ||||||
|  | 	} else { | ||||||
|  | 		ctx = context.Background() | ||||||
|  | 	} | ||||||
|  | 	_ = ctx | ||||||
|  | 	// nolint:staticcheck | ||||||
|  | 	conn, ok := w.conn.(driver.Queryer) | ||||||
|  | 	if !ok { | ||||||
|  | 		return nil, driver.ErrSkip | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	labels := []string{labelMethod, "Query", labelQuery, getCallerName()} | ||||||
|  | 	ts := time.Now() | ||||||
|  | 	rows, err := conn.Query(query, args) | ||||||
|  | 	td := time.Since(ts) | ||||||
|  | 	te := td.Seconds() | ||||||
|  | 	if err != nil { | ||||||
|  | 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelFailure)...).Inc() | ||||||
|  | 	} else { | ||||||
|  | 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelSuccess)...).Inc() | ||||||
|  | 	} | ||||||
|  | 	w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||||
|  | 	w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||||
|  | 	/* | ||||||
|  | 		if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||||
|  | 			w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "Query", getCallerName(), td, err)...) | ||||||
|  | 		} | ||||||
|  | 	*/ | ||||||
|  | 	return rows, err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // QueryContext implements Driver.QueryerContext QueryContext | ||||||
|  | func (w *wrapperConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { | ||||||
|  | 	var nctx context.Context | ||||||
|  | 	var span tracer.Span | ||||||
|  |  | ||||||
|  | 	name := getQueryName(ctx) | ||||||
|  | 	if w.ctx != nil { | ||||||
|  | 		nctx, span = w.opts.Tracer.Start(w.ctx, "sdk.database", tracer.WithSpanKind(tracer.SpanKindClient)) | ||||||
|  | 	} else { | ||||||
|  | 		nctx, span = w.opts.Tracer.Start(ctx, "sdk.database", tracer.WithSpanKind(tracer.SpanKindClient)) | ||||||
|  | 	} | ||||||
|  | 	span.AddLabels("db.method", "QueryContext") | ||||||
|  | 	span.AddLabels("db.statement", name) | ||||||
|  | 	if id, ok := ctx.Value(requestid.XRequestIDKey{}).(string); ok { | ||||||
|  | 		span.AddLabels("x-request-id", id) | ||||||
|  | 	} | ||||||
|  | 	defer span.Finish() | ||||||
|  | 	if len(args) > 0 { | ||||||
|  | 		span.AddLabels("db.args", fmt.Sprintf("%v", namedValueToLabels(args))) | ||||||
|  | 	} | ||||||
|  | 	labels := []string{labelMethod, "QueryContext", labelQuery, name} | ||||||
|  | 	conn, ok := w.conn.(driver.QueryerContext) | ||||||
|  | 	if !ok { | ||||||
|  | 		return nil, driver.ErrSkip | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	ts := time.Now() | ||||||
|  | 	rows, err := conn.QueryContext(nctx, query, args) | ||||||
|  | 	td := time.Since(ts) | ||||||
|  | 	te := td.Seconds() | ||||||
|  | 	if err != nil { | ||||||
|  | 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelFailure)...).Inc() | ||||||
|  | 		span.SetStatus(tracer.SpanStatusError, err.Error()) | ||||||
|  | 	} else { | ||||||
|  | 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelSuccess)...).Inc() | ||||||
|  | 	} | ||||||
|  | 	w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||||
|  | 	w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||||
|  | 	/* | ||||||
|  | 		if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||||
|  | 			w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "QueryContext", getCallerName(), td, err)...) | ||||||
|  | 		} | ||||||
|  | 	*/ | ||||||
|  | 	return rows, err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // CheckNamedValue implements driver.NamedValueChecker | ||||||
|  | func (w *wrapperConn) CheckNamedValue(v *driver.NamedValue) error { | ||||||
|  | 	s, ok := w.conn.(driver.NamedValueChecker) | ||||||
|  | 	if !ok { | ||||||
|  | 		return driver.ErrSkip | ||||||
|  | 	} | ||||||
|  | 	return s.CheckNamedValue(v) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // IsValid implements driver.Validator | ||||||
|  | func (w *wrapperConn) IsValid() bool { | ||||||
|  | 	v, ok := w.conn.(driver.Validator) | ||||||
|  | 	if !ok { | ||||||
|  | 		return w.conn != nil | ||||||
|  | 	} | ||||||
|  | 	return v.IsValid() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *wrapperConn) ResetSession(ctx context.Context) error { | ||||||
|  | 	s, ok := w.conn.(driver.SessionResetter) | ||||||
|  | 	if !ok { | ||||||
|  | 		return driver.ErrSkip | ||||||
|  | 	} | ||||||
|  | 	return s.ResetSession(ctx) | ||||||
|  | } | ||||||
							
								
								
									
										94
									
								
								hooks/sql/driver.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										94
									
								
								hooks/sql/driver.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,94 @@ | |||||||
|  | package sql | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"database/sql/driver" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var ( | ||||||
|  | // _ driver.DriverContext = (*wrapperDriver)(nil) | ||||||
|  | // _ driver.Connector     = (*wrapperDriver)(nil) | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | /* | ||||||
|  | type conn interface { | ||||||
|  | 	driver.Pinger | ||||||
|  | 	driver.Execer | ||||||
|  | 	driver.ExecerContext | ||||||
|  | 	driver.Queryer | ||||||
|  | 	driver.QueryerContext | ||||||
|  | 	driver.Conn | ||||||
|  | 	driver.ConnPrepareContext | ||||||
|  | 	driver.ConnBeginTx | ||||||
|  | } | ||||||
|  | */ | ||||||
|  |  | ||||||
|  | // wrapperDriver defines a wrapper for driver.Driver | ||||||
|  | type wrapperDriver struct { | ||||||
|  | 	driver driver.Driver | ||||||
|  | 	opts   Options | ||||||
|  | 	ctx    context.Context | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // NewWrapper creates and returns a new SQL driver with passed capabilities | ||||||
|  | func NewWrapper(d driver.Driver, opts ...Option) driver.Driver { | ||||||
|  | 	return &wrapperDriver{driver: d, opts: NewOptions(opts...), ctx: context.Background()} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type wrappedConnector struct { | ||||||
|  | 	connector driver.Connector | ||||||
|  | //	name      string | ||||||
|  | 	opts      Options | ||||||
|  | 	ctx       context.Context | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func NewWrapperConnector(c driver.Connector, opts ...Option) driver.Connector { | ||||||
|  | 	return &wrappedConnector{connector: c, opts: NewOptions(opts...), ctx: context.Background()} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Connect implements driver.Driver Connect | ||||||
|  | func (w *wrappedConnector) Connect(ctx context.Context) (driver.Conn, error) { | ||||||
|  | 	return w.connector.Connect(ctx) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Driver implements driver.Driver Driver | ||||||
|  | func (w *wrappedConnector) Driver() driver.Driver { | ||||||
|  | 	return w.connector.Driver() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | /* | ||||||
|  | // Connect implements driver.Driver OpenConnector | ||||||
|  | func (w *wrapperDriver) OpenConnector(name string) (driver.Conn, error) { | ||||||
|  | 	return &wrapperConnector{driver: w.driver, name: name, opts: w.opts}, nil | ||||||
|  | } | ||||||
|  | */ | ||||||
|  |  | ||||||
|  | // Open implements driver.Driver Open | ||||||
|  | func (w *wrapperDriver) Open(name string) (driver.Conn, error) { | ||||||
|  | 	// ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) // Ensure eventual timeout | ||||||
|  | 	// defer cancel() | ||||||
|  |  | ||||||
|  | 	/* | ||||||
|  | 		connector, err := w.OpenConnector(name) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 		return connector.Connect(ctx) | ||||||
|  | 	*/ | ||||||
|  |  | ||||||
|  | 	ts := time.Now() | ||||||
|  | 	c, err := w.driver.Open(name) | ||||||
|  | 	td := time.Since(ts) | ||||||
|  | 	/* | ||||||
|  | 		if w.opts.LoggerEnabled { | ||||||
|  | 			w.opts.Logger.Log(w.ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(w.ctx, "Open", getCallerName(), td, err)...) | ||||||
|  | 		} | ||||||
|  | 	*/ | ||||||
|  | 	_ = td | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return wrapConn(c, w.opts), nil | ||||||
|  | } | ||||||
							
								
								
									
										165
									
								
								hooks/sql/gen.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										165
									
								
								hooks/sql/gen.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,165 @@ | |||||||
|  | //go:build ignore | ||||||
|  |  | ||||||
|  | package main | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"crypto/md5" | ||||||
|  | 	"fmt" | ||||||
|  | 	"io" | ||||||
|  | 	"sort" | ||||||
|  | 	"strings" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var connIfaces = []string{ | ||||||
|  | 	"driver.ConnBeginTx", | ||||||
|  | 	"driver.ConnPrepareContext", | ||||||
|  | 	"driver.Execer", | ||||||
|  | 	"driver.ExecerContext", | ||||||
|  | 	"driver.NamedValueChecker", | ||||||
|  | 	"driver.Pinger", | ||||||
|  | 	"driver.Queryer", | ||||||
|  | 	"driver.QueryerContext", | ||||||
|  | 	"driver.SessionResetter", | ||||||
|  | 	"driver.Validator", | ||||||
|  | } | ||||||
|  |  | ||||||
|  | var stmtIfaces = []string{ | ||||||
|  | 	"driver.StmtExecContext", | ||||||
|  | 	"driver.StmtQueryContext", | ||||||
|  | 	"driver.ColumnConverter", | ||||||
|  | 	"driver.NamedValueChecker", | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func getHash(s []string) string { | ||||||
|  | 	h := md5.New() | ||||||
|  | 	io.WriteString(h, strings.Join(s, "|")) | ||||||
|  | 	return fmt.Sprintf("%x", h.Sum(nil)) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func main() { | ||||||
|  | 	comboConn := all(connIfaces) | ||||||
|  |  | ||||||
|  | 	sort.Slice(comboConn, func(i, j int) bool { | ||||||
|  | 		return len(comboConn[i]) < len(comboConn[j]) | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	comboStmt := all(stmtIfaces) | ||||||
|  |  | ||||||
|  | 	sort.Slice(comboStmt, func(i, j int) bool { | ||||||
|  | 		return len(comboStmt[i]) < len(comboStmt[j]) | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	b := bytes.NewBuffer(nil) | ||||||
|  | 	b.WriteString("// Code generated. DO NOT EDIT.\n\n") | ||||||
|  | 	b.WriteString("package sql\n\n") | ||||||
|  | 	b.WriteString(`import "database/sql/driver"`) | ||||||
|  | 	b.WriteString("\n\n") | ||||||
|  |  | ||||||
|  | 	b.WriteString("func wrapConn(dc driver.Conn, opts Options) driver.Conn {\n") | ||||||
|  | 	b.WriteString("\tc := &wrapperConn{conn: dc, opts: opts}\n") | ||||||
|  |  | ||||||
|  | 	for idx := len(comboConn) - 1; idx >= 0; idx-- { | ||||||
|  | 		ifaces := comboConn[idx] | ||||||
|  | 		n := len(ifaces) | ||||||
|  | 		if n == 0 { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		h := getHash(ifaces) | ||||||
|  | 		b.WriteString(fmt.Sprintf("\tif _, ok := dc.(wrapConn%04d_%s); ok {\n", n, h)) | ||||||
|  | 		b.WriteString("\t\treturn struct {\n") | ||||||
|  | 		b.WriteString(strings.Join(append([]string{"\t\t\tdriver.Conn"}, ifaces...), "\n\t\t\t")) | ||||||
|  | 		b.WriteString("\n\t\t}{") | ||||||
|  | 		for idx := range ifaces { | ||||||
|  | 			if idx > 0 { | ||||||
|  | 				b.WriteString(", ") | ||||||
|  | 				b.WriteString("c") | ||||||
|  | 			} else if idx == 0 { | ||||||
|  | 				b.WriteString("c") | ||||||
|  | 			} else { | ||||||
|  | 				b.WriteString("c") | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		b.WriteString(", c}\n") | ||||||
|  | 		b.WriteString("\t}\n\n") | ||||||
|  | 	} | ||||||
|  | 	b.WriteString("\treturn c\n") | ||||||
|  | 	b.WriteString("}\n\n") | ||||||
|  |  | ||||||
|  | 	for idx := len(comboConn) - 1; idx >= 0; idx-- { | ||||||
|  | 		ifaces := comboConn[idx] | ||||||
|  | 		n := len(ifaces) | ||||||
|  | 		if n == 0 { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		h := getHash(ifaces) | ||||||
|  | 		b.WriteString(fmt.Sprintf("// %s\n", strings.Join(ifaces, "|"))) | ||||||
|  | 		b.WriteString(fmt.Sprintf("type wrapConn%04d_%s interface {\n", n, h)) | ||||||
|  | 		for _, iface := range ifaces { | ||||||
|  | 			b.WriteString(fmt.Sprintf("\t%s\n", iface)) | ||||||
|  | 		} | ||||||
|  | 		b.WriteString("}\n\n") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	b.WriteString("func wrapStmt(stmt driver.Stmt, query string, opts Options) driver.Stmt {\n") | ||||||
|  | 	b.WriteString("\tc := &wrapperStmt{stmt: stmt, query: query, opts: opts}\n") | ||||||
|  |  | ||||||
|  | 	for idx := len(comboStmt) - 1; idx >= 0; idx-- { | ||||||
|  | 		ifaces := comboStmt[idx] | ||||||
|  | 		n := len(ifaces) | ||||||
|  | 		if n == 0 { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		h := getHash(ifaces) | ||||||
|  | 		b.WriteString(fmt.Sprintf("\tif _, ok := stmt.(wrapStmt%04d_%s); ok {\n", n, h)) | ||||||
|  | 		b.WriteString("\t\treturn struct {\n") | ||||||
|  | 		b.WriteString(strings.Join(append([]string{"\t\t\tdriver.Stmt"}, ifaces...), "\n\t\t\t")) | ||||||
|  | 		b.WriteString("\n\t\t}{") | ||||||
|  | 		for idx := range ifaces { | ||||||
|  | 			if idx > 0 { | ||||||
|  | 				b.WriteString(", ") | ||||||
|  | 				b.WriteString("c") | ||||||
|  | 			} else if idx == 0 { | ||||||
|  | 				b.WriteString("c") | ||||||
|  | 			} else { | ||||||
|  | 				b.WriteString("c") | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		b.WriteString(", c}\n") | ||||||
|  | 		b.WriteString("\t}\n\n") | ||||||
|  | 	} | ||||||
|  | 	b.WriteString("\treturn c\n") | ||||||
|  | 	b.WriteString("}\n") | ||||||
|  |  | ||||||
|  | 	for idx := len(comboStmt) - 1; idx >= 0; idx-- { | ||||||
|  | 		ifaces := comboStmt[idx] | ||||||
|  | 		n := len(ifaces) | ||||||
|  | 		if n == 0 { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		h := getHash(ifaces) | ||||||
|  | 		b.WriteString(fmt.Sprintf("\n// %s\n", strings.Join(ifaces, "|"))) | ||||||
|  | 		b.WriteString(fmt.Sprintf("type wrapStmt%04d_%s interface {\n", n, h)) | ||||||
|  | 		for _, iface := range ifaces { | ||||||
|  | 			b.WriteString(fmt.Sprintf("\t%s\n", iface)) | ||||||
|  | 		} | ||||||
|  | 		b.WriteString("}\n") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	fmt.Printf("%s", b.String()) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // all returns all combinations for a given string array. | ||||||
|  | func all[T any](set []T) (subsets [][]T) { | ||||||
|  | 	length := uint(len(set)) | ||||||
|  | 	for subsetBits := 1; subsetBits < (1 << length); subsetBits++ { | ||||||
|  | 		var subset []T | ||||||
|  | 		for object := uint(0); object < length; object++ { | ||||||
|  | 			if (subsetBits>>object)&1 == 1 { | ||||||
|  | 				subset = append(subset, set[object]) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		subsets = append(subsets, subset) | ||||||
|  | 	} | ||||||
|  | 	return subsets | ||||||
|  | } | ||||||
							
								
								
									
										172
									
								
								hooks/sql/options.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										172
									
								
								hooks/sql/options.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,172 @@ | |||||||
|  | package sql | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"fmt" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	"go.unistack.org/micro/v3/logger" | ||||||
|  | 	"go.unistack.org/micro/v3/meter" | ||||||
|  | 	"go.unistack.org/micro/v3/tracer" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var ( | ||||||
|  | 	// DefaultMeterStatsInterval holds default stats interval | ||||||
|  | 	DefaultMeterStatsInterval = 5 * time.Second | ||||||
|  | 	// DefaultLoggerObserver used to prepare labels for logger | ||||||
|  | 	DefaultLoggerObserver = func(ctx context.Context, method string, query string, td time.Duration, err error) []interface{} { | ||||||
|  | 		labels := []interface{}{"db.method", method, "took", fmt.Sprintf("%v", td)} | ||||||
|  | 		if err != nil { | ||||||
|  | 			labels = append(labels, "error", err.Error()) | ||||||
|  | 		} | ||||||
|  | 		if query != labelUnknown { | ||||||
|  | 			labels = append(labels, "query", query) | ||||||
|  | 		} | ||||||
|  | 		return labels | ||||||
|  | 	} | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var ( | ||||||
|  | 	MaxOpenConnections = "micro_sql_max_open_conn" | ||||||
|  | 	OpenConnections    = "micro_sql_open_conn" | ||||||
|  | 	InuseConnections   = "micro_sql_inuse_conn" | ||||||
|  | 	IdleConnections    = "micro_sql_idle_conn" | ||||||
|  | 	WaitConnections    = "micro_sql_waited_conn" | ||||||
|  | 	BlockedSeconds     = "micro_sql_blocked_seconds" | ||||||
|  | 	MaxIdleClosed      = "micro_sql_max_idle_closed" | ||||||
|  | 	MaxIdletimeClosed  = "micro_sql_closed_max_idle" | ||||||
|  | 	MaxLifetimeClosed  = "micro_sql_closed_max_lifetime" | ||||||
|  |  | ||||||
|  | 	meterRequestTotal               = "micro_sql_request_total" | ||||||
|  | 	meterRequestLatencyMicroseconds = "micro_sql_latency_microseconds" | ||||||
|  | 	meterRequestDurationSeconds     = "micro_sql_request_duration_seconds" | ||||||
|  |  | ||||||
|  | 	labelUnknown  = "unknown" | ||||||
|  | 	labelQuery    = "db_statement" | ||||||
|  | 	labelMethod   = "db_method" | ||||||
|  | 	labelStatus   = "status" | ||||||
|  | 	labelSuccess  = "success" | ||||||
|  | 	labelFailure  = "failure" | ||||||
|  | 	labelHost     = "db_host" | ||||||
|  | 	labelDatabase = "db_name" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // Options struct holds wrapper options | ||||||
|  | type Options struct { | ||||||
|  | 	Logger             logger.Logger | ||||||
|  | 	Meter              meter.Meter | ||||||
|  | 	Tracer             tracer.Tracer | ||||||
|  | 	DatabaseHost       string | ||||||
|  | 	DatabaseName       string | ||||||
|  | 	MeterStatsInterval time.Duration | ||||||
|  | 	LoggerLevel        logger.Level | ||||||
|  | 	LoggerEnabled      bool | ||||||
|  | 	LoggerObserver     func(ctx context.Context, method string, name string, td time.Duration, err error) []interface{} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Option func signature | ||||||
|  | type Option func(*Options) | ||||||
|  |  | ||||||
|  | // NewOptions create new Options struct from provided option slice | ||||||
|  | func NewOptions(opts ...Option) Options { | ||||||
|  | 	options := Options{ | ||||||
|  | 		Logger:             logger.DefaultLogger, | ||||||
|  | 		Meter:              meter.DefaultMeter, | ||||||
|  | 		Tracer:             tracer.DefaultTracer, | ||||||
|  | 		MeterStatsInterval: DefaultMeterStatsInterval, | ||||||
|  | 		LoggerLevel:        logger.ErrorLevel, | ||||||
|  | 		LoggerObserver:     DefaultLoggerObserver, | ||||||
|  | 	} | ||||||
|  | 	for _, o := range opts { | ||||||
|  | 		o(&options) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	options.Meter = options.Meter.Clone( | ||||||
|  | 		meter.Labels( | ||||||
|  | 			labelHost, options.DatabaseHost, | ||||||
|  | 			labelDatabase, options.DatabaseName, | ||||||
|  | 		), | ||||||
|  | 	) | ||||||
|  |  | ||||||
|  | 	options.Logger = options.Logger.Clone(logger.WithAddCallerSkipCount(1)) | ||||||
|  |  | ||||||
|  | 	return options | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // MetricInterval specifies stats interval for *sql.DB | ||||||
|  | func MetricInterval(td time.Duration) Option { | ||||||
|  | 	return func(o *Options) { | ||||||
|  | 		o.MeterStatsInterval = td | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func DatabaseHost(host string) Option { | ||||||
|  | 	return func(o *Options) { | ||||||
|  | 		o.DatabaseHost = host | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func DatabaseName(name string) Option { | ||||||
|  | 	return func(o *Options) { | ||||||
|  | 		o.DatabaseName = name | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Meter passes meter.Meter to wrapper | ||||||
|  | func Meter(m meter.Meter) Option { | ||||||
|  | 	return func(o *Options) { | ||||||
|  | 		o.Meter = m | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Logger passes logger.Logger to wrapper | ||||||
|  | func Logger(l logger.Logger) Option { | ||||||
|  | 	return func(o *Options) { | ||||||
|  | 		o.Logger = l | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // LoggerEnabled enable sql logging | ||||||
|  | func LoggerEnabled(b bool) Option { | ||||||
|  | 	return func(o *Options) { | ||||||
|  | 		o.LoggerEnabled = b | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // LoggerLevel passes logger.Level option | ||||||
|  | func LoggerLevel(lvl logger.Level) Option { | ||||||
|  | 	return func(o *Options) { | ||||||
|  | 		o.LoggerLevel = lvl | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // LoggerObserver passes observer to fill logger fields | ||||||
|  | func LoggerObserver(obs func(context.Context, string, string, time.Duration, error) []interface{}) Option { | ||||||
|  | 	return func(o *Options) { | ||||||
|  | 		o.LoggerObserver = obs | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Tracer passes tracer.Tracer to wrapper | ||||||
|  | func Tracer(t tracer.Tracer) Option { | ||||||
|  | 	return func(o *Options) { | ||||||
|  | 		o.Tracer = t | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type queryNameKey struct{} | ||||||
|  |  | ||||||
|  | // QueryName passes query name to wrapper func | ||||||
|  | func QueryName(ctx context.Context, name string) context.Context { | ||||||
|  | 	if ctx == nil { | ||||||
|  | 		ctx = context.Background() | ||||||
|  | 	} | ||||||
|  | 	return context.WithValue(ctx, queryNameKey{}, name) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func getQueryName(ctx context.Context) string { | ||||||
|  | 	if v, ok := ctx.Value(queryNameKey{}).(string); ok && v != labelUnknown { | ||||||
|  | 		return v | ||||||
|  | 	} | ||||||
|  | 	return getCallerName() | ||||||
|  | } | ||||||
							
								
								
									
										41
									
								
								hooks/sql/stats.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								hooks/sql/stats.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,41 @@ | |||||||
|  | package sql | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"database/sql" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type Statser interface { | ||||||
|  | 	Stats() sql.DBStats | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func NewStatsMeter(ctx context.Context, db Statser, opts ...Option) { | ||||||
|  | 	options := NewOptions(opts...) | ||||||
|  |  | ||||||
|  | 	go func() { | ||||||
|  | 		ticker := time.NewTicker(options.MeterStatsInterval) | ||||||
|  | 		defer ticker.Stop() | ||||||
|  |  | ||||||
|  | 		for { | ||||||
|  | 			select { | ||||||
|  | 			case <-ctx.Done(): | ||||||
|  | 				return | ||||||
|  | 			case <-ticker.C: | ||||||
|  | 				if db == nil { | ||||||
|  | 					return | ||||||
|  | 				} | ||||||
|  | 				stats := db.Stats() | ||||||
|  | 				options.Meter.Counter(MaxOpenConnections).Set(uint64(stats.MaxOpenConnections)) | ||||||
|  | 				options.Meter.Counter(OpenConnections).Set(uint64(stats.OpenConnections)) | ||||||
|  | 				options.Meter.Counter(InuseConnections).Set(uint64(stats.InUse)) | ||||||
|  | 				options.Meter.Counter(IdleConnections).Set(uint64(stats.Idle)) | ||||||
|  | 				options.Meter.Counter(WaitConnections).Set(uint64(stats.WaitCount)) | ||||||
|  | 				options.Meter.FloatCounter(BlockedSeconds).Set(stats.WaitDuration.Seconds()) | ||||||
|  | 				options.Meter.Counter(MaxIdleClosed).Set(uint64(stats.MaxIdleClosed)) | ||||||
|  | 				options.Meter.Counter(MaxIdletimeClosed).Set(uint64(stats.MaxIdleTimeClosed)) | ||||||
|  | 				options.Meter.Counter(MaxLifetimeClosed).Set(uint64(stats.MaxLifetimeClosed)) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  | } | ||||||
							
								
								
									
										287
									
								
								hooks/sql/stmt.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										287
									
								
								hooks/sql/stmt.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,287 @@ | |||||||
|  | package sql | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"database/sql/driver" | ||||||
|  | 	"fmt" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	"go.unistack.org/micro/v3/hooks/requestid" | ||||||
|  | 	"go.unistack.org/micro/v3/tracer" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var ( | ||||||
|  | 	_ driver.Stmt              = (*wrapperStmt)(nil) | ||||||
|  | 	_ driver.StmtQueryContext  = (*wrapperStmt)(nil) | ||||||
|  | 	_ driver.StmtExecContext   = (*wrapperStmt)(nil) | ||||||
|  | 	_ driver.NamedValueChecker = (*wrapperStmt)(nil) | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // wrapperStmt defines a wrapper for driver.Stmt | ||||||
|  | type wrapperStmt struct { | ||||||
|  | 	stmt  driver.Stmt | ||||||
|  | 	opts  Options | ||||||
|  | 	query string | ||||||
|  | 	ctx   context.Context | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Close implements driver.Stmt Close | ||||||
|  | func (w *wrapperStmt) Close() error { | ||||||
|  | 	var ctx context.Context | ||||||
|  | 	if w.ctx != nil { | ||||||
|  | 		ctx = w.ctx | ||||||
|  | 	} else { | ||||||
|  | 		ctx = context.Background() | ||||||
|  | 	} | ||||||
|  | 	_ = ctx | ||||||
|  | 	labels := []string{labelMethod, "Close"} | ||||||
|  | 	ts := time.Now() | ||||||
|  | 	err := w.stmt.Close() | ||||||
|  | 	td := time.Since(ts) | ||||||
|  | 	te := td.Seconds() | ||||||
|  | 	if err != nil { | ||||||
|  | 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelFailure)...).Inc() | ||||||
|  | 	} else { | ||||||
|  | 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelSuccess)...).Inc() | ||||||
|  | 	} | ||||||
|  | 	w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||||
|  | 	w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||||
|  | 	/* | ||||||
|  | 		if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||||
|  | 			w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "Close", getCallerName(), td, err)...) | ||||||
|  | 		} | ||||||
|  | 	*/ | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // NumInput implements driver.Stmt NumInput | ||||||
|  | func (w *wrapperStmt) NumInput() int { | ||||||
|  | 	return w.stmt.NumInput() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // CheckNamedValue implements driver.NamedValueChecker | ||||||
|  | func (w *wrapperStmt) CheckNamedValue(v *driver.NamedValue) error { | ||||||
|  | 	s, ok := w.stmt.(driver.NamedValueChecker) | ||||||
|  | 	if !ok { | ||||||
|  | 		return driver.ErrSkip | ||||||
|  | 	} | ||||||
|  | 	return s.CheckNamedValue(v) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Exec implements driver.Stmt Exec | ||||||
|  | func (w *wrapperStmt) Exec(args []driver.Value) (driver.Result, error) { | ||||||
|  | 	var ctx context.Context | ||||||
|  | 	if w.ctx != nil { | ||||||
|  | 		ctx = w.ctx | ||||||
|  | 	} else { | ||||||
|  | 		ctx = context.Background() | ||||||
|  | 	} | ||||||
|  | 	_ = ctx | ||||||
|  | 	labels := []string{labelMethod, "Exec"} | ||||||
|  | 	ts := time.Now() | ||||||
|  | 	res, err := w.stmt.Exec(args) // nolint:staticcheck | ||||||
|  | 	td := time.Since(ts) | ||||||
|  | 	te := td.Seconds() | ||||||
|  | 	if err != nil { | ||||||
|  | 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelFailure)...).Inc() | ||||||
|  | 	} else { | ||||||
|  | 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelSuccess)...).Inc() | ||||||
|  | 	} | ||||||
|  | 	w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||||
|  | 	w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||||
|  | 	/* | ||||||
|  | 		if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||||
|  | 			w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "Exec", getCallerName(), td, err)...) | ||||||
|  | 		} | ||||||
|  | 	*/ | ||||||
|  | 	return res, err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Query implements driver.Stmt Query | ||||||
|  | func (w *wrapperStmt) Query(args []driver.Value) (driver.Rows, error) { | ||||||
|  | 	var ctx context.Context | ||||||
|  | 	if w.ctx != nil { | ||||||
|  | 		ctx = w.ctx | ||||||
|  | 	} else { | ||||||
|  | 		ctx = context.Background() | ||||||
|  | 	} | ||||||
|  | 	_ = ctx | ||||||
|  | 	labels := []string{labelMethod, "Query"} | ||||||
|  | 	ts := time.Now() | ||||||
|  | 	rows, err := w.stmt.Query(args) // nolint:staticcheck | ||||||
|  | 	td := time.Since(ts) | ||||||
|  | 	te := td.Seconds() | ||||||
|  | 	if err != nil { | ||||||
|  | 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelFailure)...).Inc() | ||||||
|  | 	} else { | ||||||
|  | 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelSuccess)...).Inc() | ||||||
|  | 	} | ||||||
|  | 	w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||||
|  | 	w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||||
|  | 	/* | ||||||
|  | 		if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||||
|  | 			w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "Query", getCallerName(), td, err)...) | ||||||
|  | 		} | ||||||
|  | 	*/ | ||||||
|  | 	return rows, err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ColumnConverter implements driver.ColumnConverter | ||||||
|  | func (w *wrapperStmt) ColumnConverter(idx int) driver.ValueConverter { | ||||||
|  | 	s, ok := w.stmt.(driver.ColumnConverter) // nolint:staticcheck | ||||||
|  | 	if !ok { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 	return s.ColumnConverter(idx) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ExecContext implements driver.StmtExecContext ExecContext | ||||||
|  | func (w *wrapperStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { | ||||||
|  | 	var nctx context.Context | ||||||
|  | 	var span tracer.Span | ||||||
|  |  | ||||||
|  | 	name := getQueryName(ctx) | ||||||
|  | 	if w.ctx != nil { | ||||||
|  | 		nctx, span = w.opts.Tracer.Start(w.ctx, "sdk.database", tracer.WithSpanKind(tracer.SpanKindClient)) | ||||||
|  | 	} else { | ||||||
|  | 		nctx, span = w.opts.Tracer.Start(ctx, "sdk.database", tracer.WithSpanKind(tracer.SpanKindClient)) | ||||||
|  | 	} | ||||||
|  | 	span.AddLabels("db.method", "ExecContext") | ||||||
|  | 	span.AddLabels("db.statement", name) | ||||||
|  | 	defer span.Finish() | ||||||
|  | 	if len(args) > 0 { | ||||||
|  | 		span.AddLabels("db.args", fmt.Sprintf("%v", namedValueToLabels(args))) | ||||||
|  | 	} | ||||||
|  | 	if id, ok := ctx.Value(requestid.XRequestIDKey{}).(string); ok { | ||||||
|  | 		span.AddLabels("x-request-id", id) | ||||||
|  | 	} | ||||||
|  | 	labels := []string{labelMethod, "ExecContext", labelQuery, name} | ||||||
|  |  | ||||||
|  | 	if conn, ok := w.stmt.(driver.StmtExecContext); ok { | ||||||
|  | 		ts := time.Now() | ||||||
|  | 		res, err := conn.ExecContext(nctx, args) | ||||||
|  | 		td := time.Since(ts) | ||||||
|  | 		te := td.Seconds() | ||||||
|  | 		if err != nil { | ||||||
|  | 			w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelFailure)...).Inc() | ||||||
|  | 			span.SetStatus(tracer.SpanStatusError, err.Error()) | ||||||
|  | 		} else { | ||||||
|  | 			w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelSuccess)...).Inc() | ||||||
|  | 		} | ||||||
|  | 		w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||||
|  | 		w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||||
|  | 		/* | ||||||
|  | 			if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||||
|  | 				w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "ExecContext", name, td, err)...) | ||||||
|  | 			} | ||||||
|  | 		*/ | ||||||
|  | 		return res, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	values, err := namedValueToValue(args) | ||||||
|  | 	if err != nil { | ||||||
|  | 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelFailure)...).Inc() | ||||||
|  | 		span.SetStatus(tracer.SpanStatusError, err.Error()) | ||||||
|  | 		/* | ||||||
|  | 			if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||||
|  | 				w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "ExecContext", name, 0, err)...) | ||||||
|  | 			} | ||||||
|  | 		*/ | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	ts := time.Now() | ||||||
|  | 	res, err := w.Exec(values) // nolint:staticcheck | ||||||
|  | 	td := time.Since(ts) | ||||||
|  | 	te := td.Seconds() | ||||||
|  | 	if err != nil { | ||||||
|  | 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelFailure)...).Inc() | ||||||
|  | 		span.SetStatus(tracer.SpanStatusError, err.Error()) | ||||||
|  | 	} else { | ||||||
|  | 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelSuccess)...).Inc() | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||||
|  | 	w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||||
|  | 	/* | ||||||
|  | 		if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||||
|  | 			w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "ExecContext", name, td, err)...) | ||||||
|  | 		} | ||||||
|  | 	*/ | ||||||
|  | 	return res, err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // QueryContext implements driver.StmtQueryContext StmtQueryContext | ||||||
|  | func (w *wrapperStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { | ||||||
|  | 	var nctx context.Context | ||||||
|  | 	var span tracer.Span | ||||||
|  |  | ||||||
|  | 	name := getQueryName(ctx) | ||||||
|  | 	if w.ctx != nil { | ||||||
|  | 		nctx, span = w.opts.Tracer.Start(w.ctx, "sdk.database", tracer.WithSpanKind(tracer.SpanKindClient)) | ||||||
|  | 	} else { | ||||||
|  | 		nctx, span = w.opts.Tracer.Start(ctx, "sdk.database", tracer.WithSpanKind(tracer.SpanKindClient)) | ||||||
|  | 	} | ||||||
|  | 	span.AddLabels("db.method", "QueryContext") | ||||||
|  | 	span.AddLabels("db.statement", name) | ||||||
|  | 	defer span.Finish() | ||||||
|  | 	if len(args) > 0 { | ||||||
|  | 		span.AddLabels("db.args", fmt.Sprintf("%v", namedValueToLabels(args))) | ||||||
|  | 	} | ||||||
|  | 	if id, ok := ctx.Value(requestid.XRequestIDKey{}).(string); ok { | ||||||
|  | 		span.AddLabels("x-request-id", id) | ||||||
|  | 	} | ||||||
|  | 	labels := []string{labelMethod, "QueryContext", labelQuery, name} | ||||||
|  | 	if conn, ok := w.stmt.(driver.StmtQueryContext); ok { | ||||||
|  | 		ts := time.Now() | ||||||
|  | 		rows, err := conn.QueryContext(nctx, args) | ||||||
|  | 		td := time.Since(ts) | ||||||
|  | 		te := td.Seconds() | ||||||
|  | 		if err != nil { | ||||||
|  | 			w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelFailure)...).Inc() | ||||||
|  | 			span.SetStatus(tracer.SpanStatusError, err.Error()) | ||||||
|  | 		} else { | ||||||
|  | 			w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelSuccess)...).Inc() | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||||
|  | 		w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||||
|  | 		/* | ||||||
|  | 			if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||||
|  | 				w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "QueryContext", name, td, err)...) | ||||||
|  | 			} | ||||||
|  | 		*/ | ||||||
|  | 		return rows, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	values, err := namedValueToValue(args) | ||||||
|  | 	if err != nil { | ||||||
|  | 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelFailure)...).Inc() | ||||||
|  |  | ||||||
|  | 		span.SetStatus(tracer.SpanStatusError, err.Error()) | ||||||
|  | 		/* | ||||||
|  | 			if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||||
|  | 				w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "QueryContext", name, 0, err)...) | ||||||
|  | 			} | ||||||
|  | 		*/ | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	ts := time.Now() | ||||||
|  | 	rows, err := w.Query(values) // nolint:staticcheck | ||||||
|  | 	td := time.Since(ts) | ||||||
|  | 	te := td.Seconds() | ||||||
|  | 	if err != nil { | ||||||
|  | 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelFailure)...).Inc() | ||||||
|  | 		span.SetStatus(tracer.SpanStatusError, err.Error()) | ||||||
|  | 	} else { | ||||||
|  | 		w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelSuccess)...).Inc() | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) | ||||||
|  | 	w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) | ||||||
|  | 	/* | ||||||
|  | 		if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||||
|  | 			w.opts.Logger.Log(ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(ctx, "QueryContext", name, td, err)...) | ||||||
|  | 		} | ||||||
|  | 	*/ | ||||||
|  | 	return rows, err | ||||||
|  | } | ||||||
							
								
								
									
										63
									
								
								hooks/sql/tx.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										63
									
								
								hooks/sql/tx.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,63 @@ | |||||||
|  | package sql | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"database/sql/driver" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	"go.unistack.org/micro/v3/tracer" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var _ driver.Tx = (*wrapperTx)(nil) | ||||||
|  |  | ||||||
|  | // wrapperTx defines a wrapper for driver.Tx | ||||||
|  | type wrapperTx struct { | ||||||
|  | 	tx   driver.Tx | ||||||
|  | 	span tracer.Span | ||||||
|  | 	opts Options | ||||||
|  | 	ctx  context.Context | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Commit implements driver.Tx Commit | ||||||
|  | func (w *wrapperTx) Commit() error { | ||||||
|  | 	ts := time.Now() | ||||||
|  | 	err := w.tx.Commit() | ||||||
|  | 	td := time.Since(ts) | ||||||
|  | 	_ = td | ||||||
|  | 	if w.span != nil { | ||||||
|  | 		if err != nil { | ||||||
|  | 			w.span.SetStatus(tracer.SpanStatusError, err.Error()) | ||||||
|  | 		} | ||||||
|  | 		w.span.Finish() | ||||||
|  | 	} | ||||||
|  | 	/* | ||||||
|  | 		if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||||
|  | 			w.opts.Logger.Log(w.ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(w.ctx, "Commit", getCallerName(), td, err)...) | ||||||
|  | 		} | ||||||
|  | 	*/ | ||||||
|  | 	w.ctx = nil | ||||||
|  |  | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Rollback implements driver.Tx Rollback | ||||||
|  | func (w *wrapperTx) Rollback() error { | ||||||
|  | 	ts := time.Now() | ||||||
|  | 	err := w.tx.Rollback() | ||||||
|  | 	td := time.Since(ts) | ||||||
|  | 	_ = td | ||||||
|  | 	if w.span != nil { | ||||||
|  | 		if err != nil { | ||||||
|  | 			w.span.SetStatus(tracer.SpanStatusError, err.Error()) | ||||||
|  | 		} | ||||||
|  | 		w.span.Finish() | ||||||
|  | 	} | ||||||
|  | 	/* | ||||||
|  | 		if w.opts.LoggerEnabled && w.opts.Logger.V(w.opts.LoggerLevel) { | ||||||
|  | 			w.opts.Logger.Log(w.ctx, w.opts.LoggerLevel, w.opts.LoggerObserver(w.ctx, "Rollback", getCallerName(), td, err)...) | ||||||
|  | 		} | ||||||
|  | 	*/ | ||||||
|  | 	w.ctx = nil | ||||||
|  |  | ||||||
|  | 	return err | ||||||
|  | } | ||||||
							
								
								
									
										19
									
								
								hooks/sql/wrap.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								hooks/sql/wrap.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,19 @@ | |||||||
|  | package sql | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"database/sql/driver" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | /* | ||||||
|  | func wrapDriver(d driver.Driver, opts Options) driver.Driver { | ||||||
|  | 	if _, ok := d.(driver.DriverContext); ok { | ||||||
|  | 		return &wrapperDriver{driver: d, opts: opts} | ||||||
|  | 	} | ||||||
|  | 	return struct{ driver.Driver }{&wrapperDriver{driver: d, opts: opts}} | ||||||
|  | } | ||||||
|  | */ | ||||||
|  |  | ||||||
|  | // WrapConn allows an existing driver.Conn to be wrapped. | ||||||
|  | func WrapConn(c driver.Conn, opts ...Option) driver.Conn { | ||||||
|  | 	return wrapConn(c, NewOptions(opts...)) | ||||||
|  | } | ||||||
							
								
								
									
										20699
									
								
								hooks/sql/wrap_gen.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20699
									
								
								hooks/sql/wrap_gen.go
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										194
									
								
								hooks/validator/validator.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										194
									
								
								hooks/validator/validator.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,194 @@ | |||||||
|  | package validator | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  |  | ||||||
|  | 	"go.unistack.org/micro/v3/client" | ||||||
|  | 	"go.unistack.org/micro/v3/errors" | ||||||
|  | 	"go.unistack.org/micro/v3/server" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var ( | ||||||
|  | 	DefaultClientErrorFunc = func(req client.Request, rsp interface{}, err error) error { | ||||||
|  | 		if rsp != nil { | ||||||
|  | 			return errors.BadGateway(req.Service(), "%v", err) | ||||||
|  | 		} | ||||||
|  | 		return errors.BadRequest(req.Service(), "%v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	DefaultServerErrorFunc = func(req server.Request, rsp interface{}, err error) error { | ||||||
|  | 		if rsp != nil { | ||||||
|  | 			return errors.BadGateway(req.Service(), "%v", err) | ||||||
|  | 		} | ||||||
|  | 		return errors.BadRequest(req.Service(), "%v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	DefaultPublishErrorFunc = func(msg client.Message, err error) error { | ||||||
|  | 		return errors.BadRequest(msg.Topic(), "%v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	DefaultSubscribeErrorFunc = func(msg server.Message, err error) error { | ||||||
|  | 		return errors.BadRequest(msg.Topic(), "%v", err) | ||||||
|  | 	} | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type ( | ||||||
|  | 	ClientErrorFunc    func(client.Request, interface{}, error) error | ||||||
|  | 	ServerErrorFunc    func(server.Request, interface{}, error) error | ||||||
|  | 	PublishErrorFunc   func(client.Message, error) error | ||||||
|  | 	SubscribeErrorFunc func(server.Message, error) error | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // Options struct holds wrapper options | ||||||
|  | type Options struct { | ||||||
|  | 	ClientErrorFn          ClientErrorFunc | ||||||
|  | 	ServerErrorFn          ServerErrorFunc | ||||||
|  | 	PublishErrorFn         PublishErrorFunc | ||||||
|  | 	SubscribeErrorFn       SubscribeErrorFunc | ||||||
|  | 	ClientValidateResponse bool | ||||||
|  | 	ServerValidateResponse bool | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Option func signature | ||||||
|  | type Option func(*Options) | ||||||
|  |  | ||||||
|  | func ClientValidateResponse(b bool) Option { | ||||||
|  | 	return func(o *Options) { | ||||||
|  | 		o.ClientValidateResponse = b | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func ServerValidateResponse(b bool) Option { | ||||||
|  | 	return func(o *Options) { | ||||||
|  | 		o.ClientValidateResponse = b | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func ClientReqErrorFn(fn ClientErrorFunc) Option { | ||||||
|  | 	return func(o *Options) { | ||||||
|  | 		o.ClientErrorFn = fn | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func ServerErrorFn(fn ServerErrorFunc) Option { | ||||||
|  | 	return func(o *Options) { | ||||||
|  | 		o.ServerErrorFn = fn | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func PublishErrorFn(fn PublishErrorFunc) Option { | ||||||
|  | 	return func(o *Options) { | ||||||
|  | 		o.PublishErrorFn = fn | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func SubscribeErrorFn(fn SubscribeErrorFunc) Option { | ||||||
|  | 	return func(o *Options) { | ||||||
|  | 		o.SubscribeErrorFn = fn | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func NewOptions(opts ...Option) Options { | ||||||
|  | 	options := Options{ | ||||||
|  | 		ClientErrorFn:    DefaultClientErrorFunc, | ||||||
|  | 		ServerErrorFn:    DefaultServerErrorFunc, | ||||||
|  | 		PublishErrorFn:   DefaultPublishErrorFunc, | ||||||
|  | 		SubscribeErrorFn: DefaultSubscribeErrorFunc, | ||||||
|  | 	} | ||||||
|  | 	for _, o := range opts { | ||||||
|  | 		o(&options) | ||||||
|  | 	} | ||||||
|  | 	return options | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func NewHook(opts ...Option) *hook { | ||||||
|  | 	return &hook{opts: NewOptions(opts...)} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type validator interface { | ||||||
|  | 	Validate() error | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type hook struct { | ||||||
|  | 	opts Options | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *hook) ClientCall(next client.FuncCall) client.FuncCall { | ||||||
|  | 	return func(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { | ||||||
|  | 		if v, ok := req.Body().(validator); ok { | ||||||
|  | 			if err := v.Validate(); err != nil { | ||||||
|  | 				return w.opts.ClientErrorFn(req, nil, err) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		err := next(ctx, req, rsp, opts...) | ||||||
|  | 		if v, ok := rsp.(validator); ok && w.opts.ClientValidateResponse { | ||||||
|  | 			if verr := v.Validate(); verr != nil { | ||||||
|  | 				return w.opts.ClientErrorFn(req, rsp, verr) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *hook) ClientStream(next client.FuncStream) client.FuncStream { | ||||||
|  | 	return func(ctx context.Context, req client.Request, opts ...client.CallOption) (client.Stream, error) { | ||||||
|  | 		if v, ok := req.Body().(validator); ok { | ||||||
|  | 			if err := v.Validate(); err != nil { | ||||||
|  | 				return nil, w.opts.ClientErrorFn(req, nil, err) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		return next(ctx, req, opts...) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *hook) ClientPublish(next client.FuncPublish) client.FuncPublish { | ||||||
|  | 	return func(ctx context.Context, msg client.Message, opts ...client.PublishOption) error { | ||||||
|  | 		if v, ok := msg.Payload().(validator); ok { | ||||||
|  | 			if err := v.Validate(); err != nil { | ||||||
|  | 				return w.opts.PublishErrorFn(msg, err) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		return next(ctx, msg, opts...) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *hook) ClientBatchPublish(next client.FuncBatchPublish) client.FuncBatchPublish { | ||||||
|  | 	return func(ctx context.Context, msgs []client.Message, opts ...client.PublishOption) error { | ||||||
|  | 		for _, msg := range msgs { | ||||||
|  | 			if v, ok := msg.Payload().(validator); ok { | ||||||
|  | 				if err := v.Validate(); err != nil { | ||||||
|  | 					return w.opts.PublishErrorFn(msg, err) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		return next(ctx, msgs, opts...) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *hook) ServerHandler(next server.FuncHandler) server.FuncHandler { | ||||||
|  | 	return func(ctx context.Context, req server.Request, rsp interface{}) error { | ||||||
|  | 		if v, ok := req.Body().(validator); ok { | ||||||
|  | 			if err := v.Validate(); err != nil { | ||||||
|  | 				return w.opts.ServerErrorFn(req, nil, err) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		err := next(ctx, req, rsp) | ||||||
|  | 		if v, ok := rsp.(validator); ok && w.opts.ServerValidateResponse { | ||||||
|  | 			if verr := v.Validate(); verr != nil { | ||||||
|  | 				return w.opts.ServerErrorFn(req, rsp, verr) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *hook) ServerSubscriber(next server.FuncSubHandler) server.FuncSubHandler { | ||||||
|  | 	return func(ctx context.Context, msg server.Message) error { | ||||||
|  | 		if v, ok := msg.Body().(validator); ok { | ||||||
|  | 			if err := v.Validate(); err != nil { | ||||||
|  | 				return w.opts.SubscribeErrorFn(msg, err) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		return next(ctx, msg) | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @@ -99,6 +99,7 @@ func WithAddFields(fields ...interface{}) Option { | |||||||
| 					iv, iok := o.Fields[i].(string) | 					iv, iok := o.Fields[i].(string) | ||||||
| 					jv, jok := fields[j].(string) | 					jv, jok := fields[j].(string) | ||||||
| 					if iok && jok && iv == jv { | 					if iok && jok && iv == jv { | ||||||
|  | 						o.Fields[i+1] = fields[j+1] | ||||||
| 						fields = slices.Delete(fields, j, j+2) | 						fields = slices.Delete(fields, j, j+2) | ||||||
| 					} | 					} | ||||||
| 				} | 				} | ||||||
|   | |||||||
| @@ -278,7 +278,7 @@ func (s *slogLogger) printLog(ctx context.Context, lvl logger.Level, msg string, | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if (s.opts.AddStacktrace || lvl == logger.FatalLevel) || (s.opts.AddStacktrace && lvl == logger.ErrorLevel) { | 	if s.opts.AddStacktrace && (lvl == logger.FatalLevel || lvl == logger.ErrorLevel) { | ||||||
| 		stackInfo := make([]byte, 1024*1024) | 		stackInfo := make([]byte, 1024*1024) | ||||||
| 		if stackSize := runtime.Stack(stackInfo, false); stackSize > 0 { | 		if stackSize := runtime.Stack(stackInfo, false); stackSize > 0 { | ||||||
| 			traceLines := reTrace.Split(string(stackInfo[:stackSize]), -1) | 			traceLines := reTrace.Split(string(stackInfo[:stackSize]), -1) | ||||||
|   | |||||||
| @@ -21,7 +21,7 @@ import ( | |||||||
| func TestStacktrace(t *testing.T) { | func TestStacktrace(t *testing.T) { | ||||||
| 	ctx := context.TODO() | 	ctx := context.TODO() | ||||||
| 	buf := bytes.NewBuffer(nil) | 	buf := bytes.NewBuffer(nil) | ||||||
| 	l := NewLogger(logger.WithLevel(logger.ErrorLevel), logger.WithOutput(buf), | 	l := NewLogger(logger.WithLevel(logger.DebugLevel), logger.WithOutput(buf), | ||||||
| 		WithHandlerFunc(slog.NewTextHandler), | 		WithHandlerFunc(slog.NewTextHandler), | ||||||
| 		logger.WithAddStacktrace(true), | 		logger.WithAddStacktrace(true), | ||||||
| 	) | 	) | ||||||
| @@ -62,7 +62,7 @@ func TestTime(t *testing.T) { | |||||||
| 		WithHandlerFunc(slog.NewTextHandler), | 		WithHandlerFunc(slog.NewTextHandler), | ||||||
| 		logger.WithAddStacktrace(true), | 		logger.WithAddStacktrace(true), | ||||||
| 		logger.WithTimeFunc(func() time.Time { | 		logger.WithTimeFunc(func() time.Time { | ||||||
| 			return time.Unix(0, 0) | 			return time.Unix(0, 0).UTC() | ||||||
| 		}), | 		}), | ||||||
| 	) | 	) | ||||||
| 	if err := l.Init(logger.WithFields("key1", "val1")); err != nil { | 	if err := l.Init(logger.WithFields("key1", "val1")); err != nil { | ||||||
| @@ -71,8 +71,7 @@ func TestTime(t *testing.T) { | |||||||
|  |  | ||||||
| 	l.Error(ctx, "msg1", errors.New("err")) | 	l.Error(ctx, "msg1", errors.New("err")) | ||||||
|  |  | ||||||
| 	if !bytes.Contains(buf.Bytes(), []byte(`timestamp=1970-01-01T03:00:00.000000000+03:00`)) && | 	if !bytes.Contains(buf.Bytes(), []byte(`timestamp=1970-01-01T00:00:00.000000000Z`)) { | ||||||
| 		!bytes.Contains(buf.Bytes(), []byte(`timestamp=1970-01-01T00:00:00.000000000Z`)) { |  | ||||||
| 		t.Fatalf("logger error not works, buf contains: %s", buf.Bytes()) | 		t.Fatalf("logger error not works, buf contains: %s", buf.Bytes()) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| @@ -124,7 +123,7 @@ func TestWithDedupKeysWithAddFields(t *testing.T) { | |||||||
|  |  | ||||||
| 	l.Info(ctx, "msg3") | 	l.Info(ctx, "msg3") | ||||||
|  |  | ||||||
| 	if !bytes.Contains(buf.Bytes(), []byte(`msg=msg3 key1=val1 key2=val2`)) { | 	if !bytes.Contains(buf.Bytes(), []byte(`msg=msg3 key1=val4 key2=val3`)) { | ||||||
| 		t.Fatalf("logger error not works, buf contains: %s", buf.Bytes()) | 		t.Fatalf("logger error not works, buf contains: %s", buf.Bytes()) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|   | |||||||
| @@ -21,11 +21,6 @@ import ( | |||||||
| 	"go.unistack.org/micro/v3/util/rand" | 	"go.unistack.org/micro/v3/util/rand" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // DefaultCodecs will be used to encode/decode |  | ||||||
| var DefaultCodecs = map[string]codec.Codec{ |  | ||||||
| 	"application/octet-stream": codec.NewCodec(), |  | ||||||
| } |  | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	defaultContentType = "application/json" | 	defaultContentType = "application/json" | ||||||
| ) | ) | ||||||
| @@ -93,9 +88,6 @@ func (n *noopServer) newCodec(contentType string) (codec.Codec, error) { | |||||||
| 	if cf, ok := n.opts.Codecs[contentType]; ok { | 	if cf, ok := n.opts.Codecs[contentType]; ok { | ||||||
| 		return cf, nil | 		return cf, nil | ||||||
| 	} | 	} | ||||||
| 	if cf, ok := DefaultCodecs[contentType]; ok { |  | ||||||
| 		return cf, nil |  | ||||||
| 	} |  | ||||||
| 	return nil, codec.ErrUnknownContentType | 	return nil, codec.ErrUnknownContentType | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										20
									
								
								service.go
									
									
									
									
									
								
							
							
						
						
									
										20
									
								
								service.go
									
									
									
									
									
								
							| @@ -104,6 +104,7 @@ type service struct { | |||||||
| 	done chan struct{} | 	done chan struct{} | ||||||
| 	opts Options | 	opts Options | ||||||
| 	sync.RWMutex | 	sync.RWMutex | ||||||
|  | 	stopped bool | ||||||
| } | } | ||||||
|  |  | ||||||
| // NewService creates and returns a new Service based on the packages within. | // NewService creates and returns a new Service based on the packages within. | ||||||
| @@ -429,7 +430,7 @@ func (s *service) Stop() error { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	close(s.done) | 	s.notifyShutdown() | ||||||
|  |  | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| @@ -453,10 +454,23 @@ func (s *service) Run() error { | |||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// wait on context cancel |  | ||||||
| 	<-s.done | 	<-s.done | ||||||
|  |  | ||||||
| 	return s.Stop() | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // notifyShutdown marks the service as stopped and closes the done channel. | ||||||
|  | // It ensures the channel is closed only once, preventing multiple closures. | ||||||
|  | func (s *service) notifyShutdown() { | ||||||
|  | 	s.Lock() | ||||||
|  | 	if s.stopped { | ||||||
|  | 		s.Unlock() | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	s.stopped = true | ||||||
|  | 	s.Unlock() | ||||||
|  |  | ||||||
|  | 	close(s.done) | ||||||
| } | } | ||||||
|  |  | ||||||
| type Namer interface { | type Namer interface { | ||||||
|   | |||||||
| @@ -4,7 +4,9 @@ import ( | |||||||
| 	"context" | 	"context" | ||||||
| 	"reflect" | 	"reflect" | ||||||
| 	"testing" | 	"testing" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/stretchr/testify/require" | ||||||
| 	"go.unistack.org/micro/v3/broker" | 	"go.unistack.org/micro/v3/broker" | ||||||
| 	"go.unistack.org/micro/v3/client" | 	"go.unistack.org/micro/v3/client" | ||||||
| 	"go.unistack.org/micro/v3/config" | 	"go.unistack.org/micro/v3/config" | ||||||
| @@ -773,3 +775,41 @@ func Test_getNameIndex(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
| */ | */ | ||||||
|  |  | ||||||
|  | func TestServiceShutdown(t *testing.T) { | ||||||
|  | 	defer func() { | ||||||
|  | 		if r := recover(); r != nil { | ||||||
|  | 			t.Fatalf("service shutdown failed: %v", r) | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  |  | ||||||
|  | 	s, ok := NewService().(*service) | ||||||
|  | 	require.NotNil(t, s) | ||||||
|  | 	require.True(t, ok) | ||||||
|  |  | ||||||
|  | 	require.NoError(t, s.Start()) | ||||||
|  | 	require.False(t, s.stopped) | ||||||
|  |  | ||||||
|  | 	require.NoError(t, s.Stop()) | ||||||
|  | 	require.True(t, s.stopped) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestServiceMultipleShutdowns(t *testing.T) { | ||||||
|  | 	defer func() { | ||||||
|  | 		if r := recover(); r != nil { | ||||||
|  | 			t.Fatalf("service shutdown failed: %v", r) | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  |  | ||||||
|  | 	s := NewService() | ||||||
|  |  | ||||||
|  | 	go func() { | ||||||
|  | 		time.Sleep(10 * time.Millisecond) | ||||||
|  | 		// first call | ||||||
|  | 		require.NoError(t, s.Stop()) | ||||||
|  | 		// duplicate call | ||||||
|  | 		require.NoError(t, s.Stop()) | ||||||
|  | 	}() | ||||||
|  |  | ||||||
|  | 	require.NoError(t, s.Run()) | ||||||
|  | } | ||||||
|   | |||||||
| @@ -46,6 +46,10 @@ func (s memoryStringer) String() string { | |||||||
| 	return s.s | 	return s.s | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (t *Tracer) Enabled() bool { | ||||||
|  | 	return t.opts.Enabled | ||||||
|  | } | ||||||
|  |  | ||||||
| func (t *Tracer) Flush(_ context.Context) error { | func (t *Tracer) Flush(_ context.Context) error { | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| @@ -89,6 +93,10 @@ func (s *Span) Tracer() tracer.Tracer { | |||||||
| 	return s.tracer | 	return s.tracer | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (s *Span) IsRecording() bool { | ||||||
|  | 	return true | ||||||
|  | } | ||||||
|  |  | ||||||
| type Event struct { | type Event struct { | ||||||
| 	name   string | 	name   string | ||||||
| 	labels []interface{} | 	labels []interface{} | ||||||
|   | |||||||
| @@ -18,6 +18,10 @@ func (t *noopTracer) Spans() []Span { | |||||||
| 	return t.spans | 	return t.spans | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (t *noopTracer) Enabled() bool { | ||||||
|  | 	return t.opts.Enabled | ||||||
|  | } | ||||||
|  |  | ||||||
| func (t *noopTracer) Start(ctx context.Context, name string, opts ...SpanOption) (context.Context, Span) { | func (t *noopTracer) Start(ctx context.Context, name string, opts ...SpanOption) (context.Context, Span) { | ||||||
| 	options := NewSpanOptions(opts...) | 	options := NewSpanOptions(opts...) | ||||||
| 	span := &noopSpan{ | 	span := &noopSpan{ | ||||||
| @@ -120,6 +124,10 @@ func (s *noopSpan) SetStatus(st SpanStatus, msg string) { | |||||||
| 	s.statusMsg = msg | 	s.statusMsg = msg | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (s *noopSpan) IsRecording() bool { | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|  |  | ||||||
| // NewTracer returns new memory tracer | // NewTracer returns new memory tracer | ||||||
| func NewTracer(opts ...Option) Tracer { | func NewTracer(opts ...Option) Tracer { | ||||||
| 	return &noopTracer{ | 	return &noopTracer{ | ||||||
|   | |||||||
| @@ -142,6 +142,8 @@ type Options struct { | |||||||
| 	Name string | 	Name string | ||||||
| 	// ContextAttrFuncs contains funcs that provides tracing | 	// ContextAttrFuncs contains funcs that provides tracing | ||||||
| 	ContextAttrFuncs []ContextAttrFunc | 	ContextAttrFuncs []ContextAttrFunc | ||||||
|  | 	// Enabled specify trace status | ||||||
|  | 	Enabled bool | ||||||
| } | } | ||||||
|  |  | ||||||
| // Option func signature | // Option func signature | ||||||
| @@ -181,6 +183,7 @@ func NewOptions(opts ...Option) Options { | |||||||
| 		Logger:           logger.DefaultLogger, | 		Logger:           logger.DefaultLogger, | ||||||
| 		Context:          context.Background(), | 		Context:          context.Background(), | ||||||
| 		ContextAttrFuncs: DefaultContextAttrFuncs, | 		ContextAttrFuncs: DefaultContextAttrFuncs, | ||||||
|  | 		Enabled:          true, | ||||||
| 	} | 	} | ||||||
| 	for _, o := range opts { | 	for _, o := range opts { | ||||||
| 		o(&options) | 		o(&options) | ||||||
| @@ -194,3 +197,10 @@ func Name(n string) Option { | |||||||
| 		o.Name = n | 		o.Name = n | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Disabled disable tracer | ||||||
|  | func Disabled(b bool) Option { | ||||||
|  | 	return func(o *Options) { | ||||||
|  | 		o.Enabled = !b | ||||||
|  | 	} | ||||||
|  | } | ||||||
|   | |||||||
| @@ -51,6 +51,8 @@ type Tracer interface { | |||||||
| 	// Extract(ctx context.Context) | 	// Extract(ctx context.Context) | ||||||
| 	// Flush flushes spans | 	// Flush flushes spans | ||||||
| 	Flush(ctx context.Context) error | 	Flush(ctx context.Context) error | ||||||
|  | 	// Enabled returns tracer status | ||||||
|  | 	Enabled() bool | ||||||
| } | } | ||||||
|  |  | ||||||
| type Span interface { | type Span interface { | ||||||
| @@ -78,4 +80,6 @@ type Span interface { | |||||||
| 	TraceID() string | 	TraceID() string | ||||||
| 	// SpanID returns span id | 	// SpanID returns span id | ||||||
| 	SpanID() string | 	SpanID() string | ||||||
|  | 	// IsRecording returns the recording state of the Span. | ||||||
|  | 	IsRecording() bool | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										84
									
								
								util/buffer/seeker_buffer.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										84
									
								
								util/buffer/seeker_buffer.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,84 @@ | |||||||
|  | package buffer | ||||||
|  |  | ||||||
|  | import "io" | ||||||
|  |  | ||||||
|  | var _ interface { | ||||||
|  | 	io.ReadCloser | ||||||
|  | 	io.ReadSeeker | ||||||
|  | } = (*SeekerBuffer)(nil) | ||||||
|  |  | ||||||
|  | // Buffer is a ReadWriteCloser that supports seeking. It's intended to | ||||||
|  | // replicate the functionality of bytes.Buffer that I use in my projects. | ||||||
|  | // | ||||||
|  | // Note that the seeking is limited to the read marker; all writes are | ||||||
|  | // append-only. | ||||||
|  | type SeekerBuffer struct { | ||||||
|  | 	data []byte | ||||||
|  | 	pos  int64 | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func NewSeekerBuffer(data []byte) *SeekerBuffer { | ||||||
|  | 	return &SeekerBuffer{ | ||||||
|  | 		data: data, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (b *SeekerBuffer) Read(p []byte) (int, error) { | ||||||
|  | 	if b.pos >= int64(len(b.data)) { | ||||||
|  | 		return 0, io.EOF | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	n := copy(p, b.data[b.pos:]) | ||||||
|  | 	b.pos += int64(n) | ||||||
|  | 	return n, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (b *SeekerBuffer) Write(p []byte) (int, error) { | ||||||
|  | 	b.data = append(b.data, p...) | ||||||
|  | 	return len(p), nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Seek sets the read pointer to pos. | ||||||
|  | func (b *SeekerBuffer) Seek(offset int64, whence int) (int64, error) { | ||||||
|  | 	switch whence { | ||||||
|  | 	case io.SeekStart: | ||||||
|  | 		b.pos = offset | ||||||
|  | 	case io.SeekEnd: | ||||||
|  | 		b.pos = int64(len(b.data)) + offset | ||||||
|  | 	case io.SeekCurrent: | ||||||
|  | 		b.pos += offset | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return b.pos, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Rewind resets the read pointer to 0. | ||||||
|  | func (b *SeekerBuffer) Rewind() error { | ||||||
|  | 	if _, err := b.Seek(0, io.SeekStart); err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Close clears all the data out of the buffer and sets the read position to 0. | ||||||
|  | func (b *SeekerBuffer) Close() error { | ||||||
|  | 	b.data = nil | ||||||
|  | 	b.pos = 0 | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Reset clears all the data out of the buffer and sets the read position to 0. | ||||||
|  | func (b *SeekerBuffer) Reset() { | ||||||
|  | 	b.data = nil | ||||||
|  | 	b.pos = 0 | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Len returns the length of data remaining to be read. | ||||||
|  | func (b *SeekerBuffer) Len() int { | ||||||
|  | 	return len(b.data[b.pos:]) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Bytes returns the underlying bytes from the current position. | ||||||
|  | func (b *SeekerBuffer) Bytes() []byte { | ||||||
|  | 	return b.data[b.pos:] | ||||||
|  | } | ||||||
							
								
								
									
										55
									
								
								util/buffer/seeker_buffer_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								util/buffer/seeker_buffer_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,55 @@ | |||||||
|  | package buffer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"strings" | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func noErrorT(t *testing.T, err error) { | ||||||
|  | 	if nil != err { | ||||||
|  | 		t.Fatalf("%s", err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func boolT(t *testing.T, cond bool, s ...string) { | ||||||
|  | 	if !cond { | ||||||
|  | 		what := strings.Join(s, ", ") | ||||||
|  | 		if len(what) > 0 { | ||||||
|  | 			what = ": " + what | ||||||
|  | 		} | ||||||
|  | 		t.Fatalf("assert.Bool failed%s", what) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestSeeking(t *testing.T) { | ||||||
|  | 	partA := []byte("hello, ") | ||||||
|  | 	partB := []byte("world!") | ||||||
|  |  | ||||||
|  | 	buf := NewSeekerBuffer(partA) | ||||||
|  |  | ||||||
|  | 	boolT(t, buf.Len() == len(partA), fmt.Sprintf("on init: have length %d, want length %d", buf.Len(), len(partA))) | ||||||
|  |  | ||||||
|  | 	b := make([]byte, 32) | ||||||
|  |  | ||||||
|  | 	n, err := buf.Read(b) | ||||||
|  | 	noErrorT(t, err) | ||||||
|  | 	boolT(t, buf.Len() == 0, fmt.Sprintf("after reading 1: have length %d, want length 0", buf.Len())) | ||||||
|  | 	boolT(t, n == len(partA), fmt.Sprintf("after reading 2: have length %d, want length %d", n, len(partA))) | ||||||
|  |  | ||||||
|  | 	n, err = buf.Write(partB) | ||||||
|  | 	noErrorT(t, err) | ||||||
|  | 	boolT(t, n == len(partB), fmt.Sprintf("after writing: have length %d, want length %d", n, len(partB))) | ||||||
|  |  | ||||||
|  | 	n, err = buf.Read(b) | ||||||
|  | 	noErrorT(t, err) | ||||||
|  | 	boolT(t, buf.Len() == 0, fmt.Sprintf("after rereading 1: have length %d, want length 0", buf.Len())) | ||||||
|  | 	boolT(t, n == len(partB), fmt.Sprintf("after rereading 2: have length %d, want length %d", n, len(partB))) | ||||||
|  |  | ||||||
|  | 	partsLen := len(partA) + len(partB) | ||||||
|  | 	_ = buf.Rewind() | ||||||
|  | 	boolT(t, buf.Len() == partsLen, fmt.Sprintf("after rewinding: have length %d, want length %d", buf.Len(), partsLen)) | ||||||
|  |  | ||||||
|  | 	buf.Close() | ||||||
|  | 	boolT(t, buf.Len() == 0, fmt.Sprintf("after closing, have length %d, want length 0", buf.Len())) | ||||||
|  | } | ||||||
| @@ -489,35 +489,74 @@ func URLMap(query string) (map[string]interface{}, error) { | |||||||
| 	return mp.(map[string]interface{}), nil | 	return mp.(map[string]interface{}), nil | ||||||
| } | } | ||||||
|  |  | ||||||
| // FlattenMap expand key.subkey to nested map | // FlattenMap flattens a nested map into a single-level map using dot notation for nested keys. | ||||||
| func FlattenMap(a map[string]interface{}) map[string]interface{} { | // In case of key conflicts, all nested levels will be discarded in favor of the first-level key. | ||||||
| 	// preprocess map | // | ||||||
| 	nb := make(map[string]interface{}, len(a)) | // Example #1: | ||||||
| 	for k, v := range a { | // | ||||||
| 		ps := strings.Split(k, ".") | //	Input: | ||||||
| 		if len(ps) == 1 { | //	  { | ||||||
| 			nb[k] = v | //	    "user.name": "alex", | ||||||
|  | //	    "user.document.id": "document_id" | ||||||
|  | //	    "user.document.number": "document_number" | ||||||
|  | //	  } | ||||||
|  | //	Output: | ||||||
|  | //	  { | ||||||
|  | //	    "user": { | ||||||
|  | //	      "name": "alex", | ||||||
|  | //	      "document": { | ||||||
|  | //	        "id": "document_id" | ||||||
|  | //	        "number": "document_number" | ||||||
|  | //	      } | ||||||
|  | //	    } | ||||||
|  | //	  } | ||||||
|  | // | ||||||
|  | // Example #2 (with conflicts): | ||||||
|  | // | ||||||
|  | //	Input: | ||||||
|  | //	  { | ||||||
|  | //	    "user": "alex", | ||||||
|  | //	    "user.document.id": "document_id" | ||||||
|  | //	    "user.document.number": "document_number" | ||||||
|  | //	  } | ||||||
|  | //	Output: | ||||||
|  | //	  { | ||||||
|  | //	    "user": "alex" | ||||||
|  | //	  } | ||||||
|  | func FlattenMap(input map[string]interface{}) map[string]interface{} { | ||||||
|  | 	result := make(map[string]interface{}) | ||||||
|  |  | ||||||
|  | 	for k, v := range input { | ||||||
|  | 		parts := strings.Split(k, ".") | ||||||
|  |  | ||||||
|  | 		if len(parts) == 1 { | ||||||
|  | 			result[k] = v | ||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
| 		em := make(map[string]interface{}) |  | ||||||
| 		em[ps[len(ps)-1]] = v | 		current := result | ||||||
| 		for i := len(ps) - 2; i > 0; i-- { |  | ||||||
| 			nm := make(map[string]interface{}) | 		for i, part := range parts { | ||||||
| 			nm[ps[i]] = em | 			// last element in the path | ||||||
| 			em = nm | 			if i == len(parts)-1 { | ||||||
|  | 				current[part] = v | ||||||
|  | 				break | ||||||
| 			} | 			} | ||||||
| 		if vm, ok := nb[ps[0]]; ok { |  | ||||||
| 			// nested map | 			// initialize map for current level if not exist | ||||||
| 			nm := vm.(map[string]interface{}) | 			if _, ok := current[part]; !ok { | ||||||
| 			for vk, vv := range em { | 				current[part] = make(map[string]interface{}) | ||||||
| 				nm[vk] = vv |  | ||||||
| 			} | 			} | ||||||
| 			nb[ps[0]] = nm |  | ||||||
|  | 			if nested, ok := current[part].(map[string]interface{}); ok { | ||||||
|  | 				current = nested // continue to the nested map | ||||||
| 			} else { | 			} else { | ||||||
| 			nb[ps[0]] = em | 				break // if current element is not a map, ignore it | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	return nb | 	} | ||||||
|  |  | ||||||
|  | 	return result | ||||||
| } | } | ||||||
|  |  | ||||||
| /* | /* | ||||||
|   | |||||||
| @@ -6,6 +6,7 @@ import ( | |||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/stretchr/testify/require" | ||||||
| 	rutil "go.unistack.org/micro/v3/util/reflect" | 	rutil "go.unistack.org/micro/v3/util/reflect" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -319,3 +320,140 @@ func TestIsZero(t *testing.T) { | |||||||
|  |  | ||||||
| 	// t.Logf("XX %#+v\n", ok) | 	// t.Logf("XX %#+v\n", ok) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestFlattenMap(t *testing.T) { | ||||||
|  | 	tests := []struct { | ||||||
|  | 		name     string | ||||||
|  | 		input    map[string]interface{} | ||||||
|  | 		expected map[string]interface{} | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			name:     "empty map", | ||||||
|  | 			input:    map[string]interface{}{}, | ||||||
|  | 			expected: map[string]interface{}{}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:     "nil map", | ||||||
|  | 			input:    nil, | ||||||
|  | 			expected: map[string]interface{}{}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "single level", | ||||||
|  | 			input: map[string]interface{}{ | ||||||
|  | 				"username": "username", | ||||||
|  | 				"password": "password", | ||||||
|  | 			}, | ||||||
|  | 			expected: map[string]interface{}{ | ||||||
|  | 				"username": "username", | ||||||
|  | 				"password": "password", | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "two level", | ||||||
|  | 			input: map[string]interface{}{ | ||||||
|  | 				"order_id":      "order_id", | ||||||
|  | 				"user.name":     "username", | ||||||
|  | 				"user.password": "password", | ||||||
|  | 			}, | ||||||
|  | 			expected: map[string]interface{}{ | ||||||
|  | 				"order_id": "order_id", | ||||||
|  | 				"user": map[string]interface{}{ | ||||||
|  | 					"name":     "username", | ||||||
|  | 					"password": "password", | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "three level", | ||||||
|  | 			input: map[string]interface{}{ | ||||||
|  | 				"order_id":             "order_id", | ||||||
|  | 				"user.name":            "username", | ||||||
|  | 				"user.password":        "password", | ||||||
|  | 				"user.document.id":     "document_id", | ||||||
|  | 				"user.document.number": "document_number", | ||||||
|  | 			}, | ||||||
|  | 			expected: map[string]interface{}{ | ||||||
|  | 				"order_id": "order_id", | ||||||
|  | 				"user": map[string]interface{}{ | ||||||
|  | 					"name":     "username", | ||||||
|  | 					"password": "password", | ||||||
|  | 					"document": map[string]interface{}{ | ||||||
|  | 						"id":     "document_id", | ||||||
|  | 						"number": "document_number", | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "four level", | ||||||
|  | 			input: map[string]interface{}{ | ||||||
|  | 				"order_id":                    "order_id", | ||||||
|  | 				"user.name":                   "username", | ||||||
|  | 				"user.password":               "password", | ||||||
|  | 				"user.document.id":            "document_id", | ||||||
|  | 				"user.document.number":        "document_number", | ||||||
|  | 				"user.info.permissions.read":  "available", | ||||||
|  | 				"user.info.permissions.write": "available", | ||||||
|  | 			}, | ||||||
|  | 			expected: map[string]interface{}{ | ||||||
|  | 				"order_id": "order_id", | ||||||
|  | 				"user": map[string]interface{}{ | ||||||
|  | 					"name":     "username", | ||||||
|  | 					"password": "password", | ||||||
|  | 					"document": map[string]interface{}{ | ||||||
|  | 						"id":     "document_id", | ||||||
|  | 						"number": "document_number", | ||||||
|  | 					}, | ||||||
|  | 					"info": map[string]interface{}{ | ||||||
|  | 						"permissions": map[string]interface{}{ | ||||||
|  | 							"read":  "available", | ||||||
|  | 							"write": "available", | ||||||
|  | 						}, | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "key conflicts", | ||||||
|  | 			input: map[string]interface{}{ | ||||||
|  | 				"user":          "user", | ||||||
|  | 				"user.name":     "username", | ||||||
|  | 				"user.password": "password", | ||||||
|  | 			}, | ||||||
|  | 			expected: map[string]interface{}{ | ||||||
|  | 				"user": "user", | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "overwriting conflicts", | ||||||
|  | 			input: map[string]interface{}{ | ||||||
|  | 				"order_id":             "order_id", | ||||||
|  | 				"user.document.id":     "document_id", | ||||||
|  | 				"user.document.number": "document_number", | ||||||
|  | 				"user.info.address":    "address", | ||||||
|  | 				"user.info.phone":      "phone", | ||||||
|  | 			}, | ||||||
|  | 			expected: map[string]interface{}{ | ||||||
|  | 				"order_id": "order_id", | ||||||
|  | 				"user": map[string]interface{}{ | ||||||
|  | 					"document": map[string]interface{}{ | ||||||
|  | 						"id":     "document_id", | ||||||
|  | 						"number": "document_number", | ||||||
|  | 					}, | ||||||
|  | 					"info": map[string]interface{}{ | ||||||
|  | 						"address": "address", | ||||||
|  | 						"phone":   "phone", | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, tt := range tests { | ||||||
|  | 		t.Run(tt.name, func(t *testing.T) { | ||||||
|  | 			for range 100 { // need to exclude the impact of key order in the map on the test. | ||||||
|  | 				require.Equal(t, tt.expected, rutil.FlattenMap(tt.input)) | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user