@ -8,10 +8,10 @@ use std::{
use libc ::{
getsockname , nlattr , nlmsgerr , nlmsghdr , recv , send , setsockopt , sockaddr_nl , socket ,
AF_NETLINK , AF_UNSPEC , ETH_P_ALL , IFF_UP , IFLA_XDP , NETLINK_ EXT_ACK, NETLINK_ROUTE ,
N LA_ALIGNTO, NLA_F_NESTED , NLA_TYPE_MASK , NLMSG_DONE , NLMSG_ERROR , NLM_F_ACK , NLM_F_CREATE ,
NLM_F_ DUMP, NLM_F_ECHO , NLM_F_EXCL , NLM_F_MULTI , NLM_F_REQUEST , RTM_DELTFILTER , RTM_GETTFILTER ,
RTM_ NEWQDISC, RTM_NEWTFILTER , RTM_SETLINK , SOCK_RAW , SOL_NETLINK ,
AF_NETLINK , AF_UNSPEC , ETH_P_ALL , IFF_UP , IFLA_XDP , NETLINK_ CAP_ACK, NETLINK_EXT_ACK ,
N ETLINK_ROUTE, N LA_ALIGNTO, NLA_F_NESTED , NLA_TYPE_MASK , NLMSG_DONE , NLMSG_ERROR , NLM_F_ACK ,
NLM_F_ CREATE, NLM_F_ DUMP, NLM_F_ECHO , NLM_F_EXCL , NLM_F_MULTI , NLM_F_REQUEST , RTM_DELTFILTER ,
RTM_ GETTFILTER, RTM_ NEWQDISC, RTM_NEWTFILTER , RTM_SETLINK , SOCK_RAW , SOL_NETLINK ,
} ;
use thiserror ::Error ;
@ -25,6 +25,7 @@ use crate::{
util ::tc_handler_make ,
} ;
const NLMSGERR_ATTR_MSG : u16 = 0x01 ;
const NLA_HDR_LEN : usize = align_to ( mem ::size_of ::< nlattr > ( ) , NLA_ALIGNTO as usize ) ;
// Safety: marking this as unsafe overall because of all the pointer math required to comply with
@ -34,7 +35,7 @@ pub(crate) unsafe fn netlink_set_xdp_fd(
fd : Option < BorrowedFd < ' _ > > ,
old_fd : Option < BorrowedFd < ' _ > > ,
flags : u32 ,
) -> Result < ( ) , io:: Error> {
) -> Result < ( ) , Netlink Error> {
let sock = NetlinkSocket ::open ( ) ? ;
// Safety: Request is POD so this is safe
@ -54,33 +55,39 @@ pub(crate) unsafe fn netlink_set_xdp_fd(
// write the attrs
let attrs_buf = request_attributes ( & mut req , nlmsg_len ) ;
let mut attrs = NestedAttrs ::new ( attrs_buf , IFLA_XDP ) ;
attrs . write_attr (
attrs
. write_attr (
IFLA_XDP_FD as u16 ,
fd . map ( | fd | fd . as_raw_fd ( ) ) . unwrap_or ( - 1 ) ,
) ? ;
)
. map_err ( | e | NetlinkError ( NetlinkErrorRepr ::IoError ( e ) ) ) ? ;
if flags > 0 {
attrs . write_attr ( IFLA_XDP_FLAGS as u16 , flags ) ? ;
attrs
. write_attr ( IFLA_XDP_FLAGS as u16 , flags )
. map_err ( | e | NetlinkError ( NetlinkErrorRepr ::IoError ( e ) ) ) ? ;
}
if flags & XDP_FLAGS_REPLACE ! = 0 {
attrs . write_attr (
attrs
. write_attr (
IFLA_XDP_EXPECTED_FD as u16 ,
old_fd . map ( | fd | fd . as_raw_fd ( ) ) . unwrap ( ) ,
) ? ;
)
. map_err ( | e | NetlinkError ( NetlinkErrorRepr ::IoError ( e ) ) ) ? ;
}
let nla_len = attrs . finish ( ) ? ;
let nla_len = attrs
. finish ( )
. map_err ( | e | NetlinkError ( NetlinkErrorRepr ::IoError ( e ) ) ) ? ;
req . header . nlmsg_len + = align_to ( nla_len , NLA_ALIGNTO as usize ) as u32 ;
sock . send ( & bytes_of ( & req ) [ .. req . header . nlmsg_len as usize ] ) ? ;
sock . recv ( ) ? ;
Ok ( ( ) )
}
pub ( crate ) unsafe fn netlink_qdisc_add_clsact ( if_index : i32 ) -> Result < ( ) , io:: Error> {
pub ( crate ) unsafe fn netlink_qdisc_add_clsact ( if_index : i32 ) -> Result < ( ) , Netlink Error> {
let sock = NetlinkSocket ::open ( ) ? ;
let mut req = mem ::zeroed ::< TcRequest > ( ) ;
@ -101,7 +108,8 @@ pub(crate) unsafe fn netlink_qdisc_add_clsact(if_index: i32) -> Result<(), io::E
// add the TCA_KIND attribute
let attrs_buf = request_attributes ( & mut req , nlmsg_len ) ;
let attr_len = write_attr_bytes ( attrs_buf , 0 , TCA_KIND as u16 , b" clsact \0 " ) ? ;
let attr_len = write_attr_bytes ( attrs_buf , 0 , TCA_KIND as u16 , b" clsact \0 " )
. map_err ( | e | NetlinkError ( NetlinkErrorRepr ::IoError ( e ) ) ) ? ;
req . header . nlmsg_len + = align_to ( attr_len , NLA_ALIGNTO as usize ) as u32 ;
sock . send ( & bytes_of ( & req ) [ .. req . header . nlmsg_len as usize ] ) ? ;
@ -118,7 +126,7 @@ pub(crate) unsafe fn netlink_qdisc_attach(
priority : u16 ,
handle : u32 ,
create : bool ,
) -> Result < ( u16 , u32 ) , io:: Error> {
) -> Result < ( u16 , u32 ) , Netlink Error> {
let sock = NetlinkSocket ::open ( ) ? ;
let mut req = mem ::zeroed ::< TcRequest > ( ) ;
@ -152,15 +160,24 @@ pub(crate) unsafe fn netlink_qdisc_attach(
let attrs_buf = request_attributes ( & mut req , nlmsg_len ) ;
// add TCA_KIND
let kind_len = write_attr_bytes ( attrs_buf , 0 , TCA_KIND as u16 , b" bpf \0 " ) ? ;
let kind_len = write_attr_bytes ( attrs_buf , 0 , TCA_KIND as u16 , b" bpf \0 " )
. map_err ( | e | NetlinkError ( NetlinkErrorRepr ::IoError ( e ) ) ) ? ;
// add TCA_OPTIONS which includes TCA_BPF_FD, TCA_BPF_NAME and TCA_BPF_FLAGS
let mut options = NestedAttrs ::new ( & mut attrs_buf [ kind_len .. ] , TCA_OPTIONS as u16 ) ;
options . write_attr ( TCA_BPF_FD as u16 , prog_fd ) ? ;
options . write_attr_bytes ( TCA_BPF_NAME as u16 , prog_name . to_bytes_with_nul ( ) ) ? ;
options
. write_attr ( TCA_BPF_FD as u16 , prog_fd )
. map_err ( | e | NetlinkError ( NetlinkErrorRepr ::IoError ( e ) ) ) ? ;
options
. write_attr_bytes ( TCA_BPF_NAME as u16 , prog_name . to_bytes_with_nul ( ) )
. map_err ( | e | NetlinkError ( NetlinkErrorRepr ::IoError ( e ) ) ) ? ;
let flags : u32 = TCA_BPF_FLAG_ACT_DIRECT ;
options . write_attr ( TCA_BPF_FLAGS as u16 , flags ) ? ;
let options_len = options . finish ( ) ? ;
options
. write_attr ( TCA_BPF_FLAGS as u16 , flags )
. map_err ( | e | NetlinkError ( NetlinkErrorRepr ::IoError ( e ) ) ) ? ;
let options_len = options
. finish ( )
. map_err ( | e | NetlinkError ( NetlinkErrorRepr ::IoError ( e ) ) ) ? ;
req . header . nlmsg_len + = align_to ( kind_len + options_len , NLA_ALIGNTO as usize ) as u32 ;
sock . send ( & bytes_of ( & req ) [ .. req . header . nlmsg_len as usize ] ) ? ;
@ -176,10 +193,10 @@ pub(crate) unsafe fn netlink_qdisc_attach(
None = > {
// if sock.recv() succeeds we should never get here unless there's a
// bug in the kernel
return Err ( io::Error ::new (
return Err ( NetlinkError( NetlinkErrorRepr ::IoError ( io::Error ::new (
io ::ErrorKind ::Other ,
"no RTM_NEWTFILTER reply received, this is a bug." ,
) ) ;
) ) )) ;
}
} ;
@ -192,7 +209,7 @@ pub(crate) unsafe fn netlink_qdisc_detach(
attach_type : & TcAttachType ,
priority : u16 ,
handle : u32 ,
) -> Result < ( ) , io:: Error> {
) -> Result < ( ) , Netlink Error> {
let sock = NetlinkSocket ::open ( ) ? ;
let mut req = mem ::zeroed ::< TcRequest > ( ) ;
@ -222,7 +239,7 @@ pub(crate) unsafe fn netlink_find_filter_with_name(
if_index : i32 ,
attach_type : TcAttachType ,
name : & CStr ,
) -> Result < Vec < ( u16 , u32 ) > , io:: Error> {
) -> Result < Vec < ( u16 , u32 ) > , Netlink Error> {
let mut req = mem ::zeroed ::< TcRequest > ( ) ;
let nlmsg_len = mem ::size_of ::< nlmsghdr > ( ) + mem ::size_of ::< tcmsg > ( ) ;
@ -249,10 +266,12 @@ pub(crate) unsafe fn netlink_find_filter_with_name(
let tc_msg = ptr ::read_unaligned ( msg . data . as_ptr ( ) as * const tcmsg ) ;
let priority = ( tc_msg . tcm_info > > 16 ) as u16 ;
let attrs = parse_attrs ( & msg . data [ mem ::size_of ::< tcmsg > ( ) .. ] ) ? ;
let attrs = parse_attrs ( & msg . data [ mem ::size_of ::< tcmsg > ( ) .. ] )
. map_err ( | e | NetlinkError ( NetlinkErrorRepr ::NlAttrError ( e ) ) ) ? ;
if let Some ( opts ) = attrs . get ( & ( TCA_OPTIONS as u16 ) ) {
let opts = parse_attrs ( opts . data ) ? ;
let opts = parse_attrs ( opts . data )
. map_err ( | e | NetlinkError ( NetlinkErrorRepr ::NlAttrError ( e ) ) ) ? ;
if let Some ( f_name ) = opts . get ( & ( TCA_BPF_NAME as u16 ) ) {
if let Ok ( f_name ) = CStr ::from_bytes_with_nul ( f_name . data ) {
if name = = f_name {
@ -267,7 +286,7 @@ pub(crate) unsafe fn netlink_find_filter_with_name(
}
#[ doc(hidden) ]
pub unsafe fn netlink_set_link_up ( if_index : i32 ) -> Result < ( ) , io:: Error> {
pub unsafe fn netlink_set_link_up ( if_index : i32 ) -> Result < ( ) , Netlink Error> {
let sock = NetlinkSocket ::open ( ) ? ;
// Safety: Request is POD so this is safe
@ -311,12 +330,32 @@ struct NetlinkSocket {
_nl_pid : u32 ,
}
#[ derive(Error, Debug) ]
#[ error(transparent) ]
pub struct NetlinkError ( #[ from ] NetlinkErrorRepr ) ;
#[ derive(Error, Debug) ]
pub ( crate ) enum NetlinkErrorRepr {
#[ error( " netlink error: {message} " ) ]
Error {
message : String ,
#[ source ]
source : io ::Error ,
} ,
#[ error(transparent) ]
IoError ( #[ from ] io ::Error ) ,
#[ error(transparent) ]
NulError ( #[ from ] std ::ffi ::NulError ) ,
#[ error(transparent) ]
NlAttrError ( #[ from ] NlAttrError ) ,
}
impl NetlinkSocket {
fn open ( ) -> Result < Self , io ::Error > {
fn open ( ) -> Result < Self , NetlinkErrorRep r> {
// Safety: libc wrapper
let sock = unsafe { socket ( AF_NETLINK , SOCK_RAW , NETLINK_ROUTE ) } ;
if sock < 0 {
return Err ( io ::Error ::last_os_error ( ) ) ;
return Err ( NetlinkErrorRepr::IoError ( io::Error ::last_os_error ( ) ) ) ;
}
// SAFETY: `socket` returns a file descriptor.
let sock = unsafe { crate ::MockableFd ::from_raw_fd ( sock ) } ;
@ -324,13 +363,29 @@ impl NetlinkSocket {
let enable = 1 i32 ;
// Safety: libc wrapper
unsafe {
setsockopt (
// Set NETLINK_EXT_ACK to get extended attributes.
if setsockopt (
sock . as_raw_fd ( ) ,
SOL_NETLINK ,
NETLINK_EXT_ACK ,
& enable as * const _ as * const _ ,
mem ::size_of ::< i32 > ( ) as u32 ,
)
) < 0
{
return Err ( NetlinkErrorRepr ::IoError ( io ::Error ::last_os_error ( ) ) ) ;
} ;
// Set NETLINK_CAP_ACK to avoid getting copies of request payload.
if setsockopt (
sock . as_raw_fd ( ) ,
SOL_NETLINK ,
NETLINK_CAP_ACK ,
& enable as * const _ as * const _ ,
mem ::size_of ::< i32 > ( ) as u32 ,
) < 0
{
return Err ( NetlinkErrorRepr ::IoError ( io ::Error ::last_os_error ( ) ) ) ;
} ;
} ;
// Safety: sockaddr_nl is POD so this is safe
@ -346,7 +401,7 @@ impl NetlinkSocket {
)
} < 0
{
return Err ( io::Error ::last_os_error ( ) ) ;
return Err ( NetlinkErrorRepr::IoError ( io::Error ::last_os_error ( ) ) ) ;
}
Ok ( Self {
@ -355,7 +410,7 @@ impl NetlinkSocket {
} )
}
fn send ( & self , msg : & [ u8 ] ) -> Result < ( ) , io::Erro r> {
fn send ( & self , msg : & [ u8 ] ) -> Result < ( ) , NetlinkErrorRep r> {
if unsafe {
send (
self . sock . as_raw_fd ( ) ,
@ -365,12 +420,12 @@ impl NetlinkSocket {
)
} < 0
{
return Err ( io::Error ::last_os_error ( ) ) ;
return Err ( NetlinkErrorRepr::IoError ( io::Error ::last_os_error ( ) ) ) ;
}
Ok ( ( ) )
}
fn recv ( & self ) -> Result < Vec < NetlinkMessage > , io::Erro r> {
fn recv ( & self ) -> Result < Vec < NetlinkMessage > , NetlinkErrorRep r> {
let mut buf = [ 0 u8 ; 4096 ] ;
let mut messages = Vec ::new ( ) ;
let mut multipart = true ;
@ -386,7 +441,7 @@ impl NetlinkSocket {
)
} ;
if len < 0 {
return Err ( io::Error ::last_os_error ( ) ) ;
return Err ( NetlinkErrorRepr::IoError ( io::Error ::last_os_error ( ) ) ) ;
}
if len = = 0 {
break ;
@ -405,7 +460,25 @@ impl NetlinkSocket {
// this is an ACK
continue ;
}
return Err ( io ::Error ::from_raw_os_error ( - err . error ) ) ;
let attrs = parse_attrs ( & message . data ) ? ;
let err_msg = attrs . get ( & NLMSGERR_ATTR_MSG ) . and_then ( | msg | {
CStr ::from_bytes_with_nul ( msg . data )
. ok ( )
. map ( | s | s . to_string_lossy ( ) . into_owned ( ) )
} ) ;
match err_msg {
Some ( err_msg ) = > {
return Err ( NetlinkErrorRepr ::Error {
message : err_msg ,
source : io ::Error ::from_raw_os_error ( - err . error ) ,
} ) ;
}
None = > {
return Err ( NetlinkErrorRepr ::IoError (
io ::Error ::from_raw_os_error ( - err . error ) ,
) ) ;
}
}
}
NLMSG_DONE = > break 'out ,
_ = > messages . push ( message ) ,
@ -452,7 +525,7 @@ impl NetlinkMessage {
) ) ;
}
(
Vec ::new ( ) ,
buf [ data_offset + mem ::size_of ::< nlmsgerr > ( ) .. msg_len ] . to_vec ( ) ,
// Safety: nlmsgerr is POD so read is safe
Some ( unsafe {
ptr ::read_unaligned ( buf [ data_offset .. ] . as_ptr ( ) as * const nlmsgerr )
@ -628,7 +701,7 @@ struct NlAttr<'a> {
}
#[ derive(Debug, Error, PartialEq, Eq) ]
enum NlAttrError {
pub ( crate ) enum NlAttrError {
#[ error( " invalid buffer size `{size}`, expected `{expected}` " ) ]
InvalidBufferLength { size : usize , expected : usize } ,