Skip to content

Commit f968e3d

Browse files
committed
fix(baidu_netdisk): improve upload handling
1 parent 12efd82 commit f968e3d

2 files changed

Lines changed: 77 additions & 73 deletions

File tree

drivers/baidu_netdisk/driver.go

Lines changed: 63 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,31 @@
11
package baidu_netdisk
22

33
import (
4+
"bytes"
45
"context"
56
"crypto/md5"
67
"encoding/hex"
78
"errors"
8-
"fmt"
99
"io"
10+
"mime/multipart"
11+
"net/http"
1012
"net/url"
1113
"os"
1214
stdpath "path"
1315
"strconv"
1416
"strings"
15-
"sync"
1617
"time"
1718

1819
"github.com/OpenListTeam/OpenList/v4/drivers/base"
20+
"github.com/OpenListTeam/OpenList/v4/internal/cache"
1921
"github.com/OpenListTeam/OpenList/v4/internal/conf"
2022
"github.com/OpenListTeam/OpenList/v4/internal/driver"
2123
"github.com/OpenListTeam/OpenList/v4/internal/errs"
2224
"github.com/OpenListTeam/OpenList/v4/internal/model"
25+
"github.com/OpenListTeam/OpenList/v4/internal/net"
2326
"github.com/OpenListTeam/OpenList/v4/pkg/errgroup"
24-
"github.com/OpenListTeam/OpenList/v4/pkg/singleflight"
2527
"github.com/OpenListTeam/OpenList/v4/pkg/utils"
2628
"github.com/avast/retry-go"
27-
"github.com/go-resty/resty/v2"
2829
log "github.com/sirupsen/logrus"
2930
)
3031

@@ -35,11 +36,8 @@ type BaiduNetdisk struct {
3536
uploadThread int
3637
vipType int // 会员类型,0普通用户(4G/4M)、1普通会员(10G/16M)、2超级会员(20G/32M)
3738

38-
upClient *resty.Client // 上传文件使用的http客户端
39-
uploadUrlG singleflight.Group[string]
40-
uploadUrlMu sync.RWMutex
41-
uploadUrl string // 上传域名
42-
uploadUrlUpdateTime time.Time // 上传域名上次更新时间
39+
uploadUrl string // 上传域名
40+
uploadUrlExpirationTime cache.Expirable
4341
}
4442

4543
var ErrUploadIDExpired = errors.New("uploadid expired")
@@ -53,11 +51,6 @@ func (d *BaiduNetdisk) GetAddition() driver.Additional {
5351
}
5452

5553
func (d *BaiduNetdisk) Init(ctx context.Context) error {
56-
d.upClient = base.NewRestyClient().
57-
SetTimeout(UPLOAD_TIMEOUT).
58-
SetRetryCount(UPLOAD_RETRY_COUNT).
59-
SetRetryWaitTime(UPLOAD_RETRY_WAIT_TIME).
60-
SetRetryMaxWaitTime(UPLOAD_RETRY_MAX_WAIT_TIME)
6154
d.uploadThread, _ = strconv.Atoi(d.UploadThread)
6255
if d.uploadThread < 1 {
6356
d.uploadThread, d.UploadThread = 1, "1"
@@ -95,9 +88,10 @@ func (d *BaiduNetdisk) List(ctx context.Context, dir model.Obj, args model.ListA
9588
}
9689

9790
func (d *BaiduNetdisk) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) {
98-
if d.DownloadAPI == "crack" {
91+
switch d.DownloadAPI {
92+
case "crack":
9993
return d.linkCrack(file, args)
100-
} else if d.DownloadAPI == "crack_video" {
94+
case "crack_video":
10195
return d.linkCrackVideo(file, args)
10296
}
10397
return d.linkOfficial(file, args)
@@ -214,7 +208,7 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
214208
tmpF *os.File
215209
err error
216210
)
217-
if _, ok := cache.(io.ReaderAt); !ok {
211+
if cache == nil {
218212
tmpF, err = os.CreateTemp(conf.Conf.TempDir, "file-*")
219213
if err != nil {
220214
return nil, err
@@ -301,27 +295,25 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
301295

302296
// step.2 上传分片
303297
uploadLoop:
304-
for attempt := 0; attempt < 2; attempt++ {
298+
for range 2 {
305299
// 获取上传域名
306300
uploadUrl := d.getUploadUrl(path, precreateResp.Uploadid)
307301
// 并发上传
308302
threadG, upCtx := errgroup.NewGroupWithContext(ctx, d.uploadThread,
309-
retry.Attempts(1),
310-
retry.Delay(time.Second),
303+
retry.Attempts(UPLOAD_RETRY_COUNT),
304+
retry.Delay(UPLOAD_RETRY_WAIT_TIME),
305+
retry.MaxDelay(UPLOAD_RETRY_MAX_WAIT_TIME),
311306
retry.DelayType(retry.BackOffDelay))
312307

313-
cacheReaderAt, okReaderAt := cache.(io.ReaderAt)
314-
if !okReaderAt {
315-
return nil, fmt.Errorf("cache object must implement io.ReaderAt interface for upload operations")
316-
}
317-
318308
totalParts := len(precreateResp.BlockList)
319309

320310
for i, partseq := range precreateResp.BlockList {
321-
if utils.IsCanceled(upCtx) || partseq < 0 {
311+
if utils.IsCanceled(upCtx) {
312+
break
313+
}
314+
if partseq < 0 {
322315
continue
323316
}
324-
i, partseq := i, partseq
325317
offset, size := int64(partseq)*sliceSize, sliceSize
326318
if partseq+1 == count {
327319
size = lastBlockSize
@@ -335,15 +327,13 @@ uploadLoop:
335327
"uploadid": precreateResp.Uploadid,
336328
"partseq": strconv.Itoa(partseq),
337329
}
338-
section := io.NewSectionReader(cacheReaderAt, offset, size)
339-
err := d.uploadSlice(ctx, uploadUrl, params, stream.GetName(), driver.NewLimitedUploadStream(ctx, section))
330+
section := io.NewSectionReader(cache, offset, size)
331+
err := d.uploadSlice(ctx, uploadUrl, params, stream.GetName(), section)
340332
if err != nil {
341333
return err
342334
}
343335
precreateResp.BlockList[i] = -1
344-
// 当前goroutine还没退出,+1才是真正成功的数量
345-
success := threadG.Success() + 1
346-
progress := float64(success) * 100 / float64(totalParts)
336+
progress := float64(threadG.Success()+1) * 100 / float64(totalParts+1)
347337
up(progress)
348338
return nil
349339
})
@@ -430,27 +420,58 @@ func (d *BaiduNetdisk) precreate(ctx context.Context, path string, streamSize in
430420
}
431421

432422
func (d *BaiduNetdisk) uploadSlice(ctx context.Context, uploadUrl string, params map[string]string, fileName string, file io.Reader) error {
433-
res, err := d.upClient.R().
434-
SetContext(ctx).
435-
SetQueryParams(params).
436-
SetFileReader("file", fileName, file).
437-
Post(uploadUrl + "/rest/2.0/pcs/superfile2")
423+
b := bytes.NewBuffer(nil)
424+
mw := multipart.NewWriter(b)
425+
_, err := mw.CreateFormFile("file", fileName)
426+
if err != nil {
427+
return err
428+
}
429+
headSize := b.Len()
430+
err = mw.Close()
431+
if err != nil {
432+
return err
433+
}
434+
head := bytes.NewReader(b.Bytes()[:headSize])
435+
tail := bytes.NewReader(b.Bytes()[headSize:])
436+
rateLimitedRd := driver.NewLimitedUploadStream(ctx, io.MultiReader(head, file, tail))
437+
438+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, uploadUrl+"/rest/2.0/pcs/superfile2", rateLimitedRd)
439+
if err != nil {
440+
return err
441+
}
442+
query := req.URL.Query()
443+
for k, v := range params {
444+
query.Set(k, v)
445+
}
446+
req.URL.RawQuery = query.Encode()
447+
req.Header.Set("Content-Type", mw.FormDataContentType())
448+
449+
client := net.NewHttpClient()
450+
client.Timeout = UPLOAD_TIMEOUT
451+
resp, err := client.Do(req)
452+
if err != nil {
453+
return err
454+
}
455+
defer resp.Body.Close()
456+
b.Reset()
457+
_, err = b.ReadFrom(resp.Body)
438458
if err != nil {
439459
return err
440460
}
441-
log.Debugln(res.RawResponse.Status + res.String())
442-
errCode := utils.Json.Get(res.Body(), "error_code").ToInt()
443-
errNo := utils.Json.Get(res.Body(), "errno").ToInt()
444-
respStr := res.String()
461+
body := b.Bytes()
462+
respStr := string(body)
463+
log.Debugln(respStr)
445464
lower := strings.ToLower(respStr)
446465
// 合并 uploadid 过期检测逻辑
447466
if strings.Contains(lower, "uploadid") &&
448467
(strings.Contains(lower, "invalid") || strings.Contains(lower, "expired") || strings.Contains(lower, "not found")) {
449468
return ErrUploadIDExpired
450469
}
451470

471+
errCode := utils.Json.Get(body, "error_code").ToInt()
472+
errNo := utils.Json.Get(body, "errno").ToInt()
452473
if errCode != 0 || errNo != 0 {
453-
return errs.NewErr(errs.StreamIncomplete, "error uploading to baidu, response=%s", res.String())
474+
return errs.NewErr(errs.StreamIncomplete, "error uploading to baidu, response=%s", respStr)
454475
}
455476
return nil
456477
}

drivers/baidu_netdisk/util.go

Lines changed: 14 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@ import (
1212
"unicode"
1313

1414
"github.com/OpenListTeam/OpenList/v4/drivers/base"
15+
"github.com/OpenListTeam/OpenList/v4/internal/cache"
1516
"github.com/OpenListTeam/OpenList/v4/internal/driver"
1617
"github.com/OpenListTeam/OpenList/v4/internal/errs"
1718
"github.com/OpenListTeam/OpenList/v4/internal/model"
1819
"github.com/OpenListTeam/OpenList/v4/internal/op"
20+
"github.com/OpenListTeam/OpenList/v4/pkg/singleflight"
1921
"github.com/OpenListTeam/OpenList/v4/pkg/utils"
2022
"github.com/avast/retry-go"
2123
"github.com/go-resty/resty/v2"
@@ -122,7 +124,7 @@ func (d *BaiduNetdisk) request(furl string, method string, callback base.ReqCall
122124
}
123125
}
124126

125-
if 31023 == errno && d.DownloadAPI == "crack_video" {
127+
if errno == 31023 && d.DownloadAPI == "crack_video" {
126128
result = res.Body()
127129
return nil
128130
}
@@ -247,7 +249,7 @@ func (d *BaiduNetdisk) linkCrack(file model.Obj, _ model.LinkArgs) (*model.Link,
247249
func (d *BaiduNetdisk) linkCrackVideo(file model.Obj, _ model.LinkArgs) (*model.Link, error) {
248250
param := map[string]string{
249251
"type": "VideoURL",
250-
"path": fmt.Sprintf("%s", file.GetPath()),
252+
"path": file.GetPath(),
251253
"fs_id": file.GetID(),
252254
"devuid": "0%1",
253255
"clienttype": "1",
@@ -400,45 +402,26 @@ func (d *BaiduNetdisk) getUploadUrl(path, uploadId string) string {
400402
if !d.UseDynamicUploadAPI {
401403
return d.UploadAPI
402404
}
403-
getCachedUrlFunc := func() string {
404-
d.uploadUrlMu.RLock()
405-
defer d.uploadUrlMu.RUnlock()
406-
if d.uploadUrl != "" && time.Since(d.uploadUrlUpdateTime) < UPLOAD_URL_EXPIRE_TIME {
407-
uploadUrl := d.uploadUrl
408-
return uploadUrl
409-
}
410-
return ""
411-
}
412405
// 检查地址缓存
413-
if uploadUrl := getCachedUrlFunc(); uploadUrl != "" {
414-
return uploadUrl
406+
if d.uploadUrlExpirationTime != nil && !d.uploadUrlExpirationTime.Expired() {
407+
return d.uploadUrl
415408
}
416409

417-
uploadUrlGetFunc := func() (string, error) {
418-
// 双重检查缓存
419-
if uploadUrl := getCachedUrlFunc(); uploadUrl != "" {
420-
return uploadUrl, nil
421-
}
422-
410+
_, err, _ := singleflight.AnyGroup.Do(fmt.Sprintf("BaiduNetdisk.getUploadUrl_%p", d), func() (any, error) {
423411
uploadUrl, err := d.requestForUploadUrl(path, uploadId)
424412
if err != nil {
425-
return "", err
413+
log.Warnf("[baidu_netdisk] get upload URL failed (%v), will use fallback URL: %s", err, d.UploadAPI)
414+
return nil, err
426415
}
427416

428-
d.uploadUrlMu.Lock()
429-
defer d.uploadUrlMu.Unlock()
430417
d.uploadUrl = uploadUrl
431-
d.uploadUrlUpdateTime = time.Now()
432-
return uploadUrl, nil
433-
}
434-
435-
uploadUrl, err, _ := d.uploadUrlG.Do("", uploadUrlGetFunc)
418+
d.uploadUrlExpirationTime = cache.ExpirationTime(time.Now().Add(UPLOAD_URL_EXPIRE_TIME))
419+
return nil, nil
420+
})
436421
if err != nil {
437-
fallback := d.UploadAPI
438-
log.Warnf("[baidu_netdisk] get upload URL failed (%v), will use fallback URL: %s", err, fallback)
439-
return fallback
422+
return d.UploadAPI
440423
}
441-
return uploadUrl
424+
return d.uploadUrl
442425
}
443426

444427
// requestForUploadUrl 请求获取上传地址。

0 commit comments

Comments
 (0)