From 85878f69d30651a92052f35dff2a5848f6898b3b Mon Sep 17 00:00:00 2001 From: Maxim Slipenko Date: Fri, 20 Jun 2025 20:12:43 +0300 Subject: [PATCH] feat: add checksum for torrent downloader --- assets/coverage-badge.svg | 4 +-- pkg/dl/git.go | 31 ++-------------- pkg/dl/torrent.go | 12 ++++++- pkg/dl/utils.go | 74 ++++++++++++++++++++++++++++++--------- 4 files changed, 72 insertions(+), 49 deletions(-) diff --git a/assets/coverage-badge.svg b/assets/coverage-badge.svg index 665bcb6..ff1ebbf 100644 --- a/assets/coverage-badge.svg +++ b/assets/coverage-badge.svg @@ -11,7 +11,7 @@ coverage coverage - 20.2% - 20.2% + 20.1% + 20.1% diff --git a/pkg/dl/git.go b/pkg/dl/git.go index 0de751f..b5380b1 100644 --- a/pkg/dl/git.go +++ b/pkg/dl/git.go @@ -20,11 +20,8 @@ package dl import ( - "bytes" "context" - "encoding/hex" "errors" - "log/slog" "net/url" "path" "strconv" @@ -127,7 +124,7 @@ func (d *GitDownloader) Download(ctx context.Context, opts Options) (Type, strin } } - err = d.verifyHash(opts) + err = VerifyHashFromLocal("", opts) if err != nil { return 0, "", err } @@ -139,30 +136,6 @@ func (d *GitDownloader) Download(ctx context.Context, opts Options) (Type, strin return TypeDir, name, nil } -func (GitDownloader) verifyHash(opts Options) error { - if opts.Hash != nil { - h, err := opts.NewHash() - if err != nil { - return err - } - - err = HashDir(opts.Destination, h) - if err != nil { - return err - } - - sum := h.Sum(nil) - - slog.Warn("validate checksum", "real", hex.EncodeToString(sum), "expected", hex.EncodeToString(opts.Hash)) - - if !bytes.Equal(sum, opts.Hash) { - return ErrChecksumMismatch - } - } - - return nil -} - // Update uses git to pull the repository and update it // to the latest revision. It allows specifying the depth // and recursion options via query string. It returns @@ -225,7 +198,7 @@ func (d *GitDownloader) Update(opts Options) (bool, error) { return false, err } - err = d.verifyHash(opts) + err = VerifyHashFromLocal("", opts) if err != nil { return false, err } diff --git a/pkg/dl/torrent.go b/pkg/dl/torrent.go index 31a2a62..42427f0 100644 --- a/pkg/dl/torrent.go +++ b/pkg/dl/torrent.go @@ -71,7 +71,17 @@ func (TorrentDownloader) Download(ctx context.Context, opts Options) (Type, stri return 0, "", err } - return determineType(opts.Destination) + dlType, name, err := determineType(opts.Destination) + if err != nil { + return 0, "", err + } + + err = VerifyHashFromLocal(name, opts) + if err != nil { + return 0, "", err + } + + return dlType, name, nil } func removeTorrentFiles(path string) error { diff --git a/pkg/dl/utils.go b/pkg/dl/utils.go index a4e4a62..a4fc8f1 100644 --- a/pkg/dl/utils.go +++ b/pkg/dl/utils.go @@ -17,39 +17,79 @@ package dl import ( + "bytes" + "encoding/hex" + "fmt" "hash" "io" + "log/slog" "os" "path/filepath" ) -func HashDir(dirPath string, h hash.Hash) error { - err := filepath.Walk(dirPath, func(path string, info os.FileInfo, err error) error { +// If the checksum does not match, returns ErrChecksumMismatch +func VerifyHashFromLocal(path string, opts Options) error { + if opts.Hash != nil { + h, err := opts.NewHash() if err != nil { return err } - // Skip .git directory - if info.IsDir() && info.Name() == ".git" { - return filepath.SkipDir + + err = HashLocal(filepath.Join(opts.Destination, path), h) + if err != nil { + return err } - // Skip directories (only process files) - if !info.Mode().IsRegular() { - return nil + + sum := h.Sum(nil) + + slog.Debug("validate checksum", "real", hex.EncodeToString(sum), "expected", hex.EncodeToString(opts.Hash)) + + if !bytes.Equal(sum, opts.Hash) { + return ErrChecksumMismatch } - // Open file + } + + return nil +} + +func HashLocal(path string, h hash.Hash) error { + info, err := os.Stat(path) + if err != nil { + return err + } + + if info.Mode().IsRegular() { + // Single file f, err := os.Open(path) if err != nil { return err } defer f.Close() - // Write file content to hasher - if _, err := io.Copy(h, f); err != nil { - return err - } - return nil - }) - if err != nil { + _, err = io.Copy(h, f) return err } - return nil + + if info.IsDir() { + // Walk directory + return filepath.Walk(path, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info.IsDir() && info.Name() == ".git" { + return filepath.SkipDir + } + if !info.Mode().IsRegular() { + return nil + } + f, err := os.Open(path) + if err != nil { + return err + } + defer f.Close() + _, err = io.Copy(h, f) + return err + }) + } + + return fmt.Errorf("unsupported file type: %s", path) }