From be3248a4f9a1d8480558eaf78c8365f7695d5c1b Mon Sep 17 00:00:00 2001 From: bytedream Date: Sun, 10 Dec 2023 02:52:42 +0100 Subject: [PATCH] Add download/request speed limiter (#250) --- Cargo.lock | 90 +++++++++++++++++++++--- crunchy-cli-core/Cargo.toml | 8 ++- crunchy-cli-core/src/archive/command.rs | 2 +- crunchy-cli-core/src/download/command.rs | 2 +- crunchy-cli-core/src/lib.rs | 67 +++++++++++------- crunchy-cli-core/src/utils/clap.rs | 17 +++++ crunchy-cli-core/src/utils/download.rs | 47 +++---------- crunchy-cli-core/src/utils/mod.rs | 1 + crunchy-cli-core/src/utils/rate_limit.rs | 72 +++++++++++++++++++ 9 files changed, 230 insertions(+), 76 deletions(-) create mode 100644 crunchy-cli-core/src/utils/rate_limit.rs diff --git a/Cargo.lock b/Cargo.lock index fa5723f..037a4ff 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -106,6 +106,18 @@ version = "1.0.75" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6" +[[package]] +name = "async-speed-limit" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97d287ccbfb44ae20287d2f9c72ad9e560d50810883870697db5b320c541f183" +dependencies = [ + "futures-core", + "futures-io", + "futures-timer", + "pin-project-lite", +] + [[package]] name = "async-trait" version = "0.1.74" @@ -385,6 +397,7 @@ name = "crunchy-cli-core" version = "3.1.1" dependencies = [ "anyhow", + "async-speed-limit", "async-trait", "chrono", "clap", @@ -394,6 +407,8 @@ dependencies = [ "dialoguer", "dirs", "fs2", + "futures-util", + "http", "indicatif", "lazy_static", "log", @@ -409,13 +424,14 @@ dependencies = [ "sys-locale", "tempfile", "tokio", + "tower-service", ] [[package]] name = "crunchyroll-rs" -version = "0.7.0" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc4ce434784eee7892ad8c3d1ecaea0c0858db51bbb295b474db38c256e8d2fb" +checksum = "56d0d1d6a75ed27b2dc93b84fa0667cffb92a513d206b8ccd1b895fab5ad2e9c" dependencies = [ "aes", "async-trait", @@ -434,14 +450,15 @@ dependencies = [ "serde_urlencoded", "smart-default", "tokio", - "webpki-roots", + "tower-service", + "webpki-roots 0.26.0", ] [[package]] name = "crunchyroll-rs-internal" -version = "0.7.0" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be840f8cf2ce6afc9a9eae268d41423093141ec88f664a515d5ed2a85a66fb60" +checksum = "bdd0a20e750354408294b674a50b6d5dacec315ff9ead7c3a7c093f1e3594335" dependencies = [ "darling", "quote", @@ -687,6 +704,23 @@ version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eb1d22c66e66d9d72e1758f0bd7d4fd0bee04cad842ee34587d68c07e45d088c" +[[package]] +name = "futures-io" +version = "0.3.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8bf34a163b5c4c52d0478a4d757da8fb65cabef42ba90515efee0f6f9fa45aaa" + +[[package]] +name = "futures-macro" +version = "0.3.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.29" @@ -699,6 +733,12 @@ version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "efd193069b0ddadc69c46389b740bbccdd97203899b48d09c5f7969591d6bae2" +[[package]] +name = "futures-timer" +version = "3.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e64b03909df88034c26dc1547e8970b91f98bdb65165d6a4e9110d94263dbb2c" + [[package]] name = "futures-util" version = "0.3.29" @@ -706,7 +746,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a19526d624e703a3179b3d322efec918b6246ea0fa51d41124525f00f1cc8104" dependencies = [ "futures-core", + "futures-io", + "futures-macro", + "futures-sink", "futures-task", + "memchr", "pin-project-lite", "pin-utils", "slab", @@ -1409,12 +1453,14 @@ dependencies = [ "tokio-native-tls", "tokio-rustls", "tokio-socks", + "tokio-util", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams", "web-sys", - "webpki-roots", + "webpki-roots 0.25.2", "winreg", ] @@ -1490,6 +1536,12 @@ dependencies = [ "base64", ] +[[package]] +name = "rustls-pki-types" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7673e0aa20ee4937c6aacfc12bb8341cfbf054cdd21df6bec5fd0629fe9339b" + [[package]] name = "rustls-webpki" version = "0.101.7" @@ -1813,9 +1865,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.34.0" +version = "1.35.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0c014766411e834f7af5b8f4cf46257aab4036ca95e9d2c144a10f59ad6f5b9" +checksum = "841d45b238a16291a4e1584e61820b8ae57d696cc5015c459c229ccc6990cc1c" dependencies = [ "backtrace", "bytes", @@ -2065,6 +2117,19 @@ version = "0.2.88" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0d046c5d029ba91a1ed14da14dca44b68bf2f124cfbaf741c54151fdb3e0750b" +[[package]] +name = "wasm-streams" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4609d447824375f43e1ffbc051b50ad8f4b3ae8219680c94452ea05eb240ac7" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" version = "0.3.65" @@ -2081,6 +2146,15 @@ version = "0.25.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "14247bb57be4f377dfb94c72830b8ce8fc6beac03cf4bf7b9732eadd414123fc" +[[package]] +name = "webpki-roots" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de2cfda980f21be5a7ed2eadb3e6fe074d56022bea2cdeb1a62eb220fc04188" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "winapi" version = "0.3.9" diff --git a/crunchy-cli-core/Cargo.toml b/crunchy-cli-core/Cargo.toml index 4f0912f..173c3ac 100644 --- a/crunchy-cli-core/Cargo.toml +++ b/crunchy-cli-core/Cargo.toml @@ -13,21 +13,24 @@ openssl-tls-static = ["reqwest/native-tls", "reqwest/native-tls-alpn", "reqwest/ [dependencies] anyhow = "1.0" +async-speed-limit = "0.4" async-trait = "0.1" clap = { version = "4.4", features = ["derive", "string"] } chrono = "0.4" -crunchyroll-rs = { version = "0.7.0", features = ["dash-stream"] } +crunchyroll-rs = { version = "0.8.0", features = ["dash-stream", "experimental-stabilizations", "tower"] } ctrlc = "3.4" dialoguer = { version = "0.11", default-features = false } dirs = "5.0" derive_setters = "0.1" +futures-util = { version = "0.3", features = ["io"] } fs2 = "0.4" +http = "0.2" indicatif = "0.17" lazy_static = "1.4" log = { version = "0.4", features = ["std"] } num_cpus = "1.16" regex = "1.10" -reqwest = { version = "0.11", default-features = false, features = ["socks"] } +reqwest = { version = "0.11", default-features = false, features = ["socks", "stream"] } serde = "1.0" serde_json = "1.0" serde_plain = "1.0" @@ -35,6 +38,7 @@ shlex = "1.2" sys-locale = "0.3" tempfile = "3.8" tokio = { version = "1.34", features = ["io-util", "macros", "net", "rt-multi-thread", "time"] } +tower-service = "0.3" rustls-native-certs = { version = "0.6", optional = true } [target.'cfg(not(target_os = "windows"))'.dependencies] diff --git a/crunchy-cli-core/src/archive/command.rs b/crunchy-cli-core/src/archive/command.rs index 1c0723a..ea81d15 100644 --- a/crunchy-cli-core/src/archive/command.rs +++ b/crunchy-cli-core/src/archive/command.rs @@ -242,7 +242,7 @@ impl Execute for Archive { format.visual_output(&path); - downloader.download(&ctx, &path).await? + downloader.download(&path).await? } } diff --git a/crunchy-cli-core/src/download/command.rs b/crunchy-cli-core/src/download/command.rs index a45354b..ade6adf 100644 --- a/crunchy-cli-core/src/download/command.rs +++ b/crunchy-cli-core/src/download/command.rs @@ -257,7 +257,7 @@ impl Execute for Download { format.visual_output(&path); - downloader.download(&ctx, &path).await? + downloader.download(&path).await? } } diff --git a/crunchy-cli-core/src/lib.rs b/crunchy-cli-core/src/lib.rs index 366f153..3299b31 100644 --- a/crunchy-cli-core/src/lib.rs +++ b/crunchy-cli-core/src/lib.rs @@ -17,6 +17,7 @@ mod login; mod search; mod utils; +use crate::utils::rate_limit::RateLimiterService; pub use archive::Archive; use dialoguer::console::Term; pub use download::Download; @@ -66,6 +67,15 @@ pub struct Cli { #[arg(global = true, long)] user_agent: Option, + #[arg( + help = "Maximal speed to download/request (may be a bit off here and there). Must be in format of [B|KB|MB]" + )] + #[arg( + long_help = "Maximal speed to download/request (may be a bit off here and there). Must be in format of [B|KB|MB] (e.g. 500KB or 10MB)" + )] + #[arg(global = true, long, value_parser = crate::utils::clap::clap_parse_speed_limit)] + speed_limit: Option, + #[clap(subcommand)] command: Command, } @@ -264,39 +274,44 @@ async fn crunchyroll_session(cli: &mut Cli) -> Result { lang }; + let client = { + let mut builder = CrunchyrollBuilder::predefined_client_builder(); + if let Some(p) = &cli.proxy { + builder = builder.proxy(p.clone()) + } + if let Some(ua) = &cli.user_agent { + builder = builder.user_agent(ua) + } + + #[cfg(any(feature = "openssl-tls", feature = "openssl-tls-static"))] + let client = { + let mut builder = builder.use_native_tls().tls_built_in_root_certs(false); + + for certificate in rustls_native_certs::load_native_certs().unwrap() { + builder = builder.add_root_certificate( + reqwest::Certificate::from_der(certificate.0.as_slice()).unwrap(), + ) + } + + builder.build().unwrap() + }; + #[cfg(not(any(feature = "openssl-tls", feature = "openssl-tls-static")))] + let client = builder.build().unwrap(); + + client + }; + let mut builder = Crunchyroll::builder() .locale(locale) - .client({ - let mut builder = CrunchyrollBuilder::predefined_client_builder(); - if let Some(p) = &cli.proxy { - builder = builder.proxy(p.clone()) - } - if let Some(ua) = &cli.user_agent { - builder = builder.user_agent(ua) - } - - #[cfg(any(feature = "openssl-tls", feature = "openssl-tls-static"))] - let client = { - let mut builder = builder.use_native_tls().tls_built_in_root_certs(false); - - for certificate in rustls_native_certs::load_native_certs().unwrap() { - builder = builder.add_root_certificate( - reqwest::Certificate::from_der(certificate.0.as_slice()).unwrap(), - ) - } - - builder.build().unwrap() - }; - #[cfg(not(any(feature = "openssl-tls", feature = "openssl-tls-static")))] - let client = builder.build().unwrap(); - - client - }) + .client(client.clone()) .stabilization_locales(cli.experimental_fixes) .stabilization_season_number(cli.experimental_fixes); if let Command::Download(download) = &cli.command { builder = builder.preferred_audio_locale(download.audio.clone()) } + if let Some(speed_limit) = cli.speed_limit { + builder = builder.middleware(RateLimiterService::new(speed_limit, client)); + } let root_login_methods_count = cli.login_method.credentials.is_some() as u8 + cli.login_method.etp_rt.is_some() as u8 diff --git a/crunchy-cli-core/src/utils/clap.rs b/crunchy-cli-core/src/utils/clap.rs index c3088d8..37a34d3 100644 --- a/crunchy-cli-core/src/utils/clap.rs +++ b/crunchy-cli-core/src/utils/clap.rs @@ -9,3 +9,20 @@ pub fn clap_parse_resolution(s: &str) -> Result { pub fn clap_parse_proxy(s: &str) -> Result { Proxy::all(s).map_err(|e| e.to_string()) } + +pub fn clap_parse_speed_limit(s: &str) -> Result { + let quota = s.to_lowercase(); + + let bytes = if let Ok(b) = quota.parse() { + b + } else if let Ok(b) = quota.trim_end_matches('b').parse::() { + b + } else if let Ok(kb) = quota.trim_end_matches("kb").parse::() { + kb * 1024 + } else if let Ok(mb) = quota.trim_end_matches("mb").parse::() { + mb * 1024 * 1024 + } else { + return Err("Invalid speed limit".to_string()); + }; + Ok(bytes) +} diff --git a/crunchy-cli-core/src/utils/download.rs b/crunchy-cli-core/src/utils/download.rs index c715362..bdd76f2 100644 --- a/crunchy-cli-core/src/utils/download.rs +++ b/crunchy-cli-core/src/utils/download.rs @@ -1,4 +1,3 @@ -use crate::utils::context::Context; use crate::utils::ffmpeg::FFmpegPreset; use crate::utils::os::{is_special_file, temp_directory, temp_named_pipe, tempfile}; use anyhow::{bail, Result}; @@ -8,7 +7,7 @@ use crunchyroll_rs::Locale; use indicatif::{ProgressBar, ProgressDrawTarget, ProgressFinish, ProgressStyle}; use log::{debug, warn, LevelFilter}; use regex::Regex; -use std::borrow::{Borrow, BorrowMut}; +use std::borrow::Borrow; use std::cmp::Ordering; use std::collections::BTreeMap; use std::env; @@ -118,7 +117,7 @@ impl Downloader { self.formats.push(format); } - pub async fn download(mut self, ctx: &Context, dst: &Path) -> Result<()> { + pub async fn download(mut self, dst: &Path) -> Result<()> { // `.unwrap_or_default()` here unless https://doc.rust-lang.org/stable/std/path/fn.absolute.html // gets stabilized as the function might throw error on weird file paths let required = self.check_free_space(dst).await.unwrap_or_default(); @@ -197,7 +196,6 @@ impl Downloader { for (i, format) in self.formats.iter().enumerate() { let video_path = self .download_video( - ctx, &format.video.0, format!("{:<1$}", format!("Downloading video #{}", i + 1), fmt_space), ) @@ -205,7 +203,6 @@ impl Downloader { for (variant_data, locale) in format.audios.iter() { let audio_path = self .download_audio( - ctx, variant_data, format!("{:<1$}", format!("Downloading {} audio", locale), fmt_space), ) @@ -554,14 +551,13 @@ impl Downloader { async fn download_video( &self, - ctx: &Context, variant_data: &VariantData, message: String, ) -> Result { let tempfile = tempfile(".mp4")?; let (mut file, path) = tempfile.into_parts(); - self.download_segments(ctx, &mut file, message, variant_data) + self.download_segments(&mut file, message, variant_data) .await?; Ok(path) @@ -569,14 +565,13 @@ impl Downloader { async fn download_audio( &self, - ctx: &Context, variant_data: &VariantData, message: String, ) -> Result { let tempfile = tempfile(".m4a")?; let (mut file, path) = tempfile.into_parts(); - self.download_segments(ctx, &mut file, message, variant_data) + self.download_segments(&mut file, message, variant_data) .await?; Ok(path) @@ -601,7 +596,6 @@ impl Downloader { async fn download_segments( &self, - ctx: &Context, writer: &mut impl Write, message: String, variant_data: &VariantData, @@ -609,7 +603,6 @@ impl Downloader { let segments = variant_data.segments().await?; let total_segments = segments.len(); - let client = Arc::new(ctx.crunchy.client()); let count = Arc::new(Mutex::new(0)); let progress = if log::max_level() == LevelFilter::Info { @@ -643,7 +636,6 @@ impl Downloader { let mut join_set: JoinSet> = JoinSet::new(); for num in 0..cpus { - let thread_client = client.clone(); let thread_sender = sender.clone(); let thread_segments = segs.remove(0); let thread_count = count.clone(); @@ -656,42 +648,21 @@ impl Downloader { let download = || async move { for (i, segment) in thread_segments.into_iter().enumerate() { let mut retry_count = 0; - let mut buf = loop { - let request = thread_client - .get(&segment.url) - .timeout(Duration::from_secs(60)) - .send(); - - let response = match request.await { - Ok(r) => r, + let buf = loop { + let mut buf = vec![]; + match segment.write_to(&mut buf).await { + Ok(_) => break buf, Err(e) => { if retry_count == 5 { bail!("Max retry count reached ({}), multiple errors occurred while receiving segment {}: {}", retry_count, num + (i * cpus), e) } - debug!("Failed to download segment {} ({}). Retrying, {} out of 5 retries left", num + (i * cpus), e, 5 - retry_count); - continue - } - }; - - match response.bytes().await { - Ok(b) => break b.to_vec(), - Err(e) => { - if e.is_body() { - if retry_count == 5 { - bail!("Max retry count reached ({}), multiple errors occurred while receiving segment {}: {}", retry_count, num + (i * cpus), e) - } - debug!("Failed to download segment {} ({}). Retrying, {} out of 5 retries left", num + (i * cpus), e, 5 - retry_count) - } else { - bail!("{}", e) - } + debug!("Failed to download segment {} ({}). Retrying, {} out of 5 retries left", num + (i * cpus), e, 5 - retry_count) } } retry_count += 1; }; - buf = VariantSegment::decrypt(buf.borrow_mut(), segment.key)?.to_vec(); - let mut c = thread_count.lock().await; debug!( "Downloaded and decrypted segment [{}/{} {:.2}%] {}", diff --git a/crunchy-cli-core/src/utils/mod.rs b/crunchy-cli-core/src/utils/mod.rs index d46cc33..e5c4894 100644 --- a/crunchy-cli-core/src/utils/mod.rs +++ b/crunchy-cli-core/src/utils/mod.rs @@ -9,4 +9,5 @@ pub mod locale; pub mod log; pub mod os; pub mod parse; +pub mod rate_limit; pub mod video; diff --git a/crunchy-cli-core/src/utils/rate_limit.rs b/crunchy-cli-core/src/utils/rate_limit.rs new file mode 100644 index 0000000..16b22b3 --- /dev/null +++ b/crunchy-cli-core/src/utils/rate_limit.rs @@ -0,0 +1,72 @@ +use async_speed_limit::Limiter; +use crunchyroll_rs::error::Error; +use futures_util::TryStreamExt; +use reqwest::{Client, Request, Response, ResponseBuilderExt}; +use std::future::Future; +use std::io; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use tower_service::Service; + +pub struct RateLimiterService { + client: Arc, + rate_limiter: Limiter, +} + +impl RateLimiterService { + pub fn new(bytes: u32, client: Client) -> Self { + Self { + client: Arc::new(client), + rate_limiter: Limiter::new(bytes as f64), + } + } +} + +impl Service for RateLimiterService { + type Response = Response; + type Error = Error; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request) -> Self::Future { + let client = self.client.clone(); + let rate_limiter = self.rate_limiter.clone(); + + Box::pin(async move { + let mut body = vec![]; + let res = client.execute(req).await?; + let _url = res.url().clone().to_string(); + let url = _url.as_str(); + + let mut http_res = http::Response::builder() + .url(res.url().clone()) + .status(res.status()) + .version(res.version()); + *http_res.headers_mut().unwrap() = res.headers().clone(); + http_res + .extensions_ref() + .unwrap() + .clone_from(&res.extensions()); + + let limiter = rate_limiter.limit( + res.bytes_stream() + .map_err(io::Error::other) + .into_async_read(), + ); + + futures_util::io::copy(limiter, &mut body) + .await + .map_err(|e| Error::Request { + url: url.to_string(), + status: None, + message: e.to_string(), + })?; + + Ok(Response::from(http_res.body(body).unwrap())) + }) + } +}