netlink: use OwnedFd

Updates #612.
pull/771/head
Tamir Duberstein 1 year ago
parent c4643b395f
commit cee0265b52
No known key found for this signature in database

@ -11,7 +11,7 @@ use std::{
ffi::CString,
hash::Hash,
io,
os::fd::{AsFd as _, AsRawFd as _, RawFd},
os::fd::{AsFd as _, AsRawFd as _, BorrowedFd, RawFd},
};
use thiserror::Error;
@ -201,6 +201,8 @@ impl Xdp {
XdpLinkInner::NlLink(nl_link) => {
let if_index = nl_link.if_index;
let old_prog_fd = nl_link.prog_fd;
// SAFETY: TODO(https://github.com/aya-rs/aya/issues/612): make this safe by not holding `RawFd`s.
let old_prog_fd = unsafe { BorrowedFd::borrow_raw(old_prog_fd) };
let flags = nl_link.flags;
let replace_flags = flags | XdpFlags::REPLACE;
unsafe {
@ -246,7 +248,9 @@ impl Link for NlLink {
} else {
self.flags.bits()
};
let _ = unsafe { netlink_set_xdp_fd(self.if_index, None, Some(self.prog_fd), flags) };
// SAFETY: TODO(https://github.com/aya-rs/aya/issues/612): make this safe by not holding `RawFd`s.
let prog_fd = unsafe { BorrowedFd::borrow_raw(self.prog_fd) };
let _ = unsafe { netlink_set_xdp_fd(self.if_index, None, Some(prog_fd), flags) };
Ok(())
}
}

@ -2,13 +2,13 @@ use std::{
collections::HashMap,
ffi::CStr,
io, mem,
os::fd::{AsRawFd as _, BorrowedFd, RawFd},
os::fd::{AsRawFd as _, BorrowedFd, FromRawFd as _, OwnedFd},
ptr, slice,
};
use thiserror::Error;
use libc::{
close, getsockname, nlattr, nlmsgerr, nlmsghdr, recv, send, setsockopt, sockaddr_nl, socket,
getsockname, nlattr, nlmsgerr, nlmsghdr, recv, send, setsockopt, sockaddr_nl, socket,
AF_NETLINK, AF_UNSPEC, ETH_P_ALL, IFF_UP, IFLA_XDP, NETLINK_EXT_ACK, NETLINK_ROUTE,
NLA_ALIGNTO, NLA_F_NESTED, NLA_TYPE_MASK, NLMSG_DONE, NLMSG_ERROR, NLM_F_ACK, NLM_F_CREATE,
NLM_F_DUMP, NLM_F_ECHO, NLM_F_EXCL, NLM_F_MULTI, NLM_F_REQUEST, RTM_DELTFILTER, RTM_GETTFILTER,
@ -32,7 +32,7 @@ const NLA_HDR_LEN: usize = align_to(mem::size_of::<nlattr>(), NLA_ALIGNTO as usi
pub(crate) unsafe fn netlink_set_xdp_fd(
if_index: i32,
fd: Option<BorrowedFd<'_>>,
old_fd: Option<RawFd>,
old_fd: Option<BorrowedFd<'_>>,
flags: u32,
) -> Result<(), io::Error> {
let sock = NetlinkSocket::open()?;
@ -64,7 +64,10 @@ pub(crate) unsafe fn netlink_set_xdp_fd(
}
if flags & XDP_FLAGS_REPLACE != 0 {
attrs.write_attr(IFLA_XDP_EXPECTED_FD as u16, old_fd.unwrap())?;
attrs.write_attr(
IFLA_XDP_EXPECTED_FD as u16,
old_fd.map(|fd| fd.as_raw_fd()).unwrap(),
)?;
}
let nla_len = attrs.finish()?;
@ -290,7 +293,7 @@ struct TcRequest {
}
struct NetlinkSocket {
sock: RawFd,
sock: OwnedFd,
_nl_pid: u32,
}
@ -301,12 +304,14 @@ impl NetlinkSocket {
if sock < 0 {
return Err(io::Error::last_os_error());
}
// SAFETY: `socket` returns a file descriptor.
let sock = unsafe { OwnedFd::from_raw_fd(sock) };
let enable = 1i32;
// Safety: libc wrapper
unsafe {
setsockopt(
sock,
sock.as_raw_fd(),
SOL_NETLINK,
NETLINK_EXT_ACK,
&enable as *const _ as *const _,
@ -319,7 +324,13 @@ impl NetlinkSocket {
addr.nl_family = AF_NETLINK as u16;
let mut addr_len = mem::size_of::<sockaddr_nl>() as u32;
// Safety: libc wrapper
if unsafe { getsockname(sock, &mut addr as *mut _ as *mut _, &mut addr_len as *mut _) } < 0
if unsafe {
getsockname(
sock.as_raw_fd(),
&mut addr as *mut _ as *mut _,
&mut addr_len as *mut _,
)
} < 0
{
return Err(io::Error::last_os_error());
}
@ -331,7 +342,15 @@ impl NetlinkSocket {
}
fn send(&self, msg: &[u8]) -> Result<(), io::Error> {
if unsafe { send(self.sock, msg.as_ptr() as *const _, msg.len(), 0) } < 0 {
if unsafe {
send(
self.sock.as_raw_fd(),
msg.as_ptr() as *const _,
msg.len(),
0,
)
} < 0
{
return Err(io::Error::last_os_error());
}
Ok(())
@ -344,7 +363,14 @@ impl NetlinkSocket {
'out: while multipart {
multipart = false;
// Safety: libc wrapper
let len = unsafe { recv(self.sock, buf.as_mut_ptr() as *mut _, buf.len(), 0) };
let len = unsafe {
recv(
self.sock.as_raw_fd(),
buf.as_mut_ptr() as *mut _,
buf.len(),
0,
)
};
if len < 0 {
return Err(io::Error::last_os_error());
}
@ -430,13 +456,6 @@ impl NetlinkMessage {
}
}
impl Drop for NetlinkSocket {
fn drop(&mut self) {
// Safety: libc wrapper
unsafe { close(self.sock) };
}
}
const fn align_to(v: usize, align: usize) -> usize {
(v + (align - 1)) & !(align - 1)
}

Loading…
Cancel
Save