database.go
139 lines1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
package database
import (
"database/sql"
"fmt"
"strings"
"github.com/google/uuid"
)
// Database wraps a sql.DB connection with its engine for lifecycle management.
// It embeds *sql.DB so all standard database methods are available directly.
type Database struct {
*sql.DB
Sync Syncer
}
// Syncer is an optional interface for databases that support syncing with a remote.
type Syncer interface {
Sync() error
}
// GenerateID creates a new UUID string
func (db *Database) GenerateID() string {
return uuid.NewString()
}
// TableExists checks if a table exists in the database.
func (db *Database) TableExists(name string) bool {
var tableName string
err := db.QueryRow(
"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
name,
).Scan(&tableName)
return err == nil
}
// Column describes a database column for table creation and migration.
type Column struct {
Name string
Type string // SQL type: TEXT, INTEGER, REAL, DATETIME, BLOB
Primary bool
Default string // Default value as SQL literal: '', 0, CURRENT_TIMESTAMP
}
// CreateTable creates a table with the given columns.
func (db *Database) CreateTable(name string, columns []Column) error {
var cols []string
for _, col := range columns {
def := fmt.Sprintf(`"%s" %s`, col.Name, col.Type)
if col.Primary {
def += " PRIMARY KEY"
} else if col.Default != "" {
def += " DEFAULT " + col.Default
}
cols = append(cols, def)
}
query := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS "%s" (%s)`, name, strings.Join(cols, ", "))
_, err := db.Exec(query)
return err
}
// GetColumns returns the names of all columns in a table.
func (db *Database) GetColumns(table string) []string {
rows, err := db.Query(fmt.Sprintf(`PRAGMA table_info("%s")`, table))
if err != nil {
return nil
}
defer rows.Close()
var columns []string
for rows.Next() {
var cid int
var name, colType string
var notNull, pk int
var dfltValue any
if err := rows.Scan(&cid, &name, &colType, ¬Null, &dfltValue, &pk); err != nil {
continue
}
columns = append(columns, name)
}
if err := rows.Err(); err != nil {
return nil
}
return columns
}
// AddColumn adds a column to an existing table with a default value.
// SQLite requires constant defaults for ALTER TABLE ADD COLUMN,
// so non-constant defaults like CURRENT_TIMESTAMP are replaced.
func (db *Database) AddColumn(table string, col Column) error {
defaultVal := col.Default
// SQLite ALTER TABLE ADD COLUMN requires constant defaults
switch {
case defaultVal == "CURRENT_TIMESTAMP":
defaultVal = "'1970-01-01 00:00:00'"
case defaultVal == "":
defaultVal = zeroDefault(col.Type)
}
query := fmt.Sprintf(`ALTER TABLE "%s" ADD COLUMN "%s" %s DEFAULT %s`,
table, col.Name, col.Type, defaultVal)
_, err := db.Exec(query)
return err
}
// zeroDefault returns the Go zero-value SQL literal for a SQL type.
func zeroDefault(sqlType string) string {
switch strings.ToUpper(sqlType) {
case "INTEGER":
return "0"
case "REAL":
return "0.0"
default:
return "''"
}
}
// Transaction executes a function within a database transaction.
// If the function returns an error, the transaction is rolled back.
// If the function completes successfully, the transaction is committed.
func (db *Database) Transaction(fn func(tx *sql.Tx) error) error {
tx, err := db.Begin()
if err != nil {
return fmt.Errorf("begin transaction: %w", err)
}
if err := fn(tx); err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
return fmt.Errorf("rollback failed: %v (original error: %w)", rbErr, err)
}
return err
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("commit transaction: %w", err)
}
return nil
}