Skip to content

Implement async_std::sync::Condvar #369

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 12, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
417 changes: 417 additions & 0 deletions src/sync/condvar.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,417 @@
use std::fmt;
use std::pin::Pin;
use std::time::Duration;

use super::mutex::{guard_lock, MutexGuard};
use crate::future::{timeout, Future};
use crate::sync::WakerSet;
use crate::task::{Context, Poll};

#[derive(Debug, PartialEq, Eq, Copy, Clone)]
pub struct WaitTimeoutResult(bool);

/// A type indicating whether a timed wait on a condition variable returned due to a time out or
/// not
impl WaitTimeoutResult {
/// Returns `true` if the wait was known to have timed out.
pub fn timed_out(self) -> bool {
self.0
}
}

/// A Condition Variable
///
/// This type is an async version of [`std::sync::Mutex`].
///
/// [`std::sync::Condvar`]: https://doc.rust-lang.org/std/sync/struct.Condvar.html
///
/// # Examples
///
/// ```
/// # async_std::task::block_on(async {
/// #
/// use std::sync::Arc;
///
/// use async_std::sync::{Mutex, Condvar};
/// use async_std::task;
///
/// let pair = Arc::new((Mutex::new(false), Condvar::new()));
/// let pair2 = pair.clone();
///
/// // Inside of our lock, spawn a new thread, and then wait for it to start.
/// task::spawn(async move {
/// let (lock, cvar) = &*pair2;
/// let mut started = lock.lock().await;
/// *started = true;
/// // We notify the condvar that the value has changed.
/// cvar.notify_one();
/// });
///
/// // Wait for the thread to start up.
/// let (lock, cvar) = &*pair;
/// let mut started = lock.lock().await;
/// while !*started {
/// started = cvar.wait(started).await;
/// }
///
/// # })
/// ```
pub struct Condvar {
wakers: WakerSet,
}

unsafe impl Send for Condvar {}
unsafe impl Sync for Condvar {}

impl Default for Condvar {
fn default() -> Self {
Condvar::new()
}
}

impl Condvar {
/// Creates a new condition variable
///
/// # Examples
///
/// ```
/// use async_std::sync::Condvar;
///
/// let cvar = Condvar::new();
/// ```
pub fn new() -> Self {
Condvar {
wakers: WakerSet::new(),
}
}

/// Blocks the current task until this condition variable receives a notification.
///
/// Unlike the std equivalent, this does not check that a single mutex is used at runtime.
/// However, as a best practice avoid using with multiple mutexes.
///
/// # Examples
///
/// ```
/// # async_std::task::block_on(async {
/// use std::sync::Arc;
///
/// use async_std::sync::{Mutex, Condvar};
/// use async_std::task;
///
/// let pair = Arc::new((Mutex::new(false), Condvar::new()));
/// let pair2 = pair.clone();
///
/// task::spawn(async move {
/// let (lock, cvar) = &*pair2;
/// let mut started = lock.lock().await;
/// *started = true;
/// // We notify the condvar that the value has changed.
/// cvar.notify_one();
/// });
///
/// // Wait for the thread to start up.
/// let (lock, cvar) = &*pair;
/// let mut started = lock.lock().await;
/// while !*started {
/// started = cvar.wait(started).await;
/// }
/// # })
/// ```
#[allow(clippy::needless_lifetimes)]
pub async fn wait<'a, T>(&self, guard: MutexGuard<'a, T>) -> MutexGuard<'a, T> {
let mutex = guard_lock(&guard);

self.await_notify(guard).await;

mutex.lock().await
}

fn await_notify<'a, T>(&self, guard: MutexGuard<'a, T>) -> AwaitNotify<'_, 'a, T> {
AwaitNotify {
cond: self,
guard: Some(guard),
key: None,
}
}

/// Blocks the current taks until this condition variable receives a notification and the
/// required condition is met. Spurious wakeups are ignored and this function will only
/// return once the condition has been met.
///
/// # Examples
///
/// ```
/// # async_std::task::block_on(async {
/// #
/// use std::sync::Arc;
///
/// use async_std::sync::{Mutex, Condvar};
/// use async_std::task;
///
/// let pair = Arc::new((Mutex::new(false), Condvar::new()));
/// let pair2 = pair.clone();
///
/// task::spawn(async move {
/// let (lock, cvar) = &*pair2;
/// let mut started = lock.lock().await;
/// *started = true;
/// // We notify the condvar that the value has changed.
/// cvar.notify_one();
/// });
///
/// // Wait for the thread to start up.
/// let (lock, cvar) = &*pair;
/// // As long as the value inside the `Mutex<bool>` is `false`, we wait.
/// let _guard = cvar.wait_until(lock.lock().await, |started| { *started }).await;
/// #
/// # })
/// ```
#[allow(clippy::needless_lifetimes)]
pub async fn wait_until<'a, T, F>(
&self,
mut guard: MutexGuard<'a, T>,
mut condition: F,
) -> MutexGuard<'a, T>
where
F: FnMut(&mut T) -> bool,
{
while !condition(&mut *guard) {
guard = self.wait(guard).await;
}
guard
}

/// Waits on this condition variable for a notification, timing out after a specified duration.
///
/// For these reasons `Condvar::wait_timeout_until` is recommended in most cases.
///
/// # Examples
///
/// ```
/// # async_std::task::block_on(async {
/// #
/// use std::sync::Arc;
/// use std::time::Duration;
///
/// use async_std::sync::{Mutex, Condvar};
/// use async_std::task;
///
/// let pair = Arc::new((Mutex::new(false), Condvar::new()));
/// let pair2 = pair.clone();
///
/// task::spawn(async move {
/// let (lock, cvar) = &*pair2;
/// let mut started = lock.lock().await;
/// *started = true;
/// // We notify the condvar that the value has changed.
/// cvar.notify_one();
/// });
///
/// // wait for the thread to start up
/// let (lock, cvar) = &*pair;
/// let mut started = lock.lock().await;
/// loop {
/// let result = cvar.wait_timeout(started, Duration::from_millis(10)).await;
/// started = result.0;
/// if *started == true {
/// // We received the notification and the value has been updated, we can leave.
/// break
/// }
/// }
/// #
/// # })
/// ```
#[allow(clippy::needless_lifetimes)]
pub async fn wait_timeout<'a, T>(
&self,
guard: MutexGuard<'a, T>,
dur: Duration,
) -> (MutexGuard<'a, T>, WaitTimeoutResult) {
let mutex = guard_lock(&guard);
match timeout(dur, self.wait(guard)).await {
Ok(guard) => (guard, WaitTimeoutResult(false)),
Err(_) => (mutex.lock().await, WaitTimeoutResult(true)),
}
}

/// Waits on this condition variable for a notification, timing out after a specified duration.
/// Spurious wakes will not cause this function to return.
///
/// # Examples
/// ```
/// # async_std::task::block_on(async {
/// use std::sync::Arc;
/// use std::time::Duration;
///
/// use async_std::sync::{Mutex, Condvar};
/// use async_std::task;
///
/// let pair = Arc::new((Mutex::new(false), Condvar::new()));
/// let pair2 = pair.clone();
///
/// task::spawn(async move {
/// let (lock, cvar) = &*pair2;
/// let mut started = lock.lock().await;
/// *started = true;
/// // We notify the condvar that the value has changed.
/// cvar.notify_one();
/// });
///
/// // wait for the thread to start up
/// let (lock, cvar) = &*pair;
/// let result = cvar.wait_timeout_until(
/// lock.lock().await,
/// Duration::from_millis(100),
/// |&mut started| started,
/// ).await;
/// if result.1.timed_out() {
/// // timed-out without the condition ever evaluating to true.
/// }
/// // access the locked mutex via result.0
/// # });
/// ```
#[allow(clippy::needless_lifetimes)]
pub async fn wait_timeout_until<'a, T, F>(
&self,
guard: MutexGuard<'a, T>,
dur: Duration,
condition: F,
) -> (MutexGuard<'a, T>, WaitTimeoutResult)
where
F: FnMut(&mut T) -> bool,
{
let mutex = guard_lock(&guard);
match timeout(dur, self.wait_until(guard, condition)).await {
Ok(guard) => (guard, WaitTimeoutResult(false)),
Err(_) => (mutex.lock().await, WaitTimeoutResult(true)),
}
}

/// Wakes up one blocked task on this condvar.
///
/// # Examples
///
/// ```
/// # fn main() { async_std::task::block_on(async {
/// use std::sync::Arc;
///
/// use async_std::sync::{Mutex, Condvar};
/// use async_std::task;
///
/// let pair = Arc::new((Mutex::new(false), Condvar::new()));
/// let pair2 = pair.clone();
///
/// task::spawn(async move {
/// let (lock, cvar) = &*pair2;
/// let mut started = lock.lock().await;
/// *started = true;
/// // We notify the condvar that the value has changed.
/// cvar.notify_one();
/// });
///
/// // Wait for the thread to start up.
/// let (lock, cvar) = &*pair;
/// let mut started = lock.lock().await;
/// while !*started {
/// started = cvar.wait(started).await;
/// }
/// # }) }
/// ```
pub fn notify_one(&self) {
self.wakers.notify_one();
}

/// Wakes up all blocked tasks on this condvar.
///
/// # Examples
/// ```
/// # fn main() { async_std::task::block_on(async {
/// #
/// use std::sync::Arc;
///
/// use async_std::sync::{Mutex, Condvar};
/// use async_std::task;
///
/// let pair = Arc::new((Mutex::new(false), Condvar::new()));
/// let pair2 = pair.clone();
///
/// task::spawn(async move {
/// let (lock, cvar) = &*pair2;
/// let mut started = lock.lock().await;
/// *started = true;
/// // We notify the condvar that the value has changed.
/// cvar.notify_all();
/// });
///
/// // Wait for the thread to start up.
/// let (lock, cvar) = &*pair;
/// let mut started = lock.lock().await;
/// // As long as the value inside the `Mutex<bool>` is `false`, we wait.
/// while !*started {
/// started = cvar.wait(started).await;
/// }
/// #
/// # }) }
/// ```
pub fn notify_all(&self) {
self.wakers.notify_all();
}
}

impl fmt::Debug for Condvar {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.pad("Condvar { .. }")
}
}

/// A future that waits for another task to notify the condition variable.
///
/// This is an internal future that `wait` and `wait_until` await on.
struct AwaitNotify<'a, 'b, T> {
/// The condition variable that we are waiting on
cond: &'a Condvar,
/// The lock used with `cond`.
/// This will be released the first time the future is polled,
/// after registering the context to be notified.
guard: Option<MutexGuard<'b, T>>,
/// A key into the conditions variable's `WakerSet`.
/// This is set to the index of the `Waker` for the context each time
/// the future is polled and not completed.
key: Option<usize>,
}

impl<'a, 'b, T> Future for AwaitNotify<'a, 'b, T> {
type Output = ();

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.guard.take() {
Some(_) => {
self.key = Some(self.cond.wakers.insert(cx));
// the guard is dropped when we return, which frees the lock
Poll::Pending
}
None => {
if let Some(key) = self.key {
if self.cond.wakers.remove_if_notified(key, cx) {
self.key = None;
Poll::Ready(())
} else {
Poll::Pending
}
} else {
// This should only happen if it is polled twice after receiving a notification
Poll::Ready(())
}
}
}
}
}

impl<'a, 'b, T> Drop for AwaitNotify<'a, 'b, T> {
fn drop(&mut self) {
if let Some(key) = self.key {
self.cond.wakers.cancel(key);
}
}
}
2 changes: 2 additions & 0 deletions src/sync/mod.rs
Original file line number Diff line number Diff line change
@@ -185,8 +185,10 @@ mod rwlock;
cfg_unstable! {
pub use barrier::{Barrier, BarrierWaitResult};
pub use channel::{channel, Sender, Receiver, RecvError, TryRecvError, TrySendError};
pub use condvar::Condvar;

mod barrier;
mod condvar;
mod channel;
}

5 changes: 5 additions & 0 deletions src/sync/mutex.rs
Original file line number Diff line number Diff line change
@@ -287,3 +287,8 @@ impl<T: ?Sized> DerefMut for MutexGuard<'_, T> {
unsafe { &mut *self.0.value.get() }
}
}

#[cfg(feature = "unstable")]
pub fn guard_lock<'a, T>(guard: &MutexGuard<'a, T>) -> &'a Mutex<T> {
guard.0
}
22 changes: 22 additions & 0 deletions src/sync/waker_set.rs
Original file line number Diff line number Diff line change
@@ -80,6 +80,28 @@ impl WakerSet {
}
}

/// If the waker for this key is still waiting for a notification, then update
/// the waker for the entry, and return false. If the waker has been notified,
/// treat the entry as completed and return true.
#[cfg(feature = "unstable")]
pub fn remove_if_notified(&self, key: usize, cx: &Context<'_>) -> bool {
let mut inner = self.lock();

match &mut inner.entries[key] {
None => {
inner.entries.remove(key);
true
}
Some(w) => {
// We were never woken, so update instead
if !w.will_wake(cx.waker()) {
*w = cx.waker().clone();
}
false
}
}
}

/// Removes the waker of a cancelled operation.
///
/// Returns `true` if another blocked operation from the set was notified.
91 changes: 91 additions & 0 deletions tests/condvar.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#![cfg(feature = "unstable")]
use std::sync::Arc;
use std::time::Duration;

use async_std::sync::{Condvar, Mutex};
use async_std::task::{self, JoinHandle};

#[test]
fn wait_timeout_with_lock() {
task::block_on(async {
let pair = Arc::new((Mutex::new(false), Condvar::new()));
let pair2 = pair.clone();

task::spawn(async move {
let (m, c) = &*pair2;
let _g = m.lock().await;
task::sleep(Duration::from_millis(20)).await;
c.notify_one();
});

let (m, c) = &*pair;
let (_, wait_result) = c
.wait_timeout(m.lock().await, Duration::from_millis(10))
.await;
assert!(wait_result.timed_out());
})
}

#[test]
fn wait_timeout_without_lock() {
task::block_on(async {
let m = Mutex::new(false);
let c = Condvar::new();

let (_, wait_result) = c
.wait_timeout(m.lock().await, Duration::from_millis(10))
.await;
assert!(wait_result.timed_out());
})
}

#[test]
fn wait_timeout_until_timed_out() {
task::block_on(async {
let m = Mutex::new(false);
let c = Condvar::new();

let (_, wait_result) = c
.wait_timeout_until(m.lock().await, Duration::from_millis(10), |&mut started| {
started
})
.await;
assert!(wait_result.timed_out());
})
}

#[test]
fn notify_all() {
task::block_on(async {
let mut tasks: Vec<JoinHandle<()>> = Vec::new();
let pair = Arc::new((Mutex::new(0u32), Condvar::new()));

for _ in 0..10 {
let pair = pair.clone();
tasks.push(task::spawn(async move {
let (m, c) = &*pair;
let mut count = m.lock().await;
while *count == 0 {
count = c.wait(count).await;
}
*count += 1;
}));
}

// Give some time for tasks to start up
task::sleep(Duration::from_millis(5)).await;

let (m, c) = &*pair;
{
let mut count = m.lock().await;
*count += 1;
c.notify_all();
}

for t in tasks {
t.await;
}
let count = m.lock().await;
assert_eq!(11, *count);
})
}