mirror of
https://github.com/crunchy-labs/crunchy-cli.git
synced 2026-01-21 04:02:00 -06:00
Add download/request speed limiter (#250)
This commit is contained in:
parent
f9e431e181
commit
be3248a4f9
9 changed files with 230 additions and 76 deletions
|
|
@ -13,21 +13,24 @@ openssl-tls-static = ["reqwest/native-tls", "reqwest/native-tls-alpn", "reqwest/
|
|||
|
||||
[dependencies]
|
||||
anyhow = "1.0"
|
||||
async-speed-limit = "0.4"
|
||||
async-trait = "0.1"
|
||||
clap = { version = "4.4", features = ["derive", "string"] }
|
||||
chrono = "0.4"
|
||||
crunchyroll-rs = { version = "0.7.0", features = ["dash-stream"] }
|
||||
crunchyroll-rs = { version = "0.8.0", features = ["dash-stream", "experimental-stabilizations", "tower"] }
|
||||
ctrlc = "3.4"
|
||||
dialoguer = { version = "0.11", default-features = false }
|
||||
dirs = "5.0"
|
||||
derive_setters = "0.1"
|
||||
futures-util = { version = "0.3", features = ["io"] }
|
||||
fs2 = "0.4"
|
||||
http = "0.2"
|
||||
indicatif = "0.17"
|
||||
lazy_static = "1.4"
|
||||
log = { version = "0.4", features = ["std"] }
|
||||
num_cpus = "1.16"
|
||||
regex = "1.10"
|
||||
reqwest = { version = "0.11", default-features = false, features = ["socks"] }
|
||||
reqwest = { version = "0.11", default-features = false, features = ["socks", "stream"] }
|
||||
serde = "1.0"
|
||||
serde_json = "1.0"
|
||||
serde_plain = "1.0"
|
||||
|
|
@ -35,6 +38,7 @@ shlex = "1.2"
|
|||
sys-locale = "0.3"
|
||||
tempfile = "3.8"
|
||||
tokio = { version = "1.34", features = ["io-util", "macros", "net", "rt-multi-thread", "time"] }
|
||||
tower-service = "0.3"
|
||||
rustls-native-certs = { version = "0.6", optional = true }
|
||||
|
||||
[target.'cfg(not(target_os = "windows"))'.dependencies]
|
||||
|
|
|
|||
|
|
@ -242,7 +242,7 @@ impl Execute for Archive {
|
|||
|
||||
format.visual_output(&path);
|
||||
|
||||
downloader.download(&ctx, &path).await?
|
||||
downloader.download(&path).await?
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -257,7 +257,7 @@ impl Execute for Download {
|
|||
|
||||
format.visual_output(&path);
|
||||
|
||||
downloader.download(&ctx, &path).await?
|
||||
downloader.download(&path).await?
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ mod login;
|
|||
mod search;
|
||||
mod utils;
|
||||
|
||||
use crate::utils::rate_limit::RateLimiterService;
|
||||
pub use archive::Archive;
|
||||
use dialoguer::console::Term;
|
||||
pub use download::Download;
|
||||
|
|
@ -66,6 +67,15 @@ pub struct Cli {
|
|||
#[arg(global = true, long)]
|
||||
user_agent: Option<String>,
|
||||
|
||||
#[arg(
|
||||
help = "Maximal speed to download/request (may be a bit off here and there). Must be in format of <number>[B|KB|MB]"
|
||||
)]
|
||||
#[arg(
|
||||
long_help = "Maximal speed to download/request (may be a bit off here and there). Must be in format of <number>[B|KB|MB] (e.g. 500KB or 10MB)"
|
||||
)]
|
||||
#[arg(global = true, long, value_parser = crate::utils::clap::clap_parse_speed_limit)]
|
||||
speed_limit: Option<u32>,
|
||||
|
||||
#[clap(subcommand)]
|
||||
command: Command,
|
||||
}
|
||||
|
|
@ -264,39 +274,44 @@ async fn crunchyroll_session(cli: &mut Cli) -> Result<Crunchyroll> {
|
|||
lang
|
||||
};
|
||||
|
||||
let client = {
|
||||
let mut builder = CrunchyrollBuilder::predefined_client_builder();
|
||||
if let Some(p) = &cli.proxy {
|
||||
builder = builder.proxy(p.clone())
|
||||
}
|
||||
if let Some(ua) = &cli.user_agent {
|
||||
builder = builder.user_agent(ua)
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "openssl-tls", feature = "openssl-tls-static"))]
|
||||
let client = {
|
||||
let mut builder = builder.use_native_tls().tls_built_in_root_certs(false);
|
||||
|
||||
for certificate in rustls_native_certs::load_native_certs().unwrap() {
|
||||
builder = builder.add_root_certificate(
|
||||
reqwest::Certificate::from_der(certificate.0.as_slice()).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
builder.build().unwrap()
|
||||
};
|
||||
#[cfg(not(any(feature = "openssl-tls", feature = "openssl-tls-static")))]
|
||||
let client = builder.build().unwrap();
|
||||
|
||||
client
|
||||
};
|
||||
|
||||
let mut builder = Crunchyroll::builder()
|
||||
.locale(locale)
|
||||
.client({
|
||||
let mut builder = CrunchyrollBuilder::predefined_client_builder();
|
||||
if let Some(p) = &cli.proxy {
|
||||
builder = builder.proxy(p.clone())
|
||||
}
|
||||
if let Some(ua) = &cli.user_agent {
|
||||
builder = builder.user_agent(ua)
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "openssl-tls", feature = "openssl-tls-static"))]
|
||||
let client = {
|
||||
let mut builder = builder.use_native_tls().tls_built_in_root_certs(false);
|
||||
|
||||
for certificate in rustls_native_certs::load_native_certs().unwrap() {
|
||||
builder = builder.add_root_certificate(
|
||||
reqwest::Certificate::from_der(certificate.0.as_slice()).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
builder.build().unwrap()
|
||||
};
|
||||
#[cfg(not(any(feature = "openssl-tls", feature = "openssl-tls-static")))]
|
||||
let client = builder.build().unwrap();
|
||||
|
||||
client
|
||||
})
|
||||
.client(client.clone())
|
||||
.stabilization_locales(cli.experimental_fixes)
|
||||
.stabilization_season_number(cli.experimental_fixes);
|
||||
if let Command::Download(download) = &cli.command {
|
||||
builder = builder.preferred_audio_locale(download.audio.clone())
|
||||
}
|
||||
if let Some(speed_limit) = cli.speed_limit {
|
||||
builder = builder.middleware(RateLimiterService::new(speed_limit, client));
|
||||
}
|
||||
|
||||
let root_login_methods_count = cli.login_method.credentials.is_some() as u8
|
||||
+ cli.login_method.etp_rt.is_some() as u8
|
||||
|
|
|
|||
|
|
@ -9,3 +9,20 @@ pub fn clap_parse_resolution(s: &str) -> Result<Resolution, String> {
|
|||
pub fn clap_parse_proxy(s: &str) -> Result<Proxy, String> {
|
||||
Proxy::all(s).map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
pub fn clap_parse_speed_limit(s: &str) -> Result<u32, String> {
|
||||
let quota = s.to_lowercase();
|
||||
|
||||
let bytes = if let Ok(b) = quota.parse() {
|
||||
b
|
||||
} else if let Ok(b) = quota.trim_end_matches('b').parse::<u32>() {
|
||||
b
|
||||
} else if let Ok(kb) = quota.trim_end_matches("kb").parse::<u32>() {
|
||||
kb * 1024
|
||||
} else if let Ok(mb) = quota.trim_end_matches("mb").parse::<u32>() {
|
||||
mb * 1024 * 1024
|
||||
} else {
|
||||
return Err("Invalid speed limit".to_string());
|
||||
};
|
||||
Ok(bytes)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
use crate::utils::context::Context;
|
||||
use crate::utils::ffmpeg::FFmpegPreset;
|
||||
use crate::utils::os::{is_special_file, temp_directory, temp_named_pipe, tempfile};
|
||||
use anyhow::{bail, Result};
|
||||
|
|
@ -8,7 +7,7 @@ use crunchyroll_rs::Locale;
|
|||
use indicatif::{ProgressBar, ProgressDrawTarget, ProgressFinish, ProgressStyle};
|
||||
use log::{debug, warn, LevelFilter};
|
||||
use regex::Regex;
|
||||
use std::borrow::{Borrow, BorrowMut};
|
||||
use std::borrow::Borrow;
|
||||
use std::cmp::Ordering;
|
||||
use std::collections::BTreeMap;
|
||||
use std::env;
|
||||
|
|
@ -118,7 +117,7 @@ impl Downloader {
|
|||
self.formats.push(format);
|
||||
}
|
||||
|
||||
pub async fn download(mut self, ctx: &Context, dst: &Path) -> Result<()> {
|
||||
pub async fn download(mut self, dst: &Path) -> Result<()> {
|
||||
// `.unwrap_or_default()` here unless https://doc.rust-lang.org/stable/std/path/fn.absolute.html
|
||||
// gets stabilized as the function might throw error on weird file paths
|
||||
let required = self.check_free_space(dst).await.unwrap_or_default();
|
||||
|
|
@ -197,7 +196,6 @@ impl Downloader {
|
|||
for (i, format) in self.formats.iter().enumerate() {
|
||||
let video_path = self
|
||||
.download_video(
|
||||
ctx,
|
||||
&format.video.0,
|
||||
format!("{:<1$}", format!("Downloading video #{}", i + 1), fmt_space),
|
||||
)
|
||||
|
|
@ -205,7 +203,6 @@ impl Downloader {
|
|||
for (variant_data, locale) in format.audios.iter() {
|
||||
let audio_path = self
|
||||
.download_audio(
|
||||
ctx,
|
||||
variant_data,
|
||||
format!("{:<1$}", format!("Downloading {} audio", locale), fmt_space),
|
||||
)
|
||||
|
|
@ -554,14 +551,13 @@ impl Downloader {
|
|||
|
||||
async fn download_video(
|
||||
&self,
|
||||
ctx: &Context,
|
||||
variant_data: &VariantData,
|
||||
message: String,
|
||||
) -> Result<TempPath> {
|
||||
let tempfile = tempfile(".mp4")?;
|
||||
let (mut file, path) = tempfile.into_parts();
|
||||
|
||||
self.download_segments(ctx, &mut file, message, variant_data)
|
||||
self.download_segments(&mut file, message, variant_data)
|
||||
.await?;
|
||||
|
||||
Ok(path)
|
||||
|
|
@ -569,14 +565,13 @@ impl Downloader {
|
|||
|
||||
async fn download_audio(
|
||||
&self,
|
||||
ctx: &Context,
|
||||
variant_data: &VariantData,
|
||||
message: String,
|
||||
) -> Result<TempPath> {
|
||||
let tempfile = tempfile(".m4a")?;
|
||||
let (mut file, path) = tempfile.into_parts();
|
||||
|
||||
self.download_segments(ctx, &mut file, message, variant_data)
|
||||
self.download_segments(&mut file, message, variant_data)
|
||||
.await?;
|
||||
|
||||
Ok(path)
|
||||
|
|
@ -601,7 +596,6 @@ impl Downloader {
|
|||
|
||||
async fn download_segments(
|
||||
&self,
|
||||
ctx: &Context,
|
||||
writer: &mut impl Write,
|
||||
message: String,
|
||||
variant_data: &VariantData,
|
||||
|
|
@ -609,7 +603,6 @@ impl Downloader {
|
|||
let segments = variant_data.segments().await?;
|
||||
let total_segments = segments.len();
|
||||
|
||||
let client = Arc::new(ctx.crunchy.client());
|
||||
let count = Arc::new(Mutex::new(0));
|
||||
|
||||
let progress = if log::max_level() == LevelFilter::Info {
|
||||
|
|
@ -643,7 +636,6 @@ impl Downloader {
|
|||
|
||||
let mut join_set: JoinSet<Result<()>> = JoinSet::new();
|
||||
for num in 0..cpus {
|
||||
let thread_client = client.clone();
|
||||
let thread_sender = sender.clone();
|
||||
let thread_segments = segs.remove(0);
|
||||
let thread_count = count.clone();
|
||||
|
|
@ -656,42 +648,21 @@ impl Downloader {
|
|||
let download = || async move {
|
||||
for (i, segment) in thread_segments.into_iter().enumerate() {
|
||||
let mut retry_count = 0;
|
||||
let mut buf = loop {
|
||||
let request = thread_client
|
||||
.get(&segment.url)
|
||||
.timeout(Duration::from_secs(60))
|
||||
.send();
|
||||
|
||||
let response = match request.await {
|
||||
Ok(r) => r,
|
||||
let buf = loop {
|
||||
let mut buf = vec![];
|
||||
match segment.write_to(&mut buf).await {
|
||||
Ok(_) => break buf,
|
||||
Err(e) => {
|
||||
if retry_count == 5 {
|
||||
bail!("Max retry count reached ({}), multiple errors occurred while receiving segment {}: {}", retry_count, num + (i * cpus), e)
|
||||
}
|
||||
debug!("Failed to download segment {} ({}). Retrying, {} out of 5 retries left", num + (i * cpus), e, 5 - retry_count);
|
||||
continue
|
||||
}
|
||||
};
|
||||
|
||||
match response.bytes().await {
|
||||
Ok(b) => break b.to_vec(),
|
||||
Err(e) => {
|
||||
if e.is_body() {
|
||||
if retry_count == 5 {
|
||||
bail!("Max retry count reached ({}), multiple errors occurred while receiving segment {}: {}", retry_count, num + (i * cpus), e)
|
||||
}
|
||||
debug!("Failed to download segment {} ({}). Retrying, {} out of 5 retries left", num + (i * cpus), e, 5 - retry_count)
|
||||
} else {
|
||||
bail!("{}", e)
|
||||
}
|
||||
debug!("Failed to download segment {} ({}). Retrying, {} out of 5 retries left", num + (i * cpus), e, 5 - retry_count)
|
||||
}
|
||||
}
|
||||
|
||||
retry_count += 1;
|
||||
};
|
||||
|
||||
buf = VariantSegment::decrypt(buf.borrow_mut(), segment.key)?.to_vec();
|
||||
|
||||
let mut c = thread_count.lock().await;
|
||||
debug!(
|
||||
"Downloaded and decrypted segment [{}/{} {:.2}%] {}",
|
||||
|
|
|
|||
|
|
@ -9,4 +9,5 @@ pub mod locale;
|
|||
pub mod log;
|
||||
pub mod os;
|
||||
pub mod parse;
|
||||
pub mod rate_limit;
|
||||
pub mod video;
|
||||
|
|
|
|||
72
crunchy-cli-core/src/utils/rate_limit.rs
Normal file
72
crunchy-cli-core/src/utils/rate_limit.rs
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
use async_speed_limit::Limiter;
|
||||
use crunchyroll_rs::error::Error;
|
||||
use futures_util::TryStreamExt;
|
||||
use reqwest::{Client, Request, Response, ResponseBuilderExt};
|
||||
use std::future::Future;
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
use tower_service::Service;
|
||||
|
||||
pub struct RateLimiterService {
|
||||
client: Arc<Client>,
|
||||
rate_limiter: Limiter,
|
||||
}
|
||||
|
||||
impl RateLimiterService {
|
||||
pub fn new(bytes: u32, client: Client) -> Self {
|
||||
Self {
|
||||
client: Arc::new(client),
|
||||
rate_limiter: Limiter::new(bytes as f64),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Service<Request> for RateLimiterService {
|
||||
type Response = Response;
|
||||
type Error = Error;
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||
|
||||
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, req: Request) -> Self::Future {
|
||||
let client = self.client.clone();
|
||||
let rate_limiter = self.rate_limiter.clone();
|
||||
|
||||
Box::pin(async move {
|
||||
let mut body = vec![];
|
||||
let res = client.execute(req).await?;
|
||||
let _url = res.url().clone().to_string();
|
||||
let url = _url.as_str();
|
||||
|
||||
let mut http_res = http::Response::builder()
|
||||
.url(res.url().clone())
|
||||
.status(res.status())
|
||||
.version(res.version());
|
||||
*http_res.headers_mut().unwrap() = res.headers().clone();
|
||||
http_res
|
||||
.extensions_ref()
|
||||
.unwrap()
|
||||
.clone_from(&res.extensions());
|
||||
|
||||
let limiter = rate_limiter.limit(
|
||||
res.bytes_stream()
|
||||
.map_err(io::Error::other)
|
||||
.into_async_read(),
|
||||
);
|
||||
|
||||
futures_util::io::copy(limiter, &mut body)
|
||||
.await
|
||||
.map_err(|e| Error::Request {
|
||||
url: url.to_string(),
|
||||
status: None,
|
||||
message: e.to_string(),
|
||||
})?;
|
||||
|
||||
Ok(Response::from(http_res.body(body).unwrap()))
|
||||
})
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue