wdk_mutex/
fast_mutex.rs

1//! A Rust idiomatic Windows Kernel Driver FAST_MUTEX type which protects the inner type T
2
3use alloc::boxed::Box;
4use core::{
5    ffi::c_void,
6    fmt::Display,
7    ops::{Deref, DerefMut},
8    ptr::{self, drop_in_place},
9};
10use wdk_sys::{
11    ntddk::{
12        ExAcquireFastMutex, ExAllocatePool2, ExFreePool, ExReleaseFastMutex, KeGetCurrentIrql,
13        KeInitializeEvent,
14    },
15    APC_LEVEL, DISPATCH_LEVEL, FALSE, FAST_MUTEX, FM_LOCK_BIT, POOL_FLAG_NON_PAGED,
16    _EVENT_TYPE::SynchronizationEvent,
17};
18
19extern crate alloc;
20
21use crate::errors::DriverMutexError;
22
23/// An internal binding for the ExInitializeFastMutex routine.
24///
25/// # Safety
26///
27/// This function does not check the IRQL as the only place this function is used is in an area where the IRQL
28/// is already checked.
29#[allow(non_snake_case)]
30unsafe fn ExInitializeFastMutex(fast_mutex: *mut FAST_MUTEX) {
31    core::ptr::write_volatile(&mut (*fast_mutex).Count, FM_LOCK_BIT as i32);
32
33    (*fast_mutex).Owner = core::ptr::null_mut();
34    (*fast_mutex).Contention = 0;
35    KeInitializeEvent(&mut (*fast_mutex).Event, SynchronizationEvent, FALSE as _)
36}
37
38/// A thread safe mutex implemented through acquiring a `FAST_MUTEX` in the Windows kernel.
39///
40/// The type `FastMutex<T>` provides mutually exclusive access to the inner type T allocated through
41/// this crate in the non-paged pool. All data required to initialise the FastMutex is allocated in the
42/// non-paged pool and as such is safe to pass stack data into the type as it will not go out of scope.
43///
44/// `FastMutex` holds an inner value which is a pointer to a `FastMutexInner` type which is the actual type
45/// allocated in the non-paged pool, and this holds information relating to the mutex.
46///
47/// Access to the `T` within the `FastMutex` can be done through calling [`Self::lock`].
48///
49/// # Lifetimes
50///
51/// As the `FastMutex` is designed to be used in the Windows Kernel, with the Windows `wdk` crate, the lifetimes of
52/// the `FastMutex` must be considered by the caller. See examples below for usage.
53///
54/// The `FastMutex` can exist in a locally scoped function with little additional configuration. To use the mutex across
55/// thread boundaries, or to use it in callback functions, you can use the `Grt` module found in this crate. See below for
56/// details.
57///
58/// # Deallocation
59///
60/// FastMutex handles the deallocation of resources at the point the FastMutex is dropped.
61///
62/// # Examples
63///
64/// ## Locally scoped mutex:
65///
66/// ```
67/// {
68///     let mtx = FastMutex::new(0u32).unwrap();
69///     let lock = mtx.lock().unwrap();
70///
71///     // If T implements display, you do not need to dereference the lock to print.
72///     println!("The value is: {}", lock);
73/// } // Mutex will become unlocked as it is managed via RAII
74/// ```
75///
76/// ## Global scope via the `Grt` module in `wdk-mutex`:
77///
78/// ```
79/// // Initialise the mutex on DriverEntry
80///
81/// #[export_name = "DriverEntry"]
82/// pub unsafe extern "system" fn driver_entry(
83///     driver: &mut DRIVER_OBJECT,
84///     registry_path: PCUNICODE_STRING,
85/// ) -> NTSTATUS {
86///     if let Err(e) = Grt::init() {
87///         println!("Error creating Grt!: {:?}", e);
88///         return STATUS_UNSUCCESSFUL;
89///     }
90///
91///     // ...
92///     my_function();
93/// }
94///
95///
96/// // Register a new Mutex in the `Grt` of value 0u32:
97///
98/// pub fn my_function() {
99///     Grt::register_fast_mutex("my_test_mutex", 0u32);
100/// }
101///
102/// unsafe extern "C" fn my_thread_fn_pointer(_: *mut c_void) {
103///     let my_mutex = Grt::get_fast_mutex::<u32>("my_test_mutex");
104///     if let Err(e) = my_mutex {
105///         println!("Error in thread: {:?}", e);
106///         return;
107///     }
108///
109///     let mut lock = my_mutex.unwrap().lock().unwrap();
110///     *lock += 1;
111/// }
112///
113///
114/// // Destroy the Grt to prevent memory leak on DriverExit
115///
116/// extern "C" fn driver_exit(driver: *mut DRIVER_OBJECT) {
117///     unsafe {Grt::destroy()};
118/// }
119/// ```
120pub struct FastMutex<T> {
121    inner: *mut FastMutexInner<T>,
122}
123
124/// The underlying data which is non-page pool allocated which is pointed to by the `FastMutex`.
125struct FastMutexInner<T> {
126    mutex: FAST_MUTEX,
127    /// The data for which the mutex is protecting
128    data: T,
129}
130
131unsafe impl<T> Sync for FastMutex<T> {}
132unsafe impl<T> Send for FastMutex<T> {}
133
134impl<T> FastMutex<T> {
135    /// Creates a new `FAST_MUTEX` Windows Kernel Driver Mutex.
136    ///
137    /// # IRQL
138    ///
139    /// This can be called at IRQL <= DISPATCH_LEVEL.
140    ///
141    /// # Examples
142    ///
143    /// ```
144    /// use wdk_mutex::Mutex;
145    ///
146    /// let my_mutex = wdk_mutex::FastMutex::new(0u32);
147    /// ```
148    pub fn new(data: T) -> Result<Self, DriverMutexError> {
149        // This can only be called at a level <= DISPATCH_LEVEL; check current IRQL
150        // https://learn.microsoft.com/en-us/windows-hardware/drivers/ddi/wdm/nf-wdm-exinitializefastmutex
151        if unsafe { KeGetCurrentIrql() } > DISPATCH_LEVEL as u8 {
152            return Err(DriverMutexError::IrqlTooHigh);
153        }
154
155        //
156        // Non-Paged heap alloc for all struct data required for FastMutexInner
157        //
158        let total_sz_required = size_of::<FastMutexInner<T>>();
159        let inner_heap_ptr: *mut c_void = unsafe {
160            ExAllocatePool2(
161                POOL_FLAG_NON_PAGED,
162                total_sz_required as u64,
163                u32::from_be_bytes(*b"kmtx"),
164            )
165        };
166        if inner_heap_ptr.is_null() {
167            return Err(DriverMutexError::PagedPoolAllocFailed);
168        }
169
170        // Cast the memory allocation to a pointer to the inner
171        let fast_mtx_inner_ptr = inner_heap_ptr as *mut FastMutexInner<T>;
172
173        // SAFETY: This raw write is safe as the pointer validity is checked above.
174        unsafe {
175            ptr::write(
176                fast_mtx_inner_ptr,
177                FastMutexInner {
178                    mutex: FAST_MUTEX::default(),
179                    data,
180                },
181            );
182
183            // Initialise the FastMutex object via the kernel
184            ExInitializeFastMutex(&mut (*fast_mtx_inner_ptr).mutex);
185        }
186
187        Ok(Self {
188            inner: fast_mtx_inner_ptr,
189        })
190    }
191
192    /// Acquires the mutex, raising the IRQL to `APC_LEVEL`.
193    ///
194    /// Once the thread has acquired the mutex, it will return a `FastMutexGuard` which is a RAII scoped
195    /// guard allowing exclusive access to the inner T.
196    ///
197    /// # Errors
198    ///
199    /// If the IRQL is too high, this function will return an error and will not acquire a lock. To prevent
200    /// a kernel panic, the caller should match the return value rather than just unwrapping the value.
201    ///
202    /// # IRQL
203    ///
204    /// This function must be called at IRQL `<= APC_LEVEL`, if the IRQL is higher than this,
205    /// the function will return an error.
206    ///
207    /// It is the callers responsibility to ensure the IRQL is sufficient to call this function and it
208    /// will not alter the IRQL for the caller, as this may introduce undefined behaviour elsewhere in the
209    /// driver / kernel.
210    ///
211    /// # Examples
212    ///
213    /// ```
214    /// let mtx = FastMutex::new(0u32).unwrap();
215    /// let lock = mtx.lock().unwrap();
216    /// ```
217    pub fn lock(&self) -> Result<FastMutexGuard<'_, T>, DriverMutexError> {
218        // Check the IRQL is <= APC_LEVEL as per remarks at
219        // https://learn.microsoft.com/en-us/windows-hardware/drivers/ddi/wdm/nf-wdm-exacquirefastmutex
220        let irql = unsafe { KeGetCurrentIrql() };
221        if irql > APC_LEVEL as u8 {
222            return Err(DriverMutexError::IrqlTooHigh);
223        }
224
225        // SAFETY: RAII manages pointer validity and IRQL checked.
226        unsafe { ExAcquireFastMutex(&mut (*self.inner).mutex as *mut _ as *mut _) };
227
228        Ok(FastMutexGuard { fast_mutex: self })
229    }
230
231    /// Consumes the mutex and returns an owned copy of the protected data (`T`).
232    ///
233    /// This method performs a deep copy of the data (`T`) guarded by the mutex before
234    /// deallocating the internal memory. Be cautious when using this method with large
235    /// data types, as it may lead to inefficiencies or stack overflows.
236    ///
237    /// For scenarios involving large data that you prefer not to allocate on the stack,
238    /// consider using [`Self::to_owned_box`] instead.
239    ///
240    /// # Safety
241    ///
242    /// - **Single Ownership Guarantee:** After calling [`Self::to_owned`], ensure that
243    ///   no other references (especially static or global ones) attempt to access the
244    ///   underlying mutex. This is because the mutexes memory is deallocated once this
245    ///   method is invoked.
246    /// - **Exclusive Access:** This function should only be called when you can guarantee
247    ///   that there will be no further access to the protected `T`. Violating this can
248    ///   lead to undefined behavior since the memory is freed after the call.
249    ///
250    /// # Example
251    ///
252    /// ```
253    /// unsafe {
254    ///     let owned_data: T = mutex.to_owned();
255    ///     // Use `owned_data` safely here
256    /// }
257    /// ```
258    pub unsafe fn to_owned(self) -> T {
259        let data_read = unsafe { ptr::read(&(*self.inner).data) };
260        data_read
261    }
262
263    /// Consumes the mutex and returns an owned `Box<T>` containing the protected data (`T`).
264    ///
265    /// This method is an alternative to [`Self::to_owned`] and is particularly useful when
266    /// dealing with large data types. By returning a `Box<T>`, the data is pool-allocated,
267    /// avoiding potential stack overflows associated with large stack allocations.
268    ///
269    /// # Safety
270    ///
271    /// - **Single Ownership Guarantee:** After calling [`Self::to_owned_box`], ensure that
272    /// no other references (especially static or global ones) attempt to access the
273    /// underlying mutex. This is because the mutexes memory is deallocated once this
274    /// method is invoked.
275    /// - **Exclusive Access:** This function should only be called when you can guarantee
276    /// that there will be no further access to the protected `T`. Violating this can
277    /// lead to undefined behavior since the memory is freed after the call.
278    ///
279    /// # Example
280    ///
281    /// ```rust
282    /// unsafe {
283    ///     let boxed_data: Box<T> = mutex.to_owned_box();
284    ///     // Use `boxed_data` safely here
285    /// }
286    /// ```
287    pub unsafe fn to_owned_box(self) -> Box<T> {
288        let data_read = unsafe { ptr::read(&(*self.inner).data) };
289        Box::new(data_read)
290    }
291}
292
293impl<T> Drop for FastMutex<T> {
294    fn drop(&mut self) {
295        unsafe {
296            // Drop the underlying data and run destructors for the data, this would be relevant in the
297            // case where Self contains other heap allocated types which have their own deallocation
298            // methods.
299            drop_in_place(&mut (*self.inner).data);
300
301            // Free the memory we allocated
302            ExFreePool(self.inner as *mut _);
303        }
304    }
305}
306
307/// A RAII scoped guard for the inner data protected by the mutex. Once this guard is given out, the protected data
308/// may be safely mutated by the caller as we guarantee exclusive access via Windows Kernel Mutex primitives.
309///
310/// When this structure is dropped (falls out of scope), the lock will be unlocked.
311///
312/// # IRQL
313///
314/// Access to the data within this guard must be done at `APC_LEVEL` It is the callers responsible to
315/// manage IRQL whilst using the `FastMutex`. On calling [`FastMutex::lock`], the IRQL will automatically
316/// be raised to `APC_LEVEL`.
317///
318/// If you wish to manually drop the lock with a safety check, call the function [`Self::drop_safe`].
319///
320/// # Kernel panic
321///
322/// Raising the IRQL above safe limits whilst using the mutex will cause a Kernel Panic if not appropriately handled.
323///
324pub struct FastMutexGuard<'a, T> {
325    fast_mutex: &'a FastMutex<T>,
326}
327
328impl<T> Display for FastMutexGuard<'_, T>
329where
330    T: Display,
331{
332    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
333        // SAFETY: Dereferencing the inner data is safe as RAII controls the memory allocations.
334        write!(f, "{}", unsafe { &(*self.fast_mutex.inner).data })
335    }
336}
337
338impl<T> Deref for FastMutexGuard<'_, T> {
339    type Target = T;
340
341    fn deref(&self) -> &Self::Target {
342        // SAFETY: Dereferencing the inner data is safe as RAII controls the memory allocations.
343        unsafe { &(*self.fast_mutex.inner).data }
344    }
345}
346
347impl<T> DerefMut for FastMutexGuard<'_, T> {
348    fn deref_mut(&mut self) -> &mut Self::Target {
349        // SAFETY: Dereferencing the inner data is safe as RAII controls the memory allocations.
350        // Mutable access is safe due to Self only being given out whilst a mutex is held from the
351        // kernel.
352        unsafe { &mut (*self.fast_mutex.inner).data }
353    }
354}
355
356impl<T> Drop for FastMutexGuard<'_, T> {
357    fn drop(&mut self) {
358        // NOT SAFE AT AN INVALID IRQL
359        unsafe { ExReleaseFastMutex(&mut (*self.fast_mutex.inner).mutex) };
360    }
361}
362
363impl<T> FastMutexGuard<'_, T> {
364    /// Safely drop the `FastMutexGuard`, an alternative to RAII.
365    ///
366    /// This function checks the IRQL before attempting to drop the guard.
367    ///
368    /// # Errors
369    ///
370    /// If the IRQL != `APC_LEVEL`, no unlock will occur and a DriverMutexError will be returned to the
371    /// caller.
372    ///
373    /// # IRQL
374    ///
375    /// This function must be called at `APC_LEVEL`
376    pub fn drop_safe(&mut self) -> Result<(), DriverMutexError> {
377        let irql = unsafe { KeGetCurrentIrql() };
378        if irql != APC_LEVEL as u8 {
379            return Err(DriverMutexError::IrqlTooHigh);
380        }
381
382        unsafe { ExReleaseFastMutex(&mut (*self.fast_mutex.inner).mutex) };
383
384        Ok(())
385    }
386}