diff --git a/ebpf/aya-ebpf/src/maps/array.rs b/ebpf/aya-ebpf/src/maps/array.rs index d9911547..55888c34 100644 --- a/ebpf/aya-ebpf/src/maps/array.rs +++ b/ebpf/aya-ebpf/src/maps/array.rs @@ -1,4 +1,4 @@ -use core::{cell::UnsafeCell, marker::PhantomData, mem, ptr::NonNull}; +use core::{cell::UnsafeCell, fmt, marker::PhantomData, mem, ptr::NonNull}; use aya_ebpf_cty::c_long; @@ -14,9 +14,30 @@ pub struct Array { _t: PhantomData, } +#[derive(Debug)] +pub struct OutOfBounds { + length: u32, + index: u32, +} + +impl core::error::Error for OutOfBounds {} + +impl fmt::Display for OutOfBounds { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "Array index {} is out of bounds; length = {}", + self.index, self.length + ) + } +} + unsafe impl Sync for Array {} impl Array { + /// Constructs a new [`Array`] with `max_entries` entries. + /// + /// All the entries are zero-initialized. pub const fn with_max_entries(max_entries: u32, flags: u32) -> Array { Array { def: UnsafeCell::new(bpf_map_def { @@ -47,25 +68,36 @@ impl Array { } } + /// Get a reference to an array element at the given index. + /// + /// Arrays are zero-initialised upon creation so this will always return a reference to `T` as long as the index is within bounds. + /// All fields of `T` may therefore be in their default state, i.e. `bool` will be `false`, `u32` will be `0`, etc. #[inline(always)] - pub fn get(&self, index: u32) -> Option<&T> { + pub fn get(&self, index: u32) -> Result<&T, OutOfBounds> { // FIXME: alignment unsafe { self.lookup(index).map(|p| p.as_ref()) } } + /// Like [`Array::get`] but returns a pointer instead. #[inline(always)] - pub fn get_ptr(&self, index: u32) -> Option<*const T> { + pub fn get_ptr(&self, index: u32) -> Result<*const T, OutOfBounds> { unsafe { self.lookup(index).map(|p| p.as_ptr() as *const T) } } + /// Like [`Array::get`] but returns a mutable pointer instead. #[inline(always)] - pub fn get_ptr_mut(&self, index: u32) -> Option<*mut T> { + pub fn get_ptr_mut(&self, index: u32) -> Result<*mut T, OutOfBounds> { unsafe { self.lookup(index).map(|p| p.as_ptr()) } } #[inline(always)] - unsafe fn lookup(&self, index: u32) -> Option> { - lookup(self.def.get(), &index) + unsafe fn lookup(&self, index: u32) -> Result, OutOfBounds> { + let map_def = self.def.get(); + + lookup(map_def, &index).ok_or(OutOfBounds { + length: unsafe { (*map_def).max_entries }, + index, + }) } /// Sets the value of the element at the given index. diff --git a/test/integration-ebpf/src/bpf_probe_read.rs b/test/integration-ebpf/src/bpf_probe_read.rs index 2a622a25..2e304713 100644 --- a/test/integration-ebpf/src/bpf_probe_read.rs +++ b/test/integration-ebpf/src/bpf_probe_read.rs @@ -22,7 +22,7 @@ fn read_str_bytes( let Some(ilen) = ilen else { return; }; - let Some(ptr) = RESULT.get_ptr_mut(0) else { + let Ok(ptr) = RESULT.get_ptr_mut(0) else { return; }; let dst = unsafe { ptr.as_mut() }; @@ -68,6 +68,7 @@ pub fn test_bpf_probe_read_kernel_str_bytes(ctx: ProbeContext) { bpf_probe_read_kernel_str_bytes, KERNEL_BUFFER .get_ptr(0) + .ok() .and_then(|ptr| unsafe { ptr.as_ref() }) .map(|buf| buf.as_ptr()), ctx.arg::(0), diff --git a/test/integration-ebpf/src/map_test.rs b/test/integration-ebpf/src/map_test.rs index 6eaf4123..0b22fa62 100644 --- a/test/integration-ebpf/src/map_test.rs +++ b/test/integration-ebpf/src/map_test.rs @@ -29,7 +29,7 @@ static MAP_WITH_LOOOONG_NAAAAAAAAME: HashMap = HashMap::::with #[socket_filter] pub fn simple_prog(_ctx: SkBuffContext) -> i64 { // So that these maps show up under the `map_ids` field. - FOO.get(0); + let _ = FOO.get(0); // If we use the literal value `0` instead of the local variable `i`, then an additional // `.rodata` map will be associated with the program. let i = 0; diff --git a/test/integration-ebpf/src/raw_tracepoint.rs b/test/integration-ebpf/src/raw_tracepoint.rs index 513d4ab5..99b7a018 100644 --- a/test/integration-ebpf/src/raw_tracepoint.rs +++ b/test/integration-ebpf/src/raw_tracepoint.rs @@ -18,7 +18,7 @@ pub fn sys_enter(ctx: RawTracePointContext) -> i32 { let common_type: u16 = unsafe { ctx.arg(0) }; let common_flags: u8 = unsafe { ctx.arg(1) }; - if let Some(ptr) = RESULT.get_ptr_mut(0) { + if let Ok(ptr) = RESULT.get_ptr_mut(0) { unsafe { (*ptr).common_type = common_type; (*ptr).common_flags = common_flags; diff --git a/test/integration-ebpf/src/redirect.rs b/test/integration-ebpf/src/redirect.rs index b559c06d..98b3428e 100644 --- a/test/integration-ebpf/src/redirect.rs +++ b/test/integration-ebpf/src/redirect.rs @@ -72,7 +72,7 @@ pub fn redirect_dev_chain(_ctx: XdpContext) -> u32 { #[inline(always)] fn inc_hit(index: u32) { - if let Some(hit) = HITS.get_ptr_mut(index) { + if let Ok(hit) = HITS.get_ptr_mut(index) { unsafe { *hit += 1 }; } } diff --git a/test/integration-ebpf/src/relocations.rs b/test/integration-ebpf/src/relocations.rs index 03b0d342..64d2d801 100644 --- a/test/integration-ebpf/src/relocations.rs +++ b/test/integration-ebpf/src/relocations.rs @@ -30,7 +30,7 @@ pub fn test_64_32_call_relocs(_ctx: ProbeContext) { #[inline(never)] fn set_result(index: u32, value: u64) { unsafe { - if let Some(v) = RESULTS.get_ptr_mut(index) { + if let Ok(v) = RESULTS.get_ptr_mut(index) { *v = value; } } diff --git a/test/integration-ebpf/src/strncmp.rs b/test/integration-ebpf/src/strncmp.rs index 9b831157..81ca4124 100644 --- a/test/integration-ebpf/src/strncmp.rs +++ b/test/integration-ebpf/src/strncmp.rs @@ -21,7 +21,7 @@ pub fn test_bpf_strncmp(ctx: ProbeContext) -> Result<(), c_long> { let mut b1 = [0u8; 3]; let _: &[u8] = unsafe { bpf_probe_read_user_str_bytes(s1, &mut b1) }?; - let ptr = RESULT.get_ptr_mut(0).ok_or(-1)?; + let ptr = RESULT.get_ptr_mut(0).map_err(|_| -1)?; let dst = unsafe { ptr.as_mut() }; let TestResult(dst_res) = dst.ok_or(-1)?; *dst_res = bpf_strncmp(&b1, c"ff");