refactor: migrate db and config packages to use struct-based API

Removed global variables in favor of instance variables. This makes the code more maintainable and making it easier to write unit tests without relying on global state.

Marked the old functions with global state as obsolete, redirecting them to use a new API based on struct in order to rewrite the code using these functions gradually.
This commit is contained in:
Maxim Slipenko 2025-01-14 10:11:17 +03:00
parent 5d1d3d7c45
commit 52d3ab7791
8 changed files with 491 additions and 345 deletions

@ -21,59 +21,108 @@ package config
import ( import (
"context" "context"
"os" "os"
"path/filepath"
"sync" "sync"
"github.com/pelletier/go-toml/v2" "github.com/pelletier/go-toml/v2"
"go.elara.ws/logger/log"
"plemya-x.ru/alr/internal/types" "plemya-x.ru/alr/internal/types"
"plemya-x.ru/alr/pkg/loggerctx" "plemya-x.ru/alr/pkg/loggerctx"
) )
var defaultConfig = &types.Config{ type ALRConfig struct {
RootCmd: "sudo", cfg *types.Config
PagerStyle: "native", paths *Paths
IgnorePkgUpdates: []string{},
Repos: []types.Repo{ pathsOnce sync.Once
{
Name: "default",
URL: "https://gitea.plemya-x.ru/xpamych/xpamych-alr-repo.git",
},
},
} }
var ( func New() *ALRConfig {
configMtx sync.Mutex return &ALRConfig{}
config *types.Config }
)
// Config returns a ALR configuration struct. func (c *ALRConfig) Load(ctx context.Context) {
// The first time it's called, it'll load the config from a file. cfgFl, err := os.Open(c.GetPaths(ctx).ConfigPath)
// Subsequent calls will just return the same value. if err != nil {
func Config(ctx context.Context) *types.Config { log.Warn("Error opening config file, using defaults").Err(err).Send()
configMtx.Lock() c.cfg = defaultConfig
defer configMtx.Unlock() return
}
defer cfgFl.Close()
// Copy the default configuration into config
defCopy := *defaultConfig
config := &defCopy
config.Repos = nil
err = toml.NewDecoder(cfgFl).Decode(config)
if err != nil {
log.Warn("Error decoding config file, using defaults").Err(err).Send()
c.cfg = defaultConfig
return
}
c.cfg = config
}
func (c *ALRConfig) initPaths(ctx context.Context) {
log := loggerctx.From(ctx) log := loggerctx.From(ctx)
paths := &Paths{}
if config == nil { cfgDir, err := os.UserConfigDir()
cfgFl, err := os.Open(GetPaths(ctx).ConfigPath) if err != nil {
if err != nil { log.Fatal("Unable to detect user config directory").Err(err).Send()
log.Warn("Error opening config file, using defaults").Err(err).Send()
return defaultConfig
}
defer cfgFl.Close()
// Copy the default configuration into config
defCopy := *defaultConfig
config = &defCopy
config.Repos = nil
err = toml.NewDecoder(cfgFl).Decode(config)
if err != nil {
log.Warn("Error decoding config file, using defaults").Err(err).Send()
// Set config back to nil so that we try again next time
config = nil
return defaultConfig
}
} }
return config paths.ConfigDir = filepath.Join(cfgDir, "alr")
err = os.MkdirAll(paths.ConfigDir, 0o755)
if err != nil {
log.Fatal("Unable to create ALR config directory").Err(err).Send()
}
paths.ConfigPath = filepath.Join(paths.ConfigDir, "alr.toml")
if _, err := os.Stat(paths.ConfigPath); err != nil {
cfgFl, err := os.Create(paths.ConfigPath)
if err != nil {
log.Fatal("Unable to create ALR config file").Err(err).Send()
}
err = toml.NewEncoder(cfgFl).Encode(&defaultConfig)
if err != nil {
log.Fatal("Error encoding default configuration").Err(err).Send()
}
cfgFl.Close()
}
cacheDir, err := os.UserCacheDir()
if err != nil {
log.Fatal("Unable to detect cache directory").Err(err).Send()
}
paths.CacheDir = filepath.Join(cacheDir, "alr")
paths.RepoDir = filepath.Join(paths.CacheDir, "repo")
paths.PkgsDir = filepath.Join(paths.CacheDir, "pkgs")
err = os.MkdirAll(paths.RepoDir, 0o755)
if err != nil {
log.Fatal("Unable to create repo cache directory").Err(err).Send()
}
err = os.MkdirAll(paths.PkgsDir, 0o755)
if err != nil {
log.Fatal("Unable to create package cache directory").Err(err).Send()
}
paths.DBPath = filepath.Join(paths.CacheDir, "db")
c.paths = paths
}
func (c *ALRConfig) GetPaths(ctx context.Context) *Paths {
c.pathsOnce.Do(func() {
c.initPaths(ctx)
})
return c.paths
} }

@ -0,0 +1,65 @@
/*
* ALR - Any Linux Repository
* Copyright (C) 2024 Евгений Храмов
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package config
import (
"context"
"sync"
"plemya-x.ru/alr/internal/types"
)
var defaultConfig = &types.Config{
RootCmd: "sudo",
PagerStyle: "native",
IgnorePkgUpdates: []string{},
Repos: []types.Repo{
{
Name: "default",
URL: "https://gitea.plemya-x.ru/xpamych/xpamych-alr-repo.git",
},
},
}
// Config returns a ALR configuration struct.
// The first time it's called, it'll load the config from a file.
// Subsequent calls will just return the same value.
//
// Deprecated: use struct method
func Config(ctx context.Context) *types.Config {
return GetInstance(ctx).cfg
}
// =======================
// FOR LEGACY ONLY
// =======================
var (
alrConfig *ALRConfig
alrConfigOnce sync.Once
)
func GetInstance(ctx context.Context) *ALRConfig {
alrConfigOnce.Do(func() {
alrConfig = New()
alrConfig.Load(ctx)
})
return alrConfig
}

@ -20,12 +20,6 @@ package config
import ( import (
"context" "context"
"os"
"path/filepath"
"sync"
"github.com/pelletier/go-toml/v2"
"plemya-x.ru/alr/pkg/loggerctx"
) )
// Paths contains various paths used by ALR // Paths contains various paths used by ALR
@ -38,71 +32,13 @@ type Paths struct {
DBPath string DBPath string
} }
var (
pathsMtx sync.Mutex
paths *Paths
)
// GetPaths returns a Paths struct. // GetPaths returns a Paths struct.
// The first time it's called, it'll generate the struct // The first time it's called, it'll generate the struct
// using information from the system. // using information from the system.
// Subsequent calls will return the same value. // Subsequent calls will return the same value.
//
// Depreacted: use struct API
func GetPaths(ctx context.Context) *Paths { func GetPaths(ctx context.Context) *Paths {
pathsMtx.Lock() alrConfig := GetInstance(ctx)
defer pathsMtx.Unlock() return alrConfig.GetPaths(ctx)
log := loggerctx.From(ctx)
if paths == nil {
paths = &Paths{}
cfgDir, err := os.UserConfigDir()
if err != nil {
log.Fatal("Unable to detect user config directory").Err(err).Send()
}
paths.ConfigDir = filepath.Join(cfgDir, "alr")
err = os.MkdirAll(paths.ConfigDir, 0o755)
if err != nil {
log.Fatal("Unable to create ALR config directory").Err(err).Send()
}
paths.ConfigPath = filepath.Join(paths.ConfigDir, "alr.toml")
if _, err := os.Stat(paths.ConfigPath); err != nil {
cfgFl, err := os.Create(paths.ConfigPath)
if err != nil {
log.Fatal("Unable to create ALR config file").Err(err).Send()
}
err = toml.NewEncoder(cfgFl).Encode(&defaultConfig)
if err != nil {
log.Fatal("Error encoding default configuration").Err(err).Send()
}
cfgFl.Close()
}
cacheDir, err := os.UserCacheDir()
if err != nil {
log.Fatal("Unable to detect cache directory").Err(err).Send()
}
paths.CacheDir = filepath.Join(cacheDir, "alr")
paths.RepoDir = filepath.Join(paths.CacheDir, "repo")
paths.PkgsDir = filepath.Join(paths.CacheDir, "pkgs")
err = os.MkdirAll(paths.RepoDir, 0o755)
if err != nil {
log.Fatal("Unable to create repo cache directory").Err(err).Send()
}
err = os.MkdirAll(paths.PkgsDir, 0o755)
if err != nil {
log.Fatal("Unable to create package cache directory").Err(err).Send()
}
paths.DBPath = filepath.Join(paths.CacheDir, "db")
}
return paths
} }

@ -20,28 +20,16 @@ package db
import ( import (
"context" "context"
"database/sql"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"sync"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"plemya-x.ru/alr/internal/config" "plemya-x.ru/alr/internal/config"
"plemya-x.ru/alr/pkg/loggerctx" "plemya-x.ru/alr/pkg/loggerctx"
"golang.org/x/exp/slices"
"modernc.org/sqlite"
) )
// CurrentVersion is the current version of the database. // CurrentVersion is the current version of the database.
// The database is reset if its version doesn't match this. // The database is reset if its version doesn't match this.
const CurrentVersion = 2 const CurrentVersion = 2
func init() {
sqlite.MustRegisterScalarFunction("json_array_contains", 2, jsonArrayContains)
}
// Package is a ALR package's database representation // Package is a ALR package's database representation
type Package struct { type Package struct {
Name string `sh:"name,required" db:"name"` Name string `sh:"name,required" db:"name"`
@ -66,66 +54,47 @@ type version struct {
Version int `db:"version"` Version int `db:"version"`
} }
var ( type Config interface {
mu sync.Mutex GetPaths(ctx context.Context) *config.Paths
}
type Database struct {
conn *sqlx.DB conn *sqlx.DB
closed = true config Config
) }
// DB returns the ALR database. func New(config Config) *Database {
// The first time it's called, it opens the SQLite database file. return &Database{
// Subsequent calls return the same connection. config: config,
func DB(ctx context.Context) *sqlx.DB {
log := loggerctx.From(ctx)
if conn != nil && !closed {
return getConn()
} }
_, err := open(ctx, config.GetPaths(ctx).DBPath) }
func (d *Database) Init(ctx context.Context) error {
err := d.Connect(ctx)
if err != nil { if err != nil {
log.Fatal("Error opening database").Err(err).Send() return err
} }
return getConn() return d.initDB(ctx)
} }
func getConn() *sqlx.DB { func (d *Database) Connect(ctx context.Context) error {
mu.Lock() dsn := d.config.GetPaths(ctx).DBPath
defer mu.Unlock()
return conn
}
func open(ctx context.Context, dsn string) (*sqlx.DB, error) {
db, err := sqlx.Open("sqlite", dsn) db, err := sqlx.Open("sqlite", dsn)
if err != nil { if err != nil {
return nil, err return err
} }
d.conn = db
mu.Lock() return nil
conn = db
closed = false
mu.Unlock()
err = initDB(ctx, dsn)
if err != nil {
return nil, err
}
return db, nil
} }
// Close closes the database func (d *Database) GetConn() *sqlx.DB {
func Close() error { return d.conn
closed = true
if conn != nil {
return conn.Close()
} else {
return nil
}
} }
// initDB initializes the database func (d *Database) initDB(ctx context.Context) error {
func initDB(ctx context.Context, dsn string) error {
log := loggerctx.From(ctx) log := loggerctx.From(ctx)
conn = conn.Unsafe() d.conn = d.conn.Unsafe()
conn := d.conn
_, err := conn.ExecContext(ctx, ` _, err := conn.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS pkgs ( CREATE TABLE IF NOT EXISTS pkgs (
name TEXT NOT NULL, name TEXT NOT NULL,
@ -155,58 +124,72 @@ func initDB(ctx context.Context, dsn string) error {
return err return err
} }
ver, ok := GetVersion(ctx) ver, ok := d.GetVersion(ctx)
if ok && ver != CurrentVersion { if ok && ver != CurrentVersion {
log.Warn("Database version mismatch; resetting").Int("version", ver).Int("expected", CurrentVersion).Send() log.Warn("Database version mismatch; resetting").Int("version", ver).Int("expected", CurrentVersion).Send()
reset(ctx) d.reset(ctx)
return initDB(ctx, dsn) return d.initDB(ctx)
} else if !ok { } else if !ok {
log.Warn("Database version does not exist. Run alr fix if something isn't working.").Send() log.Warn("Database version does not exist. Run alr fix if something isn't working.").Send()
return addVersion(ctx, CurrentVersion) return d.addVersion(ctx, CurrentVersion)
} }
return nil return nil
} }
// reset drops all the database tables func (d *Database) GetVersion(ctx context.Context) (int, bool) {
func reset(ctx context.Context) error {
_, err := DB(ctx).ExecContext(ctx, "DROP TABLE IF EXISTS pkgs;")
if err != nil {
return err
}
_, err = DB(ctx).ExecContext(ctx, "DROP TABLE IF EXISTS alr_db_version;")
return err
}
// IsEmpty returns true if the database has no packages in it, otherwise it returns false.
func IsEmpty(ctx context.Context) bool {
var count int
err := DB(ctx).GetContext(ctx, &count, "SELECT count(1) FROM pkgs;")
if err != nil {
return true
}
return count == 0
}
// GetVersion returns the database version and a boolean indicating
// whether the database contained a version number
func GetVersion(ctx context.Context) (int, bool) {
var ver version var ver version
err := DB(ctx).GetContext(ctx, &ver, "SELECT * FROM alr_db_version LIMIT 1;") err := d.conn.GetContext(ctx, &ver, "SELECT * FROM alr_db_version LIMIT 1;")
if err != nil { if err != nil {
return 0, false return 0, false
} }
return ver.Version, true return ver.Version, true
} }
func addVersion(ctx context.Context, ver int) error { func (d *Database) addVersion(ctx context.Context, ver int) error {
_, err := DB(ctx).ExecContext(ctx, `INSERT INTO alr_db_version(version) VALUES (?);`, ver) _, err := d.conn.ExecContext(ctx, `INSERT INTO alr_db_version(version) VALUES (?);`, ver)
return err return err
} }
// InsertPackage adds a package to the database func (d *Database) reset(ctx context.Context) error {
func InsertPackage(ctx context.Context, pkg Package) error { _, err := d.conn.ExecContext(ctx, "DROP TABLE IF EXISTS pkgs;")
_, err := DB(ctx).NamedExecContext(ctx, ` if err != nil {
return err
}
_, err = d.conn.ExecContext(ctx, "DROP TABLE IF EXISTS alr_db_version;")
return err
}
func (d *Database) GetPkgs(ctx context.Context, where string, args ...any) (*sqlx.Rows, error) {
stream, err := d.conn.QueryxContext(ctx, "SELECT * FROM pkgs WHERE "+where, args...)
if err != nil {
return nil, err
}
return stream, nil
}
func (d *Database) GetPkg(ctx context.Context, where string, args ...any) (*Package, error) {
out := &Package{}
err := d.conn.GetContext(ctx, out, "SELECT * FROM pkgs WHERE "+where+" LIMIT 1", args...)
return out, err
}
func (d *Database) DeletePkgs(ctx context.Context, where string, args ...any) error {
_, err := d.conn.ExecContext(ctx, "DELETE FROM pkgs WHERE "+where, args...)
return err
}
func (d *Database) IsEmpty(ctx context.Context) bool {
var count int
err := d.conn.GetContext(ctx, &count, "SELECT count(1) FROM pkgs;")
if err != nil {
return true
}
return count == 0
}
func (d *Database) InsertPackage(ctx context.Context, pkg Package) error {
_, err := d.conn.NamedExecContext(ctx, `
INSERT OR REPLACE INTO pkgs ( INSERT OR REPLACE INTO pkgs (
name, name,
repository, repository,
@ -246,101 +229,10 @@ func InsertPackage(ctx context.Context, pkg Package) error {
return err return err
} }
// GetPkgs returns a result containing packages that match the where conditions func (d *Database) Close() error {
func GetPkgs(ctx context.Context, where string, args ...any) (*sqlx.Rows, error) { if d.conn != nil {
stream, err := DB(ctx).QueryxContext(ctx, "SELECT * FROM pkgs WHERE "+where, args...) return d.conn.Close()
if err != nil { } else {
return nil, err
}
return stream, nil
}
// GetPkg returns a single package that matches the where conditions
func GetPkg(ctx context.Context, where string, args ...any) (*Package, error) {
out := &Package{}
err := DB(ctx).GetContext(ctx, out, "SELECT * FROM pkgs WHERE "+where+" LIMIT 1", args...)
return out, err
}
// DeletePkgs deletes all packages matching the where conditions
func DeletePkgs(ctx context.Context, where string, args ...any) error {
_, err := DB(ctx).ExecContext(ctx, "DELETE FROM pkgs WHERE "+where, args...)
return err
}
// jsonArrayContains is an SQLite function that checks if a JSON array
// in the database contains a given value
func jsonArrayContains(ctx *sqlite.FunctionContext, args []driver.Value) (driver.Value, error) {
value, ok := args[0].(string)
if !ok {
return nil, errors.New("both arguments to json_array_contains must be strings")
}
item, ok := args[1].(string)
if !ok {
return nil, errors.New("both arguments to json_array_contains must be strings")
}
var array []string
err := json.Unmarshal([]byte(value), &array)
if err != nil {
return nil, err
}
return slices.Contains(array, item), nil
}
// JSON represents a JSON value in the database
type JSON[T any] struct {
Val T
}
// NewJSON creates a new database JSON value
func NewJSON[T any](v T) JSON[T] {
return JSON[T]{Val: v}
}
func (s *JSON[T]) Scan(val any) error {
if val == nil {
return nil return nil
} }
switch val := val.(type) {
case string:
err := json.Unmarshal([]byte(val), &s.Val)
if err != nil {
return err
}
case sql.NullString:
if val.Valid {
err := json.Unmarshal([]byte(val.String), &s.Val)
if err != nil {
return err
}
}
default:
return errors.New("sqlite json types must be strings")
}
return nil
}
func (s JSON[T]) Value() (driver.Value, error) {
data, err := json.Marshal(s.Val)
if err != nil {
return nil, err
}
return string(data), nil
}
func (s JSON[T]) MarshalYAML() (any, error) {
return s.Val, nil
}
func (s JSON[T]) String() string {
return fmt.Sprint(s.Val)
}
func (s JSON[T]) GoString() string {
return fmt.Sprintf("%#v", s.Val)
} }

105
internal/db/db_legacy.go Normal file

@ -0,0 +1,105 @@
/*
* ALR - Any Linux Repository
* Copyright (C) 2024 Евгений Храмов
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package db
import (
"context"
"sync"
"github.com/jmoiron/sqlx"
"plemya-x.ru/alr/internal/config"
"plemya-x.ru/alr/pkg/loggerctx"
)
// DB returns the ALR database.
// The first time it's called, it opens the SQLite database file.
// Subsequent calls return the same connection.
//
// Deprecated: use struct method
func DB(ctx context.Context) *sqlx.DB {
return getInstance(ctx).GetConn()
}
// Close closes the database
//
// Deprecated: use struct method
func Close() error {
if database != nil {
return database.Close()
}
return nil
}
// IsEmpty returns true if the database has no packages in it, otherwise it returns false.
//
// Deprecated: use struct method
func IsEmpty(ctx context.Context) bool {
return getInstance(ctx).IsEmpty(ctx)
}
// InsertPackage adds a package to the database
//
// Deprecated: use struct method
func InsertPackage(ctx context.Context, pkg Package) error {
return getInstance(ctx).InsertPackage(ctx, pkg)
}
// GetPkgs returns a result containing packages that match the where conditions
//
// Deprecated: use struct method
func GetPkgs(ctx context.Context, where string, args ...any) (*sqlx.Rows, error) {
return getInstance(ctx).GetPkgs(ctx, where, args...)
}
// GetPkg returns a single package that matches the where conditions
//
// Deprecated: use struct method
func GetPkg(ctx context.Context, where string, args ...any) (*Package, error) {
return getInstance(ctx).GetPkg(ctx, where, args...)
}
// DeletePkgs deletes all packages matching the where conditions
//
// Deprecated: use struct method
func DeletePkgs(ctx context.Context, where string, args ...any) error {
return getInstance(ctx).DeletePkgs(ctx, where, args...)
}
// =======================
// FOR LEGACY ONLY
// =======================
var (
dbOnce sync.Once
database *Database
)
// For refactoring only
func getInstance(ctx context.Context) *Database {
dbOnce.Do(func() {
log := loggerctx.From(ctx)
cfg := config.GetInstance(ctx)
database = New(cfg)
err := database.Init(ctx)
if err != nil {
log.Fatal("Error opening database").Err(err).Send()
}
})
return database
}

@ -19,14 +19,30 @@
package db_test package db_test
import ( import (
"context"
"reflect" "reflect"
"strings" "strings"
"testing" "testing"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"plemya-x.ru/alr/internal/config"
"plemya-x.ru/alr/internal/db" "plemya-x.ru/alr/internal/db"
) )
type TestALRConfig struct{}
func (c *TestALRConfig) GetPaths(ctx context.Context) *config.Paths {
return &config.Paths{
DBPath: ":memory:",
}
}
func prepareDb() *db.Database {
database := db.New(&TestALRConfig{})
database.Init(context.Background())
return database
}
var testPkg = db.Package{ var testPkg = db.Package{
Name: "test", Name: "test",
Version: "0.0.1", Version: "0.0.1",
@ -59,18 +75,11 @@ var testPkg = db.Package{
} }
func TestInit(t *testing.T) { func TestInit(t *testing.T) {
_, err := db.Open(":memory:") ctx := context.Background()
if err != nil { database := prepareDb()
t.Fatalf("Expected no error, got %s", err) defer database.Close()
}
defer db.Close()
_, err = db.DB().Exec("SELECT * FROM pkgs") ver, ok := database.GetVersion(ctx)
if err != nil {
t.Fatalf("Expected no error, got %s", err)
}
ver, ok := db.GetVersion()
if !ok { if !ok {
t.Errorf("Expected version to be present") t.Errorf("Expected version to be present")
} else if ver != db.CurrentVersion { } else if ver != db.CurrentVersion {
@ -79,19 +88,17 @@ func TestInit(t *testing.T) {
} }
func TestInsertPackage(t *testing.T) { func TestInsertPackage(t *testing.T) {
_, err := db.Open(":memory:") ctx := context.Background()
if err != nil { database := prepareDb()
t.Fatalf("Expected no error, got %s", err) defer database.Close()
}
defer db.Close()
err = db.InsertPackage(testPkg) err := database.InsertPackage(ctx, testPkg)
if err != nil { if err != nil {
t.Fatalf("Expected no error, got %s", err) t.Fatalf("Expected no error, got %s", err)
} }
dbPkg := db.Package{} dbPkg := db.Package{}
err = sqlx.Get(db.DB(), &dbPkg, "SELECT * FROM pkgs WHERE name = 'test' AND repository = 'default'") err = sqlx.Get(database.GetConn(), &dbPkg, "SELECT * FROM pkgs WHERE name = 'test' AND repository = 'default'")
if err != nil { if err != nil {
t.Fatalf("Expected no error, got %s", err) t.Fatalf("Expected no error, got %s", err)
} }
@ -102,28 +109,26 @@ func TestInsertPackage(t *testing.T) {
} }
func TestGetPkgs(t *testing.T) { func TestGetPkgs(t *testing.T) {
_, err := db.Open(":memory:") ctx := context.Background()
if err != nil { database := prepareDb()
t.Fatalf("Expected no error, got %s", err) defer database.Close()
}
defer db.Close()
x1 := testPkg x1 := testPkg
x1.Name = "x1" x1.Name = "x1"
x2 := testPkg x2 := testPkg
x2.Name = "x2" x2.Name = "x2"
err = db.InsertPackage(x1) err := database.InsertPackage(ctx, x1)
if err != nil { if err != nil {
t.Errorf("Expected no error, got %s", err) t.Errorf("Expected no error, got %s", err)
} }
err = db.InsertPackage(x2) err = database.InsertPackage(ctx, x2)
if err != nil { if err != nil {
t.Errorf("Expected no error, got %s", err) t.Errorf("Expected no error, got %s", err)
} }
result, err := db.GetPkgs("name LIKE 'x%'") result, err := database.GetPkgs(ctx, "name LIKE 'x%'")
if err != nil { if err != nil {
t.Fatalf("Expected no error, got %s", err) t.Fatalf("Expected no error, got %s", err)
} }
@ -142,28 +147,26 @@ func TestGetPkgs(t *testing.T) {
} }
func TestGetPkg(t *testing.T) { func TestGetPkg(t *testing.T) {
_, err := db.Open(":memory:") ctx := context.Background()
if err != nil { database := prepareDb()
t.Fatalf("Expected no error, got %s", err) defer database.Close()
}
defer db.Close()
x1 := testPkg x1 := testPkg
x1.Name = "x1" x1.Name = "x1"
x2 := testPkg x2 := testPkg
x2.Name = "x2" x2.Name = "x2"
err = db.InsertPackage(x1) err := database.InsertPackage(ctx, x1)
if err != nil { if err != nil {
t.Errorf("Expected no error, got %s", err) t.Errorf("Expected no error, got %s", err)
} }
err = db.InsertPackage(x2) err = database.InsertPackage(ctx, x2)
if err != nil { if err != nil {
t.Errorf("Expected no error, got %s", err) t.Errorf("Expected no error, got %s", err)
} }
pkg, err := db.GetPkg("name LIKE 'x%' ORDER BY name") pkg, err := database.GetPkg(ctx, "name LIKE 'x%' ORDER BY name")
if err != nil { if err != nil {
t.Fatalf("Expected no error, got %s", err) t.Fatalf("Expected no error, got %s", err)
} }
@ -178,34 +181,32 @@ func TestGetPkg(t *testing.T) {
} }
func TestDeletePkgs(t *testing.T) { func TestDeletePkgs(t *testing.T) {
_, err := db.Open(":memory:") ctx := context.Background()
if err != nil { database := prepareDb()
t.Fatalf("Expected no error, got %s", err) defer database.Close()
}
defer db.Close()
x1 := testPkg x1 := testPkg
x1.Name = "x1" x1.Name = "x1"
x2 := testPkg x2 := testPkg
x2.Name = "x2" x2.Name = "x2"
err = db.InsertPackage(x1) err := database.InsertPackage(ctx, x1)
if err != nil { if err != nil {
t.Errorf("Expected no error, got %s", err) t.Errorf("Expected no error, got %s", err)
} }
err = db.InsertPackage(x2) err = database.InsertPackage(ctx, x2)
if err != nil { if err != nil {
t.Errorf("Expected no error, got %s", err) t.Errorf("Expected no error, got %s", err)
} }
err = db.DeletePkgs("name = 'x1'") err = database.DeletePkgs(ctx, "name = 'x1'")
if err != nil { if err != nil {
t.Errorf("Expected no error, got %s", err) t.Errorf("Expected no error, got %s", err)
} }
var dbPkg db.Package var dbPkg db.Package
err = db.DB().Get(&dbPkg, "SELECT * FROM pkgs WHERE name LIKE 'x%' ORDER BY name LIMIT 1;") err = database.GetConn().Get(&dbPkg, "SELECT * FROM pkgs WHERE name LIKE 'x%' ORDER BY name LIMIT 1;")
if err != nil { if err != nil {
t.Errorf("Expected no error, got %s", err) t.Errorf("Expected no error, got %s", err)
} }
@ -216,11 +217,9 @@ func TestDeletePkgs(t *testing.T) {
} }
func TestJsonArrayContains(t *testing.T) { func TestJsonArrayContains(t *testing.T) {
_, err := db.Open(":memory:") ctx := context.Background()
if err != nil { database := prepareDb()
t.Fatalf("Expected no error, got %s", err) defer database.Close()
}
defer db.Close()
x1 := testPkg x1 := testPkg
x1.Name = "x1" x1.Name = "x1"
@ -228,18 +227,18 @@ func TestJsonArrayContains(t *testing.T) {
x2.Name = "x2" x2.Name = "x2"
x2.Provides.Val = append(x2.Provides.Val, "x") x2.Provides.Val = append(x2.Provides.Val, "x")
err = db.InsertPackage(x1) err := database.InsertPackage(ctx, x1)
if err != nil { if err != nil {
t.Errorf("Expected no error, got %s", err) t.Errorf("Expected no error, got %s", err)
} }
err = db.InsertPackage(x2) err = database.InsertPackage(ctx, x2)
if err != nil { if err != nil {
t.Errorf("Expected no error, got %s", err) t.Errorf("Expected no error, got %s", err)
} }
var dbPkg db.Package var dbPkg db.Package
err = db.DB().Get(&dbPkg, "SELECT * FROM pkgs WHERE json_array_contains(provides, 'x');") err = database.GetConn().Get(&dbPkg, "SELECT * FROM pkgs WHERE json_array_contains(provides, 'x');")
if err != nil { if err != nil {
t.Fatalf("Expected no error, got %s", err) t.Fatalf("Expected no error, got %s", err)
} }

64
internal/db/json.go Normal file

@ -0,0 +1,64 @@
package db
import (
"database/sql"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
)
// JSON represents a JSON value in the database
type JSON[T any] struct {
Val T
}
// NewJSON creates a new database JSON value
func NewJSON[T any](v T) JSON[T] {
return JSON[T]{Val: v}
}
func (s *JSON[T]) Scan(val any) error {
if val == nil {
return nil
}
switch val := val.(type) {
case string:
err := json.Unmarshal([]byte(val), &s.Val)
if err != nil {
return err
}
case sql.NullString:
if val.Valid {
err := json.Unmarshal([]byte(val.String), &s.Val)
if err != nil {
return err
}
}
default:
return errors.New("sqlite json types must be strings")
}
return nil
}
func (s JSON[T]) Value() (driver.Value, error) {
data, err := json.Marshal(s.Val)
if err != nil {
return nil, err
}
return string(data), nil
}
func (s JSON[T]) MarshalYAML() (any, error) {
return s.Val, nil
}
func (s JSON[T]) String() string {
return fmt.Sprint(s.Val)
}
func (s JSON[T]) GoString() string {
return fmt.Sprintf("%#v", s.Val)
}

36
internal/db/utils.go Normal file

@ -0,0 +1,36 @@
package db
import (
"database/sql/driver"
"encoding/json"
"errors"
"golang.org/x/exp/slices"
"modernc.org/sqlite"
)
func init() {
sqlite.MustRegisterScalarFunction("json_array_contains", 2, jsonArrayContains)
}
// jsonArrayContains is an SQLite function that checks if a JSON array
// in the database contains a given value
func jsonArrayContains(ctx *sqlite.FunctionContext, args []driver.Value) (driver.Value, error) {
value, ok := args[0].(string)
if !ok {
return nil, errors.New("both arguments to json_array_contains must be strings")
}
item, ok := args[1].(string)
if !ok {
return nil, errors.New("both arguments to json_array_contains must be strings")
}
var array []string
err := json.Unmarshal([]byte(value), &array)
if err != nil {
return nil, err
}
return slices.Contains(array, item), nil
}