diff --git a/aya/src/bpf.rs b/aya/src/bpf.rs index 33640239..91d3a3aa 100644 --- a/aya/src/bpf.rs +++ b/aya/src/bpf.rs @@ -16,7 +16,7 @@ use crate::{ maps::{Map, MapError, MapLock, MapRef, MapRefMut}, obj::{ btf::{Btf, BtfError}, - Object, ParseError, ProgramKind, + Object, ParseError, ProgramSection, }, programs::{ CgroupSkb, CgroupSkbAttachType, KProbe, LircMode2, ProbeKind, Program, ProgramData, @@ -148,55 +148,57 @@ impl Bpf { .programs .drain() .map(|(name, obj)| { - let kind = obj.kind; + let section = obj.section.clone(); let data = ProgramData { obj, name: name.clone(), fd: None, links: Vec::new(), }; - let program = match kind { - ProgramKind::KProbe => Program::KProbe(KProbe { + let program = match section { + ProgramSection::KProbe { .. } => Program::KProbe(KProbe { data, kind: ProbeKind::KProbe, }), - ProgramKind::KRetProbe => Program::KProbe(KProbe { + ProgramSection::KRetProbe { .. } => Program::KProbe(KProbe { data, kind: ProbeKind::KRetProbe, }), - ProgramKind::UProbe => Program::UProbe(UProbe { + ProgramSection::UProbe { .. } => Program::UProbe(UProbe { data, kind: ProbeKind::UProbe, }), - ProgramKind::URetProbe => Program::UProbe(UProbe { + ProgramSection::URetProbe { .. } => Program::UProbe(UProbe { data, kind: ProbeKind::URetProbe, }), - ProgramKind::TracePoint => Program::TracePoint(TracePoint { data }), - ProgramKind::SocketFilter => Program::SocketFilter(SocketFilter { data }), - ProgramKind::Xdp => Program::Xdp(Xdp { data }), - ProgramKind::SkMsg => Program::SkMsg(SkMsg { data }), - ProgramKind::SkSkbStreamParser => Program::SkSkb(SkSkb { + ProgramSection::TracePoint { .. } => Program::TracePoint(TracePoint { data }), + ProgramSection::SocketFilter { .. } => { + Program::SocketFilter(SocketFilter { data }) + } + ProgramSection::Xdp { .. } => Program::Xdp(Xdp { data }), + ProgramSection::SkMsg { .. } => Program::SkMsg(SkMsg { data }), + ProgramSection::SkSkbStreamParser { .. } => Program::SkSkb(SkSkb { data, kind: SkSkbKind::StreamParser, }), - ProgramKind::SkSkbStreamVerdict => Program::SkSkb(SkSkb { + ProgramSection::SkSkbStreamVerdict { .. } => Program::SkSkb(SkSkb { data, kind: SkSkbKind::StreamVerdict, }), - ProgramKind::SockOps => Program::SockOps(SockOps { data }), - ProgramKind::SchedClassifier => { + ProgramSection::SockOps { .. } => Program::SockOps(SockOps { data }), + ProgramSection::SchedClassifier { .. } => { Program::SchedClassifier(SchedClassifier { data }) } - ProgramKind::CgroupSkbIngress => Program::CgroupSkb(CgroupSkb { + ProgramSection::CgroupSkbIngress { .. } => Program::CgroupSkb(CgroupSkb { data, expected_attach_type: Some(CgroupSkbAttachType::Ingress), }), - ProgramKind::CgroupSkbEgress => Program::CgroupSkb(CgroupSkb { + ProgramSection::CgroupSkbEgress { .. } => Program::CgroupSkb(CgroupSkb { data, expected_attach_type: Some(CgroupSkbAttachType::Egress), }), - ProgramKind::LircMode2 => Program::LircMode2(LircMode2 { data }), + ProgramSection::LircMode2 { .. } => Program::LircMode2(LircMode2 { data }), }; (name, program) diff --git a/aya/src/obj/mod.rs b/aya/src/obj/mod.rs index 7caf76d4..45b1a3ce 100644 --- a/aya/src/obj/mod.rs +++ b/aya/src/obj/mod.rs @@ -51,7 +51,7 @@ pub struct Map { pub(crate) struct Program { pub(crate) license: CString, pub(crate) kernel_version: KernelVersion, - pub(crate) kind: ProgramKind, + pub(crate) section: ProgramSection, pub(crate) function: Function, } @@ -64,49 +64,86 @@ pub(crate) struct Function { pub(crate) instructions: Vec, } -#[derive(Debug, Copy, Clone)] -pub enum ProgramKind { - KProbe, - KRetProbe, - UProbe, - URetProbe, - TracePoint, - SocketFilter, - Xdp, - SkMsg, - SkSkbStreamParser, - SkSkbStreamVerdict, - SockOps, - SchedClassifier, - CgroupSkbIngress, - CgroupSkbEgress, - LircMode2, +#[derive(Debug, Clone)] +pub enum ProgramSection { + KRetProbe { name: String }, + KProbe { name: String }, + UProbe { name: String }, + URetProbe { name: String }, + TracePoint { name: String }, + SocketFilter { name: String }, + Xdp { name: String }, + SkMsg { name: String }, + SkSkbStreamParser { name: String }, + SkSkbStreamVerdict { name: String }, + SockOps { name: String }, + SchedClassifier { name: String }, + CgroupSkbIngress { name: String }, + CgroupSkbEgress { name: String }, + LircMode2 { name: String }, } -impl FromStr for ProgramKind { +impl ProgramSection { + fn name(&self) -> &str { + match self { + ProgramSection::KRetProbe { name } => name, + ProgramSection::KProbe { name } => name, + ProgramSection::UProbe { name } => name, + ProgramSection::URetProbe { name } => name, + ProgramSection::TracePoint { name } => name, + ProgramSection::SocketFilter { name } => name, + ProgramSection::Xdp { name } => name, + ProgramSection::SkMsg { name } => name, + ProgramSection::SkSkbStreamParser { name } => name, + ProgramSection::SkSkbStreamVerdict { name } => name, + ProgramSection::SockOps { name } => name, + ProgramSection::SchedClassifier { name } => name, + ProgramSection::CgroupSkbIngress { name } => name, + ProgramSection::CgroupSkbEgress { name } => name, + ProgramSection::LircMode2 { name } => name, + } + } +} + +impl FromStr for ProgramSection { type Err = ParseError; - fn from_str(kind: &str) -> Result { - use ProgramKind::*; + fn from_str(section: &str) -> Result { + use ProgramSection::*; + + // parse the common case, eg "xdp/program_name" or + // "sk_skb/stream_verdict/program_name" + let mut parts = section.rsplitn(2, "/").collect::>(); + if parts.len() == 1 { + parts.push(parts[0]); + } + let kind = parts[1]; + let name = parts[0].to_owned(); + Ok(match kind { - "kprobe" => KProbe, - "kretprobe" => KRetProbe, - "uprobe" => UProbe, - "uretprobe" => URetProbe, - "xdp" => Xdp, - "tracepoint" => TracePoint, - "socket_filter" => SocketFilter, - "sk_msg" => SkMsg, - "sk_skb/stream_parser" => SkSkbStreamParser, - "sk_skb/stream_verdict" => SkSkbStreamVerdict, - "sockops" => SockOps, - "classifier" => SchedClassifier, - "cgroup_skb/ingress" => CgroupSkbIngress, - "cgroup_skb/egress" => CgroupSkbEgress, - "lirc_mode2" => LircMode2, + "kprobe" => KProbe { name }, + "kretprobe" => KRetProbe { name }, + "uprobe" => UProbe { name }, + "uretprobe" => URetProbe { name }, + "xdp" => Xdp { name }, + _ if kind.starts_with("tracepoint") || kind.starts_with("tp") => { + // tracepoint sections are named `tracepoint/category/event_name`, + // and we want to parse the name as "category/event_name" + let name = section.splitn(2, "/").last().unwrap().to_owned(); + TracePoint { name } + } + "socket_filter" => SocketFilter { name }, + "sk_msg" => SkMsg { name }, + "sk_skb/stream_parser" => SkSkbStreamParser { name }, + "sk_skb/stream_verdict" => SkSkbStreamVerdict { name }, + "sockops" => SockOps { name }, + "classifier" => SchedClassifier { name }, + "cgroup_skb/ingress" => CgroupSkbIngress { name }, + "cgroup_skb/egress" => CgroupSkbEgress { name }, + "lirc_mode2" => LircMode2 { name }, _ => { - return Err(ParseError::InvalidProgramKind { - kind: kind.to_string(), + return Err(ParseError::InvalidProgramSection { + section: section.to_owned(), }) } }) @@ -181,18 +218,15 @@ impl Object { Ok(()) } - fn parse_program( - &self, - section: &Section, - ty: &str, - name: &str, - ) -> Result { + fn parse_program(&self, section: &Section) -> Result { + let prog_sec = ProgramSection::from_str(section.name)?; + let name = prog_sec.name().to_owned(); Ok(Program { license: self.license.clone(), kernel_version: self.kernel_version, - kind: ProgramKind::from_str(ty)?, + section: prog_sec, function: Function { - name: name.to_owned(), + name, address: section.address, section_index: section.index, section_offset: 0, @@ -276,38 +310,23 @@ impl Object { } } - match parts.as_slice() { - &[name] - if name == ".bss" || name.starts_with(".data") || name.starts_with(".rodata") => - { + match section.name { + name if name == ".bss" || name.starts_with(".data") || name.starts_with(".rodata") => { self.maps .insert(name.to_string(), parse_map(§ion, name)?); } - &[name] if name.starts_with(".text") => self.parse_text_section(section)?, - &[".BTF"] => self.parse_btf(§ion)?, - &[".BTF.ext"] => self.parse_btf_ext(§ion)?, - &["maps", name] => { + name if name.starts_with(".text") => self.parse_text_section(section)?, + ".BTF" => self.parse_btf(§ion)?, + ".BTF.ext" => self.parse_btf_ext(§ion)?, + map if map.starts_with("maps/") => { + let name = map.splitn(2, "/").last().unwrap(); self.maps .insert(name.to_string(), parse_map(§ion, name)?); } - &[ty @ "kprobe", name] - | &[ty @ "kretprobe", name] - | &[ty @ "uprobe", name] - | &[ty @ "uretprobe", name] - | &[ty @ "socket_filter", name] - | &[ty @ "xdp", name] - | &[ty @ "tracepoint", name] - | &[ty @ "sk_msg", name] - | &[ty @ "sk_skb/stream_parser", name] - | &[ty @ "sk_skb/stream_verdict", name] - | &[ty @ "sockops", name] - | &[ty @ "classifier", name] - | &[ty @ "cgroup_skb/ingress", name] - | &[ty @ "cgroup_skb/egress", name] - | &[ty @ "cgroup/skb", name] - | &[ty @ "lirc_mode2", name] => { + name if is_program_section(name) => { + let program = self.parse_program(§ion)?; self.programs - .insert(name.to_string(), self.parse_program(§ion, ty, name)?); + .insert(program.section.name().to_owned(), program); if !section.relocations.is_empty() { self.relocations.insert( section.index, @@ -351,8 +370,8 @@ pub enum ParseError { #[error("unsupported relocation target")] UnsupportedRelocationTarget, - #[error("invalid program kind `{kind}`")] - InvalidProgramKind { kind: String }, + #[error("invalid program section `{section}`")] + InvalidProgramSection { section: String }, #[error("invalid program code")] InvalidProgramCode, @@ -516,6 +535,34 @@ fn copy_instructions(data: &[u8]) -> Result, ParseError> { Ok(instructions) } +fn is_program_section(name: &str) -> bool { + for prefix in &[ + "classifier", + "cgroup/skb", + "cgroup_skb/egress", + "cgroup_skb/ingress", + "kprobe", + "kretprobe", + "lirc_mode2", + "sk_msg", + "sk_skb/stream_parser", + "sk_skb/stream_verdict", + "socket_filter", + "sockops", + "tp", + "tracepoint", + "uprobe", + "uretprobe", + "xdp", + ] { + if name.starts_with(prefix) { + return true; + } + } + + false +} + #[cfg(test)] mod tests { use matches::assert_matches; @@ -723,11 +770,7 @@ mod tests { let obj = fake_obj(); assert_matches!( - obj.parse_program( - &fake_section("kprobe/foo", &42u32.to_ne_bytes(),), - "kprobe", - "foo" - ), + obj.parse_program(&fake_section("kprobe/foo", &42u32.to_ne_bytes(),),), Err(ParseError::InvalidProgramCode) ); } @@ -737,11 +780,11 @@ mod tests { let obj = fake_obj(); assert_matches!( - obj.parse_program(&fake_section("kprobe/foo", bytes_of(&fake_ins())), "kprobe", "foo"), + obj.parse_program(&fake_section("kprobe/foo", bytes_of(&fake_ins()))), Ok(Program { license, kernel_version: KernelVersion::Any, - kind: ProgramKind::KProbe, + section: ProgramSection::KProbe { .. }, function: Function { name, address: 0, @@ -820,7 +863,7 @@ mod tests { assert_matches!( obj.programs.get("foo"), Some(Program { - kind: ProgramKind::KProbe, + section: ProgramSection::KProbe { .. }, .. }) ); @@ -837,7 +880,7 @@ mod tests { assert_matches!( obj.programs.get("foo"), Some(Program { - kind: ProgramKind::UProbe, + section: ProgramSection::UProbe { .. }, .. }) ); @@ -854,7 +897,19 @@ mod tests { assert_matches!( obj.programs.get("foo"), Some(Program { - kind: ProgramKind::TracePoint, + section: ProgramSection::TracePoint { .. }, + .. + }) + ); + + assert_matches!( + obj.parse_section(fake_section("tp/foo/bar", bytes_of(&fake_ins()))), + Ok(()) + ); + assert_matches!( + obj.programs.get("foo/bar"), + Some(Program { + section: ProgramSection::TracePoint { .. }, .. }) ); @@ -871,7 +926,7 @@ mod tests { assert_matches!( obj.programs.get("foo"), Some(Program { - kind: ProgramKind::SocketFilter, + section: ProgramSection::SocketFilter { .. }, .. }) ); @@ -888,7 +943,7 @@ mod tests { assert_matches!( obj.programs.get("foo"), Some(Program { - kind: ProgramKind::Xdp, + section: ProgramSection::Xdp { .. }, .. }) );