database.go

139 lines
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
package database

import (
	"database/sql"
	"fmt"
	"strings"

	"github.com/google/uuid"
)

// Database wraps a sql.DB connection with its engine for lifecycle management.
// It embeds *sql.DB so all standard database methods are available directly.
type Database struct {
	*sql.DB
	Sync Syncer
}

// Syncer is an optional interface for databases that support syncing with a remote.
type Syncer interface {
	Sync() error
}

// GenerateID creates a new UUID string
func (db *Database) GenerateID() string {
	return uuid.NewString()
}

// TableExists checks if a table exists in the database.
func (db *Database) TableExists(name string) bool {
	var tableName string
	err := db.QueryRow(
		"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
		name,
	).Scan(&tableName)
	return err == nil
}

// Column describes a database column for table creation and migration.
type Column struct {
	Name    string
	Type    string // SQL type: TEXT, INTEGER, REAL, DATETIME, BLOB
	Primary bool
	Default string // Default value as SQL literal: '', 0, CURRENT_TIMESTAMP
}

// CreateTable creates a table with the given columns.
func (db *Database) CreateTable(name string, columns []Column) error {
	var cols []string
	for _, col := range columns {
		def := fmt.Sprintf(`"%s" %s`, col.Name, col.Type)
		if col.Primary {
			def += " PRIMARY KEY"
		} else if col.Default != "" {
			def += " DEFAULT " + col.Default
		}
		cols = append(cols, def)
	}

	query := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS "%s" (%s)`, name, strings.Join(cols, ", "))
	_, err := db.Exec(query)
	return err
}

// GetColumns returns the names of all columns in a table.
func (db *Database) GetColumns(table string) []string {
	rows, err := db.Query(fmt.Sprintf(`PRAGMA table_info("%s")`, table))
	if err != nil {
		return nil
	}
	defer rows.Close()

	var columns []string
	for rows.Next() {
		var cid int
		var name, colType string
		var notNull, pk int
		var dfltValue any
		if err := rows.Scan(&cid, &name, &colType, &notNull, &dfltValue, &pk); err != nil {
			continue
		}
		columns = append(columns, name)
	}
	if err := rows.Err(); err != nil {
		return nil
	}
	return columns
}

// AddColumn adds a column to an existing table with a default value.
// SQLite requires constant defaults for ALTER TABLE ADD COLUMN,
// so non-constant defaults like CURRENT_TIMESTAMP are replaced.
func (db *Database) AddColumn(table string, col Column) error {
	defaultVal := col.Default
	// SQLite ALTER TABLE ADD COLUMN requires constant defaults
	switch {
	case defaultVal == "CURRENT_TIMESTAMP":
		defaultVal = "'1970-01-01 00:00:00'"
	case defaultVal == "":
		defaultVal = zeroDefault(col.Type)
	}
	query := fmt.Sprintf(`ALTER TABLE "%s" ADD COLUMN "%s" %s DEFAULT %s`,
		table, col.Name, col.Type, defaultVal)
	_, err := db.Exec(query)
	return err
}

// zeroDefault returns the Go zero-value SQL literal for a SQL type.
func zeroDefault(sqlType string) string {
	switch strings.ToUpper(sqlType) {
	case "INTEGER":
		return "0"
	case "REAL":
		return "0.0"
	default:
		return "''"
	}
}

// Transaction executes a function within a database transaction.
// If the function returns an error, the transaction is rolled back.
// If the function completes successfully, the transaction is committed.
func (db *Database) Transaction(fn func(tx *sql.Tx) error) error {
	tx, err := db.Begin()
	if err != nil {
		return fmt.Errorf("begin transaction: %w", err)
	}

	if err := fn(tx); err != nil {
		if rbErr := tx.Rollback(); rbErr != nil {
			return fmt.Errorf("rollback failed: %v (original error: %w)", rbErr, err)
		}
		return err
	}

	if err := tx.Commit(); err != nil {
		return fmt.Errorf("commit transaction: %w", err)
	}
	return nil
}