diff --git a/aya/src/programs/tc.rs b/aya/src/programs/tc.rs index 341a8f95..e07441d6 100644 --- a/aya/src/programs/tc.rs +++ b/aya/src/programs/tc.rs @@ -8,7 +8,10 @@ use crate::{ bpf_prog_type::BPF_PROG_TYPE_SCHED_CLS, TC_H_CLSACT, TC_H_MIN_EGRESS, TC_H_MIN_INGRESS, }, programs::{load_program, Link, LinkRef, ProgramData, ProgramError}, - sys::{netlink_qdisc_add_clsact, netlink_qdisc_attach, netlink_qdisc_detach}, + sys::{ + netlink_find_filter_with_name, netlink_qdisc_add_clsact, netlink_qdisc_attach, + netlink_qdisc_detach, + }, util::{ifindex_from_ifname, tc_handler_make}, }; @@ -130,6 +133,7 @@ impl SchedClassifier { let priority = unsafe { netlink_qdisc_attach(if_index as i32, &attach_type, prog_fd, &name) } .map_err(|io_error| TcError::NetlinkError { io_error })?; + Ok(self.data.link(TcLink { if_index: if_index as i32, attach_type, @@ -165,3 +169,30 @@ pub fn qdisc_add_clsact(if_name: &str) -> Result<(), io::Error> { let if_index = ifindex_from_ifname(if_name)?; unsafe { netlink_qdisc_add_clsact(if_index as i32) } } + +/// Detaches the programs with the given name. +/// +/// # Errors +/// +/// Returns [`io::ErrorKind::NotFound`] to indicate that no programs with the +/// given name were found, so nothing was detached. Other error kinds indicate +/// an actual failure while detaching a program. +pub fn qdisc_detach_program( + if_name: &str, + attach_type: TcAttachType, + name: &str, +) -> Result<(), io::Error> { + let if_index = ifindex_from_ifname(if_name)? as i32; + let c_name = CString::new(name).unwrap(); + + let prios = unsafe { netlink_find_filter_with_name(if_index, attach_type, &c_name)? }; + if prios.is_empty() { + return Err(io::Error::new(io::ErrorKind::NotFound, format!("{}", name))); + } + + for prio in prios { + unsafe { netlink_qdisc_detach(if_index, &attach_type, prio)? }; + } + + Ok(()) +} diff --git a/aya/src/sys/netlink.rs b/aya/src/sys/netlink.rs index 36c508de..365cc702 100644 --- a/aya/src/sys/netlink.rs +++ b/aya/src/sys/netlink.rs @@ -1,11 +1,12 @@ -use std::{ffi::CStr, io, mem, os::unix::io::RawFd, ptr, slice}; +use std::{collections::HashMap, ffi::CStr, io, mem, os::unix::io::RawFd, ptr, slice}; +use thiserror::Error; use libc::{ c_int, close, getsockname, nlattr, nlmsgerr, nlmsghdr, recv, send, setsockopt, sockaddr_nl, socket, AF_NETLINK, AF_UNSPEC, ETH_P_ALL, IFLA_XDP, NETLINK_ROUTE, NLA_ALIGNTO, NLA_F_NESTED, - NLMSG_DONE, NLMSG_ERROR, NLM_F_ACK, NLM_F_CREATE, NLM_F_ECHO, NLM_F_EXCL, NLM_F_MULTI, - NLM_F_REQUEST, RTM_DELTFILTER, RTM_NEWQDISC, RTM_NEWTFILTER, RTM_SETLINK, SOCK_RAW, - SOL_NETLINK, + NLA_TYPE_MASK, NLMSG_DONE, NLMSG_ERROR, NLM_F_ACK, NLM_F_CREATE, NLM_F_DUMP, NLM_F_ECHO, + NLM_F_EXCL, NLM_F_MULTI, NLM_F_REQUEST, RTM_DELTFILTER, RTM_GETTFILTER, RTM_NEWQDISC, + RTM_NEWTFILTER, RTM_SETLINK, SOCK_RAW, SOL_NETLINK, }; use crate::{ @@ -192,6 +193,54 @@ pub(crate) unsafe fn netlink_qdisc_detach( Ok(()) } +pub(crate) unsafe fn netlink_find_filter_with_name( + if_index: i32, + attach_type: TcAttachType, + name: &CStr, +) -> Result, io::Error> { + let mut req = mem::zeroed::(); + + let nlmsg_len = mem::size_of::() + mem::size_of::(); + req.header = nlmsghdr { + nlmsg_len: nlmsg_len as u32, + nlmsg_type: RTM_GETTFILTER, + nlmsg_flags: (NLM_F_REQUEST | NLM_F_DUMP) as u16, + nlmsg_pid: 0, + nlmsg_seq: 1, + }; + req.tc_info.tcm_family = AF_UNSPEC as u8; + req.tc_info.tcm_handle = 0; // auto-assigned, if not provided + req.tc_info.tcm_ifindex = if_index; + req.tc_info.tcm_parent = attach_type.parent(); + + let sock = NetlinkSocket::open()?; + sock.send(&bytes_of(&req)[..req.header.nlmsg_len as usize])?; + + let mut prios = Vec::new(); + for msg in sock.recv()? { + if msg.header.nlmsg_type != RTM_NEWTFILTER { + continue; + } + + let tc_msg = ptr::read_unaligned(msg.data.as_ptr() as *const tcmsg); + let priority = tc_msg.tcm_info >> 16; + let attrs = parse_attrs(&msg.data[mem::size_of::()..])?; + + if let Some(opts) = attrs.get(&(TCA_OPTIONS as u16)) { + let opts = parse_attrs(&opts.data)?; + if let Some(f_name) = opts.get(&(TCA_BPF_NAME as u16)) { + if let Ok(f_name) = CStr::from_bytes_with_nul(f_name.data) { + if name == f_name { + prios.push(priority); + } + } + } + } + } + + Ok(prios) +} + #[repr(C)] struct Request { header: nlmsghdr, @@ -447,8 +496,84 @@ fn write_bytes(buf: &mut [u8], offset: usize, value: &[u8]) -> Result { + attrs: &'a [u8], + offset: usize, +} + +impl<'a> NlAttrsIterator<'a> { + fn new(attrs: &[u8]) -> NlAttrsIterator { + NlAttrsIterator { attrs, offset: 0 } + } +} + +impl<'a> Iterator for NlAttrsIterator<'a> { + type Item = Result, NlAttrError>; + + fn next(&mut self) -> Option { + let buf = &self.attrs[self.offset..]; + if buf.len() == 0 { + return None; + } + + if NLA_HDR_LEN > buf.len() { + self.offset = buf.len(); + return Some(Err(NlAttrError::InvalidBufferLength { + size: buf.len(), + expected: NLA_HDR_LEN, + })); + } + + let attr = unsafe { ptr::read_unaligned(buf.as_ptr() as *const nlattr) }; + let len = attr.nla_len as usize; + let align_len = align_to(len, NLA_ALIGNTO as usize); + if len < NLA_HDR_LEN { + return Some(Err(NlAttrError::InvalidHeaderLength(len))); + } + if align_len > buf.len() { + return Some(Err(NlAttrError::InvalidBufferLength { + size: buf.len(), + expected: align_len, + })); + } + + let data = &buf[NLA_HDR_LEN..len]; + + self.offset += align_len; + Some(Ok(NlAttr { header: attr, data })) + } +} + +fn parse_attrs(buf: &[u8]) -> Result, NlAttrError> { + let mut attrs = HashMap::new(); + for attr in NlAttrsIterator::new(buf) { + let attr = attr?; + attrs.insert(attr.header.nla_type & NLA_TYPE_MASK as u16, attr); + } + Ok(attrs) +} + +#[derive(Clone)] +struct NlAttr<'a> { + header: nlattr, + data: &'a [u8], } +#[derive(Debug, Error, PartialEq)] +enum NlAttrError { + #[error("invalid buffer size `{size}`, expected `{expected}`")] + InvalidBufferLength { size: usize, expected: usize }, + + #[error("invalid nlattr header length `{0}`")] + InvalidHeaderLength(usize), +} + +impl From for io::Error { + fn from(e: NlAttrError) -> io::Error { + io::Error::new(io::ErrorKind::Other, e) + } +} unsafe fn request_attributes(req: &mut T, msg_len: usize) -> &mut [u8] { let attrs_addr = align_to( @@ -463,8 +588,11 @@ fn bytes_of(val: &T) -> &[u8] { let size = mem::size_of::(); unsafe { slice::from_raw_parts(slice::from_ref(val).as_ptr().cast(), size) } } + #[cfg(test)] mod tests { + use std::{convert::TryInto, ffi::CString}; + use super::*; #[test] @@ -510,4 +638,88 @@ mod tests { }; assert_eq!(fd, 24); } + + #[test] + fn test_nlattr_iterator_empty() { + let mut iter = NlAttrsIterator::new(&[]); + assert!(iter.next().is_none()); + } + + #[test] + fn test_nlattr_iterator_one() { + let mut buf = [0; NLA_HDR_LEN + mem::size_of::()]; + + write_attr(&mut buf, 0, IFLA_XDP_FD as u16, 42u32).unwrap(); + + let mut iter = NlAttrsIterator::new(&buf); + let attr = iter.next().unwrap().unwrap(); + assert_eq!(attr.header.nla_type, IFLA_XDP_FD as u16); + assert_eq!(attr.data.len(), mem::size_of::()); + assert_eq!(u32::from_ne_bytes(attr.data.try_into().unwrap()), 42); + + assert!(iter.next().is_none()); + } + + #[test] + fn test_nlattr_iterator_many() { + let mut buf = [0; (NLA_HDR_LEN + mem::size_of::()) * 2]; + + write_attr(&mut buf, 0, IFLA_XDP_FD as u16, 42u32).unwrap(); + write_attr( + &mut buf, + NLA_HDR_LEN + mem::size_of::(), + IFLA_XDP_EXPECTED_FD as u16, + 12u32, + ) + .unwrap(); + + let mut iter = NlAttrsIterator::new(&buf); + + let attr = iter.next().unwrap().unwrap(); + assert_eq!(attr.header.nla_type, IFLA_XDP_FD as u16); + assert_eq!(attr.data.len(), mem::size_of::()); + assert_eq!(u32::from_ne_bytes(attr.data.try_into().unwrap()), 42); + + let attr = iter.next().unwrap().unwrap(); + assert_eq!(attr.header.nla_type, IFLA_XDP_EXPECTED_FD as u16); + assert_eq!(attr.data.len(), mem::size_of::()); + assert_eq!(u32::from_ne_bytes(attr.data.try_into().unwrap()), 12); + + assert!(iter.next().is_none()); + } + + #[test] + fn test_nlattr_iterator_nested() { + let mut buf = [0; 1024]; + + let mut options = NestedAttrs::new(&mut buf, TCA_OPTIONS as u16); + options.write_attr(TCA_BPF_FD as u16, 42).unwrap(); + + let name = CString::new("foo").unwrap(); + options + .write_attr_bytes(TCA_BPF_NAME as u16, name.to_bytes_with_nul()) + .unwrap(); + options.finish().unwrap(); + + let mut iter = NlAttrsIterator::new(&buf); + let outer = iter.next().unwrap().unwrap(); + assert_eq!( + outer.header.nla_type & NLA_TYPE_MASK as u16, + TCA_OPTIONS as u16 + ); + + let mut iter = NlAttrsIterator::new(&outer.data); + let inner = iter.next().unwrap().unwrap(); + assert_eq!( + inner.header.nla_type & NLA_TYPE_MASK as u16, + TCA_BPF_FD as u16 + ); + let inner = iter.next().unwrap().unwrap(); + assert_eq!( + inner.header.nla_type & NLA_TYPE_MASK as u16, + TCA_BPF_NAME as u16 + ); + let name = CStr::from_bytes_with_nul(inner.data).unwrap(); + assert_eq!(name.to_string_lossy(), "foo"); + } }