diff --git a/internal/config/config.go b/internal/config/config.go index db43f89..9c27b93 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -21,6 +21,7 @@ package config import ( "context" "os" + "path/filepath" "sync" "github.com/pelletier/go-toml/v2" @@ -28,6 +29,14 @@ import ( "plemya-x.ru/alr/pkg/loggerctx" ) +type ALRConfig struct { + cfg *types.Config + paths *Paths + + cfgOnce sync.Once + pathsOnce sync.Once +} + var defaultConfig = &types.Config{ RootCmd: "sudo", PagerStyle: "native", @@ -40,40 +49,107 @@ var defaultConfig = &types.Config{ }, } -var ( - configMtx sync.Mutex - config *types.Config -) +func New() *ALRConfig { + return &ALRConfig{} +} -// Config returns a ALR configuration struct. -// The first time it's called, it'll load the config from a file. -// Subsequent calls will just return the same value. -func Config(ctx context.Context) *types.Config { - configMtx.Lock() - defer configMtx.Unlock() +func (c *ALRConfig) Load(ctx context.Context) { log := loggerctx.From(ctx) + cfgFl, err := os.Open(c.GetPaths(ctx).ConfigPath) + if err != nil { + log.Warn("Error opening config file, using defaults").Err(err).Send() + c.cfg = defaultConfig + return + } + defer cfgFl.Close() - if config == nil { - cfgFl, err := os.Open(GetPaths(ctx).ConfigPath) - if err != nil { - log.Warn("Error opening config file, using defaults").Err(err).Send() - return defaultConfig - } - defer cfgFl.Close() + // Copy the default configuration into config + defCopy := *defaultConfig + config := &defCopy + config.Repos = nil - // Copy the default configuration into config - defCopy := *defaultConfig - config = &defCopy - config.Repos = nil + err = toml.NewDecoder(cfgFl).Decode(config) + if err != nil { + log.Warn("Error decoding config file, using defaults").Err(err).Send() + c.cfg = defaultConfig + return + } + c.cfg = config +} - err = toml.NewDecoder(cfgFl).Decode(config) - if err != nil { - log.Warn("Error decoding config file, using defaults").Err(err).Send() - // Set config back to nil so that we try again next time - config = nil - return defaultConfig - } +func (c *ALRConfig) initPaths(ctx context.Context) { + log := loggerctx.From(ctx) + paths := &Paths{} + + cfgDir, err := os.UserConfigDir() + if err != nil { + log.Fatal("Unable to detect user config directory").Err(err).Send() } - return config + paths.ConfigDir = filepath.Join(cfgDir, "alr") + + err = os.MkdirAll(paths.ConfigDir, 0o755) + if err != nil { + log.Fatal("Unable to create ALR config directory").Err(err).Send() + } + + paths.ConfigPath = filepath.Join(paths.ConfigDir, "alr.toml") + + if _, err := os.Stat(paths.ConfigPath); err != nil { + cfgFl, err := os.Create(paths.ConfigPath) + if err != nil { + log.Fatal("Unable to create ALR config file").Err(err).Send() + } + + err = toml.NewEncoder(cfgFl).Encode(&defaultConfig) + if err != nil { + log.Fatal("Error encoding default configuration").Err(err).Send() + } + + cfgFl.Close() + } + + cacheDir, err := os.UserCacheDir() + if err != nil { + log.Fatal("Unable to detect cache directory").Err(err).Send() + } + + paths.CacheDir = filepath.Join(cacheDir, "alr") + paths.RepoDir = filepath.Join(paths.CacheDir, "repo") + paths.PkgsDir = filepath.Join(paths.CacheDir, "pkgs") + + err = os.MkdirAll(paths.RepoDir, 0o755) + if err != nil { + log.Fatal("Unable to create repo cache directory").Err(err).Send() + } + + err = os.MkdirAll(paths.PkgsDir, 0o755) + if err != nil { + log.Fatal("Unable to create package cache directory").Err(err).Send() + } + + paths.DBPath = filepath.Join(paths.CacheDir, "db") + + c.paths = paths +} + +func (c *ALRConfig) GetPaths(ctx context.Context) *Paths { + c.pathsOnce.Do(func() { + c.initPaths(ctx) + }) + return c.paths +} + +func (c *ALRConfig) Repos(ctx context.Context) []types.Repo { + c.cfgOnce.Do(func() { + c.Load(ctx) + }) + return c.cfg.Repos +} + +func (c *ALRConfig) IgnorePkgUpdates(ctx context.Context) []string { + c.cfgOnce.Do(func() { + c.Load(ctx) + }) + return c.cfg.IgnorePkgUpdates } diff --git a/internal/config/config_legacy.go b/internal/config/config_legacy.go new file mode 100644 index 0000000..46fe18c --- /dev/null +++ b/internal/config/config_legacy.go @@ -0,0 +1,54 @@ +/* + * ALR - Any Linux Repository + * Copyright (C) 2024 Евгений Храмов + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package config + +import ( + "context" + "sync" + + "plemya-x.ru/alr/internal/types" +) + +// Config returns a ALR configuration struct. +// The first time it's called, it'll load the config from a file. +// Subsequent calls will just return the same value. +// +// Deprecated: use struct method +func Config(ctx context.Context) *types.Config { + return GetInstance(ctx).cfg +} + +// ======================= +// FOR LEGACY ONLY +// ======================= + +var ( + alrConfig *ALRConfig + alrConfigOnce sync.Once +) + +// Deprecated: For legacy only +func GetInstance(ctx context.Context) *ALRConfig { + alrConfigOnce.Do(func() { + alrConfig = New() + alrConfig.Load(ctx) + }) + + return alrConfig +} diff --git a/internal/config/paths.go b/internal/config/paths.go index 8b10c35..ebdea31 100644 --- a/internal/config/paths.go +++ b/internal/config/paths.go @@ -20,12 +20,6 @@ package config import ( "context" - "os" - "path/filepath" - "sync" - - "github.com/pelletier/go-toml/v2" - "plemya-x.ru/alr/pkg/loggerctx" ) // Paths contains various paths used by ALR @@ -38,71 +32,13 @@ type Paths struct { DBPath string } -var ( - pathsMtx sync.Mutex - paths *Paths -) - // GetPaths returns a Paths struct. // The first time it's called, it'll generate the struct // using information from the system. // Subsequent calls will return the same value. +// +// Depreacted: use struct API func GetPaths(ctx context.Context) *Paths { - pathsMtx.Lock() - defer pathsMtx.Unlock() - - log := loggerctx.From(ctx) - if paths == nil { - paths = &Paths{} - - cfgDir, err := os.UserConfigDir() - if err != nil { - log.Fatal("Unable to detect user config directory").Err(err).Send() - } - - paths.ConfigDir = filepath.Join(cfgDir, "alr") - - err = os.MkdirAll(paths.ConfigDir, 0o755) - if err != nil { - log.Fatal("Unable to create ALR config directory").Err(err).Send() - } - - paths.ConfigPath = filepath.Join(paths.ConfigDir, "alr.toml") - - if _, err := os.Stat(paths.ConfigPath); err != nil { - cfgFl, err := os.Create(paths.ConfigPath) - if err != nil { - log.Fatal("Unable to create ALR config file").Err(err).Send() - } - - err = toml.NewEncoder(cfgFl).Encode(&defaultConfig) - if err != nil { - log.Fatal("Error encoding default configuration").Err(err).Send() - } - - cfgFl.Close() - } - - cacheDir, err := os.UserCacheDir() - if err != nil { - log.Fatal("Unable to detect cache directory").Err(err).Send() - } - - paths.CacheDir = filepath.Join(cacheDir, "alr") - paths.RepoDir = filepath.Join(paths.CacheDir, "repo") - paths.PkgsDir = filepath.Join(paths.CacheDir, "pkgs") - - err = os.MkdirAll(paths.RepoDir, 0o755) - if err != nil { - log.Fatal("Unable to create repo cache directory").Err(err).Send() - } - - err = os.MkdirAll(paths.PkgsDir, 0o755) - if err != nil { - log.Fatal("Unable to create package cache directory").Err(err).Send() - } - - paths.DBPath = filepath.Join(paths.CacheDir, "db") - } - return paths + alrConfig := GetInstance(ctx) + return alrConfig.GetPaths(ctx) } diff --git a/internal/db/db.go b/internal/db/db.go index 9e421cb..a4bd4a5 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -20,28 +20,16 @@ package db import ( "context" - "database/sql" - "database/sql/driver" - "encoding/json" - "errors" - "fmt" - "sync" "github.com/jmoiron/sqlx" "plemya-x.ru/alr/internal/config" "plemya-x.ru/alr/pkg/loggerctx" - "golang.org/x/exp/slices" - "modernc.org/sqlite" ) // CurrentVersion is the current version of the database. // The database is reset if its version doesn't match this. const CurrentVersion = 2 -func init() { - sqlite.MustRegisterScalarFunction("json_array_contains", 2, jsonArrayContains) -} - // Package is a ALR package's database representation type Package struct { Name string `sh:"name,required" db:"name"` @@ -66,66 +54,47 @@ type version struct { Version int `db:"version"` } -var ( - mu sync.Mutex +type Config interface { + GetPaths(ctx context.Context) *config.Paths +} + +type Database struct { conn *sqlx.DB - closed = true -) + config Config +} -// DB returns the ALR database. -// The first time it's called, it opens the SQLite database file. -// Subsequent calls return the same connection. -func DB(ctx context.Context) *sqlx.DB { - log := loggerctx.From(ctx) - if conn != nil && !closed { - return getConn() +func New(config Config) *Database { + return &Database{ + config: config, } - _, err := open(ctx, config.GetPaths(ctx).DBPath) +} + +func (d *Database) Init(ctx context.Context) error { + err := d.Connect(ctx) if err != nil { - log.Fatal("Error opening database").Err(err).Send() + return err } - return getConn() + return d.initDB(ctx) } -func getConn() *sqlx.DB { - mu.Lock() - defer mu.Unlock() - return conn -} - -func open(ctx context.Context, dsn string) (*sqlx.DB, error) { +func (d *Database) Connect(ctx context.Context) error { + dsn := d.config.GetPaths(ctx).DBPath db, err := sqlx.Open("sqlite", dsn) if err != nil { - return nil, err + return err } - - mu.Lock() - conn = db - closed = false - mu.Unlock() - - err = initDB(ctx, dsn) - if err != nil { - return nil, err - } - - return db, nil + d.conn = db + return nil } -// Close closes the database -func Close() error { - closed = true - if conn != nil { - return conn.Close() - } else { - return nil - } +func (d *Database) GetConn() *sqlx.DB { + return d.conn } -// initDB initializes the database -func initDB(ctx context.Context, dsn string) error { +func (d *Database) initDB(ctx context.Context) error { log := loggerctx.From(ctx) - conn = conn.Unsafe() + d.conn = d.conn.Unsafe() + conn := d.conn _, err := conn.ExecContext(ctx, ` CREATE TABLE IF NOT EXISTS pkgs ( name TEXT NOT NULL, @@ -155,58 +124,72 @@ func initDB(ctx context.Context, dsn string) error { return err } - ver, ok := GetVersion(ctx) + ver, ok := d.GetVersion(ctx) if ok && ver != CurrentVersion { log.Warn("Database version mismatch; resetting").Int("version", ver).Int("expected", CurrentVersion).Send() - reset(ctx) - return initDB(ctx, dsn) + d.reset(ctx) + return d.initDB(ctx) } else if !ok { log.Warn("Database version does not exist. Run alr fix if something isn't working.").Send() - return addVersion(ctx, CurrentVersion) + return d.addVersion(ctx, CurrentVersion) } return nil } -// reset drops all the database tables -func reset(ctx context.Context) error { - _, err := DB(ctx).ExecContext(ctx, "DROP TABLE IF EXISTS pkgs;") - if err != nil { - return err - } - _, err = DB(ctx).ExecContext(ctx, "DROP TABLE IF EXISTS alr_db_version;") - return err -} - -// IsEmpty returns true if the database has no packages in it, otherwise it returns false. -func IsEmpty(ctx context.Context) bool { - var count int - err := DB(ctx).GetContext(ctx, &count, "SELECT count(1) FROM pkgs;") - if err != nil { - return true - } - return count == 0 -} - -// GetVersion returns the database version and a boolean indicating -// whether the database contained a version number -func GetVersion(ctx context.Context) (int, bool) { +func (d *Database) GetVersion(ctx context.Context) (int, bool) { var ver version - err := DB(ctx).GetContext(ctx, &ver, "SELECT * FROM alr_db_version LIMIT 1;") + err := d.conn.GetContext(ctx, &ver, "SELECT * FROM alr_db_version LIMIT 1;") if err != nil { return 0, false } return ver.Version, true } -func addVersion(ctx context.Context, ver int) error { - _, err := DB(ctx).ExecContext(ctx, `INSERT INTO alr_db_version(version) VALUES (?);`, ver) +func (d *Database) addVersion(ctx context.Context, ver int) error { + _, err := d.conn.ExecContext(ctx, `INSERT INTO alr_db_version(version) VALUES (?);`, ver) return err } -// InsertPackage adds a package to the database -func InsertPackage(ctx context.Context, pkg Package) error { - _, err := DB(ctx).NamedExecContext(ctx, ` +func (d *Database) reset(ctx context.Context) error { + _, err := d.conn.ExecContext(ctx, "DROP TABLE IF EXISTS pkgs;") + if err != nil { + return err + } + _, err = d.conn.ExecContext(ctx, "DROP TABLE IF EXISTS alr_db_version;") + return err +} + +func (d *Database) GetPkgs(ctx context.Context, where string, args ...any) (*sqlx.Rows, error) { + stream, err := d.conn.QueryxContext(ctx, "SELECT * FROM pkgs WHERE "+where, args...) + if err != nil { + return nil, err + } + return stream, nil +} + +func (d *Database) GetPkg(ctx context.Context, where string, args ...any) (*Package, error) { + out := &Package{} + err := d.conn.GetContext(ctx, out, "SELECT * FROM pkgs WHERE "+where+" LIMIT 1", args...) + return out, err +} + +func (d *Database) DeletePkgs(ctx context.Context, where string, args ...any) error { + _, err := d.conn.ExecContext(ctx, "DELETE FROM pkgs WHERE "+where, args...) + return err +} + +func (d *Database) IsEmpty(ctx context.Context) bool { + var count int + err := d.conn.GetContext(ctx, &count, "SELECT count(1) FROM pkgs;") + if err != nil { + return true + } + return count == 0 +} + +func (d *Database) InsertPackage(ctx context.Context, pkg Package) error { + _, err := d.conn.NamedExecContext(ctx, ` INSERT OR REPLACE INTO pkgs ( name, repository, @@ -246,101 +229,10 @@ func InsertPackage(ctx context.Context, pkg Package) error { return err } -// GetPkgs returns a result containing packages that match the where conditions -func GetPkgs(ctx context.Context, where string, args ...any) (*sqlx.Rows, error) { - stream, err := DB(ctx).QueryxContext(ctx, "SELECT * FROM pkgs WHERE "+where, args...) - if err != nil { - return nil, err - } - return stream, nil -} - -// GetPkg returns a single package that matches the where conditions -func GetPkg(ctx context.Context, where string, args ...any) (*Package, error) { - out := &Package{} - err := DB(ctx).GetContext(ctx, out, "SELECT * FROM pkgs WHERE "+where+" LIMIT 1", args...) - return out, err -} - -// DeletePkgs deletes all packages matching the where conditions -func DeletePkgs(ctx context.Context, where string, args ...any) error { - _, err := DB(ctx).ExecContext(ctx, "DELETE FROM pkgs WHERE "+where, args...) - return err -} - -// jsonArrayContains is an SQLite function that checks if a JSON array -// in the database contains a given value -func jsonArrayContains(ctx *sqlite.FunctionContext, args []driver.Value) (driver.Value, error) { - value, ok := args[0].(string) - if !ok { - return nil, errors.New("both arguments to json_array_contains must be strings") - } - - item, ok := args[1].(string) - if !ok { - return nil, errors.New("both arguments to json_array_contains must be strings") - } - - var array []string - err := json.Unmarshal([]byte(value), &array) - if err != nil { - return nil, err - } - - return slices.Contains(array, item), nil -} - -// JSON represents a JSON value in the database -type JSON[T any] struct { - Val T -} - -// NewJSON creates a new database JSON value -func NewJSON[T any](v T) JSON[T] { - return JSON[T]{Val: v} -} - -func (s *JSON[T]) Scan(val any) error { - if val == nil { +func (d *Database) Close() error { + if d.conn != nil { + return d.conn.Close() + } else { return nil } - - switch val := val.(type) { - case string: - err := json.Unmarshal([]byte(val), &s.Val) - if err != nil { - return err - } - case sql.NullString: - if val.Valid { - err := json.Unmarshal([]byte(val.String), &s.Val) - if err != nil { - return err - } - } - default: - return errors.New("sqlite json types must be strings") - } - - return nil -} - -func (s JSON[T]) Value() (driver.Value, error) { - data, err := json.Marshal(s.Val) - if err != nil { - return nil, err - } - return string(data), nil -} - -func (s JSON[T]) MarshalYAML() (any, error) { - return s.Val, nil -} - -func (s JSON[T]) String() string { - return fmt.Sprint(s.Val) -} - -func (s JSON[T]) GoString() string { - return fmt.Sprintf("%#v", s.Val) } diff --git a/internal/db/db_legacy.go b/internal/db/db_legacy.go new file mode 100644 index 0000000..22e6fb8 --- /dev/null +++ b/internal/db/db_legacy.go @@ -0,0 +1,105 @@ +/* + * ALR - Any Linux Repository + * Copyright (C) 2024 Евгений Храмов + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package db + +import ( + "context" + "sync" + + "github.com/jmoiron/sqlx" + "plemya-x.ru/alr/internal/config" + "plemya-x.ru/alr/pkg/loggerctx" +) + +// DB returns the ALR database. +// The first time it's called, it opens the SQLite database file. +// Subsequent calls return the same connection. +// +// Deprecated: use struct method +func DB(ctx context.Context) *sqlx.DB { + return GetInstance(ctx).GetConn() +} + +// Close closes the database +// +// Deprecated: use struct method +func Close() error { + if database != nil { + return database.Close() + } + return nil +} + +// IsEmpty returns true if the database has no packages in it, otherwise it returns false. +// +// Deprecated: use struct method +func IsEmpty(ctx context.Context) bool { + return GetInstance(ctx).IsEmpty(ctx) +} + +// InsertPackage adds a package to the database +// +// Deprecated: use struct method +func InsertPackage(ctx context.Context, pkg Package) error { + return GetInstance(ctx).InsertPackage(ctx, pkg) +} + +// GetPkgs returns a result containing packages that match the where conditions +// +// Deprecated: use struct method +func GetPkgs(ctx context.Context, where string, args ...any) (*sqlx.Rows, error) { + return GetInstance(ctx).GetPkgs(ctx, where, args...) +} + +// GetPkg returns a single package that matches the where conditions +// +// Deprecated: use struct method +func GetPkg(ctx context.Context, where string, args ...any) (*Package, error) { + return GetInstance(ctx).GetPkg(ctx, where, args...) +} + +// DeletePkgs deletes all packages matching the where conditions +// +// Deprecated: use struct method +func DeletePkgs(ctx context.Context, where string, args ...any) error { + return GetInstance(ctx).DeletePkgs(ctx, where, args...) +} + +// ======================= +// FOR LEGACY ONLY +// ======================= + +var ( + dbOnce sync.Once + database *Database +) + +// Deprecated: For legacy only +func GetInstance(ctx context.Context) *Database { + dbOnce.Do(func() { + log := loggerctx.From(ctx) + cfg := config.GetInstance(ctx) + database = New(cfg) + err := database.Init(ctx) + if err != nil { + log.Fatal("Error opening database").Err(err).Send() + } + }) + return database +} diff --git a/internal/db/db_test.go b/internal/db/db_test.go index 9e38ee6..4412181 100644 --- a/internal/db/db_test.go +++ b/internal/db/db_test.go @@ -19,14 +19,30 @@ package db_test import ( + "context" "reflect" "strings" "testing" "github.com/jmoiron/sqlx" + "plemya-x.ru/alr/internal/config" "plemya-x.ru/alr/internal/db" ) +type TestALRConfig struct{} + +func (c *TestALRConfig) GetPaths(ctx context.Context) *config.Paths { + return &config.Paths{ + DBPath: ":memory:", + } +} + +func prepareDb() *db.Database { + database := db.New(&TestALRConfig{}) + database.Init(context.Background()) + return database +} + var testPkg = db.Package{ Name: "test", Version: "0.0.1", @@ -59,18 +75,11 @@ var testPkg = db.Package{ } func TestInit(t *testing.T) { - _, err := db.Open(":memory:") - if err != nil { - t.Fatalf("Expected no error, got %s", err) - } - defer db.Close() + ctx := context.Background() + database := prepareDb() + defer database.Close() - _, err = db.DB().Exec("SELECT * FROM pkgs") - if err != nil { - t.Fatalf("Expected no error, got %s", err) - } - - ver, ok := db.GetVersion() + ver, ok := database.GetVersion(ctx) if !ok { t.Errorf("Expected version to be present") } else if ver != db.CurrentVersion { @@ -79,19 +88,17 @@ func TestInit(t *testing.T) { } func TestInsertPackage(t *testing.T) { - _, err := db.Open(":memory:") - if err != nil { - t.Fatalf("Expected no error, got %s", err) - } - defer db.Close() + ctx := context.Background() + database := prepareDb() + defer database.Close() - err = db.InsertPackage(testPkg) + err := database.InsertPackage(ctx, testPkg) if err != nil { t.Fatalf("Expected no error, got %s", err) } dbPkg := db.Package{} - err = sqlx.Get(db.DB(), &dbPkg, "SELECT * FROM pkgs WHERE name = 'test' AND repository = 'default'") + err = sqlx.Get(database.GetConn(), &dbPkg, "SELECT * FROM pkgs WHERE name = 'test' AND repository = 'default'") if err != nil { t.Fatalf("Expected no error, got %s", err) } @@ -102,28 +109,26 @@ func TestInsertPackage(t *testing.T) { } func TestGetPkgs(t *testing.T) { - _, err := db.Open(":memory:") - if err != nil { - t.Fatalf("Expected no error, got %s", err) - } - defer db.Close() + ctx := context.Background() + database := prepareDb() + defer database.Close() x1 := testPkg x1.Name = "x1" x2 := testPkg x2.Name = "x2" - err = db.InsertPackage(x1) + err := database.InsertPackage(ctx, x1) if err != nil { t.Errorf("Expected no error, got %s", err) } - err = db.InsertPackage(x2) + err = database.InsertPackage(ctx, x2) if err != nil { t.Errorf("Expected no error, got %s", err) } - result, err := db.GetPkgs("name LIKE 'x%'") + result, err := database.GetPkgs(ctx, "name LIKE 'x%'") if err != nil { t.Fatalf("Expected no error, got %s", err) } @@ -142,28 +147,26 @@ func TestGetPkgs(t *testing.T) { } func TestGetPkg(t *testing.T) { - _, err := db.Open(":memory:") - if err != nil { - t.Fatalf("Expected no error, got %s", err) - } - defer db.Close() + ctx := context.Background() + database := prepareDb() + defer database.Close() x1 := testPkg x1.Name = "x1" x2 := testPkg x2.Name = "x2" - err = db.InsertPackage(x1) + err := database.InsertPackage(ctx, x1) if err != nil { t.Errorf("Expected no error, got %s", err) } - err = db.InsertPackage(x2) + err = database.InsertPackage(ctx, x2) if err != nil { t.Errorf("Expected no error, got %s", err) } - pkg, err := db.GetPkg("name LIKE 'x%' ORDER BY name") + pkg, err := database.GetPkg(ctx, "name LIKE 'x%' ORDER BY name") if err != nil { t.Fatalf("Expected no error, got %s", err) } @@ -178,34 +181,32 @@ func TestGetPkg(t *testing.T) { } func TestDeletePkgs(t *testing.T) { - _, err := db.Open(":memory:") - if err != nil { - t.Fatalf("Expected no error, got %s", err) - } - defer db.Close() + ctx := context.Background() + database := prepareDb() + defer database.Close() x1 := testPkg x1.Name = "x1" x2 := testPkg x2.Name = "x2" - err = db.InsertPackage(x1) + err := database.InsertPackage(ctx, x1) if err != nil { t.Errorf("Expected no error, got %s", err) } - err = db.InsertPackage(x2) + err = database.InsertPackage(ctx, x2) if err != nil { t.Errorf("Expected no error, got %s", err) } - err = db.DeletePkgs("name = 'x1'") + err = database.DeletePkgs(ctx, "name = 'x1'") if err != nil { t.Errorf("Expected no error, got %s", err) } var dbPkg db.Package - err = db.DB().Get(&dbPkg, "SELECT * FROM pkgs WHERE name LIKE 'x%' ORDER BY name LIMIT 1;") + err = database.GetConn().Get(&dbPkg, "SELECT * FROM pkgs WHERE name LIKE 'x%' ORDER BY name LIMIT 1;") if err != nil { t.Errorf("Expected no error, got %s", err) } @@ -216,11 +217,9 @@ func TestDeletePkgs(t *testing.T) { } func TestJsonArrayContains(t *testing.T) { - _, err := db.Open(":memory:") - if err != nil { - t.Fatalf("Expected no error, got %s", err) - } - defer db.Close() + ctx := context.Background() + database := prepareDb() + defer database.Close() x1 := testPkg x1.Name = "x1" @@ -228,18 +227,18 @@ func TestJsonArrayContains(t *testing.T) { x2.Name = "x2" x2.Provides.Val = append(x2.Provides.Val, "x") - err = db.InsertPackage(x1) + err := database.InsertPackage(ctx, x1) if err != nil { t.Errorf("Expected no error, got %s", err) } - err = db.InsertPackage(x2) + err = database.InsertPackage(ctx, x2) if err != nil { t.Errorf("Expected no error, got %s", err) } var dbPkg db.Package - err = db.DB().Get(&dbPkg, "SELECT * FROM pkgs WHERE json_array_contains(provides, 'x');") + err = database.GetConn().Get(&dbPkg, "SELECT * FROM pkgs WHERE json_array_contains(provides, 'x');") if err != nil { t.Fatalf("Expected no error, got %s", err) } diff --git a/internal/db/json.go b/internal/db/json.go new file mode 100644 index 0000000..2b05693 --- /dev/null +++ b/internal/db/json.go @@ -0,0 +1,64 @@ +package db + +import ( + "database/sql" + "database/sql/driver" + "encoding/json" + "errors" + "fmt" +) + +// JSON represents a JSON value in the database +type JSON[T any] struct { + Val T +} + +// NewJSON creates a new database JSON value +func NewJSON[T any](v T) JSON[T] { + return JSON[T]{Val: v} +} + +func (s *JSON[T]) Scan(val any) error { + if val == nil { + return nil + } + + switch val := val.(type) { + case string: + err := json.Unmarshal([]byte(val), &s.Val) + if err != nil { + return err + } + case sql.NullString: + if val.Valid { + err := json.Unmarshal([]byte(val.String), &s.Val) + if err != nil { + return err + } + } + default: + return errors.New("sqlite json types must be strings") + } + + return nil +} + +func (s JSON[T]) Value() (driver.Value, error) { + data, err := json.Marshal(s.Val) + if err != nil { + return nil, err + } + return string(data), nil +} + +func (s JSON[T]) MarshalYAML() (any, error) { + return s.Val, nil +} + +func (s JSON[T]) String() string { + return fmt.Sprint(s.Val) +} + +func (s JSON[T]) GoString() string { + return fmt.Sprintf("%#v", s.Val) +} diff --git a/internal/db/utils.go b/internal/db/utils.go new file mode 100644 index 0000000..3cbce79 --- /dev/null +++ b/internal/db/utils.go @@ -0,0 +1,36 @@ +package db + +import ( + "database/sql/driver" + "encoding/json" + "errors" + + "golang.org/x/exp/slices" + "modernc.org/sqlite" +) + +func init() { + sqlite.MustRegisterScalarFunction("json_array_contains", 2, jsonArrayContains) +} + +// jsonArrayContains is an SQLite function that checks if a JSON array +// in the database contains a given value +func jsonArrayContains(ctx *sqlite.FunctionContext, args []driver.Value) (driver.Value, error) { + value, ok := args[0].(string) + if !ok { + return nil, errors.New("both arguments to json_array_contains must be strings") + } + + item, ok := args[1].(string) + if !ok { + return nil, errors.New("both arguments to json_array_contains must be strings") + } + + var array []string + err := json.Unmarshal([]byte(value), &array) + if err != nil { + return nil, err + } + + return slices.Contains(array, item), nil +} diff --git a/internal/dl/dl.go b/internal/dl/dl.go index a2a6781..6a82bca 100644 --- a/internal/dl/dl.go +++ b/internal/dl/dl.go @@ -14,7 +14,7 @@ * * You should have received a copy of the GNU General Public License * along with this program. If not, see . -*/ + */ // Пакет dl содержит абстракции для загрузки файлов и каталогов // из различных источников. @@ -39,6 +39,7 @@ import ( "golang.org/x/crypto/blake2b" "golang.org/x/crypto/blake2s" "golang.org/x/exp/slices" + "plemya-x.ru/alr/internal/config" "plemya-x.ru/alr/internal/dlcache" "plemya-x.ru/alr/pkg/loggerctx" ) @@ -142,6 +143,9 @@ type UpdatingDownloader interface { // Функция Download загружает файл или каталог с использованием указанных параметров func Download(ctx context.Context, opts Options) (err error) { log := loggerctx.From(ctx) + cfg := config.GetInstance(ctx) + dc := dlcache.New(cfg) + normalized, err := normalizeURL(opts.URL) if err != nil { return err @@ -156,7 +160,7 @@ func Download(ctx context.Context, opts Options) (err error) { } var t Type - cacheDir, ok := dlcache.Get(ctx, opts.URL) + cacheDir, ok := dc.Get(ctx, opts.URL) if ok { var updated bool if d, ok := d.(UpdatingDownloader); ok { @@ -203,7 +207,7 @@ func Download(ctx context.Context, opts Options) (err error) { log.Info("Downloading source").Str("source", opts.Name).Str("downloader", d.Name()).Send() - cacheDir, err = dlcache.New(ctx, opts.URL) + cacheDir, err = dc.New(ctx, opts.URL) if err != nil { return err } @@ -299,8 +303,6 @@ func linkDir(src, dest string) error { return nil } - - rel, err := filepath.Rel(src, path) if err != nil { return err diff --git a/internal/dlcache/dlcache.go b/internal/dlcache/dlcache.go index ccbf57b..77eb5f4 100644 --- a/internal/dlcache/dlcache.go +++ b/internal/dlcache/dlcache.go @@ -20,29 +20,41 @@ package dlcache import ( "context" - "crypto/sha1" - "encoding/hex" - "io" "os" "path/filepath" "plemya-x.ru/alr/internal/config" ) -// BasePath returns the base path of the download cache -func BasePath(ctx context.Context) string { - return filepath.Join(config.GetPaths(ctx).CacheDir, "dl") +type Config interface { + GetPaths(ctx context.Context) *config.Paths +} + +type DownloadCache struct { + cfg Config +} + +func New(cfg Config) *DownloadCache { + return &DownloadCache{ + cfg, + } +} + +func (dc *DownloadCache) BasePath(ctx context.Context) string { + return filepath.Join( + dc.cfg.GetPaths(ctx).CacheDir, "dl", + ) } // New creates a new directory with the given ID in the cache. // If a directory with the same ID already exists, // it will be deleted before creating a new one. -func New(ctx context.Context, id string) (string, error) { +func (dc *DownloadCache) New(ctx context.Context, id string) (string, error) { h, err := hashID(id) if err != nil { return "", err } - itemPath := filepath.Join(BasePath(ctx), h) + itemPath := filepath.Join(dc.BasePath(ctx), h) fi, err := os.Stat(itemPath) if err == nil || (fi != nil && !fi.IsDir()) { @@ -65,12 +77,12 @@ func New(ctx context.Context, id string) (string, error) { // returns the directory and true. If it // does not exist, it returns an empty string // and false. -func Get(ctx context.Context, id string) (string, bool) { +func (dc *DownloadCache) Get(ctx context.Context, id string) (string, bool) { h, err := hashID(id) if err != nil { return "", false } - itemPath := filepath.Join(BasePath(ctx), h) + itemPath := filepath.Join(dc.BasePath(ctx), h) _, err = os.Stat(itemPath) if err != nil { @@ -79,15 +91,3 @@ func Get(ctx context.Context, id string) (string, bool) { return itemPath, true } - -// hashID hashes the input ID with SHA1 -// and returns the hex string of the hashed -// ID. -func hashID(id string) (string, error) { - h := sha1.New() - _, err := io.WriteString(h, id) - if err != nil { - return "", err - } - return hex.EncodeToString(h.Sum(nil)), nil -} diff --git a/internal/dlcache/dlcache_test.go b/internal/dlcache/dlcache_test.go index d347a83..7cd78cc 100644 --- a/internal/dlcache/dlcache_test.go +++ b/internal/dlcache/dlcache_test.go @@ -39,14 +39,49 @@ func init() { config.GetPaths(context.Background()).RepoDir = dir } +type TestALRConfig struct { + CacheDir string +} + +func (c *TestALRConfig) GetPaths(ctx context.Context) *config.Paths { + return &config.Paths{ + CacheDir: c.CacheDir, + } +} + +func prepare(t *testing.T) *TestALRConfig { + t.Helper() + + dir, err := os.MkdirTemp("/tmp", "alr-dlcache-test.*") + if err != nil { + panic(err) + } + + return &TestALRConfig{ + CacheDir: dir, + } +} + +func cleanup(t *testing.T, cfg *TestALRConfig) { + t.Helper() + os.Remove(cfg.CacheDir) +} + func TestNew(t *testing.T) { + cfg := prepare(t) + defer cleanup(t, cfg) + + dc := dlcache.New(cfg) + + ctx := context.Background() + const id = "https://example.com" - dir, err := dlcache.New(id) + dir, err := dc.New(ctx, id) if err != nil { t.Errorf("Expected no error, got %s", err) } - exp := filepath.Join(dlcache.BasePath(), sha1sum(id)) + exp := filepath.Join(dc.BasePath(ctx), sha1sum(id)) if dir != exp { t.Errorf("Expected %s, got %s", exp, dir) } @@ -60,7 +95,7 @@ func TestNew(t *testing.T) { t.Errorf("Expected cache item to be a directory") } - dir2, ok := dlcache.Get(id) + dir2, ok := dc.Get(ctx, id) if !ok { t.Errorf("Expected Get() to return valid value") } diff --git a/internal/dlcache/utils.go b/internal/dlcache/utils.go new file mode 100644 index 0000000..4b7a913 --- /dev/null +++ b/internal/dlcache/utils.go @@ -0,0 +1,16 @@ +package dlcache + +import ( + "crypto/sha1" + "encoding/hex" + "io" +) + +func hashID(id string) (string, error) { + h := sha1.New() + _, err := io.WriteString(h, id) + if err != nil { + return "", err + } + return hex.EncodeToString(h.Sum(nil)), nil +} diff --git a/internal/shutils/decoder/decoder_test.go b/internal/shutils/decoder/decoder_test.go index b757581..f703764 100644 --- a/internal/shutils/decoder/decoder_test.go +++ b/internal/shutils/decoder/decoder_test.go @@ -27,10 +27,10 @@ import ( "strings" "testing" - "plemya-x.ru/alr/internal/shutils/decoder" - "plemya-x.ru/alr/pkg/distro" "mvdan.cc/sh/v3/interp" "mvdan.cc/sh/v3/syntax" + "plemya-x.ru/alr/internal/shutils/decoder" + "plemya-x.ru/alr/pkg/distro" ) type BuildVars struct { @@ -56,7 +56,7 @@ const testScript = ` release=1 epoch=2 desc="Test package" - homepage='//https://gitea.plemya-x.ru/xpamych/ALR' + homepage='https://gitea.plemya-x.ru/xpamych/ALR' maintainer='Евгений Храмов ' architectures=('arm64' 'amd64') license=('GPL-3.0-or-later') diff --git a/internal/shutils/handlers/exec_test.go b/internal/shutils/handlers/exec_test.go index 56f38cd..de86030 100644 --- a/internal/shutils/handlers/exec_test.go +++ b/internal/shutils/handlers/exec_test.go @@ -23,11 +23,11 @@ import ( "strings" "testing" - "plemya-x.ru/alr/internal/shutils/handlers" - "plemya-x.ru/alr/internal/shutils/decoder" - "plemya-x.ru/alr/pkg/distro" "mvdan.cc/sh/v3/interp" "mvdan.cc/sh/v3/syntax" + "plemya-x.ru/alr/internal/shutils/decoder" + "plemya-x.ru/alr/internal/shutils/handlers" + "plemya-x.ru/alr/pkg/distro" ) const testScript = ` @@ -89,7 +89,7 @@ func TestExecFuncs(t *testing.T) { t.Fatalf("Expected test() function to exist") } - eh := shutils.ExecFuncs{ + eh := handlers.ExecFuncs{ "test-cmd": func(hc interp.HandlerContext, name string, args []string) error { if name != "test-cmd" { t.Errorf("Expected name to be 'test-cmd', got '%s'", name) diff --git a/list.go b/list.go index a9ebf29..d89081b 100644 --- a/list.go +++ b/list.go @@ -22,12 +22,12 @@ import ( "fmt" "github.com/urfave/cli/v2" + "golang.org/x/exp/slices" "plemya-x.ru/alr/internal/config" - "plemya-x.ru/alr/internal/db" + database "plemya-x.ru/alr/internal/db" "plemya-x.ru/alr/pkg/loggerctx" "plemya-x.ru/alr/pkg/manager" "plemya-x.ru/alr/pkg/repos" - "golang.org/x/exp/slices" ) var listCmd = &cli.Command{ @@ -43,8 +43,14 @@ var listCmd = &cli.Command{ Action: func(c *cli.Context) error { ctx := c.Context log := loggerctx.From(ctx) - - err := repos.Pull(ctx, config.Config(ctx).Repos) + cfg := config.New() + db := database.New(cfg) + err := db.Init(ctx) + if err != nil { + log.Fatal("Error initialization database").Err(err).Send() + } + rs := repos.New(cfg, db) + err = rs.Pull(ctx, cfg.Repos(ctx)) if err != nil { log.Fatal("Error pulling repositories").Err(err).Send() } @@ -76,13 +82,13 @@ var listCmd = &cli.Command{ } for result.Next() { - var pkg db.Package + var pkg database.Package err := result.StructScan(&pkg) if err != nil { return err } - if slices.Contains(config.Config(ctx).IgnorePkgUpdates, pkg.Name) { + if slices.Contains(cfg.IgnorePkgUpdates(ctx), pkg.Name) { continue } diff --git a/pkg/repos/find.go b/pkg/repos/find.go index faf6c30..20394de 100644 --- a/pkg/repos/find.go +++ b/pkg/repos/find.go @@ -15,7 +15,6 @@ * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ - package repos import ( @@ -24,10 +23,7 @@ import ( "plemya-x.ru/alr/internal/db" ) -// FindPkgs looks for packages matching the inputs inside the database. -// It returns a map that maps the package name input to any packages found for it. -// It also returns a slice that contains the names of all packages that were not found. -func FindPkgs(ctx context.Context, pkgs []string) (map[string][]db.Package, []string, error) { +func (rs *Repos) FindPkgs(ctx context.Context, pkgs []string) (map[string][]db.Package, []string, error) { found := map[string][]db.Package{} notFound := []string(nil) @@ -36,7 +32,7 @@ func FindPkgs(ctx context.Context, pkgs []string) (map[string][]db.Package, []st continue } - result, err := db.GetPkgs(ctx, "json_array_contains(provides, ?)", pkgName) + result, err := rs.db.GetPkgs(ctx, "json_array_contains(provides, ?)", pkgName) if err != nil { return nil, nil, err } @@ -55,7 +51,7 @@ func FindPkgs(ctx context.Context, pkgs []string) (map[string][]db.Package, []st result.Close() if added == 0 { - result, err := db.GetPkgs(ctx, "name LIKE ?", pkgName) + result, err := rs.db.GetPkgs(ctx, "name LIKE ?", pkgName) if err != nil { return nil, nil, err } diff --git a/pkg/repos/find_test.go b/pkg/repos/find_test.go index f435489..790dff3 100644 --- a/pkg/repos/find_test.go +++ b/pkg/repos/find_test.go @@ -19,7 +19,6 @@ package repos_test import ( - "context" "reflect" "strings" "testing" @@ -30,18 +29,15 @@ import ( ) func TestFindPkgs(t *testing.T) { - _, err := db.Open(":memory:") - if err != nil { - t.Fatalf("Expected no error, got %s", err) - } - defer db.Close() + e := prepare(t) + defer cleanup(t, e) - setCfgDirs(t) - defer removeCacheDir(t) + rs := repos.New( + e.Cfg, + e.Db, + ) - ctx := context.Background() - - err = repos.Pull(ctx, []types.Repo{ + err := rs.Pull(e.Ctx, []types.Repo{ { Name: "default", URL: "https://gitea.plemya-x.ru/xpamych/xpamych-alr-repo.git", @@ -51,7 +47,10 @@ func TestFindPkgs(t *testing.T) { t.Fatalf("Expected no error, got %s", err) } - found, notFound, err := repos.FindPkgs([]string{"itd", "nonexistentpackage1", "nonexistentpackage2"}) + found, notFound, err := rs.FindPkgs( + e.Ctx, + []string{"alr", "nonexistentpackage1", "nonexistentpackage2"}, + ) if err != nil { t.Fatalf("Expected no error, got %s", err) } @@ -64,33 +63,32 @@ func TestFindPkgs(t *testing.T) { t.Errorf("Expected 1 package found, got %d", len(found)) } - itdPkgs, ok := found["itd"] + alrPkgs, ok := found["alr"] if !ok { - t.Fatalf("Expected 'itd' packages to be found") + t.Fatalf("Expected 'alr' packages to be found") } - if len(itdPkgs) < 2 { - t.Errorf("Expected two 'itd' packages to be found") + if len(alrPkgs) < 2 { + t.Errorf("Expected two 'alr' packages to be found") } - for i, pkg := range itdPkgs { - if !strings.HasPrefix(pkg.Name, "itd") { - t.Errorf("Expected package name of all found packages to start with 'itd', got %s on element %d", pkg.Name, i) + for i, pkg := range alrPkgs { + if !strings.HasPrefix(pkg.Name, "alr") { + t.Errorf("Expected package name of all found packages to start with 'alr', got %s on element %d", pkg.Name, i) } } } func TestFindPkgsEmpty(t *testing.T) { - _, err := db.Open(":memory:") - if err != nil { - t.Fatalf("Expected no error, got %s", err) - } - defer db.Close() + e := prepare(t) + defer cleanup(t, e) - setCfgDirs(t) - defer removeCacheDir(t) + rs := repos.New( + e.Cfg, + e.Db, + ) - err = db.InsertPackage(db.Package{ + err := e.Db.InsertPackage(e.Ctx, db.Package{ Name: "test1", Repository: "default", Version: "0.0.1", @@ -105,7 +103,7 @@ func TestFindPkgsEmpty(t *testing.T) { t.Fatalf("Expected no error, got %s", err) } - err = db.InsertPackage(db.Package{ + err = e.Db.InsertPackage(e.Ctx, db.Package{ Name: "test2", Repository: "default", Version: "0.0.1", @@ -120,7 +118,7 @@ func TestFindPkgsEmpty(t *testing.T) { t.Fatalf("Expected no error, got %s", err) } - found, notFound, err := repos.FindPkgs([]string{"test", ""}) + found, notFound, err := rs.FindPkgs(e.Ctx, []string{"test", ""}) if err != nil { t.Fatalf("Expected no error, got %s", err) } diff --git a/pkg/repos/pull.go b/pkg/repos/pull.go index 516d7d8..f151121 100644 --- a/pkg/repos/pull.go +++ b/pkg/repos/pull.go @@ -21,41 +21,48 @@ package repos import ( "context" "errors" - "io" "net/url" "os" "path/filepath" - "reflect" "strings" "github.com/go-git/go-billy/v5" "github.com/go-git/go-billy/v5/osfs" "github.com/go-git/go-git/v5" "github.com/go-git/go-git/v5/plumbing" - "github.com/go-git/go-git/v5/plumbing/format/diff" "github.com/pelletier/go-toml/v2" "go.elara.ws/vercmp" - "plemya-x.ru/alr/internal/config" - "plemya-x.ru/alr/internal/db" - "plemya-x.ru/alr/internal/shutils/decoder" - "plemya-x.ru/alr/internal/shutils/handlers" - "plemya-x.ru/alr/internal/types" - "plemya-x.ru/alr/pkg/distro" - "plemya-x.ru/alr/pkg/loggerctx" "mvdan.cc/sh/v3/expand" "mvdan.cc/sh/v3/interp" "mvdan.cc/sh/v3/syntax" + "plemya-x.ru/alr/internal/config" + "plemya-x.ru/alr/internal/db" + "plemya-x.ru/alr/internal/shutils/handlers" + "plemya-x.ru/alr/internal/types" + "plemya-x.ru/alr/pkg/loggerctx" ) +type actionType uint8 + +const ( + actionDelete actionType = iota + actionUpdate +) + +type action struct { + Type actionType + File string +} + // Pull pulls the provided repositories. If a repo doesn't exist, it will be cloned // and its packages will be written to the DB. If it does exist, it will be pulled. // In this case, only changed packages will be processed if possible. // If repos is set to nil, the repos in the ALR config will be used. -func Pull(ctx context.Context, repos []types.Repo) error { +func (rs *Repos) Pull(ctx context.Context, repos []types.Repo) error { log := loggerctx.From(ctx) if repos == nil { - repos = config.Config(ctx).Repos + repos = rs.cfg.Repos(ctx) } for _, repo := range repos { @@ -95,7 +102,7 @@ func Pull(ctx context.Context, repos []types.Repo) error { repoFS = w.Filesystem // Make sure the DB is created even if the repo is up to date - if !errors.Is(err, git.NoErrAlreadyUpToDate) || db.IsEmpty(ctx) { + if !errors.Is(err, git.NoErrAlreadyUpToDate) || rs.db.IsEmpty(ctx) { new, err := r.Head() if err != nil { return err @@ -104,13 +111,13 @@ func Pull(ctx context.Context, repos []types.Repo) error { // If the DB was not present at startup, that means it's // empty. In this case, we need to update the DB fully // rather than just incrementally. - if db.IsEmpty(ctx) { - err = processRepoFull(ctx, repo, repoDir) + if rs.db.IsEmpty(ctx) { + err = rs.processRepoFull(ctx, repo, repoDir) if err != nil { return err } } else { - err = processRepoChanges(ctx, repo, r, w, old, new) + err = rs.processRepoChanges(ctx, repo, r, w, old, new) if err != nil { return err } @@ -135,7 +142,7 @@ func Pull(ctx context.Context, repos []types.Repo) error { return err } - err = processRepoFull(ctx, repo, repoDir) + err = rs.processRepoFull(ctx, repo, repoDir) if err != nil { return err } @@ -169,19 +176,7 @@ func Pull(ctx context.Context, repos []types.Repo) error { return nil } -type actionType uint8 - -const ( - actionDelete actionType = iota - actionUpdate -) - -type action struct { - Type actionType - File string -} - -func processRepoChanges(ctx context.Context, repo types.Repo, r *git.Repository, w *git.Worktree, old, new *plumbing.Reference) error { +func (rs *Repos) processRepoChanges(ctx context.Context, repo types.Repo, r *git.Repository, w *git.Worktree, old, new *plumbing.Reference) error { oldCommit, err := r.CommitObject(old.Hash()) if err != nil { return err @@ -275,7 +270,7 @@ func processRepoChanges(ctx context.Context, repo types.Repo, r *git.Repository, return err } - err = db.DeletePkgs(ctx, "name = ? AND repository = ?", pkg.Name, repo.Name) + err = rs.db.DeletePkgs(ctx, "name = ? AND repository = ?", pkg.Name, repo.Name) if err != nil { return err } @@ -310,7 +305,7 @@ func processRepoChanges(ctx context.Context, repo types.Repo, r *git.Repository, resolveOverrides(runner, &pkg) - err = db.InsertPackage(ctx, pkg) + err = rs.db.InsertPackage(ctx, pkg) if err != nil { return err } @@ -320,23 +315,7 @@ func processRepoChanges(ctx context.Context, repo types.Repo, r *git.Repository, return nil } -// isValid makes sure the path of the file being updated is valid. -// It checks to make sure the file is not within a nested directory -// and that it is called alr.sh. -func isValid(from, to diff.File) bool { - var path string - if from != nil { - path = from.Path() - } - if to != nil { - path = to.Path() - } - - match, _ := filepath.Match("*/*.sh", path) - return match -} - -func processRepoFull(ctx context.Context, repo types.Repo, repoDir string) error { +func (rs *Repos) processRepoFull(ctx context.Context, repo types.Repo, repoDir string) error { glob := filepath.Join(repoDir, "/*/alr.sh") matches, err := filepath.Glob(glob) if err != nil { @@ -380,7 +359,7 @@ func processRepoFull(ctx context.Context, repo types.Repo, repoDir string) error resolveOverrides(runner, &pkg) - err = db.InsertPackage(ctx, pkg) + err = rs.db.InsertPackage(ctx, pkg) if err != nil { return err } @@ -388,54 +367,3 @@ func processRepoFull(ctx context.Context, repo types.Repo, repoDir string) error return nil } - -func parseScript(ctx context.Context, parser *syntax.Parser, runner *interp.Runner, r io.ReadCloser, pkg *db.Package) error { - defer r.Close() - fl, err := parser.Parse(r, "alr.sh") - if err != nil { - return err - } - - runner.Reset() - err = runner.Run(ctx, fl) - if err != nil { - return err - } - - d := decoder.New(&distro.OSRelease{}, runner) - d.Overrides = false - d.LikeDistros = false - return d.DecodeVars(pkg) -} - -var overridable = map[string]string{ - "deps": "Depends", - "build_deps": "BuildDepends", - "desc": "Description", - "homepage": "Homepage", - "maintainer": "Maintainer", -} - -func resolveOverrides(runner *interp.Runner, pkg *db.Package) { - pkgVal := reflect.ValueOf(pkg).Elem() - for name, val := range runner.Vars { - for prefix, field := range overridable { - if strings.HasPrefix(name, prefix) { - override := strings.TrimPrefix(name, prefix) - override = strings.TrimPrefix(override, "_") - - field := pkgVal.FieldByName(field) - varVal := field.FieldByName("Val") - varType := varVal.Type() - - switch varType.Elem().String() { - case "[]string": - varVal.SetMapIndex(reflect.ValueOf(override), reflect.ValueOf(val.List)) - case "string": - varVal.SetMapIndex(reflect.ValueOf(override), reflect.ValueOf(val.Str)) - } - break - } - } - } -} diff --git a/pkg/repos/pull_test.go b/pkg/repos/pull_test.go index af1db65..44075e3 100644 --- a/pkg/repos/pull_test.go +++ b/pkg/repos/pull_test.go @@ -26,69 +26,104 @@ import ( "plemya-x.ru/alr/internal/config" "plemya-x.ru/alr/internal/db" + database "plemya-x.ru/alr/internal/db" "plemya-x.ru/alr/internal/types" "plemya-x.ru/alr/pkg/repos" ) -func setCfgDirs(t *testing.T) { - t.Helper() - - paths := config.GetPaths() - - var err error - paths.CacheDir, err = os.MkdirTemp("/tmp", "alr-pull-test.*") - if err != nil { - t.Fatalf("Expected no error, got %s", err) - } - - paths.RepoDir = filepath.Join(paths.CacheDir, "repo") - paths.PkgsDir = filepath.Join(paths.CacheDir, "pkgs") - - err = os.MkdirAll(paths.RepoDir, 0o755) - if err != nil { - t.Fatalf("Expected no error, got %s", err) - } - - err = os.MkdirAll(paths.PkgsDir, 0o755) - if err != nil { - t.Fatalf("Expected no error, got %s", err) - } - - paths.DBPath = filepath.Join(paths.CacheDir, "db") +type TestEnv struct { + Ctx context.Context + Cfg *TestALRConfig + Db *db.Database } -func removeCacheDir(t *testing.T) { - t.Helper() +type TestALRConfig struct { + CacheDir string + RepoDir string + PkgsDir string +} - err := os.RemoveAll(config.GetPaths().CacheDir) - if err != nil { - t.Fatalf("Expected no error, got %s", err) +func (c *TestALRConfig) GetPaths(ctx context.Context) *config.Paths { + return &config.Paths{ + DBPath: ":memory:", + CacheDir: c.CacheDir, + RepoDir: c.RepoDir, + PkgsDir: c.PkgsDir, } } -func TestPull(t *testing.T) { - _, err := db.Open(":memory:") +func (c *TestALRConfig) Repos(ctx context.Context) []types.Repo { + return []types.Repo{} +} + +func prepare(t *testing.T) *TestEnv { + t.Helper() + + cacheDir, err := os.MkdirTemp("/tmp", "alr-pull-test.*") if err != nil { t.Fatalf("Expected no error, got %s", err) } - defer db.Close() - setCfgDirs(t) - defer removeCacheDir(t) + repoDir := filepath.Join(cacheDir, "repo") + err = os.MkdirAll(repoDir, 0o755) + if err != nil { + t.Fatalf("Expected no error, got %s", err) + } + + pkgsDir := filepath.Join(cacheDir, "pkgs") + err = os.MkdirAll(pkgsDir, 0o755) + if err != nil { + t.Fatalf("Expected no error, got %s", err) + } + + cfg := &TestALRConfig{ + CacheDir: cacheDir, + RepoDir: repoDir, + PkgsDir: pkgsDir, + } ctx := context.Background() - err = repos.Pull(ctx, []types.Repo{ + db := database.New(cfg) + db.Init(ctx) + + return &TestEnv{ + Cfg: cfg, + Db: db, + Ctx: ctx, + } +} + +func cleanup(t *testing.T, e *TestEnv) { + t.Helper() + + err := os.RemoveAll(e.Cfg.CacheDir) + if err != nil { + t.Fatalf("Expected no error, got %s", err) + } + e.Db.Close() +} + +func TestPull(t *testing.T) { + e := prepare(t) + defer cleanup(t, e) + + rs := repos.New( + e.Cfg, + e.Db, + ) + + err := rs.Pull(e.Ctx, []types.Repo{ { Name: "default", - URL: "https://gitea.plemya-x.ru/xpamych/ALR.git", + URL: "https://gitea.plemya-x.ru/Plemya-x/xpamych-alr-repo.git", }, }) if err != nil { t.Fatalf("Expected no error, got %s", err) } - result, err := db.GetPkgs("name LIKE 'itd%'") + result, err := e.Db.GetPkgs(e.Ctx, "true") if err != nil { t.Fatalf("Expected no error, got %s", err) } @@ -103,7 +138,7 @@ func TestPull(t *testing.T) { pkgAmt++ } - if pkgAmt < 2 { - t.Errorf("Expected 2 packages to match, got %d", pkgAmt) + if pkgAmt == 0 { + t.Errorf("Expected at least 1 matching package, but got %d", pkgAmt) } } diff --git a/pkg/repos/repos.go b/pkg/repos/repos.go new file mode 100644 index 0000000..5dd420b --- /dev/null +++ b/pkg/repos/repos.go @@ -0,0 +1,29 @@ +package repos + +import ( + "context" + + "plemya-x.ru/alr/internal/config" + database "plemya-x.ru/alr/internal/db" + "plemya-x.ru/alr/internal/types" +) + +type Config interface { + GetPaths(ctx context.Context) *config.Paths + Repos(ctx context.Context) []types.Repo +} + +type Repos struct { + cfg Config + db *database.Database +} + +func New( + cfg Config, + db *database.Database, +) *Repos { + return &Repos{ + cfg, + db, + } +} diff --git a/pkg/repos/repos_legacy.go b/pkg/repos/repos_legacy.go new file mode 100644 index 0000000..892e272 --- /dev/null +++ b/pkg/repos/repos_legacy.go @@ -0,0 +1,54 @@ +package repos + +import ( + "context" + "sync" + + "plemya-x.ru/alr/internal/config" + "plemya-x.ru/alr/internal/db" + database "plemya-x.ru/alr/internal/db" + "plemya-x.ru/alr/internal/types" +) + +// Pull pulls the provided repositories. If a repo doesn't exist, it will be cloned +// and its packages will be written to the DB. If it does exist, it will be pulled. +// In this case, only changed packages will be processed if possible. +// If repos is set to nil, the repos in the ALR config will be used. +// +// Deprecated: use struct method +func Pull(ctx context.Context, repos []types.Repo) error { + return GetInstance(ctx).Pull(ctx, repos) +} + +// FindPkgs looks for packages matching the inputs inside the database. +// It returns a map that maps the package name input to any packages found for it. +// It also returns a slice that contains the names of all packages that were not found. +// +// Deprecated: use struct method +func FindPkgs(ctx context.Context, pkgs []string) (map[string][]db.Package, []string, error) { + return GetInstance(ctx).FindPkgs(ctx, pkgs) +} + +// ======================= +// FOR LEGACY ONLY +// ======================= + +var ( + reposInstance *Repos + alrConfigOnce sync.Once +) + +// Deprecated: For legacy only +func GetInstance(ctx context.Context) *Repos { + alrConfigOnce.Do(func() { + cfg := config.GetInstance(ctx) + db := database.GetInstance(ctx) + + reposInstance = New( + cfg, + db, + ) + }) + + return reposInstance +} diff --git a/pkg/repos/utils.go b/pkg/repos/utils.go new file mode 100644 index 0000000..b4a6d3d --- /dev/null +++ b/pkg/repos/utils.go @@ -0,0 +1,83 @@ +package repos + +import ( + "context" + "io" + "path/filepath" + "reflect" + "strings" + + "github.com/go-git/go-git/v5/plumbing/format/diff" + "mvdan.cc/sh/v3/interp" + "mvdan.cc/sh/v3/syntax" + "plemya-x.ru/alr/internal/db" + "plemya-x.ru/alr/internal/shutils/decoder" + "plemya-x.ru/alr/pkg/distro" +) + +// isValid makes sure the path of the file being updated is valid. +// It checks to make sure the file is not within a nested directory +// and that it is called alr.sh. +func isValid(from, to diff.File) bool { + var path string + if from != nil { + path = from.Path() + } + if to != nil { + path = to.Path() + } + + match, _ := filepath.Match("*/*.sh", path) + return match +} + +func parseScript(ctx context.Context, parser *syntax.Parser, runner *interp.Runner, r io.ReadCloser, pkg *db.Package) error { + defer r.Close() + fl, err := parser.Parse(r, "alr.sh") + if err != nil { + return err + } + + runner.Reset() + err = runner.Run(ctx, fl) + if err != nil { + return err + } + + d := decoder.New(&distro.OSRelease{}, runner) + d.Overrides = false + d.LikeDistros = false + return d.DecodeVars(pkg) +} + +var overridable = map[string]string{ + "deps": "Depends", + "build_deps": "BuildDepends", + "desc": "Description", + "homepage": "Homepage", + "maintainer": "Maintainer", +} + +func resolveOverrides(runner *interp.Runner, pkg *db.Package) { + pkgVal := reflect.ValueOf(pkg).Elem() + for name, val := range runner.Vars { + for prefix, field := range overridable { + if strings.HasPrefix(name, prefix) { + override := strings.TrimPrefix(name, prefix) + override = strings.TrimPrefix(override, "_") + + field := pkgVal.FieldByName(field) + varVal := field.FieldByName("Val") + varType := varVal.Type() + + switch varType.Elem().String() { + case "[]string": + varVal.SetMapIndex(reflect.ValueOf(override), reflect.ValueOf(val.List)) + case "string": + varVal.SetMapIndex(reflect.ValueOf(override), reflect.ValueOf(val.Str)) + } + break + } + } + } +}