Add download/request speed limiter (#250)

This commit is contained in:
bytedream 2023-12-10 02:52:42 +01:00
parent f9e431e181
commit be3248a4f9
9 changed files with 230 additions and 76 deletions

View file

@ -9,3 +9,20 @@ pub fn clap_parse_resolution(s: &str) -> Result<Resolution, String> {
pub fn clap_parse_proxy(s: &str) -> Result<Proxy, String> {
Proxy::all(s).map_err(|e| e.to_string())
}
pub fn clap_parse_speed_limit(s: &str) -> Result<u32, String> {
let quota = s.to_lowercase();
let bytes = if let Ok(b) = quota.parse() {
b
} else if let Ok(b) = quota.trim_end_matches('b').parse::<u32>() {
b
} else if let Ok(kb) = quota.trim_end_matches("kb").parse::<u32>() {
kb * 1024
} else if let Ok(mb) = quota.trim_end_matches("mb").parse::<u32>() {
mb * 1024 * 1024
} else {
return Err("Invalid speed limit".to_string());
};
Ok(bytes)
}

View file

@ -1,4 +1,3 @@
use crate::utils::context::Context;
use crate::utils::ffmpeg::FFmpegPreset;
use crate::utils::os::{is_special_file, temp_directory, temp_named_pipe, tempfile};
use anyhow::{bail, Result};
@ -8,7 +7,7 @@ use crunchyroll_rs::Locale;
use indicatif::{ProgressBar, ProgressDrawTarget, ProgressFinish, ProgressStyle};
use log::{debug, warn, LevelFilter};
use regex::Regex;
use std::borrow::{Borrow, BorrowMut};
use std::borrow::Borrow;
use std::cmp::Ordering;
use std::collections::BTreeMap;
use std::env;
@ -118,7 +117,7 @@ impl Downloader {
self.formats.push(format);
}
pub async fn download(mut self, ctx: &Context, dst: &Path) -> Result<()> {
pub async fn download(mut self, dst: &Path) -> Result<()> {
// `.unwrap_or_default()` here unless https://doc.rust-lang.org/stable/std/path/fn.absolute.html
// gets stabilized as the function might throw error on weird file paths
let required = self.check_free_space(dst).await.unwrap_or_default();
@ -197,7 +196,6 @@ impl Downloader {
for (i, format) in self.formats.iter().enumerate() {
let video_path = self
.download_video(
ctx,
&format.video.0,
format!("{:<1$}", format!("Downloading video #{}", i + 1), fmt_space),
)
@ -205,7 +203,6 @@ impl Downloader {
for (variant_data, locale) in format.audios.iter() {
let audio_path = self
.download_audio(
ctx,
variant_data,
format!("{:<1$}", format!("Downloading {} audio", locale), fmt_space),
)
@ -554,14 +551,13 @@ impl Downloader {
async fn download_video(
&self,
ctx: &Context,
variant_data: &VariantData,
message: String,
) -> Result<TempPath> {
let tempfile = tempfile(".mp4")?;
let (mut file, path) = tempfile.into_parts();
self.download_segments(ctx, &mut file, message, variant_data)
self.download_segments(&mut file, message, variant_data)
.await?;
Ok(path)
@ -569,14 +565,13 @@ impl Downloader {
async fn download_audio(
&self,
ctx: &Context,
variant_data: &VariantData,
message: String,
) -> Result<TempPath> {
let tempfile = tempfile(".m4a")?;
let (mut file, path) = tempfile.into_parts();
self.download_segments(ctx, &mut file, message, variant_data)
self.download_segments(&mut file, message, variant_data)
.await?;
Ok(path)
@ -601,7 +596,6 @@ impl Downloader {
async fn download_segments(
&self,
ctx: &Context,
writer: &mut impl Write,
message: String,
variant_data: &VariantData,
@ -609,7 +603,6 @@ impl Downloader {
let segments = variant_data.segments().await?;
let total_segments = segments.len();
let client = Arc::new(ctx.crunchy.client());
let count = Arc::new(Mutex::new(0));
let progress = if log::max_level() == LevelFilter::Info {
@ -643,7 +636,6 @@ impl Downloader {
let mut join_set: JoinSet<Result<()>> = JoinSet::new();
for num in 0..cpus {
let thread_client = client.clone();
let thread_sender = sender.clone();
let thread_segments = segs.remove(0);
let thread_count = count.clone();
@ -656,42 +648,21 @@ impl Downloader {
let download = || async move {
for (i, segment) in thread_segments.into_iter().enumerate() {
let mut retry_count = 0;
let mut buf = loop {
let request = thread_client
.get(&segment.url)
.timeout(Duration::from_secs(60))
.send();
let response = match request.await {
Ok(r) => r,
let buf = loop {
let mut buf = vec![];
match segment.write_to(&mut buf).await {
Ok(_) => break buf,
Err(e) => {
if retry_count == 5 {
bail!("Max retry count reached ({}), multiple errors occurred while receiving segment {}: {}", retry_count, num + (i * cpus), e)
}
debug!("Failed to download segment {} ({}). Retrying, {} out of 5 retries left", num + (i * cpus), e, 5 - retry_count);
continue
}
};
match response.bytes().await {
Ok(b) => break b.to_vec(),
Err(e) => {
if e.is_body() {
if retry_count == 5 {
bail!("Max retry count reached ({}), multiple errors occurred while receiving segment {}: {}", retry_count, num + (i * cpus), e)
}
debug!("Failed to download segment {} ({}). Retrying, {} out of 5 retries left", num + (i * cpus), e, 5 - retry_count)
} else {
bail!("{}", e)
}
debug!("Failed to download segment {} ({}). Retrying, {} out of 5 retries left", num + (i * cpus), e, 5 - retry_count)
}
}
retry_count += 1;
};
buf = VariantSegment::decrypt(buf.borrow_mut(), segment.key)?.to_vec();
let mut c = thread_count.lock().await;
debug!(
"Downloaded and decrypted segment [{}/{} {:.2}%] {}",

View file

@ -9,4 +9,5 @@ pub mod locale;
pub mod log;
pub mod os;
pub mod parse;
pub mod rate_limit;
pub mod video;

View file

@ -0,0 +1,72 @@
use async_speed_limit::Limiter;
use crunchyroll_rs::error::Error;
use futures_util::TryStreamExt;
use reqwest::{Client, Request, Response, ResponseBuilderExt};
use std::future::Future;
use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tower_service::Service;
pub struct RateLimiterService {
client: Arc<Client>,
rate_limiter: Limiter,
}
impl RateLimiterService {
pub fn new(bytes: u32, client: Client) -> Self {
Self {
client: Arc::new(client),
rate_limiter: Limiter::new(bytes as f64),
}
}
}
impl Service<Request> for RateLimiterService {
type Response = Response;
type Error = Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request) -> Self::Future {
let client = self.client.clone();
let rate_limiter = self.rate_limiter.clone();
Box::pin(async move {
let mut body = vec![];
let res = client.execute(req).await?;
let _url = res.url().clone().to_string();
let url = _url.as_str();
let mut http_res = http::Response::builder()
.url(res.url().clone())
.status(res.status())
.version(res.version());
*http_res.headers_mut().unwrap() = res.headers().clone();
http_res
.extensions_ref()
.unwrap()
.clone_from(&res.extensions());
let limiter = rate_limiter.limit(
res.bytes_stream()
.map_err(io::Error::other)
.into_async_read(),
);
futures_util::io::copy(limiter, &mut body)
.await
.map_err(|e| Error::Request {
url: url.to_string(),
status: None,
message: e.to_string(),
})?;
Ok(Response::from(http_res.body(body).unwrap()))
})
}
}