diff --git a/Cargo.toml b/Cargo.toml index e24bb95c..7d3c1d81 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,7 @@ members = [ "aya", "aya-build", + "aya-common", "aya-log", "aya-log-common", "aya-log-parser", diff --git a/aya-common/Cargo.toml b/aya-common/Cargo.toml new file mode 100644 index 00000000..99b738af --- /dev/null +++ b/aya-common/Cargo.toml @@ -0,0 +1,23 @@ +[package] +description = "Library shared across eBPF and user-space" +documentation = "https://docs.rs/aya-common" +keywords = ["bpf", "ebpf", "common"] +name = "aya-common" +version = "0.1.0" + +authors.workspace = true +edition.workspace = true +homepage.workspace = true +license.workspace = true +repository.workspace = true +rust-version.workspace = true + +[lints] +workspace = true + +[dependencies] +aya = { path = "../aya", version = "^0.13.1", optional = true } +aya-ebpf-cty = { version = "^0.2.2", path = "../ebpf/aya-ebpf-cty" } + +[features] +user = ["dep:aya"] diff --git a/aya-common/src/lib.rs b/aya-common/src/lib.rs new file mode 100644 index 00000000..51c6b54b --- /dev/null +++ b/aya-common/src/lib.rs @@ -0,0 +1,5 @@ +#![no_std] + +pub mod spin_lock; + +pub use spin_lock::SpinLock; diff --git a/aya-common/src/spin_lock.rs b/aya-common/src/spin_lock.rs new file mode 100644 index 00000000..234cae2e --- /dev/null +++ b/aya-common/src/spin_lock.rs @@ -0,0 +1,22 @@ +use aya_ebpf_cty::c_uint; + +// #[expect(non_camel_case_types, reason = "Binding to a C type.")] +#[repr(C)] +#[derive(Debug, Copy, Clone, Default)] +pub struct bpf_spin_lock { + pub val: c_uint, +} + +/// A spin lock that can be used to procect shared data in eBPF maps. +#[repr(C)] +#[derive(Debug, Copy, Clone, Default)] +pub struct SpinLock(bpf_spin_lock); + +impl SpinLock { + pub fn as_ptr(&self) -> *mut bpf_spin_lock { + core::ptr::from_ref::(&self.0).cast_mut() + } +} + +#[cfg(feature = "user")] +unsafe impl aya::Pod for SpinLock {} diff --git a/ebpf/aya-ebpf/Cargo.toml b/ebpf/aya-ebpf/Cargo.toml index c15afbef..b50e00a1 100644 --- a/ebpf/aya-ebpf/Cargo.toml +++ b/ebpf/aya-ebpf/Cargo.toml @@ -14,6 +14,7 @@ rust-version.workspace = true workspace = true [dependencies] +aya-common = { version = "^0.1.0", path = "../../aya-common" } aya-ebpf-bindings = { version = "^0.1.1", path = "../aya-ebpf-bindings" } aya-ebpf-cty = { version = "^0.2.2", path = "../aya-ebpf-cty" } aya-ebpf-macros = { version = "^0.1.1", path = "../../aya-ebpf-macros" } diff --git a/ebpf/aya-ebpf/src/lib.rs b/ebpf/aya-ebpf/src/lib.rs index dc59d351..447b8a59 100644 --- a/ebpf/aya-ebpf/src/lib.rs +++ b/ebpf/aya-ebpf/src/lib.rs @@ -27,6 +27,7 @@ pub mod btf_maps; pub mod helpers; pub mod maps; pub mod programs; +pub mod spin_lock; use core::ptr::NonNull; diff --git a/ebpf/aya-ebpf/src/spin_lock.rs b/ebpf/aya-ebpf/src/spin_lock.rs new file mode 100644 index 00000000..35cb6b26 --- /dev/null +++ b/ebpf/aya-ebpf/src/spin_lock.rs @@ -0,0 +1,33 @@ +pub use aya_common::spin_lock::SpinLock; + +use crate::{bindings, helpers}; + +/// An RAII implementation of a scope of a spin lock. When this structure is +/// dropped (falls out of scope), the lock will be unlocked. +pub struct SpinLockGuard<'a> { + spin_lock: &'a SpinLock, +} + +impl Drop for SpinLockGuard<'_> { + fn drop(&mut self) { + unsafe { + helpers::bpf_spin_unlock(self.spin_lock.as_ptr().cast::()); + } + } +} + +/// Extension trait allowing to acquire a [`SpinLock`] in an eBPF program. +pub trait EbpfSpinLock { + /// Acquires a spin lock and returns a [`SpinLockGuard`]. The lock is + /// acquired as long as the guard is alive. + fn lock(&self) -> SpinLockGuard<'_>; +} + +impl EbpfSpinLock for SpinLock { + fn lock(&self) -> SpinLockGuard<'_> { + unsafe { + helpers::bpf_spin_lock(self.as_ptr().cast::()); + } + SpinLockGuard { spin_lock: self } + } +} diff --git a/test/integration-common/Cargo.toml b/test/integration-common/Cargo.toml index 437f87e9..1b731db2 100644 --- a/test/integration-common/Cargo.toml +++ b/test/integration-common/Cargo.toml @@ -15,6 +15,7 @@ workspace = true [dependencies] aya = { path = "../../aya", optional = true } +aya-common = { path = "../../aya-common" } [features] user = ["aya"] diff --git a/test/integration-common/src/lib.rs b/test/integration-common/src/lib.rs index 8932c15d..676c4bce 100644 --- a/test/integration-common/src/lib.rs +++ b/test/integration-common/src/lib.rs @@ -76,6 +76,20 @@ pub mod ring_buf { unsafe impl aya::Pod for Registers {} } +pub mod spin_lock { + use aya_common::SpinLock; + + #[derive(Copy, Clone)] + #[repr(C)] + pub struct Counter { + pub count: u32, + pub spin_lock: SpinLock, + } + + #[cfg(feature = "user")] + unsafe impl aya::Pod for Counter {} +} + pub mod strncmp { #[derive(Copy, Clone)] #[repr(C)] diff --git a/test/integration-ebpf/Cargo.toml b/test/integration-ebpf/Cargo.toml index d4d6da59..a95f25c3 100644 --- a/test/integration-ebpf/Cargo.toml +++ b/test/integration-ebpf/Cargo.toml @@ -76,6 +76,10 @@ path = "src/ring_buf.rs" name = "simple_prog" path = "src/simple_prog.rs" +[[bin]] +name = "spin_lock" +path = "src/spin_lock.rs" + [[bin]] name = "strncmp" path = "src/strncmp.rs" diff --git a/test/integration-ebpf/src/spin_lock.rs b/test/integration-ebpf/src/spin_lock.rs new file mode 100644 index 00000000..72057294 --- /dev/null +++ b/test/integration-ebpf/src/spin_lock.rs @@ -0,0 +1,32 @@ +#![no_std] +#![no_main] +#![expect(unused_crate_dependencies, reason = "used in other bins")] + +#[cfg(not(test))] +extern crate ebpf_panic; + +use aya_ebpf::{ + bindings::xdp_action, + btf_maps::Array, + macros::{btf_map, xdp}, + programs::XdpContext, + spin_lock::EbpfSpinLock as _, +}; +use integration_common::spin_lock::Counter; + +#[btf_map] +static COUNTER: Array = Array::new(); + +#[xdp] +fn packet_counter(_ctx: XdpContext) -> u32 { + let Some(counter) = COUNTER.get_ptr_mut(0) else { + return xdp_action::XDP_PASS; + }; + let counter = unsafe { &mut *counter }; + { + let _guard = counter.spin_lock.lock(); + counter.count = counter.count.saturating_add(1); + } + + xdp_action::XDP_PASS +} diff --git a/test/integration-test/src/lib.rs b/test/integration-test/src/lib.rs index f7da1d9f..e2cc7a7d 100644 --- a/test/integration-test/src/lib.rs +++ b/test/integration-test/src/lib.rs @@ -51,6 +51,7 @@ bpf_file!( RELOCATIONS => "relocations", RING_BUF => "ring_buf", SIMPLE_PROG => "simple_prog", + SPIN_LOCK => "spin_lock", STRNCMP => "strncmp", TCX => "tcx", TEST => "test", diff --git a/test/integration-test/src/tests.rs b/test/integration-test/src/tests.rs index 0e376c5a..c178fd7b 100644 --- a/test/integration-test/src/tests.rs +++ b/test/integration-test/src/tests.rs @@ -14,6 +14,7 @@ mod rbpf; mod relocations; mod ring_buf; mod smoke; +mod spin_lock; mod strncmp; mod tcx; mod uprobe_cookie; diff --git a/test/integration-test/src/tests/spin_lock.rs b/test/integration-test/src/tests/spin_lock.rs new file mode 100644 index 00000000..ed06b1e7 --- /dev/null +++ b/test/integration-test/src/tests/spin_lock.rs @@ -0,0 +1,51 @@ +use std::{net::UdpSocket, time::Duration}; + +use aya::{ + EbpfLoader, + maps::Array, + programs::{Xdp, XdpFlags}, +}; +use integration_common::spin_lock::Counter; + +use crate::utils::NetNsGuard; + +#[test_log::test] +fn test_spin_lock() { + let _netns = NetNsGuard::new(); + + let mut ebpf = EbpfLoader::new().load(crate::SPIN_LOCK).unwrap(); + + let prog: &mut Xdp = ebpf + .program_mut("packet_counter") + .unwrap() + .try_into() + .unwrap(); + prog.load().unwrap(); + prog.attach("lo", XdpFlags::default()).unwrap(); + + const PAYLOAD: &str = "hello counter"; + + let sock = UdpSocket::bind("127.0.0.1:0").unwrap(); + let addr = sock.local_addr().unwrap(); + sock.set_read_timeout(Some(Duration::from_secs(60))) + .unwrap(); + + let num_packets = 10; + for _ in 0..num_packets { + sock.send_to(PAYLOAD.as_bytes(), addr).unwrap(); + } + + // Read back the packets to ensure it went through the entire network stack, + // including the XDP program. + let mut buf = [0u8; PAYLOAD.len() + 1]; + for _ in 0..num_packets { + let n = sock.recv(&mut buf).unwrap(); + assert_eq!(n, PAYLOAD.len()); + assert_eq!(&buf[..n], PAYLOAD.as_bytes()); + } + + let counter_map = ebpf.map("COUNTER").unwrap(); + let counter_map = Array::<_, Counter>::try_from(counter_map).unwrap(); + let Counter { count, .. } = counter_map.get(&0, 0).unwrap(); + assert_eq!(count, num_packets); +}