diff --git a/internal/config/config.go b/internal/config/config.go
index db43f89..2574945 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -21,59 +21,108 @@ package config
import (
"context"
"os"
+ "path/filepath"
"sync"
"github.com/pelletier/go-toml/v2"
+ "go.elara.ws/logger/log"
"plemya-x.ru/alr/internal/types"
"plemya-x.ru/alr/pkg/loggerctx"
)
-var defaultConfig = &types.Config{
- RootCmd: "sudo",
- PagerStyle: "native",
- IgnorePkgUpdates: []string{},
- Repos: []types.Repo{
- {
- Name: "default",
- URL: "https://gitea.plemya-x.ru/xpamych/xpamych-alr-repo.git",
- },
- },
+type ALRConfig struct {
+ cfg *types.Config
+ paths *Paths
+
+ pathsOnce sync.Once
}
-var (
- configMtx sync.Mutex
- config *types.Config
-)
+func New() *ALRConfig {
+ return &ALRConfig{}
+}
-// Config returns a ALR configuration struct.
-// The first time it's called, it'll load the config from a file.
-// Subsequent calls will just return the same value.
-func Config(ctx context.Context) *types.Config {
- configMtx.Lock()
- defer configMtx.Unlock()
+func (c *ALRConfig) Load(ctx context.Context) {
+ cfgFl, err := os.Open(c.GetPaths(ctx).ConfigPath)
+ if err != nil {
+ log.Warn("Error opening config file, using defaults").Err(err).Send()
+ c.cfg = defaultConfig
+ return
+ }
+ defer cfgFl.Close()
+
+ // Copy the default configuration into config
+ defCopy := *defaultConfig
+ config := &defCopy
+ config.Repos = nil
+
+ err = toml.NewDecoder(cfgFl).Decode(config)
+ if err != nil {
+ log.Warn("Error decoding config file, using defaults").Err(err).Send()
+ c.cfg = defaultConfig
+ return
+ }
+ c.cfg = config
+}
+
+func (c *ALRConfig) initPaths(ctx context.Context) {
log := loggerctx.From(ctx)
+ paths := &Paths{}
- if config == nil {
- cfgFl, err := os.Open(GetPaths(ctx).ConfigPath)
- if err != nil {
- log.Warn("Error opening config file, using defaults").Err(err).Send()
- return defaultConfig
- }
- defer cfgFl.Close()
-
- // Copy the default configuration into config
- defCopy := *defaultConfig
- config = &defCopy
- config.Repos = nil
-
- err = toml.NewDecoder(cfgFl).Decode(config)
- if err != nil {
- log.Warn("Error decoding config file, using defaults").Err(err).Send()
- // Set config back to nil so that we try again next time
- config = nil
- return defaultConfig
- }
+ cfgDir, err := os.UserConfigDir()
+ if err != nil {
+ log.Fatal("Unable to detect user config directory").Err(err).Send()
}
- return config
+ paths.ConfigDir = filepath.Join(cfgDir, "alr")
+
+ err = os.MkdirAll(paths.ConfigDir, 0o755)
+ if err != nil {
+ log.Fatal("Unable to create ALR config directory").Err(err).Send()
+ }
+
+ paths.ConfigPath = filepath.Join(paths.ConfigDir, "alr.toml")
+
+ if _, err := os.Stat(paths.ConfigPath); err != nil {
+ cfgFl, err := os.Create(paths.ConfigPath)
+ if err != nil {
+ log.Fatal("Unable to create ALR config file").Err(err).Send()
+ }
+
+ err = toml.NewEncoder(cfgFl).Encode(&defaultConfig)
+ if err != nil {
+ log.Fatal("Error encoding default configuration").Err(err).Send()
+ }
+
+ cfgFl.Close()
+ }
+
+ cacheDir, err := os.UserCacheDir()
+ if err != nil {
+ log.Fatal("Unable to detect cache directory").Err(err).Send()
+ }
+
+ paths.CacheDir = filepath.Join(cacheDir, "alr")
+ paths.RepoDir = filepath.Join(paths.CacheDir, "repo")
+ paths.PkgsDir = filepath.Join(paths.CacheDir, "pkgs")
+
+ err = os.MkdirAll(paths.RepoDir, 0o755)
+ if err != nil {
+ log.Fatal("Unable to create repo cache directory").Err(err).Send()
+ }
+
+ err = os.MkdirAll(paths.PkgsDir, 0o755)
+ if err != nil {
+ log.Fatal("Unable to create package cache directory").Err(err).Send()
+ }
+
+ paths.DBPath = filepath.Join(paths.CacheDir, "db")
+
+ c.paths = paths
+}
+
+func (c *ALRConfig) GetPaths(ctx context.Context) *Paths {
+ c.pathsOnce.Do(func() {
+ c.initPaths(ctx)
+ })
+ return c.paths
}
diff --git a/internal/config/config_legacy.go b/internal/config/config_legacy.go
new file mode 100644
index 0000000..49ecf90
--- /dev/null
+++ b/internal/config/config_legacy.go
@@ -0,0 +1,65 @@
+/*
+ * ALR - Any Linux Repository
+ * Copyright (C) 2024 Евгений Храмов
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+
+package config
+
+import (
+ "context"
+ "sync"
+
+ "plemya-x.ru/alr/internal/types"
+)
+
+var defaultConfig = &types.Config{
+ RootCmd: "sudo",
+ PagerStyle: "native",
+ IgnorePkgUpdates: []string{},
+ Repos: []types.Repo{
+ {
+ Name: "default",
+ URL: "https://gitea.plemya-x.ru/xpamych/xpamych-alr-repo.git",
+ },
+ },
+}
+
+// Config returns a ALR configuration struct.
+// The first time it's called, it'll load the config from a file.
+// Subsequent calls will just return the same value.
+//
+// Deprecated: use struct method
+func Config(ctx context.Context) *types.Config {
+ return GetInstance(ctx).cfg
+}
+
+// =======================
+// FOR LEGACY ONLY
+// =======================
+
+var (
+ alrConfig *ALRConfig
+ alrConfigOnce sync.Once
+)
+
+func GetInstance(ctx context.Context) *ALRConfig {
+ alrConfigOnce.Do(func() {
+ alrConfig = New()
+ alrConfig.Load(ctx)
+ })
+
+ return alrConfig
+}
diff --git a/internal/config/paths.go b/internal/config/paths.go
index 8b10c35..ebdea31 100644
--- a/internal/config/paths.go
+++ b/internal/config/paths.go
@@ -20,12 +20,6 @@ package config
import (
"context"
- "os"
- "path/filepath"
- "sync"
-
- "github.com/pelletier/go-toml/v2"
- "plemya-x.ru/alr/pkg/loggerctx"
)
// Paths contains various paths used by ALR
@@ -38,71 +32,13 @@ type Paths struct {
DBPath string
}
-var (
- pathsMtx sync.Mutex
- paths *Paths
-)
-
// GetPaths returns a Paths struct.
// The first time it's called, it'll generate the struct
// using information from the system.
// Subsequent calls will return the same value.
+//
+// Depreacted: use struct API
func GetPaths(ctx context.Context) *Paths {
- pathsMtx.Lock()
- defer pathsMtx.Unlock()
-
- log := loggerctx.From(ctx)
- if paths == nil {
- paths = &Paths{}
-
- cfgDir, err := os.UserConfigDir()
- if err != nil {
- log.Fatal("Unable to detect user config directory").Err(err).Send()
- }
-
- paths.ConfigDir = filepath.Join(cfgDir, "alr")
-
- err = os.MkdirAll(paths.ConfigDir, 0o755)
- if err != nil {
- log.Fatal("Unable to create ALR config directory").Err(err).Send()
- }
-
- paths.ConfigPath = filepath.Join(paths.ConfigDir, "alr.toml")
-
- if _, err := os.Stat(paths.ConfigPath); err != nil {
- cfgFl, err := os.Create(paths.ConfigPath)
- if err != nil {
- log.Fatal("Unable to create ALR config file").Err(err).Send()
- }
-
- err = toml.NewEncoder(cfgFl).Encode(&defaultConfig)
- if err != nil {
- log.Fatal("Error encoding default configuration").Err(err).Send()
- }
-
- cfgFl.Close()
- }
-
- cacheDir, err := os.UserCacheDir()
- if err != nil {
- log.Fatal("Unable to detect cache directory").Err(err).Send()
- }
-
- paths.CacheDir = filepath.Join(cacheDir, "alr")
- paths.RepoDir = filepath.Join(paths.CacheDir, "repo")
- paths.PkgsDir = filepath.Join(paths.CacheDir, "pkgs")
-
- err = os.MkdirAll(paths.RepoDir, 0o755)
- if err != nil {
- log.Fatal("Unable to create repo cache directory").Err(err).Send()
- }
-
- err = os.MkdirAll(paths.PkgsDir, 0o755)
- if err != nil {
- log.Fatal("Unable to create package cache directory").Err(err).Send()
- }
-
- paths.DBPath = filepath.Join(paths.CacheDir, "db")
- }
- return paths
+ alrConfig := GetInstance(ctx)
+ return alrConfig.GetPaths(ctx)
}
diff --git a/internal/db/db.go b/internal/db/db.go
index 9e421cb..a4bd4a5 100644
--- a/internal/db/db.go
+++ b/internal/db/db.go
@@ -20,28 +20,16 @@ package db
import (
"context"
- "database/sql"
- "database/sql/driver"
- "encoding/json"
- "errors"
- "fmt"
- "sync"
"github.com/jmoiron/sqlx"
"plemya-x.ru/alr/internal/config"
"plemya-x.ru/alr/pkg/loggerctx"
- "golang.org/x/exp/slices"
- "modernc.org/sqlite"
)
// CurrentVersion is the current version of the database.
// The database is reset if its version doesn't match this.
const CurrentVersion = 2
-func init() {
- sqlite.MustRegisterScalarFunction("json_array_contains", 2, jsonArrayContains)
-}
-
// Package is a ALR package's database representation
type Package struct {
Name string `sh:"name,required" db:"name"`
@@ -66,66 +54,47 @@ type version struct {
Version int `db:"version"`
}
-var (
- mu sync.Mutex
+type Config interface {
+ GetPaths(ctx context.Context) *config.Paths
+}
+
+type Database struct {
conn *sqlx.DB
- closed = true
-)
+ config Config
+}
-// DB returns the ALR database.
-// The first time it's called, it opens the SQLite database file.
-// Subsequent calls return the same connection.
-func DB(ctx context.Context) *sqlx.DB {
- log := loggerctx.From(ctx)
- if conn != nil && !closed {
- return getConn()
+func New(config Config) *Database {
+ return &Database{
+ config: config,
}
- _, err := open(ctx, config.GetPaths(ctx).DBPath)
+}
+
+func (d *Database) Init(ctx context.Context) error {
+ err := d.Connect(ctx)
if err != nil {
- log.Fatal("Error opening database").Err(err).Send()
+ return err
}
- return getConn()
+ return d.initDB(ctx)
}
-func getConn() *sqlx.DB {
- mu.Lock()
- defer mu.Unlock()
- return conn
-}
-
-func open(ctx context.Context, dsn string) (*sqlx.DB, error) {
+func (d *Database) Connect(ctx context.Context) error {
+ dsn := d.config.GetPaths(ctx).DBPath
db, err := sqlx.Open("sqlite", dsn)
if err != nil {
- return nil, err
+ return err
}
-
- mu.Lock()
- conn = db
- closed = false
- mu.Unlock()
-
- err = initDB(ctx, dsn)
- if err != nil {
- return nil, err
- }
-
- return db, nil
+ d.conn = db
+ return nil
}
-// Close closes the database
-func Close() error {
- closed = true
- if conn != nil {
- return conn.Close()
- } else {
- return nil
- }
+func (d *Database) GetConn() *sqlx.DB {
+ return d.conn
}
-// initDB initializes the database
-func initDB(ctx context.Context, dsn string) error {
+func (d *Database) initDB(ctx context.Context) error {
log := loggerctx.From(ctx)
- conn = conn.Unsafe()
+ d.conn = d.conn.Unsafe()
+ conn := d.conn
_, err := conn.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS pkgs (
name TEXT NOT NULL,
@@ -155,58 +124,72 @@ func initDB(ctx context.Context, dsn string) error {
return err
}
- ver, ok := GetVersion(ctx)
+ ver, ok := d.GetVersion(ctx)
if ok && ver != CurrentVersion {
log.Warn("Database version mismatch; resetting").Int("version", ver).Int("expected", CurrentVersion).Send()
- reset(ctx)
- return initDB(ctx, dsn)
+ d.reset(ctx)
+ return d.initDB(ctx)
} else if !ok {
log.Warn("Database version does not exist. Run alr fix if something isn't working.").Send()
- return addVersion(ctx, CurrentVersion)
+ return d.addVersion(ctx, CurrentVersion)
}
return nil
}
-// reset drops all the database tables
-func reset(ctx context.Context) error {
- _, err := DB(ctx).ExecContext(ctx, "DROP TABLE IF EXISTS pkgs;")
- if err != nil {
- return err
- }
- _, err = DB(ctx).ExecContext(ctx, "DROP TABLE IF EXISTS alr_db_version;")
- return err
-}
-
-// IsEmpty returns true if the database has no packages in it, otherwise it returns false.
-func IsEmpty(ctx context.Context) bool {
- var count int
- err := DB(ctx).GetContext(ctx, &count, "SELECT count(1) FROM pkgs;")
- if err != nil {
- return true
- }
- return count == 0
-}
-
-// GetVersion returns the database version and a boolean indicating
-// whether the database contained a version number
-func GetVersion(ctx context.Context) (int, bool) {
+func (d *Database) GetVersion(ctx context.Context) (int, bool) {
var ver version
- err := DB(ctx).GetContext(ctx, &ver, "SELECT * FROM alr_db_version LIMIT 1;")
+ err := d.conn.GetContext(ctx, &ver, "SELECT * FROM alr_db_version LIMIT 1;")
if err != nil {
return 0, false
}
return ver.Version, true
}
-func addVersion(ctx context.Context, ver int) error {
- _, err := DB(ctx).ExecContext(ctx, `INSERT INTO alr_db_version(version) VALUES (?);`, ver)
+func (d *Database) addVersion(ctx context.Context, ver int) error {
+ _, err := d.conn.ExecContext(ctx, `INSERT INTO alr_db_version(version) VALUES (?);`, ver)
return err
}
-// InsertPackage adds a package to the database
-func InsertPackage(ctx context.Context, pkg Package) error {
- _, err := DB(ctx).NamedExecContext(ctx, `
+func (d *Database) reset(ctx context.Context) error {
+ _, err := d.conn.ExecContext(ctx, "DROP TABLE IF EXISTS pkgs;")
+ if err != nil {
+ return err
+ }
+ _, err = d.conn.ExecContext(ctx, "DROP TABLE IF EXISTS alr_db_version;")
+ return err
+}
+
+func (d *Database) GetPkgs(ctx context.Context, where string, args ...any) (*sqlx.Rows, error) {
+ stream, err := d.conn.QueryxContext(ctx, "SELECT * FROM pkgs WHERE "+where, args...)
+ if err != nil {
+ return nil, err
+ }
+ return stream, nil
+}
+
+func (d *Database) GetPkg(ctx context.Context, where string, args ...any) (*Package, error) {
+ out := &Package{}
+ err := d.conn.GetContext(ctx, out, "SELECT * FROM pkgs WHERE "+where+" LIMIT 1", args...)
+ return out, err
+}
+
+func (d *Database) DeletePkgs(ctx context.Context, where string, args ...any) error {
+ _, err := d.conn.ExecContext(ctx, "DELETE FROM pkgs WHERE "+where, args...)
+ return err
+}
+
+func (d *Database) IsEmpty(ctx context.Context) bool {
+ var count int
+ err := d.conn.GetContext(ctx, &count, "SELECT count(1) FROM pkgs;")
+ if err != nil {
+ return true
+ }
+ return count == 0
+}
+
+func (d *Database) InsertPackage(ctx context.Context, pkg Package) error {
+ _, err := d.conn.NamedExecContext(ctx, `
INSERT OR REPLACE INTO pkgs (
name,
repository,
@@ -246,101 +229,10 @@ func InsertPackage(ctx context.Context, pkg Package) error {
return err
}
-// GetPkgs returns a result containing packages that match the where conditions
-func GetPkgs(ctx context.Context, where string, args ...any) (*sqlx.Rows, error) {
- stream, err := DB(ctx).QueryxContext(ctx, "SELECT * FROM pkgs WHERE "+where, args...)
- if err != nil {
- return nil, err
- }
- return stream, nil
-}
-
-// GetPkg returns a single package that matches the where conditions
-func GetPkg(ctx context.Context, where string, args ...any) (*Package, error) {
- out := &Package{}
- err := DB(ctx).GetContext(ctx, out, "SELECT * FROM pkgs WHERE "+where+" LIMIT 1", args...)
- return out, err
-}
-
-// DeletePkgs deletes all packages matching the where conditions
-func DeletePkgs(ctx context.Context, where string, args ...any) error {
- _, err := DB(ctx).ExecContext(ctx, "DELETE FROM pkgs WHERE "+where, args...)
- return err
-}
-
-// jsonArrayContains is an SQLite function that checks if a JSON array
-// in the database contains a given value
-func jsonArrayContains(ctx *sqlite.FunctionContext, args []driver.Value) (driver.Value, error) {
- value, ok := args[0].(string)
- if !ok {
- return nil, errors.New("both arguments to json_array_contains must be strings")
- }
-
- item, ok := args[1].(string)
- if !ok {
- return nil, errors.New("both arguments to json_array_contains must be strings")
- }
-
- var array []string
- err := json.Unmarshal([]byte(value), &array)
- if err != nil {
- return nil, err
- }
-
- return slices.Contains(array, item), nil
-}
-
-// JSON represents a JSON value in the database
-type JSON[T any] struct {
- Val T
-}
-
-// NewJSON creates a new database JSON value
-func NewJSON[T any](v T) JSON[T] {
- return JSON[T]{Val: v}
-}
-
-func (s *JSON[T]) Scan(val any) error {
- if val == nil {
+func (d *Database) Close() error {
+ if d.conn != nil {
+ return d.conn.Close()
+ } else {
return nil
}
-
- switch val := val.(type) {
- case string:
- err := json.Unmarshal([]byte(val), &s.Val)
- if err != nil {
- return err
- }
- case sql.NullString:
- if val.Valid {
- err := json.Unmarshal([]byte(val.String), &s.Val)
- if err != nil {
- return err
- }
- }
- default:
- return errors.New("sqlite json types must be strings")
- }
-
- return nil
-}
-
-func (s JSON[T]) Value() (driver.Value, error) {
- data, err := json.Marshal(s.Val)
- if err != nil {
- return nil, err
- }
- return string(data), nil
-}
-
-func (s JSON[T]) MarshalYAML() (any, error) {
- return s.Val, nil
-}
-
-func (s JSON[T]) String() string {
- return fmt.Sprint(s.Val)
-}
-
-func (s JSON[T]) GoString() string {
- return fmt.Sprintf("%#v", s.Val)
}
diff --git a/internal/db/db_legacy.go b/internal/db/db_legacy.go
new file mode 100644
index 0000000..83b9128
--- /dev/null
+++ b/internal/db/db_legacy.go
@@ -0,0 +1,105 @@
+/*
+ * ALR - Any Linux Repository
+ * Copyright (C) 2024 Евгений Храмов
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+
+package db
+
+import (
+ "context"
+ "sync"
+
+ "github.com/jmoiron/sqlx"
+ "plemya-x.ru/alr/internal/config"
+ "plemya-x.ru/alr/pkg/loggerctx"
+)
+
+// DB returns the ALR database.
+// The first time it's called, it opens the SQLite database file.
+// Subsequent calls return the same connection.
+//
+// Deprecated: use struct method
+func DB(ctx context.Context) *sqlx.DB {
+ return getInstance(ctx).GetConn()
+}
+
+// Close closes the database
+//
+// Deprecated: use struct method
+func Close() error {
+ if database != nil {
+ return database.Close()
+ }
+ return nil
+}
+
+// IsEmpty returns true if the database has no packages in it, otherwise it returns false.
+//
+// Deprecated: use struct method
+func IsEmpty(ctx context.Context) bool {
+ return getInstance(ctx).IsEmpty(ctx)
+}
+
+// InsertPackage adds a package to the database
+//
+// Deprecated: use struct method
+func InsertPackage(ctx context.Context, pkg Package) error {
+ return getInstance(ctx).InsertPackage(ctx, pkg)
+}
+
+// GetPkgs returns a result containing packages that match the where conditions
+//
+// Deprecated: use struct method
+func GetPkgs(ctx context.Context, where string, args ...any) (*sqlx.Rows, error) {
+ return getInstance(ctx).GetPkgs(ctx, where, args...)
+}
+
+// GetPkg returns a single package that matches the where conditions
+//
+// Deprecated: use struct method
+func GetPkg(ctx context.Context, where string, args ...any) (*Package, error) {
+ return getInstance(ctx).GetPkg(ctx, where, args...)
+}
+
+// DeletePkgs deletes all packages matching the where conditions
+//
+// Deprecated: use struct method
+func DeletePkgs(ctx context.Context, where string, args ...any) error {
+ return getInstance(ctx).DeletePkgs(ctx, where, args...)
+}
+
+// =======================
+// FOR LEGACY ONLY
+// =======================
+
+var (
+ dbOnce sync.Once
+ database *Database
+)
+
+// For refactoring only
+func getInstance(ctx context.Context) *Database {
+ dbOnce.Do(func() {
+ log := loggerctx.From(ctx)
+ cfg := config.GetInstance(ctx)
+ database = New(cfg)
+ err := database.Init(ctx)
+ if err != nil {
+ log.Fatal("Error opening database").Err(err).Send()
+ }
+ })
+ return database
+}
diff --git a/internal/db/db_test.go b/internal/db/db_test.go
index 9e38ee6..4412181 100644
--- a/internal/db/db_test.go
+++ b/internal/db/db_test.go
@@ -19,14 +19,30 @@
package db_test
import (
+ "context"
"reflect"
"strings"
"testing"
"github.com/jmoiron/sqlx"
+ "plemya-x.ru/alr/internal/config"
"plemya-x.ru/alr/internal/db"
)
+type TestALRConfig struct{}
+
+func (c *TestALRConfig) GetPaths(ctx context.Context) *config.Paths {
+ return &config.Paths{
+ DBPath: ":memory:",
+ }
+}
+
+func prepareDb() *db.Database {
+ database := db.New(&TestALRConfig{})
+ database.Init(context.Background())
+ return database
+}
+
var testPkg = db.Package{
Name: "test",
Version: "0.0.1",
@@ -59,18 +75,11 @@ var testPkg = db.Package{
}
func TestInit(t *testing.T) {
- _, err := db.Open(":memory:")
- if err != nil {
- t.Fatalf("Expected no error, got %s", err)
- }
- defer db.Close()
+ ctx := context.Background()
+ database := prepareDb()
+ defer database.Close()
- _, err = db.DB().Exec("SELECT * FROM pkgs")
- if err != nil {
- t.Fatalf("Expected no error, got %s", err)
- }
-
- ver, ok := db.GetVersion()
+ ver, ok := database.GetVersion(ctx)
if !ok {
t.Errorf("Expected version to be present")
} else if ver != db.CurrentVersion {
@@ -79,19 +88,17 @@ func TestInit(t *testing.T) {
}
func TestInsertPackage(t *testing.T) {
- _, err := db.Open(":memory:")
- if err != nil {
- t.Fatalf("Expected no error, got %s", err)
- }
- defer db.Close()
+ ctx := context.Background()
+ database := prepareDb()
+ defer database.Close()
- err = db.InsertPackage(testPkg)
+ err := database.InsertPackage(ctx, testPkg)
if err != nil {
t.Fatalf("Expected no error, got %s", err)
}
dbPkg := db.Package{}
- err = sqlx.Get(db.DB(), &dbPkg, "SELECT * FROM pkgs WHERE name = 'test' AND repository = 'default'")
+ err = sqlx.Get(database.GetConn(), &dbPkg, "SELECT * FROM pkgs WHERE name = 'test' AND repository = 'default'")
if err != nil {
t.Fatalf("Expected no error, got %s", err)
}
@@ -102,28 +109,26 @@ func TestInsertPackage(t *testing.T) {
}
func TestGetPkgs(t *testing.T) {
- _, err := db.Open(":memory:")
- if err != nil {
- t.Fatalf("Expected no error, got %s", err)
- }
- defer db.Close()
+ ctx := context.Background()
+ database := prepareDb()
+ defer database.Close()
x1 := testPkg
x1.Name = "x1"
x2 := testPkg
x2.Name = "x2"
- err = db.InsertPackage(x1)
+ err := database.InsertPackage(ctx, x1)
if err != nil {
t.Errorf("Expected no error, got %s", err)
}
- err = db.InsertPackage(x2)
+ err = database.InsertPackage(ctx, x2)
if err != nil {
t.Errorf("Expected no error, got %s", err)
}
- result, err := db.GetPkgs("name LIKE 'x%'")
+ result, err := database.GetPkgs(ctx, "name LIKE 'x%'")
if err != nil {
t.Fatalf("Expected no error, got %s", err)
}
@@ -142,28 +147,26 @@ func TestGetPkgs(t *testing.T) {
}
func TestGetPkg(t *testing.T) {
- _, err := db.Open(":memory:")
- if err != nil {
- t.Fatalf("Expected no error, got %s", err)
- }
- defer db.Close()
+ ctx := context.Background()
+ database := prepareDb()
+ defer database.Close()
x1 := testPkg
x1.Name = "x1"
x2 := testPkg
x2.Name = "x2"
- err = db.InsertPackage(x1)
+ err := database.InsertPackage(ctx, x1)
if err != nil {
t.Errorf("Expected no error, got %s", err)
}
- err = db.InsertPackage(x2)
+ err = database.InsertPackage(ctx, x2)
if err != nil {
t.Errorf("Expected no error, got %s", err)
}
- pkg, err := db.GetPkg("name LIKE 'x%' ORDER BY name")
+ pkg, err := database.GetPkg(ctx, "name LIKE 'x%' ORDER BY name")
if err != nil {
t.Fatalf("Expected no error, got %s", err)
}
@@ -178,34 +181,32 @@ func TestGetPkg(t *testing.T) {
}
func TestDeletePkgs(t *testing.T) {
- _, err := db.Open(":memory:")
- if err != nil {
- t.Fatalf("Expected no error, got %s", err)
- }
- defer db.Close()
+ ctx := context.Background()
+ database := prepareDb()
+ defer database.Close()
x1 := testPkg
x1.Name = "x1"
x2 := testPkg
x2.Name = "x2"
- err = db.InsertPackage(x1)
+ err := database.InsertPackage(ctx, x1)
if err != nil {
t.Errorf("Expected no error, got %s", err)
}
- err = db.InsertPackage(x2)
+ err = database.InsertPackage(ctx, x2)
if err != nil {
t.Errorf("Expected no error, got %s", err)
}
- err = db.DeletePkgs("name = 'x1'")
+ err = database.DeletePkgs(ctx, "name = 'x1'")
if err != nil {
t.Errorf("Expected no error, got %s", err)
}
var dbPkg db.Package
- err = db.DB().Get(&dbPkg, "SELECT * FROM pkgs WHERE name LIKE 'x%' ORDER BY name LIMIT 1;")
+ err = database.GetConn().Get(&dbPkg, "SELECT * FROM pkgs WHERE name LIKE 'x%' ORDER BY name LIMIT 1;")
if err != nil {
t.Errorf("Expected no error, got %s", err)
}
@@ -216,11 +217,9 @@ func TestDeletePkgs(t *testing.T) {
}
func TestJsonArrayContains(t *testing.T) {
- _, err := db.Open(":memory:")
- if err != nil {
- t.Fatalf("Expected no error, got %s", err)
- }
- defer db.Close()
+ ctx := context.Background()
+ database := prepareDb()
+ defer database.Close()
x1 := testPkg
x1.Name = "x1"
@@ -228,18 +227,18 @@ func TestJsonArrayContains(t *testing.T) {
x2.Name = "x2"
x2.Provides.Val = append(x2.Provides.Val, "x")
- err = db.InsertPackage(x1)
+ err := database.InsertPackage(ctx, x1)
if err != nil {
t.Errorf("Expected no error, got %s", err)
}
- err = db.InsertPackage(x2)
+ err = database.InsertPackage(ctx, x2)
if err != nil {
t.Errorf("Expected no error, got %s", err)
}
var dbPkg db.Package
- err = db.DB().Get(&dbPkg, "SELECT * FROM pkgs WHERE json_array_contains(provides, 'x');")
+ err = database.GetConn().Get(&dbPkg, "SELECT * FROM pkgs WHERE json_array_contains(provides, 'x');")
if err != nil {
t.Fatalf("Expected no error, got %s", err)
}
diff --git a/internal/db/json.go b/internal/db/json.go
new file mode 100644
index 0000000..2b05693
--- /dev/null
+++ b/internal/db/json.go
@@ -0,0 +1,64 @@
+package db
+
+import (
+ "database/sql"
+ "database/sql/driver"
+ "encoding/json"
+ "errors"
+ "fmt"
+)
+
+// JSON represents a JSON value in the database
+type JSON[T any] struct {
+ Val T
+}
+
+// NewJSON creates a new database JSON value
+func NewJSON[T any](v T) JSON[T] {
+ return JSON[T]{Val: v}
+}
+
+func (s *JSON[T]) Scan(val any) error {
+ if val == nil {
+ return nil
+ }
+
+ switch val := val.(type) {
+ case string:
+ err := json.Unmarshal([]byte(val), &s.Val)
+ if err != nil {
+ return err
+ }
+ case sql.NullString:
+ if val.Valid {
+ err := json.Unmarshal([]byte(val.String), &s.Val)
+ if err != nil {
+ return err
+ }
+ }
+ default:
+ return errors.New("sqlite json types must be strings")
+ }
+
+ return nil
+}
+
+func (s JSON[T]) Value() (driver.Value, error) {
+ data, err := json.Marshal(s.Val)
+ if err != nil {
+ return nil, err
+ }
+ return string(data), nil
+}
+
+func (s JSON[T]) MarshalYAML() (any, error) {
+ return s.Val, nil
+}
+
+func (s JSON[T]) String() string {
+ return fmt.Sprint(s.Val)
+}
+
+func (s JSON[T]) GoString() string {
+ return fmt.Sprintf("%#v", s.Val)
+}
diff --git a/internal/db/utils.go b/internal/db/utils.go
new file mode 100644
index 0000000..3cbce79
--- /dev/null
+++ b/internal/db/utils.go
@@ -0,0 +1,36 @@
+package db
+
+import (
+ "database/sql/driver"
+ "encoding/json"
+ "errors"
+
+ "golang.org/x/exp/slices"
+ "modernc.org/sqlite"
+)
+
+func init() {
+ sqlite.MustRegisterScalarFunction("json_array_contains", 2, jsonArrayContains)
+}
+
+// jsonArrayContains is an SQLite function that checks if a JSON array
+// in the database contains a given value
+func jsonArrayContains(ctx *sqlite.FunctionContext, args []driver.Value) (driver.Value, error) {
+ value, ok := args[0].(string)
+ if !ok {
+ return nil, errors.New("both arguments to json_array_contains must be strings")
+ }
+
+ item, ok := args[1].(string)
+ if !ok {
+ return nil, errors.New("both arguments to json_array_contains must be strings")
+ }
+
+ var array []string
+ err := json.Unmarshal([]byte(value), &array)
+ if err != nil {
+ return nil, err
+ }
+
+ return slices.Contains(array, item), nil
+}