download
This commit is contained in:
parent
dd5a9ae0b0
commit
cf933a4030
|
@ -0,0 +1,208 @@
|
|||
package goutil
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type progressWriter struct {
|
||||
onProgress func(totalSize int64, processSize int64)
|
||||
totalSize atomic.Int64
|
||||
processSize atomic.Int64
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func (p *progressWriter) Reset() {
|
||||
p.totalSize.Store(0)
|
||||
p.processSize.Store(0)
|
||||
}
|
||||
func (p *progressWriter) SetTotal(totalSize int64) {
|
||||
p.totalSize.Store(totalSize)
|
||||
}
|
||||
func (p *progressWriter) Add(size int64) {
|
||||
p.processSize.Add(size)
|
||||
}
|
||||
func (p *progressWriter) Write(dat []byte) (n int, err error) {
|
||||
n = len(dat)
|
||||
p.processSize.Add(int64(n))
|
||||
if p.onProgress != nil {
|
||||
p.onProgress(p.totalSize.Load(), p.processSize.Load())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type Downloader struct {
|
||||
concurrency int
|
||||
resume bool
|
||||
progress *progressWriter
|
||||
}
|
||||
|
||||
func NewDownloader(concurrency int, resume bool) *Downloader {
|
||||
return &Downloader{concurrency: concurrency, resume: resume}
|
||||
}
|
||||
|
||||
func (d *Downloader) Download(strURL, filename string, onProgress func(totalSize int64, processSize int64)) error {
|
||||
|
||||
resp, err := http.Head(strURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.StatusCode >= 300 && resp.StatusCode <= 399 {
|
||||
strURL = resp.Header.Get("Location")
|
||||
resp, err = http.Head(strURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if filename == "" {
|
||||
u, _ := url.Parse(strURL)
|
||||
filename = filepath.Base(u.Path)
|
||||
}
|
||||
if filename == "" {
|
||||
contentType := resp.Header.Get("Accept-Ranges")
|
||||
panic(errors.New(contentType)) //todo
|
||||
}
|
||||
|
||||
d.progress = &progressWriter{onProgress: onProgress}
|
||||
if d.concurrency > 1 && resp.StatusCode == http.StatusOK && resp.Header.Get("Accept-Ranges") == "bytes" {
|
||||
return d.multiDownload(strURL, filename, int(resp.ContentLength))
|
||||
}
|
||||
|
||||
return d.singleDownload(strURL, filename)
|
||||
}
|
||||
|
||||
func (d *Downloader) multiDownload(strURL, filename string, contentLen int) error {
|
||||
d.progress.SetTotal(int64(contentLen))
|
||||
tempFilename := fmt.Sprintf("%d", time.Now().Unix())
|
||||
partSize := contentLen / d.concurrency
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(d.concurrency)
|
||||
|
||||
rangeStart := 0
|
||||
|
||||
for i := 0; i < d.concurrency; i++ {
|
||||
go func(i, rangeStart int) {
|
||||
defer wg.Done()
|
||||
|
||||
rangeEnd := rangeStart + partSize
|
||||
// 最后一部分,总长度不能超过 ContentLength
|
||||
if i == d.concurrency-1 {
|
||||
rangeEnd = contentLen
|
||||
}
|
||||
|
||||
downloaded := 0
|
||||
partFileName := d.getPartFilename(tempFilename, i)
|
||||
if d.resume {
|
||||
content, err := os.ReadFile(partFileName)
|
||||
if err == nil {
|
||||
downloaded = len(content)
|
||||
}
|
||||
d.progress.Add(int64(downloaded))
|
||||
}
|
||||
|
||||
d.downloadPartial(strURL, partFileName, rangeStart+downloaded, rangeEnd, i)
|
||||
|
||||
}(i, rangeStart)
|
||||
|
||||
rangeStart += partSize + 1
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
d.merge(filename, tempFilename)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Downloader) downloadPartial(strURL, partFilename string, rangeStart, rangeEnd, i int) {
|
||||
if rangeStart >= rangeEnd {
|
||||
return
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", strURL, nil)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", rangeStart, rangeEnd))
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
flags := os.O_CREATE | os.O_WRONLY
|
||||
if d.resume {
|
||||
flags |= os.O_APPEND
|
||||
}
|
||||
|
||||
partFile, err := os.OpenFile(partFilename, flags, 0666)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer partFile.Close()
|
||||
|
||||
buf := make([]byte, 32*1024)
|
||||
_, err = io.CopyBuffer(io.MultiWriter(partFile, d.progress), resp.Body, buf)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return
|
||||
}
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Downloader) merge(filename, tempFilename string) error {
|
||||
destFile, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY, 0666)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer destFile.Close()
|
||||
|
||||
for i := 0; i < d.concurrency; i++ {
|
||||
partFileName := d.getPartFilename(tempFilename, i)
|
||||
partFile, err := os.Open(partFileName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
io.Copy(destFile, partFile)
|
||||
partFile.Close()
|
||||
os.Remove(partFileName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Downloader) getPartFilename(filename string, partNum int) string {
|
||||
return fmt.Sprintf("%s/%s-%d", os.TempDir(), filepath.Base(filename), partNum)
|
||||
}
|
||||
|
||||
func (d *Downloader) singleDownload(strURL, filename string) error {
|
||||
resp, err := http.Get(strURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
d.progress.SetTotal(int64(resp.ContentLength))
|
||||
|
||||
f, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY, 0666)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
buf := make([]byte, 32*1024)
|
||||
_, err = io.CopyBuffer(io.MultiWriter(f, d.progress), resp.Body, buf)
|
||||
return err
|
||||
}
|
Loading…
Reference in New Issue