diff --git a/internal/config/config.go b/internal/config/config.go index db43f89..2574945 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -21,59 +21,108 @@ package config import ( "context" "os" + "path/filepath" "sync" "github.com/pelletier/go-toml/v2" + "go.elara.ws/logger/log" "plemya-x.ru/alr/internal/types" "plemya-x.ru/alr/pkg/loggerctx" ) -var defaultConfig = &types.Config{ - RootCmd: "sudo", - PagerStyle: "native", - IgnorePkgUpdates: []string{}, - Repos: []types.Repo{ - { - Name: "default", - URL: "https://gitea.plemya-x.ru/xpamych/xpamych-alr-repo.git", - }, - }, +type ALRConfig struct { + cfg *types.Config + paths *Paths + + pathsOnce sync.Once } -var ( - configMtx sync.Mutex - config *types.Config -) +func New() *ALRConfig { + return &ALRConfig{} +} -// Config returns a ALR configuration struct. -// The first time it's called, it'll load the config from a file. -// Subsequent calls will just return the same value. -func Config(ctx context.Context) *types.Config { - configMtx.Lock() - defer configMtx.Unlock() +func (c *ALRConfig) Load(ctx context.Context) { + cfgFl, err := os.Open(c.GetPaths(ctx).ConfigPath) + if err != nil { + log.Warn("Error opening config file, using defaults").Err(err).Send() + c.cfg = defaultConfig + return + } + defer cfgFl.Close() + + // Copy the default configuration into config + defCopy := *defaultConfig + config := &defCopy + config.Repos = nil + + err = toml.NewDecoder(cfgFl).Decode(config) + if err != nil { + log.Warn("Error decoding config file, using defaults").Err(err).Send() + c.cfg = defaultConfig + return + } + c.cfg = config +} + +func (c *ALRConfig) initPaths(ctx context.Context) { log := loggerctx.From(ctx) + paths := &Paths{} - if config == nil { - cfgFl, err := os.Open(GetPaths(ctx).ConfigPath) - if err != nil { - log.Warn("Error opening config file, using defaults").Err(err).Send() - return defaultConfig - } - defer cfgFl.Close() - - // Copy the default configuration into config - defCopy := *defaultConfig - config = &defCopy - config.Repos = nil - - err = toml.NewDecoder(cfgFl).Decode(config) - if err != nil { - log.Warn("Error decoding config file, using defaults").Err(err).Send() - // Set config back to nil so that we try again next time - config = nil - return defaultConfig - } + cfgDir, err := os.UserConfigDir() + if err != nil { + log.Fatal("Unable to detect user config directory").Err(err).Send() } - return config + paths.ConfigDir = filepath.Join(cfgDir, "alr") + + err = os.MkdirAll(paths.ConfigDir, 0o755) + if err != nil { + log.Fatal("Unable to create ALR config directory").Err(err).Send() + } + + paths.ConfigPath = filepath.Join(paths.ConfigDir, "alr.toml") + + if _, err := os.Stat(paths.ConfigPath); err != nil { + cfgFl, err := os.Create(paths.ConfigPath) + if err != nil { + log.Fatal("Unable to create ALR config file").Err(err).Send() + } + + err = toml.NewEncoder(cfgFl).Encode(&defaultConfig) + if err != nil { + log.Fatal("Error encoding default configuration").Err(err).Send() + } + + cfgFl.Close() + } + + cacheDir, err := os.UserCacheDir() + if err != nil { + log.Fatal("Unable to detect cache directory").Err(err).Send() + } + + paths.CacheDir = filepath.Join(cacheDir, "alr") + paths.RepoDir = filepath.Join(paths.CacheDir, "repo") + paths.PkgsDir = filepath.Join(paths.CacheDir, "pkgs") + + err = os.MkdirAll(paths.RepoDir, 0o755) + if err != nil { + log.Fatal("Unable to create repo cache directory").Err(err).Send() + } + + err = os.MkdirAll(paths.PkgsDir, 0o755) + if err != nil { + log.Fatal("Unable to create package cache directory").Err(err).Send() + } + + paths.DBPath = filepath.Join(paths.CacheDir, "db") + + c.paths = paths +} + +func (c *ALRConfig) GetPaths(ctx context.Context) *Paths { + c.pathsOnce.Do(func() { + c.initPaths(ctx) + }) + return c.paths } diff --git a/internal/config/config_legacy.go b/internal/config/config_legacy.go new file mode 100644 index 0000000..49ecf90 --- /dev/null +++ b/internal/config/config_legacy.go @@ -0,0 +1,65 @@ +/* + * ALR - Any Linux Repository + * Copyright (C) 2024 Евгений Храмов + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package config + +import ( + "context" + "sync" + + "plemya-x.ru/alr/internal/types" +) + +var defaultConfig = &types.Config{ + RootCmd: "sudo", + PagerStyle: "native", + IgnorePkgUpdates: []string{}, + Repos: []types.Repo{ + { + Name: "default", + URL: "https://gitea.plemya-x.ru/xpamych/xpamych-alr-repo.git", + }, + }, +} + +// Config returns a ALR configuration struct. +// The first time it's called, it'll load the config from a file. +// Subsequent calls will just return the same value. +// +// Deprecated: use struct method +func Config(ctx context.Context) *types.Config { + return GetInstance(ctx).cfg +} + +// ======================= +// FOR LEGACY ONLY +// ======================= + +var ( + alrConfig *ALRConfig + alrConfigOnce sync.Once +) + +func GetInstance(ctx context.Context) *ALRConfig { + alrConfigOnce.Do(func() { + alrConfig = New() + alrConfig.Load(ctx) + }) + + return alrConfig +} diff --git a/internal/config/paths.go b/internal/config/paths.go index 8b10c35..ebdea31 100644 --- a/internal/config/paths.go +++ b/internal/config/paths.go @@ -20,12 +20,6 @@ package config import ( "context" - "os" - "path/filepath" - "sync" - - "github.com/pelletier/go-toml/v2" - "plemya-x.ru/alr/pkg/loggerctx" ) // Paths contains various paths used by ALR @@ -38,71 +32,13 @@ type Paths struct { DBPath string } -var ( - pathsMtx sync.Mutex - paths *Paths -) - // GetPaths returns a Paths struct. // The first time it's called, it'll generate the struct // using information from the system. // Subsequent calls will return the same value. +// +// Depreacted: use struct API func GetPaths(ctx context.Context) *Paths { - pathsMtx.Lock() - defer pathsMtx.Unlock() - - log := loggerctx.From(ctx) - if paths == nil { - paths = &Paths{} - - cfgDir, err := os.UserConfigDir() - if err != nil { - log.Fatal("Unable to detect user config directory").Err(err).Send() - } - - paths.ConfigDir = filepath.Join(cfgDir, "alr") - - err = os.MkdirAll(paths.ConfigDir, 0o755) - if err != nil { - log.Fatal("Unable to create ALR config directory").Err(err).Send() - } - - paths.ConfigPath = filepath.Join(paths.ConfigDir, "alr.toml") - - if _, err := os.Stat(paths.ConfigPath); err != nil { - cfgFl, err := os.Create(paths.ConfigPath) - if err != nil { - log.Fatal("Unable to create ALR config file").Err(err).Send() - } - - err = toml.NewEncoder(cfgFl).Encode(&defaultConfig) - if err != nil { - log.Fatal("Error encoding default configuration").Err(err).Send() - } - - cfgFl.Close() - } - - cacheDir, err := os.UserCacheDir() - if err != nil { - log.Fatal("Unable to detect cache directory").Err(err).Send() - } - - paths.CacheDir = filepath.Join(cacheDir, "alr") - paths.RepoDir = filepath.Join(paths.CacheDir, "repo") - paths.PkgsDir = filepath.Join(paths.CacheDir, "pkgs") - - err = os.MkdirAll(paths.RepoDir, 0o755) - if err != nil { - log.Fatal("Unable to create repo cache directory").Err(err).Send() - } - - err = os.MkdirAll(paths.PkgsDir, 0o755) - if err != nil { - log.Fatal("Unable to create package cache directory").Err(err).Send() - } - - paths.DBPath = filepath.Join(paths.CacheDir, "db") - } - return paths + alrConfig := GetInstance(ctx) + return alrConfig.GetPaths(ctx) } diff --git a/internal/db/db.go b/internal/db/db.go index 9e421cb..a4bd4a5 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -20,28 +20,16 @@ package db import ( "context" - "database/sql" - "database/sql/driver" - "encoding/json" - "errors" - "fmt" - "sync" "github.com/jmoiron/sqlx" "plemya-x.ru/alr/internal/config" "plemya-x.ru/alr/pkg/loggerctx" - "golang.org/x/exp/slices" - "modernc.org/sqlite" ) // CurrentVersion is the current version of the database. // The database is reset if its version doesn't match this. const CurrentVersion = 2 -func init() { - sqlite.MustRegisterScalarFunction("json_array_contains", 2, jsonArrayContains) -} - // Package is a ALR package's database representation type Package struct { Name string `sh:"name,required" db:"name"` @@ -66,66 +54,47 @@ type version struct { Version int `db:"version"` } -var ( - mu sync.Mutex +type Config interface { + GetPaths(ctx context.Context) *config.Paths +} + +type Database struct { conn *sqlx.DB - closed = true -) + config Config +} -// DB returns the ALR database. -// The first time it's called, it opens the SQLite database file. -// Subsequent calls return the same connection. -func DB(ctx context.Context) *sqlx.DB { - log := loggerctx.From(ctx) - if conn != nil && !closed { - return getConn() +func New(config Config) *Database { + return &Database{ + config: config, } - _, err := open(ctx, config.GetPaths(ctx).DBPath) +} + +func (d *Database) Init(ctx context.Context) error { + err := d.Connect(ctx) if err != nil { - log.Fatal("Error opening database").Err(err).Send() + return err } - return getConn() + return d.initDB(ctx) } -func getConn() *sqlx.DB { - mu.Lock() - defer mu.Unlock() - return conn -} - -func open(ctx context.Context, dsn string) (*sqlx.DB, error) { +func (d *Database) Connect(ctx context.Context) error { + dsn := d.config.GetPaths(ctx).DBPath db, err := sqlx.Open("sqlite", dsn) if err != nil { - return nil, err + return err } - - mu.Lock() - conn = db - closed = false - mu.Unlock() - - err = initDB(ctx, dsn) - if err != nil { - return nil, err - } - - return db, nil + d.conn = db + return nil } -// Close closes the database -func Close() error { - closed = true - if conn != nil { - return conn.Close() - } else { - return nil - } +func (d *Database) GetConn() *sqlx.DB { + return d.conn } -// initDB initializes the database -func initDB(ctx context.Context, dsn string) error { +func (d *Database) initDB(ctx context.Context) error { log := loggerctx.From(ctx) - conn = conn.Unsafe() + d.conn = d.conn.Unsafe() + conn := d.conn _, err := conn.ExecContext(ctx, ` CREATE TABLE IF NOT EXISTS pkgs ( name TEXT NOT NULL, @@ -155,58 +124,72 @@ func initDB(ctx context.Context, dsn string) error { return err } - ver, ok := GetVersion(ctx) + ver, ok := d.GetVersion(ctx) if ok && ver != CurrentVersion { log.Warn("Database version mismatch; resetting").Int("version", ver).Int("expected", CurrentVersion).Send() - reset(ctx) - return initDB(ctx, dsn) + d.reset(ctx) + return d.initDB(ctx) } else if !ok { log.Warn("Database version does not exist. Run alr fix if something isn't working.").Send() - return addVersion(ctx, CurrentVersion) + return d.addVersion(ctx, CurrentVersion) } return nil } -// reset drops all the database tables -func reset(ctx context.Context) error { - _, err := DB(ctx).ExecContext(ctx, "DROP TABLE IF EXISTS pkgs;") - if err != nil { - return err - } - _, err = DB(ctx).ExecContext(ctx, "DROP TABLE IF EXISTS alr_db_version;") - return err -} - -// IsEmpty returns true if the database has no packages in it, otherwise it returns false. -func IsEmpty(ctx context.Context) bool { - var count int - err := DB(ctx).GetContext(ctx, &count, "SELECT count(1) FROM pkgs;") - if err != nil { - return true - } - return count == 0 -} - -// GetVersion returns the database version and a boolean indicating -// whether the database contained a version number -func GetVersion(ctx context.Context) (int, bool) { +func (d *Database) GetVersion(ctx context.Context) (int, bool) { var ver version - err := DB(ctx).GetContext(ctx, &ver, "SELECT * FROM alr_db_version LIMIT 1;") + err := d.conn.GetContext(ctx, &ver, "SELECT * FROM alr_db_version LIMIT 1;") if err != nil { return 0, false } return ver.Version, true } -func addVersion(ctx context.Context, ver int) error { - _, err := DB(ctx).ExecContext(ctx, `INSERT INTO alr_db_version(version) VALUES (?);`, ver) +func (d *Database) addVersion(ctx context.Context, ver int) error { + _, err := d.conn.ExecContext(ctx, `INSERT INTO alr_db_version(version) VALUES (?);`, ver) return err } -// InsertPackage adds a package to the database -func InsertPackage(ctx context.Context, pkg Package) error { - _, err := DB(ctx).NamedExecContext(ctx, ` +func (d *Database) reset(ctx context.Context) error { + _, err := d.conn.ExecContext(ctx, "DROP TABLE IF EXISTS pkgs;") + if err != nil { + return err + } + _, err = d.conn.ExecContext(ctx, "DROP TABLE IF EXISTS alr_db_version;") + return err +} + +func (d *Database) GetPkgs(ctx context.Context, where string, args ...any) (*sqlx.Rows, error) { + stream, err := d.conn.QueryxContext(ctx, "SELECT * FROM pkgs WHERE "+where, args...) + if err != nil { + return nil, err + } + return stream, nil +} + +func (d *Database) GetPkg(ctx context.Context, where string, args ...any) (*Package, error) { + out := &Package{} + err := d.conn.GetContext(ctx, out, "SELECT * FROM pkgs WHERE "+where+" LIMIT 1", args...) + return out, err +} + +func (d *Database) DeletePkgs(ctx context.Context, where string, args ...any) error { + _, err := d.conn.ExecContext(ctx, "DELETE FROM pkgs WHERE "+where, args...) + return err +} + +func (d *Database) IsEmpty(ctx context.Context) bool { + var count int + err := d.conn.GetContext(ctx, &count, "SELECT count(1) FROM pkgs;") + if err != nil { + return true + } + return count == 0 +} + +func (d *Database) InsertPackage(ctx context.Context, pkg Package) error { + _, err := d.conn.NamedExecContext(ctx, ` INSERT OR REPLACE INTO pkgs ( name, repository, @@ -246,101 +229,10 @@ func InsertPackage(ctx context.Context, pkg Package) error { return err } -// GetPkgs returns a result containing packages that match the where conditions -func GetPkgs(ctx context.Context, where string, args ...any) (*sqlx.Rows, error) { - stream, err := DB(ctx).QueryxContext(ctx, "SELECT * FROM pkgs WHERE "+where, args...) - if err != nil { - return nil, err - } - return stream, nil -} - -// GetPkg returns a single package that matches the where conditions -func GetPkg(ctx context.Context, where string, args ...any) (*Package, error) { - out := &Package{} - err := DB(ctx).GetContext(ctx, out, "SELECT * FROM pkgs WHERE "+where+" LIMIT 1", args...) - return out, err -} - -// DeletePkgs deletes all packages matching the where conditions -func DeletePkgs(ctx context.Context, where string, args ...any) error { - _, err := DB(ctx).ExecContext(ctx, "DELETE FROM pkgs WHERE "+where, args...) - return err -} - -// jsonArrayContains is an SQLite function that checks if a JSON array -// in the database contains a given value -func jsonArrayContains(ctx *sqlite.FunctionContext, args []driver.Value) (driver.Value, error) { - value, ok := args[0].(string) - if !ok { - return nil, errors.New("both arguments to json_array_contains must be strings") - } - - item, ok := args[1].(string) - if !ok { - return nil, errors.New("both arguments to json_array_contains must be strings") - } - - var array []string - err := json.Unmarshal([]byte(value), &array) - if err != nil { - return nil, err - } - - return slices.Contains(array, item), nil -} - -// JSON represents a JSON value in the database -type JSON[T any] struct { - Val T -} - -// NewJSON creates a new database JSON value -func NewJSON[T any](v T) JSON[T] { - return JSON[T]{Val: v} -} - -func (s *JSON[T]) Scan(val any) error { - if val == nil { +func (d *Database) Close() error { + if d.conn != nil { + return d.conn.Close() + } else { return nil } - - switch val := val.(type) { - case string: - err := json.Unmarshal([]byte(val), &s.Val) - if err != nil { - return err - } - case sql.NullString: - if val.Valid { - err := json.Unmarshal([]byte(val.String), &s.Val) - if err != nil { - return err - } - } - default: - return errors.New("sqlite json types must be strings") - } - - return nil -} - -func (s JSON[T]) Value() (driver.Value, error) { - data, err := json.Marshal(s.Val) - if err != nil { - return nil, err - } - return string(data), nil -} - -func (s JSON[T]) MarshalYAML() (any, error) { - return s.Val, nil -} - -func (s JSON[T]) String() string { - return fmt.Sprint(s.Val) -} - -func (s JSON[T]) GoString() string { - return fmt.Sprintf("%#v", s.Val) } diff --git a/internal/db/db_legacy.go b/internal/db/db_legacy.go new file mode 100644 index 0000000..83b9128 --- /dev/null +++ b/internal/db/db_legacy.go @@ -0,0 +1,105 @@ +/* + * ALR - Any Linux Repository + * Copyright (C) 2024 Евгений Храмов + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package db + +import ( + "context" + "sync" + + "github.com/jmoiron/sqlx" + "plemya-x.ru/alr/internal/config" + "plemya-x.ru/alr/pkg/loggerctx" +) + +// DB returns the ALR database. +// The first time it's called, it opens the SQLite database file. +// Subsequent calls return the same connection. +// +// Deprecated: use struct method +func DB(ctx context.Context) *sqlx.DB { + return getInstance(ctx).GetConn() +} + +// Close closes the database +// +// Deprecated: use struct method +func Close() error { + if database != nil { + return database.Close() + } + return nil +} + +// IsEmpty returns true if the database has no packages in it, otherwise it returns false. +// +// Deprecated: use struct method +func IsEmpty(ctx context.Context) bool { + return getInstance(ctx).IsEmpty(ctx) +} + +// InsertPackage adds a package to the database +// +// Deprecated: use struct method +func InsertPackage(ctx context.Context, pkg Package) error { + return getInstance(ctx).InsertPackage(ctx, pkg) +} + +// GetPkgs returns a result containing packages that match the where conditions +// +// Deprecated: use struct method +func GetPkgs(ctx context.Context, where string, args ...any) (*sqlx.Rows, error) { + return getInstance(ctx).GetPkgs(ctx, where, args...) +} + +// GetPkg returns a single package that matches the where conditions +// +// Deprecated: use struct method +func GetPkg(ctx context.Context, where string, args ...any) (*Package, error) { + return getInstance(ctx).GetPkg(ctx, where, args...) +} + +// DeletePkgs deletes all packages matching the where conditions +// +// Deprecated: use struct method +func DeletePkgs(ctx context.Context, where string, args ...any) error { + return getInstance(ctx).DeletePkgs(ctx, where, args...) +} + +// ======================= +// FOR LEGACY ONLY +// ======================= + +var ( + dbOnce sync.Once + database *Database +) + +// For refactoring only +func getInstance(ctx context.Context) *Database { + dbOnce.Do(func() { + log := loggerctx.From(ctx) + cfg := config.GetInstance(ctx) + database = New(cfg) + err := database.Init(ctx) + if err != nil { + log.Fatal("Error opening database").Err(err).Send() + } + }) + return database +} diff --git a/internal/db/db_test.go b/internal/db/db_test.go index 9e38ee6..4412181 100644 --- a/internal/db/db_test.go +++ b/internal/db/db_test.go @@ -19,14 +19,30 @@ package db_test import ( + "context" "reflect" "strings" "testing" "github.com/jmoiron/sqlx" + "plemya-x.ru/alr/internal/config" "plemya-x.ru/alr/internal/db" ) +type TestALRConfig struct{} + +func (c *TestALRConfig) GetPaths(ctx context.Context) *config.Paths { + return &config.Paths{ + DBPath: ":memory:", + } +} + +func prepareDb() *db.Database { + database := db.New(&TestALRConfig{}) + database.Init(context.Background()) + return database +} + var testPkg = db.Package{ Name: "test", Version: "0.0.1", @@ -59,18 +75,11 @@ var testPkg = db.Package{ } func TestInit(t *testing.T) { - _, err := db.Open(":memory:") - if err != nil { - t.Fatalf("Expected no error, got %s", err) - } - defer db.Close() + ctx := context.Background() + database := prepareDb() + defer database.Close() - _, err = db.DB().Exec("SELECT * FROM pkgs") - if err != nil { - t.Fatalf("Expected no error, got %s", err) - } - - ver, ok := db.GetVersion() + ver, ok := database.GetVersion(ctx) if !ok { t.Errorf("Expected version to be present") } else if ver != db.CurrentVersion { @@ -79,19 +88,17 @@ func TestInit(t *testing.T) { } func TestInsertPackage(t *testing.T) { - _, err := db.Open(":memory:") - if err != nil { - t.Fatalf("Expected no error, got %s", err) - } - defer db.Close() + ctx := context.Background() + database := prepareDb() + defer database.Close() - err = db.InsertPackage(testPkg) + err := database.InsertPackage(ctx, testPkg) if err != nil { t.Fatalf("Expected no error, got %s", err) } dbPkg := db.Package{} - err = sqlx.Get(db.DB(), &dbPkg, "SELECT * FROM pkgs WHERE name = 'test' AND repository = 'default'") + err = sqlx.Get(database.GetConn(), &dbPkg, "SELECT * FROM pkgs WHERE name = 'test' AND repository = 'default'") if err != nil { t.Fatalf("Expected no error, got %s", err) } @@ -102,28 +109,26 @@ func TestInsertPackage(t *testing.T) { } func TestGetPkgs(t *testing.T) { - _, err := db.Open(":memory:") - if err != nil { - t.Fatalf("Expected no error, got %s", err) - } - defer db.Close() + ctx := context.Background() + database := prepareDb() + defer database.Close() x1 := testPkg x1.Name = "x1" x2 := testPkg x2.Name = "x2" - err = db.InsertPackage(x1) + err := database.InsertPackage(ctx, x1) if err != nil { t.Errorf("Expected no error, got %s", err) } - err = db.InsertPackage(x2) + err = database.InsertPackage(ctx, x2) if err != nil { t.Errorf("Expected no error, got %s", err) } - result, err := db.GetPkgs("name LIKE 'x%'") + result, err := database.GetPkgs(ctx, "name LIKE 'x%'") if err != nil { t.Fatalf("Expected no error, got %s", err) } @@ -142,28 +147,26 @@ func TestGetPkgs(t *testing.T) { } func TestGetPkg(t *testing.T) { - _, err := db.Open(":memory:") - if err != nil { - t.Fatalf("Expected no error, got %s", err) - } - defer db.Close() + ctx := context.Background() + database := prepareDb() + defer database.Close() x1 := testPkg x1.Name = "x1" x2 := testPkg x2.Name = "x2" - err = db.InsertPackage(x1) + err := database.InsertPackage(ctx, x1) if err != nil { t.Errorf("Expected no error, got %s", err) } - err = db.InsertPackage(x2) + err = database.InsertPackage(ctx, x2) if err != nil { t.Errorf("Expected no error, got %s", err) } - pkg, err := db.GetPkg("name LIKE 'x%' ORDER BY name") + pkg, err := database.GetPkg(ctx, "name LIKE 'x%' ORDER BY name") if err != nil { t.Fatalf("Expected no error, got %s", err) } @@ -178,34 +181,32 @@ func TestGetPkg(t *testing.T) { } func TestDeletePkgs(t *testing.T) { - _, err := db.Open(":memory:") - if err != nil { - t.Fatalf("Expected no error, got %s", err) - } - defer db.Close() + ctx := context.Background() + database := prepareDb() + defer database.Close() x1 := testPkg x1.Name = "x1" x2 := testPkg x2.Name = "x2" - err = db.InsertPackage(x1) + err := database.InsertPackage(ctx, x1) if err != nil { t.Errorf("Expected no error, got %s", err) } - err = db.InsertPackage(x2) + err = database.InsertPackage(ctx, x2) if err != nil { t.Errorf("Expected no error, got %s", err) } - err = db.DeletePkgs("name = 'x1'") + err = database.DeletePkgs(ctx, "name = 'x1'") if err != nil { t.Errorf("Expected no error, got %s", err) } var dbPkg db.Package - err = db.DB().Get(&dbPkg, "SELECT * FROM pkgs WHERE name LIKE 'x%' ORDER BY name LIMIT 1;") + err = database.GetConn().Get(&dbPkg, "SELECT * FROM pkgs WHERE name LIKE 'x%' ORDER BY name LIMIT 1;") if err != nil { t.Errorf("Expected no error, got %s", err) } @@ -216,11 +217,9 @@ func TestDeletePkgs(t *testing.T) { } func TestJsonArrayContains(t *testing.T) { - _, err := db.Open(":memory:") - if err != nil { - t.Fatalf("Expected no error, got %s", err) - } - defer db.Close() + ctx := context.Background() + database := prepareDb() + defer database.Close() x1 := testPkg x1.Name = "x1" @@ -228,18 +227,18 @@ func TestJsonArrayContains(t *testing.T) { x2.Name = "x2" x2.Provides.Val = append(x2.Provides.Val, "x") - err = db.InsertPackage(x1) + err := database.InsertPackage(ctx, x1) if err != nil { t.Errorf("Expected no error, got %s", err) } - err = db.InsertPackage(x2) + err = database.InsertPackage(ctx, x2) if err != nil { t.Errorf("Expected no error, got %s", err) } var dbPkg db.Package - err = db.DB().Get(&dbPkg, "SELECT * FROM pkgs WHERE json_array_contains(provides, 'x');") + err = database.GetConn().Get(&dbPkg, "SELECT * FROM pkgs WHERE json_array_contains(provides, 'x');") if err != nil { t.Fatalf("Expected no error, got %s", err) } diff --git a/internal/db/json.go b/internal/db/json.go new file mode 100644 index 0000000..2b05693 --- /dev/null +++ b/internal/db/json.go @@ -0,0 +1,64 @@ +package db + +import ( + "database/sql" + "database/sql/driver" + "encoding/json" + "errors" + "fmt" +) + +// JSON represents a JSON value in the database +type JSON[T any] struct { + Val T +} + +// NewJSON creates a new database JSON value +func NewJSON[T any](v T) JSON[T] { + return JSON[T]{Val: v} +} + +func (s *JSON[T]) Scan(val any) error { + if val == nil { + return nil + } + + switch val := val.(type) { + case string: + err := json.Unmarshal([]byte(val), &s.Val) + if err != nil { + return err + } + case sql.NullString: + if val.Valid { + err := json.Unmarshal([]byte(val.String), &s.Val) + if err != nil { + return err + } + } + default: + return errors.New("sqlite json types must be strings") + } + + return nil +} + +func (s JSON[T]) Value() (driver.Value, error) { + data, err := json.Marshal(s.Val) + if err != nil { + return nil, err + } + return string(data), nil +} + +func (s JSON[T]) MarshalYAML() (any, error) { + return s.Val, nil +} + +func (s JSON[T]) String() string { + return fmt.Sprint(s.Val) +} + +func (s JSON[T]) GoString() string { + return fmt.Sprintf("%#v", s.Val) +} diff --git a/internal/db/utils.go b/internal/db/utils.go new file mode 100644 index 0000000..3cbce79 --- /dev/null +++ b/internal/db/utils.go @@ -0,0 +1,36 @@ +package db + +import ( + "database/sql/driver" + "encoding/json" + "errors" + + "golang.org/x/exp/slices" + "modernc.org/sqlite" +) + +func init() { + sqlite.MustRegisterScalarFunction("json_array_contains", 2, jsonArrayContains) +} + +// jsonArrayContains is an SQLite function that checks if a JSON array +// in the database contains a given value +func jsonArrayContains(ctx *sqlite.FunctionContext, args []driver.Value) (driver.Value, error) { + value, ok := args[0].(string) + if !ok { + return nil, errors.New("both arguments to json_array_contains must be strings") + } + + item, ok := args[1].(string) + if !ok { + return nil, errors.New("both arguments to json_array_contains must be strings") + } + + var array []string + err := json.Unmarshal([]byte(value), &array) + if err != nil { + return nil, err + } + + return slices.Contains(array, item), nil +}