@ -11,9 +11,9 @@ use std::{
sync ::atomic ::{ AtomicU64 , Ordering } ,
sync ::atomic ::{ AtomicU64 , Ordering } ,
} ;
} ;
use anyhow ::{ Context as _ , Result } ;
use aya ::netlink_set_link_up ;
use aya ::netlink_set_link_up ;
use libc ::if_nametoindex ;
use libc ::if_nametoindex ;
use netns_rs ::{ NetNs , get_from_current_thread } ;
const CGROUP_ROOT : & str = "/sys/fs/cgroup" ;
const CGROUP_ROOT : & str = "/sys/fs/cgroup" ;
const CGROUP_PROCS : & str = "cgroup.procs" ;
const CGROUP_PROCS : & str = "cgroup.procs" ;
@ -91,46 +91,80 @@ impl Drop for ChildCgroup<'_> {
fd : _ ,
fd : _ ,
} = self ;
} = self ;
let pids = fs ::read_to_string ( path . as_ref ( ) . join ( CGROUP_PROCS ) ) . unwrap ( ) ;
match ( | | -> Result < ( ) > {
let dst = parent . path ( ) . join ( CGROUP_PROCS ) ;
let mut dst = fs ::OpenOptions ::new ( )
let mut dst = fs ::OpenOptions ::new ( )
. append ( true )
. append ( true )
. open ( parent . path ( ) . join ( CGROUP_PROCS ) )
. open ( & dst )
. unwrap ( ) ;
. with_context ( | | {
format! (
"fs::OpenOptions::new().append(true).open(\"{}\")" ,
dst . display ( )
)
} ) ? ;
let pids = path . as_ref ( ) . join ( CGROUP_PROCS ) ;
let pids = fs ::read_to_string ( & pids )
. with_context ( | | format! ( "fs::read_to_string(\"{}\")" , pids . display ( ) ) ) ? ;
for pid in pids . split_inclusive ( '\n' ) {
for pid in pids . split_inclusive ( '\n' ) {
dst . write_all ( pid . as_bytes ( ) ) . unwrap ( ) ;
dst . write_all ( pid . as_bytes ( ) )
. with_context ( | | format! ( "dst.write_all(\"{}\")" , pid ) ) ? ;
}
}
if let Err ( e ) = fs ::remove_dir ( & path ) {
fs ::remove_dir ( & path )
eprintln! ( "failed to remove {}: {e}" , path . display ( ) ) ;
. with_context ( | | format! ( "fs::remove_dir(\"{}\")" , path . display ( ) ) ) ? ;
Ok ( ( ) )
} ) ( ) {
Ok ( ( ) ) = > ( ) ,
Err ( err ) = > {
// Avoid panic in panic.
if std ::thread ::panicking ( ) {
eprintln! ( "{err:?}" ) ;
} else {
panic! ( "{err:?}" ) ;
}
}
}
}
}
}
}
}
pub ( crate ) struct NetNsGuard {
pub ( crate ) struct NetNsGuard {
name : String ,
name : String ,
old_ns : NetNs ,
old_ns : fs ::File ,
ns : Option < NetNs > ,
}
}
impl NetNsGuard {
impl NetNsGuard {
const PERSIST_DIR : & str = "/var/run/netns/" ;
pub ( crate ) fn new ( ) -> Self {
pub ( crate ) fn new ( ) -> Self {
let old_ns = get_from_current_thread ( ) . expect ( "Failed to get current netns" ) ;
let current_thread_netns_path = format! ( "/proc/self/task/{}/ns/net" , nix ::unistd ::gettid ( ) ) ;
let old_ns = fs ::File ::open ( & current_thread_netns_path ) . unwrap_or_else ( | err | {
panic! ( "fs::File::open(\"{current_thread_netns_path}\"): {err:?}" )
} ) ;
static COUNTER : AtomicU64 = AtomicU64 ::new ( 0 ) ;
static COUNTER : AtomicU64 = AtomicU64 ::new ( 0 ) ;
let pid = process ::id ( ) ;
let pid = process ::id ( ) ;
let name = format! ( "aya-test-{pid}-{}" , COUNTER . fetch_add ( 1 , Ordering ::Relaxed ) ) ;
let name = format! ( "aya-test-{pid}-{}" , COUNTER . fetch_add ( 1 , Ordering ::Relaxed ) ) ;
let ns = NetNs ::new ( & name ) . unwrap_or_else ( | e | panic! ( "failed to create netns {name}: {e}" ) ) ;
fs ::create_dir_all ( Self ::PERSIST_DIR )
. unwrap_or_else ( | err | panic! ( "fs::create_dir_all(\"{}\"): {err:?}" , Self ::PERSIST_DIR ) ) ;
let ns_path = Path ::new ( Self ::PERSIST_DIR ) . join ( & name ) ;
let _ : fs ::File = fs ::File ::create ( & ns_path )
. unwrap_or_else ( | err | panic! ( "fs::File::create(\"{}\"): {err:?}" , ns_path . display ( ) ) ) ;
nix ::sched ::unshare ( nix ::sched ::CloneFlags ::CLONE_NEWNET )
. expect ( "nix::sched::unshare(CLONE_NEWNET)" ) ;
nix ::mount ::mount (
Some ( current_thread_netns_path . as_str ( ) ) ,
& ns_path ,
Some ( "none" ) ,
nix ::mount ::MsFlags ::MS_BIND ,
None ::< & str > ,
)
. expect ( "nix::mount::mount" ) ;
ns . enter ( )
. unwrap_or_else ( | e | panic! ( "failed to enter network namespace {name}: {e}" ) ) ;
println! ( "entered network namespace {name}" ) ;
println! ( "entered network namespace {name}" ) ;
let ns = Self {
let ns = Self { old_ns , name } ;
old_ns ,
ns : Some ( ns ) ,
name ,
} ;
// By default, the loopback in a new netns is down. Set it up.
// By default, the loopback in a new netns is down. Set it up.
let lo = CString ::new ( "lo" ) . unwrap ( ) ;
let lo = CString ::new ( "lo" ) . unwrap ( ) ;
@ -153,17 +187,28 @@ impl NetNsGuard {
impl Drop for NetNsGuard {
impl Drop for NetNsGuard {
fn drop ( & mut self ) {
fn drop ( & mut self ) {
let Self { old_ns , ns , name } = self ;
let Self { old_ns , name } = self ;
match ( | | -> Result < ( ) > {
nix ::sched ::setns ( old_ns , nix ::sched ::CloneFlags ::CLONE_NEWNET )
. context ( "nix::sched::setns(_, CLONE_NEWNET)" ) ? ;
let ns_path = Path ::new ( Self ::PERSIST_DIR ) . join ( & name ) ;
nix ::mount ::umount2 ( & ns_path , nix ::mount ::MntFlags ::MNT_DETACH ) . with_context ( | | {
format! ( "nix::mount::umount2(\"{}\", MNT_DETACH)" , ns_path . display ( ) )
} ) ? ;
fs ::remove_file ( & ns_path )
. with_context ( | | format! ( "fs::remove_file(\"{}\")" , ns_path . display ( ) ) ) ? ;
Ok ( ( ) )
} ) ( ) {
Ok ( ( ) ) = > ( ) ,
Err ( err ) = > {
// Avoid panic in panic.
// Avoid panic in panic.
if let Err ( e ) = old_ns . enter ( ) {
if std ::thread ::panicking ( ) {
eprintln! ( "failed to return to original netns: {e}" ) ;
eprintln! ( "{err:?}" ) ;
} else {
panic! ( "{err:?}" ) ;
}
}
if let Some ( ns ) = ns . take ( ) {
if let Err ( e ) = ns . remove ( ) {
eprintln! ( "failed to remove netns {name}: {e}" ) ;
}
}
}
}
println! ( "exited network namespace {name}" ) ;
}
}
}
}