lychee_lib/waiter.rs
1//! Facility to wait for a dynamic set of tasks to complete, with a single
2//! waiter and multiple waitees (things that are waited for). Notably, each
3//! waitee can also start more work to be waited for.
4//!
5//! # Implementation Details
6//!
7//! The implementation of waiting in this module is just a wrapper around
8//! [`tokio::sync::mpsc::channel`]. A [`WaitGroup`] holds the unique
9//! [`tokio::sync::mpsc::Receiver`] and each [`WaitGuard`] holds a
10//! [`tokio::sync::mpsc::Sender`]. Despite this simple implementation, the
11//! [`WaitGroup`] and [`WaitGuard`] wrappers are useful to make this discoverable.
12
13use futures::StreamExt;
14use futures::never::Never;
15use tokio::sync::mpsc::{Receiver, Sender, channel};
16use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel};
17use tokio_stream::wrappers::{ReceiverStream, UnboundedReceiverStream};
18
19/// Manager for a particular wait group. This can spawn a number of [`WaitGuard`]s
20/// and it can then wait for them to all complete.
21///
22/// Each [`WaitGroup`] is single-use—calling [`WaitGroup::wait`] to start
23/// waiting consumes the [`WaitGroup`]. Additionally, once all [`WaitGuard`]s
24/// have been dropped, it is not possible to create any more [`WaitGuard`]s.
25#[derive(Debug)]
26pub struct WaitGroup {
27 /// [`Receiver`] is held to wait for multiple [`Sender`]s and detect
28 /// when they have closed. The [`Never`] type means no value can/will
29 /// ever be received through the channel.
30 recv: Receiver<Never>,
31}
32
33/// RAII guard held by a task which is being waited for.
34///
35/// The existence of values of this type represents outstanding work for
36/// its corresponding [`WaitGroup`].
37///
38/// A [`WaitGuard`] can be cloned using [`WaitGuard::clone`]. This allows
39/// a task to spawn additional tasks, recursively.
40#[derive(Clone, Debug)]
41pub struct WaitGuard {
42 /// [`Sender`] is held to keep the [`Receiver`] end open (stored in [`WaitGroup`]).
43 /// The dropping of all senders will cause the receiver to detect and close.
44 /// The [`Never`] type means no value can/will ever be sent through the channel.
45 _send: Sender<Never>,
46}
47
48impl WaitGroup {
49 /// Creates a new [`WaitGroup`] and its first associated [`WaitGuard`].
50 ///
51 /// Note that [`WaitGroup`] itself has no ability to create new guards.
52 /// If needed, new guards should be created by cloning the returned [`WaitGuard`].
53 #[must_use]
54 pub fn new() -> (Self, WaitGuard) {
55 let (send, recv) = channel(1);
56 (Self { recv }, WaitGuard { _send: send })
57 }
58
59 /// Waits, asynchronously, until all the associated [`WaitGuard`]s have finished.
60 pub async fn wait(mut self) {
61 let None = self.recv.recv().await;
62 }
63}
64
65/// Demonstrates use of the [`WaitGroup`] and [`WaitGuard`] to (very inefficiently)
66/// compute the Fibonacci number `F(n)` using recursive channels.
67///
68/// The given `waiter` will be used to detect when the work has finished and it will
69/// close the channels. Additionally, `waiter` can be omitted to show that without
70/// the [`WaitGroup`], the tasks would not terminate.
71#[allow(dead_code)]
72async fn fibonacci_waiter_example(n: usize, waiter: Option<(WaitGroup, WaitGuard)>) -> usize {
73 let (send, recv) = unbounded_channel();
74 let (incr_count, recv_count) = channel(1);
75
76 let (waiter, guard) = match waiter {
77 Some((waiter, guard)) => (Some(waiter), Some(guard)),
78 None => (None, None),
79 };
80
81 let recursive_task = tokio::task::spawn({
82 let send = send.clone();
83 fibonacci_waiter_example_task(recv, send, incr_count, waiter)
84 });
85
86 let count_task = tokio::task::spawn(async move {
87 let count = ReceiverStream::new(recv_count).count();
88 count.await
89 });
90
91 send.send((guard, n)).expect("initial send"); // note `guard` must be moved!
92
93 let ((), result) = futures::try_join!(recursive_task, count_task).expect("join");
94 result
95}
96
97/// An inefficient Fibonacci implementation. This computes `F(n)` by sending
98/// by `n-1` and `n-2` back into the channel. This shows how one work item can
99/// create multiple subsequent work items.
100///
101/// The result is determined by sending `()` into an increment channel and
102/// reading the number of increments.
103///
104/// This is wildly inefficient because it does not cache any results. Computing
105/// `F(n)` would generate `O(2^n)` channel items.
106#[allow(dead_code)]
107async fn fibonacci_waiter_example_task(
108 recv: UnboundedReceiver<(Option<WaitGuard>, usize)>,
109 send: UnboundedSender<(Option<WaitGuard>, usize)>,
110 incr_count: Sender<()>,
111 waiter: Option<WaitGroup>,
112) {
113 let stream = UnboundedReceiverStream::new(recv);
114 let stream = match waiter {
115 Some(waiter) => stream.take_until(waiter.wait()).left_stream(),
116 None => stream.right_stream(),
117 };
118
119 stream
120 .for_each(async |(guard, n)| match n {
121 0 => (),
122 1 => incr_count.send(()).await.expect("send incr"),
123 n => {
124 send.send((guard.clone(), n - 1)).expect("send 1");
125 send.send((guard.clone(), n - 2)).expect("send 2");
126 }
127 })
128 .await;
129}
130
131#[cfg(test)]
132mod tests {
133 use super::WaitGroup;
134 use super::fibonacci_waiter_example;
135 use std::time::Duration;
136
137 fn timeout<F: IntoFuture>(fut: F) -> tokio::time::Timeout<F::IntoFuture> {
138 tokio::time::timeout(Duration::from_millis(250), fut)
139 }
140
141 #[tokio::test]
142 async fn fibonacci_basic_termination() {
143 assert_eq!(fibonacci_waiter_example(0, Some(WaitGroup::new())).await, 0);
144 assert_eq!(
145 fibonacci_waiter_example(9, Some(WaitGroup::new())).await,
146 34
147 );
148 assert_eq!(
149 fibonacci_waiter_example(10, Some(WaitGroup::new())).await,
150 55
151 );
152 }
153
154 #[tokio::test]
155 async fn fibonacci_nontermination_without_waiter() {
156 // task does not terminate if WaitGroup is not used, due to recursive channels
157 assert!(timeout(fibonacci_waiter_example(9, None)).await.is_err());
158
159 // even a "trivial" case does not terminate.
160 assert!(timeout(fibonacci_waiter_example(0, None)).await.is_err());
161 }
162
163 #[tokio::test]
164 async fn fibonacci_nontermination_with_extra_guard() {
165 // in these tests, we do use a WaitGroup but it doesn't terminate because we
166 // *clone* the guard and the test function holds an extra guard, blocking
167 // WaitGroup from returning. this is an example of something that can go wrong
168 // when using the waiter.
169 let (waiter, guard) = WaitGroup::new();
170 assert!(
171 timeout(fibonacci_waiter_example(9, Some((waiter, guard.clone()))))
172 .await
173 .is_err()
174 );
175
176 let (waiter, guard) = WaitGroup::new();
177 assert!(
178 timeout(fibonacci_waiter_example(0, Some((waiter, guard.clone()))))
179 .await
180 .is_err()
181 );
182 }
183}