wdk_mutex/grt.rs
1//! GRT - Global Reference Tracker - a module to allow for global allocations of mutex
2//! objects with an easy to use API. Easier than manually adding and tracking all static
3//! allocations.
4
5extern crate alloc;
6
7use crate::{errors::GrtError, fast_mutex::FastMutex, kmutex::KMutex};
8use alloc::{boxed::Box, collections::BTreeMap};
9use core::{
10 any::Any,
11 ptr::null_mut,
12 sync::atomic::{AtomicPtr, Ordering::SeqCst},
13};
14
15// A static which points to an initialised box containing the `Grt`
16static WDK_MTX_GRT_PTR: AtomicPtr<Grt> = AtomicPtr::new(null_mut());
17
18/// The Global Reference Tracker (Grt) for `wdk-mutex` is a module designed to improve the development ergonomics
19/// of manually managing memory in a driver required for tracking objects passed between threads.
20///
21/// The `Grt` abstraction makes it safe to register mutex objects and to retrieve them from callbacks and threads
22/// at runtime in the driver, with idiomatic error handling. The `Grt` makes several pool allocations which are tracked
23/// and managed safely via RAII, so if absolute minimal speed is required for accessing mutexes, you may wish to profile this
24/// vs a manual implementation of tracking mutexes however you see fit.
25///
26/// The general way to use this, is to call [`Self::init`] during driver initialisation **once**, and on driver exit to call
27/// [`Self::destroy`] **once**. In between calling `init` and `destroy`, you may add a new `T` (that will be protected by a
28/// `wdk-mutex`) to the `Grt`, assigning a `&str` for the key of a `BTreeMap`, and the value being the `T`. **Note:** you do
29/// not pass a `Mutex` into [`Self::register_kmutex`] or [`Self::register_fast_mutex`] etc; the function will automatically wrap that for you.
30///
31/// [`Self::get_kmutex`] / [`Self::get_fast_mutex`] etc will then allow you to retrieve the `Mutex` dynamically.
32///
33/// # Examples
34///
35/// ```
36/// // Initialise the mutex
37///
38/// #[export_name = "DriverEntry"]
39/// pub unsafe extern "system" fn driver_entry(
40/// driver: &mut DRIVER_OBJECT,
41/// registry_path: PCUNICODE_STRING,
42/// ) -> NTSTATUS {
43/// if let Err(e) = Grt::init() {
44/// println!("Error creating Grt!: {:?}", e);
45/// return STATUS_UNSUCCESSFUL;
46/// }
47///
48/// // ...
49/// my_function();
50/// }
51///
52///
53/// // Register a new Mutex in the `Grt` of value 0u32:
54///
55/// pub fn my_function() {
56/// Grt::register_kmutex("my_test_mutex", 0u32);
57/// }
58///
59/// unsafe extern "C" fn my_thread_fn_pointer(_: *mut c_void) {
60/// let my_mutex = Grt::get_kmutex::<u32>("my_test_mutex");
61/// if let Err(e) = my_mutex {
62/// println!("Error in thread: {:?}", e);
63/// return;
64/// }
65///
66/// let mut lock = my_mutex.unwrap().lock().unwrap();
67/// *lock += 1;
68/// }
69///
70///
71/// // Destroy the Grt to prevent memory leak on DriverExit
72///
73/// extern "C" fn driver_exit(driver: *mut DRIVER_OBJECT) {
74/// unsafe {Grt::destroy()};
75/// }
76/// ```
77pub struct Grt {
78 global_kmutex: BTreeMap<&'static str, Box<dyn Any>>,
79}
80
81/// The type of mutexes which is passed in to the Grt to correctly initialise a new `mutex`.
82pub enum MutexType {
83 FastMutex,
84 KMutex,
85}
86
87impl Grt {
88 /// Initialise a new instance of the Global Reference Tracker for `wdk-mutex`.
89 ///
90 /// This should only be called once in your driver and will initialise the `Grt` to be globally available
91 /// at any point you wish to utilise it to retrieve a mutex and its wrapped T.
92 ///
93 /// # Errors
94 ///
95 /// This function will error if:
96 ///
97 /// - You have already initialised the `Grt`
98 ///
99 /// # Examples
100 ///
101 /// ```
102 /// #[export_name = "DriverEntry"]
103 /// pub unsafe extern "system" fn driver_entry(
104 /// driver: &mut DRIVER_OBJECT,
105 /// registry_path: PCUNICODE_STRING,
106 /// ) -> NTSTATUS {
107 /// // A good place to initialise it is early during driver initialisation
108 /// if let Err(e) = Grt::init() {
109 /// println!("Error creating Grt! {:?}", e);
110 /// return STATUS_UNSUCCESSFUL;
111 /// }
112 /// }
113 /// ```
114 pub fn init() -> Result<(), GrtError> {
115 // Check we aren't double initialising
116 if !WDK_MTX_GRT_PTR.load(SeqCst).is_null() {
117 return Err(GrtError::GrtAlreadyExists);
118 }
119
120 //
121 // Initialise a new Grt in a box, which will be converted to a raw pointer and stored in the static
122 // AtomicPtr which is used for tracking the `Grt` structure.
123 // On `Grt::destroy()` being called The raw pointer will then be converted from a ptr into a box,
124 // allowing RAII to drop the memory properly when the destroy method is called.
125 //
126
127 let pool_ptr = Box::into_raw(Box::new(Grt {
128 global_kmutex: BTreeMap::new(),
129 }));
130
131 WDK_MTX_GRT_PTR.store(pool_ptr, SeqCst);
132
133 Ok(())
134 }
135
136 /// Register a new [`KMutex`] for the global reference tracker to control.
137 ///
138 /// The function takes a label as a static &str which is the key of a BTreeMap, and the type you wish
139 /// to protect with the mutex as the data. If the key already exists, the function will indiscriminately insert
140 /// a key and overwrite any existing data.
141 ///
142 /// If you wish to perform this function checking for an existing key before registering the mutex object,
143 /// use [`Self::register_kmutex_checked`].
144 ///
145 /// # Errors
146 ///
147 /// This function will error if:
148 ///
149 /// - `Grt` has not been initialised, see [`Grt::init`]
150 ///
151 /// # Examples
152 ///
153 /// ```
154 /// Grt::register_kmutex("my_test_mutex", 0u32);
155 /// ```
156 pub fn register_kmutex<T: Any>(label: &'static str, data: T) -> Result<(), GrtError> {
157 // Check for a null pointer on the atomic
158 let atomic_ptr = WDK_MTX_GRT_PTR.load(SeqCst);
159 if atomic_ptr.is_null() {
160 return Err(GrtError::GrtIsNull);
161 }
162
163 // Try initialise a new mutex
164 let mtx = Box::new(KMutex::new(data).map_err(|e| GrtError::DriverMutexError(e))?);
165
166 // SAFETY: The atomic pointer is checked at the start of the fn for a nullptr
167 unsafe {
168 (*atomic_ptr).global_kmutex.insert(label, mtx);
169 }
170
171 Ok(())
172 }
173
174 /// Register a new [`FastMutex`] for the global reference tracker to control.
175 ///
176 /// The function takes a label as a static &str which is the key of a BTreeMap, and the type you wish
177 /// to protect with the mutex as the data. If the key already exists, the function will indiscriminately insert
178 /// a key and overwrite any existing data.
179 ///
180 /// If you wish to perform this function checking for an existing key before registering the mutex object,
181 /// use [`Self::register_fast_mutex_checked`].
182 ///
183 /// # Errors
184 ///
185 /// This function will error if:
186 ///
187 /// - `Grt` has not been initialised, see [`Grt::init`]
188 ///
189 /// # Examples
190 ///
191 /// ```
192 /// Grt::register_fast_mutex("my_test_mutex", 0u32);
193 /// ```
194 pub fn register_fast_mutex<T: Any>(label: &'static str, data: T) -> Result<(), GrtError> {
195 // Check for a null pointer on the atomic
196 let atomic_ptr = WDK_MTX_GRT_PTR.load(SeqCst);
197 if atomic_ptr.is_null() {
198 return Err(GrtError::GrtIsNull);
199 }
200
201 // Try initialise a new mutex
202 let mtx = Box::new(FastMutex::new(data).map_err(|e| GrtError::DriverMutexError(e))?);
203
204 // SAFETY: The atomic pointer is checked at the start of the fn for a nullptr
205 unsafe {
206 (*atomic_ptr).global_kmutex.insert(label, mtx);
207 }
208
209 Ok(())
210 }
211
212 /// Register a new [`KMutex`] for the global reference tracker to control, throwing an error if the key already
213 /// exists.
214 ///
215 /// This is a checked alternative to [`Self::register_kmutex`], and as such incurs a little additional overhead.
216 ///
217 /// # Errors
218 ///
219 /// This function will error if:
220 ///
221 /// - `Grt` has not been initialised, see [`Grt::init`]
222 /// - The mutex key already exists
223 ///
224 /// # Examples
225 ///
226 /// ```
227 /// let result = Grt::register_kmutex_checked("my_test_mutex", 0u32);
228 /// ```
229 pub fn register_kmutex_checked<T: Any>(label: &'static str, data: T) -> Result<(), GrtError> {
230 // Check for a null pointer on the atomic
231 let atomic_ptr = WDK_MTX_GRT_PTR.load(SeqCst);
232 if atomic_ptr.is_null() {
233 return Err(GrtError::GrtIsNull);
234 }
235
236 // Try initialise a new mutex
237 let mtx = Box::new(KMutex::new(data).map_err(|e| GrtError::DriverMutexError(e))?);
238
239 // SAFETY: The atomic pointer is checked at the start of the fn for a nullptr
240 unsafe {
241 let bucket = (*atomic_ptr).global_kmutex.get(label);
242 if bucket.is_some() {
243 return Err(GrtError::KeyExists);
244 }
245
246 (*atomic_ptr).global_kmutex.insert(label, mtx);
247 }
248
249 Ok(())
250 }
251
252 /// Register a new [`FastMutex`] for the global reference tracker to control, throwing an error if the key already
253 /// exists.
254 ///
255 /// This is a checked alternative to [`Self::register_fast_mutex`], and as such incurs a little additional overhead.
256 ///
257 /// # Errors
258 ///
259 /// This function will error if:
260 ///
261 /// - `Grt` has not been initialised, see [`Grt::init`]
262 /// - The mutex key already exists
263 ///
264 /// # Examples
265 ///
266 /// ```
267 /// let result = Grt::register_fast_mutex_checked("my_test_mutex", 0u32);
268 /// ```
269 pub fn register_fast_mutex_checked<T: Any>(
270 label: &'static str,
271 data: T,
272 ) -> Result<(), GrtError> {
273 // Check for a null pointer on the atomic
274 let atomic_ptr = WDK_MTX_GRT_PTR.load(SeqCst);
275 if atomic_ptr.is_null() {
276 return Err(GrtError::GrtIsNull);
277 }
278
279 // Try initialise a new mutex
280 let mtx = Box::new(FastMutex::new(data).map_err(|e| GrtError::DriverMutexError(e))?);
281
282 // SAFETY: The atomic pointer is checked at the start of the fn for a nullptr
283 unsafe {
284 let bucket = (*atomic_ptr).global_kmutex.get(label);
285 if bucket.is_some() {
286 return Err(GrtError::KeyExists);
287 }
288
289 (*atomic_ptr).global_kmutex.insert(label, mtx);
290 }
291
292 Ok(())
293 }
294
295 /// Retrieve a mutex by name from the `wdk-mutex` global reference tracker.
296 ///
297 /// This function takes in a static `&str` to lookup your Mutex by key (where the key is the argument). When calling
298 /// this function, a turbofish specifier is required to tell the compiler what type is contained in the `Mutex`. See
299 /// examples for more information.
300 ///
301 /// # Errors
302 ///
303 /// This function will error if:
304 ///
305 /// - The `Grt` has not been initialised
306 /// - The `Grt` is empty
307 /// - The key does not exist
308 /// - The mutex type is anything other than a [`KMutex`]
309 ///
310 /// # Examples
311 ///
312 /// ```
313 /// {
314 /// let my_mutex = Grt::get_kmutex::<u32>("my_test_mutex");
315 /// if let Err(e) = my_mutex {
316 /// println!("An error occurred: {:?}", e);
317 /// return;
318 /// }
319 /// let mut lock = my_mutex.unwrap().lock().unwrap();
320 /// *lock += 1;
321 /// }
322 /// ```
323 pub fn get_kmutex<T>(key: &'static str) -> Result<&'static KMutex<T>, GrtError> {
324 //
325 // Perform checks for erroneous state
326 //
327 let ptr = WDK_MTX_GRT_PTR.load(SeqCst);
328 if ptr.is_null() {
329 return Err(GrtError::GrtIsNull);
330 }
331
332 let grt = unsafe { &(*ptr).global_kmutex };
333 if grt.is_empty() {
334 return Err(GrtError::GrtIsEmpty);
335 }
336
337 let mutex = grt.get(key);
338 if mutex.is_none() {
339 return Err(GrtError::KeyNotFound);
340 }
341
342 //
343 // The mutex is valid so obtain a reference to it which can be returned
344 //
345
346 // SAFETY: Null pointer and inner null pointers have both been checked in the above lines.
347 let m = &**mutex.unwrap();
348 let km = m.downcast_ref::<KMutex<T>>();
349
350 if km.is_none() {
351 return Err(GrtError::DowncastError);
352 }
353
354 Ok(km.unwrap())
355 }
356
357 /// Retrieve a mutex by name from the `wdk-mutex` global reference tracker.
358 ///
359 /// This function takes in a static `&str` to lookup your Mutex by key (where the key is the argument). When calling
360 /// this function, a turbofish specifier is required to tell the compiler what type is contained in the `Mutex`. See
361 /// examples for more information.
362 ///
363 /// # Errors
364 ///
365 /// This function will error if:
366 ///
367 /// - The `Grt` has not been initialised
368 /// - The `Grt` is empty
369 /// - The key does not exist
370 /// - The mutex type is anything other than a [`FastMutex`]
371 ///
372 /// # Examples
373 ///
374 /// ```
375 /// {
376 /// let my_mutex = Grt::get_fast_mutex::<u32>("my_test_mutex");
377 /// if let Err(e) = my_mutex {
378 /// println!("An error occurred: {:?}", e);
379 /// return;
380 /// }
381 /// let mut lock = my_mutex.unwrap().lock().unwrap();
382 /// *lock += 1;
383 /// }
384 /// ```
385 pub fn get_fast_mutex<T>(key: &'static str) -> Result<&'static FastMutex<T>, GrtError> {
386 //
387 // Perform checks for erroneous state
388 //
389 let ptr = WDK_MTX_GRT_PTR.load(SeqCst);
390 if ptr.is_null() {
391 return Err(GrtError::GrtIsNull);
392 }
393
394 let grt = unsafe { &(*ptr).global_kmutex };
395 if grt.is_empty() {
396 return Err(GrtError::GrtIsEmpty);
397 }
398
399 let mutex = grt.get(key);
400 if mutex.is_none() {
401 return Err(GrtError::KeyNotFound);
402 }
403
404 //
405 // The mutex is valid so obtain a reference to it which can be returned
406 //
407
408 // SAFETY: Null pointer and inner null pointers have both been checked in the above lines.
409 let m = &**mutex.unwrap();
410 let km = m.downcast_ref::<FastMutex<T>>();
411
412 if km.is_none() {
413 return Err(GrtError::DowncastError);
414 }
415
416 Ok(km.unwrap())
417 }
418
419 /// Destroy the global reference tracker for `wdk-mutex`.
420 ///
421 /// Calling [`Self::destroy`] will destroy the 'runtime' provided for using globally accessible `wdk-mutex` mutexes
422 /// in your driver.
423 ///
424 /// # Safety
425 ///
426 /// Once this function is called you will no longer be able to access any mutexes who's lifetime is managed by the
427 /// `Grt`.
428 ///
429 /// **Note:** This function is marked `unsafe` as it could lead to UB if accidentally used whilst threads / callbacks
430 /// dependant upon a mutex that it managed. Although it is `unsafe`, attempting to access a mutex after the `Grt` is destroyed
431 /// will not cause a null pointer dereference (they are checked), but it could lead to UB as those setter/getter functions will
432 /// return an error.
433 ///
434 /// # Examples
435 ///
436 /// ```
437 /// /// Driver exit routine
438 /// extern "C" fn driver_exit(driver: *mut DRIVER_OBJECT) {
439 /// unsafe { Grt::destroy() };
440 /// }
441 /// ```
442 pub unsafe fn destroy() -> Result<(), GrtError> {
443 // Check that the static pointer is not already null
444 let grt_ptr = WDK_MTX_GRT_PTR.load(SeqCst);
445 if grt_ptr.is_null() {
446 return Err(GrtError::GrtIsNull);
447 }
448
449 // Set the atomic global to null
450 WDK_MTX_GRT_PTR.store(null_mut(), SeqCst);
451
452 // Convert the pointer back to a box which wraps the inner `Grt`, allowing Box to drop all it's content
453 // which will free all inner memory, drop will properly be called on all Mutexes.
454 let _ = unsafe { Box::from_raw(grt_ptr) };
455
456 Ok(())
457 }
458}