From e55ab1cf20193bbb7f71bf1adc1b95eeab969732 Mon Sep 17 00:00:00 2001
From: Davide Bertola <dade@dadeb.it>
Date: Wed, 8 Jun 2022 13:09:09 +0200
Subject: [PATCH] aya-gen: allow custom --ctypes-prefix

Allows the user to pass a custom --ctypes-prefix to bindgen #186
---
 aya-gen/src/generate.rs | 52 +++++++++++++++++++++++++++++++----------
 1 file changed, 40 insertions(+), 12 deletions(-)

diff --git a/aya-gen/src/generate.rs b/aya-gen/src/generate.rs
index 24c0eaea..0229ca09 100644
--- a/aya-gen/src/generate.rs
+++ b/aya-gen/src/generate.rs
@@ -43,7 +43,21 @@ pub fn generate<T: AsRef<str>>(
     types: &[T],
     additional_flags: &[T],
 ) -> Result<String, Error> {
+    let additional_flags = additional_flags
+        .iter()
+        .map(|s| s.as_ref().into())
+        .collect::<Vec<_>>();
+
     let mut bindgen = bindgen::bpf_builder();
+    let (additional_flags, ctypes_prefix) = strip_ctypes_prefix(&additional_flags);
+
+    if let Some(prefix) = ctypes_prefix {
+        bindgen = bindgen.ctypes_prefix(prefix)
+    }
+
+    for ty in types {
+        bindgen = bindgen.allowlist_type(ty);
+    }
 
     let (c_header, name) = match &input_file {
         InputFile::Btf(path) => (c_header_from_btf(path)?, "kernel_types.h"),
@@ -53,22 +67,12 @@ pub fn generate<T: AsRef<str>>(
         ),
     };
 
-    for ty in types {
-        bindgen = bindgen.allowlist_type(ty);
-    }
-
     let dir = tempdir().unwrap();
     let file_path = dir.path().join(name);
     let mut file = File::create(&file_path).unwrap();
     let _ = file.write(c_header.as_bytes()).unwrap();
 
-    let flags = combine_flags(
-        &bindgen.command_line_flags(),
-        &additional_flags
-            .iter()
-            .map(|s| s.as_ref().into())
-            .collect::<Vec<_>>(),
-    );
+    let flags = combine_flags(&bindgen.command_line_flags(), &additional_flags);
 
     let output = Command::new("bindgen")
         .arg(file_path)
@@ -104,6 +108,20 @@ fn c_header_from_btf(path: &Path) -> Result<String, Error> {
     Ok(str::from_utf8(&output.stdout).unwrap().to_owned())
 }
 
+fn strip_ctypes_prefix(s: &[String]) -> (Vec<String>, Option<String>) {
+    let mut it = s.splitn(2, |el| el == "--ctypes-prefix");
+    let mut prefix = None;
+    let mut flags = Vec::new();
+    flags.extend_from_slice(it.next().unwrap());
+
+    if let Some(after) = it.next() {
+        prefix = after.get(0).cloned();
+        flags.extend_from_slice(&after[1..]);
+    }
+
+    (flags, prefix)
+}
+
 fn combine_flags(s1: &[String], s2: &[String]) -> Vec<String> {
     let mut flags = Vec::new();
     let mut extra = Vec::new();
@@ -129,13 +147,23 @@ fn combine_flags(s1: &[String], s2: &[String]) -> Vec<String> {
 
 #[cfg(test)]
 mod test {
-    use super::combine_flags;
+    use super::{combine_flags, strip_ctypes_prefix};
 
     fn to_vec(s: &str) -> Vec<String> {
         s.split(" ").map(|x| x.into()).collect()
     }
 
     #[test]
+    fn test_strip_ctypes_prefix() {
+        let (flags, prefix) = strip_ctypes_prefix(&to_vec("asd --ctypes-prefix foo dsa"));
+        assert_eq!(flags, to_vec("asd dsa"));
+        assert_eq!(prefix, Some("foo".to_string()));
+
+        let (flags, prefix) = strip_ctypes_prefix(&to_vec("asd --ctypes-prefi foo dsa"));
+        assert_eq!(flags, to_vec("asd --ctypes-prefi foo dsa"));
+        assert_eq!(prefix, None);
+    }
+
     #[test]
     fn test_combine_flags() {
         assert_eq!(