Merge pull request #771 from aya-rs/xdp-raw

netlink: use OwnedFd
pull/775/head
Tamir Duberstein 1 year ago committed by GitHub
commit c4d1d1086a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -11,7 +11,7 @@ use std::{
ffi::CString, ffi::CString,
hash::Hash, hash::Hash,
io, io,
os::fd::{AsFd as _, AsRawFd as _, RawFd}, os::fd::{AsFd as _, AsRawFd as _, BorrowedFd, RawFd},
}; };
use thiserror::Error; use thiserror::Error;
@ -201,6 +201,8 @@ impl Xdp {
XdpLinkInner::NlLink(nl_link) => { XdpLinkInner::NlLink(nl_link) => {
let if_index = nl_link.if_index; let if_index = nl_link.if_index;
let old_prog_fd = nl_link.prog_fd; let old_prog_fd = nl_link.prog_fd;
// SAFETY: TODO(https://github.com/aya-rs/aya/issues/612): make this safe by not holding `RawFd`s.
let old_prog_fd = unsafe { BorrowedFd::borrow_raw(old_prog_fd) };
let flags = nl_link.flags; let flags = nl_link.flags;
let replace_flags = flags | XdpFlags::REPLACE; let replace_flags = flags | XdpFlags::REPLACE;
unsafe { unsafe {
@ -246,7 +248,9 @@ impl Link for NlLink {
} else { } else {
self.flags.bits() self.flags.bits()
}; };
let _ = unsafe { netlink_set_xdp_fd(self.if_index, None, Some(self.prog_fd), flags) }; // SAFETY: TODO(https://github.com/aya-rs/aya/issues/612): make this safe by not holding `RawFd`s.
let prog_fd = unsafe { BorrowedFd::borrow_raw(self.prog_fd) };
let _ = unsafe { netlink_set_xdp_fd(self.if_index, None, Some(prog_fd), flags) };
Ok(()) Ok(())
} }
} }

@ -2,13 +2,13 @@ use std::{
collections::HashMap, collections::HashMap,
ffi::CStr, ffi::CStr,
io, mem, io, mem,
os::fd::{AsRawFd as _, BorrowedFd, RawFd}, os::fd::{AsRawFd as _, BorrowedFd, FromRawFd as _, OwnedFd},
ptr, slice, ptr, slice,
}; };
use thiserror::Error; use thiserror::Error;
use libc::{ use libc::{
close, getsockname, nlattr, nlmsgerr, nlmsghdr, recv, send, setsockopt, sockaddr_nl, socket, getsockname, nlattr, nlmsgerr, nlmsghdr, recv, send, setsockopt, sockaddr_nl, socket,
AF_NETLINK, AF_UNSPEC, ETH_P_ALL, IFF_UP, IFLA_XDP, NETLINK_EXT_ACK, NETLINK_ROUTE, AF_NETLINK, AF_UNSPEC, ETH_P_ALL, IFF_UP, IFLA_XDP, NETLINK_EXT_ACK, NETLINK_ROUTE,
NLA_ALIGNTO, NLA_F_NESTED, NLA_TYPE_MASK, NLMSG_DONE, NLMSG_ERROR, NLM_F_ACK, NLM_F_CREATE, NLA_ALIGNTO, NLA_F_NESTED, NLA_TYPE_MASK, NLMSG_DONE, NLMSG_ERROR, NLM_F_ACK, NLM_F_CREATE,
NLM_F_DUMP, NLM_F_ECHO, NLM_F_EXCL, NLM_F_MULTI, NLM_F_REQUEST, RTM_DELTFILTER, RTM_GETTFILTER, NLM_F_DUMP, NLM_F_ECHO, NLM_F_EXCL, NLM_F_MULTI, NLM_F_REQUEST, RTM_DELTFILTER, RTM_GETTFILTER,
@ -32,7 +32,7 @@ const NLA_HDR_LEN: usize = align_to(mem::size_of::<nlattr>(), NLA_ALIGNTO as usi
pub(crate) unsafe fn netlink_set_xdp_fd( pub(crate) unsafe fn netlink_set_xdp_fd(
if_index: i32, if_index: i32,
fd: Option<BorrowedFd<'_>>, fd: Option<BorrowedFd<'_>>,
old_fd: Option<RawFd>, old_fd: Option<BorrowedFd<'_>>,
flags: u32, flags: u32,
) -> Result<(), io::Error> { ) -> Result<(), io::Error> {
let sock = NetlinkSocket::open()?; let sock = NetlinkSocket::open()?;
@ -64,7 +64,10 @@ pub(crate) unsafe fn netlink_set_xdp_fd(
} }
if flags & XDP_FLAGS_REPLACE != 0 { if flags & XDP_FLAGS_REPLACE != 0 {
attrs.write_attr(IFLA_XDP_EXPECTED_FD as u16, old_fd.unwrap())?; attrs.write_attr(
IFLA_XDP_EXPECTED_FD as u16,
old_fd.map(|fd| fd.as_raw_fd()).unwrap(),
)?;
} }
let nla_len = attrs.finish()?; let nla_len = attrs.finish()?;
@ -290,7 +293,7 @@ struct TcRequest {
} }
struct NetlinkSocket { struct NetlinkSocket {
sock: RawFd, sock: OwnedFd,
_nl_pid: u32, _nl_pid: u32,
} }
@ -301,12 +304,14 @@ impl NetlinkSocket {
if sock < 0 { if sock < 0 {
return Err(io::Error::last_os_error()); return Err(io::Error::last_os_error());
} }
// SAFETY: `socket` returns a file descriptor.
let sock = unsafe { OwnedFd::from_raw_fd(sock) };
let enable = 1i32; let enable = 1i32;
// Safety: libc wrapper // Safety: libc wrapper
unsafe { unsafe {
setsockopt( setsockopt(
sock, sock.as_raw_fd(),
SOL_NETLINK, SOL_NETLINK,
NETLINK_EXT_ACK, NETLINK_EXT_ACK,
&enable as *const _ as *const _, &enable as *const _ as *const _,
@ -319,7 +324,13 @@ impl NetlinkSocket {
addr.nl_family = AF_NETLINK as u16; addr.nl_family = AF_NETLINK as u16;
let mut addr_len = mem::size_of::<sockaddr_nl>() as u32; let mut addr_len = mem::size_of::<sockaddr_nl>() as u32;
// Safety: libc wrapper // Safety: libc wrapper
if unsafe { getsockname(sock, &mut addr as *mut _ as *mut _, &mut addr_len as *mut _) } < 0 if unsafe {
getsockname(
sock.as_raw_fd(),
&mut addr as *mut _ as *mut _,
&mut addr_len as *mut _,
)
} < 0
{ {
return Err(io::Error::last_os_error()); return Err(io::Error::last_os_error());
} }
@ -331,7 +342,15 @@ impl NetlinkSocket {
} }
fn send(&self, msg: &[u8]) -> Result<(), io::Error> { fn send(&self, msg: &[u8]) -> Result<(), io::Error> {
if unsafe { send(self.sock, msg.as_ptr() as *const _, msg.len(), 0) } < 0 { if unsafe {
send(
self.sock.as_raw_fd(),
msg.as_ptr() as *const _,
msg.len(),
0,
)
} < 0
{
return Err(io::Error::last_os_error()); return Err(io::Error::last_os_error());
} }
Ok(()) Ok(())
@ -344,7 +363,14 @@ impl NetlinkSocket {
'out: while multipart { 'out: while multipart {
multipart = false; multipart = false;
// Safety: libc wrapper // Safety: libc wrapper
let len = unsafe { recv(self.sock, buf.as_mut_ptr() as *mut _, buf.len(), 0) }; let len = unsafe {
recv(
self.sock.as_raw_fd(),
buf.as_mut_ptr() as *mut _,
buf.len(),
0,
)
};
if len < 0 { if len < 0 {
return Err(io::Error::last_os_error()); return Err(io::Error::last_os_error());
} }
@ -430,13 +456,6 @@ impl NetlinkMessage {
} }
} }
impl Drop for NetlinkSocket {
fn drop(&mut self) {
// Safety: libc wrapper
unsafe { close(self.sock) };
}
}
const fn align_to(v: usize, align: usize) -> usize { const fn align_to(v: usize, align: usize) -> usize {
(v + (align - 1)) & !(align - 1) (v + (align - 1)) & !(align - 1)
} }

Loading…
Cancel
Save