task_macros/
lib.rs

1use darling::{FromMeta, ast::NestedMeta};
2use proc_macro::TokenStream;
3use quote::{ToTokens, format_ident, quote};
4use syn::{ItemFn, parse_macro_input, parse_str};
5
6fn default_task_path() -> syn::Expr {
7    parse_str("task").unwrap()
8}
9
10fn default_executor() -> syn::Expr {
11    parse_str("drivers::standard_library::executor::instantiate_static_executor!()").unwrap()
12}
13
14#[derive(Debug, FromMeta, Clone)]
15struct TaskArguments {
16    #[darling(default = "default_task_path")]
17    pub task_path: syn::Expr,
18
19    #[darling(default = "default_executor")]
20    pub executor: syn::Expr,
21}
22
23impl TaskArguments {
24    fn from_token_stream(arguments: TokenStream) -> Result<Self, darling::Error> {
25        let arguments = NestedMeta::parse_meta_list(arguments.into()).unwrap();
26        Self::from_list(&arguments.clone())
27    }
28}
29
30/// A procedural macro to annotate test functions.
31///
32/// This macro wraps the annotated async function to be executed in a blocking context
33/// using embassy_futures::block_on, similar to how other modules handle async operations.
34///
35/// # Requirements
36///
37/// Test functions must:
38/// - Be async
39/// - Have no arguments
40/// - Have no return type (or return unit type `()`)
41#[proc_macro_attribute]
42#[allow(non_snake_case)]
43pub fn test(arguments: TokenStream, input: TokenStream) -> TokenStream {
44    let arguments = match TaskArguments::from_token_stream(arguments) {
45        Ok(o) => o,
46        Err(e) => return e.write_errors().into(),
47    };
48    let input_function = parse_macro_input!(input as ItemFn);
49
50    let executor = arguments.executor;
51    let task_path = arguments.task_path;
52
53    // Extract function details
54    let function_name = &input_function.sig.ident;
55
56    let function_name_string = function_name.to_string();
57
58    // Check if function is async
59    let is_asynchronous = input_function.sig.asyncness.is_some();
60
61    if !is_asynchronous {
62        return syn::Error::new_spanned(
63            input_function.sig.fn_token,
64            "Test functions must be async",
65        )
66        .to_compile_error()
67        .into();
68    }
69
70    // Check if function has no arguments
71    if !input_function.sig.inputs.is_empty() {
72        return syn::Error::new_spanned(
73            &input_function.sig.inputs,
74            "Test functions must not have any arguments",
75        )
76        .to_compile_error()
77        .into();
78    }
79
80    // Check if function has no return type (or returns unit type)
81    if let syn::ReturnType::Type(_, return_type) = &input_function.sig.output {
82        // Allow unit type () but reject any other return type
83        if let syn::Type::Tuple(tuple) = return_type.as_ref() {
84            if !tuple.elems.is_empty() {
85                return syn::Error::new_spanned(
86                    return_type,
87                    "Test functions must not have a return type",
88                )
89                .to_compile_error()
90                .into();
91            }
92        } else {
93            return syn::Error::new_spanned(
94                return_type,
95                "Test functions must not have a return type",
96            )
97            .to_compile_error()
98            .into();
99        }
100    }
101
102    // Change ident to __inner to avoid name conflicts
103    let mut Input_function = input_function.clone();
104    Input_function.sig.ident = format_ident!("__inner");
105
106    // Generate the new function
107    quote! {
108        #[std::prelude::v1::test]
109        fn #function_name() {
110            #Input_function
111
112            static mut __SPAWNER : usize = 0;
113
114            unsafe {
115                let __EXECUTOR = #executor;
116
117                __EXECUTOR.run(|Spawner, __executor| {
118                    let manager = #task_path::initialize();
119
120                    unsafe {
121                        __SPAWNER = manager.register_spawner(Spawner).expect("Failed to register spawner");
122                    }
123
124                    #task_path::futures::block_on(async move {
125                        manager.spawn(
126                            #task_path::Manager::ROOT_TASK_IDENTIFIER,
127                            #function_name_string,
128                            Some(__SPAWNER),
129                            async move |_task| {
130                                __inner().await;
131                                __executor.stop();
132                            }
133                        ).await
134                    }).expect("Failed to spawn task");
135                });
136            }
137            unsafe {
138                #task_path::get_instance().unregister_spawner(__SPAWNER).expect("Failed to unregister spawner");
139            }
140
141        }
142    }
143    .to_token_stream()
144    .into()
145}
146
147/// A procedural macro to annotate functions that should run with a specific executor.
148///
149/// This macro wraps the annotated async function to be executed with a provided
150/// executor, handling the registration, spawning, and cleanup automatically.
151///
152/// # Requirements
153///
154/// Functions must:
155/// - Be async
156/// - Have no arguments
157/// - Have no return type (or return unit type `()`)
158///
159/// # Usage
160///
161/// The macro accepts an executor expression as a parameter:
162///
163/// ```rust
164/// #[Run_with_executor(drivers::standard_library::Executor::Executor_type::new())]
165/// async fn my_function() {
166///     println!("Running with custom executor!");
167/// }
168/// ```
169///
170/// You can also use any executor expression:
171/// ```rust
172/// #[Run_with_executor(my_custom_executor)]
173/// async fn my_function() { ... }
174/// ```
175#[proc_macro_attribute]
176#[allow(non_snake_case)]
177pub fn run(Arguments: TokenStream, Input: TokenStream) -> TokenStream {
178    let Arguments = match TaskArguments::from_token_stream(Arguments) {
179        Ok(o) => o,
180        Err(e) => return e.write_errors().into(),
181    };
182    let Input_function = parse_macro_input!(Input as ItemFn);
183
184    let Task_path = Arguments.task_path;
185    let Executor_expression = Arguments.executor;
186
187    // Extract function details
188    let Function_name = &Input_function.sig.ident;
189    let Function_name_string = Function_name.to_string();
190
191    // Check if function is async
192    let is_asynchronous = Input_function.sig.asyncness.is_some();
193
194    if !is_asynchronous {
195        return syn::Error::new_spanned(
196            Input_function.sig.fn_token,
197            "Functions with Run_with_executor must be async",
198        )
199        .to_compile_error()
200        .into();
201    }
202
203    // Check if function has no arguments
204    if !Input_function.sig.inputs.is_empty() {
205        return syn::Error::new_spanned(
206            &Input_function.sig.inputs,
207            "Functions with Run_with_executor must not have any arguments",
208        )
209        .to_compile_error()
210        .into();
211    }
212
213    // Check if function has no return type (or returns unit type)
214    if let syn::ReturnType::Type(_, Return_type) = &Input_function.sig.output {
215        // Allow unit type () but reject any other return type
216        if let syn::Type::Tuple(tuple) = Return_type.as_ref() {
217            if !tuple.elems.is_empty() {
218                return syn::Error::new_spanned(
219                    Return_type,
220                    "Functions with Run_with_executor must not have a return type",
221                )
222                .to_compile_error()
223                .into();
224            }
225        } else {
226            return syn::Error::new_spanned(
227                Return_type,
228                "Functions with Run_with_executor must not have a return type",
229            )
230            .to_compile_error()
231            .into();
232        }
233    }
234
235    // Change ident to __inner to avoid name conflicts
236    let mut Input_function = Input_function.clone();
237    Input_function.sig.ident = format_ident!("__inner");
238
239    // Generate the new function
240    quote! {
241        fn #Function_name() {
242            #Input_function
243
244            static mut __SPAWNER : usize = 0;
245
246            unsafe {
247                let __EXECUTOR : &'static mut _ = #Executor_expression;
248
249                __EXECUTOR.run(|Spawner, __EXECUTOR| {
250                    let manager = #Task_path::initialize();
251
252                    unsafe {
253                        __SPAWNER = manager.register_spawner(Spawner).expect("Failed to register spawner");
254                    }
255
256                    #Task_path::futures::block_on(async move {
257                        manager.spawn(
258                            #Task_path::Manager::ROOT_TASK_IDENTIFIER,
259                            #Function_name_string,
260                            Some(__SPAWNER),
261                            async move |_task| {
262                                __inner().await;
263                                __EXECUTOR.stop();
264                            }
265                        ).await
266                    }).expect("Failed to spawn task");
267                });
268            }
269            unsafe {
270                #Task_path::get_instance().unregister_spawner(__SPAWNER).expect("Failed to unregister spawner");
271            }
272        }
273    }
274    .to_token_stream()
275    .into()
276}