diff --git a/aya/src/bpf.rs b/aya/src/bpf.rs index 91d3a3aa..00c08870 100644 --- a/aya/src/bpf.rs +++ b/aya/src/bpf.rs @@ -48,13 +48,15 @@ unsafe_impl_pod!(i8, u8, i16, u16, i32, u32, i64, u64); #[allow(non_camel_case_types)] #[repr(C)] -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, Default, PartialEq)] pub(crate) struct bpf_map_def { + // minimum features required by old BPF programs pub(crate) map_type: u32, pub(crate) key_size: u32, pub(crate) value_size: u32, pub(crate) max_entries: u32, pub(crate) map_flags: u32, + // optional features pub(crate) id: u32, pub(crate) pinning: u32, } diff --git a/aya/src/maps/hash_map/hash_map.rs b/aya/src/maps/hash_map/hash_map.rs index 63a260f9..1ba608f2 100644 --- a/aya/src/maps/hash_map/hash_map.rs +++ b/aya/src/maps/hash_map/hash_map.rs @@ -162,9 +162,7 @@ mod tests { key_size: 4, value_size: 4, max_entries: 1024, - map_flags: 0, - id: 0, - pinning: 0, + ..Default::default() }, section_index: 0, data: Vec::new(), @@ -215,9 +213,7 @@ mod tests { key_size: 4, value_size: 4, max_entries: 1024, - map_flags: 0, - id: 0, - pinning: 0, + ..Default::default() }, section_index: 0, data: Vec::new(), @@ -273,9 +269,7 @@ mod tests { key_size: 4, value_size: 4, max_entries: 1024, - map_flags: 0, - id: 0, - pinning: 0, + ..Default::default() }, section_index: 0, data: Vec::new(), diff --git a/aya/src/maps/mod.rs b/aya/src/maps/mod.rs index cb9176f7..f96e4b32 100644 --- a/aya/src/maps/mod.rs +++ b/aya/src/maps/mod.rs @@ -426,9 +426,7 @@ mod tests { key_size: 4, value_size: 4, max_entries: 1024, - map_flags: 0, - id: 0, - pinning: 0, + ..Default::default() }, section_index: 0, data: Vec::new(), diff --git a/aya/src/obj/mod.rs b/aya/src/obj/mod.rs index 45b1a3ce..819f8675 100644 --- a/aya/src/obj/mod.rs +++ b/aya/src/obj/mod.rs @@ -22,8 +22,11 @@ use crate::{ obj::btf::{Btf, BtfError, BtfExt}, BpfError, }; +use std::slice::from_raw_parts_mut; const KERNEL_VERSION_ANY: u32 = 0xFFFF_FFFE; +/// The first five __u32 of `bpf_map_def` must be defined. +const MINIMUM_MAP_SIZE: usize = mem::size_of::() * 5; #[derive(Clone)] pub struct Object { @@ -493,8 +496,7 @@ fn parse_map(section: &Section, name: &str) -> Result { value_size: section.data.len() as u32, max_entries: 1, map_flags: 0, /* FIXME: set rodata readonly */ - id: 0, - pinning: 0, + ..Default::default() }; (def, section.data.to_vec()) } else { @@ -510,13 +512,23 @@ fn parse_map(section: &Section, name: &str) -> Result { } fn parse_map_def(name: &str, data: &[u8]) -> Result { - if mem::size_of::() > data.len() { + if data.len() > mem::size_of::() || data.len() < MINIMUM_MAP_SIZE { return Err(ParseError::InvalidMapDefinition { name: name.to_owned(), }); } - Ok(unsafe { ptr::read_unaligned(data.as_ptr() as *const bpf_map_def) }) + if data.len() < mem::size_of::() { + let mut map_def = bpf_map_def::default(); + unsafe { + let map_def_ptr = + from_raw_parts_mut(&mut map_def as *mut bpf_map_def as *mut u8, data.len()); + map_def_ptr.copy_from_slice(data); + } + Ok(map_def) + } else { + Ok(unsafe { ptr::read_unaligned(data.as_ptr() as *const bpf_map_def) }) + } } fn copy_instructions(data: &[u8]) -> Result, ParseError> { @@ -662,6 +674,10 @@ mod tests { Err(ParseError::InvalidMapDefinition { .. }) )); assert!(matches!( + parse_map_def("foo", &[0u8; std::mem::size_of::() + 1]), + Err(ParseError::InvalidMapDefinition { .. }) + )); + assert_eq!( parse_map_def( "foo", bytes_of(&bpf_map_def { @@ -670,20 +686,55 @@ mod tests { value_size: 3, max_entries: 4, map_flags: 5, - id: 0, - pinning: 0 + ..Default::default() }) - ), - Ok(bpf_map_def { + ) + .unwrap(), + bpf_map_def { map_type: 1, key_size: 2, value_size: 3, max_entries: 4, map_flags: 5, - id: 0, - pinning: 0 - }) - )); + ..Default::default() + } + ); + + assert_eq!( + parse_map_def( + "foo", + &bytes_of(&bpf_map_def { + map_type: 1, + key_size: 2, + value_size: 3, + max_entries: 4, + map_flags: 5, + ..Default::default() + })[..(mem::size_of::() * 5)] + ) + .unwrap(), + bpf_map_def { + map_type: 1, + key_size: 2, + value_size: 3, + max_entries: 4, + map_flags: 5, + ..Default::default() + } + ); + let map = parse_map_def( + "foo", + &bytes_of(&bpf_map_def { + map_type: 1, + key_size: 2, + value_size: 3, + max_entries: 4, + map_flags: 5, + ..Default::default() + })[..(mem::size_of::() * 5)], + ) + .unwrap(); + assert!(map.id == 0 && map.pinning == 0) } #[test] @@ -691,6 +742,13 @@ mod tests { assert!(matches!( parse_map(&fake_section("maps/foo", &[]), "foo"), Err(ParseError::InvalidMapDefinition { .. }) + )); + assert!(matches!( + parse_map( + &fake_section("maps/foo", &[0u8; std::mem::size_of::() + 1]), + "foo" + ), + Err(ParseError::InvalidMapDefinition { .. }) )) } @@ -750,7 +808,7 @@ mod tests { max_entries: 1, map_flags: 0, id: 0, - pinning: 0 + pinning: 0, }, data }) if name == ".bss" && data == map_data && value_size == map_data.len() as u32 @@ -809,8 +867,7 @@ mod tests { value_size: 3, max_entries: 4, map_flags: 5, - id: 0, - pinning: 0 + ..Default::default() }) ),), Ok(())