refactor(db, config, repos): migrate from functions to struct #9
| @@ -21,59 +21,108 @@ package config | |||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"os" | 	"os" | ||||||
|  | 	"path/filepath" | ||||||
| 	"sync" | 	"sync" | ||||||
|  |  | ||||||
| 	"github.com/pelletier/go-toml/v2" | 	"github.com/pelletier/go-toml/v2" | ||||||
|  | 	"go.elara.ws/logger/log" | ||||||
| 	"plemya-x.ru/alr/internal/types" | 	"plemya-x.ru/alr/internal/types" | ||||||
| 	"plemya-x.ru/alr/pkg/loggerctx" | 	"plemya-x.ru/alr/pkg/loggerctx" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var defaultConfig = &types.Config{ | type ALRConfig struct { | ||||||
| 	RootCmd:          "sudo", | 	cfg   *types.Config | ||||||
| 	PagerStyle:       "native", | 	paths *Paths | ||||||
| 	IgnorePkgUpdates: []string{}, |  | ||||||
| 	Repos: []types.Repo{ | 	pathsOnce sync.Once | ||||||
| 		{ |  | ||||||
| 			Name: "default", |  | ||||||
| 			URL:  "https://gitea.plemya-x.ru/xpamych/xpamych-alr-repo.git", |  | ||||||
| 		}, |  | ||||||
| 	}, |  | ||||||
| } | } | ||||||
|  |  | ||||||
| var ( | func New() *ALRConfig { | ||||||
| 	configMtx sync.Mutex | 	return &ALRConfig{} | ||||||
| 	config    *types.Config | } | ||||||
| ) |  | ||||||
|  |  | ||||||
| // Config returns a ALR configuration struct. | func (c *ALRConfig) Load(ctx context.Context) { | ||||||
| // The first time it's called, it'll load the config from a file. | 	cfgFl, err := os.Open(c.GetPaths(ctx).ConfigPath) | ||||||
| // Subsequent calls will just return the same value. | 	if err != nil { | ||||||
| func Config(ctx context.Context) *types.Config { | 		log.Warn("Error opening config file, using defaults").Err(err).Send() | ||||||
| 	configMtx.Lock() | 		c.cfg = defaultConfig | ||||||
| 	defer configMtx.Unlock() | 		return | ||||||
|  | 	} | ||||||
|  | 	defer cfgFl.Close() | ||||||
|  |  | ||||||
|  | 	// Copy the default configuration into config | ||||||
|  | 	defCopy := *defaultConfig | ||||||
|  | 	config := &defCopy | ||||||
|  | 	config.Repos = nil | ||||||
|  |  | ||||||
|  | 	err = toml.NewDecoder(cfgFl).Decode(config) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Warn("Error decoding config file, using defaults").Err(err).Send() | ||||||
|  | 		c.cfg = defaultConfig | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	c.cfg = config | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *ALRConfig) initPaths(ctx context.Context) { | ||||||
| 	log := loggerctx.From(ctx) | 	log := loggerctx.From(ctx) | ||||||
|  | 	paths := &Paths{} | ||||||
|  |  | ||||||
| 	if config == nil { | 	cfgDir, err := os.UserConfigDir() | ||||||
| 		cfgFl, err := os.Open(GetPaths(ctx).ConfigPath) | 	if err != nil { | ||||||
| 		if err != nil { | 		log.Fatal("Unable to detect user config directory").Err(err).Send() | ||||||
| 			log.Warn("Error opening config file, using defaults").Err(err).Send() |  | ||||||
| 			return defaultConfig |  | ||||||
| 		} |  | ||||||
| 		defer cfgFl.Close() |  | ||||||
|  |  | ||||||
| 		// Copy the default configuration into config |  | ||||||
| 		defCopy := *defaultConfig |  | ||||||
| 		config = &defCopy |  | ||||||
| 		config.Repos = nil |  | ||||||
|  |  | ||||||
| 		err = toml.NewDecoder(cfgFl).Decode(config) |  | ||||||
| 		if err != nil { |  | ||||||
| 			log.Warn("Error decoding config file, using defaults").Err(err).Send() |  | ||||||
| 			// Set config back to nil so that we try again next time |  | ||||||
| 			config = nil |  | ||||||
| 			return defaultConfig |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return config | 	paths.ConfigDir = filepath.Join(cfgDir, "alr") | ||||||
|  |  | ||||||
|  | 	err = os.MkdirAll(paths.ConfigDir, 0o755) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Fatal("Unable to create ALR config directory").Err(err).Send() | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	paths.ConfigPath = filepath.Join(paths.ConfigDir, "alr.toml") | ||||||
|  |  | ||||||
|  | 	if _, err := os.Stat(paths.ConfigPath); err != nil { | ||||||
|  | 		cfgFl, err := os.Create(paths.ConfigPath) | ||||||
|  | 		if err != nil { | ||||||
|  | 			log.Fatal("Unable to create ALR config file").Err(err).Send() | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		err = toml.NewEncoder(cfgFl).Encode(&defaultConfig) | ||||||
|  | 		if err != nil { | ||||||
|  | 			log.Fatal("Error encoding default configuration").Err(err).Send() | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		cfgFl.Close() | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	cacheDir, err := os.UserCacheDir() | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Fatal("Unable to detect cache directory").Err(err).Send() | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	paths.CacheDir = filepath.Join(cacheDir, "alr") | ||||||
|  | 	paths.RepoDir = filepath.Join(paths.CacheDir, "repo") | ||||||
|  | 	paths.PkgsDir = filepath.Join(paths.CacheDir, "pkgs") | ||||||
|  |  | ||||||
|  | 	err = os.MkdirAll(paths.RepoDir, 0o755) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Fatal("Unable to create repo cache directory").Err(err).Send() | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	err = os.MkdirAll(paths.PkgsDir, 0o755) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Fatal("Unable to create package cache directory").Err(err).Send() | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	paths.DBPath = filepath.Join(paths.CacheDir, "db") | ||||||
|  |  | ||||||
|  | 	c.paths = paths | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *ALRConfig) GetPaths(ctx context.Context) *Paths { | ||||||
|  | 	c.pathsOnce.Do(func() { | ||||||
|  | 		c.initPaths(ctx) | ||||||
|  | 	}) | ||||||
|  | 	return c.paths | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										65
									
								
								internal/config/config_legacy.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										65
									
								
								internal/config/config_legacy.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,65 @@ | |||||||
|  | /* | ||||||
|  |  * ALR - Any Linux Repository | ||||||
|  |  * Copyright (C) 2024 Евгений Храмов | ||||||
|  |  * | ||||||
|  |  * This program is free software: you can redistribute it and/or modify | ||||||
|  |  * it under the terms of the GNU General Public License as published by | ||||||
|  |  * the Free Software Foundation, either version 3 of the License, or | ||||||
|  |  * (at your option) any later version. | ||||||
|  |  * | ||||||
|  |  * This program is distributed in the hope that it will be useful, | ||||||
|  |  * but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||||
|  |  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||||
|  |  * GNU General Public License for more details. | ||||||
|  |  * | ||||||
|  |  * You should have received a copy of the GNU General Public License | ||||||
|  |  * along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||||
|  |  */ | ||||||
|  |  | ||||||
|  | package config | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"sync" | ||||||
|  |  | ||||||
|  | 	"plemya-x.ru/alr/internal/types" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var defaultConfig = &types.Config{ | ||||||
|  | 	RootCmd:          "sudo", | ||||||
|  | 	PagerStyle:       "native", | ||||||
|  | 	IgnorePkgUpdates: []string{}, | ||||||
|  | 	Repos: []types.Repo{ | ||||||
|  | 		{ | ||||||
|  | 			Name: "default", | ||||||
|  | 			URL:  "https://gitea.plemya-x.ru/xpamych/xpamych-alr-repo.git", | ||||||
|  | 		}, | ||||||
|  | 	}, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Config returns a ALR configuration struct. | ||||||
|  | // The first time it's called, it'll load the config from a file. | ||||||
|  | // Subsequent calls will just return the same value. | ||||||
|  | // | ||||||
|  | // Deprecated: use struct method | ||||||
|  | func Config(ctx context.Context) *types.Config { | ||||||
|  | 	return GetInstance(ctx).cfg | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ======================= | ||||||
|  | // FOR LEGACY ONLY | ||||||
|  | // ======================= | ||||||
|  |  | ||||||
|  | var ( | ||||||
|  | 	alrConfig     *ALRConfig | ||||||
|  | 	alrConfigOnce sync.Once | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func GetInstance(ctx context.Context) *ALRConfig { | ||||||
|  | 	alrConfigOnce.Do(func() { | ||||||
|  | 		alrConfig = New() | ||||||
|  | 		alrConfig.Load(ctx) | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	return alrConfig | ||||||
|  | } | ||||||
| @@ -20,12 +20,6 @@ package config | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"os" |  | ||||||
| 	"path/filepath" |  | ||||||
| 	"sync" |  | ||||||
|  |  | ||||||
| 	"github.com/pelletier/go-toml/v2" |  | ||||||
| 	"plemya-x.ru/alr/pkg/loggerctx" |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // Paths contains various paths used by ALR | // Paths contains various paths used by ALR | ||||||
| @@ -38,71 +32,13 @@ type Paths struct { | |||||||
| 	DBPath     string | 	DBPath     string | ||||||
| } | } | ||||||
|  |  | ||||||
| var ( |  | ||||||
| 	pathsMtx sync.Mutex |  | ||||||
| 	paths    *Paths |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| // GetPaths returns a Paths struct. | // GetPaths returns a Paths struct. | ||||||
| // The first time it's called, it'll generate the struct | // The first time it's called, it'll generate the struct | ||||||
| // using information from the system. | // using information from the system. | ||||||
| // Subsequent calls will return the same value. | // Subsequent calls will return the same value. | ||||||
|  | // | ||||||
|  | // Depreacted: use struct API | ||||||
| func GetPaths(ctx context.Context) *Paths { | func GetPaths(ctx context.Context) *Paths { | ||||||
| 	pathsMtx.Lock() | 	alrConfig := GetInstance(ctx) | ||||||
| 	defer pathsMtx.Unlock() | 	return alrConfig.GetPaths(ctx) | ||||||
|  |  | ||||||
| 	log := loggerctx.From(ctx) |  | ||||||
| 	if paths == nil { |  | ||||||
| 		paths = &Paths{} |  | ||||||
|  |  | ||||||
| 		cfgDir, err := os.UserConfigDir() |  | ||||||
| 		if err != nil { |  | ||||||
| 			log.Fatal("Unable to detect user config directory").Err(err).Send() |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		paths.ConfigDir = filepath.Join(cfgDir, "alr") |  | ||||||
|  |  | ||||||
| 		err = os.MkdirAll(paths.ConfigDir, 0o755) |  | ||||||
| 		if err != nil { |  | ||||||
| 			log.Fatal("Unable to create ALR config directory").Err(err).Send() |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		paths.ConfigPath = filepath.Join(paths.ConfigDir, "alr.toml") |  | ||||||
|  |  | ||||||
| 		if _, err := os.Stat(paths.ConfigPath); err != nil { |  | ||||||
| 			cfgFl, err := os.Create(paths.ConfigPath) |  | ||||||
| 			if err != nil { |  | ||||||
| 				log.Fatal("Unable to create ALR config file").Err(err).Send() |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			err = toml.NewEncoder(cfgFl).Encode(&defaultConfig) |  | ||||||
| 			if err != nil { |  | ||||||
| 				log.Fatal("Error encoding default configuration").Err(err).Send() |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			cfgFl.Close() |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		cacheDir, err := os.UserCacheDir() |  | ||||||
| 		if err != nil { |  | ||||||
| 			log.Fatal("Unable to detect cache directory").Err(err).Send() |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		paths.CacheDir = filepath.Join(cacheDir, "alr") |  | ||||||
| 		paths.RepoDir = filepath.Join(paths.CacheDir, "repo") |  | ||||||
| 		paths.PkgsDir = filepath.Join(paths.CacheDir, "pkgs") |  | ||||||
|  |  | ||||||
| 		err = os.MkdirAll(paths.RepoDir, 0o755) |  | ||||||
| 		if err != nil { |  | ||||||
| 			log.Fatal("Unable to create repo cache directory").Err(err).Send() |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		err = os.MkdirAll(paths.PkgsDir, 0o755) |  | ||||||
| 		if err != nil { |  | ||||||
| 			log.Fatal("Unable to create package cache directory").Err(err).Send() |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		paths.DBPath = filepath.Join(paths.CacheDir, "db") |  | ||||||
| 	} |  | ||||||
| 	return paths |  | ||||||
| } | } | ||||||
|   | |||||||
| @@ -20,28 +20,16 @@ package db | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"database/sql" |  | ||||||
| 	"database/sql/driver" |  | ||||||
| 	"encoding/json" |  | ||||||
| 	"errors" |  | ||||||
| 	"fmt" |  | ||||||
| 	"sync" |  | ||||||
|  |  | ||||||
| 	"github.com/jmoiron/sqlx" | 	"github.com/jmoiron/sqlx" | ||||||
| 	"plemya-x.ru/alr/internal/config" | 	"plemya-x.ru/alr/internal/config" | ||||||
| 	"plemya-x.ru/alr/pkg/loggerctx" | 	"plemya-x.ru/alr/pkg/loggerctx" | ||||||
| 	"golang.org/x/exp/slices" |  | ||||||
| 	"modernc.org/sqlite" |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // CurrentVersion is the current version of the database. | // CurrentVersion is the current version of the database. | ||||||
| // The database is reset if its version doesn't match this. | // The database is reset if its version doesn't match this. | ||||||
| const CurrentVersion = 2 | const CurrentVersion = 2 | ||||||
|  |  | ||||||
| func init() { |  | ||||||
| 	sqlite.MustRegisterScalarFunction("json_array_contains", 2, jsonArrayContains) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Package is a ALR package's database representation | // Package is a ALR package's database representation | ||||||
| type Package struct { | type Package struct { | ||||||
| 	Name          string                    `sh:"name,required" db:"name"` | 	Name          string                    `sh:"name,required" db:"name"` | ||||||
| @@ -66,66 +54,47 @@ type version struct { | |||||||
| 	Version int `db:"version"` | 	Version int `db:"version"` | ||||||
| } | } | ||||||
|  |  | ||||||
| var ( | type Config interface { | ||||||
| 	mu     sync.Mutex | 	GetPaths(ctx context.Context) *config.Paths | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Database struct { | ||||||
| 	conn   *sqlx.DB | 	conn   *sqlx.DB | ||||||
| 	closed = true | 	config Config | ||||||
| ) | } | ||||||
|  |  | ||||||
| // DB returns the ALR database. | func New(config Config) *Database { | ||||||
| // The first time it's called, it opens the SQLite database file. | 	return &Database{ | ||||||
| // Subsequent calls return the same connection. | 		config: config, | ||||||
| func DB(ctx context.Context) *sqlx.DB { |  | ||||||
| 	log := loggerctx.From(ctx) |  | ||||||
| 	if conn != nil && !closed { |  | ||||||
| 		return getConn() |  | ||||||
| 	} | 	} | ||||||
| 	_, err := open(ctx, config.GetPaths(ctx).DBPath) | } | ||||||
|  |  | ||||||
|  | func (d *Database) Init(ctx context.Context) error { | ||||||
|  | 	err := d.Connect(ctx) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Fatal("Error opening database").Err(err).Send() | 		return err | ||||||
| 	} | 	} | ||||||
| 	return getConn() | 	return d.initDB(ctx) | ||||||
| } | } | ||||||
|  |  | ||||||
| func getConn() *sqlx.DB { | func (d *Database) Connect(ctx context.Context) error { | ||||||
| 	mu.Lock() | 	dsn := d.config.GetPaths(ctx).DBPath | ||||||
| 	defer mu.Unlock() |  | ||||||
| 	return conn |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func open(ctx context.Context, dsn string) (*sqlx.DB, error) { |  | ||||||
| 	db, err := sqlx.Open("sqlite", dsn) | 	db, err := sqlx.Open("sqlite", dsn) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return err | ||||||
| 	} | 	} | ||||||
|  | 	d.conn = db | ||||||
| 	mu.Lock() | 	return nil | ||||||
| 	conn = db |  | ||||||
| 	closed = false |  | ||||||
| 	mu.Unlock() |  | ||||||
|  |  | ||||||
| 	err = initDB(ctx, dsn) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return db, nil |  | ||||||
| } | } | ||||||
|  |  | ||||||
| // Close closes the database | func (d *Database) GetConn() *sqlx.DB { | ||||||
| func Close() error { | 	return d.conn | ||||||
| 	closed = true |  | ||||||
| 	if conn != nil { |  | ||||||
| 		return conn.Close() |  | ||||||
| 	} else { |  | ||||||
| 		return nil |  | ||||||
| 	} |  | ||||||
| } | } | ||||||
|  |  | ||||||
| // initDB initializes the database | func (d *Database) initDB(ctx context.Context) error { | ||||||
| func initDB(ctx context.Context, dsn string) error { |  | ||||||
| 	log := loggerctx.From(ctx) | 	log := loggerctx.From(ctx) | ||||||
| 	conn = conn.Unsafe() | 	d.conn = d.conn.Unsafe() | ||||||
|  | 	conn := d.conn | ||||||
| 	_, err := conn.ExecContext(ctx, ` | 	_, err := conn.ExecContext(ctx, ` | ||||||
| 		CREATE TABLE IF NOT EXISTS pkgs ( | 		CREATE TABLE IF NOT EXISTS pkgs ( | ||||||
| 			name          TEXT NOT NULL, | 			name          TEXT NOT NULL, | ||||||
| @@ -155,58 +124,72 @@ func initDB(ctx context.Context, dsn string) error { | |||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	ver, ok := GetVersion(ctx) | 	ver, ok := d.GetVersion(ctx) | ||||||
| 	if ok && ver != CurrentVersion { | 	if ok && ver != CurrentVersion { | ||||||
| 		log.Warn("Database version mismatch; resetting").Int("version", ver).Int("expected", CurrentVersion).Send() | 		log.Warn("Database version mismatch; resetting").Int("version", ver).Int("expected", CurrentVersion).Send() | ||||||
| 		reset(ctx) | 		d.reset(ctx) | ||||||
| 		return initDB(ctx, dsn) | 		return d.initDB(ctx) | ||||||
| 	} else if !ok { | 	} else if !ok { | ||||||
| 		log.Warn("Database version does not exist. Run alr fix if something isn't working.").Send() | 		log.Warn("Database version does not exist. Run alr fix if something isn't working.").Send() | ||||||
| 		return addVersion(ctx, CurrentVersion) | 		return d.addVersion(ctx, CurrentVersion) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| // reset drops all the database tables | func (d *Database) GetVersion(ctx context.Context) (int, bool) { | ||||||
| func reset(ctx context.Context) error { |  | ||||||
| 	_, err := DB(ctx).ExecContext(ctx, "DROP TABLE IF EXISTS pkgs;") |  | ||||||
| 	if err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
| 	_, err = DB(ctx).ExecContext(ctx, "DROP TABLE IF EXISTS alr_db_version;") |  | ||||||
| 	return err |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // IsEmpty returns true if the database has no packages in it, otherwise it returns false. |  | ||||||
| func IsEmpty(ctx context.Context) bool { |  | ||||||
| 	var count int |  | ||||||
| 	err := DB(ctx).GetContext(ctx, &count, "SELECT count(1) FROM pkgs;") |  | ||||||
| 	if err != nil { |  | ||||||
| 		return true |  | ||||||
| 	} |  | ||||||
| 	return count == 0 |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // GetVersion returns the database version and a boolean indicating |  | ||||||
| // whether the database contained a version number |  | ||||||
| func GetVersion(ctx context.Context) (int, bool) { |  | ||||||
| 	var ver version | 	var ver version | ||||||
| 	err := DB(ctx).GetContext(ctx, &ver, "SELECT * FROM alr_db_version LIMIT 1;") | 	err := d.conn.GetContext(ctx, &ver, "SELECT * FROM alr_db_version LIMIT 1;") | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return 0, false | 		return 0, false | ||||||
| 	} | 	} | ||||||
| 	return ver.Version, true | 	return ver.Version, true | ||||||
| } | } | ||||||
|  |  | ||||||
| func addVersion(ctx context.Context, ver int) error { | func (d *Database) addVersion(ctx context.Context, ver int) error { | ||||||
| 	_, err := DB(ctx).ExecContext(ctx, `INSERT INTO alr_db_version(version) VALUES (?);`, ver) | 	_, err := d.conn.ExecContext(ctx, `INSERT INTO alr_db_version(version) VALUES (?);`, ver) | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
|  |  | ||||||
| // InsertPackage adds a package to the database | func (d *Database) reset(ctx context.Context) error { | ||||||
| func InsertPackage(ctx context.Context, pkg Package) error { | 	_, err := d.conn.ExecContext(ctx, "DROP TABLE IF EXISTS pkgs;") | ||||||
| 	_, err := DB(ctx).NamedExecContext(ctx, ` | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	_, err = d.conn.ExecContext(ctx, "DROP TABLE IF EXISTS alr_db_version;") | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (d *Database) GetPkgs(ctx context.Context, where string, args ...any) (*sqlx.Rows, error) { | ||||||
|  | 	stream, err := d.conn.QueryxContext(ctx, "SELECT * FROM pkgs WHERE "+where, args...) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	return stream, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (d *Database) GetPkg(ctx context.Context, where string, args ...any) (*Package, error) { | ||||||
|  | 	out := &Package{} | ||||||
|  | 	err := d.conn.GetContext(ctx, out, "SELECT * FROM pkgs WHERE "+where+" LIMIT 1", args...) | ||||||
|  | 	return out, err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (d *Database) DeletePkgs(ctx context.Context, where string, args ...any) error { | ||||||
|  | 	_, err := d.conn.ExecContext(ctx, "DELETE FROM pkgs WHERE "+where, args...) | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (d *Database) IsEmpty(ctx context.Context) bool { | ||||||
|  | 	var count int | ||||||
|  | 	err := d.conn.GetContext(ctx, &count, "SELECT count(1) FROM pkgs;") | ||||||
|  | 	if err != nil { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	return count == 0 | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (d *Database) InsertPackage(ctx context.Context, pkg Package) error { | ||||||
|  | 	_, err := d.conn.NamedExecContext(ctx, ` | ||||||
| 		INSERT OR REPLACE INTO pkgs ( | 		INSERT OR REPLACE INTO pkgs ( | ||||||
| 			name, | 			name, | ||||||
| 			repository, | 			repository, | ||||||
| @@ -246,101 +229,10 @@ func InsertPackage(ctx context.Context, pkg Package) error { | |||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
|  |  | ||||||
| // GetPkgs returns a result containing packages that match the where conditions | func (d *Database) Close() error { | ||||||
| func GetPkgs(ctx context.Context, where string, args ...any) (*sqlx.Rows, error) { | 	if d.conn != nil { | ||||||
| 	stream, err := DB(ctx).QueryxContext(ctx, "SELECT * FROM pkgs WHERE "+where, args...) | 		return d.conn.Close() | ||||||
| 	if err != nil { | 	} else { | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
| 	return stream, nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // GetPkg returns a single package that matches the where conditions |  | ||||||
| func GetPkg(ctx context.Context, where string, args ...any) (*Package, error) { |  | ||||||
| 	out := &Package{} |  | ||||||
| 	err := DB(ctx).GetContext(ctx, out, "SELECT * FROM pkgs WHERE "+where+" LIMIT 1", args...) |  | ||||||
| 	return out, err |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // DeletePkgs deletes all packages matching the where conditions |  | ||||||
| func DeletePkgs(ctx context.Context, where string, args ...any) error { |  | ||||||
| 	_, err := DB(ctx).ExecContext(ctx, "DELETE FROM pkgs WHERE "+where, args...) |  | ||||||
| 	return err |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // jsonArrayContains is an SQLite function that checks if a JSON array |  | ||||||
| // in the database contains a given value |  | ||||||
| func jsonArrayContains(ctx *sqlite.FunctionContext, args []driver.Value) (driver.Value, error) { |  | ||||||
| 	value, ok := args[0].(string) |  | ||||||
| 	if !ok { |  | ||||||
| 		return nil, errors.New("both arguments to json_array_contains must be strings") |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	item, ok := args[1].(string) |  | ||||||
| 	if !ok { |  | ||||||
| 		return nil, errors.New("both arguments to json_array_contains must be strings") |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	var array []string |  | ||||||
| 	err := json.Unmarshal([]byte(value), &array) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return slices.Contains(array, item), nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // JSON represents a JSON value in the database |  | ||||||
| type JSON[T any] struct { |  | ||||||
| 	Val T |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // NewJSON creates a new database JSON value |  | ||||||
| func NewJSON[T any](v T) JSON[T] { |  | ||||||
| 	return JSON[T]{Val: v} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (s *JSON[T]) Scan(val any) error { |  | ||||||
| 	if val == nil { |  | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	switch val := val.(type) { |  | ||||||
| 	case string: |  | ||||||
| 		err := json.Unmarshal([]byte(val), &s.Val) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 	case sql.NullString: |  | ||||||
| 		if val.Valid { |  | ||||||
| 			err := json.Unmarshal([]byte(val.String), &s.Val) |  | ||||||
| 			if err != nil { |  | ||||||
| 				return err |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	default: |  | ||||||
| 		return errors.New("sqlite json types must be strings") |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (s JSON[T]) Value() (driver.Value, error) { |  | ||||||
| 	data, err := json.Marshal(s.Val) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
| 	return string(data), nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (s JSON[T]) MarshalYAML() (any, error) { |  | ||||||
| 	return s.Val, nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (s JSON[T]) String() string { |  | ||||||
| 	return fmt.Sprint(s.Val) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (s JSON[T]) GoString() string { |  | ||||||
| 	return fmt.Sprintf("%#v", s.Val) |  | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										105
									
								
								internal/db/db_legacy.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										105
									
								
								internal/db/db_legacy.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,105 @@ | |||||||
|  | /* | ||||||
|  |  * ALR - Any Linux Repository | ||||||
|  |  * Copyright (C) 2024 Евгений Храмов | ||||||
|  |  * | ||||||
|  |  * This program is free software: you can redistribute it and/or modify | ||||||
|  |  * it under the terms of the GNU General Public License as published by | ||||||
|  |  * the Free Software Foundation, either version 3 of the License, or | ||||||
|  |  * (at your option) any later version. | ||||||
|  |  * | ||||||
|  |  * This program is distributed in the hope that it will be useful, | ||||||
|  |  * but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||||
|  |  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||||
|  |  * GNU General Public License for more details. | ||||||
|  |  * | ||||||
|  |  * You should have received a copy of the GNU General Public License | ||||||
|  |  * along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||||
|  |  */ | ||||||
|  |  | ||||||
|  | package db | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"sync" | ||||||
|  |  | ||||||
|  | 	"github.com/jmoiron/sqlx" | ||||||
|  | 	"plemya-x.ru/alr/internal/config" | ||||||
|  | 	"plemya-x.ru/alr/pkg/loggerctx" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // DB returns the ALR database. | ||||||
|  | // The first time it's called, it opens the SQLite database file. | ||||||
|  | // Subsequent calls return the same connection. | ||||||
|  | // | ||||||
|  | // Deprecated: use struct method | ||||||
|  | func DB(ctx context.Context) *sqlx.DB { | ||||||
|  | 	return getInstance(ctx).GetConn() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Close closes the database | ||||||
|  | // | ||||||
|  | // Deprecated: use struct method | ||||||
|  | func Close() error { | ||||||
|  | 	if database != nil { | ||||||
|  | 		return database.Close() | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // IsEmpty returns true if the database has no packages in it, otherwise it returns false. | ||||||
|  | // | ||||||
|  | // Deprecated: use struct method | ||||||
|  | func IsEmpty(ctx context.Context) bool { | ||||||
|  | 	return getInstance(ctx).IsEmpty(ctx) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // InsertPackage adds a package to the database | ||||||
|  | // | ||||||
|  | // Deprecated: use struct method | ||||||
|  | func InsertPackage(ctx context.Context, pkg Package) error { | ||||||
|  | 	return getInstance(ctx).InsertPackage(ctx, pkg) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // GetPkgs returns a result containing packages that match the where conditions | ||||||
|  | // | ||||||
|  | // Deprecated: use struct method | ||||||
|  | func GetPkgs(ctx context.Context, where string, args ...any) (*sqlx.Rows, error) { | ||||||
|  | 	return getInstance(ctx).GetPkgs(ctx, where, args...) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // GetPkg returns a single package that matches the where conditions | ||||||
|  | // | ||||||
|  | // Deprecated: use struct method | ||||||
|  | func GetPkg(ctx context.Context, where string, args ...any) (*Package, error) { | ||||||
|  | 	return getInstance(ctx).GetPkg(ctx, where, args...) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // DeletePkgs deletes all packages matching the where conditions | ||||||
|  | // | ||||||
|  | // Deprecated: use struct method | ||||||
|  | func DeletePkgs(ctx context.Context, where string, args ...any) error { | ||||||
|  | 	return getInstance(ctx).DeletePkgs(ctx, where, args...) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ======================= | ||||||
|  | // FOR LEGACY ONLY | ||||||
|  | // ======================= | ||||||
|  |  | ||||||
|  | var ( | ||||||
|  | 	dbOnce   sync.Once | ||||||
|  | 	database *Database | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // For refactoring only | ||||||
|  | func getInstance(ctx context.Context) *Database { | ||||||
|  | 	dbOnce.Do(func() { | ||||||
|  | 		log := loggerctx.From(ctx) | ||||||
|  | 		cfg := config.GetInstance(ctx) | ||||||
|  | 		database = New(cfg) | ||||||
|  | 		err := database.Init(ctx) | ||||||
|  | 		if err != nil { | ||||||
|  | 			log.Fatal("Error opening database").Err(err).Send() | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 	return database | ||||||
|  | } | ||||||
| @@ -19,14 +19,30 @@ | |||||||
| package db_test | package db_test | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"reflect" | 	"reflect" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"testing" | 	"testing" | ||||||
|  |  | ||||||
| 	"github.com/jmoiron/sqlx" | 	"github.com/jmoiron/sqlx" | ||||||
|  | 	"plemya-x.ru/alr/internal/config" | ||||||
| 	"plemya-x.ru/alr/internal/db" | 	"plemya-x.ru/alr/internal/db" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | type TestALRConfig struct{} | ||||||
|  |  | ||||||
|  | func (c *TestALRConfig) GetPaths(ctx context.Context) *config.Paths { | ||||||
|  | 	return &config.Paths{ | ||||||
|  | 		DBPath: ":memory:", | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func prepareDb() *db.Database { | ||||||
|  | 	database := db.New(&TestALRConfig{}) | ||||||
|  | 	database.Init(context.Background()) | ||||||
|  | 	return database | ||||||
|  | } | ||||||
|  |  | ||||||
| var testPkg = db.Package{ | var testPkg = db.Package{ | ||||||
| 	Name:    "test", | 	Name:    "test", | ||||||
| 	Version: "0.0.1", | 	Version: "0.0.1", | ||||||
| @@ -59,18 +75,11 @@ var testPkg = db.Package{ | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestInit(t *testing.T) { | func TestInit(t *testing.T) { | ||||||
| 	_, err := db.Open(":memory:") | 	ctx := context.Background() | ||||||
| 	if err != nil { | 	database := prepareDb() | ||||||
| 		t.Fatalf("Expected no error, got %s", err) | 	defer database.Close() | ||||||
| 	} |  | ||||||
| 	defer db.Close() |  | ||||||
|  |  | ||||||
| 	_, err = db.DB().Exec("SELECT * FROM pkgs") | 	ver, ok := database.GetVersion(ctx) | ||||||
| 	if err != nil { |  | ||||||
| 		t.Fatalf("Expected no error, got %s", err) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	ver, ok := db.GetVersion() |  | ||||||
| 	if !ok { | 	if !ok { | ||||||
| 		t.Errorf("Expected version to be present") | 		t.Errorf("Expected version to be present") | ||||||
| 	} else if ver != db.CurrentVersion { | 	} else if ver != db.CurrentVersion { | ||||||
| @@ -79,19 +88,17 @@ func TestInit(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestInsertPackage(t *testing.T) { | func TestInsertPackage(t *testing.T) { | ||||||
| 	_, err := db.Open(":memory:") | 	ctx := context.Background() | ||||||
| 	if err != nil { | 	database := prepareDb() | ||||||
| 		t.Fatalf("Expected no error, got %s", err) | 	defer database.Close() | ||||||
| 	} |  | ||||||
| 	defer db.Close() |  | ||||||
|  |  | ||||||
| 	err = db.InsertPackage(testPkg) | 	err := database.InsertPackage(ctx, testPkg) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("Expected no error, got %s", err) | 		t.Fatalf("Expected no error, got %s", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	dbPkg := db.Package{} | 	dbPkg := db.Package{} | ||||||
| 	err = sqlx.Get(db.DB(), &dbPkg, "SELECT * FROM pkgs WHERE name = 'test' AND repository = 'default'") | 	err = sqlx.Get(database.GetConn(), &dbPkg, "SELECT * FROM pkgs WHERE name = 'test' AND repository = 'default'") | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("Expected no error, got %s", err) | 		t.Fatalf("Expected no error, got %s", err) | ||||||
| 	} | 	} | ||||||
| @@ -102,28 +109,26 @@ func TestInsertPackage(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestGetPkgs(t *testing.T) { | func TestGetPkgs(t *testing.T) { | ||||||
| 	_, err := db.Open(":memory:") | 	ctx := context.Background() | ||||||
| 	if err != nil { | 	database := prepareDb() | ||||||
| 		t.Fatalf("Expected no error, got %s", err) | 	defer database.Close() | ||||||
| 	} |  | ||||||
| 	defer db.Close() |  | ||||||
|  |  | ||||||
| 	x1 := testPkg | 	x1 := testPkg | ||||||
| 	x1.Name = "x1" | 	x1.Name = "x1" | ||||||
| 	x2 := testPkg | 	x2 := testPkg | ||||||
| 	x2.Name = "x2" | 	x2.Name = "x2" | ||||||
|  |  | ||||||
| 	err = db.InsertPackage(x1) | 	err := database.InsertPackage(ctx, x1) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Errorf("Expected no error, got %s", err) | 		t.Errorf("Expected no error, got %s", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	err = db.InsertPackage(x2) | 	err = database.InsertPackage(ctx, x2) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Errorf("Expected no error, got %s", err) | 		t.Errorf("Expected no error, got %s", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	result, err := db.GetPkgs("name LIKE 'x%'") | 	result, err := database.GetPkgs(ctx, "name LIKE 'x%'") | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("Expected no error, got %s", err) | 		t.Fatalf("Expected no error, got %s", err) | ||||||
| 	} | 	} | ||||||
| @@ -142,28 +147,26 @@ func TestGetPkgs(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestGetPkg(t *testing.T) { | func TestGetPkg(t *testing.T) { | ||||||
| 	_, err := db.Open(":memory:") | 	ctx := context.Background() | ||||||
| 	if err != nil { | 	database := prepareDb() | ||||||
| 		t.Fatalf("Expected no error, got %s", err) | 	defer database.Close() | ||||||
| 	} |  | ||||||
| 	defer db.Close() |  | ||||||
|  |  | ||||||
| 	x1 := testPkg | 	x1 := testPkg | ||||||
| 	x1.Name = "x1" | 	x1.Name = "x1" | ||||||
| 	x2 := testPkg | 	x2 := testPkg | ||||||
| 	x2.Name = "x2" | 	x2.Name = "x2" | ||||||
|  |  | ||||||
| 	err = db.InsertPackage(x1) | 	err := database.InsertPackage(ctx, x1) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Errorf("Expected no error, got %s", err) | 		t.Errorf("Expected no error, got %s", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	err = db.InsertPackage(x2) | 	err = database.InsertPackage(ctx, x2) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Errorf("Expected no error, got %s", err) | 		t.Errorf("Expected no error, got %s", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	pkg, err := db.GetPkg("name LIKE 'x%' ORDER BY name") | 	pkg, err := database.GetPkg(ctx, "name LIKE 'x%' ORDER BY name") | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("Expected no error, got %s", err) | 		t.Fatalf("Expected no error, got %s", err) | ||||||
| 	} | 	} | ||||||
| @@ -178,34 +181,32 @@ func TestGetPkg(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestDeletePkgs(t *testing.T) { | func TestDeletePkgs(t *testing.T) { | ||||||
| 	_, err := db.Open(":memory:") | 	ctx := context.Background() | ||||||
| 	if err != nil { | 	database := prepareDb() | ||||||
| 		t.Fatalf("Expected no error, got %s", err) | 	defer database.Close() | ||||||
| 	} |  | ||||||
| 	defer db.Close() |  | ||||||
|  |  | ||||||
| 	x1 := testPkg | 	x1 := testPkg | ||||||
| 	x1.Name = "x1" | 	x1.Name = "x1" | ||||||
| 	x2 := testPkg | 	x2 := testPkg | ||||||
| 	x2.Name = "x2" | 	x2.Name = "x2" | ||||||
|  |  | ||||||
| 	err = db.InsertPackage(x1) | 	err := database.InsertPackage(ctx, x1) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Errorf("Expected no error, got %s", err) | 		t.Errorf("Expected no error, got %s", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	err = db.InsertPackage(x2) | 	err = database.InsertPackage(ctx, x2) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Errorf("Expected no error, got %s", err) | 		t.Errorf("Expected no error, got %s", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	err = db.DeletePkgs("name = 'x1'") | 	err = database.DeletePkgs(ctx, "name = 'x1'") | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Errorf("Expected no error, got %s", err) | 		t.Errorf("Expected no error, got %s", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	var dbPkg db.Package | 	var dbPkg db.Package | ||||||
| 	err = db.DB().Get(&dbPkg, "SELECT * FROM pkgs WHERE name LIKE 'x%' ORDER BY name LIMIT 1;") | 	err = database.GetConn().Get(&dbPkg, "SELECT * FROM pkgs WHERE name LIKE 'x%' ORDER BY name LIMIT 1;") | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Errorf("Expected no error, got %s", err) | 		t.Errorf("Expected no error, got %s", err) | ||||||
| 	} | 	} | ||||||
| @@ -216,11 +217,9 @@ func TestDeletePkgs(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestJsonArrayContains(t *testing.T) { | func TestJsonArrayContains(t *testing.T) { | ||||||
| 	_, err := db.Open(":memory:") | 	ctx := context.Background() | ||||||
| 	if err != nil { | 	database := prepareDb() | ||||||
| 		t.Fatalf("Expected no error, got %s", err) | 	defer database.Close() | ||||||
| 	} |  | ||||||
| 	defer db.Close() |  | ||||||
|  |  | ||||||
| 	x1 := testPkg | 	x1 := testPkg | ||||||
| 	x1.Name = "x1" | 	x1.Name = "x1" | ||||||
| @@ -228,18 +227,18 @@ func TestJsonArrayContains(t *testing.T) { | |||||||
| 	x2.Name = "x2" | 	x2.Name = "x2" | ||||||
| 	x2.Provides.Val = append(x2.Provides.Val, "x") | 	x2.Provides.Val = append(x2.Provides.Val, "x") | ||||||
|  |  | ||||||
| 	err = db.InsertPackage(x1) | 	err := database.InsertPackage(ctx, x1) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Errorf("Expected no error, got %s", err) | 		t.Errorf("Expected no error, got %s", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	err = db.InsertPackage(x2) | 	err = database.InsertPackage(ctx, x2) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Errorf("Expected no error, got %s", err) | 		t.Errorf("Expected no error, got %s", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	var dbPkg db.Package | 	var dbPkg db.Package | ||||||
| 	err = db.DB().Get(&dbPkg, "SELECT * FROM pkgs WHERE json_array_contains(provides, 'x');") | 	err = database.GetConn().Get(&dbPkg, "SELECT * FROM pkgs WHERE json_array_contains(provides, 'x');") | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("Expected no error, got %s", err) | 		t.Fatalf("Expected no error, got %s", err) | ||||||
| 	} | 	} | ||||||
|   | |||||||
							
								
								
									
										64
									
								
								internal/db/json.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										64
									
								
								internal/db/json.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,64 @@ | |||||||
|  | package db | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"database/sql" | ||||||
|  | 	"database/sql/driver" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // JSON represents a JSON value in the database | ||||||
|  | type JSON[T any] struct { | ||||||
|  | 	Val T | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // NewJSON creates a new database JSON value | ||||||
|  | func NewJSON[T any](v T) JSON[T] { | ||||||
|  | 	return JSON[T]{Val: v} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *JSON[T]) Scan(val any) error { | ||||||
|  | 	if val == nil { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	switch val := val.(type) { | ||||||
|  | 	case string: | ||||||
|  | 		err := json.Unmarshal([]byte(val), &s.Val) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 	case sql.NullString: | ||||||
|  | 		if val.Valid { | ||||||
|  | 			err := json.Unmarshal([]byte(val.String), &s.Val) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	default: | ||||||
|  | 		return errors.New("sqlite json types must be strings") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s JSON[T]) Value() (driver.Value, error) { | ||||||
|  | 	data, err := json.Marshal(s.Val) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	return string(data), nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s JSON[T]) MarshalYAML() (any, error) { | ||||||
|  | 	return s.Val, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s JSON[T]) String() string { | ||||||
|  | 	return fmt.Sprint(s.Val) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s JSON[T]) GoString() string { | ||||||
|  | 	return fmt.Sprintf("%#v", s.Val) | ||||||
|  | } | ||||||
							
								
								
									
										36
									
								
								internal/db/utils.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								internal/db/utils.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,36 @@ | |||||||
|  | package db | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"database/sql/driver" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"errors" | ||||||
|  |  | ||||||
|  | 	"golang.org/x/exp/slices" | ||||||
|  | 	"modernc.org/sqlite" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func init() { | ||||||
|  | 	sqlite.MustRegisterScalarFunction("json_array_contains", 2, jsonArrayContains) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // jsonArrayContains is an SQLite function that checks if a JSON array | ||||||
|  | // in the database contains a given value | ||||||
|  | func jsonArrayContains(ctx *sqlite.FunctionContext, args []driver.Value) (driver.Value, error) { | ||||||
|  | 	value, ok := args[0].(string) | ||||||
|  | 	if !ok { | ||||||
|  | 		return nil, errors.New("both arguments to json_array_contains must be strings") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	item, ok := args[1].(string) | ||||||
|  | 	if !ok { | ||||||
|  | 		return nil, errors.New("both arguments to json_array_contains must be strings") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	var array []string | ||||||
|  | 	err := json.Unmarshal([]byte(value), &array) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return slices.Contains(array, item), nil | ||||||
|  | } | ||||||
		Reference in New Issue
	
	Block a user