wdk_mutex/
grt.rs

1//! GRT - Global Reference Tracker - a module to allow for global allocations of mutex
2//! objects with an easy to use API. Easier than manually adding and tracking all static
3//! allocations.
4
5extern crate alloc;
6
7use crate::{errors::GrtError, fast_mutex::FastMutex, kmutex::KMutex};
8use alloc::{boxed::Box, collections::BTreeMap};
9use core::{
10    any::Any,
11    ptr::null_mut,
12    sync::atomic::{AtomicPtr, Ordering::SeqCst},
13};
14
15// A static which points to an initialised box containing the `Grt`
16static WDK_MTX_GRT_PTR: AtomicPtr<Grt> = AtomicPtr::new(null_mut());
17
18/// The Global Reference Tracker (Grt) for `wdk-mutex` is a module designed to improve the development ergonomics
19/// of manually managing memory in a driver required for tracking objects passed between threads.
20///
21/// The `Grt` abstraction makes it safe to register mutex objects and to retrieve them from callbacks and threads
22/// at runtime in the driver, with idiomatic error handling. The `Grt` makes several pool allocations which are tracked
23/// and managed safely via RAII, so if absolute minimal speed is required for accessing mutexes, you may wish to profile this
24/// vs a manual implementation of tracking mutexes however you see fit.
25///
26/// The general way to use this, is to call [`Self::init`] during driver initialisation **once**, and on driver exit to call
27/// [`Self::destroy`] **once**. In between calling `init` and `destroy`, you may add a new `T` (that will be protected by a
28/// `wdk-mutex`) to the `Grt`, assigning a `&str` for the key of a `BTreeMap`, and the value being the `T`. **Note:** you do
29/// not pass a `Mutex` into [`Self::register_kmutex`] or [`Self::register_fast_mutex`] etc; the function will automatically wrap that for you.
30///
31/// [`Self::get_kmutex`] / [`Self::get_fast_mutex`] etc will then allow you to retrieve the `Mutex` dynamically.
32///
33/// # Examples
34///
35/// ```
36/// // Initialise the mutex
37///
38/// #[export_name = "DriverEntry"]
39/// pub unsafe extern "system" fn driver_entry(
40///     driver: &mut DRIVER_OBJECT,
41///     registry_path: PCUNICODE_STRING,
42/// ) -> NTSTATUS {
43///     if let Err(e) = Grt::init() {
44///         println!("Error creating Grt!: {:?}", e);
45///         return STATUS_UNSUCCESSFUL;
46///     }
47///
48///     // ...
49///     my_function();
50/// }
51///
52///
53/// // Register a new Mutex in the `Grt` of value 0u32:
54///
55/// pub fn my_function() {
56///     Grt::register_kmutex("my_test_mutex", 0u32);
57/// }
58///
59/// unsafe extern "C" fn my_thread_fn_pointer(_: *mut c_void) {
60///     let my_mutex = Grt::get_kmutex::<u32>("my_test_mutex");
61///     if let Err(e) = my_mutex {
62///         println!("Error in thread: {:?}", e);
63///         return;
64///     }
65///
66///     let mut lock = my_mutex.unwrap().lock().unwrap();
67///     *lock += 1;
68/// }
69///
70///
71/// // Destroy the Grt to prevent memory leak on DriverExit
72///
73/// extern "C" fn driver_exit(driver: *mut DRIVER_OBJECT) {
74///     unsafe {Grt::destroy()};
75/// }
76/// ```
77pub struct Grt {
78    global_kmutex: BTreeMap<&'static str, Box<dyn Any>>,
79}
80
81/// The type of mutexes which is passed in to the Grt to correctly initialise a new `mutex`.
82pub enum MutexType {
83    FastMutex,
84    KMutex,
85}
86
87impl Grt {
88    /// Initialise a new instance of the Global Reference Tracker for `wdk-mutex`.
89    ///
90    /// This should only be called once in your driver and will initialise the `Grt` to be globally available
91    /// at any point you wish to utilise it to retrieve a mutex and its wrapped T.
92    ///
93    /// # Errors
94    ///
95    /// This function will error if:
96    ///
97    /// - You have already initialised the `Grt`
98    ///
99    /// # Examples
100    ///
101    /// ```
102    /// #[export_name = "DriverEntry"]
103    /// pub unsafe extern "system" fn driver_entry(
104    ///     driver: &mut DRIVER_OBJECT,
105    ///     registry_path: PCUNICODE_STRING,
106    /// ) -> NTSTATUS {
107    ///     // A good place to initialise it is early during driver initialisation
108    ///     if let Err(e) = Grt::init() {
109    ///         println!("Error creating Grt! {:?}", e);
110    ///         return STATUS_UNSUCCESSFUL;
111    ///     }
112    /// }
113    /// ```
114    pub fn init() -> Result<(), GrtError> {
115        // Check we aren't double initialising
116        if !WDK_MTX_GRT_PTR.load(SeqCst).is_null() {
117            return Err(GrtError::GrtAlreadyExists);
118        }
119
120        //
121        // Initialise a new Grt in a box, which will be converted to a raw pointer and stored in the static
122        // AtomicPtr which is used for tracking the `Grt` structure.
123        // On `Grt::destroy()` being called The raw pointer will then be converted from a ptr into a box,
124        // allowing RAII to drop the memory properly when the destroy method is called.
125        //
126
127        let pool_ptr = Box::into_raw(Box::new(Grt {
128            global_kmutex: BTreeMap::new(),
129        }));
130
131        WDK_MTX_GRT_PTR.store(pool_ptr, SeqCst);
132
133        Ok(())
134    }
135
136    /// Register a new [`KMutex`] for the global reference tracker to control.
137    ///
138    /// The function takes a label as a static &str which is the key of a BTreeMap, and the type you wish
139    /// to protect with the mutex as the data. If the key already exists, the function will indiscriminately insert
140    /// a key and overwrite any existing data.
141    ///
142    /// If you wish to perform this function checking for an existing key before registering the mutex object,
143    /// use [`Self::register_kmutex_checked`].
144    ///
145    /// # Errors
146    ///
147    /// This function will error if:
148    ///
149    /// - `Grt` has not been initialised, see [`Grt::init`]
150    ///
151    /// # Examples
152    ///
153    /// ```
154    /// Grt::register_kmutex("my_test_mutex", 0u32);
155    /// ```
156    pub fn register_kmutex<T: Any>(label: &'static str, data: T) -> Result<(), GrtError> {
157        // Check for a null pointer on the atomic
158        let atomic_ptr = WDK_MTX_GRT_PTR.load(SeqCst);
159        if atomic_ptr.is_null() {
160            return Err(GrtError::GrtIsNull);
161        }
162
163        // Try initialise a new mutex
164        let mtx = Box::new(KMutex::new(data).map_err(|e| GrtError::DriverMutexError(e))?);
165
166        // SAFETY: The atomic pointer is checked at the start of the fn for a nullptr
167        unsafe {
168            (*atomic_ptr).global_kmutex.insert(label, mtx);
169        }
170
171        Ok(())
172    }
173
174    /// Register a new [`FastMutex`] for the global reference tracker to control.
175    ///
176    /// The function takes a label as a static &str which is the key of a BTreeMap, and the type you wish
177    /// to protect with the mutex as the data. If the key already exists, the function will indiscriminately insert
178    /// a key and overwrite any existing data.
179    ///
180    /// If you wish to perform this function checking for an existing key before registering the mutex object,
181    /// use [`Self::register_fast_mutex_checked`].
182    ///
183    /// # Errors
184    ///
185    /// This function will error if:
186    ///
187    /// - `Grt` has not been initialised, see [`Grt::init`]
188    ///
189    /// # Examples
190    ///
191    /// ```
192    /// Grt::register_fast_mutex("my_test_mutex", 0u32);
193    /// ```
194    pub fn register_fast_mutex<T: Any>(label: &'static str, data: T) -> Result<(), GrtError> {
195        // Check for a null pointer on the atomic
196        let atomic_ptr = WDK_MTX_GRT_PTR.load(SeqCst);
197        if atomic_ptr.is_null() {
198            return Err(GrtError::GrtIsNull);
199        }
200
201        // Try initialise a new mutex
202        let mtx = Box::new(FastMutex::new(data).map_err(|e| GrtError::DriverMutexError(e))?);
203
204        // SAFETY: The atomic pointer is checked at the start of the fn for a nullptr
205        unsafe {
206            (*atomic_ptr).global_kmutex.insert(label, mtx);
207        }
208
209        Ok(())
210    }
211
212    /// Register a new [`KMutex`] for the global reference tracker to control, throwing an error if the key already
213    /// exists.
214    ///
215    /// This is a checked alternative to [`Self::register_kmutex`], and as such incurs a little additional overhead.
216    ///
217    /// # Errors
218    ///
219    /// This function will error if:
220    ///
221    /// - `Grt` has not been initialised, see [`Grt::init`]
222    /// - The mutex key already exists
223    ///
224    /// # Examples
225    ///
226    /// ```
227    /// let result = Grt::register_kmutex_checked("my_test_mutex", 0u32);
228    /// ```
229    pub fn register_kmutex_checked<T: Any>(label: &'static str, data: T) -> Result<(), GrtError> {
230        // Check for a null pointer on the atomic
231        let atomic_ptr = WDK_MTX_GRT_PTR.load(SeqCst);
232        if atomic_ptr.is_null() {
233            return Err(GrtError::GrtIsNull);
234        }
235
236        // Try initialise a new mutex
237        let mtx = Box::new(KMutex::new(data).map_err(|e| GrtError::DriverMutexError(e))?);
238
239        // SAFETY: The atomic pointer is checked at the start of the fn for a nullptr
240        unsafe {
241            let bucket = (*atomic_ptr).global_kmutex.get(label);
242            if bucket.is_some() {
243                return Err(GrtError::KeyExists);
244            }
245
246            (*atomic_ptr).global_kmutex.insert(label, mtx);
247        }
248
249        Ok(())
250    }
251
252    /// Register a new [`FastMutex`] for the global reference tracker to control, throwing an error if the key already
253    /// exists.
254    ///
255    /// This is a checked alternative to [`Self::register_fast_mutex`], and as such incurs a little additional overhead.
256    ///
257    /// # Errors
258    ///
259    /// This function will error if:
260    ///
261    /// - `Grt` has not been initialised, see [`Grt::init`]
262    /// - The mutex key already exists
263    ///
264    /// # Examples
265    ///
266    /// ```
267    /// let result = Grt::register_fast_mutex_checked("my_test_mutex", 0u32);
268    /// ```
269    pub fn register_fast_mutex_checked<T: Any>(
270        label: &'static str,
271        data: T,
272    ) -> Result<(), GrtError> {
273        // Check for a null pointer on the atomic
274        let atomic_ptr = WDK_MTX_GRT_PTR.load(SeqCst);
275        if atomic_ptr.is_null() {
276            return Err(GrtError::GrtIsNull);
277        }
278
279        // Try initialise a new mutex
280        let mtx = Box::new(FastMutex::new(data).map_err(|e| GrtError::DriverMutexError(e))?);
281
282        // SAFETY: The atomic pointer is checked at the start of the fn for a nullptr
283        unsafe {
284            let bucket = (*atomic_ptr).global_kmutex.get(label);
285            if bucket.is_some() {
286                return Err(GrtError::KeyExists);
287            }
288
289            (*atomic_ptr).global_kmutex.insert(label, mtx);
290        }
291
292        Ok(())
293    }
294
295    /// Retrieve a mutex by name from the `wdk-mutex` global reference tracker.
296    ///
297    /// This function takes in a static `&str` to lookup your Mutex by key (where the key is the argument). When calling
298    /// this function, a turbofish specifier is required to tell the compiler what type is contained in the `Mutex`. See
299    /// examples for more information.
300    ///
301    /// # Errors
302    ///
303    /// This function will error if:
304    ///
305    /// - The `Grt` has not been initialised
306    /// - The `Grt` is empty
307    /// - The key does not exist
308    /// - The mutex type is anything other than a [`KMutex`]
309    ///
310    /// # Examples
311    ///
312    /// ```
313    /// {
314    ///     let my_mutex = Grt::get_kmutex::<u32>("my_test_mutex");
315    ///     if let Err(e) = my_mutex {
316    ///         println!("An error occurred: {:?}", e);
317    ///         return;
318    ///     }
319    ///     let mut lock = my_mutex.unwrap().lock().unwrap();
320    ///     *lock += 1;
321    /// }
322    /// ```
323    pub fn get_kmutex<T>(key: &'static str) -> Result<&'static KMutex<T>, GrtError> {
324        //
325        // Perform checks for erroneous state
326        //
327        let ptr = WDK_MTX_GRT_PTR.load(SeqCst);
328        if ptr.is_null() {
329            return Err(GrtError::GrtIsNull);
330        }
331
332        let grt = unsafe { &(*ptr).global_kmutex };
333        if grt.is_empty() {
334            return Err(GrtError::GrtIsEmpty);
335        }
336
337        let mutex = grt.get(key);
338        if mutex.is_none() {
339            return Err(GrtError::KeyNotFound);
340        }
341
342        //
343        // The mutex is valid so obtain a reference to it which can be returned
344        //
345
346        // SAFETY: Null pointer and inner null pointers have both been checked in the above lines.
347        let m = &**mutex.unwrap();
348        let km = m.downcast_ref::<KMutex<T>>();
349
350        if km.is_none() {
351            return Err(GrtError::DowncastError);
352        }
353
354        Ok(km.unwrap())
355    }
356
357    /// Retrieve a mutex by name from the `wdk-mutex` global reference tracker.
358    ///
359    /// This function takes in a static `&str` to lookup your Mutex by key (where the key is the argument). When calling
360    /// this function, a turbofish specifier is required to tell the compiler what type is contained in the `Mutex`. See
361    /// examples for more information.
362    ///
363    /// # Errors
364    ///
365    /// This function will error if:
366    ///
367    /// - The `Grt` has not been initialised
368    /// - The `Grt` is empty
369    /// - The key does not exist
370    /// - The mutex type is anything other than a [`FastMutex`]
371    ///
372    /// # Examples
373    ///
374    /// ```
375    /// {
376    ///     let my_mutex = Grt::get_fast_mutex::<u32>("my_test_mutex");
377    ///     if let Err(e) = my_mutex {
378    ///         println!("An error occurred: {:?}", e);
379    ///         return;
380    ///     }
381    ///     let mut lock = my_mutex.unwrap().lock().unwrap();
382    ///     *lock += 1;
383    /// }
384    /// ```
385    pub fn get_fast_mutex<T>(key: &'static str) -> Result<&'static FastMutex<T>, GrtError> {
386        //
387        // Perform checks for erroneous state
388        //
389        let ptr = WDK_MTX_GRT_PTR.load(SeqCst);
390        if ptr.is_null() {
391            return Err(GrtError::GrtIsNull);
392        }
393
394        let grt = unsafe { &(*ptr).global_kmutex };
395        if grt.is_empty() {
396            return Err(GrtError::GrtIsEmpty);
397        }
398
399        let mutex = grt.get(key);
400        if mutex.is_none() {
401            return Err(GrtError::KeyNotFound);
402        }
403
404        //
405        // The mutex is valid so obtain a reference to it which can be returned
406        //
407
408        // SAFETY: Null pointer and inner null pointers have both been checked in the above lines.
409        let m = &**mutex.unwrap();
410        let km = m.downcast_ref::<FastMutex<T>>();
411
412        if km.is_none() {
413            return Err(GrtError::DowncastError);
414        }
415
416        Ok(km.unwrap())
417    }
418
419    /// Destroy the global reference tracker for `wdk-mutex`.
420    ///
421    /// Calling [`Self::destroy`] will destroy the 'runtime' provided for using globally accessible `wdk-mutex` mutexes
422    /// in your driver.
423    ///
424    /// # Safety
425    ///
426    /// Once this function is called you will no longer be able to access any mutexes who's lifetime is managed by the
427    /// `Grt`.
428    ///
429    /// **Note:** This function is marked `unsafe` as it could lead to UB if accidentally used whilst threads / callbacks
430    /// dependant upon a mutex that it managed. Although it is `unsafe`, attempting to access a mutex after the `Grt` is destroyed
431    /// will not cause a null pointer dereference (they are checked), but it could lead to UB as those setter/getter functions will
432    /// return an error.
433    ///
434    /// # Examples
435    ///
436    /// ```
437    /// /// Driver exit routine
438    /// extern "C" fn driver_exit(driver: *mut DRIVER_OBJECT) {
439    ///     unsafe { Grt::destroy() };
440    /// }
441    /// ```
442    pub unsafe fn destroy() -> Result<(), GrtError> {
443        // Check that the static pointer is not already null
444        let grt_ptr = WDK_MTX_GRT_PTR.load(SeqCst);
445        if grt_ptr.is_null() {
446            return Err(GrtError::GrtIsNull);
447        }
448
449        // Set the atomic global to null
450        WDK_MTX_GRT_PTR.store(null_mut(), SeqCst);
451
452        // Convert the pointer back to a box which wraps the inner `Grt`, allowing Box to drop all it's content
453        // which will free all inner memory, drop will properly be called on all Mutexes.
454        let _ = unsafe { Box::from_raw(grt_ptr) };
455
456        Ok(())
457    }
458}