diff --git a/Cargo.toml b/Cargo.toml index 0f618766..6741082a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,7 @@ authors = ["Alessandro Decina "] edition = "2018" [dependencies] -libc = "0.2" +libc = { version = "0.2", features = ["extra_traits"] } thiserror = "1" object = "0.23" bytes = "1" diff --git a/scripts/gen-bindings b/scripts/gen-bindings index 09703be2..c12fac4b 100755 --- a/scripts/gen-bindings +++ b/scripts/gen-bindings @@ -74,6 +74,16 @@ PERF_VARS="\ PERF_EVENT_.* " +NETLINK_TYPES="\ + ifinfomsg + " + +NETLINK_VARS="\ + NLMSG_ALIGNTO \ + IFLA_XDP_FD \ + XDP_FLAGS_.* + " + bindgen $LIBBPF_DIR/include/uapi/linux/bpf.h \ --no-layout-tests \ --default-enum-style moduleconsts \ @@ -116,4 +126,15 @@ bindgen include/perf_wrapper.h \ $(for var in $PERF_VARS; do echo --whitelist-var "$var" done) \ - > $OUTPUT_DIR/perf_bindings.rs \ No newline at end of file + > $OUTPUT_DIR/perf_bindings.rs + +bindgen include/netlink_wrapper.h \ + --no-layout-tests \ + --default-enum-style moduleconsts \ + $(for ty in $NETLINK_TYPES; do + echo --whitelist-type "$ty" + done) \ + $(for var in $NETLINK_VARS; do + echo --whitelist-var "$var" + done) \ + > $OUTPUT_DIR/netlink_bindings.rs \ No newline at end of file diff --git a/src/generated/mod.rs b/src/generated/mod.rs index 3030c9be..582f7a60 100644 --- a/src/generated/mod.rs +++ b/src/generated/mod.rs @@ -5,9 +5,11 @@ mod bpf_bindings; mod btf_bindings; mod btf_internal_bindings; +mod netlink_bindings; mod perf_bindings; pub use bpf_bindings::*; pub use btf_bindings::*; pub use btf_internal_bindings::*; +pub use netlink_bindings::*; pub use perf_bindings::*; diff --git a/src/programs/mod.rs b/src/programs/mod.rs index 7be70c9f..6341aaf7 100644 --- a/src/programs/mod.rs +++ b/src/programs/mod.rs @@ -58,6 +58,13 @@ pub enum ProgramError { io_error: io::Error, }, + #[error("error attaching XDP program using netlink: {io_error}")] + NetlinkXdpFailed { + program: String, + #[source] + io_error: io::Error, + }, + #[error("unkown network interface {name}")] UnkownInterface { name: String }, diff --git a/src/programs/xdp.rs b/src/programs/xdp.rs index e3c8162f..c30210a7 100644 --- a/src/programs/xdp.rs +++ b/src/programs/xdp.rs @@ -2,11 +2,13 @@ use std::{cell::RefCell, ffi::CString, rc::Rc}; use libc::if_nametoindex; -use crate::RawFd; +use crate::{generated::XDP_FLAGS_REPLACE, RawFd}; use crate::{ generated::{bpf_attach_type::BPF_XDP, bpf_prog_type::BPF_PROG_TYPE_XDP}, programs::{load_program, FdLink, Link, LinkRef, ProgramData, ProgramError}, sys::bpf_link_create, + sys::kernel_version, + sys::netlink_set_xdp_fd, }; #[derive(Debug)] @@ -34,15 +36,72 @@ impl Xdp { })?; } - let link_fd = bpf_link_create(prog_fd, if_index, BPF_XDP, 0).map_err(|(_, io_error)| { - ProgramError::BpfLinkCreateFailed { - program: self.name(), - io_error, - } - })? as RawFd; - let link = Rc::new(RefCell::new(FdLink { fd: Some(link_fd) })); - self.data.links.push(link.clone()); + let k_ver = kernel_version().unwrap(); + if k_ver >= (5, 7, 0) { + let link_fd = + bpf_link_create(prog_fd, if_index, BPF_XDP, 0).map_err(|(_, io_error)| { + ProgramError::BpfLinkCreateFailed { + program: self.name(), + io_error, + } + })? as RawFd; + let link = Rc::new(RefCell::new(XdpLink::FdLink(FdLink { fd: Some(link_fd) }))); + self.data.links.push(link.clone()); - Ok(LinkRef::new(&link)) + Ok(LinkRef::new(&link)) + } else { + unsafe { netlink_set_xdp_fd(if_index, prog_fd, None, 0) }.map_err(|io_error| { + ProgramError::NetlinkXdpFailed { + program: self.name(), + io_error, + } + })?; + + let link = Rc::new(RefCell::new(XdpLink::NlLink(NlLink { + if_index, + prog_fd: Some(prog_fd), + }))); + self.data.links.push(link.clone()); + + Ok(LinkRef::new(&link)) + } + } +} + +#[derive(Debug)] +struct NlLink { + if_index: i32, + prog_fd: Option, +} + +impl Link for NlLink { + fn detach(&mut self) -> Result<(), ProgramError> { + if let Some(fd) = self.prog_fd.take() { + let _ = unsafe { netlink_set_xdp_fd(self.if_index, -1, Some(fd), XDP_FLAGS_REPLACE) }; + Ok(()) + } else { + Err(ProgramError::AlreadyDetached) + } + } +} + +impl Drop for NlLink { + fn drop(&mut self) { + let _ = self.detach(); + } +} + +#[derive(Debug)] +enum XdpLink { + FdLink(FdLink), + NlLink(NlLink), +} + +impl Link for XdpLink { + fn detach(&mut self) -> Result<(), ProgramError> { + match self { + XdpLink::FdLink(link) => link.detach(), + XdpLink::NlLink(link) => link.detach(), + } } } diff --git a/src/sys/mod.rs b/src/sys/mod.rs index 08bafc05..35c193c0 100644 --- a/src/sys/mod.rs +++ b/src/sys/mod.rs @@ -1,19 +1,18 @@ mod bpf; +mod netlink; mod perf_event; #[cfg(test)] mod fake; -use std::{ - ffi::{CStr, CString}, - io, mem, -}; +use std::{ffi::CString, io, mem}; use libc::{c_int, c_long, c_ulong, pid_t, utsname}; pub(crate) use bpf::*; #[cfg(test)] pub(crate) use fake::*; +pub(crate) use netlink::*; pub(crate) use perf_event::*; use crate::generated::{bpf_attr, bpf_cmd, perf_event_attr}; diff --git a/src/sys/netlink.rs b/src/sys/netlink.rs new file mode 100644 index 00000000..9f908cab --- /dev/null +++ b/src/sys/netlink.rs @@ -0,0 +1,265 @@ +use std::{io, mem, os::unix::io::RawFd, ptr}; +use thiserror::Error; + +use libc::{ + c_int, close, getsockname, nlattr, nlmsgerr, nlmsghdr, recv, send, setsockopt, sockaddr_nl, + socket, AF_NETLINK, AF_UNSPEC, IFLA_XDP, NETLINK_ROUTE, NLA_ALIGNTO, NLA_F_NESTED, NLMSG_DONE, + NLMSG_ERROR, NLM_F_ACK, NLM_F_MULTI, NLM_F_REQUEST, RTM_SETLINK, SOCK_RAW, SOL_NETLINK, +}; + +use crate::generated::{ + _bindgen_ty_41::{IFLA_XDP_EXPECTED_FD, IFLA_XDP_FD, IFLA_XDP_FLAGS}, + ifinfomsg, NLMSG_ALIGNTO, XDP_FLAGS_REPLACE, +}; + +const NETLINK_EXT_ACK: c_int = 11; + +// Safety: marking this as unsafe overall because of all the pointer math required to comply with +// netlink alignments +pub(crate) unsafe fn netlink_set_xdp_fd( + if_index: i32, + fd: RawFd, + old_fd: Option, + flags: u32, +) -> Result<(), io::Error> { + let sock = NetlinkSocket::open().unwrap(); + + let seq = 1; + // Safety: Request is POD so this is safe + let mut req = mem::zeroed::(); + + req.header = nlmsghdr { + nlmsg_len: (mem::size_of::() + mem::size_of::()) as u32, + nlmsg_flags: (NLM_F_REQUEST | NLM_F_ACK) as u16, + nlmsg_type: RTM_SETLINK, + nlmsg_pid: 0, + nlmsg_seq: seq, + }; + req.if_info.ifi_family = AF_UNSPEC as u8; + req.if_info.ifi_index = if_index; + + let attrs_addr = &req as *const _ as usize + req.header.nlmsg_len as usize; + let attrs_addr = align_to(attrs_addr, NLMSG_ALIGNTO as usize); + let nla_hdr_len = align_to(mem::size_of::(), NLA_ALIGNTO as usize); + + // length of the root attribute + let mut nla_len = nla_hdr_len as u16; + + // set the program fd + let mut offset = attrs_addr + nla_len as usize; + let attr = nlattr { + nla_type: IFLA_XDP_FD as u16, + // header len + fd + nla_len: (nla_hdr_len + mem::size_of::()) as u16, + }; + // write the header + ptr::write(offset as *mut nlattr, attr); + offset += nla_hdr_len; + // write the fd + ptr::write(offset as *mut RawFd, fd); + offset += 4; + nla_len += attr.nla_len; + + if flags > 0 { + // set the flags + let attr = nlattr { + nla_type: IFLA_XDP_FLAGS as u16, + // header len + flags + nla_len: (nla_hdr_len + mem::size_of::()) as u16, + }; + // write the header + ptr::write(offset as *mut nlattr, attr); + offset += nla_hdr_len; + // write the flags + ptr::write(offset as *mut u32, flags); + offset += 4; + nla_len += attr.nla_len; + } + + if flags & XDP_FLAGS_REPLACE != 0 { + // set the expected fd + let attr = nlattr { + nla_type: IFLA_XDP_EXPECTED_FD as u16, + // header len + fd + nla_len: (nla_hdr_len + mem::size_of::()) as u16, + }; + // write the header + ptr::write(offset as *mut nlattr, attr); + offset += nla_hdr_len; + // write the old fd + ptr::write(offset as *mut RawFd, old_fd.unwrap()); + // offset += 4; + nla_len += attr.nla_len; + } + + // now write the root header + let attr = nlattr { + nla_type: NLA_F_NESTED as u16 | IFLA_XDP as u16, + nla_len, + }; + offset = attrs_addr; + ptr::write(offset as *mut nlattr, attr); + + req.header.nlmsg_len += align_to(nla_len as usize, NLA_ALIGNTO as usize) as u32; + + if send( + sock.sock, + &req as *const _ as *const _, + req.header.nlmsg_len as usize, + 0, + ) < 0 + { + return Err(io::Error::last_os_error())?; + } + + sock.recv()?; + + Ok(()) +} + +#[repr(C)] +struct Request { + header: nlmsghdr, + if_info: ifinfomsg, + attrs: [u8; 64], +} + +struct NetlinkSocket { + sock: RawFd, + nl_pid: u32, +} + +impl NetlinkSocket { + fn open() -> Result { + // Safety: libc wrapper + let sock = unsafe { socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE) }; + if sock < 0 { + return Err(io::Error::last_os_error())?; + } + + let enable = 1i32; + // Safety: libc wrapper + unsafe { + setsockopt( + sock, + SOL_NETLINK, + NETLINK_EXT_ACK, + &enable as *const _ as *const _, + mem::size_of::() as u32, + ) + }; + + // Safety: sockaddr_nl is POD so this is safe + let mut addr = unsafe { mem::zeroed::() }; + addr.nl_family = AF_NETLINK as u16; + let mut addr_len = mem::size_of::() as u32; + // Safety: libc wrapper + if unsafe { getsockname(sock, &mut addr as *mut _ as *mut _, &mut addr_len as *mut _) } < 0 + { + return Err(io::Error::last_os_error())?; + } + + Ok(NetlinkSocket { + sock, + nl_pid: addr.nl_pid, + }) + } + + fn recv(&self) -> Result<(), io::Error> { + let mut buf = [0u8; 4096]; + + let mut multipart = true; + while multipart { + multipart = false; + // Safety: libc wrapper + let len = unsafe { recv(self.sock, buf.as_mut_ptr() as *mut _, buf.len(), 0) }; + if len < 0 { + return Err(io::Error::last_os_error())?; + } + if len == 0 { + break; + } + + let len = len as usize; + let mut offset = 0; + while offset < len { + let message = NetlinkMessage::read(&buf[offset..])?; + offset += align_to(message.header.nlmsg_len as usize, NLMSG_ALIGNTO as usize); + + multipart = message.header.nlmsg_flags & NLM_F_MULTI as u16 != 0; + + match message.header.nlmsg_type as i32 { + NLMSG_ERROR => { + let err = message.error.unwrap(); + if err.error == 0 { + // this is an ACK + continue; + } + return Err(io::Error::new( + io::ErrorKind::Other, + format!("netlink error: {}", err.error), + )); + } + NLMSG_DONE => break, + _ => {} + } + } + } + + Ok(()) + } +} + +#[derive(Debug)] +struct NetlinkMessage { + header: nlmsghdr, + data: Vec, + error: Option, +} + +impl NetlinkMessage { + fn read(buf: &[u8]) -> Result { + if mem::size_of::() > buf.len() { + return Err(io::Error::new(io::ErrorKind::Other, "need more data")); + } + + // Safety: nlmsghdr is POD so read is safe + let header = unsafe { ptr::read_unaligned(buf.as_ptr() as *const nlmsghdr) }; + let data_offset = align_to(mem::size_of::(), NLMSG_ALIGNTO as usize); + if data_offset >= buf.len() { + return Err(io::Error::new(io::ErrorKind::Other, "need more data")); + } + + let (data, error) = if header.nlmsg_type == NLMSG_ERROR as u16 { + if data_offset + mem::size_of::() > buf.len() { + return Err(io::Error::new(io::ErrorKind::Other, "need more data")); + } + ( + Vec::new(), + // Safety: nlmsgerr is POD so read is safe + Some(unsafe { + ptr::read_unaligned(buf[data_offset..].as_ptr() as *const nlmsgerr) + }), + ) + } else { + (buf[data_offset..].to_vec(), None) + }; + + Ok(NetlinkMessage { + header, + data, + error, + }) + } +} + +impl Drop for NetlinkSocket { + fn drop(&mut self) { + // Safety: libc wrapper + unsafe { close(self.sock) }; + } +} + +fn align_to(v: usize, align: usize) -> usize { + (v + (align - 1)) & !(align - 1) +}