From 4f13576594de87d8c6c4da12a6783c23155e359d Mon Sep 17 00:00:00 2001 From: Davide Bertola Date: Mon, 13 Jun 2022 10:21:53 +0200 Subject: [PATCH] aya-gen: allow passing custom --ctypes-prefix to bindgen (#312) Allows users to pass a custom --ctypes-prefix to bindgen #186 --- aya-gen/src/generate.rs | 78 +++++++++++++++++++++++++++++++---------- 1 file changed, 60 insertions(+), 18 deletions(-) diff --git a/aya-gen/src/generate.rs b/aya-gen/src/generate.rs index 12ac7034..4bd4b8e7 100644 --- a/aya-gen/src/generate.rs +++ b/aya-gen/src/generate.rs @@ -43,7 +43,21 @@ pub fn generate>( types: &[T], additional_flags: &[T], ) -> Result { + let additional_flags = additional_flags + .iter() + .map(|s| s.as_ref().into()) + .collect::>(); + let mut bindgen = bindgen::bpf_builder(); + let (additional_flags, ctypes_prefix) = extract_ctypes_prefix(&additional_flags); + + if let Some(prefix) = ctypes_prefix { + bindgen = bindgen.ctypes_prefix(prefix) + } + + for ty in types { + bindgen = bindgen.allowlist_type(ty); + } let (c_header, name) = match &input_file { InputFile::Btf(path) => (c_header_from_btf(path)?, "kernel_types.h"), @@ -53,22 +67,12 @@ pub fn generate>( ), }; - for ty in types { - bindgen = bindgen.allowlist_type(ty); - } - let dir = tempdir().unwrap(); let file_path = dir.path().join(name); let mut file = File::create(&file_path).unwrap(); let _ = file.write(c_header.as_bytes()).unwrap(); - let flags = combine_flags( - &bindgen.command_line_flags(), - &additional_flags - .iter() - .map(|s| s.as_ref().into()) - .collect::>(), - ); + let flags = combine_flags(&bindgen.command_line_flags(), &additional_flags); let output = Command::new("bindgen") .arg(file_path) @@ -104,14 +108,28 @@ fn c_header_from_btf(path: &Path) -> Result { Ok(str::from_utf8(&output.stdout).unwrap().to_owned()) } +fn extract_ctypes_prefix(s: &[String]) -> (Vec, Option) { + if let Some(index) = s.iter().position(|el| el == "--ctypes-prefix") { + if index < s.len() - 1 { + let mut flags = Vec::new(); + flags.extend_from_slice(&s[0..index]); + // skip ["--ctypes-prefix", "value"] + flags.extend_from_slice(&s[index + 2..]); + return (flags, s.get(index + 1).cloned()); + } + } + + (s.to_vec(), None) +} + fn combine_flags(s1: &[String], s2: &[String]) -> Vec { - let mut args = Vec::new(); + let mut flags = Vec::new(); let mut extra = Vec::new(); for s in [s1, s2] { let mut s = s.splitn(2, |el| el == "--"); // append args - args.extend(s.next().unwrap().iter().cloned()); + flags.extend(s.next().unwrap().iter().cloned()); if let Some(e) = s.next() { // append extra args extra.extend(e.iter().cloned()); @@ -120,23 +138,47 @@ fn combine_flags(s1: &[String], s2: &[String]) -> Vec { // append extra args if !extra.is_empty() { - args.push("--".to_string()); - args.extend(extra); + flags.push("--".to_string()); + flags.extend(extra); } - args + flags } #[cfg(test)] mod test { - use super::combine_flags; + use super::{combine_flags, extract_ctypes_prefix}; fn to_vec(s: &str) -> Vec { s.split(" ").map(|x| x.into()).collect() } #[test] - fn combine_arguments_test() { + fn test_extract_ctypes_prefix() { + let (flags, prefix) = extract_ctypes_prefix(&to_vec("foo --ctypes-prefix bar baz")); + assert_eq!(flags, to_vec("foo baz")); + assert_eq!(prefix, Some("bar".to_string())); + + let (flags, prefix) = extract_ctypes_prefix(&to_vec("foo --ctypes-prefi bar baz")); + assert_eq!(flags, to_vec("foo --ctypes-prefi bar baz")); + assert_eq!(prefix, None); + + let (flags, prefix) = extract_ctypes_prefix(&to_vec("--foo bar --ctypes-prefix")); + assert_eq!(flags, to_vec("--foo bar --ctypes-prefix")); + assert_eq!(prefix, None); + + let (flags, prefix) = extract_ctypes_prefix(&to_vec("--ctypes-prefix foo")); + let empty: Vec = Vec::new(); + assert_eq!(flags, empty); + assert_eq!(prefix, Some("foo".to_string())); + + let (flags, prefix) = extract_ctypes_prefix(&to_vec("--ctypes-prefix")); + assert_eq!(flags, to_vec("--ctypes-prefix")); + assert_eq!(prefix, None); + } + + #[test] + fn test_combine_flags() { assert_eq!( combine_flags(&to_vec("a b"), &to_vec("c d"),).join(" "), "a b c d".to_string(),