From ad961db0382866869883effb5ca3d0167887c7d2 Mon Sep 17 00:00:00 2001 From: Dave Tucker Date: Wed, 26 Jul 2023 22:45:18 +0100 Subject: [PATCH] aya: Refactor Netlink Error Handling Signed-off-by: Dave Tucker --- aya/src/programs/mod.rs | 6 ++- aya/src/programs/tc.rs | 39 +++++++------- aya/src/programs/xdp.rs | 6 +-- aya/src/sys/netlink.rs | 110 +++++++++++++++++++++++++++------------- 4 files changed, 105 insertions(+), 56 deletions(-) diff --git a/aya/src/programs/mod.rs b/aya/src/programs/mod.rs index f868dfc8..6f0dcf63 100644 --- a/aya/src/programs/mod.rs +++ b/aya/src/programs/mod.rs @@ -64,7 +64,7 @@ pub mod uprobe; mod utils; pub mod xdp; -use crate::util::KernelVersion; +use crate::{sys::NetlinkError, util::KernelVersion}; use libc::ENOSPC; use std::{ ffi::CString, @@ -212,6 +212,10 @@ pub enum ProgramError { /// An error occurred while working with IO. #[error(transparent)] IOError(#[from] io::Error), + + /// An error occurred while working with Netlink. + #[error(transparent)] + NetlinkError(#[from] NetlinkError), } /// A [`Program`] file descriptor. diff --git a/aya/src/programs/tc.rs b/aya/src/programs/tc.rs index fad633a3..5c87762e 100644 --- a/aya/src/programs/tc.rs +++ b/aya/src/programs/tc.rs @@ -14,7 +14,7 @@ use crate::{ programs::{define_link_wrapper, load_program, Link, ProgramData, ProgramError}, sys::{ netlink_find_filter_with_name, netlink_qdisc_add_clsact, netlink_qdisc_attach, - netlink_qdisc_detach, + netlink_qdisc_detach, NetlinkError, }, util::{ifindex_from_ifname, tc_handler_make}, VerifierLogLevel, @@ -80,12 +80,16 @@ pub struct SchedClassifier { #[derive(Debug, Error)] pub enum TcError { /// netlink error while attaching ebpf program - #[error("netlink error while attaching ebpf program to tc")] - NetlinkError { - /// the [`io::Error`] from the netlink call - #[source] - io_error: io::Error, - }, + #[error(transparent)] + NetlinkError(#[from] NetlinkError), + + /// the provided string contains a nul byte + #[error(transparent)] + NulError(#[from] std::ffi::NulError), + + #[error(transparent)] + /// an IO error occurred + IoError(#[from] io::Error), /// the clsact qdisc is already attached #[error("the clsact qdisc is already attached")] AlreadyAttached, @@ -153,8 +157,7 @@ impl SchedClassifier { options: TcOptions, ) -> Result { let prog_fd = self.data.fd_or_err()?; - let if_index = ifindex_from_ifname(interface) - .map_err(|io_error| TcError::NetlinkError { io_error })?; + let if_index = ifindex_from_ifname(interface).map_err(TcError::IoError)?; let (priority, handle) = unsafe { netlink_qdisc_attach( if_index as i32, @@ -165,7 +168,7 @@ impl SchedClassifier { options.handle, ) } - .map_err(|io_error| TcError::NetlinkError { io_error })?; + .map_err(ProgramError::NetlinkError)?; self.data.links.insert(SchedClassifierLink::new(TcLink { if_index: if_index as i32, @@ -230,7 +233,7 @@ impl Link for TcLink { unsafe { netlink_qdisc_detach(self.if_index, &self.attach_type, self.priority, self.handle) } - .map_err(|io_error| TcError::NetlinkError { io_error })?; + .map_err(ProgramError::NetlinkError)?; Ok(()) } } @@ -306,9 +309,9 @@ impl SchedClassifierLink { /// /// The `clsact` qdisc must be added to an interface before [`SchedClassifier`] /// programs can be attached. -pub fn qdisc_add_clsact(if_name: &str) -> Result<(), io::Error> { +pub fn qdisc_add_clsact(if_name: &str) -> Result<(), TcError> { let if_index = ifindex_from_ifname(if_name)?; - unsafe { netlink_qdisc_add_clsact(if_index as i32) } + unsafe { netlink_qdisc_add_clsact(if_index as i32).map_err(TcError::NetlinkError) } } /// Detaches the programs with the given name. @@ -322,8 +325,8 @@ pub fn qdisc_detach_program( if_name: &str, attach_type: TcAttachType, name: &str, -) -> Result<(), io::Error> { - let cstr = CString::new(name)?; +) -> Result<(), TcError> { + let cstr = CString::new(name).map_err(TcError::NulError)?; qdisc_detach_program_fast(if_name, attach_type, &cstr) } @@ -340,15 +343,15 @@ fn qdisc_detach_program_fast( if_name: &str, attach_type: TcAttachType, name: &CStr, -) -> Result<(), io::Error> { +) -> Result<(), TcError> { let if_index = ifindex_from_ifname(if_name)? as i32; let filter_info = unsafe { netlink_find_filter_with_name(if_index, attach_type, name)? }; if filter_info.is_empty() { - return Err(io::Error::new( + return Err(TcError::IoError(io::Error::new( io::ErrorKind::NotFound, name.to_string_lossy(), - )); + ))); } for (prio, handle) in filter_info { diff --git a/aya/src/programs/xdp.rs b/aya/src/programs/xdp.rs index a85c87c2..d4222071 100644 --- a/aya/src/programs/xdp.rs +++ b/aya/src/programs/xdp.rs @@ -1,9 +1,9 @@ //! eXpress Data Path (XDP) programs. -use crate::util::KernelVersion; +use crate::{sys::NetlinkError, util::KernelVersion}; use bitflags; use libc::if_nametoindex; -use std::{convert::TryFrom, ffi::CString, hash::Hash, io, mem, os::unix::io::RawFd}; +use std::{convert::TryFrom, ffi::CString, hash::Hash, mem, os::unix::io::RawFd}; use thiserror::Error; use crate::{ @@ -28,7 +28,7 @@ pub enum XdpError { NetlinkError { /// the [`io::Error`] from the netlink call #[source] - io_error: io::Error, + io_error: NetlinkError, }, } diff --git a/aya/src/sys/netlink.rs b/aya/src/sys/netlink.rs index a5619003..27ef2259 100644 --- a/aya/src/sys/netlink.rs +++ b/aya/src/sys/netlink.rs @@ -29,7 +29,7 @@ pub(crate) unsafe fn netlink_set_xdp_fd( fd: RawFd, old_fd: Option, flags: u32, -) -> Result<(), io::Error> { +) -> Result<(), NetlinkError> { let sock = NetlinkSocket::open()?; // Safety: Request is POD so this is safe @@ -49,17 +49,25 @@ pub(crate) unsafe fn netlink_set_xdp_fd( // write the attrs let attrs_buf = request_attributes(&mut req, nlmsg_len); let mut attrs = NestedAttrs::new(attrs_buf, IFLA_XDP); - attrs.write_attr(IFLA_XDP_FD as u16, fd)?; + attrs + .write_attr(IFLA_XDP_FD as u16, fd) + .map_err(|e| NetlinkError(NetlinkErrorRepr::IoError(e)))?; if flags > 0 { - attrs.write_attr(IFLA_XDP_FLAGS as u16, flags)?; + attrs + .write_attr(IFLA_XDP_FLAGS as u16, flags) + .map_err(|e| NetlinkError(NetlinkErrorRepr::IoError(e)))?; } if flags & XDP_FLAGS_REPLACE != 0 { - attrs.write_attr(IFLA_XDP_EXPECTED_FD as u16, old_fd.unwrap())?; + attrs + .write_attr(IFLA_XDP_EXPECTED_FD as u16, old_fd.unwrap()) + .map_err(|e| NetlinkError(NetlinkErrorRepr::IoError(e)))?; } - let nla_len = attrs.finish()?; + let nla_len = attrs + .finish() + .map_err(|e| NetlinkError(NetlinkErrorRepr::IoError(e)))?; req.header.nlmsg_len += align_to(nla_len, NLA_ALIGNTO as usize) as u32; sock.send(&bytes_of(&req)[..req.header.nlmsg_len as usize])?; @@ -67,7 +75,7 @@ pub(crate) unsafe fn netlink_set_xdp_fd( Ok(()) } -pub(crate) unsafe fn netlink_qdisc_add_clsact(if_index: i32) -> Result<(), io::Error> { +pub(crate) unsafe fn netlink_qdisc_add_clsact(if_index: i32) -> Result<(), NetlinkError> { let sock = NetlinkSocket::open()?; let mut req = mem::zeroed::(); @@ -88,7 +96,8 @@ pub(crate) unsafe fn netlink_qdisc_add_clsact(if_index: i32) -> Result<(), io::E // add the TCA_KIND attribute let attrs_buf = request_attributes(&mut req, nlmsg_len); - let attr_len = write_attr_bytes(attrs_buf, 0, TCA_KIND as u16, b"clsact\0")?; + let attr_len = write_attr_bytes(attrs_buf, 0, TCA_KIND as u16, b"clsact\0") + .map_err(|e| NetlinkError(NetlinkErrorRepr::IoError(e)))?; req.header.nlmsg_len += align_to(attr_len, NLA_ALIGNTO as usize) as u32; sock.send(&bytes_of(&req)[..req.header.nlmsg_len as usize])?; @@ -104,7 +113,7 @@ pub(crate) unsafe fn netlink_qdisc_attach( prog_name: &CStr, priority: u16, handle: u32, -) -> Result<(u16, u32), io::Error> { +) -> Result<(u16, u32), NetlinkError> { let sock = NetlinkSocket::open()?; let mut req = mem::zeroed::(); @@ -125,15 +134,24 @@ pub(crate) unsafe fn netlink_qdisc_attach( let attrs_buf = request_attributes(&mut req, nlmsg_len); // add TCA_KIND - let kind_len = write_attr_bytes(attrs_buf, 0, TCA_KIND as u16, b"bpf\0")?; + let kind_len = write_attr_bytes(attrs_buf, 0, TCA_KIND as u16, b"bpf\0") + .map_err(|e| NetlinkError(NetlinkErrorRepr::IoError(e)))?; // add TCA_OPTIONS which includes TCA_BPF_FD, TCA_BPF_NAME and TCA_BPF_FLAGS let mut options = NestedAttrs::new(&mut attrs_buf[kind_len..], TCA_OPTIONS as u16); - options.write_attr(TCA_BPF_FD as u16, prog_fd)?; - options.write_attr_bytes(TCA_BPF_NAME as u16, prog_name.to_bytes_with_nul())?; + options + .write_attr(TCA_BPF_FD as u16, prog_fd) + .map_err(|e| NetlinkError(NetlinkErrorRepr::IoError(e)))?; + options + .write_attr_bytes(TCA_BPF_NAME as u16, prog_name.to_bytes_with_nul()) + .map_err(|e| NetlinkError(NetlinkErrorRepr::IoError(e)))?; let flags: u32 = TCA_BPF_FLAG_ACT_DIRECT; - options.write_attr(TCA_BPF_FLAGS as u16, flags)?; - let options_len = options.finish()?; + options + .write_attr(TCA_BPF_FLAGS as u16, flags) + .map_err(|e| NetlinkError(NetlinkErrorRepr::IoError(e)))?; + let options_len = options + .finish() + .map_err(|e| NetlinkError(NetlinkErrorRepr::IoError(e)))?; req.header.nlmsg_len += align_to(kind_len + options_len, NLA_ALIGNTO as usize) as u32; sock.send(&bytes_of(&req)[..req.header.nlmsg_len as usize])?; @@ -149,10 +167,10 @@ pub(crate) unsafe fn netlink_qdisc_attach( None => { // if sock.recv() succeeds we should never get here unless there's a // bug in the kernel - return Err(io::Error::new( + return Err(NetlinkError(NetlinkErrorRepr::IoError(io::Error::new( io::ErrorKind::Other, "no RTM_NEWTFILTER reply received, this is a bug.", - )); + )))); } }; @@ -165,7 +183,7 @@ pub(crate) unsafe fn netlink_qdisc_detach( attach_type: &TcAttachType, priority: u16, handle: u32, -) -> Result<(), io::Error> { +) -> Result<(), NetlinkError> { let sock = NetlinkSocket::open()?; let mut req = mem::zeroed::(); @@ -195,7 +213,7 @@ pub(crate) unsafe fn netlink_find_filter_with_name( if_index: i32, attach_type: TcAttachType, name: &CStr, -) -> Result, io::Error> { +) -> Result, NetlinkError> { let mut req = mem::zeroed::(); let nlmsg_len = mem::size_of::() + mem::size_of::(); @@ -222,10 +240,12 @@ pub(crate) unsafe fn netlink_find_filter_with_name( let tc_msg = ptr::read_unaligned(msg.data.as_ptr() as *const tcmsg); let priority = (tc_msg.tcm_info >> 16) as u16; - let attrs = parse_attrs(&msg.data[mem::size_of::()..])?; + let attrs = parse_attrs(&msg.data[mem::size_of::()..]) + .map_err(|e| NetlinkError(NetlinkErrorRepr::NlAttrError(e)))?; if let Some(opts) = attrs.get(&(TCA_OPTIONS as u16)) { - let opts = parse_attrs(opts.data)?; + let opts = parse_attrs(opts.data) + .map_err(|e| NetlinkError(NetlinkErrorRepr::NlAttrError(e)))?; if let Some(f_name) = opts.get(&(TCA_BPF_NAME as u16)) { if let Ok(f_name) = CStr::from_bytes_with_nul(f_name.data) { if name == f_name { @@ -258,12 +278,32 @@ struct NetlinkSocket { _nl_pid: u32, } +#[derive(Error, Debug)] +#[error(transparent)] +pub struct NetlinkError(#[from] NetlinkErrorRepr); + +#[derive(Error, Debug)] +pub(crate) enum NetlinkErrorRepr { + #[error("netlink error: {message}")] + Error { + message: String, + #[source] + source: io::Error, + }, + #[error(transparent)] + IoError(#[from] io::Error), + #[error(transparent)] + NulError(#[from] std::ffi::NulError), + #[error(transparent)] + NlAttrError(#[from] NlAttrError), +} + impl NetlinkSocket { - fn open() -> Result { + fn open() -> Result { // Safety: libc wrapper let sock = unsafe { socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE) }; if sock < 0 { - return Err(io::Error::last_os_error()); + return Err(NetlinkErrorRepr::IoError(io::Error::last_os_error())); } let enable = 1i32; @@ -278,7 +318,7 @@ impl NetlinkSocket { mem::size_of::() as u32, ) < 0 { - return Err(io::Error::last_os_error()); + return Err(NetlinkErrorRepr::IoError(io::Error::last_os_error())); }; // Set NETLINK_CAP_ACK to avoid getting copies of request payload. @@ -290,7 +330,7 @@ impl NetlinkSocket { mem::size_of::() as u32, ) < 0 { - return Err(io::Error::last_os_error()); + return Err(NetlinkErrorRepr::IoError(io::Error::last_os_error())); }; }; @@ -301,7 +341,7 @@ impl NetlinkSocket { // Safety: libc wrapper if unsafe { getsockname(sock, &mut addr as *mut _ as *mut _, &mut addr_len as *mut _) } < 0 { - return Err(io::Error::last_os_error()); + return Err(NetlinkErrorRepr::IoError(io::Error::last_os_error())); } Ok(NetlinkSocket { @@ -310,14 +350,14 @@ impl NetlinkSocket { }) } - fn send(&self, msg: &[u8]) -> Result<(), io::Error> { + fn send(&self, msg: &[u8]) -> Result<(), NetlinkErrorRepr> { if unsafe { send(self.sock, msg.as_ptr() as *const _, msg.len(), 0) } < 0 { - return Err(io::Error::last_os_error()); + return Err(NetlinkErrorRepr::IoError(io::Error::last_os_error())); } Ok(()) } - fn recv(&self) -> Result, io::Error> { + fn recv(&self) -> Result, NetlinkErrorRepr> { let mut buf = [0u8; 4096]; let mut messages = Vec::new(); let mut multipart = true; @@ -326,7 +366,7 @@ impl NetlinkSocket { // Safety: libc wrapper let len = unsafe { recv(self.sock, buf.as_mut_ptr() as *mut _, buf.len(), 0) }; if len < 0 { - return Err(io::Error::last_os_error()); + return Err(NetlinkErrorRepr::IoError(io::Error::last_os_error())); } if len == 0 { break; @@ -353,13 +393,15 @@ impl NetlinkSocket { }); match err_msg { Some(err_msg) => { - return Err(io::Error::new( - io::ErrorKind::Other, - format!("netlink error: {}", err_msg), - )); + return Err(NetlinkErrorRepr::Error { + message: err_msg, + source: io::Error::from_raw_os_error(-err.error), + }); } None => { - return Err(io::Error::from_raw_os_error(-err.error)); + return Err(NetlinkErrorRepr::IoError( + io::Error::from_raw_os_error(-err.error), + )); } } } @@ -591,7 +633,7 @@ struct NlAttr<'a> { } #[derive(Debug, Error, PartialEq, Eq)] -enum NlAttrError { +pub(crate) enum NlAttrError { #[error("invalid buffer size `{size}`, expected `{expected}`")] InvalidBufferLength { size: usize, expected: usize },