1use darling::{FromMeta, ast::NestedMeta};
2use proc_macro::TokenStream;
3use quote::{ToTokens, format_ident, quote};
4use syn::{ItemFn, parse_macro_input, parse_str};
5
6fn default_task_path() -> syn::Expr {
7 parse_str("task").unwrap()
8}
9
10fn default_executor() -> syn::Expr {
11 parse_str("drivers::standard_library::executor::instantiate_static_executor!()").unwrap()
12}
13
14#[derive(Debug, FromMeta, Clone)]
15struct TaskArguments {
16 #[darling(default = "default_task_path")]
17 pub task_path: syn::Expr,
18
19 #[darling(default = "default_executor")]
20 pub executor: syn::Expr,
21}
22
23impl TaskArguments {
24 fn from_token_stream(arguments: TokenStream) -> Result<Self, darling::Error> {
25 let arguments = NestedMeta::parse_meta_list(arguments.into()).unwrap();
26 Self::from_list(&arguments.clone())
27 }
28}
29
30#[proc_macro_attribute]
42#[allow(non_snake_case)]
43pub fn test(arguments: TokenStream, input: TokenStream) -> TokenStream {
44 let arguments = match TaskArguments::from_token_stream(arguments) {
45 Ok(o) => o,
46 Err(e) => return e.write_errors().into(),
47 };
48 let input_function = parse_macro_input!(input as ItemFn);
49
50 let executor = arguments.executor;
51 let task_path = arguments.task_path;
52
53 let function_name = &input_function.sig.ident;
55
56 let function_name_string = function_name.to_string();
57
58 let is_asynchronous = input_function.sig.asyncness.is_some();
60
61 if !is_asynchronous {
62 return syn::Error::new_spanned(
63 input_function.sig.fn_token,
64 "Test functions must be async",
65 )
66 .to_compile_error()
67 .into();
68 }
69
70 if !input_function.sig.inputs.is_empty() {
72 return syn::Error::new_spanned(
73 &input_function.sig.inputs,
74 "Test functions must not have any arguments",
75 )
76 .to_compile_error()
77 .into();
78 }
79
80 if let syn::ReturnType::Type(_, return_type) = &input_function.sig.output {
82 if let syn::Type::Tuple(tuple) = return_type.as_ref() {
84 if !tuple.elems.is_empty() {
85 return syn::Error::new_spanned(
86 return_type,
87 "Test functions must not have a return type",
88 )
89 .to_compile_error()
90 .into();
91 }
92 } else {
93 return syn::Error::new_spanned(
94 return_type,
95 "Test functions must not have a return type",
96 )
97 .to_compile_error()
98 .into();
99 }
100 }
101
102 let mut Input_function = input_function.clone();
104 Input_function.sig.ident = format_ident!("__inner");
105
106 quote! {
108 #[std::prelude::v1::test]
109 fn #function_name() {
110 #Input_function
111
112 static mut __SPAWNER : usize = 0;
113
114 unsafe {
115 let __EXECUTOR = #executor;
116
117 __EXECUTOR.run(|Spawner, __executor| {
118 let manager = #task_path::initialize();
119
120 unsafe {
121 __SPAWNER = manager.register_spawner(Spawner).expect("Failed to register spawner");
122 }
123
124 #task_path::futures::block_on(async move {
125 manager.spawn(
126 #task_path::Manager::ROOT_TASK_IDENTIFIER,
127 #function_name_string,
128 Some(__SPAWNER),
129 async move |_task| {
130 __inner().await;
131 __executor.stop();
132 }
133 ).await
134 }).expect("Failed to spawn task");
135 });
136 }
137 unsafe {
138 #task_path::get_instance().unregister_spawner(__SPAWNER).expect("Failed to unregister spawner");
139 }
140
141 }
142 }
143 .to_token_stream()
144 .into()
145}
146
147#[proc_macro_attribute]
176#[allow(non_snake_case)]
177pub fn run(Arguments: TokenStream, Input: TokenStream) -> TokenStream {
178 let Arguments = match TaskArguments::from_token_stream(Arguments) {
179 Ok(o) => o,
180 Err(e) => return e.write_errors().into(),
181 };
182 let Input_function = parse_macro_input!(Input as ItemFn);
183
184 let Task_path = Arguments.task_path;
185 let Executor_expression = Arguments.executor;
186
187 let Function_name = &Input_function.sig.ident;
189 let Function_name_string = Function_name.to_string();
190
191 let is_asynchronous = Input_function.sig.asyncness.is_some();
193
194 if !is_asynchronous {
195 return syn::Error::new_spanned(
196 Input_function.sig.fn_token,
197 "Functions with Run_with_executor must be async",
198 )
199 .to_compile_error()
200 .into();
201 }
202
203 if !Input_function.sig.inputs.is_empty() {
205 return syn::Error::new_spanned(
206 &Input_function.sig.inputs,
207 "Functions with Run_with_executor must not have any arguments",
208 )
209 .to_compile_error()
210 .into();
211 }
212
213 if let syn::ReturnType::Type(_, Return_type) = &Input_function.sig.output {
215 if let syn::Type::Tuple(tuple) = Return_type.as_ref() {
217 if !tuple.elems.is_empty() {
218 return syn::Error::new_spanned(
219 Return_type,
220 "Functions with Run_with_executor must not have a return type",
221 )
222 .to_compile_error()
223 .into();
224 }
225 } else {
226 return syn::Error::new_spanned(
227 Return_type,
228 "Functions with Run_with_executor must not have a return type",
229 )
230 .to_compile_error()
231 .into();
232 }
233 }
234
235 let mut Input_function = Input_function.clone();
237 Input_function.sig.ident = format_ident!("__inner");
238
239 quote! {
241 fn #Function_name() {
242 #Input_function
243
244 static mut __SPAWNER : usize = 0;
245
246 unsafe {
247 let __EXECUTOR : &'static mut _ = #Executor_expression;
248
249 __EXECUTOR.run(|Spawner, __EXECUTOR| {
250 let manager = #Task_path::initialize();
251
252 unsafe {
253 __SPAWNER = manager.register_spawner(Spawner).expect("Failed to register spawner");
254 }
255
256 #Task_path::futures::block_on(async move {
257 manager.spawn(
258 #Task_path::Manager::ROOT_TASK_IDENTIFIER,
259 #Function_name_string,
260 Some(__SPAWNER),
261 async move |_task| {
262 __inner().await;
263 __EXECUTOR.stop();
264 }
265 ).await
266 }).expect("Failed to spawn task");
267 });
268 }
269 unsafe {
270 #Task_path::get_instance().unregister_spawner(__SPAWNER).expect("Failed to unregister spawner");
271 }
272 }
273 }
274 .to_token_stream()
275 .into()
276}