Added context to downloader and moved method from utils to downloader

This commit is contained in:
bytedream 2022-02-13 15:01:05 +01:00
parent c557486089
commit 0a4b9ec96e
2 changed files with 53 additions and 39 deletions

View file

@ -1,6 +1,7 @@
package crunchyroll
import (
"context"
"crypto/aes"
"crypto/cipher"
"fmt"
@ -23,9 +24,14 @@ type Downloader struct {
// If IgnoreExisting is true, existing Filename's and TempDir's may be
// overwritten or deleted
IgnoreExisting bool
// If DeleteTempAfter is true, the temp directory gets deleted afterwards
// If DeleteTempAfter is true, the temp directory gets deleted afterwards.
// Note that in case of a hard signal exit (os.Interrupt, ...) the directory
// will NOT be deleted. In such situations try to catch the signal and
// cancel Context
DeleteTempAfter bool
Context context.Context
// Goroutines is the number of goroutines to download segments with
Goroutines int
@ -43,6 +49,7 @@ func NewDownloader(filename string, goroutines int, onSegmentDownload func(segme
Filename: filename,
TempDir: tmp,
DeleteTempAfter: true,
Context: context.Background(),
Goroutines: goroutines,
OnSegmentDownload: onSegmentDownload,
}
@ -57,7 +64,7 @@ func NewDownloader(filename string, goroutines int, onSegmentDownload func(segme
// The actual crunchyroll video is split up in multiple segments (or video files) which have to be downloaded and merged after to generate a single video file.
// And this function just downloads each of this segment into the given directory.
// See https://en.wikipedia.org/wiki/MPEG_transport_stream for more information
func download(format *Format, tempDir string, goroutines int, onSegmentDownload func(segment *m3u8.MediaSegment, current, total int, file *os.File) error) error {
func download(context context.Context, format *Format, tempDir string, goroutines int, onSegmentDownload func(segment *m3u8.MediaSegment, current, total int, file *os.File) error) error {
resp, err := format.crunchy.Client.Get(format.Video.URI)
if err != nil {
return err
@ -100,15 +107,19 @@ func download(format *Format, tempDir string, goroutines int, onSegmentDownload
i := i
go func() {
defer wg.Done()
for j, segment := range segments[i:end] {
select {
case <-context.Done():
return
case <-quit:
break
return
default:
var file *os.File
k := 1
for ; k < 4; k++ {
file, err = downloadSegment(format, segment, filepath.Join(tempDir, fmt.Sprintf("%d.ts", i+j)), block, iv)
file, err = downloadSegment(context, format, segment, filepath.Join(tempDir, fmt.Sprintf("%d.ts", i+j)), block, iv)
if err == nil {
break
}
@ -129,12 +140,13 @@ func download(format *Format, tempDir string, goroutines int, onSegmentDownload
file.Close()
}
}
wg.Done()
}()
}
wg.Wait()
select {
case <-context.Done():
return context.Err()
case <-quit:
return err
default:
@ -166,9 +178,9 @@ func getCrypt(format *Format, segment *m3u8.MediaSegment) (block cipher.Block, i
}
// downloadSegment downloads a segment, decrypts it and names it after the given index
func downloadSegment(format *Format, segment *m3u8.MediaSegment, filename string, block cipher.Block, iv []byte) (*os.File, error) {
func downloadSegment(context context.Context, format *Format, segment *m3u8.MediaSegment, filename string, block cipher.Block, iv []byte) (*os.File, error) {
// every segment is aes-128 encrypted and has to be decrypted when downloaded
content, err := decryptSegment(format.crunchy.Client, segment, block, iv)
content, err := decryptSegment(context, format.crunchy.Client, segment, block, iv)
if err != nil {
return nil, err
}
@ -184,3 +196,37 @@ func downloadSegment(format *Format, segment *m3u8.MediaSegment, filename string
return file, nil
}
// https://github.com/oopsguy/m3u8/blob/4150e93ec8f4f8718875a02973f5d792648ecb97/tool/crypt.go#L25
func decryptSegment(context context.Context, client *http.Client, segment *m3u8.MediaSegment, block cipher.Block, iv []byte) ([]byte, error) {
req, err := http.NewRequest(http.MethodGet, segment.URI, nil)
if err != nil {
return nil, err
}
req.WithContext(context)
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
raw, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}
blockMode := cipher.NewCBCDecrypter(block, iv[:block.BlockSize()])
decrypted := make([]byte, len(raw))
blockMode.CryptBlocks(decrypted, raw)
raw = pkcs5UnPadding(decrypted)
return raw, nil
}
// https://github.com/oopsguy/m3u8/blob/4150e93ec8f4f8718875a02973f5d792648ecb97/tool/crypt.go#L47
func pkcs5UnPadding(origData []byte) []byte {
length := len(origData)
unPadding := int(origData[length-1])
return origData[:(length - unPadding)]
}