Add fetch task cancellation

master
Nick Krichevsky 2022-10-16 20:38:02 -04:00
parent 9984684ae6
commit 44de0b43ce
7 changed files with 346 additions and 11 deletions

10
Cargo.lock generated
View File

@ -2072,6 +2072,15 @@ version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
[[package]]
name = "uuid"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "feb41e78f93363bb2df8b0e86a2ca30eed7806ea16ea0c790d757cf93f79be83"
dependencies = [
"getrandom 0.2.7",
]
[[package]]
name = "value-bag"
version = "1.0.0-alpha.9"
@ -2335,4 +2344,5 @@ dependencies = [
"thiserror",
"tokio",
"url",
"uuid",
]

View File

@ -25,6 +25,7 @@ regex = "1.6"
reqwest = { version = "0.11.12", features = ["json"] }
url = "2.2"
anyhow = "1.0"
uuid = { version = "1.2", features = ["v4"] }
# For annoying reasons, we must pin exactly the same versions as async-imap if we want to use
# their types.

View File

@ -16,7 +16,7 @@ YNABifier, while functional, is still a bit of a work in progress. There are som
Planned improvements:
- [ ] Streamline the configuration process to not require fetching YNAB IDs by hand.
- [ ] IMAP sessions need to be more properly cleaned up.
- [x] IMAP sessions need to be more properly cleaned up.
- [ ] Various QoL features, such as allowing alternate configuration paths
## Setup

View File

@ -1,5 +1,9 @@
use futures::{
channel::mpsc::{self, Receiver},
channel::{
mpsc::{self, Receiver},
oneshot,
},
lock::Mutex,
Sink, SinkExt, Stream, StreamExt,
};
use mailparse::{MailParseError, ParsedMail};
@ -14,7 +18,7 @@ use stop_token::{StopSource, StopToken};
use thiserror::Error;
use crate::{
task::{self, ResolveOrStop, Spawn, SpawnError},
task::{self, Join, Registry, ResolveOrStop, Spawn, SpawnError},
CloseableStream, CHANNEL_SIZE,
};
@ -162,6 +166,7 @@ where
F::Error: Sync + Send,
S: Spawn + Send + Sync + 'static,
S::Handle: 'static,
<<S as Spawn>::Handle as Join>::Error: Send,
{
let (tx, rx) = mpsc::channel(CHANNEL_SIZE);
@ -199,6 +204,7 @@ async fn stream_incoming_messages_to_sink<S, N, F, O>(
) where
S: Spawn + Send + Sync + 'static,
S::Handle: 'static,
<<S as Spawn>::Handle as Join>::Error: Send,
N: CloseableStream<Item = SequenceNumber> + Send + Unpin,
F: MessageFetcher + Send + Sync + 'static,
F::Error: Send + Sync,
@ -206,6 +212,7 @@ async fn stream_incoming_messages_to_sink<S, N, F, O>(
O::Error: Debug,
{
let fetcher_arc = Arc::new(fetcher);
let task_registry = Arc::new(Mutex::new(Registry::new()));
while let Some(sequence_number) = sequence_number_stream
.next()
.resolve_or_stop(stop_token)
@ -243,16 +250,40 @@ async fn stream_incoming_messages_to_sink<S, N, F, O>(
sequence_number
);
// TODO: this needs to be able to be handled by Close
let spawn_res = spawn.spawn(fetch_future);
if let Err(spawn_err) = spawn_res {
error!(
"failed to spawn task to fetch sequence number {sequence_number}: {spawn_err:?}"
);
let (token_tx, token_rx) = oneshot::channel();
let registry_weak = Arc::downgrade(&task_registry);
let spawn_res = spawn.spawn(async move {
let task_registry = registry_weak;
// This cannot happen, as the tx channel cannot have been dropped.
let task_token = token_rx.await.expect("failed to get task token");
fetch_future.await;
// Unregister ourselves fro the registry.
// We don't really care if the registry has been dropped, as all we're trying to signal is that we're done.
if let Some(registry) = task_registry.upgrade() {
registry.lock().await.unregister_handle(task_token);
}
});
match spawn_res {
Ok(task_handle) => {
let token = task_registry.lock().await.register_handle(task_handle);
// This cannot happen, as the task cannot have been dropped.
token_tx.send(token).expect("failed to send task token");
}
Err(spawn_err) => {
error!(
"failed to spawn task to fetch sequence number {sequence_number}: {spawn_err:?}"
);
}
}
}
sequence_number_stream.close().await;
debug!("Joining all remaining fetch tasks...");
if let Err(err) = task_registry.lock().await.join_all().await {
error!("failed to join all fetch tasks: {:?}", err);
};
}
/// Fetch a message with the given sequence number, and send its output to this Task's

View File

@ -19,7 +19,7 @@ use email::{
message::RawFetcher,
};
use futures::Stream;
use task::{Spawn, SpawnError};
use task::{Join, Spawn, SpawnError};
const CHANNEL_SIZE: usize = 16;
@ -76,6 +76,7 @@ pub async fn stream_new_messages<S>(
where
S: Spawn + Send + Sync + Unpin + 'static,
S::Handle: Unpin + 'static,
<<S as Spawn>::Handle as Join>::Error: Send,
{
let session_generator_arc = Arc::new(ConfigSessionGenerator::new(imap_config.clone()));
let watcher =

View File

@ -1,12 +1,13 @@
use async_trait::async_trait;
pub(crate) use interrupt::ResolveOrStop;
use std::error::Error;
use std::fmt::{Display, Formatter};
pub(crate) use {interrupt::ResolveOrStop, register::Registry};
use futures::Future;
use thiserror::Error;
mod interrupt;
mod register;
/// `SpawnError` describes why a spawn may have failed to occur.
#[derive(Error, Debug)]

291
src/task/register.rs Normal file
View File

@ -0,0 +1,291 @@
//! Provides the [`Registry`] type, which can be used to manage a set of tasks
use std::collections::HashMap;
use std::mem;
use uuid::Uuid;
use super::Handle;
/// A `TaskToken` is a unique identifier that represents a task in Registry.
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
pub struct TaskToken(Uuid);
// A set of tasks that can be stopped as a whole. Each task can be registered individually, and
// unregistered at any point.
pub struct Registry<H> {
handles: HashMap<TaskToken, H>,
}
impl<H: Handle> Registry<H> {
pub fn new() -> Self {
Self {
handles: HashMap::new(),
}
}
/// Register a task handle, and get back a unique [`TaskToken`] to handle its lifetime.
pub fn register_handle(&mut self, handle: H) -> TaskToken {
let token = TaskToken(Uuid::new_v4());
self.handles.insert(token, handle);
token
}
/// Unregister a handle using its task token. Once a task has been unregistered, [`cancel_all`] and
/// [`join_all`] will have no effect on this task.
pub fn unregister_handle(&mut self, token: TaskToken) {
self.handles.remove(&token);
}
/// Cancel all of the the tasks in this registry. If either [`cancel_all`] or [`join_all`] are
/// called after this, the tasks that were currently in the registry will no longer be acted
/// on. In effect, these tasks are unregistered before cancellation.
pub fn cancel_all(&mut self) {
let handles = mem::take(&mut self.handles);
for handle in handles.into_values() {
handle.cancel();
}
}
/// Join all of the the tasks in this registry. If either [`cancel_all`] or [`join_all`] are
/// called after this, the tasks that were currently in the registry will no longer be acted
/// on. In effect, these tasks are unregistered before joining.
///
/// # Errors
///
/// A list of all the join errors that occurred will be returned.
pub async fn join_all(&mut self) -> Result<(), Vec<H::Error>> {
let handles = mem::take(&mut self.handles);
let mut errors = Vec::new();
for handle in handles.into_values() {
if let Err(err) = handle.join().await {
errors.push(err);
}
}
if errors.is_empty() {
Ok(())
} else {
Err(errors)
}
}
}
#[cfg(test)]
mod tests {
use std::{ops::AddAssign, sync::Arc, time::Duration};
use super::*;
use futures::{channel::oneshot, lock::Mutex};
use tokio::time;
use crate::{task::Spawn, testutil::TokioSpawner};
fn unpack_pairs<T, U, I: Iterator<Item = (T, U)>>(iter: I) -> (Vec<T>, Vec<U>) {
iter.fold((vec![], vec![]), |mut res, (tx, rx)| {
res.0.push(tx);
res.1.push(rx);
res
})
}
#[tokio::test]
async fn cancel_call_stops_all_tasks() {
let (txs_to_task, rxs_in_task) = unpack_pairs((1..=5).map(|_| oneshot::channel::<()>()));
let (txs_in_task, rxs_from_task) = unpack_pairs((1..=5).map(|_| oneshot::channel::<()>()));
let spawner = TokioSpawner;
let handles = rxs_in_task
.into_iter()
.zip(txs_in_task)
.map(|(rx_in_task, tx_in_task)| async move {
rx_in_task.await.expect("failed to rx in task");
tx_in_task.send(()).expect("failed to tx in task");
})
.map(|task| spawner.spawn(task).expect("failed to spawn task"));
let mut registry = Registry::new();
for handle in handles {
registry.register_handle(handle);
}
registry.cancel_all();
// Unblock all the tasks
for tx in txs_to_task {
// it's ok if this fails - the task might have already been cancelled
let _res = tx.send(());
}
for rx in rxs_from_task {
time::timeout(Duration::from_secs(5), async move {
rx.await.expect_err("should not have received from tasks");
})
.await
.expect("test timed out");
}
}
#[tokio::test]
async fn dropping_registry_should_not_cancel_tasks() {
let (txs_to_task, rxs_in_task) = unpack_pairs((1..=5).map(|_| oneshot::channel::<()>()));
let (txs_in_task, rxs_from_task) = unpack_pairs((1..=5).map(|_| oneshot::channel::<()>()));
let spawner = TokioSpawner;
let handles = rxs_in_task
.into_iter()
.zip(txs_in_task)
.map(|(rx_in_task, tx_in_task)| async move {
rx_in_task.await.expect("failed to rx in task");
tx_in_task.send(()).expect("failed to tx in task");
})
.map(|task| spawner.spawn(task).expect("failed to spawn task"));
let mut registry = Registry::new();
for handle in handles {
registry.register_handle(handle);
}
// Unblock all the tasks
for tx in txs_to_task {
tx.send(()).expect("failed to tx into task");
}
drop(registry);
for rx in rxs_from_task {
time::timeout(Duration::from_secs(5), async move {
rx.await.expect("should not received from tasks");
})
.await
.expect("test timed out");
}
}
#[tokio::test]
async fn join_all_waits_for_task_completion() {
let count = Arc::new(Mutex::new(0));
let (txs_to_task, rxs_in_task) = unpack_pairs((1..=500).map(|_| oneshot::channel::<()>()));
let spawner = TokioSpawner;
let handles = rxs_in_task
.into_iter()
.map(|rx_in_task| {
let count_clone = count.clone();
async move {
rx_in_task.await.expect("failed to rx in task");
// This is a little ham-fisted, but gives us some assurance that the tasks won't finish
// before we're ready for them.
time::sleep(Duration::from_millis(100)).await;
count_clone.lock().await.add_assign(1);
}
})
.map(|task| spawner.spawn(task).expect("failed to spawn task"));
let mut registry = Registry::new();
for handle in handles {
registry.register_handle(handle);
}
// Unblock all the tasks
for tx in txs_to_task {
tx.send(()).expect("failed to tx into task");
}
let join_res = time::timeout(Duration::from_secs(5), registry.join_all())
.await
.expect("test timed out");
join_res.expect("failed to join tasks");
assert_eq!(500, *count.lock().await);
}
#[tokio::test]
async fn removing_token_prevents_it_from_being_joined() {
let count = Arc::new(Mutex::new(0));
let (txs_to_task, rxs_in_task) = unpack_pairs((1..=500).map(|_| oneshot::channel::<()>()));
let spawner = TokioSpawner;
let handles = rxs_in_task
.into_iter()
.map(|rx_in_task| {
let count_clone = count.clone();
async move {
rx_in_task.await.expect("failed to rx in task");
// This is a little ham-fisted, but gives us some assurance that the tasks won't finish
// before we're ready for them.
time::sleep(Duration::from_millis(100)).await;
count_clone.lock().await.add_assign(1);
}
})
.map(|task| spawner.spawn(task).expect("failed to spawn task"));
let mut registry = Registry::new();
for handle in handles {
registry.register_handle(handle);
}
let (_tx, never_rx) = oneshot::channel::<()>();
let infinite_task_handle = spawner
.spawn(async move { never_rx.await })
.expect("failed to spawn");
let infinite_task_token = registry.register_handle(infinite_task_handle);
// Unblock all the finite tasks
for tx in txs_to_task {
tx.send(()).expect("failed to tx into task");
}
registry.unregister_handle(infinite_task_token);
let join_res = time::timeout(Duration::from_secs(5), registry.join_all())
.await
.expect("test timed out; it's possible the task wasn't unregistered");
join_res.expect("failed to join tasks");
}
#[tokio::test]
async fn removing_token_prevents_it_from_being_cancelled() {
let (txs_to_task, rxs_in_task) = unpack_pairs((1..=5).map(|_| oneshot::channel::<()>()));
let (txs_in_task, rxs_from_task) = unpack_pairs((1..=5).map(|_| oneshot::channel::<()>()));
let spawner = TokioSpawner;
let mut handles = rxs_in_task
.into_iter()
.zip(txs_in_task)
.map(|(rx_in_task, tx_in_task)| async move {
rx_in_task.await.expect("failed to rx in task");
tx_in_task.send(()).expect("failed to tx in task");
})
.map(|task| spawner.spawn(task).expect("failed to spawn task"));
let mut registry = Registry::new();
let first_task_token = registry.register_handle(handles.next().unwrap());
for handle in handles {
registry.register_handle(handle);
}
registry.unregister_handle(first_task_token);
registry.cancel_all();
// Unblock all the tasks
for tx in txs_to_task {
// it's ok if this fails - the task might have already been cancelled
let _res = tx.send(());
}
let mut rx_iter = rxs_from_task.into_iter();
// Ensure we receive from the task we unregistered
rx_iter
.next()
.unwrap()
.await
.expect("should have received from unregistered task");
// ...but not from any of the others
for rx in rx_iter {
time::timeout(Duration::from_secs(5), async move {
rx.await.expect_err("should not have received from tasks");
})
.await
.expect("test timed out");
}
}
}