From 0a4b9ec96e2a803b69521caa0b06e2b895b1f5ad Mon Sep 17 00:00:00 2001 From: bytedream Date: Sun, 13 Feb 2022 15:01:05 +0100 Subject: [PATCH] Added context to downloader and moved method from utils to downloader --- downloader.go | 60 +++++++++++++++++++++++++++++++++++++++++++++------ utils.go | 32 --------------------------- 2 files changed, 53 insertions(+), 39 deletions(-) diff --git a/downloader.go b/downloader.go index 882904d..8e4d1a1 100644 --- a/downloader.go +++ b/downloader.go @@ -1,6 +1,7 @@ package crunchyroll import ( + "context" "crypto/aes" "crypto/cipher" "fmt" @@ -23,9 +24,14 @@ type Downloader struct { // If IgnoreExisting is true, existing Filename's and TempDir's may be // overwritten or deleted IgnoreExisting bool - // If DeleteTempAfter is true, the temp directory gets deleted afterwards + // If DeleteTempAfter is true, the temp directory gets deleted afterwards. + // Note that in case of a hard signal exit (os.Interrupt, ...) the directory + // will NOT be deleted. In such situations try to catch the signal and + // cancel Context DeleteTempAfter bool + Context context.Context + // Goroutines is the number of goroutines to download segments with Goroutines int @@ -43,6 +49,7 @@ func NewDownloader(filename string, goroutines int, onSegmentDownload func(segme Filename: filename, TempDir: tmp, DeleteTempAfter: true, + Context: context.Background(), Goroutines: goroutines, OnSegmentDownload: onSegmentDownload, } @@ -57,7 +64,7 @@ func NewDownloader(filename string, goroutines int, onSegmentDownload func(segme // The actual crunchyroll video is split up in multiple segments (or video files) which have to be downloaded and merged after to generate a single video file. // And this function just downloads each of this segment into the given directory. // See https://en.wikipedia.org/wiki/MPEG_transport_stream for more information -func download(format *Format, tempDir string, goroutines int, onSegmentDownload func(segment *m3u8.MediaSegment, current, total int, file *os.File) error) error { +func download(context context.Context, format *Format, tempDir string, goroutines int, onSegmentDownload func(segment *m3u8.MediaSegment, current, total int, file *os.File) error) error { resp, err := format.crunchy.Client.Get(format.Video.URI) if err != nil { return err @@ -100,15 +107,19 @@ func download(format *Format, tempDir string, goroutines int, onSegmentDownload i := i go func() { + defer wg.Done() + for j, segment := range segments[i:end] { select { + case <-context.Done(): + return case <-quit: - break + return default: var file *os.File k := 1 for ; k < 4; k++ { - file, err = downloadSegment(format, segment, filepath.Join(tempDir, fmt.Sprintf("%d.ts", i+j)), block, iv) + file, err = downloadSegment(context, format, segment, filepath.Join(tempDir, fmt.Sprintf("%d.ts", i+j)), block, iv) if err == nil { break } @@ -129,12 +140,13 @@ func download(format *Format, tempDir string, goroutines int, onSegmentDownload file.Close() } } - wg.Done() }() } wg.Wait() select { + case <-context.Done(): + return context.Err() case <-quit: return err default: @@ -166,9 +178,9 @@ func getCrypt(format *Format, segment *m3u8.MediaSegment) (block cipher.Block, i } // downloadSegment downloads a segment, decrypts it and names it after the given index -func downloadSegment(format *Format, segment *m3u8.MediaSegment, filename string, block cipher.Block, iv []byte) (*os.File, error) { +func downloadSegment(context context.Context, format *Format, segment *m3u8.MediaSegment, filename string, block cipher.Block, iv []byte) (*os.File, error) { // every segment is aes-128 encrypted and has to be decrypted when downloaded - content, err := decryptSegment(format.crunchy.Client, segment, block, iv) + content, err := decryptSegment(context, format.crunchy.Client, segment, block, iv) if err != nil { return nil, err } @@ -184,3 +196,37 @@ func downloadSegment(format *Format, segment *m3u8.MediaSegment, filename string return file, nil } + +// https://github.com/oopsguy/m3u8/blob/4150e93ec8f4f8718875a02973f5d792648ecb97/tool/crypt.go#L25 +func decryptSegment(context context.Context, client *http.Client, segment *m3u8.MediaSegment, block cipher.Block, iv []byte) ([]byte, error) { + req, err := http.NewRequest(http.MethodGet, segment.URI, nil) + if err != nil { + return nil, err + } + req.WithContext(context) + + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + raw, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + blockMode := cipher.NewCBCDecrypter(block, iv[:block.BlockSize()]) + decrypted := make([]byte, len(raw)) + blockMode.CryptBlocks(decrypted, raw) + raw = pkcs5UnPadding(decrypted) + + return raw, nil +} + +// https://github.com/oopsguy/m3u8/blob/4150e93ec8f4f8718875a02973f5d792648ecb97/tool/crypt.go#L47 +func pkcs5UnPadding(origData []byte) []byte { + length := len(origData) + unPadding := int(origData[length-1]) + return origData[:(length - unPadding)] +} diff --git a/utils.go b/utils.go index 5983a60..a3d4191 100644 --- a/utils.go +++ b/utils.go @@ -1,11 +1,7 @@ package crunchyroll import ( - "crypto/cipher" "encoding/json" - "github.com/grafov/m3u8" - "io/ioutil" - "net/http" ) func decodeMapToStruct(m interface{}, s interface{}) error { @@ -16,34 +12,6 @@ func decodeMapToStruct(m interface{}, s interface{}) error { return json.Unmarshal(jsonBody, s) } -// https://github.com/oopsguy/m3u8/blob/4150e93ec8f4f8718875a02973f5d792648ecb97/tool/crypt.go#L25 -func decryptSegment(client *http.Client, segment *m3u8.MediaSegment, block cipher.Block, iv []byte) ([]byte, error) { - resp, err := client.Get(segment.URI) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - raw, err := ioutil.ReadAll(resp.Body) - if err != nil { - return nil, err - } - - blockMode := cipher.NewCBCDecrypter(block, iv[:block.BlockSize()]) - decrypted := make([]byte, len(raw)) - blockMode.CryptBlocks(decrypted, raw) - raw = pkcs5UnPadding(decrypted) - - return raw, nil -} - -// https://github.com/oopsguy/m3u8/blob/4150e93ec8f4f8718875a02973f5d792648ecb97/tool/crypt.go#L47 -func pkcs5UnPadding(origData []byte) []byte { - length := len(origData) - unPadding := int(origData[length-1]) - return origData[:(length - unPadding)] -} - func regexGroups(parsed [][]string, subexpNames ...string) map[string]string { groups := map[string]string{} for _, match := range parsed {