forked from Plemya-x/ALR
fix: add download cancel via context and update progressbar
This commit is contained in:
@ -134,7 +134,7 @@ type Manifest struct {
|
||||
type Downloader interface {
|
||||
Name() string
|
||||
MatchURL(string) bool
|
||||
Download(Options) (Type, string, error)
|
||||
Download(context.Context, Options) (Type, string, error)
|
||||
}
|
||||
|
||||
// Интерфейс UpdatingDownloader расширяет Downloader методом Update
|
||||
@ -157,7 +157,7 @@ func Download(ctx context.Context, opts Options) (err error) {
|
||||
d := getDownloader(opts.URL)
|
||||
|
||||
if opts.CacheDisabled {
|
||||
_, _, err = d.Download(opts)
|
||||
_, _, err = d.Download(ctx, opts)
|
||||
return err
|
||||
}
|
||||
|
||||
@ -226,7 +226,7 @@ func Download(ctx context.Context, opts Options) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
t, name, err := d.Download(Options{
|
||||
t, name, err := d.Download(ctx, Options{
|
||||
Hash: opts.Hash,
|
||||
HashAlgorithm: opts.HashAlgorithm,
|
||||
Name: opts.Name,
|
||||
|
@ -22,6 +22,7 @@ package dl
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
@ -30,12 +31,8 @@ import (
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mholt/archiver/v4"
|
||||
"github.com/schollz/progressbar/v3"
|
||||
|
||||
"gitea.plemya-x.ru/Plemya-x/ALR/internal/shutils/handlers"
|
||||
)
|
||||
|
||||
// FileDownloader загружает файлы с использованием HTTP
|
||||
@ -54,7 +51,7 @@ func (FileDownloader) MatchURL(string) bool {
|
||||
|
||||
// Download загружает файл с использованием HTTP. Если файл
|
||||
// сжат в поддерживаемом формате, он будет распакован
|
||||
func (FileDownloader) Download(opts Options) (Type, string, error) {
|
||||
func (FileDownloader) Download(ctx context.Context, opts Options) (Type, string, error) {
|
||||
// Разбор URL
|
||||
u, err := url.Parse(opts.URL)
|
||||
if err != nil {
|
||||
@ -94,8 +91,12 @@ func (FileDownloader) Download(opts Options) (Type, string, error) {
|
||||
}
|
||||
r = localFl
|
||||
} else {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
|
||||
if err != nil {
|
||||
return 0, "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
// Выполнение HTTP GET запроса
|
||||
res, err := http.Get(u.String())
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return 0, "", err
|
||||
}
|
||||
@ -114,29 +115,13 @@ func (FileDownloader) Download(opts Options) (Type, string, error) {
|
||||
if err != nil {
|
||||
return 0, "", err
|
||||
}
|
||||
defer fl.Close()
|
||||
|
||||
var bar io.WriteCloser
|
||||
var out io.WriteCloser
|
||||
// Настройка индикатора прогресса
|
||||
if opts.Progress != nil {
|
||||
bar = progressbar.NewOptions64(
|
||||
size,
|
||||
progressbar.OptionSetDescription(name),
|
||||
progressbar.OptionSetWriter(opts.Progress),
|
||||
progressbar.OptionShowBytes(true),
|
||||
progressbar.OptionSetWidth(10),
|
||||
progressbar.OptionThrottle(65*time.Millisecond),
|
||||
progressbar.OptionShowCount(),
|
||||
progressbar.OptionOnCompletion(func() {
|
||||
_, _ = io.WriteString(opts.Progress, "\n")
|
||||
}),
|
||||
progressbar.OptionSpinnerType(14),
|
||||
progressbar.OptionFullWidth(),
|
||||
progressbar.OptionSetRenderBlankState(true),
|
||||
)
|
||||
defer bar.Close()
|
||||
out = NewProgressWriter(fl, size, name, opts.Progress)
|
||||
} else {
|
||||
bar = handlers.NopRWC{}
|
||||
out = fl
|
||||
}
|
||||
|
||||
h, err := opts.NewHash()
|
||||
@ -147,9 +132,9 @@ func (FileDownloader) Download(opts Options) (Type, string, error) {
|
||||
var w io.Writer
|
||||
// Настройка MultiWriter для записи в файл, хеш и индикатор прогресса
|
||||
if opts.Hash != nil {
|
||||
w = io.MultiWriter(fl, h, bar)
|
||||
w = io.MultiWriter(h, out)
|
||||
} else {
|
||||
w = io.MultiWriter(fl, bar)
|
||||
w = io.MultiWriter(out)
|
||||
}
|
||||
|
||||
// Копирование содержимого из источника в файл назначения
|
||||
@ -158,6 +143,7 @@ func (FileDownloader) Download(opts Options) (Type, string, error) {
|
||||
return 0, "", err
|
||||
}
|
||||
r.Close()
|
||||
out.Close()
|
||||
|
||||
// Проверка контрольной суммы
|
||||
if opts.Hash != nil {
|
||||
|
@ -20,6 +20,7 @@
|
||||
package dl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/url"
|
||||
"path"
|
||||
@ -47,7 +48,7 @@ func (GitDownloader) MatchURL(u string) bool {
|
||||
// Download uses git to clone the repository from the specified URL.
|
||||
// It allows specifying the revision, depth and recursion options
|
||||
// via query string
|
||||
func (GitDownloader) Download(opts Options) (Type, string, error) {
|
||||
func (GitDownloader) Download(ctx context.Context, opts Options) (Type, string, error) {
|
||||
u, err := url.Parse(opts.URL)
|
||||
if err != nil {
|
||||
return 0, "", err
|
||||
@ -89,7 +90,7 @@ func (GitDownloader) Download(opts Options) (Type, string, error) {
|
||||
co.RecurseSubmodules = git.DefaultSubmoduleRecursionDepth
|
||||
}
|
||||
|
||||
r, err := git.PlainClone(opts.Destination, false, co)
|
||||
r, err := git.PlainCloneContext(ctx, opts.Destination, false, co)
|
||||
if err != nil {
|
||||
return 0, "", err
|
||||
}
|
||||
|
246
internal/dl/progress_tui.go
Normal file
246
internal/dl/progress_tui.go
Normal file
@ -0,0 +1,246 @@
|
||||
// ALR - Any Linux Repository
|
||||
// Copyright (C) 2025 Евгений Храмов
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package dl
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/bubbles/progress"
|
||||
"github.com/charmbracelet/bubbles/spinner"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/leonelquinteros/gotext"
|
||||
)
|
||||
|
||||
type model struct {
|
||||
progress progress.Model
|
||||
spinner spinner.Model
|
||||
percent float64
|
||||
speed float64
|
||||
done bool
|
||||
useSpinner bool
|
||||
filename string
|
||||
|
||||
total int64
|
||||
downloaded int64
|
||||
elapsed time.Duration
|
||||
remaining time.Duration
|
||||
|
||||
width int
|
||||
}
|
||||
|
||||
func (m model) Init() tea.Cmd {
|
||||
if m.useSpinner {
|
||||
return m.spinner.Tick
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if m.done {
|
||||
return m, tea.Quit
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case progressUpdate:
|
||||
m.percent = msg.percent
|
||||
m.speed = msg.speed
|
||||
m.downloaded = msg.downloaded
|
||||
m.total = msg.total
|
||||
m.elapsed = time.Duration(msg.elapsed) * time.Second
|
||||
m.remaining = time.Duration(msg.remaining) * time.Second
|
||||
if m.percent >= 1.0 {
|
||||
m.done = true
|
||||
}
|
||||
return m, nil
|
||||
case tea.WindowSizeMsg:
|
||||
m.width = msg.Width
|
||||
return m, nil
|
||||
case progress.FrameMsg:
|
||||
if !m.useSpinner {
|
||||
progressModel, cmd := m.progress.Update(msg)
|
||||
m.progress = progressModel.(progress.Model)
|
||||
return m, cmd
|
||||
}
|
||||
case spinner.TickMsg:
|
||||
if m.useSpinner {
|
||||
spinnerModel, cmd := m.spinner.Update(msg)
|
||||
m.spinner = spinnerModel
|
||||
return m, cmd
|
||||
}
|
||||
case tea.KeyMsg:
|
||||
if msg.String() == "q" {
|
||||
return m, tea.Quit
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m model) View() string {
|
||||
if m.done {
|
||||
return gotext.Get("%s: done!\n", m.filename)
|
||||
}
|
||||
if m.useSpinner {
|
||||
return gotext.Get(
|
||||
"%s %s downloading at %s/s\n",
|
||||
m.filename,
|
||||
m.spinner.View(),
|
||||
prettyByteSize(int64(m.speed)),
|
||||
)
|
||||
}
|
||||
|
||||
leftPart := m.filename
|
||||
|
||||
rightPart := fmt.Sprintf("%.2f%% (%s/%s, %s/s) [%v:%v]\n", m.percent*100,
|
||||
prettyByteSize(m.downloaded),
|
||||
prettyByteSize(m.total),
|
||||
prettyByteSize(int64(m.speed)),
|
||||
m.elapsed,
|
||||
m.remaining,
|
||||
)
|
||||
|
||||
m.progress.Width = m.width - len(leftPart) - len(rightPart) - 6
|
||||
bar := m.progress.ViewAs(m.percent)
|
||||
return fmt.Sprintf(
|
||||
"%s %s %s",
|
||||
leftPart,
|
||||
bar,
|
||||
rightPart,
|
||||
)
|
||||
}
|
||||
|
||||
func prettyByteSize(b int64) string {
|
||||
bf := float64(b)
|
||||
for _, unit := range []string{"", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"} {
|
||||
if math.Abs(bf) < 1024.0 {
|
||||
return fmt.Sprintf("%3.1f%sB", bf, unit)
|
||||
}
|
||||
bf /= 1024.0
|
||||
}
|
||||
return fmt.Sprintf("%.1fYiB", bf)
|
||||
}
|
||||
|
||||
type progressUpdate struct {
|
||||
percent float64
|
||||
speed float64
|
||||
total int64
|
||||
|
||||
downloaded int64
|
||||
elapsed float64
|
||||
remaining float64
|
||||
}
|
||||
|
||||
type ProgressWriter struct {
|
||||
baseWriter io.WriteCloser
|
||||
total int64
|
||||
downloaded int64
|
||||
startTime time.Time
|
||||
onProgress func(progressUpdate)
|
||||
lastReported time.Time
|
||||
doneChan chan struct{}
|
||||
}
|
||||
|
||||
func (pw *ProgressWriter) Write(p []byte) (int, error) {
|
||||
n, err := pw.baseWriter.Write(p)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
pw.downloaded += int64(n)
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(pw.startTime).Seconds()
|
||||
speed := float64(pw.downloaded) / elapsed
|
||||
var remaining, percent float64
|
||||
if pw.total > 0 {
|
||||
remaining = (float64(pw.total) - float64(pw.downloaded)) / speed
|
||||
percent = float64(pw.downloaded) / float64(pw.total)
|
||||
}
|
||||
|
||||
if now.Sub(pw.lastReported) > 100*time.Millisecond {
|
||||
pw.onProgress(progressUpdate{
|
||||
percent: percent,
|
||||
speed: speed,
|
||||
total: pw.total,
|
||||
downloaded: pw.downloaded,
|
||||
elapsed: elapsed,
|
||||
remaining: remaining,
|
||||
})
|
||||
pw.lastReported = now
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (pw *ProgressWriter) Close() error {
|
||||
pw.onProgress(progressUpdate{
|
||||
percent: 1,
|
||||
speed: 0,
|
||||
downloaded: pw.downloaded,
|
||||
})
|
||||
<-pw.doneChan
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewProgressWriter(base io.WriteCloser, max int64, filename string, out io.Writer) *ProgressWriter {
|
||||
var m *model
|
||||
if max == -1 {
|
||||
m = &model{
|
||||
spinner: spinner.New(),
|
||||
useSpinner: true,
|
||||
filename: filename,
|
||||
}
|
||||
m.spinner.Spinner = spinner.Dot
|
||||
} else {
|
||||
m = &model{
|
||||
progress: progress.New(
|
||||
progress.WithDefaultGradient(),
|
||||
progress.WithoutPercentage(),
|
||||
),
|
||||
useSpinner: false,
|
||||
filename: filename,
|
||||
}
|
||||
}
|
||||
|
||||
p := tea.NewProgram(m,
|
||||
tea.WithInput(nil),
|
||||
tea.WithOutput(out),
|
||||
)
|
||||
|
||||
pw := &ProgressWriter{
|
||||
baseWriter: base,
|
||||
total: max,
|
||||
startTime: time.Now(),
|
||||
doneChan: make(chan struct{}),
|
||||
onProgress: func(update progressUpdate) {
|
||||
p.Send(update)
|
||||
},
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer close(pw.doneChan)
|
||||
if _, err := p.Run(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error running progress writer: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}()
|
||||
|
||||
return pw
|
||||
}
|
@ -20,6 +20,7 @@
|
||||
package dl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
@ -49,7 +50,7 @@ func (TorrentDownloader) MatchURL(u string) bool {
|
||||
}
|
||||
|
||||
// Download downloads a file over the BitTorrent protocol.
|
||||
func (TorrentDownloader) Download(opts Options) (Type, string, error) {
|
||||
func (TorrentDownloader) Download(ctx context.Context, opts Options) (Type, string, error) {
|
||||
aria2Path, err := exec.LookPath("aria2c")
|
||||
if err != nil {
|
||||
return 0, "", ErrAria2NotFound
|
||||
@ -57,7 +58,7 @@ func (TorrentDownloader) Download(opts Options) (Type, string, error) {
|
||||
|
||||
opts.URL = strings.TrimPrefix(opts.URL, "torrent+")
|
||||
|
||||
cmd := exec.Command(aria2Path, "--summary-interval=0", "--log-level=warn", "--seed-time=0", "--dir="+opts.Destination, opts.URL)
|
||||
cmd := exec.CommandContext(ctx, aria2Path, "--summary-interval=0", "--log-level=warn", "--seed-time=0", "--dir="+opts.Destination, opts.URL)
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
err = cmd.Run()
|
||||
|
Reference in New Issue
Block a user