From 3694ccb4694f694b738178f57b36fafcd1e9b237 Mon Sep 17 00:00:00 2001
From: Davide Bertola <dade@dadeb.it>
Date: Fri, 3 Jun 2022 15:29:25 +0200
Subject: [PATCH] aya-gen: add --header option

---
 aya-gen/src/bin/aya-gen.rs | 17 ++++++++++++-----
 aya-gen/src/btf_types.rs   | 30 ++++++++++++++++++++++++++----
 2 files changed, 38 insertions(+), 9 deletions(-)

diff --git a/aya-gen/src/bin/aya-gen.rs b/aya-gen/src/bin/aya-gen.rs
index 71867dbf..b57a3bae 100644
--- a/aya-gen/src/bin/aya-gen.rs
+++ b/aya-gen/src/bin/aya-gen.rs
@@ -1,4 +1,4 @@
-use aya_gen::btf_types;
+use aya_gen::btf_types::{generate, InputFile};
 
 use std::{path::PathBuf, process::exit};
 
@@ -12,10 +12,12 @@ pub struct Options {
 
 #[derive(Parser)]
 enum Command {
-    #[clap(name = "btf-types")]
-    BtfTypes {
+    #[clap(name = "generate")]
+    Generate {
         #[clap(long, default_value = "/sys/kernel/btf/vmlinux")]
         btf: PathBuf,
+        #[clap(long, conflicts_with = "btf")]
+        header: Option<PathBuf>,
         names: Vec<String>,
     },
 }
@@ -30,8 +32,13 @@ fn main() {
 fn try_main() -> Result<(), anyhow::Error> {
     let opts = Options::parse();
     match opts.command {
-        Command::BtfTypes { btf, names } => {
-            let bindings = btf_types::generate(&btf, &names)?;
+        Command::Generate { btf, header, names } => {
+            let bindings: String;
+            if let Some(header) = header {
+                bindings = generate(InputFile::Header(header), &names)?;
+            } else {
+                bindings = generate(InputFile::Btf(btf), &names)?;
+            }
             println!("{}", bindings);
         }
     };
diff --git a/aya-gen/src/btf_types.rs b/aya-gen/src/btf_types.rs
index 214c1b0e..a1c55443 100644
--- a/aya-gen/src/btf_types.rs
+++ b/aya-gen/src/btf_types.rs
@@ -1,4 +1,9 @@
-use std::{io, path::Path, process::Command, str::from_utf8};
+use std::{
+    fs, io,
+    path::{Path, PathBuf},
+    process::Command,
+    str::from_utf8,
+};
 
 use thiserror::Error;
 
@@ -17,13 +22,30 @@ pub enum Error {
 
     #[error("rustfmt failed")]
     Rustfmt(#[source] io::Error),
+
+    #[error("error reading header file")]
+    ReadHeaderFile,
+}
+
+pub enum InputFile {
+    Btf(PathBuf),
+    Header(PathBuf),
 }
 
-pub fn generate<T: AsRef<str>>(btf_file: &Path, types: &[T]) -> Result<String, Error> {
+pub fn generate<T: AsRef<str>>(input_file: InputFile, types: &[T]) -> Result<String, Error> {
     let mut bindgen = bindgen::bpf_builder();
 
-    let c_header = c_header_from_btf(btf_file)?;
-    bindgen = bindgen.header_contents("kernel_types.h", &c_header);
+    match input_file {
+        InputFile::Btf(path) => {
+            let c_header = c_header_from_btf(&path)?;
+            bindgen = bindgen.header_contents("kernel_types.h", &c_header);
+        }
+        InputFile::Header(header) => {
+            let c_header = fs::read_to_string(&header).map_err(|_| Error::ReadHeaderFile)?;
+            let name = Path::new(&header).file_name().unwrap().to_str().unwrap();
+            bindgen = bindgen.header_contents(name, &c_header);
+        }
+    }
 
     for ty in types {
         bindgen = bindgen.allowlist_type(ty);