From c2eea45e1c55928fabf092fb2561ba5c82ad5cb9 Mon Sep 17 00:00:00 2001
From: Davide Bertola <dade@dadeb.it>
Date: Fri, 3 Jun 2022 22:28:25 +0200
Subject: [PATCH] aya-gen: add support for passing additional bindgen args

---
 aya-gen/src/bin/aya-gen.rs |  15 +++-
 aya-gen/src/btf_types.rs   | 105 -----------------------
 aya-gen/src/generate.rs    | 170 +++++++++++++++++++++++++++++++++++++
 aya-gen/src/lib.rs         |   2 +-
 4 files changed, 182 insertions(+), 110 deletions(-)
 delete mode 100644 aya-gen/src/btf_types.rs
 create mode 100644 aya-gen/src/generate.rs

diff --git a/aya-gen/src/bin/aya-gen.rs b/aya-gen/src/bin/aya-gen.rs
index b57a3bae..8f976aa0 100644
--- a/aya-gen/src/bin/aya-gen.rs
+++ b/aya-gen/src/bin/aya-gen.rs
@@ -1,4 +1,4 @@
-use aya_gen::btf_types::{generate, InputFile};
+use aya_gen::generate::{generate, InputFile};
 
 use std::{path::PathBuf, process::exit};
 
@@ -19,6 +19,8 @@ enum Command {
         #[clap(long, conflicts_with = "btf")]
         header: Option<PathBuf>,
         names: Vec<String>,
+        #[clap(last = true)]
+        bindgen_args: Vec<String>,
     },
 }
 
@@ -32,12 +34,17 @@ fn main() {
 fn try_main() -> Result<(), anyhow::Error> {
     let opts = Options::parse();
     match opts.command {
-        Command::Generate { btf, header, names } => {
+        Command::Generate {
+            btf,
+            header,
+            names,
+            bindgen_args,
+        } => {
             let bindings: String;
             if let Some(header) = header {
-                bindings = generate(InputFile::Header(header), &names)?;
+                bindings = generate(InputFile::Header(header), &names, &bindgen_args)?;
             } else {
-                bindings = generate(InputFile::Btf(btf), &names)?;
+                bindings = generate(InputFile::Btf(btf), &names, &bindgen_args)?;
             }
             println!("{}", bindings);
         }
diff --git a/aya-gen/src/btf_types.rs b/aya-gen/src/btf_types.rs
deleted file mode 100644
index d23fe2bc..00000000
--- a/aya-gen/src/btf_types.rs
+++ /dev/null
@@ -1,105 +0,0 @@
-use std::{
-    fs::{self, File},
-    io::{self, Write},
-    path::{Path, PathBuf},
-    process::Command,
-    str::from_utf8,
-};
-
-use tempfile::tempdir;
-
-use thiserror::Error;
-
-use crate::bindgen;
-
-#[derive(Error, Debug)]
-pub enum Error {
-    #[error("error executing bpftool")]
-    BpfTool(#[source] io::Error),
-
-    #[error("{stderr}\nbpftool failed with exit code {code}")]
-    BpfToolExit { code: i32, stderr: String },
-
-    #[error("bindgen failed")]
-    Bindgen(#[source] io::Error),
-
-    #[error("{stderr}\nbindgen failed with exit code {code}")]
-    BindgenExit { code: i32, stderr: String },
-
-    #[error("rustfmt failed")]
-    Rustfmt(#[source] io::Error),
-
-    #[error("error reading header file")]
-    ReadHeaderFile,
-}
-
-pub enum InputFile {
-    Btf(PathBuf),
-    Header(PathBuf),
-}
-
-pub fn generate<T: AsRef<str>>(input_file: InputFile, types: &[T]) -> Result<String, Error> {
-    let mut bindgen = bindgen::bpf_builder();
-
-    let (c_header, name) = match input_file {
-        InputFile::Btf(path) => (c_header_from_btf(&path)?, "kernel_types.h".to_string()),
-        InputFile::Header(header) => (
-            fs::read_to_string(&header).map_err(|_| Error::ReadHeaderFile)?,
-            header.file_name().unwrap().to_str().unwrap().to_owned(),
-        ),
-    };
-
-    for ty in types {
-        bindgen = bindgen.allowlist_type(ty);
-    }
-
-    // TODO: check if this part should be moved to bindgen::generate()
-    let dir = tempdir().unwrap();
-    let file_path = dir.path().join(name);
-    let mut file = File::create(&file_path).unwrap();
-    let _ = file.write(c_header.as_bytes()).unwrap();
-
-    let flags = bindgen.command_line_flags();
-
-    // TODO: check proper logging; it seems useful to see what command is
-    // launched but we can't disturb the normal output of the aya-gen tool
-    println!(
-        "Launching bindgen {} {}",
-        file_path.to_str().unwrap(),
-        flags.join(" ")
-    );
-
-    let output = Command::new("bindgen")
-        .arg(file_path)
-        .args(flags)
-        // TODO: pass additional arguments after --
-        .output()
-        .map_err(Error::Bindgen)?;
-
-    if !output.status.success() {
-        return Err(Error::BindgenExit {
-            code: output.status.code().unwrap(),
-            stderr: from_utf8(&output.stderr).unwrap().to_owned(),
-        });
-    }
-
-    Ok(from_utf8(&output.stdout).unwrap().to_owned())
-}
-
-fn c_header_from_btf(path: &Path) -> Result<String, Error> {
-    let output = Command::new("bpftool")
-        .args(&["btf", "dump", "file"])
-        .arg(path)
-        .args(&["format", "c"])
-        .output()
-        .map_err(Error::BpfTool)?;
-
-    if !output.status.success() {
-        return Err(Error::BpfToolExit {
-            code: output.status.code().unwrap(),
-            stderr: from_utf8(&output.stderr).unwrap().to_owned(),
-        });
-    }
-
-    Ok(from_utf8(&output.stdout).unwrap().to_owned())
-}
diff --git a/aya-gen/src/generate.rs b/aya-gen/src/generate.rs
new file mode 100644
index 00000000..12ac7034
--- /dev/null
+++ b/aya-gen/src/generate.rs
@@ -0,0 +1,170 @@
+use std::{
+    fs::{self, File},
+    io::{self, Write},
+    path::{Path, PathBuf},
+    process::Command,
+    str,
+};
+
+use tempfile::tempdir;
+
+use thiserror::Error;
+
+use crate::bindgen;
+
+#[derive(Error, Debug)]
+pub enum Error {
+    #[error("error executing bpftool")]
+    BpfTool(#[source] io::Error),
+
+    #[error("{stderr}\nbpftool failed with exit code {code}")]
+    BpfToolExit { code: i32, stderr: String },
+
+    #[error("bindgen failed")]
+    Bindgen(#[source] io::Error),
+
+    #[error("{stderr}\nbindgen failed with exit code {code}")]
+    BindgenExit { code: i32, stderr: String },
+
+    #[error("rustfmt failed")]
+    Rustfmt(#[source] io::Error),
+
+    #[error("error reading header file")]
+    ReadHeaderFile(#[source] io::Error),
+}
+
+pub enum InputFile {
+    Btf(PathBuf),
+    Header(PathBuf),
+}
+
+pub fn generate<T: AsRef<str>>(
+    input_file: InputFile,
+    types: &[T],
+    additional_flags: &[T],
+) -> Result<String, Error> {
+    let mut bindgen = bindgen::bpf_builder();
+
+    let (c_header, name) = match &input_file {
+        InputFile::Btf(path) => (c_header_from_btf(path)?, "kernel_types.h"),
+        InputFile::Header(header) => (
+            fs::read_to_string(&header).map_err(Error::ReadHeaderFile)?,
+            header.file_name().unwrap().to_str().unwrap(),
+        ),
+    };
+
+    for ty in types {
+        bindgen = bindgen.allowlist_type(ty);
+    }
+
+    let dir = tempdir().unwrap();
+    let file_path = dir.path().join(name);
+    let mut file = File::create(&file_path).unwrap();
+    let _ = file.write(c_header.as_bytes()).unwrap();
+
+    let flags = combine_flags(
+        &bindgen.command_line_flags(),
+        &additional_flags
+            .iter()
+            .map(|s| s.as_ref().into())
+            .collect::<Vec<_>>(),
+    );
+
+    let output = Command::new("bindgen")
+        .arg(file_path)
+        .args(&flags)
+        .output()
+        .map_err(Error::Bindgen)?;
+
+    if !output.status.success() {
+        return Err(Error::BindgenExit {
+            code: output.status.code().unwrap(),
+            stderr: str::from_utf8(&output.stderr).unwrap().to_owned(),
+        });
+    }
+
+    Ok(str::from_utf8(&output.stdout).unwrap().to_owned())
+}
+
+fn c_header_from_btf(path: &Path) -> Result<String, Error> {
+    let output = Command::new("bpftool")
+        .args(&["btf", "dump", "file"])
+        .arg(path)
+        .args(&["format", "c"])
+        .output()
+        .map_err(Error::BpfTool)?;
+
+    if !output.status.success() {
+        return Err(Error::BpfToolExit {
+            code: output.status.code().unwrap(),
+            stderr: str::from_utf8(&output.stderr).unwrap().to_owned(),
+        });
+    }
+
+    Ok(str::from_utf8(&output.stdout).unwrap().to_owned())
+}
+
+fn combine_flags(s1: &[String], s2: &[String]) -> Vec<String> {
+    let mut args = Vec::new();
+    let mut extra = Vec::new();
+
+    for s in [s1, s2] {
+        let mut s = s.splitn(2, |el| el == "--");
+        // append args
+        args.extend(s.next().unwrap().iter().cloned());
+        if let Some(e) = s.next() {
+            // append extra args
+            extra.extend(e.iter().cloned());
+        }
+    }
+
+    // append extra args
+    if !extra.is_empty() {
+        args.push("--".to_string());
+        args.extend(extra);
+    }
+
+    args
+}
+
+#[cfg(test)]
+mod test {
+    use super::combine_flags;
+
+    fn to_vec(s: &str) -> Vec<String> {
+        s.split(" ").map(|x| x.into()).collect()
+    }
+
+    #[test]
+    fn combine_arguments_test() {
+        assert_eq!(
+            combine_flags(&to_vec("a b"), &to_vec("c d"),).join(" "),
+            "a b c d".to_string(),
+        );
+
+        assert_eq!(
+            combine_flags(&to_vec("a -- b"), &to_vec("a b"),).join(" "),
+            "a a b -- b".to_string(),
+        );
+
+        assert_eq!(
+            combine_flags(&to_vec("a -- b"), &to_vec("c d"),).join(" "),
+            "a c d -- b".to_string(),
+        );
+
+        assert_eq!(
+            combine_flags(&to_vec("a b"), &to_vec("c -- d"),).join(" "),
+            "a b c -- d".to_string(),
+        );
+
+        assert_eq!(
+            combine_flags(&to_vec("a -- b"), &to_vec("c -- d"),).join(" "),
+            "a c -- b d".to_string(),
+        );
+
+        assert_eq!(
+            combine_flags(&to_vec("a -- b"), &to_vec("-- c d"),).join(" "),
+            "a -- b c d".to_string(),
+        );
+    }
+}
diff --git a/aya-gen/src/lib.rs b/aya-gen/src/lib.rs
index 558a8e0d..cb1b9545 100644
--- a/aya-gen/src/lib.rs
+++ b/aya-gen/src/lib.rs
@@ -5,7 +5,7 @@ use std::{
 };
 
 pub mod bindgen;
-pub mod btf_types;
+pub mod generate;
 pub mod rustfmt;
 
 pub fn write_to_file<T: AsRef<Path>>(path: T, code: &str) -> Result<(), io::Error> {