stmt.go (2863B)
1 package otelsql 2 3 import ( 4 "context" 5 "database/sql/driver" 6 7 "go.opentelemetry.io/otel/trace" 8 ) 9 10 type otelStmt struct { 11 driver.Stmt 12 13 query string 14 instrum *dbInstrum 15 16 execCtx stmtExecCtxFunc 17 queryCtx stmtQueryCtxFunc 18 } 19 20 var _ driver.Stmt = (*otelStmt)(nil) 21 22 func newStmt(stmt driver.Stmt, query string, instrum *dbInstrum) *otelStmt { 23 s := &otelStmt{ 24 Stmt: stmt, 25 query: query, 26 instrum: instrum, 27 } 28 s.execCtx = s.createExecCtxFunc(stmt) 29 s.queryCtx = s.createQueryCtxFunc(stmt) 30 return s 31 } 32 33 //------------------------------------------------------------------------------ 34 35 var _ driver.StmtExecContext = (*otelStmt)(nil) 36 37 func (stmt *otelStmt) ExecContext( 38 ctx context.Context, args []driver.NamedValue, 39 ) (driver.Result, error) { 40 return stmt.execCtx(ctx, args) 41 } 42 43 type stmtExecCtxFunc func(ctx context.Context, args []driver.NamedValue) (driver.Result, error) 44 45 func (s *otelStmt) createExecCtxFunc(stmt driver.Stmt) stmtExecCtxFunc { 46 var fn stmtExecCtxFunc 47 48 if execer, ok := s.Stmt.(driver.StmtExecContext); ok { 49 fn = execer.ExecContext 50 } else { 51 fn = func(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { 52 vArgs, err := namedValueToValue(args) 53 if err != nil { 54 return nil, err 55 } 56 return stmt.Exec(vArgs) 57 } 58 } 59 60 return func(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { 61 var res driver.Result 62 err := s.instrum.withSpan(ctx, "stmt.Exec", s.query, 63 func(ctx context.Context, span trace.Span) error { 64 var err error 65 res, err = fn(ctx, args) 66 if err != nil { 67 return err 68 } 69 70 if span.IsRecording() { 71 rows, err := res.RowsAffected() 72 if err == nil { 73 span.SetAttributes(dbRowsAffected.Int64(rows)) 74 } 75 } 76 77 return nil 78 }) 79 return res, err 80 } 81 } 82 83 //------------------------------------------------------------------------------ 84 85 var _ driver.StmtQueryContext = (*otelStmt)(nil) 86 87 func (stmt *otelStmt) QueryContext( 88 ctx context.Context, args []driver.NamedValue, 89 ) (driver.Rows, error) { 90 return stmt.queryCtx(ctx, args) 91 } 92 93 type stmtQueryCtxFunc func(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) 94 95 func (s *otelStmt) createQueryCtxFunc(stmt driver.Stmt) stmtQueryCtxFunc { 96 var fn stmtQueryCtxFunc 97 98 if queryer, ok := s.Stmt.(driver.StmtQueryContext); ok { 99 fn = queryer.QueryContext 100 } else { 101 fn = func(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { 102 vArgs, err := namedValueToValue(args) 103 if err != nil { 104 return nil, err 105 } 106 return s.Query(vArgs) 107 } 108 } 109 110 return func(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { 111 var rows driver.Rows 112 err := s.instrum.withSpan(ctx, "stmt.Query", s.query, 113 func(ctx context.Context, span trace.Span) error { 114 var err error 115 rows, err = fn(ctx, args) 116 return err 117 }) 118 return rows, err 119 } 120 }