diff --git a/aya-gen/src/bin/aya-gen.rs b/aya-gen/src/bin/aya-gen.rs index 71867dbf..b57a3bae 100644 --- a/aya-gen/src/bin/aya-gen.rs +++ b/aya-gen/src/bin/aya-gen.rs @@ -1,4 +1,4 @@ -use aya_gen::btf_types; +use aya_gen::btf_types::{generate, InputFile}; use std::{path::PathBuf, process::exit}; @@ -12,10 +12,12 @@ pub struct Options { #[derive(Parser)] enum Command { - #[clap(name = "btf-types")] - BtfTypes { + #[clap(name = "generate")] + Generate { #[clap(long, default_value = "/sys/kernel/btf/vmlinux")] btf: PathBuf, + #[clap(long, conflicts_with = "btf")] + header: Option, names: Vec, }, } @@ -30,8 +32,13 @@ fn main() { fn try_main() -> Result<(), anyhow::Error> { let opts = Options::parse(); match opts.command { - Command::BtfTypes { btf, names } => { - let bindings = btf_types::generate(&btf, &names)?; + Command::Generate { btf, header, names } => { + let bindings: String; + if let Some(header) = header { + bindings = generate(InputFile::Header(header), &names)?; + } else { + bindings = generate(InputFile::Btf(btf), &names)?; + } println!("{}", bindings); } }; diff --git a/aya-gen/src/btf_types.rs b/aya-gen/src/btf_types.rs index 214c1b0e..a1c55443 100644 --- a/aya-gen/src/btf_types.rs +++ b/aya-gen/src/btf_types.rs @@ -1,4 +1,9 @@ -use std::{io, path::Path, process::Command, str::from_utf8}; +use std::{ + fs, io, + path::{Path, PathBuf}, + process::Command, + str::from_utf8, +}; use thiserror::Error; @@ -17,13 +22,30 @@ pub enum Error { #[error("rustfmt failed")] Rustfmt(#[source] io::Error), + + #[error("error reading header file")] + ReadHeaderFile, +} + +pub enum InputFile { + Btf(PathBuf), + Header(PathBuf), } -pub fn generate>(btf_file: &Path, types: &[T]) -> Result { +pub fn generate>(input_file: InputFile, types: &[T]) -> Result { let mut bindgen = bindgen::bpf_builder(); - let c_header = c_header_from_btf(btf_file)?; - bindgen = bindgen.header_contents("kernel_types.h", &c_header); + match input_file { + InputFile::Btf(path) => { + let c_header = c_header_from_btf(&path)?; + bindgen = bindgen.header_contents("kernel_types.h", &c_header); + } + InputFile::Header(header) => { + let c_header = fs::read_to_string(&header).map_err(|_| Error::ReadHeaderFile)?; + let name = Path::new(&header).file_name().unwrap().to_str().unwrap(); + bindgen = bindgen.header_contents(name, &c_header); + } + } for ty in types { bindgen = bindgen.allowlist_type(ty);