diff --git a/aya/src/maps/hash_map.rs b/aya/src/maps/hash_map.rs index 5ce1f41b..45cd7f6e 100644 --- a/aya/src/maps/hash_map.rs +++ b/aya/src/maps/hash_map.rs @@ -23,25 +23,29 @@ pub struct HashMap, K, V> { impl, K: Pod, V: Pod> HashMap { pub fn new(map: T) -> Result, MapError> { - let inner = map.deref(); - let map_type = inner.obj.def.map_type; + let map_type = map.obj.def.map_type; + + // validate the map definition if map_type != BPF_MAP_TYPE_HASH { return Err(MapError::InvalidMapType { map_type: map_type as u32, })?; } let size = mem::size_of::(); - let expected = inner.obj.def.key_size as usize; + let expected = map.obj.def.key_size as usize; if size != expected { return Err(MapError::InvalidKeySize { size, expected }); } let size = mem::size_of::(); - let expected = inner.obj.def.value_size as usize; + let expected = map.obj.def.value_size as usize; if size != expected { return Err(MapError::InvalidValueSize { size, expected }); } + // make sure the map has been created + let _fd = map.fd_or_err()?; + Ok(HashMap { inner: map, _k: PhantomData, @@ -220,26 +224,35 @@ mod tests { } #[test] - fn test_try_from_ok() { - let map = Map { + fn test_new_not_created() { + let mut map = Map { obj: new_obj_map("TEST"), fd: None, }; - assert!(HashMap::<_, u32, u32>::try_from(&map).is_ok()) + + assert!(matches!( + HashMap::<_, u32, u32>::new(&mut map), + Err(MapError::NotCreated { .. }) + )); } #[test] - fn test_insert_not_created() { + fn test_new_ok() { let mut map = Map { obj: new_obj_map("TEST"), - fd: None, + fd: Some(42), }; - let mut hm = HashMap::<_, u32, u32>::new(&mut map).unwrap(); - assert!(matches!( - hm.insert(1, 42, 0), - Err(MapError::NotCreated { .. }) - )); + assert!(HashMap::<_, u32, u32>::new(&mut map).is_ok()); + } + + #[test] + fn test_try_from_ok() { + let map = Map { + obj: new_obj_map("TEST"), + fd: Some(42), + }; + assert!(HashMap::<_, u32, u32>::try_from(&map).is_ok()) } #[test] @@ -277,17 +290,6 @@ mod tests { assert!(hm.insert(1, 42, 0).is_ok()); } - #[test] - fn test_remove_not_created() { - let mut map = Map { - obj: new_obj_map("TEST"), - fd: None, - }; - let mut hm = HashMap::<_, u32, u32>::new(&mut map).unwrap(); - - assert!(matches!(hm.remove(&1), Err(MapError::NotCreated { .. }))); - } - #[test] fn test_remove_syscall_error() { override_syscall(|_| sys_error(EFAULT)); @@ -323,20 +325,6 @@ mod tests { assert!(hm.remove(&1).is_ok()); } - #[test] - fn test_get_not_created() { - let map = Map { - obj: new_obj_map("TEST"), - fd: None, - }; - let hm = HashMap::<_, u32, u32>::new(&map).unwrap(); - - assert!(matches!( - unsafe { hm.get(&1, 0) }, - Err(MapError::NotCreated { .. }) - )); - } - #[test] fn test_get_syscall_error() { override_syscall(|_| sys_error(EFAULT)); @@ -370,20 +358,6 @@ mod tests { assert!(matches!(unsafe { hm.get(&1, 0) }, Ok(None))); } - #[test] - fn test_pop_not_created() { - let mut map = Map { - obj: new_obj_map("TEST"), - fd: None, - }; - let mut hm = HashMap::<_, u32, u32>::new(&mut map).unwrap(); - - assert!(matches!( - unsafe { hm.pop(&1) }, - Err(MapError::NotCreated { .. }) - )); - } - #[test] fn test_pop_syscall_error() { override_syscall(|_| sys_error(EFAULT)); diff --git a/aya/src/maps/perf_map/perf_map.rs b/aya/src/maps/perf_map/perf_map.rs index 7c151dc0..3e6120e6 100644 --- a/aya/src/maps/perf_map/perf_map.rs +++ b/aya/src/maps/perf_map/perf_map.rs @@ -73,6 +73,7 @@ impl> PerfMap { map_type: map_type as u32, })?; } + let _fd = map.fd_or_err()?; Ok(PerfMap { map: Arc::new(map), diff --git a/aya/src/maps/program_array.rs b/aya/src/maps/program_array.rs index 4383a5be..82712985 100644 --- a/aya/src/maps/program_array.rs +++ b/aya/src/maps/program_array.rs @@ -21,30 +21,30 @@ pub struct ProgramArray> { impl> ProgramArray { pub fn new(map: T) -> Result, MapError> { - let inner = map.deref(); - let map_type = inner.obj.def.map_type; + let map_type = map.obj.def.map_type; if map_type != BPF_MAP_TYPE_PROG_ARRAY { return Err(MapError::InvalidMapType { map_type: map_type as u32, })?; } let expected = mem::size_of::(); - let size = inner.obj.def.key_size as usize; + let size = map.obj.def.key_size as usize; if size != expected { return Err(MapError::InvalidKeySize { size, expected }); } let expected = mem::size_of::(); - let size = inner.obj.def.value_size as usize; + let size = map.obj.def.value_size as usize; if size != expected { return Err(MapError::InvalidValueSize { size, expected }); } + let _fd = map.fd_or_err()?; Ok(ProgramArray { inner: map }) } pub unsafe fn get(&self, key: &u32, flags: u64) -> Result, MapError> { - let fd = self.inner.deref().fd_or_err()?; + let fd = self.inner.fd_or_err()?; let fd = bpf_map_lookup_elem(fd, key, flags) .map_err(|(code, io_error)| MapError::LookupElementError { code, io_error })?; Ok(fd) @@ -59,8 +59,8 @@ impl> ProgramArray { } fn check_bounds(&self, index: u32) -> Result<(), MapError> { - let max_entries = self.inner.deref().obj.def.max_entries; - if index >= self.inner.deref().obj.def.max_entries { + let max_entries = self.inner.obj.def.max_entries; + if index >= self.inner.obj.def.max_entries { Err(MapError::OutOfBounds { index, max_entries }) } else { Ok(()) @@ -75,7 +75,7 @@ impl + DerefMut> ProgramArray { program: &dyn ProgramFd, flags: u64, ) -> Result<(), MapError> { - let fd = self.inner.deref().fd_or_err()?; + let fd = self.inner.fd_or_err()?; self.check_bounds(index)?; let prog_fd = program.fd().ok_or(MapError::ProgramNotLoaded)?; @@ -85,14 +85,14 @@ impl + DerefMut> ProgramArray { } pub unsafe fn pop(&mut self, index: &u32) -> Result, MapError> { - let fd = self.inner.deref().fd_or_err()?; + let fd = self.inner.fd_or_err()?; self.check_bounds(*index)?; bpf_map_lookup_and_delete_elem(fd, index) .map_err(|(code, io_error)| MapError::LookupAndDeleteElementError { code, io_error }) } pub fn remove(&mut self, index: &u32) -> Result<(), MapError> { - let fd = self.inner.deref().fd_or_err()?; + let fd = self.inner.fd_or_err()?; self.check_bounds(*index)?; bpf_map_delete_elem(fd, index) .map(|_| ()) @@ -102,7 +102,7 @@ impl + DerefMut> ProgramArray { impl> IterableMap for ProgramArray { fn fd(&self) -> Result { - self.inner.deref().fd_or_err() + self.inner.fd_or_err() } unsafe fn get(&self, index: &u32) -> Result, MapError> {