Added context to downloader and moved method from utils to downloader

This commit is contained in:
bytedream 2022-02-13 15:01:05 +01:00
parent c557486089
commit 0a4b9ec96e
2 changed files with 53 additions and 39 deletions

View file

@ -1,6 +1,7 @@
package crunchyroll package crunchyroll
import ( import (
"context"
"crypto/aes" "crypto/aes"
"crypto/cipher" "crypto/cipher"
"fmt" "fmt"
@ -23,9 +24,14 @@ type Downloader struct {
// If IgnoreExisting is true, existing Filename's and TempDir's may be // If IgnoreExisting is true, existing Filename's and TempDir's may be
// overwritten or deleted // overwritten or deleted
IgnoreExisting bool IgnoreExisting bool
// If DeleteTempAfter is true, the temp directory gets deleted afterwards // If DeleteTempAfter is true, the temp directory gets deleted afterwards.
// Note that in case of a hard signal exit (os.Interrupt, ...) the directory
// will NOT be deleted. In such situations try to catch the signal and
// cancel Context
DeleteTempAfter bool DeleteTempAfter bool
Context context.Context
// Goroutines is the number of goroutines to download segments with // Goroutines is the number of goroutines to download segments with
Goroutines int Goroutines int
@ -43,6 +49,7 @@ func NewDownloader(filename string, goroutines int, onSegmentDownload func(segme
Filename: filename, Filename: filename,
TempDir: tmp, TempDir: tmp,
DeleteTempAfter: true, DeleteTempAfter: true,
Context: context.Background(),
Goroutines: goroutines, Goroutines: goroutines,
OnSegmentDownload: onSegmentDownload, OnSegmentDownload: onSegmentDownload,
} }
@ -57,7 +64,7 @@ func NewDownloader(filename string, goroutines int, onSegmentDownload func(segme
// The actual crunchyroll video is split up in multiple segments (or video files) which have to be downloaded and merged after to generate a single video file. // The actual crunchyroll video is split up in multiple segments (or video files) which have to be downloaded and merged after to generate a single video file.
// And this function just downloads each of this segment into the given directory. // And this function just downloads each of this segment into the given directory.
// See https://en.wikipedia.org/wiki/MPEG_transport_stream for more information // See https://en.wikipedia.org/wiki/MPEG_transport_stream for more information
func download(format *Format, tempDir string, goroutines int, onSegmentDownload func(segment *m3u8.MediaSegment, current, total int, file *os.File) error) error { func download(context context.Context, format *Format, tempDir string, goroutines int, onSegmentDownload func(segment *m3u8.MediaSegment, current, total int, file *os.File) error) error {
resp, err := format.crunchy.Client.Get(format.Video.URI) resp, err := format.crunchy.Client.Get(format.Video.URI)
if err != nil { if err != nil {
return err return err
@ -100,15 +107,19 @@ func download(format *Format, tempDir string, goroutines int, onSegmentDownload
i := i i := i
go func() { go func() {
defer wg.Done()
for j, segment := range segments[i:end] { for j, segment := range segments[i:end] {
select { select {
case <-context.Done():
return
case <-quit: case <-quit:
break return
default: default:
var file *os.File var file *os.File
k := 1 k := 1
for ; k < 4; k++ { for ; k < 4; k++ {
file, err = downloadSegment(format, segment, filepath.Join(tempDir, fmt.Sprintf("%d.ts", i+j)), block, iv) file, err = downloadSegment(context, format, segment, filepath.Join(tempDir, fmt.Sprintf("%d.ts", i+j)), block, iv)
if err == nil { if err == nil {
break break
} }
@ -129,12 +140,13 @@ func download(format *Format, tempDir string, goroutines int, onSegmentDownload
file.Close() file.Close()
} }
} }
wg.Done()
}() }()
} }
wg.Wait() wg.Wait()
select { select {
case <-context.Done():
return context.Err()
case <-quit: case <-quit:
return err return err
default: default:
@ -166,9 +178,9 @@ func getCrypt(format *Format, segment *m3u8.MediaSegment) (block cipher.Block, i
} }
// downloadSegment downloads a segment, decrypts it and names it after the given index // downloadSegment downloads a segment, decrypts it and names it after the given index
func downloadSegment(format *Format, segment *m3u8.MediaSegment, filename string, block cipher.Block, iv []byte) (*os.File, error) { func downloadSegment(context context.Context, format *Format, segment *m3u8.MediaSegment, filename string, block cipher.Block, iv []byte) (*os.File, error) {
// every segment is aes-128 encrypted and has to be decrypted when downloaded // every segment is aes-128 encrypted and has to be decrypted when downloaded
content, err := decryptSegment(format.crunchy.Client, segment, block, iv) content, err := decryptSegment(context, format.crunchy.Client, segment, block, iv)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -184,3 +196,37 @@ func downloadSegment(format *Format, segment *m3u8.MediaSegment, filename string
return file, nil return file, nil
} }
// https://github.com/oopsguy/m3u8/blob/4150e93ec8f4f8718875a02973f5d792648ecb97/tool/crypt.go#L25
func decryptSegment(context context.Context, client *http.Client, segment *m3u8.MediaSegment, block cipher.Block, iv []byte) ([]byte, error) {
req, err := http.NewRequest(http.MethodGet, segment.URI, nil)
if err != nil {
return nil, err
}
req.WithContext(context)
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
raw, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}
blockMode := cipher.NewCBCDecrypter(block, iv[:block.BlockSize()])
decrypted := make([]byte, len(raw))
blockMode.CryptBlocks(decrypted, raw)
raw = pkcs5UnPadding(decrypted)
return raw, nil
}
// https://github.com/oopsguy/m3u8/blob/4150e93ec8f4f8718875a02973f5d792648ecb97/tool/crypt.go#L47
func pkcs5UnPadding(origData []byte) []byte {
length := len(origData)
unPadding := int(origData[length-1])
return origData[:(length - unPadding)]
}

View file

@ -1,11 +1,7 @@
package crunchyroll package crunchyroll
import ( import (
"crypto/cipher"
"encoding/json" "encoding/json"
"github.com/grafov/m3u8"
"io/ioutil"
"net/http"
) )
func decodeMapToStruct(m interface{}, s interface{}) error { func decodeMapToStruct(m interface{}, s interface{}) error {
@ -16,34 +12,6 @@ func decodeMapToStruct(m interface{}, s interface{}) error {
return json.Unmarshal(jsonBody, s) return json.Unmarshal(jsonBody, s)
} }
// https://github.com/oopsguy/m3u8/blob/4150e93ec8f4f8718875a02973f5d792648ecb97/tool/crypt.go#L25
func decryptSegment(client *http.Client, segment *m3u8.MediaSegment, block cipher.Block, iv []byte) ([]byte, error) {
resp, err := client.Get(segment.URI)
if err != nil {
return nil, err
}
defer resp.Body.Close()
raw, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}
blockMode := cipher.NewCBCDecrypter(block, iv[:block.BlockSize()])
decrypted := make([]byte, len(raw))
blockMode.CryptBlocks(decrypted, raw)
raw = pkcs5UnPadding(decrypted)
return raw, nil
}
// https://github.com/oopsguy/m3u8/blob/4150e93ec8f4f8718875a02973f5d792648ecb97/tool/crypt.go#L47
func pkcs5UnPadding(origData []byte) []byte {
length := len(origData)
unPadding := int(origData[length-1])
return origData[:(length - unPadding)]
}
func regexGroups(parsed [][]string, subexpNames ...string) map[string]string { func regexGroups(parsed [][]string, subexpNames ...string) map[string]string {
groups := map[string]string{} groups := map[string]string{}
for _, match := range parsed { for _, match := range parsed {