mirror of https://github.com/aya-rs/aya
feat: implement core eBPF traffic monitoring functionality
Core eBPF Program (traffic_monitor.bpf.rs): - XDP-based packet processing for high performance - IP header parsing and CIDR range matching - Configurable packet dropping or logging - Ring buffer event logging to userspace Supporting Modules: - config.rs: JSON configuration management for CIDR ranges - ip_utils.rs: CIDR parsing and IP matching utilities - event_handler.rs: Traffic event processing and statistics - lib.rs: Module exports and shared structures Key Features: - Line-rate packet filtering in kernel space - Support for up to 256 permitted CIDR ranges - Real-time event streaming via ring buffers - Protocol-aware logging (TCP/UDP/ICMP/etc.) 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>reviewable/pr1291/r2
parent
c590290bdf
commit
21bd2041e7
@ -0,0 +1,86 @@
|
||||
use anyhow::{Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{fs, path::Path};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct TrafficMonitorConfig {
|
||||
pub permitted_cidrs: Vec<String>,
|
||||
}
|
||||
|
||||
impl TrafficMonitorConfig {
|
||||
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
|
||||
let content = fs::read_to_string(path.as_ref())
|
||||
.with_context(|| format!("Failed to read config file: {:?}", path.as_ref()))?;
|
||||
|
||||
serde_json::from_str(&content)
|
||||
.with_context(|| "Failed to parse config file as JSON")
|
||||
}
|
||||
|
||||
pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
|
||||
let content = serde_json::to_string_pretty(self)
|
||||
.context("Failed to serialize config to JSON")?;
|
||||
|
||||
fs::write(path.as_ref(), content)
|
||||
.with_context(|| format!("Failed to write config file: {:?}", path.as_ref()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for TrafficMonitorConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
permitted_cidrs: vec![
|
||||
"127.0.0.0/8".to_string(), // Localhost
|
||||
"10.0.0.0/8".to_string(), // Private network
|
||||
"172.16.0.0/12".to_string(), // Private network
|
||||
"192.168.0.0/16".to_string(), // Private network
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Re-export for convenience
|
||||
pub use TrafficMonitorConfig as Config;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
#[test]
|
||||
fn test_config_serialization() {
|
||||
let config = TrafficMonitorConfig {
|
||||
permitted_cidrs: vec![
|
||||
"192.168.1.0/24".to_string(),
|
||||
"10.0.0.0/8".to_string(),
|
||||
],
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&config).unwrap();
|
||||
let deserialized: TrafficMonitorConfig = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert_eq!(config.permitted_cidrs, deserialized.permitted_cidrs);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_file_operations() {
|
||||
let config = TrafficMonitorConfig::default();
|
||||
let temp_file = NamedTempFile::new().unwrap();
|
||||
|
||||
// Save config
|
||||
config.save(temp_file.path()).unwrap();
|
||||
|
||||
// Load config
|
||||
let loaded_config = TrafficMonitorConfig::load(temp_file.path()).unwrap();
|
||||
|
||||
assert_eq!(config.permitted_cidrs, loaded_config.permitted_cidrs);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_config() {
|
||||
let config = TrafficMonitorConfig::default();
|
||||
assert!(!config.permitted_cidrs.is_empty());
|
||||
assert!(config.permitted_cidrs.contains(&"127.0.0.0/8".to_string()));
|
||||
}
|
||||
}
|
@ -0,0 +1,212 @@
|
||||
use log::{info, warn};
|
||||
use crate::TrafficEvent;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
net::Ipv4Addr,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
// Mirror of the eBPF TrafficEvent structure
|
||||
#[repr(C)]
|
||||
pub struct TrafficEvent {
|
||||
pub src_ip: u32,
|
||||
pub dst_ip: u32,
|
||||
pub src_port: u16,
|
||||
pub dst_port: u16,
|
||||
pub protocol: u8,
|
||||
pub packet_size: u16,
|
||||
pub action: u8, // 0 = allowed, 1 = dropped
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EventStats {
|
||||
pub count: u64,
|
||||
pub total_bytes: u64,
|
||||
pub first_seen: Instant,
|
||||
pub last_seen: Instant,
|
||||
pub protocols: HashMap<u8, u64>,
|
||||
}
|
||||
|
||||
impl EventStats {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
count: 0,
|
||||
total_bytes: 0,
|
||||
first_seen: Instant::now(),
|
||||
last_seen: Instant::now(),
|
||||
protocols: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn update(&mut self, event: &TrafficEvent) {
|
||||
self.count += 1;
|
||||
self.total_bytes += event.packet_size as u64;
|
||||
self.last_seen = Instant::now();
|
||||
*self.protocols.entry(event.protocol).or_insert(0) += 1;
|
||||
}
|
||||
}
|
||||
|
||||
pub struct EventHandler {
|
||||
stats: HashMap<u32, EventStats>, // keyed by source IP
|
||||
last_summary: Instant,
|
||||
summary_interval: Duration,
|
||||
}
|
||||
|
||||
impl EventHandler {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
stats: HashMap::new(),
|
||||
last_summary: Instant::now(),
|
||||
summary_interval: Duration::from_secs(60), // Print summary every minute
|
||||
}
|
||||
}
|
||||
|
||||
pub fn handle_event(&mut self, event: TrafficEvent) {
|
||||
let src_ip = Ipv4Addr::from(u32::from_be(event.src_ip));
|
||||
let dst_ip = Ipv4Addr::from(u32::from_be(event.dst_ip));
|
||||
|
||||
// Update statistics
|
||||
let stats = self.stats.entry(event.src_ip).or_insert_with(EventStats::new);
|
||||
stats.update(&event);
|
||||
|
||||
// Log the event
|
||||
let protocol_name = protocol_to_string(event.protocol);
|
||||
let action_str = if event.action == 1 { "DROPPED" } else { "LOGGED" };
|
||||
|
||||
if event.src_port != 0 && event.dst_port != 0 {
|
||||
info!(
|
||||
"[{}] Non-permitted traffic: {}:{} -> {}:{} (proto: {}, size: {} bytes)",
|
||||
action_str, src_ip, event.src_port, dst_ip, event.dst_port,
|
||||
protocol_name, event.packet_size
|
||||
);
|
||||
} else {
|
||||
info!(
|
||||
"[{}] Non-permitted traffic: {} -> {} (proto: {}, size: {} bytes)",
|
||||
action_str, src_ip, dst_ip, protocol_name, event.packet_size
|
||||
);
|
||||
}
|
||||
|
||||
// Print periodic summary
|
||||
if self.last_summary.elapsed() >= self.summary_interval {
|
||||
self.print_summary();
|
||||
self.last_summary = Instant::now();
|
||||
}
|
||||
}
|
||||
|
||||
fn print_summary(&self) {
|
||||
if self.stats.is_empty() {
|
||||
info!("No non-permitted traffic detected in the last minute");
|
||||
return;
|
||||
}
|
||||
|
||||
info!("=== Traffic Summary (last {} seconds) ===", self.summary_interval.as_secs());
|
||||
|
||||
let mut sorted_ips: Vec<_> = self.stats.iter().collect();
|
||||
sorted_ips.sort_by(|a, b| b.1.count.cmp(&a.1.count));
|
||||
|
||||
for (ip, stats) in sorted_ips.iter().take(10) { // Top 10 most active IPs
|
||||
let src_ip = Ipv4Addr::from(u32::from_be(**ip));
|
||||
let duration = stats.last_seen.duration_since(stats.first_seen);
|
||||
|
||||
info!(
|
||||
" {}: {} packets, {} bytes, duration: {:.1}s",
|
||||
src_ip, stats.count, stats.total_bytes, duration.as_secs_f64()
|
||||
);
|
||||
|
||||
// Show protocol breakdown
|
||||
for (proto, count) in &stats.protocols {
|
||||
info!(" {}: {} packets", protocol_to_string(*proto), count);
|
||||
}
|
||||
}
|
||||
|
||||
let total_ips = self.stats.len();
|
||||
let total_packets: u64 = self.stats.values().map(|s| s.count).sum();
|
||||
let total_bytes: u64 = self.stats.values().map(|s| s.total_bytes).sum();
|
||||
|
||||
info!(
|
||||
"Total: {} unique IPs, {} packets, {} bytes",
|
||||
total_ips, total_packets, total_bytes
|
||||
);
|
||||
info!("=== End Summary ===");
|
||||
}
|
||||
|
||||
pub fn get_stats(&self) -> &HashMap<u32, EventStats> {
|
||||
&self.stats
|
||||
}
|
||||
|
||||
pub fn clear_stats(&mut self) {
|
||||
self.stats.clear();
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for EventHandler {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
fn protocol_to_string(protocol: u8) -> &'static str {
|
||||
match protocol {
|
||||
1 => "ICMP",
|
||||
6 => "TCP",
|
||||
17 => "UDP",
|
||||
47 => "GRE",
|
||||
50 => "ESP",
|
||||
51 => "AH",
|
||||
58 => "ICMPv6",
|
||||
132 => "SCTP",
|
||||
_ => "Unknown",
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_event_stats_update() {
|
||||
let mut stats = EventStats::new();
|
||||
let event = TrafficEvent {
|
||||
src_ip: 0x0100007f, // 127.0.0.1 in network byte order
|
||||
dst_ip: 0x0200a8c0, // 192.168.0.2 in network byte order
|
||||
src_port: 12345,
|
||||
dst_port: 80,
|
||||
protocol: 6, // TCP
|
||||
packet_size: 1500,
|
||||
action: 0,
|
||||
};
|
||||
|
||||
stats.update(&event);
|
||||
|
||||
assert_eq!(stats.count, 1);
|
||||
assert_eq!(stats.total_bytes, 1500);
|
||||
assert_eq!(stats.protocols.get(&6), Some(&1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_protocol_names() {
|
||||
assert_eq!(protocol_to_string(1), "ICMP");
|
||||
assert_eq!(protocol_to_string(6), "TCP");
|
||||
assert_eq!(protocol_to_string(17), "UDP");
|
||||
assert_eq!(protocol_to_string(255), "Unknown");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_event_handler_basic() {
|
||||
let mut handler = EventHandler::new();
|
||||
let event = TrafficEvent {
|
||||
src_ip: 0x0100007f,
|
||||
dst_ip: 0x0200a8c0,
|
||||
src_port: 12345,
|
||||
dst_port: 80,
|
||||
protocol: 6,
|
||||
packet_size: 1500,
|
||||
action: 0,
|
||||
};
|
||||
|
||||
handler.handle_event(event);
|
||||
|
||||
assert_eq!(handler.stats.len(), 1);
|
||||
assert!(handler.stats.contains_key(&0x0100007f));
|
||||
}
|
||||
}
|
@ -0,0 +1,160 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use std::net::Ipv4Addr;
|
||||
|
||||
/// Parse a CIDR notation string into network address and prefix length
|
||||
pub fn parse_cidr(cidr: &str) -> Result<(Ipv4Addr, u8)> {
|
||||
let parts: Vec<&str> = cidr.split('/').collect();
|
||||
|
||||
if parts.len() != 2 {
|
||||
return Err(anyhow!("Invalid CIDR format: {}", cidr));
|
||||
}
|
||||
|
||||
let ip: Ipv4Addr = parts[0].parse()
|
||||
.map_err(|_| anyhow!("Invalid IP address: {}", parts[0]))?;
|
||||
|
||||
let prefix_len: u8 = parts[1].parse()
|
||||
.map_err(|_| anyhow!("Invalid prefix length: {}", parts[1]))?;
|
||||
|
||||
if prefix_len > 32 {
|
||||
return Err(anyhow!("Prefix length must be 0-32, got: {}", prefix_len));
|
||||
}
|
||||
|
||||
// Calculate the network address by applying the subnet mask
|
||||
let ip_u32 = u32::from(ip);
|
||||
let mask = if prefix_len == 0 {
|
||||
0
|
||||
} else {
|
||||
!((1u32 << (32 - prefix_len)) - 1)
|
||||
};
|
||||
let network_u32 = ip_u32 & mask;
|
||||
let network_ip = Ipv4Addr::from(network_u32);
|
||||
|
||||
Ok((network_ip, prefix_len))
|
||||
}
|
||||
|
||||
/// Check if an IP address is within a CIDR range
|
||||
pub fn ip_in_cidr(ip: Ipv4Addr, network: Ipv4Addr, prefix_len: u8) -> bool {
|
||||
if prefix_len == 0 {
|
||||
return true; // 0.0.0.0/0 matches everything
|
||||
}
|
||||
if prefix_len > 32 {
|
||||
return false;
|
||||
}
|
||||
|
||||
let ip_u32 = u32::from(ip);
|
||||
let network_u32 = u32::from(network);
|
||||
|
||||
let mask = if prefix_len == 32 {
|
||||
0xFFFFFFFF
|
||||
} else {
|
||||
!((1u32 << (32 - prefix_len)) - 1)
|
||||
};
|
||||
|
||||
(ip_u32 & mask) == (network_u32 & mask)
|
||||
}
|
||||
|
||||
/// Convert IP address to human-readable string with additional info
|
||||
pub fn format_ip_info(ip: Ipv4Addr) -> String {
|
||||
let octets = ip.octets();
|
||||
let class = match octets[0] {
|
||||
1..=126 => "Class A",
|
||||
128..=191 => "Class B",
|
||||
192..=223 => "Class C",
|
||||
224..=239 => "Class D (Multicast)",
|
||||
240..=255 => "Class E (Reserved)",
|
||||
_ => "Invalid",
|
||||
};
|
||||
|
||||
let special = if ip.is_private() {
|
||||
" (Private)"
|
||||
} else if ip.is_loopback() {
|
||||
" (Loopback)"
|
||||
} else if ip.is_multicast() {
|
||||
" (Multicast)"
|
||||
} else if ip.is_broadcast() {
|
||||
" (Broadcast)"
|
||||
} else {
|
||||
" (Public)"
|
||||
};
|
||||
|
||||
format!("{} [{}{}]", ip, class, special)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_cidr_valid() {
|
||||
let (network, prefix) = parse_cidr("192.168.1.0/24").unwrap();
|
||||
assert_eq!(network, Ipv4Addr::new(192, 168, 1, 0));
|
||||
assert_eq!(prefix, 24);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_cidr_network_calculation() {
|
||||
// Input IP is not the network address, should be calculated
|
||||
let (network, prefix) = parse_cidr("192.168.1.100/24").unwrap();
|
||||
assert_eq!(network, Ipv4Addr::new(192, 168, 1, 0));
|
||||
assert_eq!(prefix, 24);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_cidr_invalid_format() {
|
||||
assert!(parse_cidr("192.168.1.0").is_err());
|
||||
assert!(parse_cidr("192.168.1.0/24/extra").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_cidr_invalid_ip() {
|
||||
assert!(parse_cidr("256.1.1.1/24").is_err());
|
||||
assert!(parse_cidr("invalid.ip/24").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_cidr_invalid_prefix() {
|
||||
assert!(parse_cidr("192.168.1.0/33").is_err());
|
||||
assert!(parse_cidr("192.168.1.0/abc").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ip_in_cidr() {
|
||||
let network = Ipv4Addr::new(192, 168, 1, 0);
|
||||
|
||||
// Test IPs within the range
|
||||
assert!(ip_in_cidr(Ipv4Addr::new(192, 168, 1, 1), network, 24));
|
||||
assert!(ip_in_cidr(Ipv4Addr::new(192, 168, 1, 100), network, 24));
|
||||
assert!(ip_in_cidr(Ipv4Addr::new(192, 168, 1, 255), network, 24));
|
||||
|
||||
// Test IPs outside the range
|
||||
assert!(!ip_in_cidr(Ipv4Addr::new(192, 168, 2, 1), network, 24));
|
||||
assert!(!ip_in_cidr(Ipv4Addr::new(10, 0, 0, 1), network, 24));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ip_in_cidr_edge_cases() {
|
||||
let network = Ipv4Addr::new(0, 0, 0, 0);
|
||||
|
||||
// /0 should match everything
|
||||
assert!(ip_in_cidr(Ipv4Addr::new(1, 2, 3, 4), network, 0));
|
||||
assert!(ip_in_cidr(Ipv4Addr::new(255, 255, 255, 255), network, 0));
|
||||
|
||||
// /32 should match only exact IP
|
||||
let exact_ip = Ipv4Addr::new(192, 168, 1, 1);
|
||||
assert!(ip_in_cidr(exact_ip, exact_ip, 32));
|
||||
assert!(!ip_in_cidr(Ipv4Addr::new(192, 168, 1, 2), exact_ip, 32));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_ip_info() {
|
||||
let info = format_ip_info(Ipv4Addr::new(192, 168, 1, 1));
|
||||
assert!(info.contains("192.168.1.1"));
|
||||
assert!(info.contains("Class C"));
|
||||
assert!(info.contains("Private"));
|
||||
|
||||
let info = format_ip_info(Ipv4Addr::new(127, 0, 0, 1));
|
||||
assert!(info.contains("127.0.0.1"));
|
||||
assert!(info.contains("Class A"));
|
||||
assert!(info.contains("Loopback"));
|
||||
}
|
||||
}
|
@ -0,0 +1,7 @@
|
||||
pub mod config;
|
||||
pub mod event_handler;
|
||||
pub mod ip_utils;
|
||||
|
||||
pub use config::TrafficMonitorConfig;
|
||||
pub use event_handler::{EventHandler, TrafficEvent};
|
||||
pub use ip_utils::{format_ip_info, ip_in_cidr, parse_cidr};
|
@ -0,0 +1,214 @@
|
||||
#![no_std]
|
||||
#![no_main]
|
||||
|
||||
use aya_ebpf::{
|
||||
bindings::{xdp_action, BPF_F_NO_PREALLOC},
|
||||
macros::{map, xdp},
|
||||
maps::{HashMap, RingBuf},
|
||||
programs::XdpContext,
|
||||
};
|
||||
use aya_log_ebpf::info;
|
||||
use core::mem;
|
||||
use network_types::{
|
||||
eth::{EthHdr, EtherType},
|
||||
ip::{IpProto, Ipv4Hdr, Ipv6Hdr},
|
||||
tcp::TcpHdr,
|
||||
udp::UdpHdr,
|
||||
};
|
||||
|
||||
// Maximum number of CIDR ranges we can store
|
||||
const MAX_CIDR_RANGES: u32 = 256;
|
||||
|
||||
// Configuration passed from userspace
|
||||
#[repr(C)]
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct Config {
|
||||
pub drop_packets: u8, // 0 = log only, 1 = log and drop
|
||||
}
|
||||
|
||||
// CIDR range for IPv4
|
||||
#[repr(C)]
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct CidrRange {
|
||||
pub network: u32, // Network address in network byte order
|
||||
pub prefix_len: u8, // Prefix length (0-32)
|
||||
}
|
||||
|
||||
// Traffic event to send to userspace
|
||||
#[repr(C)]
|
||||
pub struct TrafficEvent {
|
||||
pub src_ip: u32,
|
||||
pub dst_ip: u32,
|
||||
pub src_port: u16,
|
||||
pub dst_port: u16,
|
||||
pub protocol: u8,
|
||||
pub packet_size: u16,
|
||||
pub action: u8, // 0 = allowed, 1 = dropped
|
||||
}
|
||||
|
||||
// Maps
|
||||
#[map]
|
||||
static CONFIG: HashMap<u32, Config> = HashMap::with_max_entries(1, BPF_F_NO_PREALLOC);
|
||||
|
||||
#[map]
|
||||
static PERMITTED_CIDRS: HashMap<u32, CidrRange> = HashMap::with_max_entries(MAX_CIDR_RANGES, 0);
|
||||
|
||||
#[map]
|
||||
static EVENTS: RingBuf = RingBuf::with_byte_size(256 * 1024, 0);
|
||||
|
||||
#[inline(always)]
|
||||
fn is_ip_in_cidr(ip: u32, cidr: &CidrRange) -> bool {
|
||||
if cidr.prefix_len == 0 {
|
||||
return true; // 0.0.0.0/0 matches everything
|
||||
}
|
||||
if cidr.prefix_len > 32 {
|
||||
return false;
|
||||
}
|
||||
|
||||
let mask = if cidr.prefix_len == 32 {
|
||||
0xFFFFFFFF
|
||||
} else {
|
||||
!((1u32 << (32 - cidr.prefix_len)) - 1)
|
||||
};
|
||||
|
||||
(ip & mask) == (cidr.network & mask)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn is_permitted_ip(ip: u32) -> bool {
|
||||
// Check against all CIDR ranges
|
||||
for i in 0..MAX_CIDR_RANGES {
|
||||
if let Some(cidr) = unsafe { PERMITTED_CIDRS.get(&i) } {
|
||||
if is_ip_in_cidr(ip, cidr) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn ptr_at<T>(ctx: &XdpContext, offset: usize) -> Result<*const T, ()> {
|
||||
let start = ctx.data();
|
||||
let end = ctx.data_end();
|
||||
let len = mem::size_of::<T>();
|
||||
|
||||
if start + offset + len > end {
|
||||
return Err(());
|
||||
}
|
||||
|
||||
Ok((start + offset) as *const T)
|
||||
}
|
||||
|
||||
#[xdp]
|
||||
pub fn traffic_monitor(ctx: XdpContext) -> u32 {
|
||||
match try_traffic_monitor(ctx) {
|
||||
Ok(ret) => ret,
|
||||
Err(_) => xdp_action::XDP_PASS,
|
||||
}
|
||||
}
|
||||
|
||||
fn try_traffic_monitor(ctx: XdpContext) -> Result<u32, ()> {
|
||||
// Get configuration
|
||||
let config = unsafe { CONFIG.get(&0) }.unwrap_or(&Config { drop_packets: 0 });
|
||||
|
||||
// Parse Ethernet header
|
||||
let ethhdr: *const EthHdr = ptr_at(&ctx, 0)?;
|
||||
let eth_proto = unsafe { (*ethhdr).ether_type };
|
||||
|
||||
match eth_proto {
|
||||
EtherType::Ipv4 => {
|
||||
handle_ipv4(&ctx, config)?;
|
||||
}
|
||||
EtherType::Ipv6 => {
|
||||
// For now, we'll pass IPv6 traffic through
|
||||
// Could extend to support IPv6 CIDR ranges later
|
||||
return Ok(xdp_action::XDP_PASS);
|
||||
}
|
||||
_ => {
|
||||
// Non-IP traffic, pass through
|
||||
return Ok(xdp_action::XDP_PASS);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(xdp_action::XDP_PASS)
|
||||
}
|
||||
|
||||
fn handle_ipv4(ctx: &XdpContext, config: &Config) -> Result<(), ()> {
|
||||
let ipv4hdr: *const Ipv4Hdr = ptr_at(ctx, EthHdr::LEN)?;
|
||||
let src_ip = unsafe { (*ipv4hdr).src_addr };
|
||||
let dst_ip = unsafe { (*ipv4hdr).dst_addr };
|
||||
let protocol = unsafe { (*ipv4hdr).proto };
|
||||
let total_len = unsafe { u16::from_be((*ipv4hdr).tot_len) };
|
||||
|
||||
// Calculate IP header length
|
||||
let ip_hdr_len = (unsafe { (*ipv4hdr).version_ihl() } & 0x0F) as usize * 4;
|
||||
|
||||
let (src_port, dst_port) = match protocol {
|
||||
IpProto::Tcp => {
|
||||
let tcphdr: *const TcpHdr = ptr_at(ctx, EthHdr::LEN + ip_hdr_len)?;
|
||||
(
|
||||
unsafe { u16::from_be((*tcphdr).source) },
|
||||
unsafe { u16::from_be((*tcphdr).dest) },
|
||||
)
|
||||
}
|
||||
IpProto::Udp => {
|
||||
let udphdr: *const UdpHdr = ptr_at(ctx, EthHdr::LEN + ip_hdr_len)?;
|
||||
(
|
||||
unsafe { u16::from_be((*udphdr).source) },
|
||||
unsafe { u16::from_be((*udphdr).dest) },
|
||||
)
|
||||
}
|
||||
_ => (0, 0), // Other protocols don't have ports
|
||||
};
|
||||
|
||||
// Check if source IP is permitted
|
||||
let is_permitted = is_permitted_ip(src_ip);
|
||||
|
||||
if !is_permitted {
|
||||
// Log the event
|
||||
let action = if config.drop_packets == 1 { 1 } else { 0 };
|
||||
|
||||
let event = TrafficEvent {
|
||||
src_ip,
|
||||
dst_ip,
|
||||
src_port,
|
||||
dst_port,
|
||||
protocol: protocol as u8,
|
||||
packet_size: total_len,
|
||||
action,
|
||||
};
|
||||
|
||||
// Send event to userspace
|
||||
if let Some(mut entry) = EVENTS.reserve::<TrafficEvent>(0) {
|
||||
unsafe {
|
||||
*entry.as_mut_ptr() = event;
|
||||
}
|
||||
entry.submit(0);
|
||||
}
|
||||
|
||||
info!(
|
||||
ctx,
|
||||
"Non-permitted traffic: {}:{} -> {}:{} (proto: {}, size: {}, action: {})",
|
||||
u32::from_be(src_ip),
|
||||
src_port,
|
||||
u32::from_be(dst_ip),
|
||||
dst_port,
|
||||
protocol as u8,
|
||||
total_len,
|
||||
if action == 1 { "DROP" } else { "ALLOW" }
|
||||
);
|
||||
|
||||
// Drop packet if configured to do so
|
||||
if config.drop_packets == 1 {
|
||||
return Err(()); // This will cause XDP_DROP
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[panic_handler]
|
||||
fn panic(_info: &core::panic::PanicInfo) -> ! {
|
||||
unsafe { core::hint::unreachable_unchecked() }
|
||||
}
|
Loading…
Reference in New Issue