diff --git a/test-distro/src/init.rs b/test-distro/src/init.rs index ccc8cdfc..49cc94f2 100644 --- a/test-distro/src/init.rs +++ b/test-distro/src/init.rs @@ -97,6 +97,14 @@ fn run() -> anyhow::Result<()> { data: None, target_mode: None, }, + Mount { + source: "securityfs", + target: "/sys/kernel/security", + fstype: "securityfs", + flags: nix::mount::MsFlags::empty(), + data: None, + target_mode: None, + }, ] { match target_mode { None => { diff --git a/test/integration-test/src/tests/lsm.rs b/test/integration-test/src/tests/lsm.rs index d7a6301c..01683e2e 100644 --- a/test/integration-test/src/tests/lsm.rs +++ b/test/integration-test/src/tests/lsm.rs @@ -1,18 +1,14 @@ -use std::{ - fs::{File, OpenOptions}, - io::{ErrorKind, Write as _}, - net::TcpListener, - path::Path, -}; +use std::{io::ErrorKind, net::TcpListener}; use aya::{ Btf, Ebpf, programs::{Lsm, lsm_cgroup::LsmCgroup}, sys::is_program_supported, - util::KernelVersion, }; -fn check_sys_lsm_enabled() -> bool { +use crate::utils::Cgroup; + +fn is_lsm_bpf_enabled() -> bool { std::fs::read_to_string("/sys/kernel/security/lsm") .unwrap() .contains("bpf") @@ -20,14 +16,7 @@ fn check_sys_lsm_enabled() -> bool { #[test] fn lsm_cgroup() { - let kernel_version = KernelVersion::current().unwrap(); - if kernel_version < KernelVersion::new(6, 0, 0) { - eprintln!("skipping lsm_cgroup test on kernel {kernel_version:?}"); - return; - } - - if !(is_program_supported(aya::programs::ProgramType::Lsm).unwrap()) || !check_sys_lsm_enabled() - { + if !(is_program_supported(aya::programs::ProgramType::Lsm).unwrap() && is_lsm_bpf_enabled()) { eprintln!("LSM programs are not supported"); return; } @@ -43,46 +32,32 @@ fn lsm_cgroup() { assert_matches::assert_matches!(TcpListener::bind("127.0.0.1:12345"), Ok(_)); - let pid = std::process::id(); + let root = Cgroup::root(); + let cgroup = root.create_child("aya-test-lsm-cgroup"); + let fd = cgroup.fd(); - let cgroup_dir = Path::new("/sys/fs/cgroup/lsm_cgroup_test"); - std::fs::create_dir_all(cgroup_dir).expect("could not create cgroup dir"); - let _guard = scopeguard::guard((), |()| { - std::fs::remove_dir(cgroup_dir).unwrap(); - }); - - let link_id = prog.attach(File::open(cgroup_dir).unwrap()).unwrap(); + let link_id = prog.attach(&fd).unwrap(); let _guard = scopeguard::guard((), |()| { prog.detach(link_id).unwrap(); }); + let pid = std::process::id(); + assert_matches::assert_matches!(TcpListener::bind("127.0.0.1:12345"), Ok(_)); - let proc_path = cgroup_dir.join("cgroup.procs"); - let mut procfs = OpenOptions::new().append(true).open(proc_path).unwrap(); - write!(procfs, "{pid}").expect("could not write into procs file"); - let _guard = scopeguard::guard((), |()| { - let mut file = OpenOptions::new() - .append(true) - .open("/sys/fs/cgroup/cgroup.procs") - .unwrap(); - write!(file, "{pid}").expect("could not write into procs file"); - }); + cgroup.into_cgroup().write_pid(pid); assert_matches::assert_matches!(TcpListener::bind("127.0.0.1:12345"), Err(e) => assert_eq!( e.kind(), ErrorKind::PermissionDenied)); + + root.write_pid(pid); + + assert_matches::assert_matches!(TcpListener::bind("127.0.0.1:12345"), Ok(_)); } #[test] fn lsm() { - let kernel_version = KernelVersion::current().unwrap(); - if kernel_version < KernelVersion::new(5, 7, 0) { - eprintln!("skipping lsm test on kernel {kernel_version:?}"); - return; - } - - if !(is_program_supported(aya::programs::ProgramType::Lsm).unwrap()) || !check_sys_lsm_enabled() - { + if !(is_program_supported(aya::programs::ProgramType::Lsm).unwrap() && is_lsm_bpf_enabled()) { eprintln!("LSM programs are not supported"); return; } diff --git a/test/integration-test/src/utils.rs b/test/integration-test/src/utils.rs index 1a8cc02b..2a782bae 100644 --- a/test/integration-test/src/utils.rs +++ b/test/integration-test/src/utils.rs @@ -1,8 +1,13 @@ //! Utilities to run tests use std::{ + borrow::Cow, + cell::OnceCell, ffi::CString, - io, process, + fs, + io::{self, Write as _}, + path::Path, + process, sync::atomic::{AtomicU64, Ordering}, }; @@ -10,6 +15,96 @@ use aya::netlink_set_link_up; use libc::if_nametoindex; use netns_rs::{NetNs, get_from_current_thread}; +const CGROUP_ROOT: &str = "/sys/fs/cgroup"; +const CGROUP_PROCS: &str = "cgroup.procs"; +pub(crate) struct ChildCgroup<'a> { + parent: &'a Cgroup<'a>, + path: Cow<'a, Path>, + fd: OnceCell, +} + +pub(crate) enum Cgroup<'a> { + Root, + Child(ChildCgroup<'a>), +} + +impl Cgroup<'static> { + pub(crate) fn root() -> Self { + Self::Root + } +} + +impl<'a> Cgroup<'a> { + fn path(&self) -> &Path { + match self { + Self::Root => Path::new(CGROUP_ROOT), + Self::Child(ChildCgroup { + parent: _, + path, + fd: _, + }) => path, + } + } + + pub(crate) fn create_child(&'a self, name: &str) -> ChildCgroup<'a> { + let path = self.path().join(name); + fs::create_dir(&path).unwrap(); + + ChildCgroup { + parent: self, + path: path.into(), + fd: OnceCell::new(), + } + } + + pub(crate) fn write_pid(&self, pid: u32) { + fs::write(self.path().join(CGROUP_PROCS), format!("{pid}\n")).unwrap(); + } +} + +impl<'a> ChildCgroup<'a> { + pub(crate) fn fd(&self) -> &fs::File { + let Self { + parent: _, + path, + fd, + } = self; + fd.get_or_init(|| { + fs::OpenOptions::new() + .read(true) + .open(path.as_ref()) + .unwrap() + }) + } + + pub(crate) fn into_cgroup(self) -> Cgroup<'a> { + Cgroup::Child(self) + } +} + +impl Drop for ChildCgroup<'_> { + fn drop(&mut self) { + let Self { + parent, + path, + fd: _, + } = self; + + let pids = fs::read_to_string(path.as_ref().join(CGROUP_PROCS)).unwrap(); + let mut dst = fs::OpenOptions::new() + .append(true) + .open(parent.path().join(CGROUP_PROCS)) + .unwrap(); + for pid in pids.split_inclusive('\n') { + dst.write_all(pid.as_bytes()).unwrap(); + } + + if let Err(e) = fs::remove_dir(&path) { + eprintln!("failed to remove {}: {e}", path.display()); + } + } +} + pub(crate) struct NetNsGuard { name: String, old_ns: NetNs,