Add download/request speed limiter (#250)

This commit is contained in:
bytedream 2023-12-10 02:52:42 +01:00
parent f9e431e181
commit be3248a4f9
9 changed files with 230 additions and 76 deletions

90
Cargo.lock generated
View file

@ -106,6 +106,18 @@ version = "1.0.75"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6" checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6"
[[package]]
name = "async-speed-limit"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97d287ccbfb44ae20287d2f9c72ad9e560d50810883870697db5b320c541f183"
dependencies = [
"futures-core",
"futures-io",
"futures-timer",
"pin-project-lite",
]
[[package]] [[package]]
name = "async-trait" name = "async-trait"
version = "0.1.74" version = "0.1.74"
@ -385,6 +397,7 @@ name = "crunchy-cli-core"
version = "3.1.1" version = "3.1.1"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"async-speed-limit",
"async-trait", "async-trait",
"chrono", "chrono",
"clap", "clap",
@ -394,6 +407,8 @@ dependencies = [
"dialoguer", "dialoguer",
"dirs", "dirs",
"fs2", "fs2",
"futures-util",
"http",
"indicatif", "indicatif",
"lazy_static", "lazy_static",
"log", "log",
@ -409,13 +424,14 @@ dependencies = [
"sys-locale", "sys-locale",
"tempfile", "tempfile",
"tokio", "tokio",
"tower-service",
] ]
[[package]] [[package]]
name = "crunchyroll-rs" name = "crunchyroll-rs"
version = "0.7.0" version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc4ce434784eee7892ad8c3d1ecaea0c0858db51bbb295b474db38c256e8d2fb" checksum = "56d0d1d6a75ed27b2dc93b84fa0667cffb92a513d206b8ccd1b895fab5ad2e9c"
dependencies = [ dependencies = [
"aes", "aes",
"async-trait", "async-trait",
@ -434,14 +450,15 @@ dependencies = [
"serde_urlencoded", "serde_urlencoded",
"smart-default", "smart-default",
"tokio", "tokio",
"webpki-roots", "tower-service",
"webpki-roots 0.26.0",
] ]
[[package]] [[package]]
name = "crunchyroll-rs-internal" name = "crunchyroll-rs-internal"
version = "0.7.0" version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be840f8cf2ce6afc9a9eae268d41423093141ec88f664a515d5ed2a85a66fb60" checksum = "bdd0a20e750354408294b674a50b6d5dacec315ff9ead7c3a7c093f1e3594335"
dependencies = [ dependencies = [
"darling", "darling",
"quote", "quote",
@ -687,6 +704,23 @@ version = "0.3.29"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb1d22c66e66d9d72e1758f0bd7d4fd0bee04cad842ee34587d68c07e45d088c" checksum = "eb1d22c66e66d9d72e1758f0bd7d4fd0bee04cad842ee34587d68c07e45d088c"
[[package]]
name = "futures-io"
version = "0.3.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8bf34a163b5c4c52d0478a4d757da8fb65cabef42ba90515efee0f6f9fa45aaa"
[[package]]
name = "futures-macro"
version = "0.3.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "futures-sink" name = "futures-sink"
version = "0.3.29" version = "0.3.29"
@ -699,6 +733,12 @@ version = "0.3.29"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "efd193069b0ddadc69c46389b740bbccdd97203899b48d09c5f7969591d6bae2" checksum = "efd193069b0ddadc69c46389b740bbccdd97203899b48d09c5f7969591d6bae2"
[[package]]
name = "futures-timer"
version = "3.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e64b03909df88034c26dc1547e8970b91f98bdb65165d6a4e9110d94263dbb2c"
[[package]] [[package]]
name = "futures-util" name = "futures-util"
version = "0.3.29" version = "0.3.29"
@ -706,7 +746,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a19526d624e703a3179b3d322efec918b6246ea0fa51d41124525f00f1cc8104" checksum = "a19526d624e703a3179b3d322efec918b6246ea0fa51d41124525f00f1cc8104"
dependencies = [ dependencies = [
"futures-core", "futures-core",
"futures-io",
"futures-macro",
"futures-sink",
"futures-task", "futures-task",
"memchr",
"pin-project-lite", "pin-project-lite",
"pin-utils", "pin-utils",
"slab", "slab",
@ -1409,12 +1453,14 @@ dependencies = [
"tokio-native-tls", "tokio-native-tls",
"tokio-rustls", "tokio-rustls",
"tokio-socks", "tokio-socks",
"tokio-util",
"tower-service", "tower-service",
"url", "url",
"wasm-bindgen", "wasm-bindgen",
"wasm-bindgen-futures", "wasm-bindgen-futures",
"wasm-streams",
"web-sys", "web-sys",
"webpki-roots", "webpki-roots 0.25.2",
"winreg", "winreg",
] ]
@ -1490,6 +1536,12 @@ dependencies = [
"base64", "base64",
] ]
[[package]]
name = "rustls-pki-types"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7673e0aa20ee4937c6aacfc12bb8341cfbf054cdd21df6bec5fd0629fe9339b"
[[package]] [[package]]
name = "rustls-webpki" name = "rustls-webpki"
version = "0.101.7" version = "0.101.7"
@ -1813,9 +1865,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]] [[package]]
name = "tokio" name = "tokio"
version = "1.34.0" version = "1.35.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d0c014766411e834f7af5b8f4cf46257aab4036ca95e9d2c144a10f59ad6f5b9" checksum = "841d45b238a16291a4e1584e61820b8ae57d696cc5015c459c229ccc6990cc1c"
dependencies = [ dependencies = [
"backtrace", "backtrace",
"bytes", "bytes",
@ -2065,6 +2117,19 @@ version = "0.2.88"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d046c5d029ba91a1ed14da14dca44b68bf2f124cfbaf741c54151fdb3e0750b" checksum = "0d046c5d029ba91a1ed14da14dca44b68bf2f124cfbaf741c54151fdb3e0750b"
[[package]]
name = "wasm-streams"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4609d447824375f43e1ffbc051b50ad8f4b3ae8219680c94452ea05eb240ac7"
dependencies = [
"futures-util",
"js-sys",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
]
[[package]] [[package]]
name = "web-sys" name = "web-sys"
version = "0.3.65" version = "0.3.65"
@ -2081,6 +2146,15 @@ version = "0.25.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "14247bb57be4f377dfb94c72830b8ce8fc6beac03cf4bf7b9732eadd414123fc" checksum = "14247bb57be4f377dfb94c72830b8ce8fc6beac03cf4bf7b9732eadd414123fc"
[[package]]
name = "webpki-roots"
version = "0.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0de2cfda980f21be5a7ed2eadb3e6fe074d56022bea2cdeb1a62eb220fc04188"
dependencies = [
"rustls-pki-types",
]
[[package]] [[package]]
name = "winapi" name = "winapi"
version = "0.3.9" version = "0.3.9"

View file

@ -13,21 +13,24 @@ openssl-tls-static = ["reqwest/native-tls", "reqwest/native-tls-alpn", "reqwest/
[dependencies] [dependencies]
anyhow = "1.0" anyhow = "1.0"
async-speed-limit = "0.4"
async-trait = "0.1" async-trait = "0.1"
clap = { version = "4.4", features = ["derive", "string"] } clap = { version = "4.4", features = ["derive", "string"] }
chrono = "0.4" chrono = "0.4"
crunchyroll-rs = { version = "0.7.0", features = ["dash-stream"] } crunchyroll-rs = { version = "0.8.0", features = ["dash-stream", "experimental-stabilizations", "tower"] }
ctrlc = "3.4" ctrlc = "3.4"
dialoguer = { version = "0.11", default-features = false } dialoguer = { version = "0.11", default-features = false }
dirs = "5.0" dirs = "5.0"
derive_setters = "0.1" derive_setters = "0.1"
futures-util = { version = "0.3", features = ["io"] }
fs2 = "0.4" fs2 = "0.4"
http = "0.2"
indicatif = "0.17" indicatif = "0.17"
lazy_static = "1.4" lazy_static = "1.4"
log = { version = "0.4", features = ["std"] } log = { version = "0.4", features = ["std"] }
num_cpus = "1.16" num_cpus = "1.16"
regex = "1.10" regex = "1.10"
reqwest = { version = "0.11", default-features = false, features = ["socks"] } reqwest = { version = "0.11", default-features = false, features = ["socks", "stream"] }
serde = "1.0" serde = "1.0"
serde_json = "1.0" serde_json = "1.0"
serde_plain = "1.0" serde_plain = "1.0"
@ -35,6 +38,7 @@ shlex = "1.2"
sys-locale = "0.3" sys-locale = "0.3"
tempfile = "3.8" tempfile = "3.8"
tokio = { version = "1.34", features = ["io-util", "macros", "net", "rt-multi-thread", "time"] } tokio = { version = "1.34", features = ["io-util", "macros", "net", "rt-multi-thread", "time"] }
tower-service = "0.3"
rustls-native-certs = { version = "0.6", optional = true } rustls-native-certs = { version = "0.6", optional = true }
[target.'cfg(not(target_os = "windows"))'.dependencies] [target.'cfg(not(target_os = "windows"))'.dependencies]

View file

@ -242,7 +242,7 @@ impl Execute for Archive {
format.visual_output(&path); format.visual_output(&path);
downloader.download(&ctx, &path).await? downloader.download(&path).await?
} }
} }

View file

@ -257,7 +257,7 @@ impl Execute for Download {
format.visual_output(&path); format.visual_output(&path);
downloader.download(&ctx, &path).await? downloader.download(&path).await?
} }
} }

View file

@ -17,6 +17,7 @@ mod login;
mod search; mod search;
mod utils; mod utils;
use crate::utils::rate_limit::RateLimiterService;
pub use archive::Archive; pub use archive::Archive;
use dialoguer::console::Term; use dialoguer::console::Term;
pub use download::Download; pub use download::Download;
@ -66,6 +67,15 @@ pub struct Cli {
#[arg(global = true, long)] #[arg(global = true, long)]
user_agent: Option<String>, user_agent: Option<String>,
#[arg(
help = "Maximal speed to download/request (may be a bit off here and there). Must be in format of <number>[B|KB|MB]"
)]
#[arg(
long_help = "Maximal speed to download/request (may be a bit off here and there). Must be in format of <number>[B|KB|MB] (e.g. 500KB or 10MB)"
)]
#[arg(global = true, long, value_parser = crate::utils::clap::clap_parse_speed_limit)]
speed_limit: Option<u32>,
#[clap(subcommand)] #[clap(subcommand)]
command: Command, command: Command,
} }
@ -264,9 +274,7 @@ async fn crunchyroll_session(cli: &mut Cli) -> Result<Crunchyroll> {
lang lang
}; };
let mut builder = Crunchyroll::builder() let client = {
.locale(locale)
.client({
let mut builder = CrunchyrollBuilder::predefined_client_builder(); let mut builder = CrunchyrollBuilder::predefined_client_builder();
if let Some(p) = &cli.proxy { if let Some(p) = &cli.proxy {
builder = builder.proxy(p.clone()) builder = builder.proxy(p.clone())
@ -291,12 +299,19 @@ async fn crunchyroll_session(cli: &mut Cli) -> Result<Crunchyroll> {
let client = builder.build().unwrap(); let client = builder.build().unwrap();
client client
}) };
let mut builder = Crunchyroll::builder()
.locale(locale)
.client(client.clone())
.stabilization_locales(cli.experimental_fixes) .stabilization_locales(cli.experimental_fixes)
.stabilization_season_number(cli.experimental_fixes); .stabilization_season_number(cli.experimental_fixes);
if let Command::Download(download) = &cli.command { if let Command::Download(download) = &cli.command {
builder = builder.preferred_audio_locale(download.audio.clone()) builder = builder.preferred_audio_locale(download.audio.clone())
} }
if let Some(speed_limit) = cli.speed_limit {
builder = builder.middleware(RateLimiterService::new(speed_limit, client));
}
let root_login_methods_count = cli.login_method.credentials.is_some() as u8 let root_login_methods_count = cli.login_method.credentials.is_some() as u8
+ cli.login_method.etp_rt.is_some() as u8 + cli.login_method.etp_rt.is_some() as u8

View file

@ -9,3 +9,20 @@ pub fn clap_parse_resolution(s: &str) -> Result<Resolution, String> {
pub fn clap_parse_proxy(s: &str) -> Result<Proxy, String> { pub fn clap_parse_proxy(s: &str) -> Result<Proxy, String> {
Proxy::all(s).map_err(|e| e.to_string()) Proxy::all(s).map_err(|e| e.to_string())
} }
pub fn clap_parse_speed_limit(s: &str) -> Result<u32, String> {
let quota = s.to_lowercase();
let bytes = if let Ok(b) = quota.parse() {
b
} else if let Ok(b) = quota.trim_end_matches('b').parse::<u32>() {
b
} else if let Ok(kb) = quota.trim_end_matches("kb").parse::<u32>() {
kb * 1024
} else if let Ok(mb) = quota.trim_end_matches("mb").parse::<u32>() {
mb * 1024 * 1024
} else {
return Err("Invalid speed limit".to_string());
};
Ok(bytes)
}

View file

@ -1,4 +1,3 @@
use crate::utils::context::Context;
use crate::utils::ffmpeg::FFmpegPreset; use crate::utils::ffmpeg::FFmpegPreset;
use crate::utils::os::{is_special_file, temp_directory, temp_named_pipe, tempfile}; use crate::utils::os::{is_special_file, temp_directory, temp_named_pipe, tempfile};
use anyhow::{bail, Result}; use anyhow::{bail, Result};
@ -8,7 +7,7 @@ use crunchyroll_rs::Locale;
use indicatif::{ProgressBar, ProgressDrawTarget, ProgressFinish, ProgressStyle}; use indicatif::{ProgressBar, ProgressDrawTarget, ProgressFinish, ProgressStyle};
use log::{debug, warn, LevelFilter}; use log::{debug, warn, LevelFilter};
use regex::Regex; use regex::Regex;
use std::borrow::{Borrow, BorrowMut}; use std::borrow::Borrow;
use std::cmp::Ordering; use std::cmp::Ordering;
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::env; use std::env;
@ -118,7 +117,7 @@ impl Downloader {
self.formats.push(format); self.formats.push(format);
} }
pub async fn download(mut self, ctx: &Context, dst: &Path) -> Result<()> { pub async fn download(mut self, dst: &Path) -> Result<()> {
// `.unwrap_or_default()` here unless https://doc.rust-lang.org/stable/std/path/fn.absolute.html // `.unwrap_or_default()` here unless https://doc.rust-lang.org/stable/std/path/fn.absolute.html
// gets stabilized as the function might throw error on weird file paths // gets stabilized as the function might throw error on weird file paths
let required = self.check_free_space(dst).await.unwrap_or_default(); let required = self.check_free_space(dst).await.unwrap_or_default();
@ -197,7 +196,6 @@ impl Downloader {
for (i, format) in self.formats.iter().enumerate() { for (i, format) in self.formats.iter().enumerate() {
let video_path = self let video_path = self
.download_video( .download_video(
ctx,
&format.video.0, &format.video.0,
format!("{:<1$}", format!("Downloading video #{}", i + 1), fmt_space), format!("{:<1$}", format!("Downloading video #{}", i + 1), fmt_space),
) )
@ -205,7 +203,6 @@ impl Downloader {
for (variant_data, locale) in format.audios.iter() { for (variant_data, locale) in format.audios.iter() {
let audio_path = self let audio_path = self
.download_audio( .download_audio(
ctx,
variant_data, variant_data,
format!("{:<1$}", format!("Downloading {} audio", locale), fmt_space), format!("{:<1$}", format!("Downloading {} audio", locale), fmt_space),
) )
@ -554,14 +551,13 @@ impl Downloader {
async fn download_video( async fn download_video(
&self, &self,
ctx: &Context,
variant_data: &VariantData, variant_data: &VariantData,
message: String, message: String,
) -> Result<TempPath> { ) -> Result<TempPath> {
let tempfile = tempfile(".mp4")?; let tempfile = tempfile(".mp4")?;
let (mut file, path) = tempfile.into_parts(); let (mut file, path) = tempfile.into_parts();
self.download_segments(ctx, &mut file, message, variant_data) self.download_segments(&mut file, message, variant_data)
.await?; .await?;
Ok(path) Ok(path)
@ -569,14 +565,13 @@ impl Downloader {
async fn download_audio( async fn download_audio(
&self, &self,
ctx: &Context,
variant_data: &VariantData, variant_data: &VariantData,
message: String, message: String,
) -> Result<TempPath> { ) -> Result<TempPath> {
let tempfile = tempfile(".m4a")?; let tempfile = tempfile(".m4a")?;
let (mut file, path) = tempfile.into_parts(); let (mut file, path) = tempfile.into_parts();
self.download_segments(ctx, &mut file, message, variant_data) self.download_segments(&mut file, message, variant_data)
.await?; .await?;
Ok(path) Ok(path)
@ -601,7 +596,6 @@ impl Downloader {
async fn download_segments( async fn download_segments(
&self, &self,
ctx: &Context,
writer: &mut impl Write, writer: &mut impl Write,
message: String, message: String,
variant_data: &VariantData, variant_data: &VariantData,
@ -609,7 +603,6 @@ impl Downloader {
let segments = variant_data.segments().await?; let segments = variant_data.segments().await?;
let total_segments = segments.len(); let total_segments = segments.len();
let client = Arc::new(ctx.crunchy.client());
let count = Arc::new(Mutex::new(0)); let count = Arc::new(Mutex::new(0));
let progress = if log::max_level() == LevelFilter::Info { let progress = if log::max_level() == LevelFilter::Info {
@ -643,7 +636,6 @@ impl Downloader {
let mut join_set: JoinSet<Result<()>> = JoinSet::new(); let mut join_set: JoinSet<Result<()>> = JoinSet::new();
for num in 0..cpus { for num in 0..cpus {
let thread_client = client.clone();
let thread_sender = sender.clone(); let thread_sender = sender.clone();
let thread_segments = segs.remove(0); let thread_segments = segs.remove(0);
let thread_count = count.clone(); let thread_count = count.clone();
@ -656,42 +648,21 @@ impl Downloader {
let download = || async move { let download = || async move {
for (i, segment) in thread_segments.into_iter().enumerate() { for (i, segment) in thread_segments.into_iter().enumerate() {
let mut retry_count = 0; let mut retry_count = 0;
let mut buf = loop { let buf = loop {
let request = thread_client let mut buf = vec![];
.get(&segment.url) match segment.write_to(&mut buf).await {
.timeout(Duration::from_secs(60)) Ok(_) => break buf,
.send();
let response = match request.await {
Ok(r) => r,
Err(e) => { Err(e) => {
if retry_count == 5 {
bail!("Max retry count reached ({}), multiple errors occurred while receiving segment {}: {}", retry_count, num + (i * cpus), e)
}
debug!("Failed to download segment {} ({}). Retrying, {} out of 5 retries left", num + (i * cpus), e, 5 - retry_count);
continue
}
};
match response.bytes().await {
Ok(b) => break b.to_vec(),
Err(e) => {
if e.is_body() {
if retry_count == 5 { if retry_count == 5 {
bail!("Max retry count reached ({}), multiple errors occurred while receiving segment {}: {}", retry_count, num + (i * cpus), e) bail!("Max retry count reached ({}), multiple errors occurred while receiving segment {}: {}", retry_count, num + (i * cpus), e)
} }
debug!("Failed to download segment {} ({}). Retrying, {} out of 5 retries left", num + (i * cpus), e, 5 - retry_count) debug!("Failed to download segment {} ({}). Retrying, {} out of 5 retries left", num + (i * cpus), e, 5 - retry_count)
} else {
bail!("{}", e)
}
} }
} }
retry_count += 1; retry_count += 1;
}; };
buf = VariantSegment::decrypt(buf.borrow_mut(), segment.key)?.to_vec();
let mut c = thread_count.lock().await; let mut c = thread_count.lock().await;
debug!( debug!(
"Downloaded and decrypted segment [{}/{} {:.2}%] {}", "Downloaded and decrypted segment [{}/{} {:.2}%] {}",

View file

@ -9,4 +9,5 @@ pub mod locale;
pub mod log; pub mod log;
pub mod os; pub mod os;
pub mod parse; pub mod parse;
pub mod rate_limit;
pub mod video; pub mod video;

View file

@ -0,0 +1,72 @@
use async_speed_limit::Limiter;
use crunchyroll_rs::error::Error;
use futures_util::TryStreamExt;
use reqwest::{Client, Request, Response, ResponseBuilderExt};
use std::future::Future;
use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tower_service::Service;
pub struct RateLimiterService {
client: Arc<Client>,
rate_limiter: Limiter,
}
impl RateLimiterService {
pub fn new(bytes: u32, client: Client) -> Self {
Self {
client: Arc::new(client),
rate_limiter: Limiter::new(bytes as f64),
}
}
}
impl Service<Request> for RateLimiterService {
type Response = Response;
type Error = Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request) -> Self::Future {
let client = self.client.clone();
let rate_limiter = self.rate_limiter.clone();
Box::pin(async move {
let mut body = vec![];
let res = client.execute(req).await?;
let _url = res.url().clone().to_string();
let url = _url.as_str();
let mut http_res = http::Response::builder()
.url(res.url().clone())
.status(res.status())
.version(res.version());
*http_res.headers_mut().unwrap() = res.headers().clone();
http_res
.extensions_ref()
.unwrap()
.clone_from(&res.extensions());
let limiter = rate_limiter.limit(
res.bytes_stream()
.map_err(io::Error::other)
.into_async_read(),
);
futures_util::io::copy(limiter, &mut body)
.await
.map_err(|e| Error::Request {
url: url.to_string(),
status: None,
message: e.to_string(),
})?;
Ok(Response::from(http_res.body(body).unwrap()))
})
}
}