wdk_mutex/
kmutex.rs

1//! A Rust idiomatic Windows Kernel Driver KMUTEX type which protects the inner type T
2
3use alloc::boxed::Box;
4use core::{
5    ffi::c_void,
6    fmt::Display,
7    ops::{Deref, DerefMut},
8    ptr::{self, drop_in_place, null_mut},
9};
10use wdk_sys::{
11    ntddk::{
12        ExAllocatePool2, ExFreePool, KeGetCurrentIrql, KeInitializeMutex, KeReleaseMutex,
13        KeWaitForSingleObject,
14    },
15    APC_LEVEL, DISPATCH_LEVEL, FALSE, KMUTEX, POOL_FLAG_NON_PAGED,
16    _KWAIT_REASON::Executive,
17    _MODE::KernelMode,
18};
19
20extern crate alloc;
21
22use crate::errors::DriverMutexError;
23/// A thread safe mutex implemented through acquiring a KMUTEX in the Windows kernel.
24///
25/// The type `Kmutex<T>` provides mutually exclusive access to the inner type T allocated through
26/// this crate in the non-paged pool. All data required to initialise the KMutex is allocated in the
27/// non-paged pool and as such is safe to pass stack data into the type as it will not go out of scope.
28///
29/// `KMutex` holds an inner value which is a pointer to a `KMutexInner` type which is the actual type
30/// allocated in the non-paged pool, and this holds information relating to the mutex.
31///
32/// Access to the `T` within the `KMutex` can be done through calling [`Self::lock`].
33///
34/// # Lifetimes
35///
36/// As the `KMutex` is designed to be used in the Windows Kernel, with the Windows `wdk` crate, the lifetimes of
37/// the `KMutex` must be considered by the caller. See examples below for usage.
38///
39/// The `KMutex` can exist in a locally scoped function with little additional configuration. To use the mutex across
40/// thread boundaries, or to use it in callback functions, you can use the `Grt` module found in this crate. See below for
41/// details.
42///
43/// # Deallocation
44///
45/// KMutex handles the deallocation of resources at the point the KMutex is dropped.
46///
47/// # Examples
48///
49/// ## Locally scoped mutex:
50///
51/// ```
52/// {
53///     let mtx = KMutex::new(0u32).unwrap();
54///     let lock = mtx.lock().unwrap();
55///
56///     // If T implements display, you do not need to dereference the lock to print.
57///     println!("The value is: {}", lock);
58/// } // Mutex will become unlocked as it is managed via RAII
59/// ```
60///
61/// ## Global scope via the `Grt` module in `wdk-mutex`:
62///
63/// ```
64/// // Initialise the mutex on DriverEntry
65///
66/// #[export_name = "DriverEntry"]
67/// pub unsafe extern "system" fn driver_entry(
68///     driver: &mut DRIVER_OBJECT,
69///     registry_path: PCUNICODE_STRING,
70/// ) -> NTSTATUS {
71///     if let Err(e) = Grt::init() {
72///         println!("Error creating Grt!: {:?}", e);
73///         return STATUS_UNSUCCESSFUL;
74///     }
75///
76///     // ...
77///     my_function();
78/// }
79///
80///
81/// // Register a new Mutex in the `Grt` of value 0u32:
82///
83/// pub fn my_function() {
84///     Grt::register_kmutex("my_test_mutex", 0u32);
85/// }
86///
87/// unsafe extern "C" fn my_thread_fn_pointer(_: *mut c_void) {
88///     let my_mutex = Grt::get_kmutex::<u32>("my_test_mutex");
89///     if let Err(e) = my_mutex {
90///         println!("Error in thread: {:?}", e);
91///         return;
92///     }
93///
94///     let mut lock = my_mutex.unwrap().lock().unwrap();
95///     *lock += 1;
96/// }
97///
98///
99/// // Destroy the Grt to prevent memory leak on DriverExit
100///
101/// extern "C" fn driver_exit(driver: *mut DRIVER_OBJECT) {
102///     unsafe {Grt::destroy()};
103/// }
104/// ```
105pub struct KMutex<T> {
106    inner: *mut KMutexInner<T>,
107}
108
109/// The underlying data which is non-page pool allocated which is pointed to by the `KMutex`.
110struct KMutexInner<T> {
111    /// A KMUTEX structure allocated into KMutexInner
112    mutex: KMUTEX,
113    /// The data for which the mutex is protecting
114    data: T,
115}
116
117unsafe impl<T> Sync for KMutex<T> {}
118unsafe impl<T> Send for KMutex<T> {}
119
120impl<T> KMutex<T> {
121    /// Creates a new KMUTEX Windows Kernel Driver Mutex in a signaled (free) state.
122    ///
123    /// # IRQL
124    ///
125    /// This can be called at any IRQL.
126    ///
127    /// # Examples
128    ///
129    /// ```
130    /// use wdk_mutex::Mutex;
131    ///
132    /// let my_mutex = wdk_mutex::KMutex::new(0u32);
133    /// ```
134    pub fn new(data: T) -> Result<Self, DriverMutexError> {
135        //
136        // Non-Paged heap alloc for all struct data required for KMutexInner
137        //
138        let total_sz_required = size_of::<KMutexInner<T>>();
139        let inner_heap_ptr: *mut c_void = unsafe {
140            ExAllocatePool2(
141                POOL_FLAG_NON_PAGED,
142                total_sz_required as u64,
143                u32::from_be_bytes(*b"kmtx"),
144            )
145        };
146        if inner_heap_ptr.is_null() {
147            return Err(DriverMutexError::PagedPoolAllocFailed);
148        }
149
150        // Cast the memory allocation to a pointer to the inner
151        let kmutex_inner_ptr = inner_heap_ptr as *mut KMutexInner<T>;
152
153        // SAFETY: This raw write is safe as the pointer validity is checked above.
154        unsafe {
155            ptr::write(
156                kmutex_inner_ptr,
157                KMutexInner {
158                    mutex: KMUTEX::default(),
159                    data,
160                },
161            );
162
163            // Initialise the KMUTEX object via the kernel
164            KeInitializeMutex(&(*kmutex_inner_ptr).mutex as *const _ as *mut _, 0);
165        }
166
167        Ok(Self {
168            inner: kmutex_inner_ptr,
169        })
170    }
171
172    /// Acquires a mutex in a non-alertable manner.
173    ///
174    /// Once the thread has acquired the mutex, it will return a `KMutexGuard` which is a RAII scoped
175    /// guard allowing exclusive access to the inner T.
176    ///
177    /// # Errors
178    ///
179    /// If the IRQL is too high, this function will return an error and will not acquire a lock. To prevent
180    /// a kernel panic, the caller should match the return value rather than just unwrapping the value.
181    ///
182    /// # IRQL
183    ///
184    /// This function must be called at IRQL `<= APC_LEVEL`, if the IRQL is higher than this,
185    /// the function will return an error.
186    ///
187    /// It is the callers responsibility to ensure the IRQL is sufficient to call this function and it
188    /// will not alter the IRQL for the caller, as this may introduce undefined behaviour elsewhere in the
189    /// driver / kernel.
190    ///
191    /// # Examples
192    ///
193    /// ```
194    /// let mtx = KMutex::new(0u32).unwrap();
195    /// let lock = mtx.lock().unwrap();
196    /// ```
197    pub fn lock(&self) -> Result<KMutexGuard<'_, T>, DriverMutexError> {
198        // Check the IRQL is <= APC_LEVEL as per remarks at
199        // https://learn.microsoft.com/en-us/windows-hardware/drivers/ddi/wdm/nf-wdm-kewaitforsingleobject
200        let irql = unsafe { KeGetCurrentIrql() };
201        if irql > APC_LEVEL as u8 {
202            return Err(DriverMutexError::IrqlTooHigh);
203        }
204
205        // Discard the return value; the status code does not represent an error or contain information
206        // relevant to the context of no timeout.
207        let _ = unsafe {
208            // SAFETY: The IRQL is sufficient for the operation as checked above, and we know our pointer
209            // is valid as RAII manages the lifetime of the heap allocation, ensuring it will only be deallocated
210            // once Self gets dropped.
211            KeWaitForSingleObject(
212                &mut (*self.inner).mutex as *mut _ as *mut _,
213                Executive,
214                KernelMode as i8,
215                FALSE as u8,
216                null_mut(),
217            )
218        };
219
220        Ok(KMutexGuard { kmutex: self })
221    }
222
223    /// Consumes the mutex and returns an owned copy of the protected data (`T`).
224    ///
225    /// This method performs a deep copy of the data (`T`) guarded by the mutex before
226    /// deallocating the internal memory. Be cautious when using this method with large
227    /// data types, as it may lead to inefficiencies or stack overflows.
228    ///
229    /// For scenarios involving large data that you prefer not to allocate on the stack,
230    /// consider using [`Self::to_owned_box`] instead.
231    ///
232    /// # Safety
233    ///
234    /// - **Single Ownership Guarantee:** After calling [`Self::to_owned`], ensure that
235    ///   no other references (especially static or global ones) attempt to access the
236    ///   underlying mutex. This is because the mutexes memory is deallocated once this
237    ///   method is invoked.
238    /// - **Exclusive Access:** This function should only be called when you can guarantee
239    ///   that there will be no further access to the protected `T`. Violating this can
240    ///   lead to undefined behavior since the memory is freed after the call.
241    ///
242    /// # Example
243    ///
244    /// ```
245    /// unsafe {
246    ///     let owned_data: T = mutex.to_owned();
247    ///     // Use `owned_data` safely here
248    /// }
249    /// ```
250    pub unsafe fn to_owned(self) -> T {
251        let data_read = unsafe { ptr::read(&(*self.inner).data) };
252        data_read
253    }
254
255    /// Consumes the mutex and returns an owned `Box<T>` containing the protected data (`T`).
256    ///
257    /// This method is an alternative to [`Self::to_owned`] and is particularly useful when
258    /// dealing with large data types. By returning a `Box<T>`, the data is pool-allocated,
259    /// avoiding potential stack overflows associated with large stack allocations.
260    ///
261    /// # Safety
262    ///
263    /// - **Single Ownership Guarantee:** After calling [`Self::to_owned_box`], ensure that
264    /// no other references (especially static or global ones) attempt to access the
265    /// underlying mutex. This is because the mutexes memory is deallocated once this
266    /// method is invoked.
267    /// - **Exclusive Access:** This function should only be called when you can guarantee
268    /// that there will be no further access to the protected `T`. Violating this can
269    /// lead to undefined behavior since the memory is freed after the call.
270    ///
271    /// # Example
272    ///
273    /// ```rust
274    /// unsafe {
275    ///     let boxed_data: Box<T> = mutex.to_owned_box();
276    ///     // Use `boxed_data` safely here
277    /// }
278    /// ```
279    pub unsafe fn to_owned_box(self) -> Box<T> {
280        let data_read = unsafe { ptr::read(&(*self.inner).data) };
281        Box::new(data_read)
282    }
283}
284
285impl<T> Drop for KMutex<T> {
286    fn drop(&mut self) {
287        unsafe {
288            // Drop the underlying data and run destructors for the data, this would be relevant in the
289            // case where Self contains other heap allocated types which have their own deallocation
290            // methods.
291            drop_in_place(&mut (*self.inner).data);
292
293            // Free the memory we allocated
294            ExFreePool(self.inner as *mut _);
295        }
296    }
297}
298
299/// A RAII scoped guard for the inner data protected by the mutex. Once this guard is given out, the protected data
300/// may be safely mutated by the caller as we guarantee exclusive access via Windows Kernel Mutex primitives.
301///
302/// When this structure is dropped (falls out of scope), the lock will be unlocked.
303///
304/// # IRQL
305///
306/// Access to the data within this guard must be done at <= APC_LEVEL if a non-alertable lock was acquired, or <=
307/// DISPATCH_LEVEL if an alertable lock was acquired. It is the callers responsible to manage APC levels whilst
308/// using the KMutex.
309///
310/// If you wish to manually drop the lock with a safety check, call the function [`Self::drop_safe`].
311///
312/// # Kernel panic
313///
314/// Raising the IRQL above safe limits whilst using the mutex will cause a Kernel Panic if not appropriately handled.
315/// When RAII drops this type, the mutex is released, if the mutex goes out of scope whilst you hold an IRQL that
316/// is too high, you will receive a kernel panic.
317///
318pub struct KMutexGuard<'a, T> {
319    kmutex: &'a KMutex<T>,
320}
321
322impl<T> Display for KMutexGuard<'_, T>
323where
324    T: Display,
325{
326    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
327        // SAFETY: Dereferencing the inner data is safe as RAII controls the memory allocations.
328        write!(f, "{}", unsafe { &(*self.kmutex.inner).data })
329    }
330}
331
332impl<T> Deref for KMutexGuard<'_, T> {
333    type Target = T;
334
335    fn deref(&self) -> &Self::Target {
336        // SAFETY: Dereferencing the inner data is safe as RAII controls the memory allocations.
337        unsafe { &(*self.kmutex.inner).data }
338    }
339}
340
341impl<T> DerefMut for KMutexGuard<'_, T> {
342    fn deref_mut(&mut self) -> &mut Self::Target {
343        // SAFETY: Dereferencing the inner data is safe as RAII controls the memory allocations.
344        // Mutable access is safe due to Self only being given out whilst a mutex is held from the
345        // kernel.
346        unsafe { &mut (*self.kmutex.inner).data }
347    }
348}
349
350impl<T> Drop for KMutexGuard<'_, T> {
351    fn drop(&mut self) {
352        // NOT SAFE AT A IRQL TOO HIGH
353        unsafe { KeReleaseMutex(&mut (*self.kmutex.inner).mutex, FALSE as u8) };
354    }
355}
356
357impl<T> KMutexGuard<'_, T> {
358    /// Safely drop the KMutexGuard, an alternative to RAII.
359    ///
360    /// This function checks the IRQL before attempting to drop the guard.
361    ///
362    /// # Errors
363    ///
364    /// If the IRQL > DISPATCH_LEVEL, no unlock will occur and a DriverMutexError will be returned to the
365    /// caller.
366    ///
367    /// # IRQL
368    ///
369    /// This function is safe to call at any IRQL, but it will not release the mutex if IRQL > DISPATCH_LEVEL
370    pub fn drop_safe(&mut self) -> Result<(), DriverMutexError> {
371        let irql = unsafe { KeGetCurrentIrql() };
372        if irql > DISPATCH_LEVEL as u8 {
373            return Err(DriverMutexError::IrqlTooHigh);
374        }
375
376        unsafe { KeReleaseMutex(&mut (*self.kmutex.inner).mutex, FALSE as u8) };
377
378        Ok(())
379    }
380}