Add single-threaded option to downloading commands

This commit is contained in:
Valentine Briese 2023-10-13 12:21:38 -07:00
parent 13335c020b
commit 9ace4f3476
3 changed files with 171 additions and 149 deletions

View file

@ -98,6 +98,10 @@ pub struct Archive {
#[arg(short, long, default_value_t = false)]
pub(crate) yes: bool,
#[arg(help = "Download using only one thread")]
#[arg(short = 't', long, default_value_t = false)]
pub(crate) single_threaded: bool,
#[arg(help = "Crunchyroll series url(s)")]
#[arg(required = true)]
pub(crate) urls: Vec<String>,
@ -158,7 +162,8 @@ impl Execute for Archive {
.ffmpeg_preset(self.ffmpeg_preset.clone().unwrap_or_default())
.output_format(Some("matroska".to_string()))
.audio_sort(Some(self.audio.clone()))
.subtitle_sort(Some(self.subtitle.clone()));
.subtitle_sort(Some(self.subtitle.clone()))
.single_threaded(self.single_threaded);
for single_formats in single_format_collection.into_iter() {
let (download_formats, mut format) = get_format(&self, &single_formats).await?;

View file

@ -80,6 +80,10 @@ pub struct Download {
#[arg(long, default_value_t = false)]
pub(crate) force_hardsub: bool,
#[arg(help = "Download using only one thread")]
#[arg(short = 't', long, default_value_t = false)]
pub(crate) single_threaded: bool,
#[arg(help = "Url(s) to Crunchyroll episodes or series")]
#[arg(required = true)]
pub(crate) urls: Vec<String>,
@ -149,7 +153,8 @@ impl Execute for Download {
} else {
None
})
.ffmpeg_preset(self.ffmpeg_preset.clone().unwrap_or_default());
.ffmpeg_preset(self.ffmpeg_preset.clone().unwrap_or_default())
.single_threaded(self.single_threaded);
for mut single_formats in single_format_collection.into_iter() {
// the vec contains always only one item

View file

@ -50,6 +50,7 @@ pub struct DownloadBuilder {
audio_sort: Option<Vec<Locale>>,
subtitle_sort: Option<Vec<Locale>>,
force_hardsub: bool,
single_threaded: bool,
}
impl DownloadBuilder {
@ -61,6 +62,7 @@ impl DownloadBuilder {
audio_sort: None,
subtitle_sort: None,
force_hardsub: false,
single_threaded: false,
}
}
@ -73,6 +75,7 @@ impl DownloadBuilder {
subtitle_sort: self.subtitle_sort,
force_hardsub: self.force_hardsub,
single_threaded: self.single_threaded,
formats: vec![],
}
@ -99,6 +102,7 @@ pub struct Downloader {
subtitle_sort: Option<Vec<Locale>>,
force_hardsub: bool,
single_threaded: bool,
formats: Vec<DownloadFormat>,
}
@ -502,7 +506,8 @@ impl Downloader {
let tempfile = tempfile(".mp4")?;
let (mut file, path) = tempfile.into_parts();
download_segments(ctx, &mut file, message, variant_data).await?;
self.download_segments(ctx, &mut file, message, variant_data)
.await?;
Ok(path)
}
@ -516,7 +521,8 @@ impl Downloader {
let tempfile = tempfile(".m4a")?;
let (mut file, path) = tempfile.into_parts();
download_segments(ctx, &mut file, message, variant_data).await?;
self.download_segments(ctx, &mut file, message, variant_data)
.await?;
Ok(path)
}
@ -537,188 +543,194 @@ impl Downloader {
Ok(path)
}
}
pub async fn download_segments(
ctx: &Context,
writer: &mut impl Write,
message: String,
variant_data: &VariantData,
) -> Result<()> {
let segments = variant_data.segments().await?;
let total_segments = segments.len();
async fn download_segments(
&self,
ctx: &Context,
writer: &mut impl Write,
message: String,
variant_data: &VariantData,
) -> Result<()> {
let segments = variant_data.segments().await?;
let total_segments = segments.len();
let client = Arc::new(ctx.crunchy.client());
let count = Arc::new(Mutex::new(0));
let client = Arc::new(ctx.crunchy.client());
let count = Arc::new(Mutex::new(0));
let progress = if log::max_level() == LevelFilter::Info {
let estimated_file_size = estimate_variant_file_size(variant_data, &segments);
let progress = if log::max_level() == LevelFilter::Info {
let estimated_file_size = estimate_variant_file_size(variant_data, &segments);
let progress = ProgressBar::new(estimated_file_size)
.with_style(
ProgressStyle::with_template(
":: {msg} {bytes:>10} {bytes_per_sec:>12} [{wide_bar}] {percent:>3}%",
let progress = ProgressBar::new(estimated_file_size)
.with_style(
ProgressStyle::with_template(
":: {msg} {bytes:>10} {bytes_per_sec:>12} [{wide_bar}] {percent:>3}%",
)
.unwrap()
.progress_chars("##-"),
)
.unwrap()
.progress_chars("##-"),
)
.with_message(message)
.with_finish(ProgressFinish::Abandon);
Some(progress)
} else {
None
};
.with_message(message)
.with_finish(ProgressFinish::Abandon);
Some(progress)
} else {
None
};
let cpus = num_cpus::get();
let mut segs: Vec<Vec<VariantSegment>> = Vec::with_capacity(cpus);
for _ in 0..cpus {
segs.push(vec![])
}
for (i, segment) in segments.clone().into_iter().enumerate() {
segs[i - ((i / cpus) * cpus)].push(segment);
}
// Only use 1 CPU (core?) if `single-threaded` option is enabled
let cpus = if self.single_threaded {
1
} else {
num_cpus::get()
};
let mut segs: Vec<Vec<VariantSegment>> = Vec::with_capacity(cpus);
for _ in 0..cpus {
segs.push(vec![])
}
for (i, segment) in segments.clone().into_iter().enumerate() {
segs[i - ((i / cpus) * cpus)].push(segment);
}
let (sender, mut receiver) = unbounded_channel();
let (sender, mut receiver) = unbounded_channel();
let mut join_set: JoinSet<Result<()>> = JoinSet::new();
for num in 0..cpus {
let thread_client = client.clone();
let thread_sender = sender.clone();
let thread_segments = segs.remove(0);
let thread_count = count.clone();
join_set.spawn(async move {
let after_download_sender = thread_sender.clone();
let mut join_set: JoinSet<Result<()>> = JoinSet::new();
for num in 0..cpus {
let thread_client = client.clone();
let thread_sender = sender.clone();
let thread_segments = segs.remove(0);
let thread_count = count.clone();
join_set.spawn(async move {
let after_download_sender = thread_sender.clone();
// the download process is encapsulated in its own function. this is done to easily
// catch errors which get returned with `...?` and `bail!(...)` and that the thread
// itself can report that an error has occurred
let download = || async move {
for (i, segment) in thread_segments.into_iter().enumerate() {
let mut retry_count = 0;
let mut buf = loop {
let request = thread_client
.get(&segment.url)
.timeout(Duration::from_secs(60))
.send();
// the download process is encapsulated in its own function. this is done to easily
// catch errors which get returned with `...?` and `bail!(...)` and that the thread
// itself can report that an error has occurred
let download = || async move {
for (i, segment) in thread_segments.into_iter().enumerate() {
let mut retry_count = 0;
let mut buf = loop {
let request = thread_client
.get(&segment.url)
.timeout(Duration::from_secs(60))
.send();
let response = match request.await {
Ok(r) => r,
Err(e) => {
if retry_count == 5 {
bail!("Max retry count reached ({}), multiple errors occurred while receiving segment {}: {}", retry_count, num + (i * cpus), e)
}
debug!("Failed to download segment {} ({}). Retrying, {} out of 5 retries left", num + (i * cpus), e, 5 - retry_count);
continue
}
};
match response.bytes().await {
Ok(b) => break b.to_vec(),
Err(e) => {
if e.is_body() {
let response = match request.await {
Ok(r) => r,
Err(e) => {
if retry_count == 5 {
bail!("Max retry count reached ({}), multiple errors occurred while receiving segment {}: {}", retry_count, num + (i * cpus), e)
}
debug!("Failed to download segment {} ({}). Retrying, {} out of 5 retries left", num + (i * cpus), e, 5 - retry_count)
} else {
bail!("{}", e)
debug!("Failed to download segment {} ({}). Retrying, {} out of 5 retries left", num + (i * cpus), e, 5 - retry_count);
continue
}
};
match response.bytes().await {
Ok(b) => break b.to_vec(),
Err(e) => {
if e.is_body() {
if retry_count == 5 {
bail!("Max retry count reached ({}), multiple errors occurred while receiving segment {}: {}", retry_count, num + (i * cpus), e)
}
debug!("Failed to download segment {} ({}). Retrying, {} out of 5 retries left", num + (i * cpus), e, 5 - retry_count)
} else {
bail!("{}", e)
}
}
}
}
retry_count += 1;
};
retry_count += 1;
};
buf = VariantSegment::decrypt(buf.borrow_mut(), segment.key)?.to_vec();
buf = VariantSegment::decrypt(buf.borrow_mut(), segment.key)?.to_vec();
let mut c = thread_count.lock().await;
debug!(
"Downloaded and decrypted segment [{}/{} {:.2}%] {}",
num + (i * cpus) + 1,
total_segments,
((*c + 1) as f64 / total_segments as f64) * 100f64,
segment.url
);
let mut c = thread_count.lock().await;
debug!(
"Downloaded and decrypted segment [{}/{} {:.2}%] {}",
num + (i * cpus) + 1,
total_segments,
((*c + 1) as f64 / total_segments as f64) * 100f64,
segment.url
);
thread_sender.send((num as i32 + (i * cpus) as i32, buf))?;
thread_sender.send((num as i32 + (i * cpus) as i32, buf))?;
*c += 1;
*c += 1;
}
Ok(())
};
let result = download().await;
if result.is_err() {
after_download_sender.send((-1 as i32, vec![]))?;
}
Ok(())
};
result
});
}
// drop the sender already here so it does not outlive all download threads which are the only
// real consumers of it
drop(sender);
let result = download().await;
if result.is_err() {
after_download_sender.send((-1 as i32, vec![]))?;
// this is the main loop which writes the data. it uses a BTreeMap as a buffer as the write
// happens synchronized. the download consist of multiple segments. the map keys are representing
// the segment number and the values the corresponding bytes
let mut data_pos = 0;
let mut buf: BTreeMap<i32, Vec<u8>> = BTreeMap::new();
while let Some((pos, bytes)) = receiver.recv().await {
// if the position is lower than 0, an error occurred in the sending download thread
if pos < 0 {
break;
}
result
});
}
// drop the sender already here so it does not outlive all download threads which are the only
// real consumers of it
drop(sender);
if let Some(p) = &progress {
let progress_len = p.length().unwrap();
let estimated_segment_len = (variant_data.bandwidth / 8)
* segments.get(pos as usize).unwrap().length.as_secs();
let bytes_len = bytes.len() as u64;
// this is the main loop which writes the data. it uses a BTreeMap as a buffer as the write
// happens synchronized. the download consist of multiple segments. the map keys are representing
// the segment number and the values the corresponding bytes
let mut data_pos = 0;
let mut buf: BTreeMap<i32, Vec<u8>> = BTreeMap::new();
while let Some((pos, bytes)) = receiver.recv().await {
// if the position is lower than 0, an error occurred in the sending download thread
if pos < 0 {
break;
p.set_length(progress_len - estimated_segment_len + bytes_len);
p.inc(bytes_len)
}
// check if the currently sent bytes are the next in the buffer. if so, write them directly
// to the target without first adding them to the buffer.
// if not, add them to the buffer
if data_pos == pos {
writer.write_all(bytes.borrow())?;
data_pos += 1;
} else {
buf.insert(pos, bytes);
}
// check if the buffer contains the next segment(s)
while let Some(b) = buf.remove(&data_pos) {
writer.write_all(b.borrow())?;
data_pos += 1;
}
}
if let Some(p) = &progress {
let progress_len = p.length().unwrap();
let estimated_segment_len =
(variant_data.bandwidth / 8) * segments.get(pos as usize).unwrap().length.as_secs();
let bytes_len = bytes.len() as u64;
p.set_length(progress_len - estimated_segment_len + bytes_len);
p.inc(bytes_len)
// if any error has occurred while downloading it gets returned here
while let Some(joined) = join_set.join_next().await {
joined??
}
// check if the currently sent bytes are the next in the buffer. if so, write them directly
// to the target without first adding them to the buffer.
// if not, add them to the buffer
if data_pos == pos {
writer.write_all(bytes.borrow())?;
data_pos += 1;
} else {
buf.insert(pos, bytes);
}
// check if the buffer contains the next segment(s)
// write the remaining buffer, if existent
while let Some(b) = buf.remove(&data_pos) {
writer.write_all(b.borrow())?;
data_pos += 1;
}
}
// if any error has occurred while downloading it gets returned here
while let Some(joined) = join_set.join_next().await {
joined??
}
if !buf.is_empty() {
bail!(
"Download buffer is not empty. Remaining segments: {}",
buf.into_keys()
.map(|k| k.to_string())
.collect::<Vec<String>>()
.join(", ")
)
}
// write the remaining buffer, if existent
while let Some(b) = buf.remove(&data_pos) {
writer.write_all(b.borrow())?;
data_pos += 1;
Ok(())
}
if !buf.is_empty() {
bail!(
"Download buffer is not empty. Remaining segments: {}",
buf.into_keys()
.map(|k| k.to_string())
.collect::<Vec<String>>()
.join(", ")
)
}
Ok(())
}
fn estimate_variant_file_size(variant_data: &VariantData, segments: &Vec<VariantSegment>) -> u64 {