Add fetch task cancellation
parent
9984684ae6
commit
44de0b43ce
|
@ -2072,6 +2072,15 @@ version = "0.7.6"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
|
||||
|
||||
[[package]]
|
||||
name = "uuid"
|
||||
version = "1.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "feb41e78f93363bb2df8b0e86a2ca30eed7806ea16ea0c790d757cf93f79be83"
|
||||
dependencies = [
|
||||
"getrandom 0.2.7",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "value-bag"
|
||||
version = "1.0.0-alpha.9"
|
||||
|
@ -2335,4 +2344,5 @@ dependencies = [
|
|||
"thiserror",
|
||||
"tokio",
|
||||
"url",
|
||||
"uuid",
|
||||
]
|
||||
|
|
|
@ -25,6 +25,7 @@ regex = "1.6"
|
|||
reqwest = { version = "0.11.12", features = ["json"] }
|
||||
url = "2.2"
|
||||
anyhow = "1.0"
|
||||
uuid = { version = "1.2", features = ["v4"] }
|
||||
|
||||
# For annoying reasons, we must pin exactly the same versions as async-imap if we want to use
|
||||
# their types.
|
||||
|
|
|
@ -16,7 +16,7 @@ YNABifier, while functional, is still a bit of a work in progress. There are som
|
|||
|
||||
Planned improvements:
|
||||
- [ ] Streamline the configuration process to not require fetching YNAB IDs by hand.
|
||||
- [ ] IMAP sessions need to be more properly cleaned up.
|
||||
- [x] IMAP sessions need to be more properly cleaned up.
|
||||
- [ ] Various QoL features, such as allowing alternate configuration paths
|
||||
|
||||
## Setup
|
||||
|
|
47
src/email.rs
47
src/email.rs
|
@ -1,5 +1,9 @@
|
|||
use futures::{
|
||||
channel::mpsc::{self, Receiver},
|
||||
channel::{
|
||||
mpsc::{self, Receiver},
|
||||
oneshot,
|
||||
},
|
||||
lock::Mutex,
|
||||
Sink, SinkExt, Stream, StreamExt,
|
||||
};
|
||||
use mailparse::{MailParseError, ParsedMail};
|
||||
|
@ -14,7 +18,7 @@ use stop_token::{StopSource, StopToken};
|
|||
use thiserror::Error;
|
||||
|
||||
use crate::{
|
||||
task::{self, ResolveOrStop, Spawn, SpawnError},
|
||||
task::{self, Join, Registry, ResolveOrStop, Spawn, SpawnError},
|
||||
CloseableStream, CHANNEL_SIZE,
|
||||
};
|
||||
|
||||
|
@ -162,6 +166,7 @@ where
|
|||
F::Error: Sync + Send,
|
||||
S: Spawn + Send + Sync + 'static,
|
||||
S::Handle: 'static,
|
||||
<<S as Spawn>::Handle as Join>::Error: Send,
|
||||
{
|
||||
let (tx, rx) = mpsc::channel(CHANNEL_SIZE);
|
||||
|
||||
|
@ -199,6 +204,7 @@ async fn stream_incoming_messages_to_sink<S, N, F, O>(
|
|||
) where
|
||||
S: Spawn + Send + Sync + 'static,
|
||||
S::Handle: 'static,
|
||||
<<S as Spawn>::Handle as Join>::Error: Send,
|
||||
N: CloseableStream<Item = SequenceNumber> + Send + Unpin,
|
||||
F: MessageFetcher + Send + Sync + 'static,
|
||||
F::Error: Send + Sync,
|
||||
|
@ -206,6 +212,7 @@ async fn stream_incoming_messages_to_sink<S, N, F, O>(
|
|||
O::Error: Debug,
|
||||
{
|
||||
let fetcher_arc = Arc::new(fetcher);
|
||||
let task_registry = Arc::new(Mutex::new(Registry::new()));
|
||||
while let Some(sequence_number) = sequence_number_stream
|
||||
.next()
|
||||
.resolve_or_stop(stop_token)
|
||||
|
@ -243,16 +250,40 @@ async fn stream_incoming_messages_to_sink<S, N, F, O>(
|
|||
sequence_number
|
||||
);
|
||||
|
||||
// TODO: this needs to be able to be handled by Close
|
||||
let spawn_res = spawn.spawn(fetch_future);
|
||||
if let Err(spawn_err) = spawn_res {
|
||||
error!(
|
||||
"failed to spawn task to fetch sequence number {sequence_number}: {spawn_err:?}"
|
||||
);
|
||||
let (token_tx, token_rx) = oneshot::channel();
|
||||
let registry_weak = Arc::downgrade(&task_registry);
|
||||
let spawn_res = spawn.spawn(async move {
|
||||
let task_registry = registry_weak;
|
||||
// This cannot happen, as the tx channel cannot have been dropped.
|
||||
let task_token = token_rx.await.expect("failed to get task token");
|
||||
fetch_future.await;
|
||||
|
||||
// Unregister ourselves fro the registry.
|
||||
// We don't really care if the registry has been dropped, as all we're trying to signal is that we're done.
|
||||
if let Some(registry) = task_registry.upgrade() {
|
||||
registry.lock().await.unregister_handle(task_token);
|
||||
}
|
||||
});
|
||||
|
||||
match spawn_res {
|
||||
Ok(task_handle) => {
|
||||
let token = task_registry.lock().await.register_handle(task_handle);
|
||||
// This cannot happen, as the task cannot have been dropped.
|
||||
token_tx.send(token).expect("failed to send task token");
|
||||
}
|
||||
Err(spawn_err) => {
|
||||
error!(
|
||||
"failed to spawn task to fetch sequence number {sequence_number}: {spawn_err:?}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sequence_number_stream.close().await;
|
||||
debug!("Joining all remaining fetch tasks...");
|
||||
if let Err(err) = task_registry.lock().await.join_all().await {
|
||||
error!("failed to join all fetch tasks: {:?}", err);
|
||||
};
|
||||
}
|
||||
|
||||
/// Fetch a message with the given sequence number, and send its output to this Task's
|
||||
|
|
|
@ -19,7 +19,7 @@ use email::{
|
|||
message::RawFetcher,
|
||||
};
|
||||
use futures::Stream;
|
||||
use task::{Spawn, SpawnError};
|
||||
use task::{Join, Spawn, SpawnError};
|
||||
|
||||
const CHANNEL_SIZE: usize = 16;
|
||||
|
||||
|
@ -76,6 +76,7 @@ pub async fn stream_new_messages<S>(
|
|||
where
|
||||
S: Spawn + Send + Sync + Unpin + 'static,
|
||||
S::Handle: Unpin + 'static,
|
||||
<<S as Spawn>::Handle as Join>::Error: Send,
|
||||
{
|
||||
let session_generator_arc = Arc::new(ConfigSessionGenerator::new(imap_config.clone()));
|
||||
let watcher =
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
use async_trait::async_trait;
|
||||
pub(crate) use interrupt::ResolveOrStop;
|
||||
use std::error::Error;
|
||||
use std::fmt::{Display, Formatter};
|
||||
pub(crate) use {interrupt::ResolveOrStop, register::Registry};
|
||||
|
||||
use futures::Future;
|
||||
use thiserror::Error;
|
||||
|
||||
mod interrupt;
|
||||
mod register;
|
||||
|
||||
/// `SpawnError` describes why a spawn may have failed to occur.
|
||||
#[derive(Error, Debug)]
|
||||
|
|
|
@ -0,0 +1,291 @@
|
|||
//! Provides the [`Registry`] type, which can be used to manage a set of tasks
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use std::mem;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::Handle;
|
||||
|
||||
/// A `TaskToken` is a unique identifier that represents a task in Registry.
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
|
||||
pub struct TaskToken(Uuid);
|
||||
|
||||
// A set of tasks that can be stopped as a whole. Each task can be registered individually, and
|
||||
// unregistered at any point.
|
||||
pub struct Registry<H> {
|
||||
handles: HashMap<TaskToken, H>,
|
||||
}
|
||||
|
||||
impl<H: Handle> Registry<H> {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
handles: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a task handle, and get back a unique [`TaskToken`] to handle its lifetime.
|
||||
pub fn register_handle(&mut self, handle: H) -> TaskToken {
|
||||
let token = TaskToken(Uuid::new_v4());
|
||||
self.handles.insert(token, handle);
|
||||
token
|
||||
}
|
||||
|
||||
/// Unregister a handle using its task token. Once a task has been unregistered, [`cancel_all`] and
|
||||
/// [`join_all`] will have no effect on this task.
|
||||
pub fn unregister_handle(&mut self, token: TaskToken) {
|
||||
self.handles.remove(&token);
|
||||
}
|
||||
|
||||
/// Cancel all of the the tasks in this registry. If either [`cancel_all`] or [`join_all`] are
|
||||
/// called after this, the tasks that were currently in the registry will no longer be acted
|
||||
/// on. In effect, these tasks are unregistered before cancellation.
|
||||
pub fn cancel_all(&mut self) {
|
||||
let handles = mem::take(&mut self.handles);
|
||||
for handle in handles.into_values() {
|
||||
handle.cancel();
|
||||
}
|
||||
}
|
||||
|
||||
/// Join all of the the tasks in this registry. If either [`cancel_all`] or [`join_all`] are
|
||||
/// called after this, the tasks that were currently in the registry will no longer be acted
|
||||
/// on. In effect, these tasks are unregistered before joining.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// A list of all the join errors that occurred will be returned.
|
||||
pub async fn join_all(&mut self) -> Result<(), Vec<H::Error>> {
|
||||
let handles = mem::take(&mut self.handles);
|
||||
|
||||
let mut errors = Vec::new();
|
||||
for handle in handles.into_values() {
|
||||
if let Err(err) = handle.join().await {
|
||||
errors.push(err);
|
||||
}
|
||||
}
|
||||
|
||||
if errors.is_empty() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(errors)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{ops::AddAssign, sync::Arc, time::Duration};
|
||||
|
||||
use super::*;
|
||||
use futures::{channel::oneshot, lock::Mutex};
|
||||
use tokio::time;
|
||||
|
||||
use crate::{task::Spawn, testutil::TokioSpawner};
|
||||
|
||||
fn unpack_pairs<T, U, I: Iterator<Item = (T, U)>>(iter: I) -> (Vec<T>, Vec<U>) {
|
||||
iter.fold((vec![], vec![]), |mut res, (tx, rx)| {
|
||||
res.0.push(tx);
|
||||
res.1.push(rx);
|
||||
res
|
||||
})
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cancel_call_stops_all_tasks() {
|
||||
let (txs_to_task, rxs_in_task) = unpack_pairs((1..=5).map(|_| oneshot::channel::<()>()));
|
||||
let (txs_in_task, rxs_from_task) = unpack_pairs((1..=5).map(|_| oneshot::channel::<()>()));
|
||||
let spawner = TokioSpawner;
|
||||
let handles = rxs_in_task
|
||||
.into_iter()
|
||||
.zip(txs_in_task)
|
||||
.map(|(rx_in_task, tx_in_task)| async move {
|
||||
rx_in_task.await.expect("failed to rx in task");
|
||||
tx_in_task.send(()).expect("failed to tx in task");
|
||||
})
|
||||
.map(|task| spawner.spawn(task).expect("failed to spawn task"));
|
||||
|
||||
let mut registry = Registry::new();
|
||||
for handle in handles {
|
||||
registry.register_handle(handle);
|
||||
}
|
||||
|
||||
registry.cancel_all();
|
||||
|
||||
// Unblock all the tasks
|
||||
for tx in txs_to_task {
|
||||
// it's ok if this fails - the task might have already been cancelled
|
||||
let _res = tx.send(());
|
||||
}
|
||||
|
||||
for rx in rxs_from_task {
|
||||
time::timeout(Duration::from_secs(5), async move {
|
||||
rx.await.expect_err("should not have received from tasks");
|
||||
})
|
||||
.await
|
||||
.expect("test timed out");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn dropping_registry_should_not_cancel_tasks() {
|
||||
let (txs_to_task, rxs_in_task) = unpack_pairs((1..=5).map(|_| oneshot::channel::<()>()));
|
||||
let (txs_in_task, rxs_from_task) = unpack_pairs((1..=5).map(|_| oneshot::channel::<()>()));
|
||||
let spawner = TokioSpawner;
|
||||
let handles = rxs_in_task
|
||||
.into_iter()
|
||||
.zip(txs_in_task)
|
||||
.map(|(rx_in_task, tx_in_task)| async move {
|
||||
rx_in_task.await.expect("failed to rx in task");
|
||||
tx_in_task.send(()).expect("failed to tx in task");
|
||||
})
|
||||
.map(|task| spawner.spawn(task).expect("failed to spawn task"));
|
||||
|
||||
let mut registry = Registry::new();
|
||||
for handle in handles {
|
||||
registry.register_handle(handle);
|
||||
}
|
||||
|
||||
// Unblock all the tasks
|
||||
for tx in txs_to_task {
|
||||
tx.send(()).expect("failed to tx into task");
|
||||
}
|
||||
|
||||
drop(registry);
|
||||
for rx in rxs_from_task {
|
||||
time::timeout(Duration::from_secs(5), async move {
|
||||
rx.await.expect("should not received from tasks");
|
||||
})
|
||||
.await
|
||||
.expect("test timed out");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn join_all_waits_for_task_completion() {
|
||||
let count = Arc::new(Mutex::new(0));
|
||||
let (txs_to_task, rxs_in_task) = unpack_pairs((1..=500).map(|_| oneshot::channel::<()>()));
|
||||
let spawner = TokioSpawner;
|
||||
let handles = rxs_in_task
|
||||
.into_iter()
|
||||
.map(|rx_in_task| {
|
||||
let count_clone = count.clone();
|
||||
async move {
|
||||
rx_in_task.await.expect("failed to rx in task");
|
||||
// This is a little ham-fisted, but gives us some assurance that the tasks won't finish
|
||||
// before we're ready for them.
|
||||
time::sleep(Duration::from_millis(100)).await;
|
||||
count_clone.lock().await.add_assign(1);
|
||||
}
|
||||
})
|
||||
.map(|task| spawner.spawn(task).expect("failed to spawn task"));
|
||||
|
||||
let mut registry = Registry::new();
|
||||
for handle in handles {
|
||||
registry.register_handle(handle);
|
||||
}
|
||||
|
||||
// Unblock all the tasks
|
||||
for tx in txs_to_task {
|
||||
tx.send(()).expect("failed to tx into task");
|
||||
}
|
||||
|
||||
let join_res = time::timeout(Duration::from_secs(5), registry.join_all())
|
||||
.await
|
||||
.expect("test timed out");
|
||||
|
||||
join_res.expect("failed to join tasks");
|
||||
|
||||
assert_eq!(500, *count.lock().await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn removing_token_prevents_it_from_being_joined() {
|
||||
let count = Arc::new(Mutex::new(0));
|
||||
let (txs_to_task, rxs_in_task) = unpack_pairs((1..=500).map(|_| oneshot::channel::<()>()));
|
||||
let spawner = TokioSpawner;
|
||||
let handles = rxs_in_task
|
||||
.into_iter()
|
||||
.map(|rx_in_task| {
|
||||
let count_clone = count.clone();
|
||||
async move {
|
||||
rx_in_task.await.expect("failed to rx in task");
|
||||
// This is a little ham-fisted, but gives us some assurance that the tasks won't finish
|
||||
// before we're ready for them.
|
||||
time::sleep(Duration::from_millis(100)).await;
|
||||
count_clone.lock().await.add_assign(1);
|
||||
}
|
||||
})
|
||||
.map(|task| spawner.spawn(task).expect("failed to spawn task"));
|
||||
|
||||
let mut registry = Registry::new();
|
||||
for handle in handles {
|
||||
registry.register_handle(handle);
|
||||
}
|
||||
|
||||
let (_tx, never_rx) = oneshot::channel::<()>();
|
||||
let infinite_task_handle = spawner
|
||||
.spawn(async move { never_rx.await })
|
||||
.expect("failed to spawn");
|
||||
let infinite_task_token = registry.register_handle(infinite_task_handle);
|
||||
|
||||
// Unblock all the finite tasks
|
||||
for tx in txs_to_task {
|
||||
tx.send(()).expect("failed to tx into task");
|
||||
}
|
||||
|
||||
registry.unregister_handle(infinite_task_token);
|
||||
|
||||
let join_res = time::timeout(Duration::from_secs(5), registry.join_all())
|
||||
.await
|
||||
.expect("test timed out; it's possible the task wasn't unregistered");
|
||||
|
||||
join_res.expect("failed to join tasks");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn removing_token_prevents_it_from_being_cancelled() {
|
||||
let (txs_to_task, rxs_in_task) = unpack_pairs((1..=5).map(|_| oneshot::channel::<()>()));
|
||||
let (txs_in_task, rxs_from_task) = unpack_pairs((1..=5).map(|_| oneshot::channel::<()>()));
|
||||
let spawner = TokioSpawner;
|
||||
let mut handles = rxs_in_task
|
||||
.into_iter()
|
||||
.zip(txs_in_task)
|
||||
.map(|(rx_in_task, tx_in_task)| async move {
|
||||
rx_in_task.await.expect("failed to rx in task");
|
||||
tx_in_task.send(()).expect("failed to tx in task");
|
||||
})
|
||||
.map(|task| spawner.spawn(task).expect("failed to spawn task"));
|
||||
|
||||
let mut registry = Registry::new();
|
||||
let first_task_token = registry.register_handle(handles.next().unwrap());
|
||||
for handle in handles {
|
||||
registry.register_handle(handle);
|
||||
}
|
||||
|
||||
registry.unregister_handle(first_task_token);
|
||||
registry.cancel_all();
|
||||
|
||||
// Unblock all the tasks
|
||||
for tx in txs_to_task {
|
||||
// it's ok if this fails - the task might have already been cancelled
|
||||
let _res = tx.send(());
|
||||
}
|
||||
|
||||
let mut rx_iter = rxs_from_task.into_iter();
|
||||
// Ensure we receive from the task we unregistered
|
||||
rx_iter
|
||||
.next()
|
||||
.unwrap()
|
||||
.await
|
||||
.expect("should have received from unregistered task");
|
||||
|
||||
// ...but not from any of the others
|
||||
for rx in rx_iter {
|
||||
time::timeout(Duration::from_secs(5), async move {
|
||||
rx.await.expect_err("should not have received from tasks");
|
||||
})
|
||||
.await
|
||||
.expect("test timed out");
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue