diff options
| -rw-r--r-- | rust/kernel/kunit.rs | 26 | ||||
| -rw-r--r-- | rust/macros/concat_idents.rs | 39 | ||||
| -rw-r--r-- | rust/macros/export.rs | 26 | ||||
| -rw-r--r-- | rust/macros/fmt.rs | 4 | ||||
| -rw-r--r-- | rust/macros/helpers.rs | 131 | ||||
| -rw-r--r-- | rust/macros/kunit.rs | 275 | ||||
| -rw-r--r-- | rust/macros/lib.rs | 41 | ||||
| -rw-r--r-- | rust/macros/module.rs | 907 | ||||
| -rw-r--r-- | rust/macros/paste.rs | 2 | ||||
| -rw-r--r-- | rust/macros/quote.rs | 182 | ||||
| -rw-r--r-- | rust/macros/vtable.rs | 165 |
11 files changed, 816 insertions, 982 deletions
diff --git a/rust/kernel/kunit.rs b/rust/kernel/kunit.rs index 4ccc8fc4a800..f93f24a60bdd 100644 --- a/rust/kernel/kunit.rs +++ b/rust/kernel/kunit.rs @@ -189,9 +189,6 @@ pub fn is_test_result_ok(t: impl TestResult) -> bool { } /// Represents an individual test case. -/// -/// The [`kunit_unsafe_test_suite!`] macro expects a `NULL`-terminated list of valid test cases. -/// Use [`kunit_case_null`] to generate such a delimiter. #[doc(hidden)] pub const fn kunit_case( name: &'static kernel::str::CStr, @@ -212,27 +209,6 @@ pub const fn kunit_case( } } -/// Represents the `NULL` test case delimiter. -/// -/// The [`kunit_unsafe_test_suite!`] macro expects a `NULL`-terminated list of test cases. This -/// function returns such a delimiter. -#[doc(hidden)] -pub const fn kunit_case_null() -> kernel::bindings::kunit_case { - kernel::bindings::kunit_case { - run_case: None, - name: core::ptr::null_mut(), - generate_params: None, - attr: kernel::bindings::kunit_attributes { - speed: kernel::bindings::kunit_speed_KUNIT_SPEED_NORMAL, - }, - status: kernel::bindings::kunit_status_KUNIT_SUCCESS, - module_name: core::ptr::null_mut(), - log: core::ptr::null_mut(), - param_init: None, - param_exit: None, - } -} - /// Registers a KUnit test suite. /// /// # Safety @@ -251,7 +227,7 @@ pub const fn kunit_case_null() -> kernel::bindings::kunit_case { /// /// static mut KUNIT_TEST_CASES: [kernel::bindings::kunit_case; 2] = [ /// kernel::kunit::kunit_case(c"name", test_fn), -/// kernel::kunit::kunit_case_null(), +/// pin_init::zeroed(), /// ]; /// kernel::kunit_unsafe_test_suite!(suite_name, KUNIT_TEST_CASES); /// ``` diff --git a/rust/macros/concat_idents.rs b/rust/macros/concat_idents.rs index 7e4b450f3a50..47b6add378d2 100644 --- a/rust/macros/concat_idents.rs +++ b/rust/macros/concat_idents.rs @@ -1,23 +1,36 @@ // SPDX-License-Identifier: GPL-2.0 -use proc_macro::{token_stream, Ident, TokenStream, TokenTree}; +use proc_macro2::{ + Ident, + TokenStream, + TokenTree, // +}; +use syn::{ + parse::{ + Parse, + ParseStream, // + }, + Result, + Token, // +}; -use crate::helpers::expect_punct; +pub(crate) struct Input { + a: Ident, + _comma: Token![,], + b: Ident, +} -fn expect_ident(it: &mut token_stream::IntoIter) -> Ident { - if let Some(TokenTree::Ident(ident)) = it.next() { - ident - } else { - panic!("Expected Ident") +impl Parse for Input { + fn parse(input: ParseStream<'_>) -> Result<Self> { + Ok(Self { + a: input.parse()?, + _comma: input.parse()?, + b: input.parse()?, + }) } } -pub(crate) fn concat_idents(ts: TokenStream) -> TokenStream { - let mut it = ts.into_iter(); - let a = expect_ident(&mut it); - assert_eq!(expect_punct(&mut it), ','); - let b = expect_ident(&mut it); - assert!(it.next().is_none(), "only two idents can be concatenated"); +pub(crate) fn concat_idents(Input { a, b, .. }: Input) -> TokenStream { let res = Ident::new(&format!("{a}{b}"), b.span()); TokenStream::from_iter([TokenTree::Ident(res)]) } diff --git a/rust/macros/export.rs b/rust/macros/export.rs index a08f6337d5c8..6d53521f62fc 100644 --- a/rust/macros/export.rs +++ b/rust/macros/export.rs @@ -1,19 +1,16 @@ // SPDX-License-Identifier: GPL-2.0 -use crate::helpers::function_name; -use proc_macro::TokenStream; +use proc_macro2::TokenStream; +use quote::quote; /// Please see [`crate::export`] for documentation. -pub(crate) fn export(_attr: TokenStream, ts: TokenStream) -> TokenStream { - let Some(name) = function_name(ts.clone()) else { - return "::core::compile_error!(\"The #[export] attribute must be used on a function.\");" - .parse::<TokenStream>() - .unwrap(); - }; +pub(crate) fn export(f: syn::ItemFn) -> TokenStream { + let name = &f.sig.ident; - // This verifies that the function has the same signature as the declaration generated by - // bindgen. It makes use of the fact that all branches of an if/else must have the same type. - let signature_check = quote!( + quote! { + // This verifies that the function has the same signature as the declaration generated by + // bindgen. It makes use of the fact that all branches of an if/else must have the same + // type. const _: () = { if true { ::kernel::bindings::#name @@ -21,9 +18,8 @@ pub(crate) fn export(_attr: TokenStream, ts: TokenStream) -> TokenStream { #name }; }; - ); - let no_mangle = quote!(#[no_mangle]); - - TokenStream::from_iter([signature_check, no_mangle, ts]) + #[no_mangle] + #f + } } diff --git a/rust/macros/fmt.rs b/rust/macros/fmt.rs index 2f4b9f6e2211..19f709262552 100644 --- a/rust/macros/fmt.rs +++ b/rust/macros/fmt.rs @@ -1,8 +1,10 @@ // SPDX-License-Identifier: GPL-2.0 -use proc_macro::{Ident, TokenStream, TokenTree}; use std::collections::BTreeSet; +use proc_macro2::{Ident, TokenStream, TokenTree}; +use quote::quote_spanned; + /// Please see [`crate::fmt`] for documentation. pub(crate) fn fmt(input: TokenStream) -> TokenStream { let mut input = input.into_iter(); diff --git a/rust/macros/helpers.rs b/rust/macros/helpers.rs index 365d7eb499c0..37ef6a6f2c85 100644 --- a/rust/macros/helpers.rs +++ b/rust/macros/helpers.rs @@ -1,101 +1,41 @@ // SPDX-License-Identifier: GPL-2.0 -use proc_macro::{token_stream, Group, Ident, TokenStream, TokenTree}; - -pub(crate) fn try_ident(it: &mut token_stream::IntoIter) -> Option<String> { - if let Some(TokenTree::Ident(ident)) = it.next() { - Some(ident.to_string()) - } else { - None - } -} - -pub(crate) fn try_sign(it: &mut token_stream::IntoIter) -> Option<char> { - let peek = it.clone().next(); - match peek { - Some(TokenTree::Punct(punct)) if punct.as_char() == '-' => { - let _ = it.next(); - Some(punct.as_char()) +use proc_macro2::TokenStream; +use quote::ToTokens; +use syn::{ + parse::{ + Parse, + ParseStream, // + }, + Attribute, + Error, + LitStr, + Result, // +}; + +/// A string literal that is required to have ASCII value only. +pub(crate) struct AsciiLitStr(LitStr); + +impl Parse for AsciiLitStr { + fn parse(input: ParseStream<'_>) -> Result<Self> { + let s: LitStr = input.parse()?; + if !s.value().is_ascii() { + return Err(Error::new_spanned(s, "expected ASCII-only string literal")); } - _ => None, - } -} - -pub(crate) fn try_literal(it: &mut token_stream::IntoIter) -> Option<String> { - if let Some(TokenTree::Literal(literal)) = it.next() { - Some(literal.to_string()) - } else { - None + Ok(Self(s)) } } -pub(crate) fn try_string(it: &mut token_stream::IntoIter) -> Option<String> { - try_literal(it).and_then(|string| { - if string.starts_with('\"') && string.ends_with('\"') { - let content = &string[1..string.len() - 1]; - if content.contains('\\') { - panic!("Escape sequences in string literals not yet handled"); - } - Some(content.to_string()) - } else if string.starts_with("r\"") { - panic!("Raw string literals are not yet handled"); - } else { - None - } - }) -} - -pub(crate) fn expect_ident(it: &mut token_stream::IntoIter) -> String { - try_ident(it).expect("Expected Ident") -} - -pub(crate) fn expect_punct(it: &mut token_stream::IntoIter) -> char { - if let TokenTree::Punct(punct) = it.next().expect("Reached end of token stream for Punct") { - punct.as_char() - } else { - panic!("Expected Punct"); +impl ToTokens for AsciiLitStr { + fn to_tokens(&self, ts: &mut TokenStream) { + self.0.to_tokens(ts); } } -pub(crate) fn expect_string(it: &mut token_stream::IntoIter) -> String { - try_string(it).expect("Expected string") -} - -pub(crate) fn expect_string_ascii(it: &mut token_stream::IntoIter) -> String { - let string = try_string(it).expect("Expected string"); - assert!(string.is_ascii(), "Expected ASCII string"); - string -} - -pub(crate) fn expect_group(it: &mut token_stream::IntoIter) -> Group { - if let TokenTree::Group(group) = it.next().expect("Reached end of token stream for Group") { - group - } else { - panic!("Expected Group"); - } -} - -pub(crate) fn expect_end(it: &mut token_stream::IntoIter) { - if it.next().is_some() { - panic!("Expected end"); - } -} - -/// Given a function declaration, finds the name of the function. -pub(crate) fn function_name(input: TokenStream) -> Option<Ident> { - let mut input = input.into_iter(); - while let Some(token) = input.next() { - match token { - TokenTree::Ident(i) if i.to_string() == "fn" => { - if let Some(TokenTree::Ident(i)) = input.next() { - return Some(i); - } - return None; - } - _ => continue, - } +impl AsciiLitStr { + pub(crate) fn value(&self) -> String { + self.0.value() } - None } pub(crate) fn file() -> String { @@ -115,16 +55,7 @@ pub(crate) fn file() -> String { } } -/// Parse a token stream of the form `expected_name: "value",` and return the -/// string in the position of "value". -/// -/// # Panics -/// -/// - On parse error. -pub(crate) fn expect_string_field(it: &mut token_stream::IntoIter, expected_name: &str) -> String { - assert_eq!(expect_ident(it), expected_name); - assert_eq!(expect_punct(it), ':'); - let string = expect_string(it); - assert_eq!(expect_punct(it), ','); - string +/// Obtain all `#[cfg]` attributes. +pub(crate) fn gather_cfg_attrs(attr: &[Attribute]) -> impl Iterator<Item = &Attribute> + '_ { + attr.iter().filter(|a| a.path().is_ident("cfg")) } diff --git a/rust/macros/kunit.rs b/rust/macros/kunit.rs index 3d7724b35c0f..6be880d634e2 100644 --- a/rust/macros/kunit.rs +++ b/rust/macros/kunit.rs @@ -4,80 +4,50 @@ //! //! Copyright (c) 2023 José Expósito <jose.exposito89@gmail.com> -use proc_macro::{Delimiter, Group, TokenStream, TokenTree}; -use std::collections::HashMap; -use std::fmt::Write; - -pub(crate) fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream { - let attr = attr.to_string(); - - if attr.is_empty() { - panic!("Missing test name in `#[kunit_tests(test_name)]` macro") - } - - if attr.len() > 255 { - panic!("The test suite name `{attr}` exceeds the maximum length of 255 bytes") +use std::ffi::CString; + +use proc_macro2::TokenStream; +use quote::{ + format_ident, + quote, + ToTokens, // +}; +use syn::{ + parse_quote, + Error, + Ident, + Item, + ItemMod, + LitCStr, + Result, // +}; + +pub(crate) fn kunit_tests(test_suite: Ident, mut module: ItemMod) -> Result<TokenStream> { + if test_suite.to_string().len() > 255 { + return Err(Error::new_spanned( + test_suite, + "test suite names cannot exceed the maximum length of 255 bytes", + )); } - let mut tokens: Vec<_> = ts.into_iter().collect(); - - // Scan for the `mod` keyword. - tokens - .iter() - .find_map(|token| match token { - TokenTree::Ident(ident) => match ident.to_string().as_str() { - "mod" => Some(true), - _ => None, - }, - _ => None, - }) - .expect("`#[kunit_tests(test_name)]` attribute should only be applied to modules"); - - // Retrieve the main body. The main body should be the last token tree. - let body = match tokens.pop() { - Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group, - _ => panic!("Cannot locate main body of module"), + // We cannot handle modules that defer to another file (e.g. `mod foo;`). + let Some((module_brace, module_items)) = module.content.take() else { + Err(Error::new_spanned( + module, + "`#[kunit_tests(test_name)]` attribute should only be applied to inline modules", + ))? }; - // Get the functions set as tests. Search for `[test]` -> `fn`. - let mut body_it = body.stream().into_iter(); - let mut tests = Vec::new(); - let mut attributes: HashMap<String, TokenStream> = HashMap::new(); - while let Some(token) = body_it.next() { - match token { - TokenTree::Punct(ref p) if p.as_char() == '#' => match body_it.next() { - Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Bracket => { - if let Some(TokenTree::Ident(name)) = g.stream().into_iter().next() { - // Collect attributes because we need to find which are tests. We also - // need to copy `cfg` attributes so tests can be conditionally enabled. - attributes - .entry(name.to_string()) - .or_default() - .extend([token, TokenTree::Group(g)]); - } - continue; - } - _ => (), - }, - TokenTree::Ident(i) if i.to_string() == "fn" && attributes.contains_key("test") => { - if let Some(TokenTree::Ident(test_name)) = body_it.next() { - tests.push((test_name, attributes.remove("cfg").unwrap_or_default())) - } - } - - _ => (), - } - attributes.clear(); - } + // Make the entire module gated behind `CONFIG_KUNIT`. + module + .attrs + .insert(0, parse_quote!(#[cfg(CONFIG_KUNIT="y")])); - // Add `#[cfg(CONFIG_KUNIT="y")]` before the module declaration. - let config_kunit = "#[cfg(CONFIG_KUNIT=\"y\")]".to_owned().parse().unwrap(); - tokens.insert( - 0, - TokenTree::Group(Group::new(Delimiter::None, config_kunit)), - ); + let mut processed_items = Vec::new(); + let mut test_cases = Vec::new(); // Generate the test KUnit test suite and a test case for each `#[test]`. + // // The code generated for the following test module: // // ``` @@ -104,103 +74,98 @@ pub(crate) fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream { // static mut TEST_CASES: [::kernel::bindings::kunit_case; 3] = [ // ::kernel::kunit::kunit_case(c"foo", kunit_rust_wrapper_foo), // ::kernel::kunit::kunit_case(c"bar", kunit_rust_wrapper_bar), - // ::kernel::kunit::kunit_case_null(), + // ::pin_init::zeroed(), // ]; // // ::kernel::kunit_unsafe_test_suite!(kunit_test_suit_name, TEST_CASES); // ``` - let mut kunit_macros = "".to_owned(); - let mut test_cases = "".to_owned(); - let mut assert_macros = "".to_owned(); - let path = crate::helpers::file(); - let num_tests = tests.len(); - for (test, cfg_attr) in tests { - let kunit_wrapper_fn_name = format!("kunit_rust_wrapper_{test}"); - // Append any `cfg` attributes the user might have written on their tests so we don't - // attempt to call them when they are `cfg`'d out. An extra `use` is used here to reduce - // the length of the assert message. - let kunit_wrapper = format!( - r#"unsafe extern "C" fn {kunit_wrapper_fn_name}(_test: *mut ::kernel::bindings::kunit) - {{ - (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SKIPPED; - {cfg_attr} {{ - (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SUCCESS; - use ::kernel::kunit::is_test_result_ok; - assert!(is_test_result_ok({test}())); + // + // Non-function items (e.g. imports) are preserved. + for item in module_items { + let Item::Fn(mut f) = item else { + processed_items.push(item); + continue; + }; + + // TODO: Replace below with `extract_if` when MSRV is bumped above 1.85. + let before_len = f.attrs.len(); + f.attrs.retain(|attr| !attr.path().is_ident("test")); + if f.attrs.len() == before_len { + processed_items.push(Item::Fn(f)); + continue; + } + + let test = f.sig.ident.clone(); + + // Retrieve `#[cfg]` applied on the function which needs to be present on derived items too. + let cfg_attrs: Vec<_> = f + .attrs + .iter() + .filter(|attr| attr.path().is_ident("cfg")) + .cloned() + .collect(); + + // Before the test, override usual `assert!` and `assert_eq!` macros with ones that call + // KUnit instead. + let test_str = test.to_string(); + let path = CString::new(crate::helpers::file()).expect("file path cannot contain NUL"); + processed_items.push(parse_quote! { + #[allow(unused)] + macro_rules! assert { + ($cond:expr $(,)?) => {{ + kernel::kunit_assert!(#test_str, #path, 0, $cond); + }} + } + }); + processed_items.push(parse_quote! { + #[allow(unused)] + macro_rules! assert_eq { + ($left:expr, $right:expr $(,)?) => {{ + kernel::kunit_assert_eq!(#test_str, #path, 0, $left, $right); }} - }}"#, + } + }); + + // Add back the test item. + processed_items.push(Item::Fn(f)); + + let kunit_wrapper_fn_name = format_ident!("kunit_rust_wrapper_{test}"); + let test_cstr = LitCStr::new( + &CString::new(test_str.as_str()).expect("identifier cannot contain NUL"), + test.span(), ); - writeln!(kunit_macros, "{kunit_wrapper}").unwrap(); - writeln!( - test_cases, - " ::kernel::kunit::kunit_case(c\"{test}\", {kunit_wrapper_fn_name})," - ) - .unwrap(); - writeln!( - assert_macros, - r#" -/// Overrides the usual [`assert!`] macro with one that calls KUnit instead. -#[allow(unused)] -macro_rules! assert {{ - ($cond:expr $(,)?) => {{{{ - kernel::kunit_assert!("{test}", c"{path}", 0, $cond); - }}}} -}} - -/// Overrides the usual [`assert_eq!`] macro with one that calls KUnit instead. -#[allow(unused)] -macro_rules! assert_eq {{ - ($left:expr, $right:expr $(,)?) => {{{{ - kernel::kunit_assert_eq!("{test}", c"{path}", 0, $left, $right); - }}}} -}} - "# - ) - .unwrap(); - } + processed_items.push(parse_quote! { + unsafe extern "C" fn #kunit_wrapper_fn_name(_test: *mut ::kernel::bindings::kunit) { + (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SKIPPED; - writeln!(kunit_macros).unwrap(); - writeln!( - kunit_macros, - "static mut TEST_CASES: [::kernel::bindings::kunit_case; {}] = [\n{test_cases} ::kernel::kunit::kunit_case_null(),\n];", - num_tests + 1 - ) - .unwrap(); - - writeln!( - kunit_macros, - "::kernel::kunit_unsafe_test_suite!({attr}, TEST_CASES);" - ) - .unwrap(); - - // Remove the `#[test]` macros. - // We do this at a token level, in order to preserve span information. - let mut new_body = vec![]; - let mut body_it = body.stream().into_iter(); - - while let Some(token) = body_it.next() { - match token { - TokenTree::Punct(ref c) if c.as_char() == '#' => match body_it.next() { - Some(TokenTree::Group(group)) if group.to_string() == "[test]" => (), - Some(next) => { - new_body.extend([token, next]); - } - _ => { - new_body.push(token); + // Append any `cfg` attributes the user might have written on their tests so we + // don't attempt to call them when they are `cfg`'d out. An extra `use` is used + // here to reduce the length of the assert message. + #(#cfg_attrs)* + { + (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SUCCESS; + use ::kernel::kunit::is_test_result_ok; + assert!(is_test_result_ok(#test())); } - }, - _ => { - new_body.push(token); } - } - } - - let mut final_body = TokenStream::new(); - final_body.extend::<TokenStream>(assert_macros.parse().unwrap()); - final_body.extend(new_body); - final_body.extend::<TokenStream>(kunit_macros.parse().unwrap()); + }); - tokens.push(TokenTree::Group(Group::new(Delimiter::Brace, final_body))); + test_cases.push(quote!( + ::kernel::kunit::kunit_case(#test_cstr, #kunit_wrapper_fn_name) + )); + } - tokens.into_iter().collect() + let num_tests_plus_1 = test_cases.len() + 1; + processed_items.push(parse_quote! { + static mut TEST_CASES: [::kernel::bindings::kunit_case; #num_tests_plus_1] = [ + #(#test_cases,)* + ::pin_init::zeroed(), + ]; + }); + processed_items.push(parse_quote! { + ::kernel::kunit_unsafe_test_suite!(#test_suite, TEST_CASES); + }); + + module.content = Some((module_brace, processed_items)); + Ok(module.to_token_stream()) } diff --git a/rust/macros/lib.rs b/rust/macros/lib.rs index f26775285dc5..85b7938c08e5 100644 --- a/rust/macros/lib.rs +++ b/rust/macros/lib.rs @@ -11,8 +11,6 @@ // to avoid depending on the full `proc_macro_span` on Rust >= 1.88.0. #![cfg_attr(not(CONFIG_RUSTC_HAS_SPAN_FILE), feature(proc_macro_span))] -#[macro_use] -mod quote; mod concat_idents; mod export; mod fmt; @@ -24,6 +22,8 @@ mod vtable; use proc_macro::TokenStream; +use syn::parse_macro_input; + /// Declares a kernel module. /// /// The `type` argument should be a type which implements the [`Module`] @@ -131,8 +131,10 @@ use proc_macro::TokenStream; /// - `firmware`: array of ASCII string literals of the firmware files of /// the kernel module. #[proc_macro] -pub fn module(ts: TokenStream) -> TokenStream { - module::module(ts) +pub fn module(input: TokenStream) -> TokenStream { + module::module(parse_macro_input!(input)) + .unwrap_or_else(|e| e.into_compile_error()) + .into() } /// Declares or implements a vtable trait. @@ -206,8 +208,11 @@ pub fn module(ts: TokenStream) -> TokenStream { /// /// [`kernel::error::VTABLE_DEFAULT_ERROR`]: ../kernel/error/constant.VTABLE_DEFAULT_ERROR.html #[proc_macro_attribute] -pub fn vtable(attr: TokenStream, ts: TokenStream) -> TokenStream { - vtable::vtable(attr, ts) +pub fn vtable(attr: TokenStream, input: TokenStream) -> TokenStream { + parse_macro_input!(attr as syn::parse::Nothing); + vtable::vtable(parse_macro_input!(input)) + .unwrap_or_else(|e| e.into_compile_error()) + .into() } /// Export a function so that C code can call it via a header file. @@ -229,8 +234,9 @@ pub fn vtable(attr: TokenStream, ts: TokenStream) -> TokenStream { /// This macro is *not* the same as the C macros `EXPORT_SYMBOL_*`. All Rust symbols are currently /// automatically exported with `EXPORT_SYMBOL_GPL`. #[proc_macro_attribute] -pub fn export(attr: TokenStream, ts: TokenStream) -> TokenStream { - export::export(attr, ts) +pub fn export(attr: TokenStream, input: TokenStream) -> TokenStream { + parse_macro_input!(attr as syn::parse::Nothing); + export::export(parse_macro_input!(input)).into() } /// Like [`core::format_args!`], but automatically wraps arguments in [`kernel::fmt::Adapter`]. @@ -248,7 +254,7 @@ pub fn export(attr: TokenStream, ts: TokenStream) -> TokenStream { /// [`pr_info!`]: ../kernel/macro.pr_info.html #[proc_macro] pub fn fmt(input: TokenStream) -> TokenStream { - fmt::fmt(input) + fmt::fmt(input.into()).into() } /// Concatenate two identifiers. @@ -305,8 +311,8 @@ pub fn fmt(input: TokenStream) -> TokenStream { /// assert_eq!(BR_OK, binder_driver_return_protocol_BR_OK); /// ``` #[proc_macro] -pub fn concat_idents(ts: TokenStream) -> TokenStream { - concat_idents::concat_idents(ts) +pub fn concat_idents(input: TokenStream) -> TokenStream { + concat_idents::concat_idents(parse_macro_input!(input)).into() } /// Paste identifiers together. @@ -444,9 +450,12 @@ pub fn concat_idents(ts: TokenStream) -> TokenStream { /// [`paste`]: https://docs.rs/paste/ #[proc_macro] pub fn paste(input: TokenStream) -> TokenStream { - let mut tokens = input.into_iter().collect(); + let mut tokens = proc_macro2::TokenStream::from(input).into_iter().collect(); paste::expand(&mut tokens); - tokens.into_iter().collect() + tokens + .into_iter() + .collect::<proc_macro2::TokenStream>() + .into() } /// Registers a KUnit test suite and its test cases using a user-space like syntax. @@ -472,6 +481,8 @@ pub fn paste(input: TokenStream) -> TokenStream { /// } /// ``` #[proc_macro_attribute] -pub fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream { - kunit::kunit_tests(attr, ts) +pub fn kunit_tests(attr: TokenStream, input: TokenStream) -> TokenStream { + kunit::kunit_tests(parse_macro_input!(attr), parse_macro_input!(input)) + .unwrap_or_else(|e| e.into_compile_error()) + .into() } diff --git a/rust/macros/module.rs b/rust/macros/module.rs index 80cb9b16f5aa..e16298e520c7 100644 --- a/rust/macros/module.rs +++ b/rust/macros/module.rs @@ -1,32 +1,42 @@ // SPDX-License-Identifier: GPL-2.0 +use std::ffi::CString; + +use proc_macro2::{ + Literal, + TokenStream, // +}; +use quote::{ + format_ident, + quote, // +}; +use syn::{ + braced, + bracketed, + ext::IdentExt, + parse::{ + Parse, + ParseStream, // + }, + parse_quote, + punctuated::Punctuated, + Error, + Expr, + Ident, + LitStr, + Path, + Result, + Token, + Type, // +}; + use crate::helpers::*; -use proc_macro::{token_stream, Delimiter, Literal, TokenStream, TokenTree}; -use std::fmt::Write; - -fn expect_string_array(it: &mut token_stream::IntoIter) -> Vec<String> { - let group = expect_group(it); - assert_eq!(group.delimiter(), Delimiter::Bracket); - let mut values = Vec::new(); - let mut it = group.stream().into_iter(); - - while let Some(val) = try_string(&mut it) { - assert!(val.is_ascii(), "Expected ASCII string"); - values.push(val); - match it.next() { - Some(TokenTree::Punct(punct)) => assert_eq!(punct.as_char(), ','), - None => break, - _ => panic!("Expected ',' or end of array"), - } - } - values -} struct ModInfoBuilder<'a> { module: &'a str, counter: usize, - buffer: String, - param_buffer: String, + ts: TokenStream, + param_ts: TokenStream, } impl<'a> ModInfoBuilder<'a> { @@ -34,8 +44,8 @@ impl<'a> ModInfoBuilder<'a> { ModInfoBuilder { module, counter: 0, - buffer: String::new(), - param_buffer: String::new(), + ts: TokenStream::new(), + param_ts: TokenStream::new(), } } @@ -52,33 +62,31 @@ impl<'a> ModInfoBuilder<'a> { // Loadable modules' modinfo strings go as-is. format!("{field}={content}\0") }; - - let buffer = if param { - &mut self.param_buffer + let length = string.len(); + let string = Literal::byte_string(string.as_bytes()); + let cfg = if builtin { + quote!(#[cfg(not(MODULE))]) } else { - &mut self.buffer + quote!(#[cfg(MODULE)]) }; - write!( - buffer, - " - {cfg} - #[doc(hidden)] - #[cfg_attr(not(target_os = \"macos\"), link_section = \".modinfo\")] - #[used(compiler)] - pub static __{module}_{counter}: [u8; {length}] = *{string}; - ", - cfg = if builtin { - "#[cfg(not(MODULE))]" - } else { - "#[cfg(MODULE)]" - }, + let counter = format_ident!( + "__{module}_{counter}", module = self.module.to_uppercase(), - counter = self.counter, - length = string.len(), - string = Literal::byte_string(string.as_bytes()), - ) - .unwrap(); + counter = self.counter + ); + let item = quote! { + #cfg + #[cfg_attr(not(target_os = "macos"), link_section = ".modinfo")] + #[used(compiler)] + pub static #counter: [u8; #length] = *#string; + }; + + if param { + self.param_ts.extend(item); + } else { + self.ts.extend(item); + } self.counter += 1; } @@ -111,201 +119,160 @@ impl<'a> ModInfoBuilder<'a> { }; for param in params { - let ops = param_ops_path(¶m.ptype); + let param_name_str = param.name.to_string(); + let param_type_str = param.ptype.to_string(); + + let ops = param_ops_path(¶m_type_str); // Note: The spelling of these fields is dictated by the user space // tool `modinfo`. - self.emit_param("parmtype", ¶m.name, ¶m.ptype); - self.emit_param("parm", ¶m.name, ¶m.description); - - write!( - self.param_buffer, - " - pub(crate) static {param_name}: - ::kernel::module_param::ModuleParamAccess<{param_type}> = - ::kernel::module_param::ModuleParamAccess::new({param_default}); - - const _: () = {{ - #[link_section = \"__param\"] - #[used] - static __{module_name}_{param_name}_struct: + self.emit_param("parmtype", ¶m_name_str, ¶m_type_str); + self.emit_param("parm", ¶m_name_str, ¶m.description.value()); + + let static_name = format_ident!("__{}_{}_struct", self.module, param.name); + let param_name_cstr = + CString::new(param_name_str).expect("name contains NUL-terminator"); + let param_name_cstr_with_module = + CString::new(format!("{}.{}", self.module, param.name)) + .expect("name contains NUL-terminator"); + + let param_name = ¶m.name; + let param_type = ¶m.ptype; + let param_default = ¶m.default; + + self.param_ts.extend(quote! { + #[allow(non_upper_case_globals)] + pub(crate) static #param_name: + ::kernel::module_param::ModuleParamAccess<#param_type> = + ::kernel::module_param::ModuleParamAccess::new(#param_default); + + const _: () = { + #[allow(non_upper_case_globals)] + #[link_section = "__param"] + #[used(compiler)] + static #static_name: ::kernel::module_param::KernelParam = ::kernel::module_param::KernelParam::new( - ::kernel::bindings::kernel_param {{ - name: if ::core::cfg!(MODULE) {{ - ::kernel::c_str!(\"{param_name}\").to_bytes_with_nul() - }} else {{ - ::kernel::c_str!(\"{module_name}.{param_name}\") - .to_bytes_with_nul() - }}.as_ptr(), + ::kernel::bindings::kernel_param { + name: kernel::str::as_char_ptr_in_const_context( + if ::core::cfg!(MODULE) { + #param_name_cstr + } else { + #param_name_cstr_with_module + } + ), // SAFETY: `__this_module` is constructed by the kernel at load // time and will not be freed until the module is unloaded. #[cfg(MODULE)] - mod_: unsafe {{ + mod_: unsafe { core::ptr::from_ref(&::kernel::bindings::__this_module) .cast_mut() - }}, + }, #[cfg(not(MODULE))] mod_: ::core::ptr::null_mut(), - ops: core::ptr::from_ref(&{ops}), + ops: core::ptr::from_ref(&#ops), perm: 0, // Will not appear in sysfs level: -1, flags: 0, - __bindgen_anon_1: ::kernel::bindings::kernel_param__bindgen_ty_1 {{ - arg: {param_name}.as_void_ptr() - }}, - }} + __bindgen_anon_1: ::kernel::bindings::kernel_param__bindgen_ty_1 { + arg: #param_name.as_void_ptr() + }, + } ); - }}; - ", - module_name = info.name, - param_type = param.ptype, - param_default = param.default, - param_name = param.name, - ops = ops, - ) - .unwrap(); + }; + }); } } } -fn param_ops_path(param_type: &str) -> &'static str { +fn param_ops_path(param_type: &str) -> Path { match param_type { - "i8" => "::kernel::module_param::PARAM_OPS_I8", - "u8" => "::kernel::module_param::PARAM_OPS_U8", - "i16" => "::kernel::module_param::PARAM_OPS_I16", - "u16" => "::kernel::module_param::PARAM_OPS_U16", - "i32" => "::kernel::module_param::PARAM_OPS_I32", - "u32" => "::kernel::module_param::PARAM_OPS_U32", - "i64" => "::kernel::module_param::PARAM_OPS_I64", - "u64" => "::kernel::module_param::PARAM_OPS_U64", - "isize" => "::kernel::module_param::PARAM_OPS_ISIZE", - "usize" => "::kernel::module_param::PARAM_OPS_USIZE", + "i8" => parse_quote!(::kernel::module_param::PARAM_OPS_I8), + "u8" => parse_quote!(::kernel::module_param::PARAM_OPS_U8), + "i16" => parse_quote!(::kernel::module_param::PARAM_OPS_I16), + "u16" => parse_quote!(::kernel::module_param::PARAM_OPS_U16), + "i32" => parse_quote!(::kernel::module_param::PARAM_OPS_I32), + "u32" => parse_quote!(::kernel::module_param::PARAM_OPS_U32), + "i64" => parse_quote!(::kernel::module_param::PARAM_OPS_I64), + "u64" => parse_quote!(::kernel::module_param::PARAM_OPS_U64), + "isize" => parse_quote!(::kernel::module_param::PARAM_OPS_ISIZE), + "usize" => parse_quote!(::kernel::module_param::PARAM_OPS_USIZE), t => panic!("Unsupported parameter type {}", t), } } -fn expect_param_default(param_it: &mut token_stream::IntoIter) -> String { - assert_eq!(expect_ident(param_it), "default"); - assert_eq!(expect_punct(param_it), ':'); - let sign = try_sign(param_it); - let default = try_literal(param_it).expect("Expected default param value"); - assert_eq!(expect_punct(param_it), ','); - let mut value = sign.map(String::from).unwrap_or_default(); - value.push_str(&default); - value -} - -#[derive(Debug, Default)] -struct ModuleInfo { - type_: String, - license: String, - name: String, - authors: Option<Vec<String>>, - description: Option<String>, - alias: Option<Vec<String>>, - firmware: Option<Vec<String>>, - imports_ns: Option<Vec<String>>, - params: Option<Vec<Parameter>>, -} - -#[derive(Debug)] -struct Parameter { - name: String, - ptype: String, - default: String, - description: String, -} - -fn expect_params(it: &mut token_stream::IntoIter) -> Vec<Parameter> { - let params = expect_group(it); - assert_eq!(params.delimiter(), Delimiter::Brace); - let mut it = params.stream().into_iter(); - let mut parsed = Vec::new(); - - loop { - let param_name = match it.next() { - Some(TokenTree::Ident(ident)) => ident.to_string(), - Some(_) => panic!("Expected Ident or end"), - None => break, - }; - - assert_eq!(expect_punct(&mut it), ':'); - let param_type = expect_ident(&mut it); - let group = expect_group(&mut it); - assert_eq!(group.delimiter(), Delimiter::Brace); - assert_eq!(expect_punct(&mut it), ','); - - let mut param_it = group.stream().into_iter(); - let param_default = expect_param_default(&mut param_it); - let param_description = expect_string_field(&mut param_it, "description"); - expect_end(&mut param_it); - - parsed.push(Parameter { - name: param_name, - ptype: param_type, - default: param_default, - description: param_description, - }) - } - - parsed -} - -impl ModuleInfo { - fn parse(it: &mut token_stream::IntoIter) -> Self { - let mut info = ModuleInfo::default(); - - const EXPECTED_KEYS: &[&str] = &[ - "type", - "name", - "authors", - "description", - "license", - "alias", - "firmware", - "imports_ns", - "params", - ]; - const REQUIRED_KEYS: &[&str] = &["type", "name", "license"]; +/// Parse fields that are required to use a specific order. +/// +/// As fields must follow a specific order, we *could* just parse fields one by one by peeking. +/// However the error message generated when implementing that way is not very friendly. +/// +/// So instead we parse fields in an arbitrary order, but only enforce the ordering after parsing, +/// and if the wrong order is used, the proper order is communicated to the user with error message. +/// +/// Usage looks like this: +/// ```ignore +/// parse_ordered_fields! { +/// from input; +/// +/// // This will extract "foo: <field>" into a variable named "foo". +/// // The variable will have type `Option<_>`. +/// foo => <expression that parses the field>, +/// +/// // If you need the variable name to be different than the key name. +/// // This extracts "baz: <field>" into a variable named "bar". +/// // You might want this if "baz" is a keyword. +/// baz as bar => <expression that parse the field>, +/// +/// // You can mark a key as required, and the variable will no longer be `Option`. +/// // foobar will be of type `Expr` instead of `Option<Expr>`. +/// foobar [required] => input.parse::<Expr>()?, +/// } +/// ``` +macro_rules! parse_ordered_fields { + (@gen + [$input:expr] + [$([$name:ident; $key:ident; $parser:expr])*] + [$([$req_name:ident; $req_key:ident])*] + ) => { + $(let mut $name = None;)* + + const EXPECTED_KEYS: &[&str] = &[$(stringify!($key),)*]; + const REQUIRED_KEYS: &[&str] = &[$(stringify!($req_key),)*]; + + let span = $input.span(); let mut seen_keys = Vec::new(); - loop { - let key = match it.next() { - Some(TokenTree::Ident(ident)) => ident.to_string(), - Some(_) => panic!("Expected Ident or end"), - None => break, - }; + while !$input.is_empty() { + let key = $input.call(Ident::parse_any)?; if seen_keys.contains(&key) { - panic!("Duplicated key \"{key}\". Keys can only be specified once."); + Err(Error::new_spanned( + &key, + format!(r#"duplicated key "{key}". Keys can only be specified once."#), + ))? } - assert_eq!(expect_punct(it), ':'); - - match key.as_str() { - "type" => info.type_ = expect_ident(it), - "name" => info.name = expect_string_ascii(it), - "authors" => info.authors = Some(expect_string_array(it)), - "description" => info.description = Some(expect_string(it)), - "license" => info.license = expect_string_ascii(it), - "alias" => info.alias = Some(expect_string_array(it)), - "firmware" => info.firmware = Some(expect_string_array(it)), - "imports_ns" => info.imports_ns = Some(expect_string_array(it)), - "params" => info.params = Some(expect_params(it)), - _ => panic!("Unknown key \"{key}\". Valid keys are: {EXPECTED_KEYS:?}."), + $input.parse::<Token![:]>()?; + + match &*key.to_string() { + $( + stringify!($key) => $name = Some($parser), + )* + _ => { + Err(Error::new_spanned( + &key, + format!(r#"unknown key "{key}". Valid keys are: {EXPECTED_KEYS:?}."#), + ))? + } } - assert_eq!(expect_punct(it), ','); - + $input.parse::<Token![,]>()?; seen_keys.push(key); } - expect_end(it); - for key in REQUIRED_KEYS { if !seen_keys.iter().any(|e| e == key) { - panic!("Missing required key \"{key}\"."); + Err(Error::new(span, format!(r#"missing required key "{key}""#)))? } } @@ -317,43 +284,190 @@ impl ModuleInfo { } if seen_keys != ordered_keys { - panic!("Keys are not ordered as expected. Order them like: {ordered_keys:?}."); + Err(Error::new( + span, + format!(r#"keys are not ordered as expected. Order them like: {ordered_keys:?}."#), + ))? + } + + $(let $req_name = $req_name.expect("required field");)* + }; + + // Handle required fields. + (@gen + [$input:expr] [$($tok:tt)*] [$($req:tt)*] + $key:ident as $name:ident [required] => $parser:expr, + $($rest:tt)* + ) => { + parse_ordered_fields!( + @gen [$input] [$($tok)* [$name; $key; $parser]] [$($req)* [$name; $key]] $($rest)* + ) + }; + (@gen + [$input:expr] [$($tok:tt)*] [$($req:tt)*] + $name:ident [required] => $parser:expr, + $($rest:tt)* + ) => { + parse_ordered_fields!( + @gen [$input] [$($tok)* [$name; $name; $parser]] [$($req)* [$name; $name]] $($rest)* + ) + }; + + // Handle optional fields. + (@gen + [$input:expr] [$($tok:tt)*] [$($req:tt)*] + $key:ident as $name:ident => $parser:expr, + $($rest:tt)* + ) => { + parse_ordered_fields!( + @gen [$input] [$($tok)* [$name; $key; $parser]] [$($req)*] $($rest)* + ) + }; + (@gen + [$input:expr] [$($tok:tt)*] [$($req:tt)*] + $name:ident => $parser:expr, + $($rest:tt)* + ) => { + parse_ordered_fields!( + @gen [$input] [$($tok)* [$name; $name; $parser]] [$($req)*] $($rest)* + ) + }; + + (from $input:expr; $($tok:tt)*) => { + parse_ordered_fields!(@gen [$input] [] [] $($tok)*) + } +} + +struct Parameter { + name: Ident, + ptype: Ident, + default: Expr, + description: LitStr, +} + +impl Parse for Parameter { + fn parse(input: ParseStream<'_>) -> Result<Self> { + let name = input.parse()?; + input.parse::<Token![:]>()?; + let ptype = input.parse()?; + + let fields; + braced!(fields in input); + + parse_ordered_fields! { + from fields; + default [required] => fields.parse()?, + description [required] => fields.parse()?, } - info + Ok(Self { + name, + ptype, + default, + description, + }) } } -pub(crate) fn module(ts: TokenStream) -> TokenStream { - let mut it = ts.into_iter(); +pub(crate) struct ModuleInfo { + type_: Type, + license: AsciiLitStr, + name: AsciiLitStr, + authors: Option<Punctuated<AsciiLitStr, Token![,]>>, + description: Option<LitStr>, + alias: Option<Punctuated<AsciiLitStr, Token![,]>>, + firmware: Option<Punctuated<AsciiLitStr, Token![,]>>, + imports_ns: Option<Punctuated<AsciiLitStr, Token![,]>>, + params: Option<Punctuated<Parameter, Token![,]>>, +} + +impl Parse for ModuleInfo { + fn parse(input: ParseStream<'_>) -> Result<Self> { + parse_ordered_fields!( + from input; + type as type_ [required] => input.parse()?, + name [required] => input.parse()?, + authors => { + let list; + bracketed!(list in input); + Punctuated::parse_terminated(&list)? + }, + description => input.parse()?, + license [required] => input.parse()?, + alias => { + let list; + bracketed!(list in input); + Punctuated::parse_terminated(&list)? + }, + firmware => { + let list; + bracketed!(list in input); + Punctuated::parse_terminated(&list)? + }, + imports_ns => { + let list; + bracketed!(list in input); + Punctuated::parse_terminated(&list)? + }, + params => { + let list; + braced!(list in input); + Punctuated::parse_terminated(&list)? + }, + ); + + Ok(ModuleInfo { + type_, + license, + name, + authors, + description, + alias, + firmware, + imports_ns, + params, + }) + } +} - let info = ModuleInfo::parse(&mut it); +pub(crate) fn module(info: ModuleInfo) -> Result<TokenStream> { + let ModuleInfo { + type_, + license, + name, + authors, + description, + alias, + firmware, + imports_ns, + params: _, + } = &info; // Rust does not allow hyphens in identifiers, use underscore instead. - let ident = info.name.replace('-', "_"); + let ident = name.value().replace('-', "_"); let mut modinfo = ModInfoBuilder::new(ident.as_ref()); - if let Some(authors) = &info.authors { + if let Some(authors) = authors { for author in authors { - modinfo.emit("author", author); + modinfo.emit("author", &author.value()); } } - if let Some(description) = &info.description { - modinfo.emit("description", description); + if let Some(description) = description { + modinfo.emit("description", &description.value()); } - modinfo.emit("license", &info.license); - if let Some(aliases) = &info.alias { + modinfo.emit("license", &license.value()); + if let Some(aliases) = alias { for alias in aliases { - modinfo.emit("alias", alias); + modinfo.emit("alias", &alias.value()); } } - if let Some(firmware) = &info.firmware { + if let Some(firmware) = firmware { for fw in firmware { - modinfo.emit("firmware", fw); + modinfo.emit("firmware", &fw.value()); } } - if let Some(imports) = &info.imports_ns { + if let Some(imports) = imports_ns { for ns in imports { - modinfo.emit("import_ns", ns); + modinfo.emit("import_ns", &ns.value()); } } @@ -364,182 +478,181 @@ pub(crate) fn module(ts: TokenStream) -> TokenStream { modinfo.emit_params(&info); - format!( - " - /// The module name. - /// - /// Used by the printing macros, e.g. [`info!`]. - const __LOG_PREFIX: &[u8] = b\"{name}\\0\"; - - // SAFETY: `__this_module` is constructed by the kernel at load time and will not be - // freed until the module is unloaded. - #[cfg(MODULE)] - static THIS_MODULE: ::kernel::ThisModule = unsafe {{ - extern \"C\" {{ - static __this_module: ::kernel::types::Opaque<::kernel::bindings::module>; - }} - - ::kernel::ThisModule::from_ptr(__this_module.get()) - }}; - #[cfg(not(MODULE))] - static THIS_MODULE: ::kernel::ThisModule = unsafe {{ - ::kernel::ThisModule::from_ptr(::core::ptr::null_mut()) - }}; - - /// The `LocalModule` type is the type of the module created by `module!`, - /// `module_pci_driver!`, `module_platform_driver!`, etc. - type LocalModule = {type_}; - - impl ::kernel::ModuleMetadata for {type_} {{ - const NAME: &'static ::kernel::str::CStr = c\"{name}\"; - }} - - // Double nested modules, since then nobody can access the public items inside. - mod __module_init {{ - mod __module_init {{ - use super::super::{type_}; - use pin_init::PinInit; - - /// The \"Rust loadable module\" mark. - // - // This may be best done another way later on, e.g. as a new modinfo - // key or a new section. For the moment, keep it simple. - #[cfg(MODULE)] - #[doc(hidden)] - #[used(compiler)] - static __IS_RUST_MODULE: () = (); - - static mut __MOD: ::core::mem::MaybeUninit<{type_}> = - ::core::mem::MaybeUninit::uninit(); - - // Loadable modules need to export the `{{init,cleanup}}_module` identifiers. - /// # Safety - /// - /// This function must not be called after module initialization, because it may be - /// freed after that completes. - #[cfg(MODULE)] - #[doc(hidden)] - #[no_mangle] - #[link_section = \".init.text\"] - pub unsafe extern \"C\" fn init_module() -> ::kernel::ffi::c_int {{ - // SAFETY: This function is inaccessible to the outside due to the double - // module wrapping it. It is called exactly once by the C side via its - // unique name. - unsafe {{ __init() }} - }} - - #[cfg(MODULE)] - #[doc(hidden)] - #[used(compiler)] - #[link_section = \".init.data\"] - static __UNIQUE_ID___addressable_init_module: unsafe extern \"C\" fn() -> i32 = init_module; - - #[cfg(MODULE)] - #[doc(hidden)] - #[no_mangle] - #[link_section = \".exit.text\"] - pub extern \"C\" fn cleanup_module() {{ - // SAFETY: - // - This function is inaccessible to the outside due to the double - // module wrapping it. It is called exactly once by the C side via its - // unique name, - // - furthermore it is only called after `init_module` has returned `0` - // (which delegates to `__init`). - unsafe {{ __exit() }} - }} - - #[cfg(MODULE)] - #[doc(hidden)] - #[used(compiler)] - #[link_section = \".exit.data\"] - static __UNIQUE_ID___addressable_cleanup_module: extern \"C\" fn() = cleanup_module; - - // Built-in modules are initialized through an initcall pointer - // and the identifiers need to be unique. - #[cfg(not(MODULE))] - #[cfg(not(CONFIG_HAVE_ARCH_PREL32_RELOCATIONS))] - #[doc(hidden)] - #[link_section = \"{initcall_section}\"] - #[used(compiler)] - pub static __{ident}_initcall: extern \"C\" fn() -> - ::kernel::ffi::c_int = __{ident}_init; - - #[cfg(not(MODULE))] - #[cfg(CONFIG_HAVE_ARCH_PREL32_RELOCATIONS)] - ::core::arch::global_asm!( - r#\".section \"{initcall_section}\", \"a\" - __{ident}_initcall: - .long __{ident}_init - . - .previous - \"# + let modinfo_ts = modinfo.ts; + let params_ts = modinfo.param_ts; + + let ident_init = format_ident!("__{ident}_init"); + let ident_exit = format_ident!("__{ident}_exit"); + let ident_initcall = format_ident!("__{ident}_initcall"); + let initcall_section = ".initcall6.init"; + + let global_asm = format!( + r#".section "{initcall_section}", "a" + __{ident}_initcall: + .long __{ident}_init - . + .previous + "# + ); + + let name_cstr = CString::new(name.value()).expect("name contains NUL-terminator"); + + Ok(quote! { + /// The module name. + /// + /// Used by the printing macros, e.g. [`info!`]. + const __LOG_PREFIX: &[u8] = #name_cstr.to_bytes_with_nul(); + + // SAFETY: `__this_module` is constructed by the kernel at load time and will not be + // freed until the module is unloaded. + #[cfg(MODULE)] + static THIS_MODULE: ::kernel::ThisModule = unsafe { + extern "C" { + static __this_module: ::kernel::types::Opaque<::kernel::bindings::module>; + }; + + ::kernel::ThisModule::from_ptr(__this_module.get()) + }; + + #[cfg(not(MODULE))] + static THIS_MODULE: ::kernel::ThisModule = unsafe { + ::kernel::ThisModule::from_ptr(::core::ptr::null_mut()) + }; + + /// The `LocalModule` type is the type of the module created by `module!`, + /// `module_pci_driver!`, `module_platform_driver!`, etc. + type LocalModule = #type_; + + impl ::kernel::ModuleMetadata for #type_ { + const NAME: &'static ::kernel::str::CStr = #name_cstr; + } + + // Double nested modules, since then nobody can access the public items inside. + #[doc(hidden)] + mod __module_init { + mod __module_init { + use pin_init::PinInit; + + /// The "Rust loadable module" mark. + // + // This may be best done another way later on, e.g. as a new modinfo + // key or a new section. For the moment, keep it simple. + #[cfg(MODULE)] + #[used(compiler)] + static __IS_RUST_MODULE: () = (); + + static mut __MOD: ::core::mem::MaybeUninit<super::super::LocalModule> = + ::core::mem::MaybeUninit::uninit(); + + // Loadable modules need to export the `{init,cleanup}_module` identifiers. + /// # Safety + /// + /// This function must not be called after module initialization, because it may be + /// freed after that completes. + #[cfg(MODULE)] + #[no_mangle] + #[link_section = ".init.text"] + pub unsafe extern "C" fn init_module() -> ::kernel::ffi::c_int { + // SAFETY: This function is inaccessible to the outside due to the double + // module wrapping it. It is called exactly once by the C side via its + // unique name. + unsafe { __init() } + } + + #[cfg(MODULE)] + #[used(compiler)] + #[link_section = ".init.data"] + static __UNIQUE_ID___addressable_init_module: unsafe extern "C" fn() -> i32 = + init_module; + + #[cfg(MODULE)] + #[no_mangle] + #[link_section = ".exit.text"] + pub extern "C" fn cleanup_module() { + // SAFETY: + // - This function is inaccessible to the outside due to the double + // module wrapping it. It is called exactly once by the C side via its + // unique name, + // - furthermore it is only called after `init_module` has returned `0` + // (which delegates to `__init`). + unsafe { __exit() } + } + + #[cfg(MODULE)] + #[used(compiler)] + #[link_section = ".exit.data"] + static __UNIQUE_ID___addressable_cleanup_module: extern "C" fn() = cleanup_module; + + // Built-in modules are initialized through an initcall pointer + // and the identifiers need to be unique. + #[cfg(not(MODULE))] + #[cfg(not(CONFIG_HAVE_ARCH_PREL32_RELOCATIONS))] + #[link_section = #initcall_section] + #[used(compiler)] + pub static #ident_initcall: extern "C" fn() -> + ::kernel::ffi::c_int = #ident_init; + + #[cfg(not(MODULE))] + #[cfg(CONFIG_HAVE_ARCH_PREL32_RELOCATIONS)] + ::core::arch::global_asm!(#global_asm); + + #[cfg(not(MODULE))] + #[no_mangle] + pub extern "C" fn #ident_init() -> ::kernel::ffi::c_int { + // SAFETY: This function is inaccessible to the outside due to the double + // module wrapping it. It is called exactly once by the C side via its + // placement above in the initcall section. + unsafe { __init() } + } + + #[cfg(not(MODULE))] + #[no_mangle] + pub extern "C" fn #ident_exit() { + // SAFETY: + // - This function is inaccessible to the outside due to the double + // module wrapping it. It is called exactly once by the C side via its + // unique name, + // - furthermore it is only called after `#ident_init` has + // returned `0` (which delegates to `__init`). + unsafe { __exit() } + } + + /// # Safety + /// + /// This function must only be called once. + unsafe fn __init() -> ::kernel::ffi::c_int { + let initer = <super::super::LocalModule as ::kernel::InPlaceModule>::init( + &super::super::THIS_MODULE ); + // SAFETY: No data race, since `__MOD` can only be accessed by this module + // and there only `__init` and `__exit` access it. These functions are only + // called once and `__exit` cannot be called before or during `__init`. + match unsafe { initer.__pinned_init(__MOD.as_mut_ptr()) } { + Ok(m) => 0, + Err(e) => e.to_errno(), + } + } + + /// # Safety + /// + /// This function must + /// - only be called once, + /// - be called after `__init` has been called and returned `0`. + unsafe fn __exit() { + // SAFETY: No data race, since `__MOD` can only be accessed by this module + // and there only `__init` and `__exit` access it. These functions are only + // called once and `__init` was already called. + unsafe { + // Invokes `drop()` on `__MOD`, which should be used for cleanup. + __MOD.assume_init_drop(); + } + } + + #modinfo_ts + } + } - #[cfg(not(MODULE))] - #[doc(hidden)] - #[no_mangle] - pub extern \"C\" fn __{ident}_init() -> ::kernel::ffi::c_int {{ - // SAFETY: This function is inaccessible to the outside due to the double - // module wrapping it. It is called exactly once by the C side via its - // placement above in the initcall section. - unsafe {{ __init() }} - }} - - #[cfg(not(MODULE))] - #[doc(hidden)] - #[no_mangle] - pub extern \"C\" fn __{ident}_exit() {{ - // SAFETY: - // - This function is inaccessible to the outside due to the double - // module wrapping it. It is called exactly once by the C side via its - // unique name, - // - furthermore it is only called after `__{ident}_init` has - // returned `0` (which delegates to `__init`). - unsafe {{ __exit() }} - }} - - /// # Safety - /// - /// This function must only be called once. - unsafe fn __init() -> ::kernel::ffi::c_int {{ - let initer = - <{type_} as ::kernel::InPlaceModule>::init(&super::super::THIS_MODULE); - // SAFETY: No data race, since `__MOD` can only be accessed by this module - // and there only `__init` and `__exit` access it. These functions are only - // called once and `__exit` cannot be called before or during `__init`. - match unsafe {{ initer.__pinned_init(__MOD.as_mut_ptr()) }} {{ - Ok(m) => 0, - Err(e) => e.to_errno(), - }} - }} - - /// # Safety - /// - /// This function must - /// - only be called once, - /// - be called after `__init` has been called and returned `0`. - unsafe fn __exit() {{ - // SAFETY: No data race, since `__MOD` can only be accessed by this module - // and there only `__init` and `__exit` access it. These functions are only - // called once and `__init` was already called. - unsafe {{ - // Invokes `drop()` on `__MOD`, which should be used for cleanup. - __MOD.assume_init_drop(); - }} - }} - {modinfo} - }} - }} - mod module_parameters {{ - {params} - }} - ", - type_ = info.type_, - name = info.name, - ident = ident, - modinfo = modinfo.buffer, - params = modinfo.param_buffer, - initcall_section = ".initcall6.init" - ) - .parse() - .expect("Error parsing formatted string into token stream.") + mod module_parameters { + #params_ts + } + }) } diff --git a/rust/macros/paste.rs b/rust/macros/paste.rs index cce712d19855..2181e312a7d3 100644 --- a/rust/macros/paste.rs +++ b/rust/macros/paste.rs @@ -1,6 +1,6 @@ // SPDX-License-Identifier: GPL-2.0 -use proc_macro::{Delimiter, Group, Ident, Spacing, Span, TokenTree}; +use proc_macro2::{Delimiter, Group, Ident, Spacing, Span, TokenTree}; fn concat_helper(tokens: &[TokenTree]) -> Vec<(String, Span)> { let mut tokens = tokens.iter(); diff --git a/rust/macros/quote.rs b/rust/macros/quote.rs deleted file mode 100644 index ddfc21577539..000000000000 --- a/rust/macros/quote.rs +++ /dev/null @@ -1,182 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 OR MIT - -use proc_macro::{TokenStream, TokenTree}; - -pub(crate) trait ToTokens { - fn to_tokens(&self, tokens: &mut TokenStream); -} - -impl<T: ToTokens> ToTokens for Option<T> { - fn to_tokens(&self, tokens: &mut TokenStream) { - if let Some(v) = self { - v.to_tokens(tokens); - } - } -} - -impl ToTokens for proc_macro::Group { - fn to_tokens(&self, tokens: &mut TokenStream) { - tokens.extend([TokenTree::from(self.clone())]); - } -} - -impl ToTokens for proc_macro::Ident { - fn to_tokens(&self, tokens: &mut TokenStream) { - tokens.extend([TokenTree::from(self.clone())]); - } -} - -impl ToTokens for TokenTree { - fn to_tokens(&self, tokens: &mut TokenStream) { - tokens.extend([self.clone()]); - } -} - -impl ToTokens for TokenStream { - fn to_tokens(&self, tokens: &mut TokenStream) { - tokens.extend(self.clone()); - } -} - -/// Converts tokens into [`proc_macro::TokenStream`] and performs variable interpolations with -/// the given span. -/// -/// This is a similar to the -/// [`quote_spanned!`](https://docs.rs/quote/latest/quote/macro.quote_spanned.html) macro from the -/// `quote` crate but provides only just enough functionality needed by the current `macros` crate. -macro_rules! quote_spanned { - ($span:expr => $($tt:tt)*) => {{ - let mut tokens = ::proc_macro::TokenStream::new(); - { - #[allow(unused_variables)] - let span = $span; - quote_spanned!(@proc tokens span $($tt)*); - } - tokens - }}; - (@proc $v:ident $span:ident) => {}; - (@proc $v:ident $span:ident #$id:ident $($tt:tt)*) => { - $crate::quote::ToTokens::to_tokens(&$id, &mut $v); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident #(#$id:ident)* $($tt:tt)*) => { - for token in $id { - $crate::quote::ToTokens::to_tokens(&token, &mut $v); - } - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident ( $($inner:tt)* ) $($tt:tt)*) => { - #[allow(unused_mut)] - let mut tokens = ::proc_macro::TokenStream::new(); - quote_spanned!(@proc tokens $span $($inner)*); - $v.extend([::proc_macro::TokenTree::Group(::proc_macro::Group::new( - ::proc_macro::Delimiter::Parenthesis, - tokens, - ))]); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident [ $($inner:tt)* ] $($tt:tt)*) => { - let mut tokens = ::proc_macro::TokenStream::new(); - quote_spanned!(@proc tokens $span $($inner)*); - $v.extend([::proc_macro::TokenTree::Group(::proc_macro::Group::new( - ::proc_macro::Delimiter::Bracket, - tokens, - ))]); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident { $($inner:tt)* } $($tt:tt)*) => { - let mut tokens = ::proc_macro::TokenStream::new(); - quote_spanned!(@proc tokens $span $($inner)*); - $v.extend([::proc_macro::TokenTree::Group(::proc_macro::Group::new( - ::proc_macro::Delimiter::Brace, - tokens, - ))]); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident :: $($tt:tt)*) => { - $v.extend([::proc_macro::Spacing::Joint, ::proc_macro::Spacing::Alone].map(|spacing| { - ::proc_macro::TokenTree::Punct(::proc_macro::Punct::new(':', spacing)) - })); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident : $($tt:tt)*) => { - $v.extend([::proc_macro::TokenTree::Punct( - ::proc_macro::Punct::new(':', ::proc_macro::Spacing::Alone), - )]); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident , $($tt:tt)*) => { - $v.extend([::proc_macro::TokenTree::Punct( - ::proc_macro::Punct::new(',', ::proc_macro::Spacing::Alone), - )]); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident @ $($tt:tt)*) => { - $v.extend([::proc_macro::TokenTree::Punct( - ::proc_macro::Punct::new('@', ::proc_macro::Spacing::Alone), - )]); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident ! $($tt:tt)*) => { - $v.extend([::proc_macro::TokenTree::Punct( - ::proc_macro::Punct::new('!', ::proc_macro::Spacing::Alone), - )]); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident ; $($tt:tt)*) => { - $v.extend([::proc_macro::TokenTree::Punct( - ::proc_macro::Punct::new(';', ::proc_macro::Spacing::Alone), - )]); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident + $($tt:tt)*) => { - $v.extend([::proc_macro::TokenTree::Punct( - ::proc_macro::Punct::new('+', ::proc_macro::Spacing::Alone), - )]); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident = $($tt:tt)*) => { - $v.extend([::proc_macro::TokenTree::Punct( - ::proc_macro::Punct::new('=', ::proc_macro::Spacing::Alone), - )]); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident # $($tt:tt)*) => { - $v.extend([::proc_macro::TokenTree::Punct( - ::proc_macro::Punct::new('#', ::proc_macro::Spacing::Alone), - )]); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident & $($tt:tt)*) => { - $v.extend([::proc_macro::TokenTree::Punct( - ::proc_macro::Punct::new('&', ::proc_macro::Spacing::Alone), - )]); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident _ $($tt:tt)*) => { - $v.extend([::proc_macro::TokenTree::Ident( - ::proc_macro::Ident::new("_", $span), - )]); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident $id:ident $($tt:tt)*) => { - $v.extend([::proc_macro::TokenTree::Ident( - ::proc_macro::Ident::new(stringify!($id), $span), - )]); - quote_spanned!(@proc $v $span $($tt)*); - }; -} - -/// Converts tokens into [`proc_macro::TokenStream`] and performs variable interpolations with -/// mixed site span ([`Span::mixed_site()`]). -/// -/// This is a similar to the [`quote!`](https://docs.rs/quote/latest/quote/macro.quote.html) macro -/// from the `quote` crate but provides only just enough functionality needed by the current -/// `macros` crate. -/// -/// [`Span::mixed_site()`]: https://doc.rust-lang.org/proc_macro/struct.Span.html#method.mixed_site -macro_rules! quote { - ($($tt:tt)*) => { - quote_spanned!(::proc_macro::Span::mixed_site() => $($tt)*) - } -} diff --git a/rust/macros/vtable.rs b/rust/macros/vtable.rs index ee06044fcd4f..c6510b0c4ea1 100644 --- a/rust/macros/vtable.rs +++ b/rust/macros/vtable.rs @@ -1,96 +1,105 @@ // SPDX-License-Identifier: GPL-2.0 -use proc_macro::{Delimiter, Group, TokenStream, TokenTree}; -use std::collections::HashSet; -use std::fmt::Write; +use std::{ + collections::HashSet, + iter::Extend, // +}; -pub(crate) fn vtable(_attr: TokenStream, ts: TokenStream) -> TokenStream { - let mut tokens: Vec<_> = ts.into_iter().collect(); +use proc_macro2::{ + Ident, + TokenStream, // +}; +use quote::ToTokens; +use syn::{ + parse_quote, + Error, + ImplItem, + Item, + ItemImpl, + ItemTrait, + Result, + TraitItem, // +}; - // Scan for the `trait` or `impl` keyword. - let is_trait = tokens - .iter() - .find_map(|token| match token { - TokenTree::Ident(ident) => match ident.to_string().as_str() { - "trait" => Some(true), - "impl" => Some(false), - _ => None, - }, - _ => None, - }) - .expect("#[vtable] attribute should only be applied to trait or impl block"); +fn handle_trait(mut item: ItemTrait) -> Result<ItemTrait> { + let mut gen_items = Vec::new(); - // Retrieve the main body. The main body should be the last token tree. - let body = match tokens.pop() { - Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group, - _ => panic!("cannot locate main body of trait or impl block"), - }; + gen_items.push(parse_quote! { + /// A marker to prevent implementors from forgetting to use [`#[vtable]`](vtable) + /// attribute when implementing this trait. + const USE_VTABLE_ATTR: (); + }); - let mut body_it = body.stream().into_iter(); - let mut functions = Vec::new(); - let mut consts = HashSet::new(); - while let Some(token) = body_it.next() { - match token { - TokenTree::Ident(ident) if ident.to_string() == "fn" => { - let fn_name = match body_it.next() { - Some(TokenTree::Ident(ident)) => ident.to_string(), - // Possibly we've encountered a fn pointer type instead. - _ => continue, - }; - functions.push(fn_name); - } - TokenTree::Ident(ident) if ident.to_string() == "const" => { - let const_name = match body_it.next() { - Some(TokenTree::Ident(ident)) => ident.to_string(), - // Possibly we've encountered an inline const block instead. - _ => continue, - }; - consts.insert(const_name); - } - _ => (), + for item in &item.items { + if let TraitItem::Fn(fn_item) = item { + let name = &fn_item.sig.ident; + let gen_const_name = Ident::new( + &format!("HAS_{}", name.to_string().to_uppercase()), + name.span(), + ); + + // We don't know on the implementation-site whether a method is required or provided + // so we have to generate a const for all methods. + let cfg_attrs = crate::helpers::gather_cfg_attrs(&fn_item.attrs); + let comment = + format!("Indicates if the `{name}` method is overridden by the implementor."); + gen_items.push(parse_quote! { + #(#cfg_attrs)* + #[doc = #comment] + const #gen_const_name: bool = false; + }); } } - let mut const_items; - if is_trait { - const_items = " - /// A marker to prevent implementors from forgetting to use [`#[vtable]`](vtable) - /// attribute when implementing this trait. - const USE_VTABLE_ATTR: (); - " - .to_owned(); + item.items.extend(gen_items); + Ok(item) +} - for f in functions { - let gen_const_name = format!("HAS_{}", f.to_uppercase()); - // Skip if it's declared already -- this allows user override. - if consts.contains(&gen_const_name) { - continue; - } - // We don't know on the implementation-site whether a method is required or provided - // so we have to generate a const for all methods. - write!( - const_items, - "/// Indicates if the `{f}` method is overridden by the implementor. - const {gen_const_name}: bool = false;", - ) - .unwrap(); - consts.insert(gen_const_name); +fn handle_impl(mut item: ItemImpl) -> Result<ItemImpl> { + let mut gen_items = Vec::new(); + let mut defined_consts = HashSet::new(); + + // Iterate over all user-defined constants to gather any possible explicit overrides. + for item in &item.items { + if let ImplItem::Const(const_item) = item { + defined_consts.insert(const_item.ident.clone()); } - } else { - const_items = "const USE_VTABLE_ATTR: () = ();".to_owned(); + } + + gen_items.push(parse_quote! { + const USE_VTABLE_ATTR: () = (); + }); - for f in functions { - let gen_const_name = format!("HAS_{}", f.to_uppercase()); - if consts.contains(&gen_const_name) { + for item in &item.items { + if let ImplItem::Fn(fn_item) = item { + let name = &fn_item.sig.ident; + let gen_const_name = Ident::new( + &format!("HAS_{}", name.to_string().to_uppercase()), + name.span(), + ); + // Skip if it's declared already -- this allows user override. + if defined_consts.contains(&gen_const_name) { continue; } - write!(const_items, "const {gen_const_name}: bool = true;").unwrap(); + let cfg_attrs = crate::helpers::gather_cfg_attrs(&fn_item.attrs); + gen_items.push(parse_quote! { + #(#cfg_attrs)* + const #gen_const_name: bool = true; + }); } } - let new_body = vec![const_items.parse().unwrap(), body.stream()] - .into_iter() - .collect(); - tokens.push(TokenTree::Group(Group::new(Delimiter::Brace, new_body))); - tokens.into_iter().collect() + item.items.extend(gen_items); + Ok(item) +} + +pub(crate) fn vtable(input: Item) -> Result<TokenStream> { + match input { + Item::Trait(item) => Ok(handle_trait(item)?.into_token_stream()), + Item::Impl(item) => Ok(handle_impl(item)?.into_token_stream()), + _ => Err(Error::new_spanned( + input, + "`#[vtable]` attribute should only be applied to trait or impl block", + ))?, + } } |
