refactor: migrate db and config packages to use struct-based API
Removed global variables in favor of instance variables. This makes the code more maintainable and making it easier to write unit tests without relying on global state. Marked the old functions with global state as obsolete, redirecting them to use a new API based on struct in order to rewrite the code using these functions gradually.
This commit is contained in:
@ -20,28 +20,16 @@ package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"plemya-x.ru/alr/internal/config"
|
||||
"plemya-x.ru/alr/pkg/loggerctx"
|
||||
"golang.org/x/exp/slices"
|
||||
"modernc.org/sqlite"
|
||||
)
|
||||
|
||||
// CurrentVersion is the current version of the database.
|
||||
// The database is reset if its version doesn't match this.
|
||||
const CurrentVersion = 2
|
||||
|
||||
func init() {
|
||||
sqlite.MustRegisterScalarFunction("json_array_contains", 2, jsonArrayContains)
|
||||
}
|
||||
|
||||
// Package is a ALR package's database representation
|
||||
type Package struct {
|
||||
Name string `sh:"name,required" db:"name"`
|
||||
@ -66,66 +54,47 @@ type version struct {
|
||||
Version int `db:"version"`
|
||||
}
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
type Config interface {
|
||||
GetPaths(ctx context.Context) *config.Paths
|
||||
}
|
||||
|
||||
type Database struct {
|
||||
conn *sqlx.DB
|
||||
closed = true
|
||||
)
|
||||
config Config
|
||||
}
|
||||
|
||||
// DB returns the ALR database.
|
||||
// The first time it's called, it opens the SQLite database file.
|
||||
// Subsequent calls return the same connection.
|
||||
func DB(ctx context.Context) *sqlx.DB {
|
||||
log := loggerctx.From(ctx)
|
||||
if conn != nil && !closed {
|
||||
return getConn()
|
||||
func New(config Config) *Database {
|
||||
return &Database{
|
||||
config: config,
|
||||
}
|
||||
_, err := open(ctx, config.GetPaths(ctx).DBPath)
|
||||
}
|
||||
|
||||
func (d *Database) Init(ctx context.Context) error {
|
||||
err := d.Connect(ctx)
|
||||
if err != nil {
|
||||
log.Fatal("Error opening database").Err(err).Send()
|
||||
return err
|
||||
}
|
||||
return getConn()
|
||||
return d.initDB(ctx)
|
||||
}
|
||||
|
||||
func getConn() *sqlx.DB {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return conn
|
||||
}
|
||||
|
||||
func open(ctx context.Context, dsn string) (*sqlx.DB, error) {
|
||||
func (d *Database) Connect(ctx context.Context) error {
|
||||
dsn := d.config.GetPaths(ctx).DBPath
|
||||
db, err := sqlx.Open("sqlite", dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
conn = db
|
||||
closed = false
|
||||
mu.Unlock()
|
||||
|
||||
err = initDB(ctx, dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db, nil
|
||||
d.conn = db
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the database
|
||||
func Close() error {
|
||||
closed = true
|
||||
if conn != nil {
|
||||
return conn.Close()
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
func (d *Database) GetConn() *sqlx.DB {
|
||||
return d.conn
|
||||
}
|
||||
|
||||
// initDB initializes the database
|
||||
func initDB(ctx context.Context, dsn string) error {
|
||||
func (d *Database) initDB(ctx context.Context) error {
|
||||
log := loggerctx.From(ctx)
|
||||
conn = conn.Unsafe()
|
||||
d.conn = d.conn.Unsafe()
|
||||
conn := d.conn
|
||||
_, err := conn.ExecContext(ctx, `
|
||||
CREATE TABLE IF NOT EXISTS pkgs (
|
||||
name TEXT NOT NULL,
|
||||
@ -155,58 +124,72 @@ func initDB(ctx context.Context, dsn string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
ver, ok := GetVersion(ctx)
|
||||
ver, ok := d.GetVersion(ctx)
|
||||
if ok && ver != CurrentVersion {
|
||||
log.Warn("Database version mismatch; resetting").Int("version", ver).Int("expected", CurrentVersion).Send()
|
||||
reset(ctx)
|
||||
return initDB(ctx, dsn)
|
||||
d.reset(ctx)
|
||||
return d.initDB(ctx)
|
||||
} else if !ok {
|
||||
log.Warn("Database version does not exist. Run alr fix if something isn't working.").Send()
|
||||
return addVersion(ctx, CurrentVersion)
|
||||
return d.addVersion(ctx, CurrentVersion)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// reset drops all the database tables
|
||||
func reset(ctx context.Context) error {
|
||||
_, err := DB(ctx).ExecContext(ctx, "DROP TABLE IF EXISTS pkgs;")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = DB(ctx).ExecContext(ctx, "DROP TABLE IF EXISTS alr_db_version;")
|
||||
return err
|
||||
}
|
||||
|
||||
// IsEmpty returns true if the database has no packages in it, otherwise it returns false.
|
||||
func IsEmpty(ctx context.Context) bool {
|
||||
var count int
|
||||
err := DB(ctx).GetContext(ctx, &count, "SELECT count(1) FROM pkgs;")
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
return count == 0
|
||||
}
|
||||
|
||||
// GetVersion returns the database version and a boolean indicating
|
||||
// whether the database contained a version number
|
||||
func GetVersion(ctx context.Context) (int, bool) {
|
||||
func (d *Database) GetVersion(ctx context.Context) (int, bool) {
|
||||
var ver version
|
||||
err := DB(ctx).GetContext(ctx, &ver, "SELECT * FROM alr_db_version LIMIT 1;")
|
||||
err := d.conn.GetContext(ctx, &ver, "SELECT * FROM alr_db_version LIMIT 1;")
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
return ver.Version, true
|
||||
}
|
||||
|
||||
func addVersion(ctx context.Context, ver int) error {
|
||||
_, err := DB(ctx).ExecContext(ctx, `INSERT INTO alr_db_version(version) VALUES (?);`, ver)
|
||||
func (d *Database) addVersion(ctx context.Context, ver int) error {
|
||||
_, err := d.conn.ExecContext(ctx, `INSERT INTO alr_db_version(version) VALUES (?);`, ver)
|
||||
return err
|
||||
}
|
||||
|
||||
// InsertPackage adds a package to the database
|
||||
func InsertPackage(ctx context.Context, pkg Package) error {
|
||||
_, err := DB(ctx).NamedExecContext(ctx, `
|
||||
func (d *Database) reset(ctx context.Context) error {
|
||||
_, err := d.conn.ExecContext(ctx, "DROP TABLE IF EXISTS pkgs;")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = d.conn.ExecContext(ctx, "DROP TABLE IF EXISTS alr_db_version;")
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *Database) GetPkgs(ctx context.Context, where string, args ...any) (*sqlx.Rows, error) {
|
||||
stream, err := d.conn.QueryxContext(ctx, "SELECT * FROM pkgs WHERE "+where, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
func (d *Database) GetPkg(ctx context.Context, where string, args ...any) (*Package, error) {
|
||||
out := &Package{}
|
||||
err := d.conn.GetContext(ctx, out, "SELECT * FROM pkgs WHERE "+where+" LIMIT 1", args...)
|
||||
return out, err
|
||||
}
|
||||
|
||||
func (d *Database) DeletePkgs(ctx context.Context, where string, args ...any) error {
|
||||
_, err := d.conn.ExecContext(ctx, "DELETE FROM pkgs WHERE "+where, args...)
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *Database) IsEmpty(ctx context.Context) bool {
|
||||
var count int
|
||||
err := d.conn.GetContext(ctx, &count, "SELECT count(1) FROM pkgs;")
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
return count == 0
|
||||
}
|
||||
|
||||
func (d *Database) InsertPackage(ctx context.Context, pkg Package) error {
|
||||
_, err := d.conn.NamedExecContext(ctx, `
|
||||
INSERT OR REPLACE INTO pkgs (
|
||||
name,
|
||||
repository,
|
||||
@ -246,101 +229,10 @@ func InsertPackage(ctx context.Context, pkg Package) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// GetPkgs returns a result containing packages that match the where conditions
|
||||
func GetPkgs(ctx context.Context, where string, args ...any) (*sqlx.Rows, error) {
|
||||
stream, err := DB(ctx).QueryxContext(ctx, "SELECT * FROM pkgs WHERE "+where, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
// GetPkg returns a single package that matches the where conditions
|
||||
func GetPkg(ctx context.Context, where string, args ...any) (*Package, error) {
|
||||
out := &Package{}
|
||||
err := DB(ctx).GetContext(ctx, out, "SELECT * FROM pkgs WHERE "+where+" LIMIT 1", args...)
|
||||
return out, err
|
||||
}
|
||||
|
||||
// DeletePkgs deletes all packages matching the where conditions
|
||||
func DeletePkgs(ctx context.Context, where string, args ...any) error {
|
||||
_, err := DB(ctx).ExecContext(ctx, "DELETE FROM pkgs WHERE "+where, args...)
|
||||
return err
|
||||
}
|
||||
|
||||
// jsonArrayContains is an SQLite function that checks if a JSON array
|
||||
// in the database contains a given value
|
||||
func jsonArrayContains(ctx *sqlite.FunctionContext, args []driver.Value) (driver.Value, error) {
|
||||
value, ok := args[0].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("both arguments to json_array_contains must be strings")
|
||||
}
|
||||
|
||||
item, ok := args[1].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("both arguments to json_array_contains must be strings")
|
||||
}
|
||||
|
||||
var array []string
|
||||
err := json.Unmarshal([]byte(value), &array)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return slices.Contains(array, item), nil
|
||||
}
|
||||
|
||||
// JSON represents a JSON value in the database
|
||||
type JSON[T any] struct {
|
||||
Val T
|
||||
}
|
||||
|
||||
// NewJSON creates a new database JSON value
|
||||
func NewJSON[T any](v T) JSON[T] {
|
||||
return JSON[T]{Val: v}
|
||||
}
|
||||
|
||||
func (s *JSON[T]) Scan(val any) error {
|
||||
if val == nil {
|
||||
func (d *Database) Close() error {
|
||||
if d.conn != nil {
|
||||
return d.conn.Close()
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch val := val.(type) {
|
||||
case string:
|
||||
err := json.Unmarshal([]byte(val), &s.Val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case sql.NullString:
|
||||
if val.Valid {
|
||||
err := json.Unmarshal([]byte(val.String), &s.Val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
default:
|
||||
return errors.New("sqlite json types must be strings")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s JSON[T]) Value() (driver.Value, error) {
|
||||
data, err := json.Marshal(s.Val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
func (s JSON[T]) MarshalYAML() (any, error) {
|
||||
return s.Val, nil
|
||||
}
|
||||
|
||||
func (s JSON[T]) String() string {
|
||||
return fmt.Sprint(s.Val)
|
||||
}
|
||||
|
||||
func (s JSON[T]) GoString() string {
|
||||
return fmt.Sprintf("%#v", s.Val)
|
||||
}
|
||||
|
Reference in New Issue
Block a user