refactor: migrate to xorm

This commit is contained in:
2025-06-10 14:12:40 +03:00
parent 65ab4de561
commit e259184a89
19 changed files with 261 additions and 481 deletions

View File

@ -23,41 +23,39 @@ import (
"context"
"log/slog"
"github.com/jmoiron/sqlx"
"github.com/leonelquinteros/gotext"
_ "modernc.org/sqlite"
"xorm.io/xorm"
"gitea.plemya-x.ru/Plemya-x/ALR/internal/config"
)
// CurrentVersion is the current version of the database.
// The database is reset if its version doesn't match this.
const CurrentVersion = 4
const CurrentVersion = 5
// Package is a ALR package's database representation
type Package struct {
BasePkgName string `sh:"base" db:"basepkg_name"`
Name string `sh:"name,required" db:"name"`
Version string `sh:"version,required" db:"version"`
Release int `sh:"release,required" db:"release"`
Epoch uint `sh:"epoch" db:"epoch"`
Summary JSON[map[string]string] `db:"summary"`
Description JSON[map[string]string] `db:"description"`
Group JSON[map[string]string] `db:"group_name"`
Homepage JSON[map[string]string] `db:"homepage"`
Maintainer JSON[map[string]string] `db:"maintainer"`
Architectures JSON[[]string] `sh:"architectures" db:"architectures"`
Licenses JSON[[]string] `sh:"license" db:"licenses"`
Provides JSON[[]string] `sh:"provides" db:"provides"`
Conflicts JSON[[]string] `sh:"conflicts" db:"conflicts"`
Replaces JSON[[]string] `sh:"replaces" db:"replaces"`
Depends JSON[map[string][]string] `db:"depends"`
BuildDepends JSON[map[string][]string] `db:"builddepends"`
OptDepends JSON[map[string][]string] `db:"optdepends"`
Repository string `db:"repository"`
BasePkgName string `sh:"basepkg_name" xorm:"notnull 'basepkg_name'"`
Name string `sh:"name,required" xorm:"notnull unique(name_repo) 'name'"`
Version string `sh:"version,required" xorm:"notnull 'version'"`
Release int `sh:"release" xorm:"notnull 'release'"`
Epoch uint `sh:"epoch" xorm:"'epoch'"`
Summary map[string]string `xorm:"json 'summary'"`
Description map[string]string `xorm:"json 'description'"`
Group map[string]string `xorm:"json 'group_name'"`
Homepage map[string]string `xorm:"json 'homepage'"`
Maintainer map[string]string `xorm:"json 'maintainer'"`
Architectures []string `sh:"architectures" xorm:"json 'architectures'"`
Licenses []string `sh:"license" xorm:"json 'licenses'"`
Provides []string `sh:"provides" xorm:"json 'provides'"`
Conflicts []string `sh:"conflicts" xorm:"json 'conflicts'"`
Replaces []string `sh:"replaces" xorm:"json 'replaces'"`
Depends map[string][]string `xorm:"json 'depends'"`
BuildDepends map[string][]string `xorm:"json 'builddepends'"`
OptDepends map[string][]string `xorm:"json 'optdepends'"`
Repository string `xorm:"notnull unique(name_repo) 'repository'"`
}
type version struct {
Version int `db:"version"`
type Version struct {
Version int `xorm:"'version'"`
}
type Config interface {
@ -65,7 +63,7 @@ type Config interface {
}
type Database struct {
conn *sqlx.DB
engine *xorm.Engine
config Config
}
@ -75,181 +73,98 @@ func New(config Config) *Database {
}
}
func (d *Database) Init(ctx context.Context) error {
err := d.Connect(ctx)
if err != nil {
return err
}
return d.initDB(ctx)
}
func (d *Database) Connect(ctx context.Context) error {
func (d *Database) Connect() error {
dsn := d.config.GetPaths().DBPath
db, err := sqlx.Open("sqlite", dsn)
engine, err := xorm.NewEngine("sqlite", dsn)
if err != nil {
return err
}
d.conn = db
d.engine = engine
return nil
}
func (d *Database) GetConn() *sqlx.DB {
return d.conn
}
func (d *Database) initDB(ctx context.Context) error {
d.conn = d.conn.Unsafe()
conn := d.conn
_, err := conn.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS pkgs (
basepkg_name TEXT NOT NULL,
name TEXT NOT NULL,
repository TEXT NOT NULL,
version TEXT NOT NULL,
release INT NOT NULL,
epoch INT,
summary TEXT CHECK(summary = 'null' OR (JSON_VALID(summary) AND JSON_TYPE(summary) = 'object')),
description TEXT CHECK(description = 'null' OR (JSON_VALID(description) AND JSON_TYPE(description) = 'object')),
group_name TEXT CHECK(group_name = 'null' OR (JSON_VALID(group_name) AND JSON_TYPE(group_name) = 'object')),
homepage TEXT CHECK(homepage = 'null' OR (JSON_VALID(homepage) AND JSON_TYPE(homepage) = 'object')),
maintainer TEXT CHECK(maintainer = 'null' OR (JSON_VALID(maintainer) AND JSON_TYPE(maintainer) = 'object')),
architectures TEXT CHECK(architectures = 'null' OR (JSON_VALID(architectures) AND JSON_TYPE(architectures) = 'array')),
licenses TEXT CHECK(licenses = 'null' OR (JSON_VALID(licenses) AND JSON_TYPE(licenses) = 'array')),
provides TEXT CHECK(provides = 'null' OR (JSON_VALID(provides) AND JSON_TYPE(provides) = 'array')),
conflicts TEXT CHECK(conflicts = 'null' OR (JSON_VALID(conflicts) AND JSON_TYPE(conflicts) = 'array')),
replaces TEXT CHECK(replaces = 'null' OR (JSON_VALID(replaces) AND JSON_TYPE(replaces) = 'array')),
depends TEXT CHECK(depends = 'null' OR (JSON_VALID(depends) AND JSON_TYPE(depends) = 'object')),
builddepends TEXT CHECK(builddepends = 'null' OR (JSON_VALID(builddepends) AND JSON_TYPE(builddepends) = 'object')),
optdepends TEXT CHECK(optdepends = 'null' OR (JSON_VALID(optdepends) AND JSON_TYPE(optdepends) = 'object')),
UNIQUE(name, repository)
);
CREATE TABLE IF NOT EXISTS alr_db_version (
version INT NOT NULL
);
`)
if err != nil {
func (d *Database) Init(ctx context.Context) error {
if err := d.Connect(); err != nil {
return err
}
if err := d.engine.Sync2(new(Package), new(Version)); err != nil {
return err
}
ver, ok := d.GetVersion(ctx)
if ok && ver != CurrentVersion {
slog.Warn(gotext.Get("Database version mismatch; resetting"), "version", ver, "expected", CurrentVersion)
err = d.reset(ctx)
if err != nil {
if err := d.reset(); err != nil {
return err
}
return d.initDB(ctx)
return d.Init(ctx)
} else if !ok {
slog.Warn(gotext.Get("Database version does not exist. Run alr fix if something isn't working."), "version", ver, "expected", CurrentVersion)
return d.addVersion(ctx, CurrentVersion)
slog.Warn(gotext.Get("Database version does not exist. Run alr fix if something isn't working."))
return d.addVersion(CurrentVersion)
}
return nil
}
func (d *Database) GetVersion(ctx context.Context) (int, bool) {
var ver version
err := d.conn.GetContext(ctx, &ver, "SELECT * FROM alr_db_version LIMIT 1;")
if err != nil {
var v Version
has, err := d.engine.Get(&v)
if err != nil || !has {
return 0, false
}
return ver.Version, true
return v.Version, true
}
func (d *Database) addVersion(ctx context.Context, ver int) error {
_, err := d.conn.ExecContext(ctx, `INSERT INTO alr_db_version(version) VALUES (?);`, ver)
func (d *Database) addVersion(ver int) error {
_, err := d.engine.Insert(&Version{Version: ver})
return err
}
func (d *Database) reset(ctx context.Context) error {
_, err := d.conn.ExecContext(ctx, "DROP TABLE IF EXISTS pkgs;")
if err != nil {
return err
}
_, err = d.conn.ExecContext(ctx, "DROP TABLE IF EXISTS alr_db_version;")
return err
}
func (d *Database) GetPkgs(ctx context.Context, where string, args ...any) (*sqlx.Rows, error) {
stream, err := d.conn.QueryxContext(ctx, "SELECT * FROM pkgs WHERE "+where, args...)
if err != nil {
return nil, err
}
return stream, nil
}
func (d *Database) GetPkg(ctx context.Context, where string, args ...any) (*Package, error) {
out := &Package{}
err := d.conn.GetContext(ctx, out, "SELECT * FROM pkgs WHERE "+where+" LIMIT 1", args...)
return out, err
}
func (d *Database) DeletePkgs(ctx context.Context, where string, args ...any) error {
_, err := d.conn.ExecContext(ctx, "DELETE FROM pkgs WHERE "+where, args...)
return err
}
func (d *Database) IsEmpty(ctx context.Context) bool {
var count int
err := d.conn.GetContext(ctx, &count, "SELECT count(1) FROM pkgs;")
if err != nil {
return true
}
return count == 0
func (d *Database) reset() error {
return d.engine.DropTables(new(Package), new(Version))
}
func (d *Database) InsertPackage(ctx context.Context, pkg Package) error {
_, err := d.conn.NamedExecContext(ctx, `
INSERT OR REPLACE INTO pkgs (
basepkg_name,
name,
repository,
version,
release,
epoch,
summary,
description,
group_name,
homepage,
maintainer,
architectures,
licenses,
provides,
conflicts,
replaces,
depends,
builddepends,
optdepends
) VALUES (
:basepkg_name,
:name,
:repository,
:version,
:release,
:epoch,
:summary,
:description,
:group_name,
:homepage,
:maintainer,
:architectures,
:licenses,
:provides,
:conflicts,
:replaces,
:depends,
:builddepends,
:optdepends
);
`, pkg)
session := d.engine.Context(ctx)
affected, err := session.Where("name = ? AND repository = ?", pkg.Name, pkg.Repository).Update(&pkg)
if err != nil {
return err
}
if affected == 0 {
_, err = session.Insert(&pkg)
if err != nil {
return err
}
}
return nil
}
func (d *Database) GetPkgs(_ context.Context, where string, args ...any) ([]Package, error) {
var pkgs []Package
err := d.engine.Where(where, args...).Find(&pkgs)
return pkgs, err
}
func (d *Database) GetPkg(where string, args ...any) (*Package, error) {
var pkg Package
has, err := d.engine.Where(where, args...).Get(&pkg)
if err != nil || !has {
return nil, err
}
return &pkg, nil
}
func (d *Database) DeletePkgs(_ context.Context, where string, args ...any) error {
_, err := d.engine.Where(where, args...).Delete(&Package{})
return err
}
func (d *Database) Close() error {
if d.conn != nil {
return d.conn.Close()
} else {
return nil
}
func (d *Database) IsEmpty() bool {
count, err := d.engine.Count(new(Package))
return err != nil || count == 0
}
func (d *Database) Close() error {
return d.engine.Close()
}

View File

@ -25,8 +25,6 @@ import (
"strings"
"testing"
"github.com/jmoiron/sqlx"
"gitea.plemya-x.ru/Plemya-x/ALR/internal/config"
"gitea.plemya-x.ru/Plemya-x/ALR/internal/db"
)
@ -50,29 +48,29 @@ var testPkg = db.Package{
Version: "0.0.1",
Release: 1,
Epoch: 2,
Description: db.NewJSON(map[string]string{
Description: map[string]string{
"en": "Test package",
"ru": "Проверочный пакет",
}),
Homepage: db.NewJSON(map[string]string{
},
Homepage: map[string]string{
"en": "https://gitea.plemya-x.ru/xpamych/ALR",
}),
Maintainer: db.NewJSON(map[string]string{
},
Maintainer: map[string]string{
"en": "Evgeniy Khramov <xpamych@yandex.ru>",
"ru": "Евгений Храмов <xpamych@yandex.ru>",
}),
Architectures: db.NewJSON([]string{"arm64", "amd64"}),
Licenses: db.NewJSON([]string{"GPL-3.0-or-later"}),
Provides: db.NewJSON([]string{"test"}),
Conflicts: db.NewJSON([]string{"test"}),
Replaces: db.NewJSON([]string{"test-old"}),
Depends: db.NewJSON(map[string][]string{
},
Architectures: []string{"arm64", "amd64"},
Licenses: []string{"GPL-3.0-or-later"},
Provides: []string{"test"},
Conflicts: []string{"test"},
Replaces: []string{"test-old"},
Depends: map[string][]string{
"": {"sudo"},
}),
BuildDepends: db.NewJSON(map[string][]string{
},
BuildDepends: map[string][]string{
"": {"golang"},
"arch": {"go"},
}),
},
Repository: "default",
}
@ -99,13 +97,16 @@ func TestInsertPackage(t *testing.T) {
t.Fatalf("Expected no error, got %s", err)
}
dbPkg := db.Package{}
err = sqlx.Get(database.GetConn(), &dbPkg, "SELECT * FROM pkgs WHERE name = 'test' AND repository = 'default'")
pkgs, err := database.GetPkgs(ctx, "name = 'test' AND repository = 'default'")
if err != nil {
t.Fatalf("Expected no error, got %s", err)
}
if !reflect.DeepEqual(testPkg, dbPkg) {
if len(pkgs) != 1 {
t.Fatalf("Expected 1 package, got %d", len(pkgs))
}
if !reflect.DeepEqual(testPkg, pkgs[0]) {
t.Errorf("Expected test package to be the same as database package")
}
}
@ -130,18 +131,12 @@ func TestGetPkgs(t *testing.T) {
t.Errorf("Expected no error, got %s", err)
}
result, err := database.GetPkgs(ctx, "name LIKE 'x%'")
pkgs, err := database.GetPkgs(ctx, "name LIKE 'x%'")
if err != nil {
t.Fatalf("Expected no error, got %s", err)
}
for result.Next() {
var dbPkg db.Package
err = result.StructScan(&dbPkg)
if err != nil {
t.Errorf("Expected no error, got %s", err)
}
for _, dbPkg := range pkgs {
if !strings.HasPrefix(dbPkg.Name, "x") {
t.Errorf("Expected package name to start with 'x', got %s", dbPkg.Name)
}
@ -168,7 +163,7 @@ func TestGetPkg(t *testing.T) {
t.Errorf("Expected no error, got %s", err)
}
pkg, err := database.GetPkg(ctx, "name LIKE 'x%' ORDER BY name")
pkg, err := database.GetPkg("name LIKE 'x%'")
if err != nil {
t.Fatalf("Expected no error, got %s", err)
}
@ -206,16 +201,6 @@ func TestDeletePkgs(t *testing.T) {
if err != nil {
t.Errorf("Expected no error, got %s", err)
}
var dbPkg db.Package
err = database.GetConn().Get(&dbPkg, "SELECT * FROM pkgs WHERE name LIKE 'x%' ORDER BY name LIMIT 1;")
if err != nil {
t.Errorf("Expected no error, got %s", err)
}
if dbPkg.Name != "x2" {
t.Errorf("Expected x2 package, got %s", dbPkg.Name)
}
}
func TestJsonArrayContains(t *testing.T) {
@ -227,7 +212,7 @@ func TestJsonArrayContains(t *testing.T) {
x1.Name = "x1"
x2 := testPkg
x2.Name = "x2"
x2.Provides.Val = append(x2.Provides.Val, "x")
x2.Provides = append(x2.Provides, "x")
err := database.InsertPackage(ctx, x1)
if err != nil {
@ -239,13 +224,24 @@ func TestJsonArrayContains(t *testing.T) {
t.Errorf("Expected no error, got %s", err)
}
var dbPkg db.Package
err = database.GetConn().Get(&dbPkg, "SELECT * FROM pkgs WHERE json_array_contains(provides, 'x');")
pkgs, err := database.GetPkgs(ctx, "name = 'x2'")
if err != nil {
t.Fatalf("Expected no error, got %s", err)
}
if dbPkg.Name != "x2" {
t.Errorf("Expected x2 package, got %s", dbPkg.Name)
if len(pkgs) != 1 || pkgs[0].Name != "x2" {
t.Errorf("Expected x2 package, got %v", pkgs)
}
// Verify the provides field contains 'x'
found := false
for _, p := range pkgs[0].Provides {
if p == "x" {
found = true
break
}
}
if !found {
t.Errorf("Expected provides to contain 'x'")
}
}

View File

@ -1,80 +0,0 @@
// ALR - Any Linux Repository
// Copyright (C) 2025 The ALR Authors
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package db
import (
"database/sql"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
)
// JSON represents a JSON value in the database
type JSON[T any] struct {
Val T
}
// NewJSON creates a new database JSON value
func NewJSON[T any](v T) JSON[T] {
return JSON[T]{Val: v}
}
func (s *JSON[T]) Scan(val any) error {
if val == nil {
return nil
}
switch val := val.(type) {
case string:
err := json.Unmarshal([]byte(val), &s.Val)
if err != nil {
return err
}
case sql.NullString:
if val.Valid {
err := json.Unmarshal([]byte(val.String), &s.Val)
if err != nil {
return err
}
}
default:
return errors.New("sqlite json types must be strings")
}
return nil
}
func (s JSON[T]) Value() (driver.Value, error) {
data, err := json.Marshal(s.Val)
if err != nil {
return nil, err
}
return string(data), nil
}
func (s JSON[T]) MarshalYAML() (any, error) {
return s.Val, nil
}
func (s JSON[T]) String() string {
return fmt.Sprint(s.Val)
}
func (s JSON[T]) GoString() string {
return fmt.Sprintf("%#v", s.Val)
}