From: Gary Guo gary@garyguo.net
This allows significant cleanups.
Signed-off-by: Gary Guo gary@garyguo.net --- rust/macros/kunit.rs | 274 +++++++++++++++++++------------------------ rust/macros/lib.rs | 6 +- 2 files changed, 123 insertions(+), 157 deletions(-)
diff --git a/rust/macros/kunit.rs b/rust/macros/kunit.rs index 7427c17ee5f5c..516219f5b1356 100644 --- a/rust/macros/kunit.rs +++ b/rust/macros/kunit.rs @@ -4,81 +4,50 @@ //! //! Copyright (c) 2023 José Expósito jose.exposito89@gmail.com
-use std::collections::HashMap; -use std::fmt::Write; - -use proc_macro2::{Delimiter, Group, TokenStream, TokenTree}; - -pub(crate) fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream { - let attr = attr.to_string(); - - if attr.is_empty() { - panic!("Missing test name in `#[kunit_tests(test_name)]` macro") - } - - if attr.len() > 255 { - panic!("The test suite name `{attr}` exceeds the maximum length of 255 bytes") +use std::ffi::CString; + +use proc_macro2::TokenStream; +use quote::{ + format_ident, + quote, + ToTokens, // +}; +use syn::{ + parse_quote, + Error, + Ident, + Item, + ItemMod, + LitCStr, + Result, // +}; + +pub(crate) fn kunit_tests(test_suite: Ident, mut module: ItemMod) -> Result<TokenStream> { + if test_suite.to_string().len() > 255 { + return Err(Error::new_spanned( + test_suite, + "test suite names cannot exceed the maximum length of 255 bytes", + )); }
- let mut tokens: Vec<_> = ts.into_iter().collect(); - - // Scan for the `mod` keyword. - tokens - .iter() - .find_map(|token| match token { - TokenTree::Ident(ident) => match ident.to_string().as_str() { - "mod" => Some(true), - _ => None, - }, - _ => None, - }) - .expect("`#[kunit_tests(test_name)]` attribute should only be applied to modules"); - - // Retrieve the main body. The main body should be the last token tree. - let body = match tokens.pop() { - Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group, - _ => panic!("Cannot locate main body of module"), + // We cannot handle modules that defer to another file (e.g. `mod foo;`). + let Some((module_brace, module_items)) = module.content.take() else { + Err(Error::new_spanned( + module, + "`#[kunit_tests(test_name)]` attribute should only be applied to inline modules", + ))? };
- // Get the functions set as tests. Search for `[test]` -> `fn`. - let mut body_it = body.stream().into_iter(); - let mut tests = Vec::new(); - let mut attributes: HashMap<String, TokenStream> = HashMap::new(); - while let Some(token) = body_it.next() { - match token { - TokenTree::Punct(ref p) if p.as_char() == '#' => match body_it.next() { - Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Bracket => { - if let Some(TokenTree::Ident(name)) = g.stream().into_iter().next() { - // Collect attributes because we need to find which are tests. We also - // need to copy `cfg` attributes so tests can be conditionally enabled. - attributes - .entry(name.to_string()) - .or_default() - .extend([token, TokenTree::Group(g)]); - } - continue; - } - _ => (), - }, - TokenTree::Ident(i) if i.to_string() == "fn" && attributes.contains_key("test") => { - if let Some(TokenTree::Ident(test_name)) = body_it.next() { - tests.push((test_name, attributes.remove("cfg").unwrap_or_default())) - } - } - - _ => (), - } - attributes.clear(); - } + // Make the entire module gated behind `CONFIG_KUNIT`. + module + .attrs + .insert(0, parse_quote!(#[cfg(CONFIG_KUNIT="y")]));
- // Add `#[cfg(CONFIG_KUNIT="y")]` before the module declaration. - let config_kunit = "#[cfg(CONFIG_KUNIT="y")]".to_owned().parse().unwrap(); - tokens.insert( - 0, - TokenTree::Group(Group::new(Delimiter::None, config_kunit)), - ); + let mut processed_items = Vec::new(); + let mut test_cases = Vec::new();
// Generate the test KUnit test suite and a test case for each `#[test]`. + // // The code generated for the following test module: // // ``` @@ -110,98 +79,93 @@ pub(crate) fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream { // // ::kernel::kunit_unsafe_test_suite!(kunit_test_suit_name, TEST_CASES); // ``` - let mut kunit_macros = "".to_owned(); - let mut test_cases = "".to_owned(); - let mut assert_macros = "".to_owned(); - let path = crate::helpers::file(); - let num_tests = tests.len(); - for (test, cfg_attr) in tests { - let kunit_wrapper_fn_name = format!("kunit_rust_wrapper_{test}"); - // Append any `cfg` attributes the user might have written on their tests so we don't - // attempt to call them when they are `cfg`'d out. An extra `use` is used here to reduce - // the length of the assert message. - let kunit_wrapper = format!( - r#"unsafe extern "C" fn {kunit_wrapper_fn_name}(_test: *mut ::kernel::bindings::kunit) - {{ - (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SKIPPED; - {cfg_attr} {{ - (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SUCCESS; - use ::kernel::kunit::is_test_result_ok; - assert!(is_test_result_ok({test}())); + // + // Non-function items (e.g. imports) are preserved. + for item in module_items { + let Item::Fn(mut f) = item else { + processed_items.push(item); + continue; + }; + + // TODO: Replace below with `extract_if` when MSRV is bumped above 1.85. + // Remove `#[test]` attributes applied on the function and count if any. + if !f.attrs.iter().any(|attr| attr.path().is_ident("test")) { + processed_items.push(Item::Fn(f)); + continue; + } + f.attrs.retain(|attr| !attr.path().is_ident("test")); + + let test = f.sig.ident.clone(); + + // Retrieve `#[cfg]` applied on the function which needs to be present on derived items too. + let cfg_attrs: Vec<_> = f + .attrs + .iter() + .filter(|attr| attr.path().is_ident("cfg")) + .cloned() + .collect(); + + // Before the test, override usual `assert!` and `assert_eq!` macros with ones that call + // KUnit instead. + let test_str = test.to_string(); + let path = crate::helpers::file(); + processed_items.push(parse_quote! { + #[allow(unused)] + macro_rules! assert { + ($cond:expr $(,)?) => {{ + kernel::kunit_assert!(#test_str, #path, 0, $cond); + }} + } + }); + processed_items.push(parse_quote! { + #[allow(unused)] + macro_rules! assert_eq { + ($left:expr, $right:expr $(,)?) => {{ + kernel::kunit_assert_eq!(#test_str, #path, 0, $left, $right); }} - }}"#, + } + }); + + // Add back the test item. + processed_items.push(Item::Fn(f)); + + let kunit_wrapper_fn_name = format_ident!("kunit_rust_wrapper_{test}"); + let test_cstr = LitCStr::new( + &CString::new(test_str.as_str()).expect("identifier cannot contain NUL"), + test.span(), ); - writeln!(kunit_macros, "{kunit_wrapper}").unwrap(); - writeln!( - test_cases, - " ::kernel::kunit::kunit_case(::kernel::c_str!("{test}"), {kunit_wrapper_fn_name})," - ) - .unwrap(); - writeln!( - assert_macros, - r#" -/// Overrides the usual [`assert!`] macro with one that calls KUnit instead. -#[allow(unused)] -macro_rules! assert {{ - ($cond:expr $(,)?) => {{{{ - kernel::kunit_assert!("{test}", "{path}", 0, $cond); - }}}} -}} - -/// Overrides the usual [`assert_eq!`] macro with one that calls KUnit instead. -#[allow(unused)] -macro_rules! assert_eq {{ - ($left:expr, $right:expr $(,)?) => {{{{ - kernel::kunit_assert_eq!("{test}", "{path}", 0, $left, $right); - }}}} -}} - "# - ) - .unwrap(); - } + processed_items.push(parse_quote! { + unsafe extern "C" fn #kunit_wrapper_fn_name(_test: *mut ::kernel::bindings::kunit) { + (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SKIPPED;
- writeln!(kunit_macros).unwrap(); - writeln!( - kunit_macros, - "static mut TEST_CASES: [::kernel::bindings::kunit_case; {}] = [\n{test_cases} ::kernel::kunit::kunit_case_null(),\n];", - num_tests + 1 - ) - .unwrap(); - - writeln!( - kunit_macros, - "::kernel::kunit_unsafe_test_suite!({attr}, TEST_CASES);" - ) - .unwrap(); - - // Remove the `#[test]` macros. - // We do this at a token level, in order to preserve span information. - let mut new_body = vec![]; - let mut body_it = body.stream().into_iter(); - - while let Some(token) = body_it.next() { - match token { - TokenTree::Punct(ref c) if c.as_char() == '#' => match body_it.next() { - Some(TokenTree::Group(group)) if group.to_string() == "[test]" => (), - Some(next) => { - new_body.extend([token, next]); - } - _ => { - new_body.push(token); + // Append any `cfg` attributes the user might have written on their tests so we + // don't attempt to call them when they are `cfg`'d out. An extra `use` is used + // here to reduce the length of the assert message. + #(#cfg_attrs)* + { + (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SUCCESS; + use ::kernel::kunit::is_test_result_ok; + assert!(is_test_result_ok(#test())); } - }, - _ => { - new_body.push(token); } - } - } - - let mut final_body = TokenStream::new(); - final_body.extend::<TokenStream>(assert_macros.parse().unwrap()); - final_body.extend(new_body); - final_body.extend::<TokenStream>(kunit_macros.parse().unwrap()); + });
- tokens.push(TokenTree::Group(Group::new(Delimiter::Brace, final_body))); + test_cases.push(quote!( + ::kernel::kunit::kunit_case(#test_cstr, #kunit_wrapper_fn_name) + )); + }
- tokens.into_iter().collect() + let num_tests_plus_1 = test_cases.len() + 1; + processed_items.push(parse_quote! { + static mut TEST_CASES: [::kernel::bindings::kunit_case; #num_tests_plus_1] = [ + #(#test_cases,)* + ::kernel::kunit::kunit_case_null(), + ]; + }); + processed_items.push(parse_quote! { + ::kernel::kunit_unsafe_test_suite!(#test_suite, TEST_CASES); + }); + + module.content = Some((module_brace, processed_items)); + Ok(module.to_token_stream()) } diff --git a/rust/macros/lib.rs b/rust/macros/lib.rs index bb2dfd4a4dafc..9cfac9fce0d36 100644 --- a/rust/macros/lib.rs +++ b/rust/macros/lib.rs @@ -453,6 +453,8 @@ pub fn paste(input: TokenStream) -> TokenStream { /// } /// ``` #[proc_macro_attribute] -pub fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream { - kunit::kunit_tests(attr.into(), ts.into()).into() +pub fn kunit_tests(attr: TokenStream, input: TokenStream) -> TokenStream { + kunit::kunit_tests(parse_macro_input!(attr), parse_macro_input!(input)) + .unwrap_or_else(|e| e.into_compile_error()) + .into() }