sqlite.go (57777B)
1 // Copyright 2017 The Sqlite Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 //go:generate go run generator.go -full-path-comments 6 7 package sqlite // import "modernc.org/sqlite" 8 9 import ( 10 "context" 11 "database/sql" 12 "database/sql/driver" 13 "errors" 14 "fmt" 15 "io" 16 "math" 17 "math/bits" 18 "net/url" 19 "reflect" 20 "runtime" 21 "strconv" 22 "strings" 23 "sync" 24 "sync/atomic" 25 "time" 26 "unsafe" 27 28 "modernc.org/libc" 29 "modernc.org/libc/sys/types" 30 sqlite3 "modernc.org/sqlite/lib" 31 ) 32 33 var ( 34 _ driver.Conn = (*conn)(nil) 35 _ driver.Driver = (*Driver)(nil) 36 //lint:ignore SA1019 TODO implement ExecerContext 37 _ driver.Execer = (*conn)(nil) 38 //lint:ignore SA1019 TODO implement QueryerContext 39 _ driver.Queryer = (*conn)(nil) 40 _ driver.Result = (*result)(nil) 41 _ driver.Rows = (*rows)(nil) 42 _ driver.RowsColumnTypeDatabaseTypeName = (*rows)(nil) 43 _ driver.RowsColumnTypeLength = (*rows)(nil) 44 _ driver.RowsColumnTypeNullable = (*rows)(nil) 45 _ driver.RowsColumnTypePrecisionScale = (*rows)(nil) 46 _ driver.RowsColumnTypeScanType = (*rows)(nil) 47 _ driver.Stmt = (*stmt)(nil) 48 _ driver.Tx = (*tx)(nil) 49 _ error = (*Error)(nil) 50 ) 51 52 const ( 53 driverName = "sqlite" 54 ptrSize = unsafe.Sizeof(uintptr(0)) 55 sqliteLockedSharedcache = sqlite3.SQLITE_LOCKED | (1 << 8) 56 ) 57 58 // Error represents sqlite library error code. 59 type Error struct { 60 msg string 61 code int 62 } 63 64 // Error implements error. 65 func (e *Error) Error() string { return e.msg } 66 67 // Code returns the sqlite result code for this error. 68 func (e *Error) Code() int { return e.code } 69 70 var ( 71 // ErrorCodeString maps Error.Code() to its string representation. 72 ErrorCodeString = map[int]string{ 73 sqlite3.SQLITE_ABORT: "Callback routine requested an abort (SQLITE_ABORT)", 74 sqlite3.SQLITE_AUTH: "Authorization denied (SQLITE_AUTH)", 75 sqlite3.SQLITE_BUSY: "The database file is locked (SQLITE_BUSY)", 76 sqlite3.SQLITE_CANTOPEN: "Unable to open the database file (SQLITE_CANTOPEN)", 77 sqlite3.SQLITE_CONSTRAINT: "Abort due to constraint violation (SQLITE_CONSTRAINT)", 78 sqlite3.SQLITE_CORRUPT: "The database disk image is malformed (SQLITE_CORRUPT)", 79 sqlite3.SQLITE_DONE: "sqlite3_step() has finished executing (SQLITE_DONE)", 80 sqlite3.SQLITE_EMPTY: "Internal use only (SQLITE_EMPTY)", 81 sqlite3.SQLITE_ERROR: "Generic error (SQLITE_ERROR)", 82 sqlite3.SQLITE_FORMAT: "Not used (SQLITE_FORMAT)", 83 sqlite3.SQLITE_FULL: "Insertion failed because database is full (SQLITE_FULL)", 84 sqlite3.SQLITE_INTERNAL: "Internal logic error in SQLite (SQLITE_INTERNAL)", 85 sqlite3.SQLITE_INTERRUPT: "Operation terminated by sqlite3_interrupt()(SQLITE_INTERRUPT)", 86 sqlite3.SQLITE_IOERR | (1 << 8): "(SQLITE_IOERR_READ)", 87 sqlite3.SQLITE_IOERR | (10 << 8): "(SQLITE_IOERR_DELETE)", 88 sqlite3.SQLITE_IOERR | (11 << 8): "(SQLITE_IOERR_BLOCKED)", 89 sqlite3.SQLITE_IOERR | (12 << 8): "(SQLITE_IOERR_NOMEM)", 90 sqlite3.SQLITE_IOERR | (13 << 8): "(SQLITE_IOERR_ACCESS)", 91 sqlite3.SQLITE_IOERR | (14 << 8): "(SQLITE_IOERR_CHECKRESERVEDLOCK)", 92 sqlite3.SQLITE_IOERR | (15 << 8): "(SQLITE_IOERR_LOCK)", 93 sqlite3.SQLITE_IOERR | (16 << 8): "(SQLITE_IOERR_CLOSE)", 94 sqlite3.SQLITE_IOERR | (17 << 8): "(SQLITE_IOERR_DIR_CLOSE)", 95 sqlite3.SQLITE_IOERR | (2 << 8): "(SQLITE_IOERR_SHORT_READ)", 96 sqlite3.SQLITE_IOERR | (3 << 8): "(SQLITE_IOERR_WRITE)", 97 sqlite3.SQLITE_IOERR | (4 << 8): "(SQLITE_IOERR_FSYNC)", 98 sqlite3.SQLITE_IOERR | (5 << 8): "(SQLITE_IOERR_DIR_FSYNC)", 99 sqlite3.SQLITE_IOERR | (6 << 8): "(SQLITE_IOERR_TRUNCATE)", 100 sqlite3.SQLITE_IOERR | (7 << 8): "(SQLITE_IOERR_FSTAT)", 101 sqlite3.SQLITE_IOERR | (8 << 8): "(SQLITE_IOERR_UNLOCK)", 102 sqlite3.SQLITE_IOERR | (9 << 8): "(SQLITE_IOERR_RDLOCK)", 103 sqlite3.SQLITE_IOERR: "Some kind of disk I/O error occurred (SQLITE_IOERR)", 104 sqlite3.SQLITE_LOCKED | (1 << 8): "(SQLITE_LOCKED_SHAREDCACHE)", 105 sqlite3.SQLITE_LOCKED: "A table in the database is locked (SQLITE_LOCKED)", 106 sqlite3.SQLITE_MISMATCH: "Data type mismatch (SQLITE_MISMATCH)", 107 sqlite3.SQLITE_MISUSE: "Library used incorrectly (SQLITE_MISUSE)", 108 sqlite3.SQLITE_NOLFS: "Uses OS features not supported on host (SQLITE_NOLFS)", 109 sqlite3.SQLITE_NOMEM: "A malloc() failed (SQLITE_NOMEM)", 110 sqlite3.SQLITE_NOTADB: "File opened that is not a database file (SQLITE_NOTADB)", 111 sqlite3.SQLITE_NOTFOUND: "Unknown opcode in sqlite3_file_control() (SQLITE_NOTFOUND)", 112 sqlite3.SQLITE_NOTICE: "Notifications from sqlite3_log() (SQLITE_NOTICE)", 113 sqlite3.SQLITE_PERM: "Access permission denied (SQLITE_PERM)", 114 sqlite3.SQLITE_PROTOCOL: "Database lock protocol error (SQLITE_PROTOCOL)", 115 sqlite3.SQLITE_RANGE: "2nd parameter to sqlite3_bind out of range (SQLITE_RANGE)", 116 sqlite3.SQLITE_READONLY: "Attempt to write a readonly database (SQLITE_READONLY)", 117 sqlite3.SQLITE_ROW: "sqlite3_step() has another row ready (SQLITE_ROW)", 118 sqlite3.SQLITE_SCHEMA: "The database schema changed (SQLITE_SCHEMA)", 119 sqlite3.SQLITE_TOOBIG: "String or BLOB exceeds size limit (SQLITE_TOOBIG)", 120 sqlite3.SQLITE_WARNING: "Warnings from sqlite3_log() (SQLITE_WARNING)", 121 } 122 ) 123 124 func init() { 125 sql.Register(driverName, newDriver()) 126 } 127 128 type result struct { 129 lastInsertID int64 130 rowsAffected int 131 } 132 133 func newResult(c *conn) (_ *result, err error) { 134 r := &result{} 135 if r.rowsAffected, err = c.changes(); err != nil { 136 return nil, err 137 } 138 139 if r.lastInsertID, err = c.lastInsertRowID(); err != nil { 140 return nil, err 141 } 142 143 return r, nil 144 } 145 146 // LastInsertId returns the database's auto-generated ID after, for example, an 147 // INSERT into a table with primary key. 148 func (r *result) LastInsertId() (int64, error) { 149 if r == nil { 150 return 0, nil 151 } 152 153 return r.lastInsertID, nil 154 } 155 156 // RowsAffected returns the number of rows affected by the query. 157 func (r *result) RowsAffected() (int64, error) { 158 if r == nil { 159 return 0, nil 160 } 161 162 return int64(r.rowsAffected), nil 163 } 164 165 type rows struct { 166 allocs []uintptr 167 c *conn 168 columns []string 169 pstmt uintptr 170 171 doStep bool 172 empty bool 173 } 174 175 func newRows(c *conn, pstmt uintptr, allocs []uintptr, empty bool) (r *rows, err error) { 176 r = &rows{c: c, pstmt: pstmt, allocs: allocs, empty: empty} 177 178 defer func() { 179 if err != nil { 180 r.Close() 181 r = nil 182 } 183 }() 184 185 n, err := c.columnCount(pstmt) 186 if err != nil { 187 return nil, err 188 } 189 190 r.columns = make([]string, n) 191 for i := range r.columns { 192 if r.columns[i], err = r.c.columnName(pstmt, i); err != nil { 193 return nil, err 194 } 195 } 196 197 return r, nil 198 } 199 200 // Close closes the rows iterator. 201 func (r *rows) Close() (err error) { 202 for _, v := range r.allocs { 203 r.c.free(v) 204 } 205 r.allocs = nil 206 return r.c.finalize(r.pstmt) 207 } 208 209 // Columns returns the names of the columns. The number of columns of the 210 // result is inferred from the length of the slice. If a particular column name 211 // isn't known, an empty string should be returned for that entry. 212 func (r *rows) Columns() (c []string) { 213 return r.columns 214 } 215 216 // Next is called to populate the next row of data into the provided slice. The 217 // provided slice will be the same size as the Columns() are wide. 218 // 219 // Next should return io.EOF when there are no more rows. 220 func (r *rows) Next(dest []driver.Value) (err error) { 221 if r.empty { 222 return io.EOF 223 } 224 225 rc := sqlite3.SQLITE_ROW 226 if r.doStep { 227 if rc, err = r.c.step(r.pstmt); err != nil { 228 return err 229 } 230 } 231 232 r.doStep = true 233 switch rc { 234 case sqlite3.SQLITE_ROW: 235 if g, e := len(dest), len(r.columns); g != e { 236 return fmt.Errorf("sqlite: Next: have %v destination values, expected %v", g, e) 237 } 238 239 for i := range dest { 240 ct, err := r.c.columnType(r.pstmt, i) 241 if err != nil { 242 return err 243 } 244 245 switch ct { 246 case sqlite3.SQLITE_INTEGER: 247 v, err := r.c.columnInt64(r.pstmt, i) 248 if err != nil { 249 return err 250 } 251 252 dest[i] = v 253 case sqlite3.SQLITE_FLOAT: 254 v, err := r.c.columnDouble(r.pstmt, i) 255 if err != nil { 256 return err 257 } 258 259 dest[i] = v 260 case sqlite3.SQLITE_TEXT: 261 v, err := r.c.columnText(r.pstmt, i) 262 if err != nil { 263 return err 264 } 265 266 switch r.ColumnTypeDatabaseTypeName(i) { 267 case "DATE", "DATETIME", "TIMESTAMP": 268 dest[i], _ = r.c.parseTime(v) 269 default: 270 dest[i] = v 271 } 272 case sqlite3.SQLITE_BLOB: 273 v, err := r.c.columnBlob(r.pstmt, i) 274 if err != nil { 275 return err 276 } 277 278 dest[i] = v 279 case sqlite3.SQLITE_NULL: 280 dest[i] = nil 281 default: 282 return fmt.Errorf("internal error: rc %d", rc) 283 } 284 } 285 return nil 286 case sqlite3.SQLITE_DONE: 287 return io.EOF 288 default: 289 return r.c.errstr(int32(rc)) 290 } 291 } 292 293 // Inspired by mattn/go-sqlite3: https://github.com/mattn/go-sqlite3/blob/ab91e934/sqlite3.go#L210-L226 294 // 295 // These time.Parse formats handle formats 1 through 7 listed at https://www.sqlite.org/lang_datefunc.html. 296 var parseTimeFormats = []string{ 297 "2006-01-02 15:04:05.999999999-07:00", 298 "2006-01-02T15:04:05.999999999-07:00", 299 "2006-01-02 15:04:05.999999999", 300 "2006-01-02T15:04:05.999999999", 301 "2006-01-02 15:04", 302 "2006-01-02T15:04", 303 "2006-01-02", 304 } 305 306 // Attempt to parse s as a time. Return (s, false) if s is not 307 // recognized as a valid time encoding. 308 func (c *conn) parseTime(s string) (interface{}, bool) { 309 if v, ok := c.parseTimeString(s, strings.Index(s, "m=")); ok { 310 return v, true 311 } 312 313 ts := strings.TrimSuffix(s, "Z") 314 315 for _, f := range parseTimeFormats { 316 t, err := time.Parse(f, ts) 317 if err == nil { 318 return t, true 319 } 320 } 321 322 return s, false 323 } 324 325 // Attempt to parse s as a time string produced by t.String(). If x > 0 it's 326 // the index of substring "m=" within s. Return (s, false) if s is 327 // not recognized as a valid time encoding. 328 func (c *conn) parseTimeString(s0 string, x int) (interface{}, bool) { 329 s := s0 330 if x > 0 { 331 s = s[:x] // "2006-01-02 15:04:05.999999999 -0700 MST m=+9999" -> "2006-01-02 15:04:05.999999999 -0700 MST " 332 } 333 s = strings.TrimSpace(s) 334 if t, err := time.Parse("2006-01-02 15:04:05.999999999 -0700 MST", s); err == nil { 335 return t, true 336 } 337 338 return s0, false 339 } 340 341 // writeTimeFormats are the names and formats supported 342 // by the `_time_format` DSN query param. 343 var writeTimeFormats = map[string]string{ 344 "sqlite": parseTimeFormats[0], 345 } 346 347 func (c *conn) formatTime(t time.Time) string { 348 // Before configurable write time formats were supported, 349 // time.Time.String was used. Maintain that default to 350 // keep existing driver users formatting times the same. 351 if c.writeTimeFormat == "" { 352 return t.String() 353 } 354 return t.Format(c.writeTimeFormat) 355 } 356 357 // RowsColumnTypeDatabaseTypeName may be implemented by Rows. It should return 358 // the database system type name without the length. Type names should be 359 // uppercase. Examples of returned types: "VARCHAR", "NVARCHAR", "VARCHAR2", 360 // "CHAR", "TEXT", "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", 361 // "JSONB", "XML", "TIMESTAMP". 362 func (r *rows) ColumnTypeDatabaseTypeName(index int) string { 363 return strings.ToUpper(r.c.columnDeclType(r.pstmt, index)) 364 } 365 366 // RowsColumnTypeLength may be implemented by Rows. It should return the length 367 // of the column type if the column is a variable length type. If the column is 368 // not a variable length type ok should return false. If length is not limited 369 // other than system limits, it should return math.MaxInt64. The following are 370 // examples of returned values for various types: 371 // 372 // TEXT (math.MaxInt64, true) 373 // varchar(10) (10, true) 374 // nvarchar(10) (10, true) 375 // decimal (0, false) 376 // int (0, false) 377 // bytea(30) (30, true) 378 func (r *rows) ColumnTypeLength(index int) (length int64, ok bool) { 379 t, err := r.c.columnType(r.pstmt, index) 380 if err != nil { 381 return 0, false 382 } 383 384 switch t { 385 case sqlite3.SQLITE_INTEGER: 386 return 0, false 387 case sqlite3.SQLITE_FLOAT: 388 return 0, false 389 case sqlite3.SQLITE_TEXT: 390 return math.MaxInt64, true 391 case sqlite3.SQLITE_BLOB: 392 return math.MaxInt64, true 393 case sqlite3.SQLITE_NULL: 394 return 0, false 395 default: 396 return 0, false 397 } 398 } 399 400 // RowsColumnTypeNullable may be implemented by Rows. The nullable value should 401 // be true if it is known the column may be null, or false if the column is 402 // known to be not nullable. If the column nullability is unknown, ok should be 403 // false. 404 func (r *rows) ColumnTypeNullable(index int) (nullable, ok bool) { 405 return true, true 406 } 407 408 // RowsColumnTypePrecisionScale may be implemented by Rows. It should return 409 // the precision and scale for decimal types. If not applicable, ok should be 410 // false. The following are examples of returned values for various types: 411 // 412 // decimal(38, 4) (38, 4, true) 413 // int (0, 0, false) 414 // decimal (math.MaxInt64, math.MaxInt64, true) 415 func (r *rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) { 416 return 0, 0, false 417 } 418 419 // RowsColumnTypeScanType may be implemented by Rows. It should return the 420 // value type that can be used to scan types into. For example, the database 421 // column type "bigint" this should return "reflect.TypeOf(int64(0))". 422 func (r *rows) ColumnTypeScanType(index int) reflect.Type { 423 t, err := r.c.columnType(r.pstmt, index) 424 if err != nil { 425 return reflect.TypeOf("") 426 } 427 428 switch t { 429 case sqlite3.SQLITE_INTEGER: 430 switch strings.ToLower(r.c.columnDeclType(r.pstmt, index)) { 431 case "boolean": 432 return reflect.TypeOf(false) 433 case "date", "datetime", "time", "timestamp": 434 return reflect.TypeOf(time.Time{}) 435 default: 436 return reflect.TypeOf(int64(0)) 437 } 438 case sqlite3.SQLITE_FLOAT: 439 return reflect.TypeOf(float64(0)) 440 case sqlite3.SQLITE_TEXT: 441 return reflect.TypeOf("") 442 case sqlite3.SQLITE_BLOB: 443 return reflect.SliceOf(reflect.TypeOf([]byte{})) 444 case sqlite3.SQLITE_NULL: 445 return reflect.TypeOf(nil) 446 default: 447 return reflect.TypeOf("") 448 } 449 } 450 451 type stmt struct { 452 c *conn 453 psql uintptr 454 } 455 456 func newStmt(c *conn, sql string) (*stmt, error) { 457 p, err := libc.CString(sql) 458 if err != nil { 459 return nil, err 460 } 461 stm := stmt{c: c, psql: p} 462 463 return &stm, nil 464 } 465 466 // Close closes the statement. 467 // 468 // As of Go 1.1, a Stmt will not be closed if it's in use by any queries. 469 func (s *stmt) Close() (err error) { 470 s.c.free(s.psql) 471 s.psql = 0 472 return nil 473 } 474 475 // Exec executes a query that doesn't return rows, such as an INSERT or UPDATE. 476 // 477 // 478 // Deprecated: Drivers should implement StmtExecContext instead (or 479 // additionally). 480 func (s *stmt) Exec(args []driver.Value) (driver.Result, error) { //TODO StmtExecContext 481 return s.exec(context.Background(), toNamedValues(args)) 482 } 483 484 // toNamedValues converts []driver.Value to []driver.NamedValue 485 func toNamedValues(vals []driver.Value) (r []driver.NamedValue) { 486 r = make([]driver.NamedValue, len(vals)) 487 for i, val := range vals { 488 r[i] = driver.NamedValue{Value: val, Ordinal: i + 1} 489 } 490 return r 491 } 492 493 func (s *stmt) exec(ctx context.Context, args []driver.NamedValue) (r driver.Result, err error) { 494 var pstmt uintptr 495 var done int32 496 if ctx != nil && ctx.Done() != nil { 497 defer interruptOnDone(ctx, s.c, &done)() 498 } 499 500 for psql := s.psql; *(*byte)(unsafe.Pointer(psql)) != 0 && atomic.LoadInt32(&done) == 0; { 501 if pstmt, err = s.c.prepareV2(&psql); err != nil { 502 return nil, err 503 } 504 505 if pstmt == 0 { 506 continue 507 } 508 err = func() (err error) { 509 n, err := s.c.bindParameterCount(pstmt) 510 if err != nil { 511 return err 512 } 513 514 if n != 0 { 515 allocs, err := s.c.bind(pstmt, n, args) 516 if err != nil { 517 return err 518 } 519 520 if len(allocs) != 0 { 521 defer func() { 522 for _, v := range allocs { 523 s.c.free(v) 524 } 525 }() 526 } 527 } 528 529 rc, err := s.c.step(pstmt) 530 if err != nil { 531 return err 532 } 533 534 switch rc & 0xff { 535 case sqlite3.SQLITE_DONE, sqlite3.SQLITE_ROW: 536 // nop 537 default: 538 return s.c.errstr(int32(rc)) 539 } 540 541 return nil 542 }() 543 544 if e := s.c.finalize(pstmt); e != nil && err == nil { 545 err = e 546 } 547 548 if err != nil { 549 return nil, err 550 } 551 } 552 return newResult(s.c) 553 } 554 555 // NumInput returns the number of placeholder parameters. 556 // 557 // If NumInput returns >= 0, the sql package will sanity check argument counts 558 // from callers and return errors to the caller before the statement's Exec or 559 // Query methods are called. 560 // 561 // NumInput may also return -1, if the driver doesn't know its number of 562 // placeholders. In that case, the sql package will not sanity check Exec or 563 // Query argument counts. 564 func (s *stmt) NumInput() (n int) { 565 return -1 566 } 567 568 // Query executes a query that may return rows, such as a 569 // SELECT. 570 // 571 // Deprecated: Drivers should implement StmtQueryContext instead (or 572 // additionally). 573 func (s *stmt) Query(args []driver.Value) (driver.Rows, error) { //TODO StmtQueryContext 574 return s.query(context.Background(), toNamedValues(args)) 575 } 576 577 func (s *stmt) query(ctx context.Context, args []driver.NamedValue) (r driver.Rows, err error) { 578 var pstmt uintptr 579 var done int32 580 if ctx != nil && ctx.Done() != nil { 581 defer interruptOnDone(ctx, s.c, &done)() 582 } 583 584 var allocs []uintptr 585 for psql := s.psql; *(*byte)(unsafe.Pointer(psql)) != 0 && atomic.LoadInt32(&done) == 0; { 586 if pstmt, err = s.c.prepareV2(&psql); err != nil { 587 return nil, err 588 } 589 590 if pstmt == 0 { 591 continue 592 } 593 594 err = func() (err error) { 595 n, err := s.c.bindParameterCount(pstmt) 596 if err != nil { 597 return err 598 } 599 600 if n != 0 { 601 if allocs, err = s.c.bind(pstmt, n, args); err != nil { 602 return err 603 } 604 } 605 606 rc, err := s.c.step(pstmt) 607 if err != nil { 608 return err 609 } 610 611 switch rc & 0xff { 612 case sqlite3.SQLITE_ROW: 613 if r != nil { 614 r.Close() 615 } 616 if r, err = newRows(s.c, pstmt, allocs, false); err != nil { 617 return err 618 } 619 620 pstmt = 0 621 return nil 622 case sqlite3.SQLITE_DONE: 623 if r == nil { 624 if r, err = newRows(s.c, pstmt, allocs, true); err != nil { 625 return err 626 } 627 pstmt = 0 628 return nil 629 } 630 631 // nop 632 default: 633 return s.c.errstr(int32(rc)) 634 } 635 636 if *(*byte)(unsafe.Pointer(psql)) == 0 { 637 if r != nil { 638 r.Close() 639 } 640 if r, err = newRows(s.c, pstmt, allocs, true); err != nil { 641 return err 642 } 643 644 pstmt = 0 645 } 646 return nil 647 }() 648 if e := s.c.finalize(pstmt); e != nil && err == nil { 649 err = e 650 } 651 652 if err != nil { 653 return nil, err 654 } 655 } 656 return r, err 657 } 658 659 type tx struct { 660 c *conn 661 } 662 663 func newTx(c *conn, opts driver.TxOptions) (*tx, error) { 664 r := &tx{c: c} 665 666 sql := "begin" 667 if !opts.ReadOnly && c.beginMode != "" { 668 sql = "begin " + c.beginMode 669 } 670 671 if err := r.exec(context.Background(), sql); err != nil { 672 return nil, err 673 } 674 675 return r, nil 676 } 677 678 // Commit implements driver.Tx. 679 func (t *tx) Commit() (err error) { 680 return t.exec(context.Background(), "commit") 681 } 682 683 // Rollback implements driver.Tx. 684 func (t *tx) Rollback() (err error) { 685 return t.exec(context.Background(), "rollback") 686 } 687 688 func (t *tx) exec(ctx context.Context, sql string) (err error) { 689 psql, err := libc.CString(sql) 690 if err != nil { 691 return err 692 } 693 694 defer t.c.free(psql) 695 //TODO use t.conn.ExecContext() instead 696 697 if ctx != nil && ctx.Done() != nil { 698 defer interruptOnDone(ctx, t.c, nil)() 699 } 700 701 if rc := sqlite3.Xsqlite3_exec(t.c.tls, t.c.db, psql, 0, 0, 0); rc != sqlite3.SQLITE_OK { 702 return t.c.errstr(rc) 703 } 704 705 return nil 706 } 707 708 // interruptOnDone sets up a goroutine to interrupt the provided db when the 709 // context is canceled, and returns a function the caller must defer so it 710 // doesn't interrupt after the caller finishes. 711 func interruptOnDone( 712 ctx context.Context, 713 c *conn, 714 done *int32, 715 ) func() { 716 if done == nil { 717 var d int32 718 done = &d 719 } 720 721 donech := make(chan struct{}) 722 723 go func() { 724 select { 725 case <-ctx.Done(): 726 // don't call interrupt if we were already done: it indicates that this 727 // call to exec is no longer running and we would be interrupting 728 // nothing, or even possibly an unrelated later call to exec. 729 if atomic.AddInt32(done, 1) == 1 { 730 c.interrupt(c.db) 731 } 732 case <-donech: 733 } 734 }() 735 736 // the caller is expected to defer this function 737 return func() { 738 // set the done flag so that a context cancellation right after the caller 739 // returns doesn't trigger a call to interrupt for some other statement. 740 atomic.AddInt32(done, 1) 741 close(donech) 742 } 743 } 744 745 type conn struct { 746 db uintptr // *sqlite3.Xsqlite3 747 tls *libc.TLS 748 749 // Context handling can cause conn.Close and conn.interrupt to be invoked 750 // concurrently. 751 sync.Mutex 752 753 writeTimeFormat string 754 beginMode string 755 } 756 757 func newConn(dsn string) (*conn, error) { 758 var query, vfsName string 759 760 // Parse the query parameters from the dsn and them from the dsn if not prefixed by file: 761 // https://github.com/mattn/go-sqlite3/blob/3392062c729d77820afc1f5cae3427f0de39e954/sqlite3.go#L1046 762 // https://github.com/mattn/go-sqlite3/blob/3392062c729d77820afc1f5cae3427f0de39e954/sqlite3.go#L1383 763 pos := strings.IndexRune(dsn, '?') 764 if pos >= 1 { 765 query = dsn[pos+1:] 766 var err error 767 vfsName, err = getVFSName(query) 768 if err != nil { 769 return nil, err 770 } 771 772 if !strings.HasPrefix(dsn, "file:") { 773 dsn = dsn[:pos] 774 } 775 } 776 777 c := &conn{tls: libc.NewTLS()} 778 db, err := c.openV2( 779 dsn, 780 vfsName, 781 sqlite3.SQLITE_OPEN_READWRITE|sqlite3.SQLITE_OPEN_CREATE| 782 sqlite3.SQLITE_OPEN_FULLMUTEX| 783 sqlite3.SQLITE_OPEN_URI, 784 ) 785 if err != nil { 786 return nil, err 787 } 788 789 c.db = db 790 if err = c.extendedResultCodes(true); err != nil { 791 c.Close() 792 return nil, err 793 } 794 795 if err = applyQueryParams(c, query); err != nil { 796 c.Close() 797 return nil, err 798 } 799 800 return c, nil 801 } 802 803 func getVFSName(query string) (r string, err error) { 804 q, err := url.ParseQuery(query) 805 if err != nil { 806 return "", err 807 } 808 809 for _, v := range q["vfs"] { 810 if r != "" && r != v { 811 return "", fmt.Errorf("conflicting vfs query parameters: %v", q["vfs"]) 812 } 813 814 r = v 815 } 816 817 return r, nil 818 } 819 820 func applyQueryParams(c *conn, query string) error { 821 q, err := url.ParseQuery(query) 822 if err != nil { 823 return err 824 } 825 826 for _, v := range q["_pragma"] { 827 cmd := "pragma " + v 828 _, err := c.exec(context.Background(), cmd, nil) 829 if err != nil { 830 return err 831 } 832 } 833 834 if v := q.Get("_time_format"); v != "" { 835 f, ok := writeTimeFormats[v] 836 if !ok { 837 return fmt.Errorf("unknown _time_format %q", v) 838 } 839 c.writeTimeFormat = f 840 } 841 842 if v := q.Get("_txlock"); v != "" { 843 lower := strings.ToLower(v) 844 if lower != "deferred" && lower != "immediate" && lower != "exclusive" { 845 return fmt.Errorf("unknown _txlock %q", v) 846 } 847 c.beginMode = v 848 } 849 850 return nil 851 } 852 853 // const void *sqlite3_column_blob(sqlite3_stmt*, int iCol); 854 func (c *conn) columnBlob(pstmt uintptr, iCol int) (v []byte, err error) { 855 p := sqlite3.Xsqlite3_column_blob(c.tls, pstmt, int32(iCol)) 856 len, err := c.columnBytes(pstmt, iCol) 857 if err != nil { 858 return nil, err 859 } 860 861 if p == 0 || len == 0 { 862 return nil, nil 863 } 864 865 v = make([]byte, len) 866 copy(v, (*libc.RawMem)(unsafe.Pointer(p))[:len:len]) 867 return v, nil 868 } 869 870 // int sqlite3_column_bytes(sqlite3_stmt*, int iCol); 871 func (c *conn) columnBytes(pstmt uintptr, iCol int) (_ int, err error) { 872 v := sqlite3.Xsqlite3_column_bytes(c.tls, pstmt, int32(iCol)) 873 return int(v), nil 874 } 875 876 // const unsigned char *sqlite3_column_text(sqlite3_stmt*, int iCol); 877 func (c *conn) columnText(pstmt uintptr, iCol int) (v string, err error) { 878 p := sqlite3.Xsqlite3_column_text(c.tls, pstmt, int32(iCol)) 879 len, err := c.columnBytes(pstmt, iCol) 880 if err != nil { 881 return "", err 882 } 883 884 if p == 0 || len == 0 { 885 return "", nil 886 } 887 888 b := make([]byte, len) 889 copy(b, (*libc.RawMem)(unsafe.Pointer(p))[:len:len]) 890 return string(b), nil 891 } 892 893 // double sqlite3_column_double(sqlite3_stmt*, int iCol); 894 func (c *conn) columnDouble(pstmt uintptr, iCol int) (v float64, err error) { 895 v = sqlite3.Xsqlite3_column_double(c.tls, pstmt, int32(iCol)) 896 return v, nil 897 } 898 899 // sqlite3_int64 sqlite3_column_int64(sqlite3_stmt*, int iCol); 900 func (c *conn) columnInt64(pstmt uintptr, iCol int) (v int64, err error) { 901 v = sqlite3.Xsqlite3_column_int64(c.tls, pstmt, int32(iCol)) 902 return v, nil 903 } 904 905 // int sqlite3_column_type(sqlite3_stmt*, int iCol); 906 func (c *conn) columnType(pstmt uintptr, iCol int) (_ int, err error) { 907 v := sqlite3.Xsqlite3_column_type(c.tls, pstmt, int32(iCol)) 908 return int(v), nil 909 } 910 911 // const char *sqlite3_column_decltype(sqlite3_stmt*,int); 912 func (c *conn) columnDeclType(pstmt uintptr, iCol int) string { 913 return libc.GoString(sqlite3.Xsqlite3_column_decltype(c.tls, pstmt, int32(iCol))) 914 } 915 916 // const char *sqlite3_column_name(sqlite3_stmt*, int N); 917 func (c *conn) columnName(pstmt uintptr, n int) (string, error) { 918 p := sqlite3.Xsqlite3_column_name(c.tls, pstmt, int32(n)) 919 return libc.GoString(p), nil 920 } 921 922 // int sqlite3_column_count(sqlite3_stmt *pStmt); 923 func (c *conn) columnCount(pstmt uintptr) (_ int, err error) { 924 v := sqlite3.Xsqlite3_column_count(c.tls, pstmt) 925 return int(v), nil 926 } 927 928 // sqlite3_int64 sqlite3_last_insert_rowid(sqlite3*); 929 func (c *conn) lastInsertRowID() (v int64, _ error) { 930 return sqlite3.Xsqlite3_last_insert_rowid(c.tls, c.db), nil 931 } 932 933 // int sqlite3_changes(sqlite3*); 934 func (c *conn) changes() (int, error) { 935 v := sqlite3.Xsqlite3_changes(c.tls, c.db) 936 return int(v), nil 937 } 938 939 // int sqlite3_step(sqlite3_stmt*); 940 func (c *conn) step(pstmt uintptr) (int, error) { 941 for { 942 switch rc := sqlite3.Xsqlite3_step(c.tls, pstmt); rc { 943 case sqliteLockedSharedcache: 944 if err := c.retry(pstmt); err != nil { 945 return sqlite3.SQLITE_LOCKED, err 946 } 947 case 948 sqlite3.SQLITE_DONE, 949 sqlite3.SQLITE_ROW: 950 951 return int(rc), nil 952 default: 953 return int(rc), c.errstr(rc) 954 } 955 } 956 } 957 958 func (c *conn) retry(pstmt uintptr) error { 959 mu := mutexAlloc(c.tls) 960 (*mutex)(unsafe.Pointer(mu)).Lock() 961 rc := sqlite3.Xsqlite3_unlock_notify( 962 c.tls, 963 c.db, 964 *(*uintptr)(unsafe.Pointer(&struct { 965 f func(*libc.TLS, uintptr, int32) 966 }{unlockNotify})), 967 mu, 968 ) 969 if rc == sqlite3.SQLITE_LOCKED { // Deadlock, see https://www.sqlite.org/c3ref/unlock_notify.html 970 (*mutex)(unsafe.Pointer(mu)).Unlock() 971 mutexFree(c.tls, mu) 972 return c.errstr(rc) 973 } 974 975 (*mutex)(unsafe.Pointer(mu)).Lock() 976 (*mutex)(unsafe.Pointer(mu)).Unlock() 977 mutexFree(c.tls, mu) 978 if pstmt != 0 { 979 sqlite3.Xsqlite3_reset(c.tls, pstmt) 980 } 981 return nil 982 } 983 984 func unlockNotify(t *libc.TLS, ppArg uintptr, nArg int32) { 985 for i := int32(0); i < nArg; i++ { 986 mu := *(*uintptr)(unsafe.Pointer(ppArg)) 987 (*mutex)(unsafe.Pointer(mu)).Unlock() 988 ppArg += ptrSize 989 } 990 } 991 992 func (c *conn) bind(pstmt uintptr, n int, args []driver.NamedValue) (allocs []uintptr, err error) { 993 defer func() { 994 if err == nil { 995 return 996 } 997 998 for _, v := range allocs { 999 c.free(v) 1000 } 1001 allocs = nil 1002 }() 1003 1004 for i := 1; i <= n; i++ { 1005 name, err := c.bindParameterName(pstmt, i) 1006 if err != nil { 1007 return allocs, err 1008 } 1009 1010 var found bool 1011 var v driver.NamedValue 1012 for _, v = range args { 1013 if name != "" { 1014 // For ?NNN and $NNN params, match if NNN == v.Ordinal. 1015 // 1016 // Supporting this for $NNN is a special case that makes eg 1017 // `select $1, $2, $3 ...` work without needing to use 1018 // sql.Named. 1019 if (name[0] == '?' || name[0] == '$') && name[1:] == strconv.Itoa(v.Ordinal) { 1020 found = true 1021 break 1022 } 1023 1024 // sqlite supports '$', '@' and ':' prefixes for string 1025 // identifiers and '?' for numeric, so we cannot 1026 // combine different prefixes with the same name 1027 // because `database/sql` requires variable names 1028 // to start with a letter 1029 if name[1:] == v.Name[:] { 1030 found = true 1031 break 1032 } 1033 } else { 1034 if v.Ordinal == i { 1035 found = true 1036 break 1037 } 1038 } 1039 } 1040 1041 if !found { 1042 if name != "" { 1043 return allocs, fmt.Errorf("missing named argument %q", name[1:]) 1044 } 1045 1046 return allocs, fmt.Errorf("missing argument with index %d", i) 1047 } 1048 1049 var p uintptr 1050 switch x := v.Value.(type) { 1051 case int64: 1052 if err := c.bindInt64(pstmt, i, x); err != nil { 1053 return allocs, err 1054 } 1055 case float64: 1056 if err := c.bindDouble(pstmt, i, x); err != nil { 1057 return allocs, err 1058 } 1059 case bool: 1060 v := 0 1061 if x { 1062 v = 1 1063 } 1064 if err := c.bindInt(pstmt, i, v); err != nil { 1065 return allocs, err 1066 } 1067 case []byte: 1068 if p, err = c.bindBlob(pstmt, i, x); err != nil { 1069 return allocs, err 1070 } 1071 case string: 1072 if p, err = c.bindText(pstmt, i, x); err != nil { 1073 return allocs, err 1074 } 1075 case time.Time: 1076 if p, err = c.bindText(pstmt, i, c.formatTime(x)); err != nil { 1077 return allocs, err 1078 } 1079 case nil: 1080 if p, err = c.bindNull(pstmt, i); err != nil { 1081 return allocs, err 1082 } 1083 default: 1084 return allocs, fmt.Errorf("sqlite: invalid driver.Value type %T", x) 1085 } 1086 if p != 0 { 1087 allocs = append(allocs, p) 1088 } 1089 } 1090 return allocs, nil 1091 } 1092 1093 // int sqlite3_bind_null(sqlite3_stmt*, int); 1094 func (c *conn) bindNull(pstmt uintptr, idx1 int) (uintptr, error) { 1095 if rc := sqlite3.Xsqlite3_bind_null(c.tls, pstmt, int32(idx1)); rc != sqlite3.SQLITE_OK { 1096 return 0, c.errstr(rc) 1097 } 1098 1099 return 0, nil 1100 } 1101 1102 // int sqlite3_bind_text(sqlite3_stmt*,int,const char*,int,void(*)(void*)); 1103 func (c *conn) bindText(pstmt uintptr, idx1 int, value string) (uintptr, error) { 1104 p, err := libc.CString(value) 1105 if err != nil { 1106 return 0, err 1107 } 1108 1109 if rc := sqlite3.Xsqlite3_bind_text(c.tls, pstmt, int32(idx1), p, int32(len(value)), 0); rc != sqlite3.SQLITE_OK { 1110 c.free(p) 1111 return 0, c.errstr(rc) 1112 } 1113 1114 return p, nil 1115 } 1116 1117 // int sqlite3_bind_blob(sqlite3_stmt*, int, const void*, int n, void(*)(void*)); 1118 func (c *conn) bindBlob(pstmt uintptr, idx1 int, value []byte) (uintptr, error) { 1119 if value != nil && len(value) == 0 { 1120 if rc := sqlite3.Xsqlite3_bind_zeroblob(c.tls, pstmt, int32(idx1), 0); rc != sqlite3.SQLITE_OK { 1121 return 0, c.errstr(rc) 1122 } 1123 return 0, nil 1124 } 1125 1126 p, err := c.malloc(len(value)) 1127 if err != nil { 1128 return 0, err 1129 } 1130 if len(value) != 0 { 1131 copy((*libc.RawMem)(unsafe.Pointer(p))[:len(value):len(value)], value) 1132 } 1133 if rc := sqlite3.Xsqlite3_bind_blob(c.tls, pstmt, int32(idx1), p, int32(len(value)), 0); rc != sqlite3.SQLITE_OK { 1134 c.free(p) 1135 return 0, c.errstr(rc) 1136 } 1137 1138 return p, nil 1139 } 1140 1141 // int sqlite3_bind_int(sqlite3_stmt*, int, int); 1142 func (c *conn) bindInt(pstmt uintptr, idx1, value int) (err error) { 1143 if rc := sqlite3.Xsqlite3_bind_int(c.tls, pstmt, int32(idx1), int32(value)); rc != sqlite3.SQLITE_OK { 1144 return c.errstr(rc) 1145 } 1146 1147 return nil 1148 } 1149 1150 // int sqlite3_bind_double(sqlite3_stmt*, int, double); 1151 func (c *conn) bindDouble(pstmt uintptr, idx1 int, value float64) (err error) { 1152 if rc := sqlite3.Xsqlite3_bind_double(c.tls, pstmt, int32(idx1), value); rc != 0 { 1153 return c.errstr(rc) 1154 } 1155 1156 return nil 1157 } 1158 1159 // int sqlite3_bind_int64(sqlite3_stmt*, int, sqlite3_int64); 1160 func (c *conn) bindInt64(pstmt uintptr, idx1 int, value int64) (err error) { 1161 if rc := sqlite3.Xsqlite3_bind_int64(c.tls, pstmt, int32(idx1), value); rc != sqlite3.SQLITE_OK { 1162 return c.errstr(rc) 1163 } 1164 1165 return nil 1166 } 1167 1168 // const char *sqlite3_bind_parameter_name(sqlite3_stmt*, int); 1169 func (c *conn) bindParameterName(pstmt uintptr, i int) (string, error) { 1170 p := sqlite3.Xsqlite3_bind_parameter_name(c.tls, pstmt, int32(i)) 1171 return libc.GoString(p), nil 1172 } 1173 1174 // int sqlite3_bind_parameter_count(sqlite3_stmt*); 1175 func (c *conn) bindParameterCount(pstmt uintptr) (_ int, err error) { 1176 r := sqlite3.Xsqlite3_bind_parameter_count(c.tls, pstmt) 1177 return int(r), nil 1178 } 1179 1180 // int sqlite3_finalize(sqlite3_stmt *pStmt); 1181 func (c *conn) finalize(pstmt uintptr) error { 1182 if rc := sqlite3.Xsqlite3_finalize(c.tls, pstmt); rc != sqlite3.SQLITE_OK { 1183 return c.errstr(rc) 1184 } 1185 1186 return nil 1187 } 1188 1189 // int sqlite3_prepare_v2( 1190 // sqlite3 *db, /* Database handle */ 1191 // const char *zSql, /* SQL statement, UTF-8 encoded */ 1192 // int nByte, /* Maximum length of zSql in bytes. */ 1193 // sqlite3_stmt **ppStmt, /* OUT: Statement handle */ 1194 // const char **pzTail /* OUT: Pointer to unused portion of zSql */ 1195 // ); 1196 func (c *conn) prepareV2(zSQL *uintptr) (pstmt uintptr, err error) { 1197 var ppstmt, pptail uintptr 1198 1199 defer func() { 1200 c.free(ppstmt) 1201 c.free(pptail) 1202 }() 1203 1204 if ppstmt, err = c.malloc(int(ptrSize)); err != nil { 1205 return 0, err 1206 } 1207 1208 if pptail, err = c.malloc(int(ptrSize)); err != nil { 1209 return 0, err 1210 } 1211 1212 for { 1213 switch rc := sqlite3.Xsqlite3_prepare_v2(c.tls, c.db, *zSQL, -1, ppstmt, pptail); rc { 1214 case sqlite3.SQLITE_OK: 1215 *zSQL = *(*uintptr)(unsafe.Pointer(pptail)) 1216 return *(*uintptr)(unsafe.Pointer(ppstmt)), nil 1217 case sqliteLockedSharedcache: 1218 if err := c.retry(0); err != nil { 1219 return 0, err 1220 } 1221 default: 1222 return 0, c.errstr(rc) 1223 } 1224 } 1225 } 1226 1227 // void sqlite3_interrupt(sqlite3*); 1228 func (c *conn) interrupt(pdb uintptr) (err error) { 1229 c.Lock() // Defend against race with .Close invoked by context handling. 1230 1231 defer c.Unlock() 1232 1233 if c.tls != nil { 1234 sqlite3.Xsqlite3_interrupt(c.tls, pdb) 1235 } 1236 return nil 1237 } 1238 1239 // int sqlite3_extended_result_codes(sqlite3*, int onoff); 1240 func (c *conn) extendedResultCodes(on bool) error { 1241 if rc := sqlite3.Xsqlite3_extended_result_codes(c.tls, c.db, libc.Bool32(on)); rc != sqlite3.SQLITE_OK { 1242 return c.errstr(rc) 1243 } 1244 1245 return nil 1246 } 1247 1248 // int sqlite3_open_v2( 1249 // const char *filename, /* Database filename (UTF-8) */ 1250 // sqlite3 **ppDb, /* OUT: SQLite db handle */ 1251 // int flags, /* Flags */ 1252 // const char *zVfs /* Name of VFS module to use */ 1253 // ); 1254 func (c *conn) openV2(name, vfsName string, flags int32) (uintptr, error) { 1255 var p, s, vfs uintptr 1256 1257 defer func() { 1258 if p != 0 { 1259 c.free(p) 1260 } 1261 if s != 0 { 1262 c.free(s) 1263 } 1264 if vfs != 0 { 1265 c.free(vfs) 1266 } 1267 }() 1268 1269 p, err := c.malloc(int(ptrSize)) 1270 if err != nil { 1271 return 0, err 1272 } 1273 1274 if s, err = libc.CString(name); err != nil { 1275 return 0, err 1276 } 1277 1278 if vfsName != "" { 1279 if vfs, err = libc.CString(vfsName); err != nil { 1280 return 0, err 1281 } 1282 } 1283 1284 if rc := sqlite3.Xsqlite3_open_v2(c.tls, s, p, flags, vfs); rc != sqlite3.SQLITE_OK { 1285 return 0, c.errstr(rc) 1286 } 1287 1288 return *(*uintptr)(unsafe.Pointer(p)), nil 1289 } 1290 1291 func (c *conn) malloc(n int) (uintptr, error) { 1292 if p := libc.Xmalloc(c.tls, types.Size_t(n)); p != 0 || n == 0 { 1293 return p, nil 1294 } 1295 1296 return 0, fmt.Errorf("sqlite: cannot allocate %d bytes of memory", n) 1297 } 1298 1299 func (c *conn) free(p uintptr) { 1300 if p != 0 { 1301 libc.Xfree(c.tls, p) 1302 } 1303 } 1304 1305 // const char *sqlite3_errstr(int); 1306 func (c *conn) errstr(rc int32) error { 1307 p := sqlite3.Xsqlite3_errstr(c.tls, rc) 1308 str := libc.GoString(p) 1309 p = sqlite3.Xsqlite3_errmsg(c.tls, c.db) 1310 var s string 1311 if rc == sqlite3.SQLITE_BUSY { 1312 s = " (SQLITE_BUSY)" 1313 } 1314 switch msg := libc.GoString(p); { 1315 case msg == str: 1316 return &Error{msg: fmt.Sprintf("%s (%v)%s", str, rc, s), code: int(rc)} 1317 default: 1318 return &Error{msg: fmt.Sprintf("%s: %s (%v)%s", str, msg, rc, s), code: int(rc)} 1319 } 1320 } 1321 1322 // Begin starts a transaction. 1323 // 1324 // Deprecated: Drivers should implement ConnBeginTx instead (or additionally). 1325 func (c *conn) Begin() (dt driver.Tx, err error) { 1326 if dmesgs { 1327 defer func() { 1328 dmesg("conn %p: (driver.Tx %p, err %v)", c, dt, err) 1329 }() 1330 } 1331 return c.begin(context.Background(), driver.TxOptions{}) 1332 } 1333 1334 func (c *conn) begin(ctx context.Context, opts driver.TxOptions) (t driver.Tx, err error) { 1335 return newTx(c, opts) 1336 } 1337 1338 // Close invalidates and potentially stops any current prepared statements and 1339 // transactions, marking this connection as no longer in use. 1340 // 1341 // Because the sql package maintains a free pool of connections and only calls 1342 // Close when there's a surplus of idle connections, it shouldn't be necessary 1343 // for drivers to do their own connection caching. 1344 func (c *conn) Close() (err error) { 1345 if dmesgs { 1346 defer func() { 1347 dmesg("conn %p: err %v", c, err) 1348 }() 1349 } 1350 c.Lock() // Defend against race with .interrupt invoked by context handling. 1351 1352 defer c.Unlock() 1353 1354 if c.db != 0 { 1355 if err := c.closeV2(c.db); err != nil { 1356 return err 1357 } 1358 1359 c.db = 0 1360 } 1361 1362 if c.tls != nil { 1363 c.tls.Close() 1364 c.tls = nil 1365 } 1366 return nil 1367 } 1368 1369 // int sqlite3_close_v2(sqlite3*); 1370 func (c *conn) closeV2(db uintptr) error { 1371 if rc := sqlite3.Xsqlite3_close_v2(c.tls, db); rc != sqlite3.SQLITE_OK { 1372 return c.errstr(rc) 1373 } 1374 1375 return nil 1376 } 1377 1378 // FunctionImpl describes an [application-defined SQL function]. If Scalar is 1379 // set, it is treated as a scalar function; otherwise, it is treated as an 1380 // aggregate function using MakeAggregate. 1381 // 1382 // [application-defined SQL function]: https://sqlite.org/appfunc.html 1383 type FunctionImpl struct { 1384 // NArgs is the required number of arguments that the function accepts. 1385 // If NArgs is negative, then the function is variadic. 1386 NArgs int32 1387 1388 // If Deterministic is true, the function must always give the same 1389 // output when the input parameters are the same. This enables functions 1390 // to be used in additional contexts like the WHERE clause of partial 1391 // indexes and enables additional optimizations. 1392 // 1393 // See https://sqlite.org/c3ref/c_deterministic.html#sqlitedeterministic 1394 // for more details. 1395 Deterministic bool 1396 1397 // Scalar is called when a scalar function is invoked in SQL. The 1398 // argument Values are not valid past the return of the function. 1399 Scalar func(ctx *FunctionContext, args []driver.Value) (driver.Value, error) 1400 1401 // MakeAggregate is called at the beginning of each evaluation of an 1402 // aggregate function. 1403 MakeAggregate func(ctx FunctionContext) (AggregateFunction, error) 1404 } 1405 1406 // An AggregateFunction is an invocation of an aggregate or window function. See 1407 // the documentation for [aggregate function callbacks] and [application-defined 1408 // window functions] for an overview. 1409 // 1410 // [aggregate function callbacks]: https://www.sqlite.org/appfunc.html#the_aggregate_function_callbacks 1411 // [application-defined window functions]: https://www.sqlite.org/windowfunctions.html#user_defined_aggregate_window_functions 1412 type AggregateFunction interface { 1413 // Step is called for each row of an aggregate function's SQL 1414 // invocation. The argument Values are not valid past the return of the 1415 // function. 1416 Step(ctx *FunctionContext, rowArgs []driver.Value) error 1417 1418 // WindowInverse is called to remove the oldest presently aggregated 1419 // result of Step from the current window. The arguments are those 1420 // passed to Step for the row being removed. The argument Values are not 1421 // valid past the return of the function. 1422 WindowInverse(ctx *FunctionContext, rowArgs []driver.Value) error 1423 1424 // WindowValue is called to get the current value of an aggregate 1425 // function. This is used to return the final value of the function, 1426 // whether it is used as a window function or not. 1427 WindowValue(ctx *FunctionContext) (driver.Value, error) 1428 1429 // Final is called after all of the aggregate function's input rows have 1430 // been stepped through. No other methods will be called on the 1431 // AggregateFunction after calling Final. WindowValue returns the value 1432 // from the function. 1433 Final(ctx *FunctionContext) 1434 } 1435 1436 type userDefinedFunction struct { 1437 zFuncName uintptr 1438 nArg int32 1439 eTextRep int32 1440 pApp uintptr 1441 1442 scalar bool 1443 freeOnce sync.Once 1444 } 1445 1446 func (c *conn) createFunctionInternal(fun *userDefinedFunction) error { 1447 var rc int32 1448 1449 if fun.scalar { 1450 rc = sqlite3.Xsqlite3_create_function( 1451 c.tls, 1452 c.db, 1453 fun.zFuncName, 1454 fun.nArg, 1455 fun.eTextRep, 1456 fun.pApp, 1457 cFuncPointer(funcTrampoline), 1458 0, 1459 0, 1460 ) 1461 } else { 1462 rc = sqlite3.Xsqlite3_create_window_function( 1463 c.tls, 1464 c.db, 1465 fun.zFuncName, 1466 fun.nArg, 1467 fun.eTextRep, 1468 fun.pApp, 1469 cFuncPointer(stepTrampoline), 1470 cFuncPointer(finalTrampoline), 1471 cFuncPointer(valueTrampoline), 1472 cFuncPointer(inverseTrampoline), 1473 0, 1474 ) 1475 } 1476 1477 if rc != sqlite3.SQLITE_OK { 1478 return c.errstr(rc) 1479 } 1480 return nil 1481 } 1482 1483 // Execer is an optional interface that may be implemented by a Conn. 1484 // 1485 // If a Conn does not implement Execer, the sql package's DB.Exec will first 1486 // prepare a query, execute the statement, and then close the statement. 1487 // 1488 // Exec may return ErrSkip. 1489 // 1490 // Deprecated: Drivers should implement ExecerContext instead. 1491 func (c *conn) Exec(query string, args []driver.Value) (dr driver.Result, err error) { 1492 if dmesgs { 1493 defer func() { 1494 dmesg("conn %p, query %q, args %v: (driver.Result %p, err %v)", c, query, args, dr, err) 1495 }() 1496 } 1497 return c.exec(context.Background(), query, toNamedValues(args)) 1498 } 1499 1500 func (c *conn) exec(ctx context.Context, query string, args []driver.NamedValue) (r driver.Result, err error) { 1501 s, err := c.prepare(ctx, query) 1502 if err != nil { 1503 return nil, err 1504 } 1505 1506 defer func() { 1507 if err2 := s.Close(); err2 != nil && err == nil { 1508 err = err2 1509 } 1510 }() 1511 1512 return s.(*stmt).exec(ctx, args) 1513 } 1514 1515 // Prepare returns a prepared statement, bound to this connection. 1516 func (c *conn) Prepare(query string) (ds driver.Stmt, err error) { 1517 if dmesgs { 1518 defer func() { 1519 dmesg("conn %p, query %q: (driver.Stmt %p, err %v)", c, query, ds, err) 1520 }() 1521 } 1522 return c.prepare(context.Background(), query) 1523 } 1524 1525 func (c *conn) prepare(ctx context.Context, query string) (s driver.Stmt, err error) { 1526 //TODO use ctx 1527 return newStmt(c, query) 1528 } 1529 1530 // Queryer is an optional interface that may be implemented by a Conn. 1531 // 1532 // If a Conn does not implement Queryer, the sql package's DB.Query will first 1533 // prepare a query, execute the statement, and then close the statement. 1534 // 1535 // Query may return ErrSkip. 1536 // 1537 // Deprecated: Drivers should implement QueryerContext instead. 1538 func (c *conn) Query(query string, args []driver.Value) (dr driver.Rows, err error) { 1539 if dmesgs { 1540 defer func() { 1541 dmesg("conn %p, query %q, args %v: (driver.Rows %p, err %v)", c, query, args, dr, err) 1542 }() 1543 } 1544 return c.query(context.Background(), query, toNamedValues(args)) 1545 } 1546 1547 func (c *conn) query(ctx context.Context, query string, args []driver.NamedValue) (r driver.Rows, err error) { 1548 s, err := c.prepare(ctx, query) 1549 if err != nil { 1550 return nil, err 1551 } 1552 1553 defer func() { 1554 if err2 := s.Close(); err2 != nil && err == nil { 1555 err = err2 1556 } 1557 }() 1558 1559 return s.(*stmt).query(ctx, args) 1560 } 1561 1562 // Driver implements database/sql/driver.Driver. 1563 type Driver struct { 1564 // user defined functions that are added to every new connection on Open 1565 udfs map[string]*userDefinedFunction 1566 } 1567 1568 var d = &Driver{udfs: make(map[string]*userDefinedFunction)} 1569 1570 func newDriver() *Driver { return d } 1571 1572 // Open returns a new connection to the database. The name is a string in a 1573 // driver-specific format. 1574 // 1575 // Open may return a cached connection (one previously closed), but doing so is 1576 // unnecessary; the sql package maintains a pool of idle connections for 1577 // efficient re-use. 1578 // 1579 // The returned connection is only used by one goroutine at a time. 1580 // 1581 // If name contains a '?', what follows is treated as a query string. This 1582 // driver supports the following query parameters: 1583 // 1584 // _pragma: Each value will be run as a "PRAGMA ..." statement (with the PRAGMA 1585 // keyword added for you). May be specified more than once. Example: 1586 // "_pragma=foreign_keys(1)" will enable foreign key enforcement. More 1587 // information on supported PRAGMAs is available from the SQLite documentation: 1588 // https://www.sqlite.org/pragma.html 1589 // 1590 // _time_format: The name of a format to use when writing time values to the 1591 // database. Currently the only supported value is "sqlite", which corresponds 1592 // to format 7 from https://www.sqlite.org/lang_datefunc.html#time_values, 1593 // including the timezone specifier. If this parameter is not specified, then 1594 // the default String() format will be used. 1595 // 1596 // _txlock: The locking behavior to use when beginning a transaction. May be 1597 // "deferred", "immediate", or "exclusive" (case insensitive). The default is to 1598 // not specify one, which SQLite maps to "deferred". More information is 1599 // available at 1600 // https://www.sqlite.org/lang_transaction.html#deferred_immediate_and_exclusive_transactions 1601 func (d *Driver) Open(name string) (conn driver.Conn, err error) { 1602 if dmesgs { 1603 defer func() { 1604 dmesg("name %q: (driver.Conn %p, err %v)", name, conn, err) 1605 }() 1606 } 1607 c, err := newConn(name) 1608 if err != nil { 1609 return nil, err 1610 } 1611 1612 for _, udf := range d.udfs { 1613 if err = c.createFunctionInternal(udf); err != nil { 1614 c.Close() 1615 return nil, err 1616 } 1617 } 1618 return c, nil 1619 } 1620 1621 // FunctionContext represents the context user defined functions execute in. 1622 // Fields and/or methods of this type may get addedd in the future. 1623 type FunctionContext struct { 1624 tls *libc.TLS 1625 ctx uintptr 1626 } 1627 1628 const sqliteValPtrSize = unsafe.Sizeof(&sqlite3.Sqlite3_value{}) 1629 1630 // RegisterFunction registers a function named zFuncName with nArg arguments. 1631 // Passing -1 for nArg indicates the function is variadic. The FunctionImpl 1632 // determines whether the function is deterministic or not, and whether it is a 1633 // scalar function (when Scalar is defined) or an aggregate function (when 1634 // Scalar is not defined and MakeAggregate is defined). 1635 // 1636 // The new function will be available to all new connections opened after 1637 // executing RegisterFunction. 1638 func RegisterFunction( 1639 zFuncName string, 1640 impl *FunctionImpl, 1641 ) error { 1642 return registerFunction(zFuncName, impl) 1643 } 1644 1645 // MustRegisterFunction is like RegisterFunction but panics on error. 1646 func MustRegisterFunction( 1647 zFuncName string, 1648 impl *FunctionImpl, 1649 ) { 1650 if err := RegisterFunction(zFuncName, impl); err != nil { 1651 panic(err) 1652 } 1653 } 1654 1655 // RegisterScalarFunction registers a scalar function named zFuncName with nArg 1656 // arguments. Passing -1 for nArg indicates the function is variadic. 1657 // 1658 // The new function will be available to all new connections opened after 1659 // executing RegisterScalarFunction. 1660 func RegisterScalarFunction( 1661 zFuncName string, 1662 nArg int32, 1663 xFunc func(ctx *FunctionContext, args []driver.Value) (driver.Value, error), 1664 ) (err error) { 1665 if dmesgs { 1666 defer func() { 1667 dmesg("zFuncName %q, nArg %v, xFunc %p: err %v", zFuncName, nArg, xFunc, err) 1668 }() 1669 } 1670 return registerFunction(zFuncName, &FunctionImpl{NArgs: nArg, Scalar: xFunc, Deterministic: false}) 1671 } 1672 1673 // MustRegisterScalarFunction is like RegisterScalarFunction but panics on 1674 // error. 1675 func MustRegisterScalarFunction( 1676 zFuncName string, 1677 nArg int32, 1678 xFunc func(ctx *FunctionContext, args []driver.Value) (driver.Value, error), 1679 ) { 1680 if dmesgs { 1681 dmesg("zFuncName %q, nArg %v, xFunc %p", zFuncName, nArg, xFunc) 1682 } 1683 if err := RegisterScalarFunction(zFuncName, nArg, xFunc); err != nil { 1684 panic(err) 1685 } 1686 } 1687 1688 // MustRegisterDeterministicScalarFunction is like 1689 // RegisterDeterministicScalarFunction but panics on error. 1690 func MustRegisterDeterministicScalarFunction( 1691 zFuncName string, 1692 nArg int32, 1693 xFunc func(ctx *FunctionContext, args []driver.Value) (driver.Value, error), 1694 ) { 1695 if dmesgs { 1696 dmesg("zFuncName %q, nArg %v, xFunc %p", zFuncName, nArg, xFunc) 1697 } 1698 if err := RegisterDeterministicScalarFunction(zFuncName, nArg, xFunc); err != nil { 1699 panic(err) 1700 } 1701 } 1702 1703 // RegisterDeterministicScalarFunction registers a deterministic scalar 1704 // function named zFuncName with nArg arguments. Passing -1 for nArg indicates 1705 // the function is variadic. A deterministic function means that the function 1706 // always gives the same output when the input parameters are the same. 1707 // 1708 // The new function will be available to all new connections opened after 1709 // executing RegisterDeterministicScalarFunction. 1710 func RegisterDeterministicScalarFunction( 1711 zFuncName string, 1712 nArg int32, 1713 xFunc func(ctx *FunctionContext, args []driver.Value) (driver.Value, error), 1714 ) (err error) { 1715 if dmesgs { 1716 defer func() { 1717 dmesg("zFuncName %q, nArg %v, xFunc %p: err %v", zFuncName, nArg, xFunc, err) 1718 }() 1719 } 1720 return registerFunction(zFuncName, &FunctionImpl{NArgs: nArg, Scalar: xFunc, Deterministic: true}) 1721 } 1722 1723 func registerFunction( 1724 zFuncName string, 1725 impl *FunctionImpl, 1726 ) error { 1727 1728 if _, ok := d.udfs[zFuncName]; ok { 1729 return fmt.Errorf("a function named %q is already registered", zFuncName) 1730 } 1731 1732 // dont free, functions registered on the driver live as long as the program 1733 name, err := libc.CString(zFuncName) 1734 if err != nil { 1735 return err 1736 } 1737 1738 var textrep int32 = sqlite3.SQLITE_UTF8 1739 1740 if impl.Deterministic { 1741 textrep |= sqlite3.SQLITE_DETERMINISTIC 1742 } 1743 1744 udf := &userDefinedFunction{ 1745 zFuncName: name, 1746 nArg: impl.NArgs, 1747 eTextRep: textrep, 1748 } 1749 1750 if impl.Scalar != nil { 1751 xFuncs.mu.Lock() 1752 id := xFuncs.ids.next() 1753 xFuncs.m[id] = impl.Scalar 1754 xFuncs.mu.Unlock() 1755 1756 udf.scalar = true 1757 udf.pApp = id 1758 } else { 1759 xAggregateFactories.mu.Lock() 1760 id := xAggregateFactories.ids.next() 1761 xAggregateFactories.m[id] = impl.MakeAggregate 1762 xAggregateFactories.mu.Unlock() 1763 1764 udf.pApp = id 1765 } 1766 1767 d.udfs[zFuncName] = udf 1768 1769 return nil 1770 } 1771 1772 func origin(skip int) string { 1773 pc, fn, fl, _ := runtime.Caller(skip) 1774 f := runtime.FuncForPC(pc) 1775 var fns string 1776 if f != nil { 1777 fns = f.Name() 1778 if x := strings.LastIndex(fns, "."); x > 0 { 1779 fns = fns[x+1:] 1780 } 1781 } 1782 return fmt.Sprintf("%s:%d:%s", fn, fl, fns) 1783 } 1784 1785 func errorResultFunction(tls *libc.TLS, ctx uintptr) func(error) { 1786 return func(res error) { 1787 errmsg, cerr := libc.CString(res.Error()) 1788 if cerr != nil { 1789 panic(cerr) 1790 } 1791 defer libc.Xfree(tls, errmsg) 1792 sqlite3.Xsqlite3_result_error(tls, ctx, errmsg, -1) 1793 sqlite3.Xsqlite3_result_error_code(tls, ctx, sqlite3.SQLITE_ERROR) 1794 } 1795 } 1796 1797 func functionArgs(tls *libc.TLS, argc int32, argv uintptr) []driver.Value { 1798 args := make([]driver.Value, argc) 1799 for i := int32(0); i < argc; i++ { 1800 valPtr := *(*uintptr)(unsafe.Pointer(argv + uintptr(i)*sqliteValPtrSize)) 1801 1802 switch valType := sqlite3.Xsqlite3_value_type(tls, valPtr); valType { 1803 case sqlite3.SQLITE_TEXT: 1804 args[i] = libc.GoString(sqlite3.Xsqlite3_value_text(tls, valPtr)) 1805 case sqlite3.SQLITE_INTEGER: 1806 args[i] = sqlite3.Xsqlite3_value_int64(tls, valPtr) 1807 case sqlite3.SQLITE_FLOAT: 1808 args[i] = sqlite3.Xsqlite3_value_double(tls, valPtr) 1809 case sqlite3.SQLITE_NULL: 1810 args[i] = nil 1811 case sqlite3.SQLITE_BLOB: 1812 size := sqlite3.Xsqlite3_value_bytes(tls, valPtr) 1813 blobPtr := sqlite3.Xsqlite3_value_blob(tls, valPtr) 1814 v := make([]byte, size) 1815 copy(v, (*libc.RawMem)(unsafe.Pointer(blobPtr))[:size:size]) 1816 args[i] = v 1817 default: 1818 panic(fmt.Sprintf("unexpected argument type %q passed by sqlite", valType)) 1819 } 1820 } 1821 1822 return args 1823 } 1824 1825 func functionReturnValue(tls *libc.TLS, ctx uintptr, res driver.Value) error { 1826 switch resTyped := res.(type) { 1827 case nil: 1828 sqlite3.Xsqlite3_result_null(tls, ctx) 1829 case int64: 1830 sqlite3.Xsqlite3_result_int64(tls, ctx, resTyped) 1831 case float64: 1832 sqlite3.Xsqlite3_result_double(tls, ctx, resTyped) 1833 case bool: 1834 sqlite3.Xsqlite3_result_int(tls, ctx, libc.Bool32(resTyped)) 1835 case time.Time: 1836 sqlite3.Xsqlite3_result_int64(tls, ctx, resTyped.Unix()) 1837 case string: 1838 size := int32(len(resTyped)) 1839 cstr, err := libc.CString(resTyped) 1840 if err != nil { 1841 panic(err) 1842 } 1843 defer libc.Xfree(tls, cstr) 1844 sqlite3.Xsqlite3_result_text(tls, ctx, cstr, size, sqlite3.SQLITE_TRANSIENT) 1845 case []byte: 1846 size := int32(len(resTyped)) 1847 if size == 0 { 1848 sqlite3.Xsqlite3_result_zeroblob(tls, ctx, 0) 1849 return nil 1850 } 1851 p := libc.Xmalloc(tls, types.Size_t(size)) 1852 if p == 0 { 1853 panic(fmt.Sprintf("unable to allocate space for blob: %d", size)) 1854 } 1855 defer libc.Xfree(tls, p) 1856 copy((*libc.RawMem)(unsafe.Pointer(p))[:size:size], resTyped) 1857 1858 sqlite3.Xsqlite3_result_blob(tls, ctx, p, size, sqlite3.SQLITE_TRANSIENT) 1859 default: 1860 return fmt.Errorf("function did not return a valid driver.Value: %T", resTyped) 1861 } 1862 1863 return nil 1864 } 1865 1866 // The below is all taken from zombiezen.com/go/sqlite. Aggregate functions need 1867 // to maintain state (for instance, the count of values seen so far). We give 1868 // each aggregate function an ID, generated by idGen, and put that in the pApp 1869 // argument to sqlite3_create_function. We track this on the Go side in 1870 // xAggregateFactories. 1871 // 1872 // When (if) the function is called is called by a query, we call the 1873 // MakeAggregate factory function to set it up, and track that in 1874 // xAggregateContext, retrieving it via sqlite3_aggregate_context. 1875 // 1876 // We also need to ensure that, for both aggregate and scalar functions, the 1877 // function pointer we pass to SQLite meets certain rules on the Go side, so 1878 // that the pointer remains valid. 1879 var ( 1880 xFuncs = struct { 1881 mu sync.RWMutex 1882 m map[uintptr]func(*FunctionContext, []driver.Value) (driver.Value, error) 1883 ids idGen 1884 }{ 1885 m: make(map[uintptr]func(*FunctionContext, []driver.Value) (driver.Value, error)), 1886 } 1887 1888 xAggregateFactories = struct { 1889 mu sync.RWMutex 1890 m map[uintptr]func(FunctionContext) (AggregateFunction, error) 1891 ids idGen 1892 }{ 1893 m: make(map[uintptr]func(FunctionContext) (AggregateFunction, error)), 1894 } 1895 1896 xAggregateContext = struct { 1897 mu sync.RWMutex 1898 m map[uintptr]AggregateFunction 1899 ids idGen 1900 }{ 1901 m: make(map[uintptr]AggregateFunction), 1902 } 1903 ) 1904 1905 type idGen struct { 1906 bitset []uint64 1907 } 1908 1909 func (gen *idGen) next() uintptr { 1910 base := uintptr(1) 1911 for i := 0; i < len(gen.bitset); i, base = i+1, base+64 { 1912 b := gen.bitset[i] 1913 if b != 1<<64-1 { 1914 n := uintptr(bits.TrailingZeros64(^b)) 1915 gen.bitset[i] |= 1 << n 1916 return base + n 1917 } 1918 } 1919 gen.bitset = append(gen.bitset, 1) 1920 return base 1921 } 1922 1923 func (gen *idGen) reclaim(id uintptr) { 1924 bit := id - 1 1925 gen.bitset[bit/64] &^= 1 << (bit % 64) 1926 } 1927 1928 func makeAggregate(tls *libc.TLS, ctx uintptr) (AggregateFunction, uintptr) { 1929 goCtx := FunctionContext{tls: tls, ctx: ctx} 1930 aggCtx := (*uintptr)(unsafe.Pointer(sqlite3.Xsqlite3_aggregate_context(tls, ctx, int32(ptrSize)))) 1931 setErrorResult := errorResultFunction(tls, ctx) 1932 if aggCtx == nil { 1933 setErrorResult(errors.New("insufficient memory for aggregate")) 1934 return nil, 0 1935 } 1936 if *aggCtx != 0 { 1937 // Already created. 1938 xAggregateContext.mu.RLock() 1939 f := xAggregateContext.m[*aggCtx] 1940 xAggregateContext.mu.RUnlock() 1941 return f, *aggCtx 1942 } 1943 1944 factoryID := sqlite3.Xsqlite3_user_data(tls, ctx) 1945 xAggregateFactories.mu.RLock() 1946 factory := xAggregateFactories.m[factoryID] 1947 xAggregateFactories.mu.RUnlock() 1948 1949 f, err := factory(goCtx) 1950 if err != nil { 1951 setErrorResult(err) 1952 return nil, 0 1953 } 1954 if f == nil { 1955 setErrorResult(errors.New("MakeAggregate function returned nil")) 1956 return nil, 0 1957 } 1958 1959 xAggregateContext.mu.Lock() 1960 *aggCtx = xAggregateContext.ids.next() 1961 xAggregateContext.m[*aggCtx] = f 1962 xAggregateContext.mu.Unlock() 1963 return f, *aggCtx 1964 } 1965 1966 // cFuncPointer converts a function defined by a function declaration to a C pointer. 1967 // The result of using cFuncPointer on closures is undefined. 1968 func cFuncPointer[T any](f T) uintptr { 1969 // This assumes the memory representation described in https://golang.org/s/go11func. 1970 // 1971 // cFuncPointer does its conversion by doing the following in order: 1972 // 1) Create a Go struct containing a pointer to a pointer to 1973 // the function. It is assumed that the pointer to the function will be 1974 // stored in the read-only data section and thus will not move. 1975 // 2) Convert the pointer to the Go struct to a pointer to uintptr through 1976 // unsafe.Pointer. This is permitted via Rule #1 of unsafe.Pointer. 1977 // 3) Dereference the pointer to uintptr to obtain the function value as a 1978 // uintptr. This is safe as long as function values are passed as pointers. 1979 return *(*uintptr)(unsafe.Pointer(&struct{ f T }{f})) 1980 } 1981 1982 func funcTrampoline(tls *libc.TLS, ctx uintptr, argc int32, argv uintptr) { 1983 id := sqlite3.Xsqlite3_user_data(tls, ctx) 1984 xFuncs.mu.RLock() 1985 xFunc := xFuncs.m[id] 1986 xFuncs.mu.RUnlock() 1987 1988 setErrorResult := errorResultFunction(tls, ctx) 1989 res, err := xFunc(&FunctionContext{}, functionArgs(tls, argc, argv)) 1990 1991 if err != nil { 1992 setErrorResult(err) 1993 return 1994 } 1995 1996 err = functionReturnValue(tls, ctx, res) 1997 if err != nil { 1998 setErrorResult(err) 1999 } 2000 } 2001 2002 func stepTrampoline(tls *libc.TLS, ctx uintptr, argc int32, argv uintptr) { 2003 impl, _ := makeAggregate(tls, ctx) 2004 if impl == nil { 2005 return 2006 } 2007 2008 setErrorResult := errorResultFunction(tls, ctx) 2009 err := impl.Step(&FunctionContext{}, functionArgs(tls, argc, argv)) 2010 if err != nil { 2011 setErrorResult(err) 2012 } 2013 } 2014 2015 func inverseTrampoline(tls *libc.TLS, ctx uintptr, argc int32, argv uintptr) { 2016 impl, _ := makeAggregate(tls, ctx) 2017 if impl == nil { 2018 return 2019 } 2020 2021 setErrorResult := errorResultFunction(tls, ctx) 2022 err := impl.WindowInverse(&FunctionContext{}, functionArgs(tls, argc, argv)) 2023 if err != nil { 2024 setErrorResult(err) 2025 } 2026 } 2027 2028 func valueTrampoline(tls *libc.TLS, ctx uintptr) { 2029 impl, _ := makeAggregate(tls, ctx) 2030 if impl == nil { 2031 return 2032 } 2033 2034 setErrorResult := errorResultFunction(tls, ctx) 2035 res, err := impl.WindowValue(&FunctionContext{}) 2036 if err != nil { 2037 setErrorResult(err) 2038 } else { 2039 err = functionReturnValue(tls, ctx, res) 2040 if err != nil { 2041 setErrorResult(err) 2042 } 2043 } 2044 } 2045 2046 func finalTrampoline(tls *libc.TLS, ctx uintptr) { 2047 impl, id := makeAggregate(tls, ctx) 2048 if impl == nil { 2049 return 2050 } 2051 2052 setErrorResult := errorResultFunction(tls, ctx) 2053 res, err := impl.WindowValue(&FunctionContext{}) 2054 if err != nil { 2055 setErrorResult(err) 2056 } else { 2057 err = functionReturnValue(tls, ctx, res) 2058 if err != nil { 2059 setErrorResult(err) 2060 } 2061 } 2062 impl.Final(&FunctionContext{}) 2063 2064 xAggregateContext.mu.Lock() 2065 defer xAggregateContext.mu.Unlock() 2066 delete(xAggregateContext.m, id) 2067 xAggregateContext.ids.reclaim(id) 2068 } 2069 2070 // int sqlite3_limit(sqlite3*, int id, int newVal); 2071 func (c *conn) limit(id int, newVal int) int { 2072 return int(sqlite3.Xsqlite3_limit(c.tls, c.db, int32(id), int32(newVal))) 2073 } 2074 2075 // Limit calls sqlite3_limit, see the docs at 2076 // https://www.sqlite.org/c3ref/limit.html for details. 2077 // 2078 // To get a sql.Conn from a *sql.DB, use (*sql.DB).Conn(). Limits are bound to 2079 // the particular instance of 'c', so getting a new connection only to pass it 2080 // to Limit is possibly not useful above querying what are the various 2081 // configured default values. 2082 func Limit(c *sql.Conn, id int, newVal int) (r int, err error) { 2083 err = c.Raw(func(driverConn any) error { 2084 switch dc := driverConn.(type) { 2085 case *conn: 2086 r = dc.limit(id, newVal) 2087 return nil 2088 default: 2089 return fmt.Errorf("unexpected driverConn type: %T", driverConn) 2090 } 2091 }) 2092 return r, err 2093 2094 }