diff --git a/Cargo.lock b/Cargo.lock index 46ba0bf..ab3da58 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2072,6 +2072,15 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" +[[package]] +name = "uuid" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "feb41e78f93363bb2df8b0e86a2ca30eed7806ea16ea0c790d757cf93f79be83" +dependencies = [ + "getrandom 0.2.7", +] + [[package]] name = "value-bag" version = "1.0.0-alpha.9" @@ -2335,4 +2344,5 @@ dependencies = [ "thiserror", "tokio", "url", + "uuid", ] diff --git a/Cargo.toml b/Cargo.toml index 99d4d05..b306adb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,7 @@ regex = "1.6" reqwest = { version = "0.11.12", features = ["json"] } url = "2.2" anyhow = "1.0" +uuid = { version = "1.2", features = ["v4"] } # For annoying reasons, we must pin exactly the same versions as async-imap if we want to use # their types. diff --git a/README.md b/README.md index f7ab61b..5a00715 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ YNABifier, while functional, is still a bit of a work in progress. There are som Planned improvements: - [ ] Streamline the configuration process to not require fetching YNAB IDs by hand. - - [ ] IMAP sessions need to be more properly cleaned up. + - [x] IMAP sessions need to be more properly cleaned up. - [ ] Various QoL features, such as allowing alternate configuration paths ## Setup diff --git a/src/email.rs b/src/email.rs index 008ffcd..08d7e1e 100644 --- a/src/email.rs +++ b/src/email.rs @@ -1,5 +1,9 @@ use futures::{ - channel::mpsc::{self, Receiver}, + channel::{ + mpsc::{self, Receiver}, + oneshot, + }, + lock::Mutex, Sink, SinkExt, Stream, StreamExt, }; use mailparse::{MailParseError, ParsedMail}; @@ -14,7 +18,7 @@ use stop_token::{StopSource, StopToken}; use thiserror::Error; use crate::{ - task::{self, ResolveOrStop, Spawn, SpawnError}, + task::{self, Join, Registry, ResolveOrStop, Spawn, SpawnError}, CloseableStream, CHANNEL_SIZE, }; @@ -162,6 +166,7 @@ where F::Error: Sync + Send, S: Spawn + Send + Sync + 'static, S::Handle: 'static, + <::Handle as Join>::Error: Send, { let (tx, rx) = mpsc::channel(CHANNEL_SIZE); @@ -199,6 +204,7 @@ async fn stream_incoming_messages_to_sink( ) where S: Spawn + Send + Sync + 'static, S::Handle: 'static, + <::Handle as Join>::Error: Send, N: CloseableStream + Send + Unpin, F: MessageFetcher + Send + Sync + 'static, F::Error: Send + Sync, @@ -206,6 +212,7 @@ async fn stream_incoming_messages_to_sink( O::Error: Debug, { let fetcher_arc = Arc::new(fetcher); + let task_registry = Arc::new(Mutex::new(Registry::new())); while let Some(sequence_number) = sequence_number_stream .next() .resolve_or_stop(stop_token) @@ -243,16 +250,40 @@ async fn stream_incoming_messages_to_sink( sequence_number ); - // TODO: this needs to be able to be handled by Close - let spawn_res = spawn.spawn(fetch_future); - if let Err(spawn_err) = spawn_res { - error!( - "failed to spawn task to fetch sequence number {sequence_number}: {spawn_err:?}" - ); + let (token_tx, token_rx) = oneshot::channel(); + let registry_weak = Arc::downgrade(&task_registry); + let spawn_res = spawn.spawn(async move { + let task_registry = registry_weak; + // This cannot happen, as the tx channel cannot have been dropped. + let task_token = token_rx.await.expect("failed to get task token"); + fetch_future.await; + + // Unregister ourselves fro the registry. + // We don't really care if the registry has been dropped, as all we're trying to signal is that we're done. + if let Some(registry) = task_registry.upgrade() { + registry.lock().await.unregister_handle(task_token); + } + }); + + match spawn_res { + Ok(task_handle) => { + let token = task_registry.lock().await.register_handle(task_handle); + // This cannot happen, as the task cannot have been dropped. + token_tx.send(token).expect("failed to send task token"); + } + Err(spawn_err) => { + error!( + "failed to spawn task to fetch sequence number {sequence_number}: {spawn_err:?}" + ); + } } } sequence_number_stream.close().await; + debug!("Joining all remaining fetch tasks..."); + if let Err(err) = task_registry.lock().await.join_all().await { + error!("failed to join all fetch tasks: {:?}", err); + }; } /// Fetch a message with the given sequence number, and send its output to this Task's diff --git a/src/lib.rs b/src/lib.rs index 82b21d5..0b7de69 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,7 +19,7 @@ use email::{ message::RawFetcher, }; use futures::Stream; -use task::{Spawn, SpawnError}; +use task::{Join, Spawn, SpawnError}; const CHANNEL_SIZE: usize = 16; @@ -76,6 +76,7 @@ pub async fn stream_new_messages( where S: Spawn + Send + Sync + Unpin + 'static, S::Handle: Unpin + 'static, + <::Handle as Join>::Error: Send, { let session_generator_arc = Arc::new(ConfigSessionGenerator::new(imap_config.clone())); let watcher = diff --git a/src/task.rs b/src/task.rs index c9bd122..067b66e 100644 --- a/src/task.rs +++ b/src/task.rs @@ -1,12 +1,13 @@ use async_trait::async_trait; -pub(crate) use interrupt::ResolveOrStop; use std::error::Error; use std::fmt::{Display, Formatter}; +pub(crate) use {interrupt::ResolveOrStop, register::Registry}; use futures::Future; use thiserror::Error; mod interrupt; +mod register; /// `SpawnError` describes why a spawn may have failed to occur. #[derive(Error, Debug)] diff --git a/src/task/register.rs b/src/task/register.rs new file mode 100644 index 0000000..af18f98 --- /dev/null +++ b/src/task/register.rs @@ -0,0 +1,291 @@ +//! Provides the [`Registry`] type, which can be used to manage a set of tasks + +use std::collections::HashMap; + +use std::mem; +use uuid::Uuid; + +use super::Handle; + +/// A `TaskToken` is a unique identifier that represents a task in Registry. +#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] +pub struct TaskToken(Uuid); + +// A set of tasks that can be stopped as a whole. Each task can be registered individually, and +// unregistered at any point. +pub struct Registry { + handles: HashMap, +} + +impl Registry { + pub fn new() -> Self { + Self { + handles: HashMap::new(), + } + } + + /// Register a task handle, and get back a unique [`TaskToken`] to handle its lifetime. + pub fn register_handle(&mut self, handle: H) -> TaskToken { + let token = TaskToken(Uuid::new_v4()); + self.handles.insert(token, handle); + token + } + + /// Unregister a handle using its task token. Once a task has been unregistered, [`cancel_all`] and + /// [`join_all`] will have no effect on this task. + pub fn unregister_handle(&mut self, token: TaskToken) { + self.handles.remove(&token); + } + + /// Cancel all of the the tasks in this registry. If either [`cancel_all`] or [`join_all`] are + /// called after this, the tasks that were currently in the registry will no longer be acted + /// on. In effect, these tasks are unregistered before cancellation. + pub fn cancel_all(&mut self) { + let handles = mem::take(&mut self.handles); + for handle in handles.into_values() { + handle.cancel(); + } + } + + /// Join all of the the tasks in this registry. If either [`cancel_all`] or [`join_all`] are + /// called after this, the tasks that were currently in the registry will no longer be acted + /// on. In effect, these tasks are unregistered before joining. + /// + /// # Errors + /// + /// A list of all the join errors that occurred will be returned. + pub async fn join_all(&mut self) -> Result<(), Vec> { + let handles = mem::take(&mut self.handles); + + let mut errors = Vec::new(); + for handle in handles.into_values() { + if let Err(err) = handle.join().await { + errors.push(err); + } + } + + if errors.is_empty() { + Ok(()) + } else { + Err(errors) + } + } +} + +#[cfg(test)] +mod tests { + use std::{ops::AddAssign, sync::Arc, time::Duration}; + + use super::*; + use futures::{channel::oneshot, lock::Mutex}; + use tokio::time; + + use crate::{task::Spawn, testutil::TokioSpawner}; + + fn unpack_pairs>(iter: I) -> (Vec, Vec) { + iter.fold((vec![], vec![]), |mut res, (tx, rx)| { + res.0.push(tx); + res.1.push(rx); + res + }) + } + + #[tokio::test] + async fn cancel_call_stops_all_tasks() { + let (txs_to_task, rxs_in_task) = unpack_pairs((1..=5).map(|_| oneshot::channel::<()>())); + let (txs_in_task, rxs_from_task) = unpack_pairs((1..=5).map(|_| oneshot::channel::<()>())); + let spawner = TokioSpawner; + let handles = rxs_in_task + .into_iter() + .zip(txs_in_task) + .map(|(rx_in_task, tx_in_task)| async move { + rx_in_task.await.expect("failed to rx in task"); + tx_in_task.send(()).expect("failed to tx in task"); + }) + .map(|task| spawner.spawn(task).expect("failed to spawn task")); + + let mut registry = Registry::new(); + for handle in handles { + registry.register_handle(handle); + } + + registry.cancel_all(); + + // Unblock all the tasks + for tx in txs_to_task { + // it's ok if this fails - the task might have already been cancelled + let _res = tx.send(()); + } + + for rx in rxs_from_task { + time::timeout(Duration::from_secs(5), async move { + rx.await.expect_err("should not have received from tasks"); + }) + .await + .expect("test timed out"); + } + } + + #[tokio::test] + async fn dropping_registry_should_not_cancel_tasks() { + let (txs_to_task, rxs_in_task) = unpack_pairs((1..=5).map(|_| oneshot::channel::<()>())); + let (txs_in_task, rxs_from_task) = unpack_pairs((1..=5).map(|_| oneshot::channel::<()>())); + let spawner = TokioSpawner; + let handles = rxs_in_task + .into_iter() + .zip(txs_in_task) + .map(|(rx_in_task, tx_in_task)| async move { + rx_in_task.await.expect("failed to rx in task"); + tx_in_task.send(()).expect("failed to tx in task"); + }) + .map(|task| spawner.spawn(task).expect("failed to spawn task")); + + let mut registry = Registry::new(); + for handle in handles { + registry.register_handle(handle); + } + + // Unblock all the tasks + for tx in txs_to_task { + tx.send(()).expect("failed to tx into task"); + } + + drop(registry); + for rx in rxs_from_task { + time::timeout(Duration::from_secs(5), async move { + rx.await.expect("should not received from tasks"); + }) + .await + .expect("test timed out"); + } + } + + #[tokio::test] + async fn join_all_waits_for_task_completion() { + let count = Arc::new(Mutex::new(0)); + let (txs_to_task, rxs_in_task) = unpack_pairs((1..=500).map(|_| oneshot::channel::<()>())); + let spawner = TokioSpawner; + let handles = rxs_in_task + .into_iter() + .map(|rx_in_task| { + let count_clone = count.clone(); + async move { + rx_in_task.await.expect("failed to rx in task"); + // This is a little ham-fisted, but gives us some assurance that the tasks won't finish + // before we're ready for them. + time::sleep(Duration::from_millis(100)).await; + count_clone.lock().await.add_assign(1); + } + }) + .map(|task| spawner.spawn(task).expect("failed to spawn task")); + + let mut registry = Registry::new(); + for handle in handles { + registry.register_handle(handle); + } + + // Unblock all the tasks + for tx in txs_to_task { + tx.send(()).expect("failed to tx into task"); + } + + let join_res = time::timeout(Duration::from_secs(5), registry.join_all()) + .await + .expect("test timed out"); + + join_res.expect("failed to join tasks"); + + assert_eq!(500, *count.lock().await); + } + + #[tokio::test] + async fn removing_token_prevents_it_from_being_joined() { + let count = Arc::new(Mutex::new(0)); + let (txs_to_task, rxs_in_task) = unpack_pairs((1..=500).map(|_| oneshot::channel::<()>())); + let spawner = TokioSpawner; + let handles = rxs_in_task + .into_iter() + .map(|rx_in_task| { + let count_clone = count.clone(); + async move { + rx_in_task.await.expect("failed to rx in task"); + // This is a little ham-fisted, but gives us some assurance that the tasks won't finish + // before we're ready for them. + time::sleep(Duration::from_millis(100)).await; + count_clone.lock().await.add_assign(1); + } + }) + .map(|task| spawner.spawn(task).expect("failed to spawn task")); + + let mut registry = Registry::new(); + for handle in handles { + registry.register_handle(handle); + } + + let (_tx, never_rx) = oneshot::channel::<()>(); + let infinite_task_handle = spawner + .spawn(async move { never_rx.await }) + .expect("failed to spawn"); + let infinite_task_token = registry.register_handle(infinite_task_handle); + + // Unblock all the finite tasks + for tx in txs_to_task { + tx.send(()).expect("failed to tx into task"); + } + + registry.unregister_handle(infinite_task_token); + + let join_res = time::timeout(Duration::from_secs(5), registry.join_all()) + .await + .expect("test timed out; it's possible the task wasn't unregistered"); + + join_res.expect("failed to join tasks"); + } + + #[tokio::test] + async fn removing_token_prevents_it_from_being_cancelled() { + let (txs_to_task, rxs_in_task) = unpack_pairs((1..=5).map(|_| oneshot::channel::<()>())); + let (txs_in_task, rxs_from_task) = unpack_pairs((1..=5).map(|_| oneshot::channel::<()>())); + let spawner = TokioSpawner; + let mut handles = rxs_in_task + .into_iter() + .zip(txs_in_task) + .map(|(rx_in_task, tx_in_task)| async move { + rx_in_task.await.expect("failed to rx in task"); + tx_in_task.send(()).expect("failed to tx in task"); + }) + .map(|task| spawner.spawn(task).expect("failed to spawn task")); + + let mut registry = Registry::new(); + let first_task_token = registry.register_handle(handles.next().unwrap()); + for handle in handles { + registry.register_handle(handle); + } + + registry.unregister_handle(first_task_token); + registry.cancel_all(); + + // Unblock all the tasks + for tx in txs_to_task { + // it's ok if this fails - the task might have already been cancelled + let _res = tx.send(()); + } + + let mut rx_iter = rxs_from_task.into_iter(); + // Ensure we receive from the task we unregistered + rx_iter + .next() + .unwrap() + .await + .expect("should have received from unregistered task"); + + // ...but not from any of the others + for rx in rx_iter { + time::timeout(Duration::from_secs(5), async move { + rx.await.expect_err("should not have received from tasks"); + }) + .await + .expect("test timed out"); + } + } +}