From 95860fc561f04e5d5ab9c41ffa95073ed1961b70 Mon Sep 17 00:00:00 2001 From: Vasiliy Tolstov Date: Mon, 16 Jan 2023 15:19:30 +0300 Subject: [PATCH 1/3] small fixes for span logs Signed-off-by: Vasiliy Tolstov --- tx.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tx.go b/tx.go index 7aae403..01defaf 100644 --- a/tx.go +++ b/tx.go @@ -24,6 +24,11 @@ func (w *wrapperTx) Commit() error { err := w.tx.Commit() td := time.Since(ts) + if err != nil { + w.span.AddLabels("error", true) + w.span.AddLabels("err", err.Error()) + } + if w.opts.LoggerEnabled { w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "Commit", labelUnknown, td, err)...).Log(context.TODO(), w.opts.LoggerLevel) } @@ -40,6 +45,11 @@ func (w *wrapperTx) Rollback() error { err := w.tx.Rollback() td := time.Since(ts) + if err != nil { + w.span.AddLabels("error", true) + w.span.AddLabels("err", err.Error()) + } + if w.opts.LoggerEnabled { w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "Rollback", labelUnknown, td, err)...).Log(context.TODO(), w.opts.LoggerLevel) } -- 2.45.2 From 9ba10e0c3e141933d482a508d61b97b3d02973c1 Mon Sep 17 00:00:00 2001 From: Vasiliy Tolstov Date: Mon, 16 Jan 2023 23:26:58 +0300 Subject: [PATCH 2/3] fix Signed-off-by: Vasiliy Tolstov --- common.go | 5 ++--- conn.go | 47 +++++++++++++++++++++++++++++++++++++++-------- stmt.go | 21 +++++++++++++++++++-- tx.go | 12 ++++++++++-- 4 files changed, 70 insertions(+), 15 deletions(-) diff --git a/common.go b/common.go index 3b59ff8..b74452d 100644 --- a/common.go +++ b/common.go @@ -24,7 +24,7 @@ func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { // namedValueToLabels convert driver arguments to interface{} slice func namedValueToLabels(named []driver.NamedValue) []interface{} { - largs := make([]interface{}, len(named)*2) + largs := make([]interface{}, 0, len(named)*2) var name string for _, param := range named { if param.Name != "" { @@ -32,8 +32,7 @@ func namedValueToLabels(named []driver.NamedValue) []interface{} { } else { name = fmt.Sprintf("$%d", param.Ordinal) } - - largs = append(largs, name, param.Value) + largs = append(largs, fmt.Sprintf("%s=%s", name, param.Value)) } return largs } diff --git a/conn.go b/conn.go index f9612db..ad334df 100644 --- a/conn.go +++ b/conn.go @@ -5,12 +5,17 @@ import ( "database/sql/driver" "fmt" "time" + + "go.unistack.org/micro/v3/tracer" ) +var _ driver.Conn = &wrapperConn{} + // wrapperConn defines a wrapper for driver.Conn type wrapperConn struct { conn driver.Conn opts Options + ctx context.Context } // Prepare implements driver.Conn Prepare @@ -91,7 +96,7 @@ func (w *wrapperConn) Begin() (driver.Tx, error) { // BeginTx implements driver.ConnBeginTx BeginTx func (w *wrapperConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { - nctx, span := w.opts.Tracer.Start(ctx, "BeginTx") + nctx, span := w.opts.Tracer.Start(ctx, "Transaction") span.AddLabels("method", "BeginTx") name := getQueryName(ctx) if name != "" { @@ -120,7 +125,8 @@ func (w *wrapperConn) BeginTx(ctx context.Context, opts driver.TxOptions) (drive if w.opts.LoggerEnabled { w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "BeginTx", name, td, err)...).Log(context.TODO(), w.opts.LoggerLevel) } - return &wrapperTx{tx: tx, opts: w.opts, span: span}, nil + w.ctx = nctx + return &wrapperTx{ctx: ctx, tx: tx, opts: w.opts, span: span, conn: w}, nil } ts := time.Now() // nolint:staticcheck @@ -140,12 +146,19 @@ func (w *wrapperConn) BeginTx(ctx context.Context, opts driver.TxOptions) (drive if w.opts.LoggerEnabled { w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "BeginTx", name, td, err)...).Log(context.TODO(), w.opts.LoggerLevel) } - return tx, nil + w.ctx = nctx + return &wrapperTx{ctx: ctx, tx: tx, opts: w.opts, span: span, conn: w}, nil } // PrepareContext implements driver.ConnPrepareContext PrepareContext func (w *wrapperConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { - nctx, span := w.opts.Tracer.Start(ctx, "PrepareContext") + var nctx context.Context + var span tracer.Span + if w.ctx != nil { + nctx, span = w.opts.Tracer.Start(w.ctx, "PrepareContext") + } else { + nctx, span = w.opts.Tracer.Start(ctx, "PrepareContext") + } span.AddLabels("method", "PrepareContext") name := getQueryName(ctx) if name != "" { @@ -176,7 +189,7 @@ func (w *wrapperConn) PrepareContext(ctx context.Context, query string) (driver. if w.opts.LoggerEnabled { w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "PrepareContext", name, td, err)...).Log(context.TODO(), w.opts.LoggerLevel) } - return &wrapperStmt{stmt: stmt, opts: w.opts}, nil + return &wrapperStmt{stmt: stmt, opts: w.opts, ctx: nctx}, nil } ts := time.Now() stmt, err := w.conn.Prepare(query) @@ -227,7 +240,13 @@ func (w *wrapperConn) Exec(query string, args []driver.Value) (driver.Result, er // Exec implements driver.StmtExecContext ExecContext func (w *wrapperConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { - nctx, span := w.opts.Tracer.Start(ctx, "ExecContext") + var nctx context.Context + var span tracer.Span + if w.ctx != nil { + nctx, span = w.opts.Tracer.Start(w.ctx, "ExecContext") + } else { + nctx, span = w.opts.Tracer.Start(ctx, "ExecContext") + } span.AddLabels("method", "ExecContext") name := getQueryName(ctx) if name != "" { @@ -292,7 +311,13 @@ func (w *wrapperConn) ExecContext(ctx context.Context, query string, args []driv // Ping implements driver.Pinger Ping func (w *wrapperConn) Ping(ctx context.Context) error { if conn, ok := w.conn.(driver.Pinger); ok { - nctx, span := w.opts.Tracer.Start(ctx, "Ping") + var nctx context.Context + var span tracer.Span + if w.ctx != nil { + nctx, span = w.opts.Tracer.Start(w.ctx, "Ping") + } else { + nctx, span = w.opts.Tracer.Start(ctx, "Ping") + } defer span.Finish() labels := []string{labelMethod, "Ping"} ts := time.Now() @@ -348,7 +373,13 @@ func (w *wrapperConn) Query(query string, args []driver.Value) (driver.Rows, err // QueryContext implements Driver.QueryerContext QueryContext func (w *wrapperConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { - nctx, span := w.opts.Tracer.Start(ctx, "QueryContext") + var nctx context.Context + var span tracer.Span + if w.ctx != nil { + nctx, span = w.opts.Tracer.Start(w.ctx, "QueryContext") + } else { + nctx, span = w.opts.Tracer.Start(ctx, "QueryContext") + } span.AddLabels("method", "QueryContext") name := getQueryName(ctx) if name != "" { diff --git a/stmt.go b/stmt.go index f5705ee..707362c 100644 --- a/stmt.go +++ b/stmt.go @@ -5,12 +5,17 @@ import ( "database/sql/driver" "fmt" "time" + + "go.unistack.org/micro/v3/tracer" ) +var _ driver.Stmt = &wrapperStmt{} + // wrapperStmt defines a wrapper for driver.Stmt type wrapperStmt struct { stmt driver.Stmt opts Options + ctx context.Context } // Close implements driver.Stmt Close @@ -85,7 +90,13 @@ func (w *wrapperStmt) Query(args []driver.Value) (driver.Rows, error) { // ExecContext implements driver.ExecerContext ExecContext func (w *wrapperStmt) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { - nctx, span := w.opts.Tracer.Start(ctx, "ExecContext") + var nctx context.Context + var span tracer.Span + if w.ctx != nil { + nctx, span = w.opts.Tracer.Start(w.ctx, "ExecContext") + } else { + nctx, span = w.opts.Tracer.Start(ctx, "ExecContext") + } span.AddLabels("method", "ExecContext") name := getQueryName(ctx) if name != "" { @@ -153,7 +164,13 @@ func (w *wrapperStmt) ExecContext(ctx context.Context, query string, args []driv // QueryContext implements Driver.QueryerContext QueryContext func (w *wrapperStmt) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { - nctx, span := w.opts.Tracer.Start(ctx, "QueryContext") + var nctx context.Context + var span tracer.Span + if w.ctx != nil { + nctx, span = w.opts.Tracer.Start(w.ctx, "QueryContext") + } else { + nctx, span = w.opts.Tracer.Start(ctx, "QueryContext") + } span.AddLabels("method", "QueryContext") name := getQueryName(ctx) if name != "" { diff --git a/tx.go b/tx.go index 7aae403..5fad4e0 100644 --- a/tx.go +++ b/tx.go @@ -8,11 +8,15 @@ import ( "go.unistack.org/micro/v3/tracer" ) +var _ driver.Tx = &wrapperTx{} + // wrapperTx defines a wrapper for driver.Tx type wrapperTx struct { tx driver.Tx span tracer.Span opts Options + conn *wrapperConn + ctx context.Context } // Commit implements driver.Tx Commit @@ -25,9 +29,11 @@ func (w *wrapperTx) Commit() error { td := time.Since(ts) if w.opts.LoggerEnabled { - w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "Commit", labelUnknown, td, err)...).Log(context.TODO(), w.opts.LoggerLevel) + w.opts.Logger.Fields(w.opts.LoggerObserver(w.ctx, "Commit", labelUnknown, td, err)...).Log(context.TODO(), w.opts.LoggerLevel) } + w.ctx = nil + return err } @@ -41,8 +47,10 @@ func (w *wrapperTx) Rollback() error { td := time.Since(ts) if w.opts.LoggerEnabled { - w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "Rollback", labelUnknown, td, err)...).Log(context.TODO(), w.opts.LoggerLevel) + w.opts.Logger.Fields(w.opts.LoggerObserver(w.ctx, "Rollback", labelUnknown, td, err)...).Log(context.TODO(), w.opts.LoggerLevel) } + w.ctx = nil + return err } -- 2.45.2 From 066a62c1e357dcc9d51af96af41b26a9f6b8dfb9 Mon Sep 17 00:00:00 2001 From: Vasiliy Tolstov Date: Mon, 16 Jan 2023 23:38:29 +0300 Subject: [PATCH 3/3] fix context Signed-off-by: Vasiliy Tolstov --- conn.go | 83 ++++++++++++++++++++++++++++++++++++++----------------- driver.go | 7 +++-- stmt.go | 36 ++++++++++++++++++------ tx.go | 4 +-- 4 files changed, 91 insertions(+), 39 deletions(-) diff --git a/conn.go b/conn.go index ad334df..128ba7f 100644 --- a/conn.go +++ b/conn.go @@ -20,6 +20,13 @@ type wrapperConn struct { // Prepare implements driver.Conn Prepare func (w *wrapperConn) Prepare(query string) (driver.Stmt, error) { + var ctx context.Context + if w.ctx != nil { + ctx = w.ctx + } else { + ctx = context.Background() + } + labels := []string{labelMethod, "Prepare", labelQuery, labelUnknown} ts := time.Now() stmt, err := w.conn.Prepare(query) @@ -31,7 +38,7 @@ func (w *wrapperConn) Prepare(query string) (driver.Stmt, error) { w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) if w.opts.LoggerEnabled { - w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "Prepare", labelUnknown, td, err)...).Log(context.TODO(), w.opts.LoggerLevel) + w.opts.Logger.Fields(w.opts.LoggerObserver(ctx, "Prepare", labelUnknown, td, err)...).Log(ctx, w.opts.LoggerLevel) } return nil, err } @@ -40,14 +47,21 @@ func (w *wrapperConn) Prepare(query string) (driver.Stmt, error) { w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) if w.opts.LoggerEnabled { - w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "QueryContext", labelUnknown, td, err)...).Log(context.TODO(), w.opts.LoggerLevel) + w.opts.Logger.Fields(w.opts.LoggerObserver(ctx, "Prepare", labelUnknown, td, err)...).Log(ctx, w.opts.LoggerLevel) } - return &wrapperStmt{stmt: stmt, opts: w.opts}, nil + return &wrapperStmt{stmt: stmt, opts: w.opts, ctx: ctx}, nil } // Close implements driver.Conn Close func (w *wrapperConn) Close() error { + var ctx context.Context + if w.ctx != nil { + ctx = w.ctx + } else { + ctx = context.Background() + } + labels := []string{labelMethod, "Close"} ts := time.Now() err := w.conn.Close() @@ -62,7 +76,7 @@ func (w *wrapperConn) Close() error { w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) if w.opts.LoggerEnabled { - w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "QueryContext", labelUnknown, td, err)...).Log(context.TODO(), w.opts.LoggerLevel) + w.opts.Logger.Fields(w.opts.LoggerObserver(ctx, "Close", labelUnknown, td, err)...).Log(ctx, w.opts.LoggerLevel) } return err @@ -70,6 +84,13 @@ func (w *wrapperConn) Close() error { // Begin implements driver.Conn Begin func (w *wrapperConn) Begin() (driver.Tx, error) { + var ctx context.Context + if w.ctx != nil { + ctx = w.ctx + } else { + ctx = context.Background() + } + labels := []string{labelMethod, "Begin"} ts := time.Now() // nolint:staticcheck @@ -81,7 +102,7 @@ func (w *wrapperConn) Begin() (driver.Tx, error) { w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) if w.opts.LoggerEnabled { - w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "Begin", labelUnknown, td, err)...).Log(context.TODO(), w.opts.LoggerLevel) + w.opts.Logger.Fields(w.opts.LoggerObserver(ctx, "Begin", labelUnknown, td, err)...).Log(ctx, w.opts.LoggerLevel) } return nil, err } @@ -89,9 +110,9 @@ func (w *wrapperConn) Begin() (driver.Tx, error) { w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) if w.opts.LoggerEnabled { - w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "Begin", labelUnknown, td, err)...).Log(context.TODO(), w.opts.LoggerLevel) + w.opts.Logger.Fields(w.opts.LoggerObserver(ctx, "Begin", labelUnknown, td, err)...).Log(ctx, w.opts.LoggerLevel) } - return &wrapperTx{tx: tx, opts: w.opts}, nil + return &wrapperTx{tx: tx, opts: w.opts, ctx: ctx}, nil } // BeginTx implements driver.ConnBeginTx BeginTx @@ -117,13 +138,13 @@ func (w *wrapperConn) BeginTx(ctx context.Context, opts driver.TxOptions) (drive span.AddLabels("error", true) span.AddLabels("err", err.Error()) if w.opts.LoggerEnabled { - w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "BeginTx", name, td, err)...).Log(context.TODO(), w.opts.LoggerLevel) + w.opts.Logger.Fields(w.opts.LoggerObserver(ctx, "BeginTx", name, td, err)...).Log(ctx, w.opts.LoggerLevel) } return nil, err } w.opts.Meter.Counter(meterRequestTotal, append(labels, labelStatus, labelSuccess)...).Inc() if w.opts.LoggerEnabled { - w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "BeginTx", name, td, err)...).Log(context.TODO(), w.opts.LoggerLevel) + w.opts.Logger.Fields(w.opts.LoggerObserver(ctx, "BeginTx", name, td, err)...).Log(ctx, w.opts.LoggerLevel) } w.ctx = nctx return &wrapperTx{ctx: ctx, tx: tx, opts: w.opts, span: span, conn: w}, nil @@ -144,7 +165,7 @@ func (w *wrapperConn) BeginTx(ctx context.Context, opts driver.TxOptions) (drive w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) if w.opts.LoggerEnabled { - w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "BeginTx", name, td, err)...).Log(context.TODO(), w.opts.LoggerLevel) + w.opts.Logger.Fields(w.opts.LoggerObserver(ctx, "BeginTx", name, td, err)...).Log(ctx, w.opts.LoggerLevel) } w.ctx = nctx return &wrapperTx{ctx: ctx, tx: tx, opts: w.opts, span: span, conn: w}, nil @@ -179,7 +200,7 @@ func (w *wrapperConn) PrepareContext(ctx context.Context, query string) (driver. span.AddLabels("error", true) span.AddLabels("err", err.Error()) if w.opts.LoggerEnabled { - w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "PrepareContext", name, td, err)...).Log(context.TODO(), w.opts.LoggerLevel) + w.opts.Logger.Fields(w.opts.LoggerObserver(ctx, "PrepareContext", name, td, err)...).Log(ctx, w.opts.LoggerLevel) } return nil, err } @@ -187,7 +208,7 @@ func (w *wrapperConn) PrepareContext(ctx context.Context, query string) (driver. w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) if w.opts.LoggerEnabled { - w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "PrepareContext", name, td, err)...).Log(context.TODO(), w.opts.LoggerLevel) + w.opts.Logger.Fields(w.opts.LoggerObserver(ctx, "PrepareContext", name, td, err)...).Log(ctx, w.opts.LoggerLevel) } return &wrapperStmt{stmt: stmt, opts: w.opts, ctx: nctx}, nil } @@ -206,13 +227,19 @@ func (w *wrapperConn) PrepareContext(ctx context.Context, query string) (driver. w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) if w.opts.LoggerEnabled { - w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "PrepareContext", name, td, err)...).Log(context.TODO(), w.opts.LoggerLevel) + w.opts.Logger.Fields(w.opts.LoggerObserver(ctx, "PrepareContext", name, td, err)...).Log(ctx, w.opts.LoggerLevel) } return stmt, nil } // Exec implements driver.Execer Exec func (w *wrapperConn) Exec(query string, args []driver.Value) (driver.Result, error) { + var ctx context.Context + if w.ctx != nil { + ctx = w.ctx + } else { + ctx = context.Background() + } // nolint:staticcheck labels := []string{labelMethod, "Exec", labelQuery, labelUnknown} if execer, ok := w.conn.(driver.Execer); ok { @@ -228,12 +255,12 @@ func (w *wrapperConn) Exec(query string, args []driver.Value) (driver.Result, er w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) if w.opts.LoggerEnabled { - w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "Exec", labelUnknown, td, err)...).Log(context.TODO(), w.opts.LoggerLevel) + w.opts.Logger.Fields(w.opts.LoggerObserver(ctx, "Exec", labelUnknown, td, err)...).Log(ctx, w.opts.LoggerLevel) } return res, err } if w.opts.LoggerEnabled { - w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "Exec", labelUnknown, 0, ErrUnsupported)...).Log(context.TODO(), w.opts.LoggerLevel) + w.opts.Logger.Fields(w.opts.LoggerObserver(ctx, "Exec", labelUnknown, 0, ErrUnsupported)...).Log(ctx, w.opts.LoggerLevel) } return nil, ErrUnsupported } @@ -275,7 +302,7 @@ func (w *wrapperConn) ExecContext(ctx context.Context, query string, args []driv w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) if w.opts.LoggerEnabled { - w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "ExecContext", name, td, err)...).Log(context.TODO(), w.opts.LoggerLevel) + w.opts.Logger.Fields(w.opts.LoggerObserver(ctx, "ExecContext", name, td, err)...).Log(ctx, w.opts.LoggerLevel) } return res, err } @@ -284,7 +311,7 @@ func (w *wrapperConn) ExecContext(ctx context.Context, query string, args []driv span.AddLabels("error", true) span.AddLabels("err", err.Error()) if w.opts.LoggerEnabled { - w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "ExecContext", labelUnknown, 0, err)...).Log(context.TODO(), w.opts.LoggerLevel) + w.opts.Logger.Fields(w.opts.LoggerObserver(ctx, "ExecContext", labelUnknown, 0, err)...).Log(ctx, w.opts.LoggerLevel) } return nil, err } @@ -303,7 +330,7 @@ func (w *wrapperConn) ExecContext(ctx context.Context, query string, args []driv w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) if w.opts.LoggerEnabled { - w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "ExecContext", name, td, err)...).Log(context.TODO(), w.opts.LoggerLevel) + w.opts.Logger.Fields(w.opts.LoggerObserver(ctx, "ExecContext", name, td, err)...).Log(ctx, w.opts.LoggerLevel) } return res, err } @@ -329,7 +356,7 @@ func (w *wrapperConn) Ping(ctx context.Context) error { span.AddLabels("error", true) span.AddLabels("err", err.Error()) if w.opts.LoggerEnabled { - w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "Ping", labelUnknown, td, err)...).Log(context.TODO(), w.opts.LoggerLevel) + w.opts.Logger.Fields(w.opts.LoggerObserver(ctx, "Ping", labelUnknown, td, err)...).Log(ctx, w.opts.LoggerLevel) } return err } else { @@ -339,13 +366,19 @@ func (w *wrapperConn) Ping(ctx context.Context) error { w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) } if w.opts.LoggerEnabled { - w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "Ping", labelUnknown, 0, ErrUnsupported)...).Log(context.TODO(), w.opts.LoggerLevel) + w.opts.Logger.Fields(w.opts.LoggerObserver(ctx, "Ping", labelUnknown, 0, ErrUnsupported)...).Log(ctx, w.opts.LoggerLevel) } return ErrUnsupported } // Query implements driver.Queryer Query func (w *wrapperConn) Query(query string, args []driver.Value) (driver.Rows, error) { + var ctx context.Context + if w.ctx != nil { + ctx = w.ctx + } else { + ctx = context.Background() + } // nolint:staticcheck if conn, ok := w.conn.(driver.Queryer); ok { labels := []string{labelMethod, "Query", labelQuery, labelUnknown} @@ -361,12 +394,12 @@ func (w *wrapperConn) Query(query string, args []driver.Value) (driver.Rows, err w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) if w.opts.LoggerEnabled { - w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "Query", labelUnknown, td, err)...).Log(context.TODO(), w.opts.LoggerLevel) + w.opts.Logger.Fields(w.opts.LoggerObserver(ctx, "Query", labelUnknown, td, err)...).Log(ctx, w.opts.LoggerLevel) } return rows, err } if w.opts.LoggerEnabled { - w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "Query", labelUnknown, 0, ErrUnsupported)...).Log(context.TODO(), w.opts.LoggerLevel) + w.opts.Logger.Fields(w.opts.LoggerObserver(ctx, "Query", labelUnknown, 0, ErrUnsupported)...).Log(ctx, w.opts.LoggerLevel) } return nil, ErrUnsupported } @@ -407,7 +440,7 @@ func (w *wrapperConn) QueryContext(ctx context.Context, query string, args []dri w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) if w.opts.LoggerEnabled { - w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "QueryContext", name, td, err)...).Log(context.TODO(), w.opts.LoggerLevel) + w.opts.Logger.Fields(w.opts.LoggerObserver(ctx, "QueryContext", name, td, err)...).Log(ctx, w.opts.LoggerLevel) } return rows, err } @@ -416,7 +449,7 @@ func (w *wrapperConn) QueryContext(ctx context.Context, query string, args []dri span.AddLabels("error", true) span.AddLabels("err", err.Error()) if w.opts.LoggerEnabled { - w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "QueryContext", name, 0, err)...).Log(context.TODO(), w.opts.LoggerLevel) + w.opts.Logger.Fields(w.opts.LoggerObserver(ctx, "QueryContext", name, 0, err)...).Log(ctx, w.opts.LoggerLevel) } return nil, err } @@ -435,7 +468,7 @@ func (w *wrapperConn) QueryContext(ctx context.Context, query string, args []dri w.opts.Meter.Summary(meterRequestLatencyMicroseconds, labels...).Update(te) w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) if w.opts.LoggerEnabled { - w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "QueryContext", name, td, err)...).Log(context.TODO(), w.opts.LoggerLevel) + w.opts.Logger.Fields(w.opts.LoggerObserver(ctx, "QueryContext", name, td, err)...).Log(ctx, w.opts.LoggerLevel) } return rows, err } diff --git a/driver.go b/driver.go index eb44e14..6b9e4dc 100644 --- a/driver.go +++ b/driver.go @@ -10,11 +10,12 @@ import ( type wrapperDriver struct { driver driver.Driver opts Options + ctx context.Context } // NewWrapper creates and returns a new SQL driver with passed capabilities func NewWrapper(d driver.Driver, opts ...Option) driver.Driver { - return &wrapperDriver{driver: d, opts: NewOptions(opts...)} + return &wrapperDriver{driver: d, opts: NewOptions(opts...), ctx: context.Background()} } // Open implements driver.Driver Open @@ -24,12 +25,12 @@ func (w *wrapperDriver) Open(name string) (driver.Conn, error) { td := time.Since(ts) if w.opts.LoggerEnabled { - w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "Open", labelUnknown, td, err)...).Log(context.TODO(), w.opts.LoggerLevel) + w.opts.Logger.Fields(w.opts.LoggerObserver(w.ctx, "Open", labelUnknown, td, err)...).Log(w.ctx, w.opts.LoggerLevel) } if err != nil { return nil, err } - return &wrapperConn{conn: c, opts: w.opts}, nil + return &wrapperConn{conn: c, opts: w.opts, ctx: w.ctx}, nil } diff --git a/stmt.go b/stmt.go index 707362c..dda47e6 100644 --- a/stmt.go +++ b/stmt.go @@ -20,6 +20,12 @@ type wrapperStmt struct { // Close implements driver.Stmt Close func (w *wrapperStmt) Close() error { + var ctx context.Context + if w.ctx != nil { + ctx = w.ctx + } else { + ctx = context.Background() + } labels := []string{labelMethod, "Close"} ts := time.Now() err := w.stmt.Close() @@ -34,7 +40,7 @@ func (w *wrapperStmt) Close() error { w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) if w.opts.LoggerEnabled { - w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "Close", labelUnknown, td, err)...).Log(context.TODO(), w.opts.LoggerLevel) + w.opts.Logger.Fields(w.opts.LoggerObserver(ctx, "Close", labelUnknown, td, err)...).Log(ctx, w.opts.LoggerLevel) } return err } @@ -46,6 +52,12 @@ func (w *wrapperStmt) NumInput() int { // Exec implements driver.Stmt Exec func (w *wrapperStmt) Exec(args []driver.Value) (driver.Result, error) { + var ctx context.Context + if w.ctx != nil { + ctx = w.ctx + } else { + ctx = context.Background() + } labels := []string{labelMethod, "Exec"} ts := time.Now() // nolint:staticcheck @@ -61,13 +73,19 @@ func (w *wrapperStmt) Exec(args []driver.Value) (driver.Result, error) { w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) if w.opts.LoggerEnabled { - w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "Exec", labelUnknown, td, err)...).Log(context.TODO(), w.opts.LoggerLevel) + w.opts.Logger.Fields(w.opts.LoggerObserver(ctx, "Exec", labelUnknown, td, err)...).Log(ctx, w.opts.LoggerLevel) } return res, err } // Query implements driver.Stmt Query func (w *wrapperStmt) Query(args []driver.Value) (driver.Rows, error) { + var ctx context.Context + if w.ctx != nil { + ctx = w.ctx + } else { + ctx = context.Background() + } labels := []string{labelMethod, "Query"} ts := time.Now() // nolint:staticcheck @@ -83,7 +101,7 @@ func (w *wrapperStmt) Query(args []driver.Value) (driver.Rows, error) { w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) if w.opts.LoggerEnabled { - w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "Query", labelUnknown, td, err)...).Log(context.TODO(), w.opts.LoggerLevel) + w.opts.Logger.Fields(w.opts.LoggerObserver(ctx, "Query", labelUnknown, td, err)...).Log(ctx, w.opts.LoggerLevel) } return rows, err } @@ -125,7 +143,7 @@ func (w *wrapperStmt) ExecContext(ctx context.Context, query string, args []driv w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) if w.opts.LoggerEnabled { - w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "ExecContext", name, td, err)...).Log(context.TODO(), w.opts.LoggerLevel) + w.opts.Logger.Fields(w.opts.LoggerObserver(ctx, "ExecContext", name, td, err)...).Log(ctx, w.opts.LoggerLevel) } return res, err } @@ -136,7 +154,7 @@ func (w *wrapperStmt) ExecContext(ctx context.Context, query string, args []driv span.AddLabels("err", err.Error()) if w.opts.LoggerEnabled { - w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "ExecContext", name, 0, err)...).Log(context.TODO(), w.opts.LoggerLevel) + w.opts.Logger.Fields(w.opts.LoggerObserver(ctx, "ExecContext", name, 0, err)...).Log(ctx, w.opts.LoggerLevel) } return nil, err } @@ -157,7 +175,7 @@ func (w *wrapperStmt) ExecContext(ctx context.Context, query string, args []driv w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) if w.opts.LoggerEnabled { - w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "ExecContext", name, td, err)).Log(context.TODO(), w.opts.LoggerLevel) + w.opts.Logger.Fields(w.opts.LoggerObserver(ctx, "ExecContext", name, td, err)).Log(ctx, w.opts.LoggerLevel) } return res, err } @@ -200,7 +218,7 @@ func (w *wrapperStmt) QueryContext(ctx context.Context, query string, args []dri w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) if w.opts.LoggerEnabled { - w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "QueryContext", name, td, err)).Log(context.TODO(), w.opts.LoggerLevel) + w.opts.Logger.Fields(w.opts.LoggerObserver(ctx, "QueryContext", name, td, err)).Log(ctx, w.opts.LoggerLevel) } return rows, err } @@ -212,7 +230,7 @@ func (w *wrapperStmt) QueryContext(ctx context.Context, query string, args []dri span.AddLabels("err", err.Error()) if w.opts.LoggerEnabled { - w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "QueryContext", name, 0, err)...).Log(context.TODO(), w.opts.LoggerLevel) + w.opts.Logger.Fields(w.opts.LoggerObserver(ctx, "QueryContext", name, 0, err)...).Log(ctx, w.opts.LoggerLevel) } return nil, err } @@ -233,7 +251,7 @@ func (w *wrapperStmt) QueryContext(ctx context.Context, query string, args []dri w.opts.Meter.Histogram(meterRequestDurationSeconds, labels...).Update(te) if w.opts.LoggerEnabled { - w.opts.Logger.Fields(w.opts.LoggerObserver(context.TODO(), "QueryContext", name, td, err)...).Log(context.TODO(), w.opts.LoggerLevel) + w.opts.Logger.Fields(w.opts.LoggerObserver(ctx, "QueryContext", name, td, err)...).Log(ctx, w.opts.LoggerLevel) } return rows, err } diff --git a/tx.go b/tx.go index 9a56cbf..e0b6bd9 100644 --- a/tx.go +++ b/tx.go @@ -34,7 +34,7 @@ func (w *wrapperTx) Commit() error { } if w.opts.LoggerEnabled { - w.opts.Logger.Fields(w.opts.LoggerObserver(w.ctx, "Commit", labelUnknown, td, err)...).Log(context.TODO(), w.opts.LoggerLevel) + w.opts.Logger.Fields(w.opts.LoggerObserver(w.ctx, "Commit", labelUnknown, td, err)...).Log(w.ctx, w.opts.LoggerLevel) } w.ctx = nil @@ -57,7 +57,7 @@ func (w *wrapperTx) Rollback() error { } if w.opts.LoggerEnabled { - w.opts.Logger.Fields(w.opts.LoggerObserver(w.ctx, "Rollback", labelUnknown, td, err)...).Log(context.TODO(), w.opts.LoggerLevel) + w.opts.Logger.Fields(w.opts.LoggerObserver(w.ctx, "Rollback", labelUnknown, td, err)...).Log(w.ctx, w.opts.LoggerLevel) } w.ctx = nil -- 2.45.2