Skip to main content

lychee_lib/
waiter.rs

1//! Facility to wait for a dynamic set of tasks to complete, with a single
2//! waiter and multiple waitees (things that are waited for). Notably, each
3//! waitee can also start more work to be waited for.
4//!
5//! # Implementation Details
6//!
7//! The implementation of waiting in this module is just a wrapper around
8//! [`tokio::sync::mpsc::channel`]. A [`WaitGroup`] holds the unique
9//! [`tokio::sync::mpsc::Receiver`] and each [`WaitGuard`] holds a
10//! [`tokio::sync::mpsc::Sender`]. Despite this simple implementation, the
11//! [`WaitGroup`] and [`WaitGuard`] wrappers are useful to make this discoverable.
12
13use futures::StreamExt;
14use futures::never::Never;
15use tokio::sync::mpsc::{Receiver, Sender, channel};
16use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel};
17use tokio_stream::wrappers::{ReceiverStream, UnboundedReceiverStream};
18
19/// Manager for a particular wait group. This can spawn a number of [`WaitGuard`]s
20/// and it can then wait for them to all complete.
21///
22/// Each [`WaitGroup`] is single-use—calling [`WaitGroup::wait`] to start
23/// waiting consumes the [`WaitGroup`]. Additionally, once all [`WaitGuard`]s
24/// have been dropped, it is not possible to create any more [`WaitGuard`]s.
25#[derive(Debug)]
26pub struct WaitGroup {
27    /// [`Receiver`] is held to wait for multiple [`Sender`]s and detect
28    /// when they have closed. The [`Never`] type means no value can/will
29    /// ever be received through the channel.
30    recv: Receiver<Never>,
31}
32
33/// RAII guard held by a task which is being waited for.
34///
35/// The existence of values of this type represents outstanding work for
36/// its corresponding [`WaitGroup`].
37///
38/// A [`WaitGuard`] can be cloned using [`WaitGuard::clone`]. This allows
39/// a task to spawn additional tasks, recursively.
40#[derive(Clone, Debug)]
41pub struct WaitGuard {
42    /// [`Sender`] is held to keep the [`Receiver`] end open (stored in [`WaitGroup`]).
43    /// The dropping of all senders will cause the receiver to detect and close.
44    /// The [`Never`] type means no value can/will ever be sent through the channel.
45    _send: Sender<Never>,
46}
47
48impl WaitGroup {
49    /// Creates a new [`WaitGroup`] and its first associated [`WaitGuard`].
50    ///
51    /// Note that [`WaitGroup`] itself has no ability to create new guards.
52    /// If needed, new guards should be created by cloning the returned [`WaitGuard`].
53    #[must_use]
54    pub fn new() -> (Self, WaitGuard) {
55        let (send, recv) = channel(1);
56        (Self { recv }, WaitGuard { _send: send })
57    }
58
59    /// Waits, asynchronously, until all the associated [`WaitGuard`]s have finished.
60    pub async fn wait(mut self) {
61        let None = self.recv.recv().await;
62    }
63}
64
65/// Demonstrates use of the [`WaitGroup`] and [`WaitGuard`] to (very inefficiently)
66/// compute the Fibonacci number `F(n)` using recursive channels.
67///
68/// The given `waiter` will be used to detect when the work has finished and it will
69/// close the channels. Additionally, `waiter` can be omitted to show that without
70/// the [`WaitGroup`], the tasks would not terminate.
71#[allow(dead_code)]
72async fn fibonacci_waiter_example(n: usize, waiter: Option<(WaitGroup, WaitGuard)>) -> usize {
73    let (send, recv) = unbounded_channel();
74    let (incr_count, recv_count) = channel(1);
75
76    let (waiter, guard) = match waiter {
77        Some((waiter, guard)) => (Some(waiter), Some(guard)),
78        None => (None, None),
79    };
80
81    let recursive_task = tokio::task::spawn({
82        let send = send.clone();
83        fibonacci_waiter_example_task(recv, send, incr_count, waiter)
84    });
85
86    let count_task = tokio::task::spawn(async move {
87        let count = ReceiverStream::new(recv_count).count();
88        count.await
89    });
90
91    send.send((guard, n)).expect("initial send"); // note `guard` must be moved!
92
93    let ((), result) = futures::try_join!(recursive_task, count_task).expect("join");
94    result
95}
96
97/// An inefficient Fibonacci implementation. This computes `F(n)` by sending
98/// by `n-1` and `n-2` back into the channel. This shows how one work item can
99/// create multiple subsequent work items.
100///
101/// The result is determined by sending `()` into an increment channel and
102/// reading the number of increments.
103///
104/// This is wildly inefficient because it does not cache any results. Computing
105/// `F(n)` would generate `O(2^n)` channel items.
106#[allow(dead_code)]
107async fn fibonacci_waiter_example_task(
108    recv: UnboundedReceiver<(Option<WaitGuard>, usize)>,
109    send: UnboundedSender<(Option<WaitGuard>, usize)>,
110    incr_count: Sender<()>,
111    waiter: Option<WaitGroup>,
112) {
113    let stream = UnboundedReceiverStream::new(recv);
114    let stream = match waiter {
115        Some(waiter) => stream.take_until(waiter.wait()).left_stream(),
116        None => stream.right_stream(),
117    };
118
119    stream
120        .for_each(async |(guard, n)| match n {
121            0 => (),
122            1 => incr_count.send(()).await.expect("send incr"),
123            n => {
124                send.send((guard.clone(), n - 1)).expect("send 1");
125                send.send((guard.clone(), n - 2)).expect("send 2");
126            }
127        })
128        .await;
129}
130
131#[cfg(test)]
132mod tests {
133    use super::WaitGroup;
134    use super::fibonacci_waiter_example;
135    use std::time::Duration;
136
137    fn timeout<F: IntoFuture>(fut: F) -> tokio::time::Timeout<F::IntoFuture> {
138        tokio::time::timeout(Duration::from_millis(250), fut)
139    }
140
141    #[tokio::test]
142    async fn fibonacci_basic_termination() {
143        assert_eq!(fibonacci_waiter_example(0, Some(WaitGroup::new())).await, 0);
144        assert_eq!(
145            fibonacci_waiter_example(9, Some(WaitGroup::new())).await,
146            34
147        );
148        assert_eq!(
149            fibonacci_waiter_example(10, Some(WaitGroup::new())).await,
150            55
151        );
152    }
153
154    #[tokio::test]
155    async fn fibonacci_nontermination_without_waiter() {
156        // task does not terminate if WaitGroup is not used, due to recursive channels
157        assert!(timeout(fibonacci_waiter_example(9, None)).await.is_err());
158
159        // even a "trivial" case does not terminate.
160        assert!(timeout(fibonacci_waiter_example(0, None)).await.is_err());
161    }
162
163    #[tokio::test]
164    async fn fibonacci_nontermination_with_extra_guard() {
165        // in these tests, we do use a WaitGroup but it doesn't terminate because we
166        // *clone* the guard and the test function holds an extra guard, blocking
167        // WaitGroup from returning. this is an example of something that can go wrong
168        // when using the waiter.
169        let (waiter, guard) = WaitGroup::new();
170        assert!(
171            timeout(fibonacci_waiter_example(9, Some((waiter, guard.clone()))))
172                .await
173                .is_err()
174        );
175
176        let (waiter, guard) = WaitGroup::new();
177        assert!(
178            timeout(fibonacci_waiter_example(0, Some((waiter, guard.clone()))))
179                .await
180                .is_err()
181        );
182    }
183}