diff --git a/aya/src/maps/lpm_trie.rs b/aya/src/maps/lpm_trie.rs index 83bad16d..40d91f4c 100644 --- a/aya/src/maps/lpm_trie.rs +++ b/aya/src/maps/lpm_trie.rs @@ -2,8 +2,8 @@ use std::{convert::TryFrom, marker::PhantomData, mem, ops::Deref}; use crate::{ - generated::bpf_map_type::BPF_MAP_TYPE_LPM_TRIE, - maps::{Map, MapError}, + generated::{bpf_lpm_trie_key, bpf_map_type::BPF_MAP_TYPE_LPM_TRIE}, + maps::{Map, IterableMap, MapError, MapRef, MapRefMut}, sys::{bpf_map_delete_elem, bpf_map_lookup_elem, bpf_map_update_elem}, Pod, }; @@ -32,6 +32,28 @@ pub struct LpmTrie<T: Deref<Target = Map>, K, V> { _v: PhantomData<V>, } +#[derive(Clone, Copy)] +#[repr(packed)] +pub struct Key<K: Pod> { + pub key_base: bpf_lpm_trie_key, + pub data: K, +} + +impl<K: Pod> Key<K> { + pub fn new(prefixlen: u32, data: K) -> Self { + Self { + key_base: bpf_lpm_trie_key { + prefixlen: prefixlen, + data: Default::default(), + }, + data: data, + } + } +} + +// A Pod impl is required as bpf_lpm_trie_key is a key for a map. +unsafe impl<K: Pod> Pod for Key<K> {} + impl<T: Deref<Target = Map>, K: Pod, V: Pod> LpmTrie<T, K, V> { pub(crate) fn new(map: T) -> Result<LpmTrie<T, K, V>, MapError> { let map_type = map.obj.def.map_type; @@ -42,7 +64,7 @@ impl<T: Deref<Target = Map>, K: Pod, V: Pod> LpmTrie<T, K, V> { map_type: map_type as u32, }); } - let size = mem::size_of::<K>(); + let size = mem::size_of::<K>() + mem::size_of::<bpf_lpm_trie_key>(); let expected = map.obj.def.key_size as usize; if size != expected { return Err(MapError::InvalidKeySize { size, expected }); @@ -63,9 +85,9 @@ impl<T: Deref<Target = Map>, K: Pod, V: Pod> LpmTrie<T, K, V> { } /// Returns a copy of the value associated with the key. - pub unsafe fn get(&self, key: &K, flags: u64) -> Result<V, MapError> { + pub fn get(&self, key: Key<K>, flags: u64) -> Result<V, MapError> { let fd = self.inner.deref().fd_or_err()?; - let value = bpf_map_lookup_elem(fd, key, flags).map_err(|(code, io_error)| { + let value = bpf_map_lookup_elem(fd, &key, flags).map_err(|(code, io_error)| { MapError::SyscallError { call: "bpf_map_lookup_elem".to_owned(), code, @@ -76,7 +98,7 @@ impl<T: Deref<Target = Map>, K: Pod, V: Pod> LpmTrie<T, K, V> { } /// Updates a key-value pair for the value associated with the key. - pub unsafe fn insert(&self, key: K, value: V, flags: u64) -> Result<(), MapError> { + pub fn insert(&self, key: Key<K>, value: V, flags: u64) -> Result<(), MapError> { let fd = self.inner.deref().fd_or_err()?; bpf_map_update_elem(fd, &key, &value, flags).map_err(|(code, io_error)| { MapError::SyscallError { @@ -90,9 +112,9 @@ impl<T: Deref<Target = Map>, K: Pod, V: Pod> LpmTrie<T, K, V> { } /// Deletes elements from the map by key. - pub unsafe fn remove(&self, key: &K) -> Result<(), MapError> { + pub fn remove(&self, key: &Key<K>) -> Result<(), MapError> { let fd = self.inner.deref().fd_or_err()?; - bpf_map_delete_elem(fd, key) + bpf_map_delete_elem(fd, &key) .map(|_| ()) .map_err(|(code, io_error)| MapError::SyscallError { call: "bpf_map_delete_elem".to_owned(), @@ -101,3 +123,47 @@ impl<T: Deref<Target = Map>, K: Pod, V: Pod> LpmTrie<T, K, V> { }) } } + + +impl<T: Deref<Target = Map>, K: Pod, V: Pod> IterableMap<K, V> for LpmTrie<T, K, V> { + fn map(&self) -> &Map { + &self.inner + } + + fn get(&self, key: &K) -> Result<V, MapError> { + let lookup = Key::new(mem::size_of::<K>() as u32, *key); + self.get(lookup, 0) + } +} + +impl<K: Pod, V: Pod> TryFrom<MapRef> for LpmTrie<MapRef, K, V> { + type Error = MapError; + + fn try_from(a: MapRef) -> Result<LpmTrie<MapRef, K, V>, MapError> { + LpmTrie::new(a) + } +} + +impl<K: Pod, V: Pod> TryFrom<MapRefMut> for LpmTrie<MapRefMut, K, V> { + type Error = MapError; + + fn try_from(a: MapRefMut) -> Result<LpmTrie<MapRefMut, K, V>, MapError> { + LpmTrie::new(a) + } +} + +impl<'a, K: Pod, V: Pod> TryFrom<&'a Map> for LpmTrie<&'a Map, K, V> { + type Error = MapError; + + fn try_from(a: &'a Map) -> Result<LpmTrie<&'a Map, K, V>, MapError> { + LpmTrie::new(a) + } +} + +impl<'a, K: Pod, V: Pod> TryFrom<&'a mut Map> for LpmTrie<&'a mut Map, K, V> { + type Error = MapError; + + fn try_from(a: &'a mut Map) -> Result<LpmTrie<&'a mut Map, K, V>, MapError> { + LpmTrie::new(a) + } +}