11package baidu_netdisk
22
33import (
4+ "bytes"
45 "context"
56 "crypto/md5"
67 "encoding/hex"
78 "errors"
8- "fmt"
99 "io"
10+ "mime/multipart"
11+ "net/http"
1012 "net/url"
1113 "os"
1214 stdpath "path"
1315 "strconv"
1416 "strings"
15- "sync"
1617 "time"
1718
1819 "github.com/OpenListTeam/OpenList/v4/drivers/base"
20+ "github.com/OpenListTeam/OpenList/v4/internal/cache"
1921 "github.com/OpenListTeam/OpenList/v4/internal/conf"
2022 "github.com/OpenListTeam/OpenList/v4/internal/driver"
2123 "github.com/OpenListTeam/OpenList/v4/internal/errs"
2224 "github.com/OpenListTeam/OpenList/v4/internal/model"
25+ "github.com/OpenListTeam/OpenList/v4/internal/net"
2326 "github.com/OpenListTeam/OpenList/v4/pkg/errgroup"
24- "github.com/OpenListTeam/OpenList/v4/pkg/singleflight"
2527 "github.com/OpenListTeam/OpenList/v4/pkg/utils"
2628 "github.com/avast/retry-go"
27- "github.com/go-resty/resty/v2"
2829 log "github.com/sirupsen/logrus"
2930)
3031
@@ -35,11 +36,8 @@ type BaiduNetdisk struct {
3536 uploadThread int
3637 vipType int // 会员类型,0普通用户(4G/4M)、1普通会员(10G/16M)、2超级会员(20G/32M)
3738
38- upClient * resty.Client // 上传文件使用的http客户端
39- uploadUrlG singleflight.Group [string ]
40- uploadUrlMu sync.RWMutex
41- uploadUrl string // 上传域名
42- uploadUrlUpdateTime time.Time // 上传域名上次更新时间
39+ uploadUrl string // 上传域名
40+ uploadUrlExpirationTime cache.Expirable
4341}
4442
4543var ErrUploadIDExpired = errors .New ("uploadid expired" )
@@ -53,11 +51,6 @@ func (d *BaiduNetdisk) GetAddition() driver.Additional {
5351}
5452
5553func (d * BaiduNetdisk ) Init (ctx context.Context ) error {
56- d .upClient = base .NewRestyClient ().
57- SetTimeout (UPLOAD_TIMEOUT ).
58- SetRetryCount (UPLOAD_RETRY_COUNT ).
59- SetRetryWaitTime (UPLOAD_RETRY_WAIT_TIME ).
60- SetRetryMaxWaitTime (UPLOAD_RETRY_MAX_WAIT_TIME )
6154 d .uploadThread , _ = strconv .Atoi (d .UploadThread )
6255 if d .uploadThread < 1 {
6356 d .uploadThread , d .UploadThread = 1 , "1"
@@ -95,9 +88,10 @@ func (d *BaiduNetdisk) List(ctx context.Context, dir model.Obj, args model.ListA
9588}
9689
9790func (d * BaiduNetdisk ) Link (ctx context.Context , file model.Obj , args model.LinkArgs ) (* model.Link , error ) {
98- if d .DownloadAPI == "crack" {
91+ switch d .DownloadAPI {
92+ case "crack" :
9993 return d .linkCrack (file , args )
100- } else if d . DownloadAPI == "crack_video" {
94+ case "crack_video" :
10195 return d .linkCrackVideo (file , args )
10296 }
10397 return d .linkOfficial (file , args )
@@ -214,7 +208,7 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
214208 tmpF * os.File
215209 err error
216210 )
217- if _ , ok := cache .(io. ReaderAt ); ! ok {
211+ if cache == nil {
218212 tmpF , err = os .CreateTemp (conf .Conf .TempDir , "file-*" )
219213 if err != nil {
220214 return nil , err
@@ -301,27 +295,25 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
301295
302296 // step.2 上传分片
303297uploadLoop:
304- for attempt := 0 ; attempt < 2 ; attempt ++ {
298+ for range 2 {
305299 // 获取上传域名
306300 uploadUrl := d .getUploadUrl (path , precreateResp .Uploadid )
307301 // 并发上传
308302 threadG , upCtx := errgroup .NewGroupWithContext (ctx , d .uploadThread ,
309- retry .Attempts (1 ),
310- retry .Delay (time .Second ),
303+ retry .Attempts (UPLOAD_RETRY_COUNT ),
304+ retry .Delay (UPLOAD_RETRY_WAIT_TIME ),
305+ retry .MaxDelay (UPLOAD_RETRY_MAX_WAIT_TIME ),
311306 retry .DelayType (retry .BackOffDelay ))
312307
313- cacheReaderAt , okReaderAt := cache .(io.ReaderAt )
314- if ! okReaderAt {
315- return nil , fmt .Errorf ("cache object must implement io.ReaderAt interface for upload operations" )
316- }
317-
318308 totalParts := len (precreateResp .BlockList )
319309
320310 for i , partseq := range precreateResp .BlockList {
321- if utils .IsCanceled (upCtx ) || partseq < 0 {
311+ if utils .IsCanceled (upCtx ) {
312+ break
313+ }
314+ if partseq < 0 {
322315 continue
323316 }
324- i , partseq := i , partseq
325317 offset , size := int64 (partseq )* sliceSize , sliceSize
326318 if partseq + 1 == count {
327319 size = lastBlockSize
@@ -335,15 +327,13 @@ uploadLoop:
335327 "uploadid" : precreateResp .Uploadid ,
336328 "partseq" : strconv .Itoa (partseq ),
337329 }
338- section := io .NewSectionReader (cacheReaderAt , offset , size )
339- err := d .uploadSlice (ctx , uploadUrl , params , stream .GetName (), driver . NewLimitedUploadStream ( ctx , section ) )
330+ section := io .NewSectionReader (cache , offset , size )
331+ err := d .uploadSlice (ctx , uploadUrl , params , stream .GetName (), section )
340332 if err != nil {
341333 return err
342334 }
343335 precreateResp .BlockList [i ] = - 1
344- // 当前goroutine还没退出,+1才是真正成功的数量
345- success := threadG .Success () + 1
346- progress := float64 (success ) * 100 / float64 (totalParts )
336+ progress := float64 (threadG .Success ()+ 1 ) * 100 / float64 (totalParts + 1 )
347337 up (progress )
348338 return nil
349339 })
@@ -430,27 +420,58 @@ func (d *BaiduNetdisk) precreate(ctx context.Context, path string, streamSize in
430420}
431421
432422func (d * BaiduNetdisk ) uploadSlice (ctx context.Context , uploadUrl string , params map [string ]string , fileName string , file io.Reader ) error {
433- res , err := d .upClient .R ().
434- SetContext (ctx ).
435- SetQueryParams (params ).
436- SetFileReader ("file" , fileName , file ).
437- Post (uploadUrl + "/rest/2.0/pcs/superfile2" )
423+ b := bytes .NewBuffer (nil )
424+ mw := multipart .NewWriter (b )
425+ _ , err := mw .CreateFormFile ("file" , fileName )
426+ if err != nil {
427+ return err
428+ }
429+ headSize := b .Len ()
430+ err = mw .Close ()
431+ if err != nil {
432+ return err
433+ }
434+ head := bytes .NewReader (b .Bytes ()[:headSize ])
435+ tail := bytes .NewReader (b .Bytes ()[headSize :])
436+ rateLimitedRd := driver .NewLimitedUploadStream (ctx , io .MultiReader (head , file , tail ))
437+
438+ req , err := http .NewRequestWithContext (ctx , http .MethodPost , uploadUrl + "/rest/2.0/pcs/superfile2" , rateLimitedRd )
439+ if err != nil {
440+ return err
441+ }
442+ query := req .URL .Query ()
443+ for k , v := range params {
444+ query .Set (k , v )
445+ }
446+ req .URL .RawQuery = query .Encode ()
447+ req .Header .Set ("Content-Type" , mw .FormDataContentType ())
448+
449+ client := net .NewHttpClient ()
450+ client .Timeout = UPLOAD_TIMEOUT
451+ resp , err := client .Do (req )
452+ if err != nil {
453+ return err
454+ }
455+ defer resp .Body .Close ()
456+ b .Reset ()
457+ _ , err = b .ReadFrom (resp .Body )
438458 if err != nil {
439459 return err
440460 }
441- log .Debugln (res .RawResponse .Status + res .String ())
442- errCode := utils .Json .Get (res .Body (), "error_code" ).ToInt ()
443- errNo := utils .Json .Get (res .Body (), "errno" ).ToInt ()
444- respStr := res .String ()
461+ body := b .Bytes ()
462+ respStr := string (body )
463+ log .Debugln (respStr )
445464 lower := strings .ToLower (respStr )
446465 // 合并 uploadid 过期检测逻辑
447466 if strings .Contains (lower , "uploadid" ) &&
448467 (strings .Contains (lower , "invalid" ) || strings .Contains (lower , "expired" ) || strings .Contains (lower , "not found" )) {
449468 return ErrUploadIDExpired
450469 }
451470
471+ errCode := utils .Json .Get (body , "error_code" ).ToInt ()
472+ errNo := utils .Json .Get (body , "errno" ).ToInt ()
452473 if errCode != 0 || errNo != 0 {
453- return errs .NewErr (errs .StreamIncomplete , "error uploading to baidu, response=%s" , res . String () )
474+ return errs .NewErr (errs .StreamIncomplete , "error uploading to baidu, response=%s" , respStr )
454475 }
455476 return nil
456477}
0 commit comments