From cf933a40301620651bb9c1b443f9980d822db3ef Mon Sep 17 00:00:00 2001 From: jiangyong27 Date: Sat, 27 Jul 2024 12:19:12 +0800 Subject: [PATCH] download --- download.go | 208 ++++++++++++++++++++++++++++++++++++++++++++++++++++ go.mod | 2 +- 2 files changed, 209 insertions(+), 1 deletion(-) create mode 100644 download.go diff --git a/download.go b/download.go new file mode 100644 index 0000000..5726b5e --- /dev/null +++ b/download.go @@ -0,0 +1,208 @@ +package goutil + +import ( + "errors" + "fmt" + "io" + "log" + "net/http" + "net/url" + "os" + "path/filepath" + "sync" + "sync/atomic" + "time" +) + +type progressWriter struct { + onProgress func(totalSize int64, processSize int64) + totalSize atomic.Int64 + processSize atomic.Int64 + lock sync.Mutex +} + +func (p *progressWriter) Reset() { + p.totalSize.Store(0) + p.processSize.Store(0) +} +func (p *progressWriter) SetTotal(totalSize int64) { + p.totalSize.Store(totalSize) +} +func (p *progressWriter) Add(size int64) { + p.processSize.Add(size) +} +func (p *progressWriter) Write(dat []byte) (n int, err error) { + n = len(dat) + p.processSize.Add(int64(n)) + if p.onProgress != nil { + p.onProgress(p.totalSize.Load(), p.processSize.Load()) + } + return +} + +type Downloader struct { + concurrency int + resume bool + progress *progressWriter +} + +func NewDownloader(concurrency int, resume bool) *Downloader { + return &Downloader{concurrency: concurrency, resume: resume} +} + +func (d *Downloader) Download(strURL, filename string, onProgress func(totalSize int64, processSize int64)) error { + + resp, err := http.Head(strURL) + if err != nil { + return err + } + if resp.StatusCode >= 300 && resp.StatusCode <= 399 { + strURL = resp.Header.Get("Location") + resp, err = http.Head(strURL) + if err != nil { + return err + } + } + if filename == "" { + u, _ := url.Parse(strURL) + filename = filepath.Base(u.Path) + } + if filename == "" { + contentType := resp.Header.Get("Accept-Ranges") + panic(errors.New(contentType)) //todo + } + + d.progress = &progressWriter{onProgress: onProgress} + if d.concurrency > 1 && resp.StatusCode == http.StatusOK && resp.Header.Get("Accept-Ranges") == "bytes" { + return d.multiDownload(strURL, filename, int(resp.ContentLength)) + } + + return d.singleDownload(strURL, filename) +} + +func (d *Downloader) multiDownload(strURL, filename string, contentLen int) error { + d.progress.SetTotal(int64(contentLen)) + tempFilename := fmt.Sprintf("%d", time.Now().Unix()) + partSize := contentLen / d.concurrency + + var wg sync.WaitGroup + wg.Add(d.concurrency) + + rangeStart := 0 + + for i := 0; i < d.concurrency; i++ { + go func(i, rangeStart int) { + defer wg.Done() + + rangeEnd := rangeStart + partSize + // 最后一部分,总长度不能超过 ContentLength + if i == d.concurrency-1 { + rangeEnd = contentLen + } + + downloaded := 0 + partFileName := d.getPartFilename(tempFilename, i) + if d.resume { + content, err := os.ReadFile(partFileName) + if err == nil { + downloaded = len(content) + } + d.progress.Add(int64(downloaded)) + } + + d.downloadPartial(strURL, partFileName, rangeStart+downloaded, rangeEnd, i) + + }(i, rangeStart) + + rangeStart += partSize + 1 + } + + wg.Wait() + + d.merge(filename, tempFilename) + + return nil +} + +func (d *Downloader) downloadPartial(strURL, partFilename string, rangeStart, rangeEnd, i int) { + if rangeStart >= rangeEnd { + return + } + + req, err := http.NewRequest("GET", strURL, nil) + if err != nil { + log.Fatal(err) + } + + req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", rangeStart, rangeEnd)) + resp, err := http.DefaultClient.Do(req) + if err != nil { + log.Fatal(err) + } + defer resp.Body.Close() + + flags := os.O_CREATE | os.O_WRONLY + if d.resume { + flags |= os.O_APPEND + } + + partFile, err := os.OpenFile(partFilename, flags, 0666) + if err != nil { + log.Fatal(err) + } + defer partFile.Close() + + buf := make([]byte, 32*1024) + _, err = io.CopyBuffer(io.MultiWriter(partFile, d.progress), resp.Body, buf) + if err != nil { + if err == io.EOF { + return + } + log.Fatal(err) + } +} + +func (d *Downloader) merge(filename, tempFilename string) error { + destFile, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY, 0666) + if err != nil { + return err + } + defer destFile.Close() + + for i := 0; i < d.concurrency; i++ { + partFileName := d.getPartFilename(tempFilename, i) + partFile, err := os.Open(partFileName) + if err != nil { + return err + } + io.Copy(destFile, partFile) + partFile.Close() + os.Remove(partFileName) + } + + return nil +} + +func (d *Downloader) getPartFilename(filename string, partNum int) string { + return fmt.Sprintf("%s/%s-%d", os.TempDir(), filepath.Base(filename), partNum) +} + +func (d *Downloader) singleDownload(strURL, filename string) error { + resp, err := http.Get(strURL) + if err != nil { + return err + } + defer resp.Body.Close() + + d.progress.SetTotal(int64(resp.ContentLength)) + + f, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY, 0666) + if err != nil { + return err + } + defer f.Close() + + buf := make([]byte, 32*1024) + _, err = io.CopyBuffer(io.MultiWriter(f, d.progress), resp.Body, buf) + return err +} diff --git a/go.mod b/go.mod index d479d2e..87ea875 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/smbrave/goutil -go 1.18 +go 1.21.4 require ( github.com/sirupsen/logrus v1.9.0