diff --git a/aya/src/maps/lpm_trie.rs b/aya/src/maps/lpm_trie.rs new file mode 100644 index 00000000..dad13c09 --- /dev/null +++ b/aya/src/maps/lpm_trie.rs @@ -0,0 +1,460 @@ +//! A LPM Trie. +use std::{convert::TryFrom, marker::PhantomData, mem, ops::Deref}; + +use crate::{ + generated::bpf_map_type::BPF_MAP_TYPE_LPM_TRIE, + maps::{IterableMap, Map, MapError, MapRef, MapRefMut}, + sys::{bpf_map_delete_elem, bpf_map_lookup_elem, bpf_map_update_elem}, + Pod, +}; + +/// A Longest Prefix Match Trie. +/// +/// # Minimum kernel version +/// +/// The minimum kernel version required to use this feature is 4.20. +/// +/// # Examples +/// +/// ```no_run +/// # let bpf = aya::Bpf::load(&[])?; +/// use aya::maps::lpm_trie::{LpmTrie, Key}; +/// use std::convert::TryFrom; +/// use std::net::Ipv4Addr; +/// +/// let mut trie = LpmTrie::try_from(bpf.map_mut("LPM_TRIE")?)?; +/// let ipaddr = Ipv4Addr::new(8, 8, 8, 8); +/// // The following represents a key for the "8.8.8.8/16" subnet. +/// // The first argument - the prefix length - represents how many bytes should be matched against. The second argument is the actual data to be matched. +/// let key = Key::new(16, u32::from(ipaddr).to_be()); +/// trie.insert(&key, 1, 0)?; +/// +/// // LpmTrie matches against the longest (most accurate) key. +/// let lookup = Key::new(32, u32::from(ipaddr).to_be()); +/// let value = trie.get(&lookup, 0)?; +/// assert_eq!(value, 1); +/// +/// // If we were to insert a key with longer 'prefix_len' +/// // our trie should match against it. +/// let longer_key = Key::new(24, u32::from(ipaddr).to_be()); +/// trie.insert(&longer_key, 2, 0)?; +/// let value = trie.get(&lookup, 0)?; +/// assert_eq!(value, 2); +/// # Ok::<(), aya::BpfError>(()) +/// ``` + +#[doc(alias = "BPF_MAP_TYPE_LPM_TRIE")] +pub struct LpmTrie, K, V> { + inner: T, + _k: PhantomData, + _v: PhantomData, +} + +/// A Key for and LpmTrie map. +/// +/// # Examples +/// +/// ```no_run +/// use aya::maps::lpm_trie::{LpmTrie, Key}; +/// use std::convert::TryFrom; +/// use std::net::Ipv4Addr; +/// +/// let ipaddr = Ipv4Addr::new(8,8,8,8); +/// let key = Key::new(16, u32::from(ipaddr).to_be()); +/// ``` +#[repr(packed)] +pub struct Key { + /// Represents the number of bytes matched against. + pub prefix_len: u32, + /// Represents arbitrary data stored in the LpmTrie. + pub data: K, +} + +impl Key { + /// Creates a new key. + /// + /// # Examples + /// + /// ```no_run + /// use aya::maps::lpm_trie::{LpmTrie, Key}; + /// use std::convert::TryFrom; + /// use std::net::Ipv4Addr; + /// + /// let ipaddr = Ipv4Addr::new(8, 8, 8, 8); + /// let key = Key::new(16, u32::from(ipaddr).to_be()); + /// ``` + pub fn new(prefix_len: u32, data: K) -> Self { + Self { prefix_len, data } + } +} + +impl Copy for Key {} + +impl Clone for Key { + fn clone(&self) -> Self { + *self + } +} + +// A Pod impl is required as Key struct is a key for a map. +unsafe impl Pod for Key {} + +impl, K: Pod, V: Pod> LpmTrie { + pub(crate) fn new(map: T) -> Result, MapError> { + let map_type = map.obj.def.map_type; + + // validate the map definition + if map_type != BPF_MAP_TYPE_LPM_TRIE as u32 { + return Err(MapError::InvalidMapType { + map_type: map_type as u32, + }); + } + let size = mem::size_of::>(); + let expected = map.obj.def.key_size as usize; + if size != expected { + return Err(MapError::InvalidKeySize { size, expected }); + } + let size = mem::size_of::(); + let expected = map.obj.def.value_size as usize; + if size != expected { + return Err(MapError::InvalidValueSize { size, expected }); + }; + + let _ = map.fd_or_err()?; + + Ok(LpmTrie { + inner: map, + _k: PhantomData, + _v: PhantomData, + }) + } + + /// Returns a copy of the value associated with the longest prefix matching key in the LpmTrie. + pub fn get(&self, key: &Key, flags: u64) -> Result { + let fd = self.inner.deref().fd_or_err()?; + let value = bpf_map_lookup_elem(fd, key, flags).map_err(|(code, io_error)| { + MapError::SyscallError { + call: "bpf_map_lookup_elem".to_owned(), + code, + io_error, + } + })?; + value.ok_or(MapError::KeyNotFound) + } + + /// Inserts a key value pair into the map. + pub fn insert(&self, key: &Key, value: V, flags: u64) -> Result<(), MapError> { + let fd = self.inner.deref().fd_or_err()?; + bpf_map_update_elem(fd, key, &value, flags).map_err(|(code, io_error)| { + MapError::SyscallError { + call: "bpf_map_update_elem".to_owned(), + code, + io_error, + } + })?; + + Ok(()) + } + + /// Removes an element from the map. + /// + /// Both the prefix and data must match exactly - this method does not do a longest prefix match. + pub fn remove(&self, key: &Key) -> Result<(), MapError> { + let fd = self.inner.deref().fd_or_err()?; + bpf_map_delete_elem(fd, key) + .map(|_| ()) + .map_err(|(code, io_error)| MapError::SyscallError { + call: "bpf_map_delete_elem".to_owned(), + code, + io_error, + }) + } +} + +impl, K: Pod, V: Pod> IterableMap for LpmTrie { + fn map(&self) -> &Map { + &self.inner + } + + fn get(&self, key: &K) -> Result { + let lookup = Key::new(mem::size_of::() as u32, *key); + self.get(&lookup, 0) + } +} + +impl TryFrom for LpmTrie { + type Error = MapError; + + fn try_from(a: MapRef) -> Result, MapError> { + LpmTrie::new(a) + } +} + +impl TryFrom for LpmTrie { + type Error = MapError; + + fn try_from(a: MapRefMut) -> Result, MapError> { + LpmTrie::new(a) + } +} + +impl<'a, K: Pod, V: Pod> TryFrom<&'a Map> for LpmTrie<&'a Map, K, V> { + type Error = MapError; + + fn try_from(a: &'a Map) -> Result, MapError> { + LpmTrie::new(a) + } +} + +impl<'a, K: Pod, V: Pod> TryFrom<&'a mut Map> for LpmTrie<&'a mut Map, K, V> { + type Error = MapError; + + fn try_from(a: &'a mut Map) -> Result, MapError> { + LpmTrie::new(a) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + bpf_map_def, + generated::{ + bpf_cmd, + bpf_map_type::{BPF_MAP_TYPE_LPM_TRIE, BPF_MAP_TYPE_PERF_EVENT_ARRAY}, + }, + obj, + sys::{override_syscall, SysResult, Syscall}, + }; + use libc::{EFAULT, ENOENT}; + use std::{io, mem, net::Ipv4Addr}; + + fn new_obj_map() -> obj::Map { + obj::Map { + def: bpf_map_def { + map_type: BPF_MAP_TYPE_LPM_TRIE as u32, + key_size: mem::size_of::>() as u32, + value_size: 4, + max_entries: 1024, + ..Default::default() + }, + section_index: 0, + data: Vec::new(), + kind: obj::MapKind::Other, + } + } + + fn sys_error(value: i32) -> SysResult { + Err((-1, io::Error::from_raw_os_error(value))) + } + + #[test] + fn test_wrong_key_size() { + let map = Map { + obj: new_obj_map(), + fd: None, + pinned: false, + }; + assert!(matches!( + LpmTrie::<_, u16, u32>::new(&map), + Err(MapError::InvalidKeySize { + size: 6, + expected: 8 // four bytes for prefixlen and four bytes for data. + }) + )); + } + + #[test] + fn test_wrong_value_size() { + let map = Map { + obj: new_obj_map(), + fd: None, + pinned: false, + }; + assert!(matches!( + LpmTrie::<_, u32, u16>::new(&map), + Err(MapError::InvalidValueSize { + size: 2, + expected: 4 + }) + )); + } + + #[test] + fn test_try_from_wrong_map() { + let map = Map { + obj: obj::Map { + def: bpf_map_def { + map_type: BPF_MAP_TYPE_PERF_EVENT_ARRAY as u32, + key_size: 4, + value_size: 4, + max_entries: 1024, + ..Default::default() + }, + section_index: 0, + data: Vec::new(), + kind: obj::MapKind::Other, + }, + fd: None, + pinned: false, + }; + + assert!(matches!( + LpmTrie::<_, u32, u32>::try_from(&map), + Err(MapError::InvalidMapType { .. }) + )); + } + + #[test] + fn test_new_not_created() { + let mut map = Map { + obj: new_obj_map(), + fd: None, + pinned: false, + }; + + assert!(matches!( + LpmTrie::<_, u32, u32>::new(&mut map), + Err(MapError::NotCreated { .. }) + )); + } + + #[test] + fn test_new_ok() { + let mut map = Map { + obj: new_obj_map(), + fd: Some(42), + pinned: false, + }; + + assert!(LpmTrie::<_, u32, u32>::new(&mut map).is_ok()); + } + + #[test] + fn test_try_from_ok() { + let map = Map { + obj: new_obj_map(), + fd: Some(42), + pinned: false, + }; + assert!(LpmTrie::<_, u32, u32>::try_from(&map).is_ok()) + } + + #[test] + fn test_insert_syscall_error() { + override_syscall(|_| sys_error(EFAULT)); + + let mut map = Map { + obj: new_obj_map(), + fd: Some(42), + pinned: false, + }; + let trie = LpmTrie::<_, u32, u32>::new(&mut map).unwrap(); + let ipaddr = Ipv4Addr::new(8, 8, 8, 8); + let key = Key::new(16, u32::from(ipaddr).to_be()); + assert!(matches!( + trie.insert(&key, 1, 0), + Err(MapError::SyscallError { call, code: -1, io_error }) if call == "bpf_map_update_elem" && io_error.raw_os_error() == Some(EFAULT) + )); + } + + #[test] + fn test_insert_ok() { + override_syscall(|call| match call { + Syscall::Bpf { + cmd: bpf_cmd::BPF_MAP_UPDATE_ELEM, + .. + } => Ok(1), + _ => sys_error(EFAULT), + }); + + let mut map = Map { + obj: new_obj_map(), + fd: Some(42), + pinned: false, + }; + + let trie = LpmTrie::<_, u32, u32>::new(&mut map).unwrap(); + let ipaddr = Ipv4Addr::new(8, 8, 8, 8); + let key = Key::new(16, u32::from(ipaddr).to_be()); + assert!(trie.insert(&key, 1, 0).is_ok()); + } + + #[test] + fn test_remove_syscall_error() { + override_syscall(|_| sys_error(EFAULT)); + + let mut map = Map { + obj: new_obj_map(), + fd: Some(42), + pinned: false, + }; + let trie = LpmTrie::<_, u32, u32>::new(&mut map).unwrap(); + let ipaddr = Ipv4Addr::new(8, 8, 8, 8); + let key = Key::new(16, u32::from(ipaddr).to_be()); + assert!(matches!( + trie.remove(&key), + Err(MapError::SyscallError { call, code: -1, io_error }) if call == "bpf_map_delete_elem" && io_error.raw_os_error() == Some(EFAULT) + )); + } + + #[test] + fn test_remove_ok() { + override_syscall(|call| match call { + Syscall::Bpf { + cmd: bpf_cmd::BPF_MAP_DELETE_ELEM, + .. + } => Ok(1), + _ => sys_error(EFAULT), + }); + + let mut map = Map { + obj: new_obj_map(), + fd: Some(42), + pinned: false, + }; + let trie = LpmTrie::<_, u32, u32>::new(&mut map).unwrap(); + let ipaddr = Ipv4Addr::new(8, 8, 8, 8); + let key = Key::new(16, u32::from(ipaddr).to_be()); + assert!(trie.remove(&key).is_ok()); + } + + #[test] + fn test_get_syscall_error() { + override_syscall(|_| sys_error(EFAULT)); + let map = Map { + obj: new_obj_map(), + fd: Some(42), + pinned: false, + }; + let trie = LpmTrie::<_, u32, u32>::new(&map).unwrap(); + let ipaddr = Ipv4Addr::new(8, 8, 8, 8); + let key = Key::new(16, u32::from(ipaddr).to_be()); + + assert!(matches!( + trie.get(&key, 0), + Err(MapError::SyscallError { call, code: -1, io_error }) if call == "bpf_map_lookup_elem" && io_error.raw_os_error() == Some(EFAULT) + )); + } + + #[test] + fn test_get_not_found() { + override_syscall(|call| match call { + Syscall::Bpf { + cmd: bpf_cmd::BPF_MAP_LOOKUP_ELEM, + .. + } => sys_error(ENOENT), + _ => sys_error(EFAULT), + }); + let map = Map { + obj: new_obj_map(), + fd: Some(42), + pinned: false, + }; + let trie = LpmTrie::<_, u32, u32>::new(&map).unwrap(); + let ipaddr = Ipv4Addr::new(8, 8, 8, 8); + let key = Key::new(16, u32::from(ipaddr).to_be()); + + assert!(matches!(trie.get(&key, 0), Err(MapError::KeyNotFound))); + } +} diff --git a/aya/src/maps/mod.rs b/aya/src/maps/mod.rs index 29edec55..59c7b0c9 100644 --- a/aya/src/maps/mod.rs +++ b/aya/src/maps/mod.rs @@ -50,6 +50,7 @@ mod map_lock; pub mod array; pub mod hash_map; +pub mod lpm_trie; pub mod perf; pub mod queue; pub mod sock;