diff --git a/Cargo.toml b/Cargo.toml index d3e4f644..b7146682 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -93,6 +93,7 @@ rbpf = { version = "0.3.0", default-features = false } rustdoc-json = { version = "0.9.0", default-features = false } rustup-toolchain = { version = "0.1.5", default-features = false } rustversion = { version = "1.0.0", default-features = false } +scopeguard = { version = "1.2.0", default-features = false } syn = { version = "2", default-features = false } tempfile = { version = "3", default-features = false } test-case = { version = "3.1.0", default-features = false } diff --git a/test/integration-test/Cargo.toml b/test/integration-test/Cargo.toml index a6b22133..162fc82c 100644 --- a/test/integration-test/Cargo.toml +++ b/test/integration-test/Cargo.toml @@ -29,6 +29,7 @@ netns-rs = { workspace = true } object = { workspace = true, features = ["elf", "read_core", "std"] } rand = { workspace = true, features = ["thread_rng"] } rbpf = { workspace = true } +scopeguard = { workspace = true } test-case = { workspace = true } test-log = { workspace = true, features = ["log"] } tokio = { workspace = true, features = ["macros", "rt-multi-thread", "time"] } diff --git a/test/integration-test/src/tests/info.rs b/test/integration-test/src/tests/info.rs index 700ca936..d4410b77 100644 --- a/test/integration-test/src/tests/info.rs +++ b/test/integration-test/src/tests/info.rs @@ -12,16 +12,12 @@ use aya::{ Ebpf, maps::{Array, HashMap, IterableMap as _, MapError, MapType, loaded_maps}, programs::{ProgramError, ProgramType, SocketFilter, TracePoint, loaded_programs}, - sys::enable_stats, util::KernelVersion, }; use libc::EINVAL; use crate::utils::{kernel_assert, kernel_assert_eq}; -const BPF_JIT_ENABLE: &str = "/proc/sys/net/core/bpf_jit_enable"; -const BPF_STATS_ENABLED: &str = "/proc/sys/kernel/bpf_stats_enabled"; - #[test] fn test_loaded_programs() { // Load a program. @@ -57,16 +53,7 @@ fn test_loaded_programs() { #[test] fn test_program_info() { // Kernels below v4.15 have been observed to have `bpf_jit_enable` disabled by default. - let previously_enabled = is_sysctl_enabled(BPF_JIT_ENABLE); - // Restore to previous state when panic occurs. - let prev_panic = panic::take_hook(); - panic::set_hook(Box::new(move |panic_info| { - if !previously_enabled { - disable_sysctl_param(BPF_JIT_ENABLE); - } - prev_panic(panic_info); - })); - let jit_enabled = previously_enabled || enable_sysctl_param(BPF_JIT_ENABLE); + let _guard = ensure_sysctl_enabled("/proc/sys/net/core/bpf_jit_enable"); let mut bpf = Ebpf::load(crate::SIMPLE_PROG).unwrap(); let prog: &mut SocketFilter = bpf.program_mut("simple_prog").unwrap().try_into().unwrap(); @@ -81,9 +68,7 @@ fn test_program_info() { ); kernel_assert!(test_prog.id() > 0, KernelVersion::new(4, 13, 0)); kernel_assert!(test_prog.tag() > 0, KernelVersion::new(4, 13, 0)); - if jit_enabled { - kernel_assert!(test_prog.size_jitted() > 0, KernelVersion::new(4, 13, 0)); - } + kernel_assert!(test_prog.size_jitted() > 0, KernelVersion::new(4, 13, 0)); kernel_assert!( test_prog.size_translated().is_some(), KernelVersion::new(4, 13, 0), @@ -121,11 +106,6 @@ fn test_program_info() { // Ensure rest of the fields do not panic. test_prog.memory_locked().unwrap(); test_prog.fd().unwrap(); - - // Restore to previous state - if !previously_enabled { - disable_sysctl_param(BPF_JIT_ENABLE); - } } #[test] @@ -184,23 +164,7 @@ fn test_prog_stats() { return; } - let stats_fd = enable_stats(aya::sys::Stats::RunTime).ok(); - // Restore to previous state when panic occurs. - let previously_enabled = is_sysctl_enabled(BPF_STATS_ENABLED); - let prev_panic = panic::take_hook(); - panic::set_hook(Box::new(move |panic_info| { - if !previously_enabled { - disable_sysctl_param(BPF_STATS_ENABLED); - } - prev_panic(panic_info); - })); - - let stats_enabled = - stats_fd.is_some() || previously_enabled || enable_sysctl_param(BPF_STATS_ENABLED); - if !stats_enabled { - eprintln!("ignoring test completely as bpf stats could not be enabled on the host"); - return; - } + let _guard = ensure_sysctl_enabled("/proc/sys/kernel/bpf_stats_enabled"); let mut bpf = Ebpf::load(crate::TEST).unwrap(); let prog: &mut TracePoint = bpf @@ -213,11 +177,6 @@ fn test_prog_stats() { let test_prog = prog.info().unwrap(); kernel_assert!(test_prog.run_count() > 0, KernelVersion::new(5, 1, 0)); - - // Restore to previous state - if !previously_enabled { - disable_sysctl_param(BPF_STATS_ENABLED); - } } #[test] @@ -318,20 +277,12 @@ fn test_map_info() { array.fd().unwrap(); } -/// Whether sysctl parameter is enabled in the `/proc` file. -fn is_sysctl_enabled(path: &str) -> bool { - match fs::read_to_string(path) { - Ok(contents) => contents.chars().next().is_some_and(|c| c == '1'), - Err(_) => false, - } -} - -/// Enable sysctl parameter through procfs. -fn enable_sysctl_param(path: &str) -> bool { - fs::write(path, b"1").is_ok() -} - -/// Disable sysctl parameter through procfs. -fn disable_sysctl_param(path: &str) -> bool { - fs::write(path, b"0").is_ok() +fn ensure_sysctl_enabled<'a>( + path: &'a str, +) -> Option> { + let content = fs::read_to_string(path).unwrap(); + (!content.starts_with('1')).then(move || { + fs::write(path, b"1").unwrap(); + scopeguard::guard(path, |path| fs::write(path, b"0").unwrap()) + }) }