@ -8,18 +8,19 @@ use std::{
use libc ::{
getsockname , nlattr , nlmsgerr , nlmsghdr , recv , send , setsockopt , sockaddr_nl , socket ,
AF_NETLINK , AF_UNSPEC , ETH_P_ALL , IFF_UP , IFLA_XDP , NETLINK_ EXT_ACK, NETLINK_ROUTE ,
N LA_ALIGNTO, NLA_F_NESTED , NLA_TYPE_MASK , NLMSG_DONE , NLMSG_ERROR , NLM_F_ACK , NLM_F_CREATE ,
NLM_F_ DUMP, NLM_F_ECHO , NLM_F_EXCL , NLM_F_MULTI , NLM_F_REQUEST , RTM_DELTFILTER , RTM_GETTFILTER ,
RTM_ NEWQDISC, RTM_NEWTFILTER , RTM_SETLINK , SOCK_RAW , SOL_NETLINK ,
AF_NETLINK , AF_UNSPEC , ETH_P_ALL , IFF_UP , IFLA_XDP , NETLINK_ CAP_ACK, NETLINK_EXT_ACK ,
N ETLINK_ROUTE, N LA_ALIGNTO, NLA_F_NESTED , NLA_TYPE_MASK , NLMSG_DONE , NLMSG_ERROR , NLM_F_ACK ,
NLM_F_ CREATE, NLM_F_ DUMP, NLM_F_ECHO , NLM_F_EXCL , NLM_F_MULTI , NLM_F_REQUEST , RTM_DELTFILTER ,
RTM_ GETTFILTER, RTM_ NEWQDISC, RTM_NEWTFILTER , RTM_SETLINK , SOCK_RAW , SOL_NETLINK ,
} ;
use thiserror ::Error ;
use crate ::{
generated ::{
ifinfomsg , tcmsg , IFLA_XDP_EXPECTED_FD , IFLA_XDP_FD , IFLA_XDP_FLAGS , NLMSG_ALIGNTO ,
TCA_BPF_FD , TCA_BPF_FLAGS , TCA_BPF_FLAG_ACT_DIRECT , TCA_BPF_NAME , TCA_KIND , TCA_OPTIONS ,
TC_H_CLSACT , TC_H_INGRESS , TC_H_MAJ_MASK , TC_H_UNSPEC , XDP_FLAGS_REPLACE ,
ifinfomsg , nlmsgerr_attrs ::NLMSGERR_ATTR_MSG , tcmsg , IFLA_XDP_EXPECTED_FD , IFLA_XDP_FD ,
IFLA_XDP_FLAGS , NLMSG_ALIGNTO , TCA_BPF_FD , TCA_BPF_FLAGS , TCA_BPF_FLAG_ACT_DIRECT ,
TCA_BPF_NAME , TCA_KIND , TCA_OPTIONS , TC_H_CLSACT , TC_H_INGRESS , TC_H_MAJ_MASK , TC_H_UNSPEC ,
XDP_FLAGS_REPLACE ,
} ,
programs ::TcAttachType ,
util ::tc_handler_make ,
@ -27,6 +28,28 @@ use crate::{
const NLA_HDR_LEN : usize = align_to ( mem ::size_of ::< nlattr > ( ) , NLA_ALIGNTO as usize ) ;
/// A private error type for internal use in this module.
#[ derive(Error, Debug) ]
pub ( crate ) enum NetlinkErrorInternal {
#[ error( " netlink error: {message} " ) ]
Error {
message : String ,
#[ source ]
source : io ::Error ,
} ,
#[ error(transparent) ]
IoError ( #[ from ] io ::Error ) ,
#[ error(transparent) ]
NulError ( #[ from ] std ::ffi ::NulError ) ,
#[ error(transparent) ]
NlAttrError ( #[ from ] NlAttrError ) ,
}
/// An error occurred during a netlink operation.
#[ derive(Error, Debug) ]
#[ error(transparent) ]
pub struct NetlinkError ( #[ from ] NetlinkErrorInternal ) ;
// Safety: marking this as unsafe overall because of all the pointer math required to comply with
// netlink alignments
pub ( crate ) unsafe fn netlink_set_xdp_fd (
@ -34,7 +57,7 @@ pub(crate) unsafe fn netlink_set_xdp_fd(
fd : Option < BorrowedFd < ' _ > > ,
old_fd : Option < BorrowedFd < ' _ > > ,
flags : u32 ,
) -> Result < ( ) , io:: Error> {
) -> Result < ( ) , Netlink Error> {
let sock = NetlinkSocket ::open ( ) ? ;
// Safety: Request is POD so this is safe
@ -54,33 +77,39 @@ pub(crate) unsafe fn netlink_set_xdp_fd(
// write the attrs
let attrs_buf = request_attributes ( & mut req , nlmsg_len ) ;
let mut attrs = NestedAttrs ::new ( attrs_buf , IFLA_XDP ) ;
attrs . write_attr (
IFLA_XDP_FD as u16 ,
fd . map ( | fd | fd . as_raw_fd ( ) ) . unwrap_or ( - 1 ) ,
) ? ;
attrs
. write_attr (
IFLA_XDP_FD as u16 ,
fd . map ( | fd | fd . as_raw_fd ( ) ) . unwrap_or ( - 1 ) ,
)
. map_err ( | e | NetlinkError ( NetlinkErrorInternal ::IoError ( e ) ) ) ? ;
if flags > 0 {
attrs . write_attr ( IFLA_XDP_FLAGS as u16 , flags ) ? ;
attrs
. write_attr ( IFLA_XDP_FLAGS as u16 , flags )
. map_err ( | e | NetlinkError ( NetlinkErrorInternal ::IoError ( e ) ) ) ? ;
}
if flags & XDP_FLAGS_REPLACE ! = 0 {
attrs . write_attr (
IFLA_XDP_EXPECTED_FD as u16 ,
old_fd . map ( | fd | fd . as_raw_fd ( ) ) . unwrap ( ) ,
) ? ;
attrs
. write_attr (
IFLA_XDP_EXPECTED_FD as u16 ,
old_fd . map ( | fd | fd . as_raw_fd ( ) ) . unwrap ( ) ,
)
. map_err ( | e | NetlinkError ( NetlinkErrorInternal ::IoError ( e ) ) ) ? ;
}
let nla_len = attrs . finish ( ) ? ;
let nla_len = attrs
. finish ( )
. map_err ( | e | NetlinkError ( NetlinkErrorInternal ::IoError ( e ) ) ) ? ;
req . header . nlmsg_len + = align_to ( nla_len , NLA_ALIGNTO as usize ) as u32 ;
sock . send ( & bytes_of ( & req ) [ .. req . header . nlmsg_len as usize ] ) ? ;
sock . recv ( ) ? ;
Ok ( ( ) )
}
pub ( crate ) unsafe fn netlink_qdisc_add_clsact ( if_index : i32 ) -> Result < ( ) , io:: Error> {
pub ( crate ) unsafe fn netlink_qdisc_add_clsact ( if_index : i32 ) -> Result < ( ) , Netlink Error> {
let sock = NetlinkSocket ::open ( ) ? ;
let mut req = mem ::zeroed ::< TcRequest > ( ) ;
@ -101,7 +130,8 @@ pub(crate) unsafe fn netlink_qdisc_add_clsact(if_index: i32) -> Result<(), io::E
// add the TCA_KIND attribute
let attrs_buf = request_attributes ( & mut req , nlmsg_len ) ;
let attr_len = write_attr_bytes ( attrs_buf , 0 , TCA_KIND as u16 , b" clsact \0 " ) ? ;
let attr_len = write_attr_bytes ( attrs_buf , 0 , TCA_KIND as u16 , b" clsact \0 " )
. map_err ( | e | NetlinkError ( NetlinkErrorInternal ::IoError ( e ) ) ) ? ;
req . header . nlmsg_len + = align_to ( attr_len , NLA_ALIGNTO as usize ) as u32 ;
sock . send ( & bytes_of ( & req ) [ .. req . header . nlmsg_len as usize ] ) ? ;
@ -118,7 +148,7 @@ pub(crate) unsafe fn netlink_qdisc_attach(
priority : u16 ,
handle : u32 ,
create : bool ,
) -> Result < ( u16 , u32 ) , io:: Error> {
) -> Result < ( u16 , u32 ) , Netlink Error> {
let sock = NetlinkSocket ::open ( ) ? ;
let mut req = mem ::zeroed ::< TcRequest > ( ) ;
@ -152,15 +182,24 @@ pub(crate) unsafe fn netlink_qdisc_attach(
let attrs_buf = request_attributes ( & mut req , nlmsg_len ) ;
// add TCA_KIND
let kind_len = write_attr_bytes ( attrs_buf , 0 , TCA_KIND as u16 , b" bpf \0 " ) ? ;
let kind_len = write_attr_bytes ( attrs_buf , 0 , TCA_KIND as u16 , b" bpf \0 " )
. map_err ( | e | NetlinkError ( NetlinkErrorInternal ::IoError ( e ) ) ) ? ;
// add TCA_OPTIONS which includes TCA_BPF_FD, TCA_BPF_NAME and TCA_BPF_FLAGS
let mut options = NestedAttrs ::new ( & mut attrs_buf [ kind_len .. ] , TCA_OPTIONS as u16 ) ;
options . write_attr ( TCA_BPF_FD as u16 , prog_fd ) ? ;
options . write_attr_bytes ( TCA_BPF_NAME as u16 , prog_name . to_bytes_with_nul ( ) ) ? ;
options
. write_attr ( TCA_BPF_FD as u16 , prog_fd )
. map_err ( | e | NetlinkError ( NetlinkErrorInternal ::IoError ( e ) ) ) ? ;
options
. write_attr_bytes ( TCA_BPF_NAME as u16 , prog_name . to_bytes_with_nul ( ) )
. map_err ( | e | NetlinkError ( NetlinkErrorInternal ::IoError ( e ) ) ) ? ;
let flags : u32 = TCA_BPF_FLAG_ACT_DIRECT ;
options . write_attr ( TCA_BPF_FLAGS as u16 , flags ) ? ;
let options_len = options . finish ( ) ? ;
options
. write_attr ( TCA_BPF_FLAGS as u16 , flags )
. map_err ( | e | NetlinkError ( NetlinkErrorInternal ::IoError ( e ) ) ) ? ;
let options_len = options
. finish ( )
. map_err ( | e | NetlinkError ( NetlinkErrorInternal ::IoError ( e ) ) ) ? ;
req . header . nlmsg_len + = align_to ( kind_len + options_len , NLA_ALIGNTO as usize ) as u32 ;
sock . send ( & bytes_of ( & req ) [ .. req . header . nlmsg_len as usize ] ) ? ;
@ -176,10 +215,10 @@ pub(crate) unsafe fn netlink_qdisc_attach(
None = > {
// if sock.recv() succeeds we should never get here unless there's a
// bug in the kernel
return Err ( io::Error ::new (
return Err ( NetlinkError( NetlinkErrorInternal ::IoError ( io::Error ::new (
io ::ErrorKind ::Other ,
"no RTM_NEWTFILTER reply received, this is a bug." ,
) ) ;
) ) )) ;
}
} ;
@ -192,7 +231,7 @@ pub(crate) unsafe fn netlink_qdisc_detach(
attach_type : & TcAttachType ,
priority : u16 ,
handle : u32 ,
) -> Result < ( ) , io:: Error> {
) -> Result < ( ) , Netlink Error> {
let sock = NetlinkSocket ::open ( ) ? ;
let mut req = mem ::zeroed ::< TcRequest > ( ) ;
@ -222,7 +261,7 @@ pub(crate) unsafe fn netlink_find_filter_with_name(
if_index : i32 ,
attach_type : TcAttachType ,
name : & CStr ,
) -> Result < Vec < ( u16 , u32 ) > , io:: Error> {
) -> Result < Vec < ( u16 , u32 ) > , Netlink Error> {
let mut req = mem ::zeroed ::< TcRequest > ( ) ;
let nlmsg_len = mem ::size_of ::< nlmsghdr > ( ) + mem ::size_of ::< tcmsg > ( ) ;
@ -249,10 +288,12 @@ pub(crate) unsafe fn netlink_find_filter_with_name(
let tc_msg = ptr ::read_unaligned ( msg . data . as_ptr ( ) as * const tcmsg ) ;
let priority = ( tc_msg . tcm_info > > 16 ) as u16 ;
let attrs = parse_attrs ( & msg . data [ mem ::size_of ::< tcmsg > ( ) .. ] ) ? ;
let attrs = parse_attrs ( & msg . data [ mem ::size_of ::< tcmsg > ( ) .. ] )
. map_err ( | e | NetlinkError ( NetlinkErrorInternal ::NlAttrError ( e ) ) ) ? ;
if let Some ( opts ) = attrs . get ( & ( TCA_OPTIONS as u16 ) ) {
let opts = parse_attrs ( opts . data ) ? ;
let opts = parse_attrs ( opts . data )
. map_err ( | e | NetlinkError ( NetlinkErrorInternal ::NlAttrError ( e ) ) ) ? ;
if let Some ( f_name ) = opts . get ( & ( TCA_BPF_NAME as u16 ) ) {
if let Ok ( f_name ) = CStr ::from_bytes_with_nul ( f_name . data ) {
if name = = f_name {
@ -267,7 +308,7 @@ pub(crate) unsafe fn netlink_find_filter_with_name(
}
#[ doc(hidden) ]
pub unsafe fn netlink_set_link_up ( if_index : i32 ) -> Result < ( ) , io:: Error> {
pub unsafe fn netlink_set_link_up ( if_index : i32 ) -> Result < ( ) , Netlink Error> {
let sock = NetlinkSocket ::open ( ) ? ;
// Safety: Request is POD so this is safe
@ -312,11 +353,11 @@ struct NetlinkSocket {
}
impl NetlinkSocket {
fn open ( ) -> Result < Self , io::Error > {
fn open ( ) -> Result < Self , NetlinkErrorInternal > {
// Safety: libc wrapper
let sock = unsafe { socket ( AF_NETLINK , SOCK_RAW , NETLINK_ROUTE ) } ;
if sock < 0 {
return Err ( io::Error ::last_os_error ( ) ) ;
return Err ( NetlinkErrorInternal::IoError ( io::Error ::last_os_error ( ) ) ) ;
}
// SAFETY: `socket` returns a file descriptor.
let sock = unsafe { crate ::MockableFd ::from_raw_fd ( sock ) } ;
@ -324,13 +365,29 @@ impl NetlinkSocket {
let enable = 1 i32 ;
// Safety: libc wrapper
unsafe {
setsockopt (
// Set NETLINK_EXT_ACK to get extended attributes.
if setsockopt (
sock . as_raw_fd ( ) ,
SOL_NETLINK ,
NETLINK_EXT_ACK ,
& enable as * const _ as * const _ ,
mem ::size_of ::< i32 > ( ) as u32 ,
)
) < 0
{
return Err ( NetlinkErrorInternal ::IoError ( io ::Error ::last_os_error ( ) ) ) ;
} ;
// Set NETLINK_CAP_ACK to avoid getting copies of request payload.
if setsockopt (
sock . as_raw_fd ( ) ,
SOL_NETLINK ,
NETLINK_CAP_ACK ,
& enable as * const _ as * const _ ,
mem ::size_of ::< i32 > ( ) as u32 ,
) < 0
{
return Err ( NetlinkErrorInternal ::IoError ( io ::Error ::last_os_error ( ) ) ) ;
} ;
} ;
// Safety: sockaddr_nl is POD so this is safe
@ -346,7 +403,7 @@ impl NetlinkSocket {
)
} < 0
{
return Err ( io::Error ::last_os_error ( ) ) ;
return Err ( NetlinkErrorInternal::IoError ( io::Error ::last_os_error ( ) ) ) ;
}
Ok ( Self {
@ -355,7 +412,7 @@ impl NetlinkSocket {
} )
}
fn send ( & self , msg : & [ u8 ] ) -> Result < ( ) , io::Error > {
fn send ( & self , msg : & [ u8 ] ) -> Result < ( ) , NetlinkErrorInternal > {
if unsafe {
send (
self . sock . as_raw_fd ( ) ,
@ -365,12 +422,12 @@ impl NetlinkSocket {
)
} < 0
{
return Err ( io::Error ::last_os_error ( ) ) ;
return Err ( NetlinkErrorInternal::IoError ( io::Error ::last_os_error ( ) ) ) ;
}
Ok ( ( ) )
}
fn recv ( & self ) -> Result < Vec < NetlinkMessage > , io::Error > {
fn recv ( & self ) -> Result < Vec < NetlinkMessage > , NetlinkErrorInternal > {
let mut buf = [ 0 u8 ; 4096 ] ;
let mut messages = Vec ::new ( ) ;
let mut multipart = true ;
@ -386,7 +443,7 @@ impl NetlinkSocket {
)
} ;
if len < 0 {
return Err ( io::Error ::last_os_error ( ) ) ;
return Err ( NetlinkErrorInternal::IoError ( io::Error ::last_os_error ( ) ) ) ;
}
if len = = 0 {
break ;
@ -405,7 +462,22 @@ impl NetlinkSocket {
// this is an ACK
continue ;
}
return Err ( io ::Error ::from_raw_os_error ( - err . error ) ) ;
let attrs = parse_attrs ( & message . data ) ? ;
let err_msg = attrs . get ( & ( NLMSGERR_ATTR_MSG as u16 ) ) . and_then ( | msg | {
CStr ::from_bytes_with_nul ( msg . data )
. ok ( )
. map ( | s | s . to_string_lossy ( ) . into_owned ( ) )
} ) ;
let e = match err_msg {
Some ( err_msg ) = > NetlinkErrorInternal ::Error {
message : err_msg ,
source : io ::Error ::from_raw_os_error ( - err . error ) ,
} ,
None = > NetlinkErrorInternal ::IoError ( io ::Error ::from_raw_os_error (
- err . error ,
) ) ,
} ;
return Err ( e ) ;
}
NLMSG_DONE = > break 'out ,
_ = > messages . push ( message ) ,
@ -444,7 +516,7 @@ impl NetlinkMessage {
return Err ( io ::Error ::new ( io ::ErrorKind ::Other , "need more data" ) ) ;
}
let ( data , error ) = if header . nlmsg_type = = NLMSG_ERROR as u16 {
let ( rest , error ) = if header . nlmsg_type = = NLMSG_ERROR as u16 {
if data_offset + mem ::size_of ::< nlmsgerr > ( ) > buf . len ( ) {
return Err ( io ::Error ::new (
io ::ErrorKind ::Other ,
@ -452,19 +524,19 @@ impl NetlinkMessage {
) ) ;
}
(
Vec ::new ( ) ,
& buf [ data_offset + mem ::size_of ::< nlmsgerr > ( ) .. msg_len ] ,
// Safety: nlmsgerr is POD so read is safe
Some ( unsafe {
ptr ::read_unaligned ( buf [ data_offset .. ] . as_ptr ( ) as * const nlmsgerr )
} ) ,
)
} else {
( buf [ data_offset .. msg_len ] . to_vec ( ) , None )
( & buf [ data_offset .. msg_len ] , None )
} ;
Ok ( Self {
header ,
data ,
data : rest . to_vec ( ) ,
error ,
} )
}
@ -628,7 +700,7 @@ struct NlAttr<'a> {
}
#[ derive(Debug, Error, PartialEq, Eq) ]
enum NlAttrError {
pub ( crate ) enum NlAttrError {
#[ error( " invalid buffer size `{size}`, expected `{expected}` " ) ]
InvalidBufferLength { size : usize , expected : usize } ,