gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

stmt.go (2863B)


      1 package otelsql
      2 
      3 import (
      4 	"context"
      5 	"database/sql/driver"
      6 
      7 	"go.opentelemetry.io/otel/trace"
      8 )
      9 
     10 type otelStmt struct {
     11 	driver.Stmt
     12 
     13 	query   string
     14 	instrum *dbInstrum
     15 
     16 	execCtx  stmtExecCtxFunc
     17 	queryCtx stmtQueryCtxFunc
     18 }
     19 
     20 var _ driver.Stmt = (*otelStmt)(nil)
     21 
     22 func newStmt(stmt driver.Stmt, query string, instrum *dbInstrum) *otelStmt {
     23 	s := &otelStmt{
     24 		Stmt:    stmt,
     25 		query:   query,
     26 		instrum: instrum,
     27 	}
     28 	s.execCtx = s.createExecCtxFunc(stmt)
     29 	s.queryCtx = s.createQueryCtxFunc(stmt)
     30 	return s
     31 }
     32 
     33 //------------------------------------------------------------------------------
     34 
     35 var _ driver.StmtExecContext = (*otelStmt)(nil)
     36 
     37 func (stmt *otelStmt) ExecContext(
     38 	ctx context.Context, args []driver.NamedValue,
     39 ) (driver.Result, error) {
     40 	return stmt.execCtx(ctx, args)
     41 }
     42 
     43 type stmtExecCtxFunc func(ctx context.Context, args []driver.NamedValue) (driver.Result, error)
     44 
     45 func (s *otelStmt) createExecCtxFunc(stmt driver.Stmt) stmtExecCtxFunc {
     46 	var fn stmtExecCtxFunc
     47 
     48 	if execer, ok := s.Stmt.(driver.StmtExecContext); ok {
     49 		fn = execer.ExecContext
     50 	} else {
     51 		fn = func(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
     52 			vArgs, err := namedValueToValue(args)
     53 			if err != nil {
     54 				return nil, err
     55 			}
     56 			return stmt.Exec(vArgs)
     57 		}
     58 	}
     59 
     60 	return func(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
     61 		var res driver.Result
     62 		err := s.instrum.withSpan(ctx, "stmt.Exec", s.query,
     63 			func(ctx context.Context, span trace.Span) error {
     64 				var err error
     65 				res, err = fn(ctx, args)
     66 				if err != nil {
     67 					return err
     68 				}
     69 
     70 				if span.IsRecording() {
     71 					rows, err := res.RowsAffected()
     72 					if err == nil {
     73 						span.SetAttributes(dbRowsAffected.Int64(rows))
     74 					}
     75 				}
     76 
     77 				return nil
     78 			})
     79 		return res, err
     80 	}
     81 }
     82 
     83 //------------------------------------------------------------------------------
     84 
     85 var _ driver.StmtQueryContext = (*otelStmt)(nil)
     86 
     87 func (stmt *otelStmt) QueryContext(
     88 	ctx context.Context, args []driver.NamedValue,
     89 ) (driver.Rows, error) {
     90 	return stmt.queryCtx(ctx, args)
     91 }
     92 
     93 type stmtQueryCtxFunc func(ctx context.Context, args []driver.NamedValue) (driver.Rows, error)
     94 
     95 func (s *otelStmt) createQueryCtxFunc(stmt driver.Stmt) stmtQueryCtxFunc {
     96 	var fn stmtQueryCtxFunc
     97 
     98 	if queryer, ok := s.Stmt.(driver.StmtQueryContext); ok {
     99 		fn = queryer.QueryContext
    100 	} else {
    101 		fn = func(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
    102 			vArgs, err := namedValueToValue(args)
    103 			if err != nil {
    104 				return nil, err
    105 			}
    106 			return s.Query(vArgs)
    107 		}
    108 	}
    109 
    110 	return func(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
    111 		var rows driver.Rows
    112 		err := s.instrum.withSpan(ctx, "stmt.Query", s.query,
    113 			func(ctx context.Context, span trace.Span) error {
    114 				var err error
    115 				rows, err = fn(ctx, args)
    116 				return err
    117 			})
    118 		return rows, err
    119 	}
    120 }