shared/
flags.rs

1#[macro_export]
2macro_rules! flags {
3    () => {};
4
5    // Entry point for enumerations without values
6    ($(#[$attributes:meta])* $visibility:vis enum $identifier:ident: $t:ty { $($(#[$variant:meta])* $k:ident),+ $(,)* } $($next:tt)*) => {
7        $crate::flags! {
8            @count_and_gen $(#[$attributes])* $visibility enum $identifier: $t
9            { } // accumulated variants
10            [ $( ($(#[$variant])* $k) )+ ] // remaining to process
11            [] // counter (empty = 0)
12            $($next)*
13        }
14    };
15
16    // Process each variant, incrementing the counter
17    (@count_and_gen $(#[$attributes:meta])* $visibility:vis enum $identifier:ident: $t:ty
18        { $($accumulated:tt)* }
19        [ ($(#[$variant_meta:meta])* $current:ident) $(($($rest_items:tt)*))* ]
20        [ $($counter:tt)* ]
21        $($next:tt)*
22    ) => {
23        $crate::flags! {
24            @count_and_gen $(#[$attributes])* $visibility enum $identifier: $t
25            { $($accumulated)* $(#[$variant_meta])* $current = $crate::flags!(@bit_value [ $($counter)* ]), }
26            [ $(($($rest_items)*))* ]
27            [ $($counter)* + ] // increment counter
28            $($next)*
29        }
30    };
31
32    // When all variants are processed, generate the struct
33    (@count_and_gen $(#[$attributes:meta])* $visibility:vis enum $identifier:ident: $t:ty
34        { $($accumulated:tt)* }
35        [ ]
36        [ $($counter:tt)* ]
37        $($next:tt)*
38    ) => {
39        $crate::flags! { $(#[$attributes])* $visibility enum $identifier: $t { $($accumulated)* } $($next)* }
40    };
41
42    // Convert counter tokens to bit shift value
43    (@bit_value []) => { 1 << 0 };
44    (@bit_value [+]) => { 1 << 1 };
45    (@bit_value [+ +]) => { 1 << 2 };
46    (@bit_value [+ + +]) => { 1 << 3 };
47    (@bit_value [+ + + +]) => { 1 << 4 };
48    (@bit_value [+ + + + +]) => { 1 << 5 };
49    (@bit_value [+ + + + + +]) => { 1 << 6 };
50    (@bit_value [+ + + + + + +]) => { 1 << 7 };
51    (@bit_value [+ + + + + + + +]) => { 1 << 8 };
52    (@bit_value [+ + + + + + + + +]) => { 1 << 9 };
53    (@bit_value [+ + + + + + + + + +]) => { 1 << 10 };
54    (@bit_value [+ + + + + + + + + + +]) => { 1 << 11 };
55    (@bit_value [+ + + + + + + + + + + +]) => { 1 << 12 };
56    (@bit_value [+ + + + + + + + + + + + +]) => { 1 << 13 };
57    (@bit_value [+ + + + + + + + + + + + + +]) => { 1 << 14 };
58    (@bit_value [+ + + + + + + + + + + + + + +]) => { 1 << 15 };
59    (@bit_value [+ + + + + + + + + + + + + + + +]) => { 1 << 16 };
60    (@bit_value [+ + + + + + + + + + + + + + + + +]) => { 1 << 17 };
61    (@bit_value [+ + + + + + + + + + + + + + + + + +]) => { 1 << 18 };
62    (@bit_value [+ + + + + + + + + + + + + + + + + + +]) => { 1 << 19 };
63    (@bit_value [+ + + + + + + + + + + + + + + + + + + +]) => { 1 << 20 };
64    (@bit_value [+ + + + + + + + + + + + + + + + + + + + +]) => { 1 << 21 };
65    (@bit_value [+ + + + + + + + + + + + + + + + + + + + + +]) => { 1 << 22 };
66    (@bit_value [+ + + + + + + + + + + + + + + + + + + + + + +]) => { 1 << 23 };
67    (@bit_value [+ + + + + + + + + + + + + + + + + + + + + + + +]) => { 1 << 24 };
68    (@bit_value [+ + + + + + + + + + + + + + + + + + + + + + + + +]) => { 1 << 25 };
69    (@bit_value [+ + + + + + + + + + + + + + + + + + + + + + + + + +]) => { 1 << 26 };
70    (@bit_value [+ + + + + + + + + + + + + + + + + + + + + + + + + + +]) => { 1 << 27 };
71    (@bit_value [+ + + + + + + + + + + + + + + + + + + + + + + + + + + +]) => { 1 << 28 };
72    (@bit_value [+ + + + + + + + + + + + + + + + + + + + + + + + + + + + +]) => { 1 << 29 };
73    (@bit_value [+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +]) => { 1 << 30 };
74    (@bit_value [+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +]) => { 1 << 31 };
75
76    // Entry point for enumerations with explicit values
77    ($(#[$attributes:meta])* $visibility:vis enum $identifier:ident: $t:ty { $($(#[$variant:meta])*$k:ident = $v:expr),* $(,)* } $($next:tt)*) => {
78        $(#[$attributes])*
79        #[derive(Copy, Clone, PartialEq, Eq)]
80        #[repr(transparent)]
81        $visibility struct $identifier($t);
82
83        impl $identifier {
84            $(
85                #[allow(non_upper_case_globals)]
86                $(#[$variant])*
87                $visibility const $k: Self = Self($v);
88            )*
89
90            #[allow(non_upper_case_globals)]
91            $visibility const None: Self = Self(0);
92
93            #[allow(non_upper_case_globals)]
94            $visibility const All: Self = Self($( $v )|*);
95
96
97            /// Checks if the flag set contains the specified flag(s)
98            #[allow(dead_code)]
99            $visibility const fn contains(&self, other: Self) -> bool {
100                (self.0 & other.0) == other.0
101            }
102
103            /// Checks if the flag set contains any of the specified flag(s)
104            #[allow(dead_code)]
105            $visibility const fn intersects(&self, other: Self) -> bool {
106                (self.0 & other.0) != 0
107            }
108
109            /// Inserts the specified flag(s) into the set
110            #[allow(dead_code)]
111            $visibility const fn insert(mut self, other: Self) -> Self {
112                self.0 |= other.0;
113                self
114            }
115
116            /// Removes the specified flag(s) from the set
117            #[allow(dead_code)]
118            $visibility const fn remove(mut self, other: Self) -> Self {
119                self.0 &= !other.0;
120                self
121            }
122
123            /// Toggles the specified flag(s) in the set
124            #[allow(dead_code)]
125            $visibility const fn toggle(mut self, other: Self) -> Self {
126                self.0 ^= other.0;
127                self
128            }
129
130            /// Sets or clears the specified flag(s) based on the passed value
131            #[allow(dead_code)]
132            $visibility const fn set(self, other: Self, value: bool) -> Self {
133                if value {
134                    self.insert(other)
135                } else {
136                    self.remove(other)
137                }
138            }
139
140            /// Returns the intersection of the two flag sets
141            #[allow(dead_code)]
142            $visibility const fn intersection(self, other: Self) -> Self {
143                Self(self.0 & other.0)
144            }
145
146            /// Returns the union of the two flag sets
147            #[allow(dead_code)]
148            $visibility const fn union(self, other: Self) -> Self {
149                Self(self.0 | other.0)
150            }
151
152            /// Returns the difference between the two flag sets
153            $visibility const fn difference(self, other: Self) -> Self {
154                Self(self.0 & !other.0)
155            }
156
157            /// Returns the symmetric difference between the two flag sets
158            $visibility const fn symmetric_difference(self, other: Self) -> Self {
159                Self(self.0 ^ other.0)
160            }
161
162            /// Returns the complement of the flag set
163            $visibility const fn complement(self) -> Self {
164                Self(!self.0)
165            }
166
167            /// Checks if the flag set is empty
168            #[allow(dead_code)]
169            $visibility const fn is_empty(&self) -> bool {
170                self.0 == 0
171            }
172
173            /// Returns the raw value
174            #[allow(dead_code)]
175            $visibility const fn bits(&self) -> $t {
176                self.0
177            }
178
179            /// Returns the number of bits required to represent all defined flags
180            #[allow(dead_code)]
181            $visibility const fn bits_used() -> u8 {
182                let all_bits = Self::All.0;
183                if all_bits == 0 {
184                    0
185                } else {
186                    // Calculate the position of the highest set bit + 1
187                    (core::mem::size_of::<$t>() * 8) as u8 - all_bits.leading_zeros() as u8
188                }
189            }
190
191            /// Creates a flag set from raw bits
192            #[allow(dead_code)]
193            $visibility const fn from_bits(bits: $t) -> Option<Self> {
194                let all_bits = Self::All.0;
195                if (bits & !all_bits) == 0 {
196                    Some(Self(bits))
197                } else {
198                    None
199                }
200            }
201
202            /// Creates a flag set from raw bits, truncating any unknown bits
203            #[allow(dead_code)]
204            $visibility const fn from_bits_truncate(bits: $t) -> Self {
205                let all_bits = Self::All.0;
206                Self(bits & all_bits)
207            }
208
209            /// Creates a flag set from raw bits without checking validity
210            #[allow(dead_code)]
211            $visibility const unsafe fn from_bits_unchecked(bits: $t) -> Self {
212                Self(bits)
213            }
214        }
215
216        impl core::fmt::Debug for $identifier {
217            fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
218                let mut first = true;
219                write!(f, "{} {{ ", stringify!($identifier))?;
220                $(
221                    if self.contains(Self::$k) {
222                        if !core::mem::replace(&mut first, false) {
223                            write!(f, " | ")?;
224                        }
225                        write!(f, "{}", stringify!($k))?;
226                    }
227                )*
228                write!(f, " }}")
229            }
230        }
231
232        impl core::ops::BitOr for $identifier {
233            type Output = Self;
234
235            fn bitor(self, other: Self) -> Self {
236                self.union(other)
237            }
238        }
239
240        impl core::ops::BitOrAssign for $identifier {
241            fn bitor_assign(&mut self, other: Self) {
242                self.insert(other);
243            }
244        }
245
246        impl core::ops::BitAnd for $identifier {
247            type Output = Self;
248
249            fn bitand(self, other: Self) -> Self {
250                self.intersection(other)
251            }
252        }
253
254        impl core::ops::BitAndAssign for $identifier {
255            fn bitand_assign(&mut self, other: Self) {
256                *self = self.intersection(other);
257            }
258        }
259
260        impl core::ops::BitXor for $identifier {
261            type Output = Self;
262
263            fn bitxor(self, other: Self) -> Self {
264                self.symmetric_difference(other)
265            }
266        }
267
268        impl core::ops::BitXorAssign for $identifier {
269            fn bitxor_assign(&mut self, other: Self) {
270                self.toggle(other);
271            }
272        }
273
274        impl core::ops::Not for $identifier {
275            type Output = Self;
276
277            fn not(self) -> Self {
278                self.complement()
279            }
280        }
281
282        impl core::ops::Sub for $identifier {
283            type Output = Self;
284
285            fn sub(self, other: Self) -> Self {
286                self.difference(other)
287            }
288        }
289
290        impl core::ops::SubAssign for $identifier {
291            fn sub_assign(&mut self, other: Self) {
292                self.remove(other);
293            }
294        }
295
296        $crate::flags! { $($next)* }
297    };
298}
299
300#[cfg(test)]
301mod tests {
302    extern crate alloc;
303
304    use alloc::format;
305
306    flags! {
307        pub enum TestFlags: u8 {
308            FlagA,
309            FlagB,
310            FlagC,
311        }
312    }
313
314    #[test]
315    fn test_debug() {
316        let flags = TestFlags::FlagA | TestFlags::FlagC;
317        let debug_str = format!("{:?}", flags);
318        assert_eq!(debug_str, "TestFlags { FlagA | FlagC }");
319    }
320
321    #[test]
322    fn test_flagset_operations() {
323        let flag_a = TestFlags::FlagA;
324        let flag_b = TestFlags::FlagB;
325        let flag_c = TestFlags::FlagC;
326
327        // Test individual flags
328        assert_eq!(flag_a.bits(), 1);
329        assert_eq!(flag_b.bits(), 2);
330        assert_eq!(flag_c.bits(), 4);
331
332        // Test empty and all
333        let empty = TestFlags::None;
334        assert!(empty.is_empty());
335        assert_eq!(empty.bits(), 0);
336
337        let all = TestFlags::All;
338        assert_eq!(all.bits(), 7); // 1 | 2 | 4
339        assert!(!all.is_empty());
340
341        // Test contains
342        let mut flags = TestFlags::None;
343        assert!(!flags.contains(flag_a));
344
345        flags = flags.insert(flag_a);
346        assert!(flags.contains(flag_a));
347        assert!(!flags.contains(flag_b));
348
349        // Test union/intersection
350        let ab = flag_a | flag_b;
351        assert!(ab.contains(flag_a));
352        assert!(ab.contains(flag_b));
353        assert!(!ab.contains(flag_c));
354
355        // Test set
356        let mut flags = TestFlags::None;
357        flags = flags.set(flag_a, true);
358        assert!(flags.contains(flag_a));
359        flags = flags.set(flag_a, false);
360        assert!(!flags.contains(flag_a));
361
362        // Test remove
363        let mut flags = flag_a | flag_b;
364        flags = flags.remove(flag_a);
365        assert!(!flags.contains(flag_a));
366        assert!(flags.contains(flag_b));
367
368        // Test toggle
369        let mut flags = flag_a;
370        flags = flags.toggle(flag_b);
371        assert!(flags.contains(flag_a));
372        assert!(flags.contains(flag_b));
373        flags = flags.toggle(flag_a);
374        assert!(!flags.contains(flag_a));
375        assert!(flags.contains(flag_b));
376
377        // Test intersects
378        let flags1 = flag_a | flag_b;
379        let flags2 = flag_b | flag_c;
380        assert!(flags1.intersects(flags2));
381        assert!(!flag_a.intersects(flag_c));
382
383        // Test set method with conditional flag setting
384        let mut flags = TestFlags::None;
385        flags = flags.set(flag_a, true);
386        assert!(flags.contains(flag_a));
387        flags = flags.set(flag_b, false);
388        assert!(!flags.contains(flag_b));
389        flags = flags.set(flag_c, true);
390        assert!(flags.contains(flag_c));
391
392        // Test bits_used - with 3 flags (bits 0, 1, 2), we need 3 bits
393        assert_eq!(TestFlags::bits_used(), 3);
394    }
395
396    flags! {
397        pub enum CustomFlags: u16 {
398            Read = 0b0001,
399            Write = 0b0010,
400            Execute = 0b0100,
401            Admin = 0b1000,
402        }
403    }
404
405    #[test]
406    fn test_debug_custom() {
407        let flags = CustomFlags::Read | CustomFlags::Execute;
408        let debug_str = format!("{:?}", flags);
409        assert_eq!(debug_str, "CustomFlags { Read | Execute }");
410
411        let no_flags = CustomFlags::None;
412        let debug_str_no_flags = format!("{:?}", no_flags);
413        assert_eq!(debug_str_no_flags, "CustomFlags {  }");
414
415        let all_flags = CustomFlags::All;
416        let debug_str_all_flags = format!("{:?}", all_flags);
417        assert_eq!(
418            debug_str_all_flags,
419            "CustomFlags { Read | Write | Execute | Admin }"
420        );
421    }
422
423    #[test]
424    fn test_explicit_values() {
425        let read = CustomFlags::Read;
426        let write = CustomFlags::Write;
427        let execute = CustomFlags::Execute;
428        let admin = CustomFlags::Admin;
429
430        // Test explicit bit values
431        assert_eq!(read.bits(), 0b0001);
432        assert_eq!(write.bits(), 0b0010);
433        assert_eq!(execute.bits(), 0b0100);
434        assert_eq!(admin.bits(), 0b1000);
435
436        // Test combinations
437        let read_write = read | write;
438        assert!(read_write.contains(read));
439        assert!(read_write.contains(write));
440        assert!(!read_write.contains(execute));
441
442        // Test all permissions
443        let all_perms = read | write | execute | admin;
444        assert_eq!(all_perms.bits(), 0b1111);
445        assert_eq!(all_perms, CustomFlags::All);
446
447        // Test from_bits
448        assert_eq!(CustomFlags::from_bits(0b0011), Some(read | write));
449        assert_eq!(CustomFlags::from_bits(0b10000), None); // Invalid bit
450
451        // Test from_bits_truncate
452        assert_eq!(CustomFlags::from_bits_truncate(0b10011), read | write); // Truncates invalid bit
453
454        // Test bits_used - with 4 flags at positions 0-3, we need 4 bits
455        assert_eq!(CustomFlags::bits_used(), 4);
456    }
457}