From cee0265b5291acb747cf3a9532cfbf61c455f398 Mon Sep 17 00:00:00 2001 From: Tamir Duberstein Date: Tue, 29 Aug 2023 17:38:55 -0400 Subject: [PATCH] netlink: use OwnedFd Updates #612. --- aya/src/programs/xdp.rs | 8 +++++-- aya/src/sys/netlink.rs | 51 ++++++++++++++++++++++++++++------------- 2 files changed, 41 insertions(+), 18 deletions(-) diff --git a/aya/src/programs/xdp.rs b/aya/src/programs/xdp.rs index 499b2b38..89574328 100644 --- a/aya/src/programs/xdp.rs +++ b/aya/src/programs/xdp.rs @@ -11,7 +11,7 @@ use std::{ ffi::CString, hash::Hash, io, - os::fd::{AsFd as _, AsRawFd as _, RawFd}, + os::fd::{AsFd as _, AsRawFd as _, BorrowedFd, RawFd}, }; use thiserror::Error; @@ -201,6 +201,8 @@ impl Xdp { XdpLinkInner::NlLink(nl_link) => { let if_index = nl_link.if_index; let old_prog_fd = nl_link.prog_fd; + // SAFETY: TODO(https://github.com/aya-rs/aya/issues/612): make this safe by not holding `RawFd`s. + let old_prog_fd = unsafe { BorrowedFd::borrow_raw(old_prog_fd) }; let flags = nl_link.flags; let replace_flags = flags | XdpFlags::REPLACE; unsafe { @@ -246,7 +248,9 @@ impl Link for NlLink { } else { self.flags.bits() }; - let _ = unsafe { netlink_set_xdp_fd(self.if_index, None, Some(self.prog_fd), flags) }; + // SAFETY: TODO(https://github.com/aya-rs/aya/issues/612): make this safe by not holding `RawFd`s. + let prog_fd = unsafe { BorrowedFd::borrow_raw(self.prog_fd) }; + let _ = unsafe { netlink_set_xdp_fd(self.if_index, None, Some(prog_fd), flags) }; Ok(()) } } diff --git a/aya/src/sys/netlink.rs b/aya/src/sys/netlink.rs index 3643e95b..c4679aab 100644 --- a/aya/src/sys/netlink.rs +++ b/aya/src/sys/netlink.rs @@ -2,13 +2,13 @@ use std::{ collections::HashMap, ffi::CStr, io, mem, - os::fd::{AsRawFd as _, BorrowedFd, RawFd}, + os::fd::{AsRawFd as _, BorrowedFd, FromRawFd as _, OwnedFd}, ptr, slice, }; use thiserror::Error; use libc::{ - close, getsockname, nlattr, nlmsgerr, nlmsghdr, recv, send, setsockopt, sockaddr_nl, socket, + getsockname, nlattr, nlmsgerr, nlmsghdr, recv, send, setsockopt, sockaddr_nl, socket, AF_NETLINK, AF_UNSPEC, ETH_P_ALL, IFF_UP, IFLA_XDP, NETLINK_EXT_ACK, NETLINK_ROUTE, NLA_ALIGNTO, NLA_F_NESTED, NLA_TYPE_MASK, NLMSG_DONE, NLMSG_ERROR, NLM_F_ACK, NLM_F_CREATE, NLM_F_DUMP, NLM_F_ECHO, NLM_F_EXCL, NLM_F_MULTI, NLM_F_REQUEST, RTM_DELTFILTER, RTM_GETTFILTER, @@ -32,7 +32,7 @@ const NLA_HDR_LEN: usize = align_to(mem::size_of::(), NLA_ALIGNTO as usi pub(crate) unsafe fn netlink_set_xdp_fd( if_index: i32, fd: Option>, - old_fd: Option, + old_fd: Option>, flags: u32, ) -> Result<(), io::Error> { let sock = NetlinkSocket::open()?; @@ -64,7 +64,10 @@ pub(crate) unsafe fn netlink_set_xdp_fd( } if flags & XDP_FLAGS_REPLACE != 0 { - attrs.write_attr(IFLA_XDP_EXPECTED_FD as u16, old_fd.unwrap())?; + attrs.write_attr( + IFLA_XDP_EXPECTED_FD as u16, + old_fd.map(|fd| fd.as_raw_fd()).unwrap(), + )?; } let nla_len = attrs.finish()?; @@ -290,7 +293,7 @@ struct TcRequest { } struct NetlinkSocket { - sock: RawFd, + sock: OwnedFd, _nl_pid: u32, } @@ -301,12 +304,14 @@ impl NetlinkSocket { if sock < 0 { return Err(io::Error::last_os_error()); } + // SAFETY: `socket` returns a file descriptor. + let sock = unsafe { OwnedFd::from_raw_fd(sock) }; let enable = 1i32; // Safety: libc wrapper unsafe { setsockopt( - sock, + sock.as_raw_fd(), SOL_NETLINK, NETLINK_EXT_ACK, &enable as *const _ as *const _, @@ -319,7 +324,13 @@ impl NetlinkSocket { addr.nl_family = AF_NETLINK as u16; let mut addr_len = mem::size_of::() as u32; // Safety: libc wrapper - if unsafe { getsockname(sock, &mut addr as *mut _ as *mut _, &mut addr_len as *mut _) } < 0 + if unsafe { + getsockname( + sock.as_raw_fd(), + &mut addr as *mut _ as *mut _, + &mut addr_len as *mut _, + ) + } < 0 { return Err(io::Error::last_os_error()); } @@ -331,7 +342,15 @@ impl NetlinkSocket { } fn send(&self, msg: &[u8]) -> Result<(), io::Error> { - if unsafe { send(self.sock, msg.as_ptr() as *const _, msg.len(), 0) } < 0 { + if unsafe { + send( + self.sock.as_raw_fd(), + msg.as_ptr() as *const _, + msg.len(), + 0, + ) + } < 0 + { return Err(io::Error::last_os_error()); } Ok(()) @@ -344,7 +363,14 @@ impl NetlinkSocket { 'out: while multipart { multipart = false; // Safety: libc wrapper - let len = unsafe { recv(self.sock, buf.as_mut_ptr() as *mut _, buf.len(), 0) }; + let len = unsafe { + recv( + self.sock.as_raw_fd(), + buf.as_mut_ptr() as *mut _, + buf.len(), + 0, + ) + }; if len < 0 { return Err(io::Error::last_os_error()); } @@ -430,13 +456,6 @@ impl NetlinkMessage { } } -impl Drop for NetlinkSocket { - fn drop(&mut self) { - // Safety: libc wrapper - unsafe { close(self.sock) }; - } -} - const fn align_to(v: usize, align: usize) -> usize { (v + (align - 1)) & !(align - 1) }