Skip to main content

shared/
flags.rs

1#[macro_export]
2macro_rules! flags {
3    () => {};
4
5    // Entry point for enumerations without values
6    ($(#[$attributes:meta])* $visibility:vis enum $identifier:ident: $t:ty { $($(#[$variant:meta])* $k:ident),+ $(,)* } $($next:tt)*) => {
7        $crate::flags! {
8            @count_and_gen $(#[$attributes])* $visibility enum $identifier: $t
9            { } // accumulated variants
10            [ $( ($(#[$variant])* $k) )+ ] // remaining to process
11            [] // counter (empty = 0)
12            $($next)*
13        }
14    };
15
16    // Process each variant, incrementing the counter
17    (@count_and_gen $(#[$attributes:meta])* $visibility:vis enum $identifier:ident: $t:ty
18        { $($accumulated:tt)* }
19        [ ($(#[$variant_meta:meta])* $current:ident) $(($($rest_items:tt)*))* ]
20        [ $($counter:tt)* ]
21        $($next:tt)*
22    ) => {
23        $crate::flags! {
24            @count_and_gen $(#[$attributes])* $visibility enum $identifier: $t
25            { $($accumulated)* $(#[$variant_meta])* $current = $crate::flags!(@bit_value [ $($counter)* ]), }
26            [ $(($($rest_items)*))* ]
27            [ $($counter)* + ] // increment counter
28            $($next)*
29        }
30    };
31
32    // When all variants are processed, generate the struct
33    (@count_and_gen $(#[$attributes:meta])* $visibility:vis enum $identifier:ident: $t:ty
34        { $($accumulated:tt)* }
35        [ ]
36        [ $($counter:tt)* ]
37        $($next:tt)*
38    ) => {
39        $crate::flags! { $(#[$attributes])* $visibility enum $identifier: $t { $($accumulated)* } $($next)* }
40    };
41
42    // Convert counter tokens to bit shift value
43    (@bit_value []) => { 1 << 0 };
44    (@bit_value [+]) => { 1 << 1 };
45    (@bit_value [+ +]) => { 1 << 2 };
46    (@bit_value [+ + +]) => { 1 << 3 };
47    (@bit_value [+ + + +]) => { 1 << 4 };
48    (@bit_value [+ + + + +]) => { 1 << 5 };
49    (@bit_value [+ + + + + +]) => { 1 << 6 };
50    (@bit_value [+ + + + + + +]) => { 1 << 7 };
51    (@bit_value [+ + + + + + + +]) => { 1 << 8 };
52    (@bit_value [+ + + + + + + + +]) => { 1 << 9 };
53    (@bit_value [+ + + + + + + + + +]) => { 1 << 10 };
54    (@bit_value [+ + + + + + + + + + +]) => { 1 << 11 };
55    (@bit_value [+ + + + + + + + + + + +]) => { 1 << 12 };
56    (@bit_value [+ + + + + + + + + + + + +]) => { 1 << 13 };
57    (@bit_value [+ + + + + + + + + + + + + +]) => { 1 << 14 };
58    (@bit_value [+ + + + + + + + + + + + + + +]) => { 1 << 15 };
59    (@bit_value [+ + + + + + + + + + + + + + + +]) => { 1 << 16 };
60    (@bit_value [+ + + + + + + + + + + + + + + + +]) => { 1 << 17 };
61    (@bit_value [+ + + + + + + + + + + + + + + + + +]) => { 1 << 18 };
62    (@bit_value [+ + + + + + + + + + + + + + + + + + +]) => { 1 << 19 };
63    (@bit_value [+ + + + + + + + + + + + + + + + + + + +]) => { 1 << 20 };
64    (@bit_value [+ + + + + + + + + + + + + + + + + + + + +]) => { 1 << 21 };
65    (@bit_value [+ + + + + + + + + + + + + + + + + + + + + +]) => { 1 << 22 };
66    (@bit_value [+ + + + + + + + + + + + + + + + + + + + + + +]) => { 1 << 23 };
67    (@bit_value [+ + + + + + + + + + + + + + + + + + + + + + + +]) => { 1 << 24 };
68    (@bit_value [+ + + + + + + + + + + + + + + + + + + + + + + + +]) => { 1 << 25 };
69    (@bit_value [+ + + + + + + + + + + + + + + + + + + + + + + + + +]) => { 1 << 26 };
70    (@bit_value [+ + + + + + + + + + + + + + + + + + + + + + + + + + +]) => { 1 << 27 };
71    (@bit_value [+ + + + + + + + + + + + + + + + + + + + + + + + + + + +]) => { 1 << 28 };
72    (@bit_value [+ + + + + + + + + + + + + + + + + + + + + + + + + + + + +]) => { 1 << 29 };
73    (@bit_value [+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +]) => { 1 << 30 };
74    (@bit_value [+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +]) => { 1 << 31 };
75
76    // Entry point for enumerations with explicit values
77    ($(#[$attributes:meta])* $visibility:vis enum $identifier:ident: $t:ty { $($(#[$variant:meta])*$k:ident = $v:expr),* $(,)* } $($next:tt)*) => {
78        $(#[$attributes])*
79        #[derive(Copy, Clone, PartialEq, Eq)]
80        #[repr(transparent)]
81        $visibility struct $identifier($t);
82
83        impl $identifier {
84            $(
85                #[allow(non_upper_case_globals)]
86                $(#[$variant])*
87                $visibility const $k: Self = Self($v);
88            )*
89
90            #[allow(non_upper_case_globals)]
91            $visibility const None: Self = Self(0);
92
93            #[allow(non_upper_case_globals)]
94            $visibility const All: Self = Self($( $v )|*);
95
96
97            /// Checks if the flag set contains the specified flag(s)
98            #[allow(dead_code)]
99            $visibility const fn contains(&self, other: Self) -> bool {
100                (self.0 & other.0) == other.0
101            }
102
103            /// Checks if the flag set contains any of the specified flag(s)
104            #[allow(dead_code)]
105            $visibility const fn intersects(&self, other: Self) -> bool {
106                (self.0 & other.0) != 0
107            }
108
109            /// Inserts the specified flag(s) into the set
110            #[allow(dead_code)]
111            $visibility const fn insert(mut self, other: Self) -> Self {
112                self.0 |= other.0;
113                self
114            }
115
116            /// Removes the specified flag(s) from the set
117            #[allow(dead_code)]
118            $visibility const fn remove(mut self, other: Self) -> Self {
119                self.0 &= !other.0;
120                self
121            }
122
123            /// Toggles the specified flag(s) in the set
124            #[allow(dead_code)]
125            $visibility const fn toggle(mut self, other: Self) -> Self {
126                self.0 ^= other.0;
127                self
128            }
129
130            /// Sets or clears the specified flag(s) based on the passed value
131            #[allow(dead_code)]
132            $visibility const fn set(self, other: Self, value: bool) -> Self {
133                if value {
134                    self.insert(other)
135                } else {
136                    self.remove(other)
137                }
138            }
139
140            /// Returns the intersection of the two flag sets
141            #[allow(dead_code)]
142            $visibility const fn intersection(self, other: Self) -> Self {
143                Self(self.0 & other.0)
144            }
145
146            /// Returns the union of the two flag sets
147            #[allow(dead_code)]
148            $visibility const fn union(self, other: Self) -> Self {
149                Self(self.0 | other.0)
150            }
151
152            /// Returns the difference between the two flag sets
153            $visibility const fn difference(self, other: Self) -> Self {
154                Self(self.0 & !other.0)
155            }
156
157            /// Returns the symmetric difference between the two flag sets
158            $visibility const fn symmetric_difference(self, other: Self) -> Self {
159                Self(self.0 ^ other.0)
160            }
161
162            /// Returns the complement of the flag set
163            $visibility const fn complement(self) -> Self {
164                Self(!self.0)
165            }
166
167            /// Checks if the flag set is empty
168            #[allow(dead_code)]
169            $visibility const fn is_empty(&self) -> bool {
170                self.0 == 0
171            }
172
173            /// Returns the raw value
174            #[allow(dead_code)]
175            $visibility const fn bits(&self) -> $t {
176                self.0
177            }
178
179            /// Returns the number of bits required to represent all defined flags
180            #[allow(dead_code)]
181            $visibility const fn bits_used() -> u8 {
182                let all_bits = Self::All.0;
183                if all_bits == 0 {
184                    0
185                } else {
186                    // Calculate the position of the highest set bit + 1
187                    (core::mem::size_of::<$t>() * 8) as u8 - all_bits.leading_zeros() as u8
188                }
189            }
190
191            /// Creates a flag set from raw bits
192            #[allow(dead_code)]
193            $visibility const fn from_bits(bits: $t) -> Option<Self> {
194                let all_bits = Self::All.0;
195                if (bits & !all_bits) == 0 {
196                    Some(Self(bits))
197                } else {
198                    None
199                }
200            }
201
202            /// Creates a flag set from raw bits, truncating any unknown bits
203            #[allow(dead_code)]
204            $visibility const fn from_bits_truncate(bits: $t) -> Self {
205                let all_bits = Self::All.0;
206                Self(bits & all_bits)
207            }
208
209            /// Creates a flag set from raw bits without checking validity
210            #[allow(dead_code)]
211            $visibility const unsafe fn from_bits_unchecked(bits: $t) -> Self {
212                Self(bits)
213            }
214
215            /// Count the number of flags set in the flag set
216            #[allow(dead_code)]
217            $visibility const fn count_ones(&self) -> u32 {
218                self.0.count_ones()
219            }
220
221        }
222
223        impl core::fmt::Debug for $identifier {
224            fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
225                let mut first = true;
226                write!(f, "{} {{ ", stringify!($identifier))?;
227                $(
228                    if self.contains(Self::$k) {
229                        if !core::mem::replace(&mut first, false) {
230                            write!(f, " | ")?;
231                        }
232                        write!(f, "{}", stringify!($k))?;
233                    }
234                )*
235                write!(f, " }}")
236            }
237        }
238
239        impl core::ops::BitOr for $identifier {
240            type Output = Self;
241
242            fn bitor(self, other: Self) -> Self {
243                self.union(other)
244            }
245        }
246
247        impl core::ops::BitOrAssign for $identifier {
248            fn bitor_assign(&mut self, other: Self) {
249                self.insert(other);
250            }
251        }
252
253        impl core::ops::BitAnd for $identifier {
254            type Output = Self;
255
256            fn bitand(self, other: Self) -> Self {
257                self.intersection(other)
258            }
259        }
260
261        impl core::ops::BitAndAssign for $identifier {
262            fn bitand_assign(&mut self, other: Self) {
263                *self = self.intersection(other);
264            }
265        }
266
267        impl core::ops::BitXor for $identifier {
268            type Output = Self;
269
270            fn bitxor(self, other: Self) -> Self {
271                self.symmetric_difference(other)
272            }
273        }
274
275        impl core::ops::BitXorAssign for $identifier {
276            fn bitxor_assign(&mut self, other: Self) {
277                self.toggle(other);
278            }
279        }
280
281        impl core::ops::Not for $identifier {
282            type Output = Self;
283
284            fn not(self) -> Self {
285                self.complement()
286            }
287        }
288
289        impl core::ops::Sub for $identifier {
290            type Output = Self;
291
292            fn sub(self, other: Self) -> Self {
293                self.difference(other)
294            }
295        }
296
297        impl core::ops::SubAssign for $identifier {
298            fn sub_assign(&mut self, other: Self) {
299                self.remove(other);
300            }
301        }
302
303        $crate::flags! { $($next)* }
304    };
305}
306
307#[cfg(test)]
308mod tests {
309    extern crate alloc;
310
311    use alloc::format;
312
313    flags! {
314        pub enum TestFlags: u8 {
315            FlagA,
316            FlagB,
317            FlagC,
318        }
319    }
320
321    #[test]
322    fn test_debug() {
323        let flags = TestFlags::FlagA | TestFlags::FlagC;
324        let debug_str = format!("{:?}", flags);
325        assert_eq!(debug_str, "TestFlags { FlagA | FlagC }");
326    }
327
328    #[test]
329    fn test_flagset_operations() {
330        let flag_a = TestFlags::FlagA;
331        let flag_b = TestFlags::FlagB;
332        let flag_c = TestFlags::FlagC;
333
334        // Test individual flags
335        assert_eq!(flag_a.bits(), 1);
336        assert_eq!(flag_b.bits(), 2);
337        assert_eq!(flag_c.bits(), 4);
338
339        // Test empty and all
340        let empty = TestFlags::None;
341        assert!(empty.is_empty());
342        assert_eq!(empty.bits(), 0);
343
344        let all = TestFlags::All;
345        assert_eq!(all.bits(), 7); // 1 | 2 | 4
346        assert!(!all.is_empty());
347
348        // Test contains
349        let mut flags = TestFlags::None;
350        assert!(!flags.contains(flag_a));
351
352        flags = flags.insert(flag_a);
353        assert!(flags.contains(flag_a));
354        assert!(!flags.contains(flag_b));
355
356        // Test union/intersection
357        let ab = flag_a | flag_b;
358        assert!(ab.contains(flag_a));
359        assert!(ab.contains(flag_b));
360        assert!(!ab.contains(flag_c));
361
362        // Test set
363        let mut flags = TestFlags::None;
364        flags = flags.set(flag_a, true);
365        assert!(flags.contains(flag_a));
366        flags = flags.set(flag_a, false);
367        assert!(!flags.contains(flag_a));
368
369        // Test remove
370        let mut flags = flag_a | flag_b;
371        flags = flags.remove(flag_a);
372        assert!(!flags.contains(flag_a));
373        assert!(flags.contains(flag_b));
374
375        // Test toggle
376        let mut flags = flag_a;
377        flags = flags.toggle(flag_b);
378        assert!(flags.contains(flag_a));
379        assert!(flags.contains(flag_b));
380        flags = flags.toggle(flag_a);
381        assert!(!flags.contains(flag_a));
382        assert!(flags.contains(flag_b));
383
384        // Test intersects
385        let flags1 = flag_a | flag_b;
386        let flags2 = flag_b | flag_c;
387        assert!(flags1.intersects(flags2));
388        assert!(!flag_a.intersects(flag_c));
389
390        // Test set method with conditional flag setting
391        let mut flags = TestFlags::None;
392        flags = flags.set(flag_a, true);
393        assert!(flags.contains(flag_a));
394        flags = flags.set(flag_b, false);
395        assert!(!flags.contains(flag_b));
396        flags = flags.set(flag_c, true);
397        assert!(flags.contains(flag_c));
398
399        // Test bits_used - with 3 flags (bits 0, 1, 2), we need 3 bits
400        assert_eq!(TestFlags::bits_used(), 3);
401    }
402
403    flags! {
404        pub enum CustomFlags: u16 {
405            Read = 0b0001,
406            Write = 0b0010,
407            Execute = 0b0100,
408            Admin = 0b1000,
409        }
410    }
411
412    #[test]
413    fn test_debug_custom() {
414        let flags = CustomFlags::Read | CustomFlags::Execute;
415        let debug_str = format!("{:?}", flags);
416        assert_eq!(debug_str, "CustomFlags { Read | Execute }");
417
418        let no_flags = CustomFlags::None;
419        let debug_str_no_flags = format!("{:?}", no_flags);
420        assert_eq!(debug_str_no_flags, "CustomFlags {  }");
421
422        let all_flags = CustomFlags::All;
423        let debug_str_all_flags = format!("{:?}", all_flags);
424        assert_eq!(
425            debug_str_all_flags,
426            "CustomFlags { Read | Write | Execute | Admin }"
427        );
428    }
429
430    #[test]
431    fn test_explicit_values() {
432        let read = CustomFlags::Read;
433        let write = CustomFlags::Write;
434        let execute = CustomFlags::Execute;
435        let admin = CustomFlags::Admin;
436
437        // Test explicit bit values
438        assert_eq!(read.bits(), 0b0001);
439        assert_eq!(write.bits(), 0b0010);
440        assert_eq!(execute.bits(), 0b0100);
441        assert_eq!(admin.bits(), 0b1000);
442
443        // Test combinations
444        let read_write = read | write;
445        assert!(read_write.contains(read));
446        assert!(read_write.contains(write));
447        assert!(!read_write.contains(execute));
448
449        // Test all permissions
450        let all_perms = read | write | execute | admin;
451        assert_eq!(all_perms.bits(), 0b1111);
452        assert_eq!(all_perms, CustomFlags::All);
453
454        // Test from_bits
455        assert_eq!(CustomFlags::from_bits(0b0011), Some(read | write));
456        assert_eq!(CustomFlags::from_bits(0b10000), None); // Invalid bit
457
458        // Test from_bits_truncate
459        assert_eq!(CustomFlags::from_bits_truncate(0b10011), read | write); // Truncates invalid bit
460
461        // Test bits_used - with 4 flags at positions 0-3, we need 4 bits
462        assert_eq!(CustomFlags::bits_used(), 4);
463    }
464}