1use core::cell::{RefCell, UnsafeCell};
5use core::fmt;
6use core::future::{poll_fn, Future};
7use core::ops::{Deref, DerefMut};
8use core::task::Poll;
9
10use crate::blocking_mutex::raw::RawMutex;
11use crate::blocking_mutex::Mutex as BlockingMutex;
12use crate::waitqueue::WakerRegistration;
13
14#[derive(PartialEq, Eq, Clone, Copy, Debug)]
16#[cfg_attr(feature = "defmt", derive(defmt::Format))]
17pub struct TryLockError;
18
19#[derive(Debug)]
20struct State {
21 readers: usize,
22 writer: bool,
23 waker: WakerRegistration,
24}
25
26pub struct RwLock<M, T>
43where
44 M: RawMutex,
45 T: ?Sized,
46{
47 state: BlockingMutex<M, RefCell<State>>,
48 inner: UnsafeCell<T>,
49}
50
51unsafe impl<M: RawMutex + Send, T: ?Sized + Send> Send for RwLock<M, T> {}
52unsafe impl<M: RawMutex + Sync, T: ?Sized + Send> Sync for RwLock<M, T> {}
53
54impl<M, T> RwLock<M, T>
56where
57 M: RawMutex,
58{
59 pub const fn new(value: T) -> Self {
61 Self {
62 inner: UnsafeCell::new(value),
63 state: BlockingMutex::new(RefCell::new(State {
64 readers: 0,
65 writer: false,
66 waker: WakerRegistration::new(),
67 })),
68 }
69 }
70}
71
72impl<M, T> RwLock<M, T>
73where
74 M: RawMutex,
75 T: ?Sized,
76{
77 pub fn read(&self) -> impl Future<Output = RwLockReadGuard<'_, M, T>> {
81 poll_fn(|cx| {
82 let ready = self.state.lock(|s| {
83 let mut s = s.borrow_mut();
84 if s.writer {
85 s.waker.register(cx.waker());
86 false
87 } else {
88 s.readers += 1;
89 true
90 }
91 });
92
93 if ready {
94 Poll::Ready(RwLockReadGuard { rwlock: self })
95 } else {
96 Poll::Pending
97 }
98 })
99 }
100
101 pub fn write(&self) -> impl Future<Output = RwLockWriteGuard<'_, M, T>> {
105 poll_fn(|cx| {
106 let ready = self.state.lock(|s| {
107 let mut s = s.borrow_mut();
108 if s.writer || s.readers > 0 {
109 s.waker.register(cx.waker());
110 false
111 } else {
112 s.writer = true;
113 true
114 }
115 });
116
117 if ready {
118 Poll::Ready(RwLockWriteGuard { rwlock: self })
119 } else {
120 Poll::Pending
121 }
122 })
123 }
124
125 pub fn try_read(&self) -> Result<RwLockReadGuard<'_, M, T>, TryLockError> {
129 self.state
130 .lock(|s| {
131 let mut s = s.borrow_mut();
132 if s.writer {
133 return Err(());
134 }
135 s.readers += 1;
136 Ok(())
137 })
138 .map_err(|_| TryLockError)?;
139
140 Ok(RwLockReadGuard { rwlock: self })
141 }
142
143 pub fn try_write(&self) -> Result<RwLockWriteGuard<'_, M, T>, TryLockError> {
147 self.state
148 .lock(|s| {
149 let mut s = s.borrow_mut();
150 if s.writer || s.readers > 0 {
151 return Err(());
152 }
153 s.writer = true;
154 Ok(())
155 })
156 .map_err(|_| TryLockError)?;
157
158 Ok(RwLockWriteGuard { rwlock: self })
159 }
160
161 pub fn into_inner(self) -> T
163 where
164 T: Sized,
165 {
166 self.inner.into_inner()
167 }
168
169 pub fn get_mut(&mut self) -> &mut T {
174 self.inner.get_mut()
175 }
176}
177
178impl<M: RawMutex, T> From<T> for RwLock<M, T> {
179 fn from(from: T) -> Self {
180 Self::new(from)
181 }
182}
183
184impl<M, T> Default for RwLock<M, T>
185where
186 M: RawMutex,
187 T: Default,
188{
189 fn default() -> Self {
190 Self::new(Default::default())
191 }
192}
193
194impl<M, T> fmt::Debug for RwLock<M, T>
195where
196 M: RawMutex,
197 T: ?Sized + fmt::Debug,
198{
199 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
200 let mut d = f.debug_struct("RwLock");
201 match self.try_read() {
202 Ok(guard) => d.field("inner", &&*guard),
203 Err(TryLockError) => d.field("inner", &"Locked"),
204 }
205 .finish_non_exhaustive()
206 }
207}
208
209#[clippy::has_significant_drop]
216#[must_use = "if unused the RwLock will immediately unlock"]
217pub struct RwLockReadGuard<'a, R, T>
218where
219 R: RawMutex,
220 T: ?Sized,
221{
222 rwlock: &'a RwLock<R, T>,
223}
224
225impl<'a, M, T> Drop for RwLockReadGuard<'a, M, T>
226where
227 M: RawMutex,
228 T: ?Sized,
229{
230 fn drop(&mut self) {
231 self.rwlock.state.lock(|s| {
232 let mut s = unwrap!(s.try_borrow_mut());
233 s.readers -= 1;
234 if s.readers == 0 {
235 s.waker.wake();
236 }
237 })
238 }
239}
240
241impl<'a, M, T> Deref for RwLockReadGuard<'a, M, T>
242where
243 M: RawMutex,
244 T: ?Sized,
245{
246 type Target = T;
247 fn deref(&self) -> &Self::Target {
248 unsafe { &*(self.rwlock.inner.get() as *const T) }
251 }
252}
253
254impl<'a, M, T> fmt::Debug for RwLockReadGuard<'a, M, T>
255where
256 M: RawMutex,
257 T: ?Sized + fmt::Debug,
258{
259 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
260 fmt::Debug::fmt(&**self, f)
261 }
262}
263
264impl<'a, M, T> fmt::Display for RwLockReadGuard<'a, M, T>
265where
266 M: RawMutex,
267 T: ?Sized + fmt::Display,
268{
269 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
270 fmt::Display::fmt(&**self, f)
271 }
272}
273
274#[clippy::has_significant_drop]
281#[must_use = "if unused the RwLock will immediately unlock"]
282pub struct RwLockWriteGuard<'a, R, T>
283where
284 R: RawMutex,
285 T: ?Sized,
286{
287 rwlock: &'a RwLock<R, T>,
288}
289
290impl<'a, R, T> Drop for RwLockWriteGuard<'a, R, T>
291where
292 R: RawMutex,
293 T: ?Sized,
294{
295 fn drop(&mut self) {
296 self.rwlock.state.lock(|s| {
297 let mut s = unwrap!(s.try_borrow_mut());
298 s.writer = false;
299 s.waker.wake();
300 })
301 }
302}
303
304impl<'a, R, T> Deref for RwLockWriteGuard<'a, R, T>
305where
306 R: RawMutex,
307 T: ?Sized,
308{
309 type Target = T;
310 fn deref(&self) -> &Self::Target {
311 unsafe { &*(self.rwlock.inner.get() as *mut T) }
314 }
315}
316
317impl<'a, R, T> DerefMut for RwLockWriteGuard<'a, R, T>
318where
319 R: RawMutex,
320 T: ?Sized,
321{
322 fn deref_mut(&mut self) -> &mut Self::Target {
323 unsafe { &mut *(self.rwlock.inner.get()) }
326 }
327}
328
329impl<'a, R, T> fmt::Debug for RwLockWriteGuard<'a, R, T>
330where
331 R: RawMutex,
332 T: ?Sized + fmt::Debug,
333{
334 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
335 fmt::Debug::fmt(&**self, f)
336 }
337}
338
339impl<'a, R, T> fmt::Display for RwLockWriteGuard<'a, R, T>
340where
341 R: RawMutex,
342 T: ?Sized + fmt::Display,
343{
344 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
345 fmt::Display::fmt(&**self, f)
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use crate::blocking_mutex::raw::NoopRawMutex;
352 use crate::rwlock::RwLock;
353
354 #[futures_test::test]
355 async fn read_guard_releases_lock_when_dropped() {
356 let rwlock: RwLock<NoopRawMutex, [i32; 2]> = RwLock::new([0, 1]);
357
358 {
359 let guard = rwlock.read().await;
360 assert_eq!(*guard, [0, 1]);
361 }
362
363 {
364 let guard = rwlock.read().await;
365 assert_eq!(*guard, [0, 1]);
366 }
367
368 assert_eq!(*rwlock.read().await, [0, 1]);
369 }
370
371 #[futures_test::test]
372 async fn write_guard_releases_lock_when_dropped() {
373 let rwlock: RwLock<NoopRawMutex, [i32; 2]> = RwLock::new([0, 1]);
374
375 {
376 let mut guard = rwlock.write().await;
377 assert_eq!(*guard, [0, 1]);
378 guard[1] = 2;
379 }
380
381 {
382 let guard = rwlock.read().await;
383 assert_eq!(*guard, [0, 2]);
384 }
385
386 assert_eq!(*rwlock.read().await, [0, 2]);
387 }
388}