mirror of https://github.com/aya-rs/aya
feat: implement core eBPF traffic monitoring functionality
Core eBPF Program (traffic_monitor.bpf.rs): - XDP-based packet processing for high performance - IP header parsing and CIDR range matching - Configurable packet dropping or logging - Ring buffer event logging to userspace Supporting Modules: - config.rs: JSON configuration management for CIDR ranges - ip_utils.rs: CIDR parsing and IP matching utilities - event_handler.rs: Traffic event processing and statistics - lib.rs: Module exports and shared structures Key Features: - Line-rate packet filtering in kernel space - Support for up to 256 permitted CIDR ranges - Real-time event streaming via ring buffers - Protocol-aware logging (TCP/UDP/ICMP/etc.) 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>reviewable/pr1291/r2
parent
c590290bdf
commit
21bd2041e7
@ -0,0 +1,86 @@
|
|||||||
|
use anyhow::{Context, Result};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::{fs, path::Path};
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct TrafficMonitorConfig {
|
||||||
|
pub permitted_cidrs: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TrafficMonitorConfig {
|
||||||
|
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
|
||||||
|
let content = fs::read_to_string(path.as_ref())
|
||||||
|
.with_context(|| format!("Failed to read config file: {:?}", path.as_ref()))?;
|
||||||
|
|
||||||
|
serde_json::from_str(&content)
|
||||||
|
.with_context(|| "Failed to parse config file as JSON")
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
|
||||||
|
let content = serde_json::to_string_pretty(self)
|
||||||
|
.context("Failed to serialize config to JSON")?;
|
||||||
|
|
||||||
|
fs::write(path.as_ref(), content)
|
||||||
|
.with_context(|| format!("Failed to write config file: {:?}", path.as_ref()))?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for TrafficMonitorConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
permitted_cidrs: vec![
|
||||||
|
"127.0.0.0/8".to_string(), // Localhost
|
||||||
|
"10.0.0.0/8".to_string(), // Private network
|
||||||
|
"172.16.0.0/12".to_string(), // Private network
|
||||||
|
"192.168.0.0/16".to_string(), // Private network
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-export for convenience
|
||||||
|
pub use TrafficMonitorConfig as Config;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use tempfile::NamedTempFile;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_config_serialization() {
|
||||||
|
let config = TrafficMonitorConfig {
|
||||||
|
permitted_cidrs: vec![
|
||||||
|
"192.168.1.0/24".to_string(),
|
||||||
|
"10.0.0.0/8".to_string(),
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
let json = serde_json::to_string(&config).unwrap();
|
||||||
|
let deserialized: TrafficMonitorConfig = serde_json::from_str(&json).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(config.permitted_cidrs, deserialized.permitted_cidrs);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_config_file_operations() {
|
||||||
|
let config = TrafficMonitorConfig::default();
|
||||||
|
let temp_file = NamedTempFile::new().unwrap();
|
||||||
|
|
||||||
|
// Save config
|
||||||
|
config.save(temp_file.path()).unwrap();
|
||||||
|
|
||||||
|
// Load config
|
||||||
|
let loaded_config = TrafficMonitorConfig::load(temp_file.path()).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(config.permitted_cidrs, loaded_config.permitted_cidrs);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_default_config() {
|
||||||
|
let config = TrafficMonitorConfig::default();
|
||||||
|
assert!(!config.permitted_cidrs.is_empty());
|
||||||
|
assert!(config.permitted_cidrs.contains(&"127.0.0.0/8".to_string()));
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,212 @@
|
|||||||
|
use log::{info, warn};
|
||||||
|
use crate::TrafficEvent;
|
||||||
|
use std::{
|
||||||
|
collections::HashMap,
|
||||||
|
net::Ipv4Addr,
|
||||||
|
time::{Duration, Instant},
|
||||||
|
};
|
||||||
|
|
||||||
|
// Mirror of the eBPF TrafficEvent structure
|
||||||
|
#[repr(C)]
|
||||||
|
pub struct TrafficEvent {
|
||||||
|
pub src_ip: u32,
|
||||||
|
pub dst_ip: u32,
|
||||||
|
pub src_port: u16,
|
||||||
|
pub dst_port: u16,
|
||||||
|
pub protocol: u8,
|
||||||
|
pub packet_size: u16,
|
||||||
|
pub action: u8, // 0 = allowed, 1 = dropped
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct EventStats {
|
||||||
|
pub count: u64,
|
||||||
|
pub total_bytes: u64,
|
||||||
|
pub first_seen: Instant,
|
||||||
|
pub last_seen: Instant,
|
||||||
|
pub protocols: HashMap<u8, u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EventStats {
|
||||||
|
fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
count: 0,
|
||||||
|
total_bytes: 0,
|
||||||
|
first_seen: Instant::now(),
|
||||||
|
last_seen: Instant::now(),
|
||||||
|
protocols: HashMap::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update(&mut self, event: &TrafficEvent) {
|
||||||
|
self.count += 1;
|
||||||
|
self.total_bytes += event.packet_size as u64;
|
||||||
|
self.last_seen = Instant::now();
|
||||||
|
*self.protocols.entry(event.protocol).or_insert(0) += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct EventHandler {
|
||||||
|
stats: HashMap<u32, EventStats>, // keyed by source IP
|
||||||
|
last_summary: Instant,
|
||||||
|
summary_interval: Duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EventHandler {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
stats: HashMap::new(),
|
||||||
|
last_summary: Instant::now(),
|
||||||
|
summary_interval: Duration::from_secs(60), // Print summary every minute
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn handle_event(&mut self, event: TrafficEvent) {
|
||||||
|
let src_ip = Ipv4Addr::from(u32::from_be(event.src_ip));
|
||||||
|
let dst_ip = Ipv4Addr::from(u32::from_be(event.dst_ip));
|
||||||
|
|
||||||
|
// Update statistics
|
||||||
|
let stats = self.stats.entry(event.src_ip).or_insert_with(EventStats::new);
|
||||||
|
stats.update(&event);
|
||||||
|
|
||||||
|
// Log the event
|
||||||
|
let protocol_name = protocol_to_string(event.protocol);
|
||||||
|
let action_str = if event.action == 1 { "DROPPED" } else { "LOGGED" };
|
||||||
|
|
||||||
|
if event.src_port != 0 && event.dst_port != 0 {
|
||||||
|
info!(
|
||||||
|
"[{}] Non-permitted traffic: {}:{} -> {}:{} (proto: {}, size: {} bytes)",
|
||||||
|
action_str, src_ip, event.src_port, dst_ip, event.dst_port,
|
||||||
|
protocol_name, event.packet_size
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
info!(
|
||||||
|
"[{}] Non-permitted traffic: {} -> {} (proto: {}, size: {} bytes)",
|
||||||
|
action_str, src_ip, dst_ip, protocol_name, event.packet_size
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print periodic summary
|
||||||
|
if self.last_summary.elapsed() >= self.summary_interval {
|
||||||
|
self.print_summary();
|
||||||
|
self.last_summary = Instant::now();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn print_summary(&self) {
|
||||||
|
if self.stats.is_empty() {
|
||||||
|
info!("No non-permitted traffic detected in the last minute");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
info!("=== Traffic Summary (last {} seconds) ===", self.summary_interval.as_secs());
|
||||||
|
|
||||||
|
let mut sorted_ips: Vec<_> = self.stats.iter().collect();
|
||||||
|
sorted_ips.sort_by(|a, b| b.1.count.cmp(&a.1.count));
|
||||||
|
|
||||||
|
for (ip, stats) in sorted_ips.iter().take(10) { // Top 10 most active IPs
|
||||||
|
let src_ip = Ipv4Addr::from(u32::from_be(**ip));
|
||||||
|
let duration = stats.last_seen.duration_since(stats.first_seen);
|
||||||
|
|
||||||
|
info!(
|
||||||
|
" {}: {} packets, {} bytes, duration: {:.1}s",
|
||||||
|
src_ip, stats.count, stats.total_bytes, duration.as_secs_f64()
|
||||||
|
);
|
||||||
|
|
||||||
|
// Show protocol breakdown
|
||||||
|
for (proto, count) in &stats.protocols {
|
||||||
|
info!(" {}: {} packets", protocol_to_string(*proto), count);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let total_ips = self.stats.len();
|
||||||
|
let total_packets: u64 = self.stats.values().map(|s| s.count).sum();
|
||||||
|
let total_bytes: u64 = self.stats.values().map(|s| s.total_bytes).sum();
|
||||||
|
|
||||||
|
info!(
|
||||||
|
"Total: {} unique IPs, {} packets, {} bytes",
|
||||||
|
total_ips, total_packets, total_bytes
|
||||||
|
);
|
||||||
|
info!("=== End Summary ===");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_stats(&self) -> &HashMap<u32, EventStats> {
|
||||||
|
&self.stats
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_stats(&mut self) {
|
||||||
|
self.stats.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for EventHandler {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn protocol_to_string(protocol: u8) -> &'static str {
|
||||||
|
match protocol {
|
||||||
|
1 => "ICMP",
|
||||||
|
6 => "TCP",
|
||||||
|
17 => "UDP",
|
||||||
|
47 => "GRE",
|
||||||
|
50 => "ESP",
|
||||||
|
51 => "AH",
|
||||||
|
58 => "ICMPv6",
|
||||||
|
132 => "SCTP",
|
||||||
|
_ => "Unknown",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_event_stats_update() {
|
||||||
|
let mut stats = EventStats::new();
|
||||||
|
let event = TrafficEvent {
|
||||||
|
src_ip: 0x0100007f, // 127.0.0.1 in network byte order
|
||||||
|
dst_ip: 0x0200a8c0, // 192.168.0.2 in network byte order
|
||||||
|
src_port: 12345,
|
||||||
|
dst_port: 80,
|
||||||
|
protocol: 6, // TCP
|
||||||
|
packet_size: 1500,
|
||||||
|
action: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
stats.update(&event);
|
||||||
|
|
||||||
|
assert_eq!(stats.count, 1);
|
||||||
|
assert_eq!(stats.total_bytes, 1500);
|
||||||
|
assert_eq!(stats.protocols.get(&6), Some(&1));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_protocol_names() {
|
||||||
|
assert_eq!(protocol_to_string(1), "ICMP");
|
||||||
|
assert_eq!(protocol_to_string(6), "TCP");
|
||||||
|
assert_eq!(protocol_to_string(17), "UDP");
|
||||||
|
assert_eq!(protocol_to_string(255), "Unknown");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_event_handler_basic() {
|
||||||
|
let mut handler = EventHandler::new();
|
||||||
|
let event = TrafficEvent {
|
||||||
|
src_ip: 0x0100007f,
|
||||||
|
dst_ip: 0x0200a8c0,
|
||||||
|
src_port: 12345,
|
||||||
|
dst_port: 80,
|
||||||
|
protocol: 6,
|
||||||
|
packet_size: 1500,
|
||||||
|
action: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
handler.handle_event(event);
|
||||||
|
|
||||||
|
assert_eq!(handler.stats.len(), 1);
|
||||||
|
assert!(handler.stats.contains_key(&0x0100007f));
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,160 @@
|
|||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use std::net::Ipv4Addr;
|
||||||
|
|
||||||
|
/// Parse a CIDR notation string into network address and prefix length
|
||||||
|
pub fn parse_cidr(cidr: &str) -> Result<(Ipv4Addr, u8)> {
|
||||||
|
let parts: Vec<&str> = cidr.split('/').collect();
|
||||||
|
|
||||||
|
if parts.len() != 2 {
|
||||||
|
return Err(anyhow!("Invalid CIDR format: {}", cidr));
|
||||||
|
}
|
||||||
|
|
||||||
|
let ip: Ipv4Addr = parts[0].parse()
|
||||||
|
.map_err(|_| anyhow!("Invalid IP address: {}", parts[0]))?;
|
||||||
|
|
||||||
|
let prefix_len: u8 = parts[1].parse()
|
||||||
|
.map_err(|_| anyhow!("Invalid prefix length: {}", parts[1]))?;
|
||||||
|
|
||||||
|
if prefix_len > 32 {
|
||||||
|
return Err(anyhow!("Prefix length must be 0-32, got: {}", prefix_len));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate the network address by applying the subnet mask
|
||||||
|
let ip_u32 = u32::from(ip);
|
||||||
|
let mask = if prefix_len == 0 {
|
||||||
|
0
|
||||||
|
} else {
|
||||||
|
!((1u32 << (32 - prefix_len)) - 1)
|
||||||
|
};
|
||||||
|
let network_u32 = ip_u32 & mask;
|
||||||
|
let network_ip = Ipv4Addr::from(network_u32);
|
||||||
|
|
||||||
|
Ok((network_ip, prefix_len))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if an IP address is within a CIDR range
|
||||||
|
pub fn ip_in_cidr(ip: Ipv4Addr, network: Ipv4Addr, prefix_len: u8) -> bool {
|
||||||
|
if prefix_len == 0 {
|
||||||
|
return true; // 0.0.0.0/0 matches everything
|
||||||
|
}
|
||||||
|
if prefix_len > 32 {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
let ip_u32 = u32::from(ip);
|
||||||
|
let network_u32 = u32::from(network);
|
||||||
|
|
||||||
|
let mask = if prefix_len == 32 {
|
||||||
|
0xFFFFFFFF
|
||||||
|
} else {
|
||||||
|
!((1u32 << (32 - prefix_len)) - 1)
|
||||||
|
};
|
||||||
|
|
||||||
|
(ip_u32 & mask) == (network_u32 & mask)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert IP address to human-readable string with additional info
|
||||||
|
pub fn format_ip_info(ip: Ipv4Addr) -> String {
|
||||||
|
let octets = ip.octets();
|
||||||
|
let class = match octets[0] {
|
||||||
|
1..=126 => "Class A",
|
||||||
|
128..=191 => "Class B",
|
||||||
|
192..=223 => "Class C",
|
||||||
|
224..=239 => "Class D (Multicast)",
|
||||||
|
240..=255 => "Class E (Reserved)",
|
||||||
|
_ => "Invalid",
|
||||||
|
};
|
||||||
|
|
||||||
|
let special = if ip.is_private() {
|
||||||
|
" (Private)"
|
||||||
|
} else if ip.is_loopback() {
|
||||||
|
" (Loopback)"
|
||||||
|
} else if ip.is_multicast() {
|
||||||
|
" (Multicast)"
|
||||||
|
} else if ip.is_broadcast() {
|
||||||
|
" (Broadcast)"
|
||||||
|
} else {
|
||||||
|
" (Public)"
|
||||||
|
};
|
||||||
|
|
||||||
|
format!("{} [{}{}]", ip, class, special)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_cidr_valid() {
|
||||||
|
let (network, prefix) = parse_cidr("192.168.1.0/24").unwrap();
|
||||||
|
assert_eq!(network, Ipv4Addr::new(192, 168, 1, 0));
|
||||||
|
assert_eq!(prefix, 24);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_cidr_network_calculation() {
|
||||||
|
// Input IP is not the network address, should be calculated
|
||||||
|
let (network, prefix) = parse_cidr("192.168.1.100/24").unwrap();
|
||||||
|
assert_eq!(network, Ipv4Addr::new(192, 168, 1, 0));
|
||||||
|
assert_eq!(prefix, 24);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_cidr_invalid_format() {
|
||||||
|
assert!(parse_cidr("192.168.1.0").is_err());
|
||||||
|
assert!(parse_cidr("192.168.1.0/24/extra").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_cidr_invalid_ip() {
|
||||||
|
assert!(parse_cidr("256.1.1.1/24").is_err());
|
||||||
|
assert!(parse_cidr("invalid.ip/24").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_cidr_invalid_prefix() {
|
||||||
|
assert!(parse_cidr("192.168.1.0/33").is_err());
|
||||||
|
assert!(parse_cidr("192.168.1.0/abc").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_ip_in_cidr() {
|
||||||
|
let network = Ipv4Addr::new(192, 168, 1, 0);
|
||||||
|
|
||||||
|
// Test IPs within the range
|
||||||
|
assert!(ip_in_cidr(Ipv4Addr::new(192, 168, 1, 1), network, 24));
|
||||||
|
assert!(ip_in_cidr(Ipv4Addr::new(192, 168, 1, 100), network, 24));
|
||||||
|
assert!(ip_in_cidr(Ipv4Addr::new(192, 168, 1, 255), network, 24));
|
||||||
|
|
||||||
|
// Test IPs outside the range
|
||||||
|
assert!(!ip_in_cidr(Ipv4Addr::new(192, 168, 2, 1), network, 24));
|
||||||
|
assert!(!ip_in_cidr(Ipv4Addr::new(10, 0, 0, 1), network, 24));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_ip_in_cidr_edge_cases() {
|
||||||
|
let network = Ipv4Addr::new(0, 0, 0, 0);
|
||||||
|
|
||||||
|
// /0 should match everything
|
||||||
|
assert!(ip_in_cidr(Ipv4Addr::new(1, 2, 3, 4), network, 0));
|
||||||
|
assert!(ip_in_cidr(Ipv4Addr::new(255, 255, 255, 255), network, 0));
|
||||||
|
|
||||||
|
// /32 should match only exact IP
|
||||||
|
let exact_ip = Ipv4Addr::new(192, 168, 1, 1);
|
||||||
|
assert!(ip_in_cidr(exact_ip, exact_ip, 32));
|
||||||
|
assert!(!ip_in_cidr(Ipv4Addr::new(192, 168, 1, 2), exact_ip, 32));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_format_ip_info() {
|
||||||
|
let info = format_ip_info(Ipv4Addr::new(192, 168, 1, 1));
|
||||||
|
assert!(info.contains("192.168.1.1"));
|
||||||
|
assert!(info.contains("Class C"));
|
||||||
|
assert!(info.contains("Private"));
|
||||||
|
|
||||||
|
let info = format_ip_info(Ipv4Addr::new(127, 0, 0, 1));
|
||||||
|
assert!(info.contains("127.0.0.1"));
|
||||||
|
assert!(info.contains("Class A"));
|
||||||
|
assert!(info.contains("Loopback"));
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,7 @@
|
|||||||
|
pub mod config;
|
||||||
|
pub mod event_handler;
|
||||||
|
pub mod ip_utils;
|
||||||
|
|
||||||
|
pub use config::TrafficMonitorConfig;
|
||||||
|
pub use event_handler::{EventHandler, TrafficEvent};
|
||||||
|
pub use ip_utils::{format_ip_info, ip_in_cidr, parse_cidr};
|
@ -0,0 +1,214 @@
|
|||||||
|
#![no_std]
|
||||||
|
#![no_main]
|
||||||
|
|
||||||
|
use aya_ebpf::{
|
||||||
|
bindings::{xdp_action, BPF_F_NO_PREALLOC},
|
||||||
|
macros::{map, xdp},
|
||||||
|
maps::{HashMap, RingBuf},
|
||||||
|
programs::XdpContext,
|
||||||
|
};
|
||||||
|
use aya_log_ebpf::info;
|
||||||
|
use core::mem;
|
||||||
|
use network_types::{
|
||||||
|
eth::{EthHdr, EtherType},
|
||||||
|
ip::{IpProto, Ipv4Hdr, Ipv6Hdr},
|
||||||
|
tcp::TcpHdr,
|
||||||
|
udp::UdpHdr,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Maximum number of CIDR ranges we can store
|
||||||
|
const MAX_CIDR_RANGES: u32 = 256;
|
||||||
|
|
||||||
|
// Configuration passed from userspace
|
||||||
|
#[repr(C)]
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
|
pub struct Config {
|
||||||
|
pub drop_packets: u8, // 0 = log only, 1 = log and drop
|
||||||
|
}
|
||||||
|
|
||||||
|
// CIDR range for IPv4
|
||||||
|
#[repr(C)]
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
|
pub struct CidrRange {
|
||||||
|
pub network: u32, // Network address in network byte order
|
||||||
|
pub prefix_len: u8, // Prefix length (0-32)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Traffic event to send to userspace
|
||||||
|
#[repr(C)]
|
||||||
|
pub struct TrafficEvent {
|
||||||
|
pub src_ip: u32,
|
||||||
|
pub dst_ip: u32,
|
||||||
|
pub src_port: u16,
|
||||||
|
pub dst_port: u16,
|
||||||
|
pub protocol: u8,
|
||||||
|
pub packet_size: u16,
|
||||||
|
pub action: u8, // 0 = allowed, 1 = dropped
|
||||||
|
}
|
||||||
|
|
||||||
|
// Maps
|
||||||
|
#[map]
|
||||||
|
static CONFIG: HashMap<u32, Config> = HashMap::with_max_entries(1, BPF_F_NO_PREALLOC);
|
||||||
|
|
||||||
|
#[map]
|
||||||
|
static PERMITTED_CIDRS: HashMap<u32, CidrRange> = HashMap::with_max_entries(MAX_CIDR_RANGES, 0);
|
||||||
|
|
||||||
|
#[map]
|
||||||
|
static EVENTS: RingBuf = RingBuf::with_byte_size(256 * 1024, 0);
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn is_ip_in_cidr(ip: u32, cidr: &CidrRange) -> bool {
|
||||||
|
if cidr.prefix_len == 0 {
|
||||||
|
return true; // 0.0.0.0/0 matches everything
|
||||||
|
}
|
||||||
|
if cidr.prefix_len > 32 {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mask = if cidr.prefix_len == 32 {
|
||||||
|
0xFFFFFFFF
|
||||||
|
} else {
|
||||||
|
!((1u32 << (32 - cidr.prefix_len)) - 1)
|
||||||
|
};
|
||||||
|
|
||||||
|
(ip & mask) == (cidr.network & mask)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn is_permitted_ip(ip: u32) -> bool {
|
||||||
|
// Check against all CIDR ranges
|
||||||
|
for i in 0..MAX_CIDR_RANGES {
|
||||||
|
if let Some(cidr) = unsafe { PERMITTED_CIDRS.get(&i) } {
|
||||||
|
if is_ip_in_cidr(ip, cidr) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn ptr_at<T>(ctx: &XdpContext, offset: usize) -> Result<*const T, ()> {
|
||||||
|
let start = ctx.data();
|
||||||
|
let end = ctx.data_end();
|
||||||
|
let len = mem::size_of::<T>();
|
||||||
|
|
||||||
|
if start + offset + len > end {
|
||||||
|
return Err(());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok((start + offset) as *const T)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[xdp]
|
||||||
|
pub fn traffic_monitor(ctx: XdpContext) -> u32 {
|
||||||
|
match try_traffic_monitor(ctx) {
|
||||||
|
Ok(ret) => ret,
|
||||||
|
Err(_) => xdp_action::XDP_PASS,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn try_traffic_monitor(ctx: XdpContext) -> Result<u32, ()> {
|
||||||
|
// Get configuration
|
||||||
|
let config = unsafe { CONFIG.get(&0) }.unwrap_or(&Config { drop_packets: 0 });
|
||||||
|
|
||||||
|
// Parse Ethernet header
|
||||||
|
let ethhdr: *const EthHdr = ptr_at(&ctx, 0)?;
|
||||||
|
let eth_proto = unsafe { (*ethhdr).ether_type };
|
||||||
|
|
||||||
|
match eth_proto {
|
||||||
|
EtherType::Ipv4 => {
|
||||||
|
handle_ipv4(&ctx, config)?;
|
||||||
|
}
|
||||||
|
EtherType::Ipv6 => {
|
||||||
|
// For now, we'll pass IPv6 traffic through
|
||||||
|
// Could extend to support IPv6 CIDR ranges later
|
||||||
|
return Ok(xdp_action::XDP_PASS);
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
// Non-IP traffic, pass through
|
||||||
|
return Ok(xdp_action::XDP_PASS);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(xdp_action::XDP_PASS)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn handle_ipv4(ctx: &XdpContext, config: &Config) -> Result<(), ()> {
|
||||||
|
let ipv4hdr: *const Ipv4Hdr = ptr_at(ctx, EthHdr::LEN)?;
|
||||||
|
let src_ip = unsafe { (*ipv4hdr).src_addr };
|
||||||
|
let dst_ip = unsafe { (*ipv4hdr).dst_addr };
|
||||||
|
let protocol = unsafe { (*ipv4hdr).proto };
|
||||||
|
let total_len = unsafe { u16::from_be((*ipv4hdr).tot_len) };
|
||||||
|
|
||||||
|
// Calculate IP header length
|
||||||
|
let ip_hdr_len = (unsafe { (*ipv4hdr).version_ihl() } & 0x0F) as usize * 4;
|
||||||
|
|
||||||
|
let (src_port, dst_port) = match protocol {
|
||||||
|
IpProto::Tcp => {
|
||||||
|
let tcphdr: *const TcpHdr = ptr_at(ctx, EthHdr::LEN + ip_hdr_len)?;
|
||||||
|
(
|
||||||
|
unsafe { u16::from_be((*tcphdr).source) },
|
||||||
|
unsafe { u16::from_be((*tcphdr).dest) },
|
||||||
|
)
|
||||||
|
}
|
||||||
|
IpProto::Udp => {
|
||||||
|
let udphdr: *const UdpHdr = ptr_at(ctx, EthHdr::LEN + ip_hdr_len)?;
|
||||||
|
(
|
||||||
|
unsafe { u16::from_be((*udphdr).source) },
|
||||||
|
unsafe { u16::from_be((*udphdr).dest) },
|
||||||
|
)
|
||||||
|
}
|
||||||
|
_ => (0, 0), // Other protocols don't have ports
|
||||||
|
};
|
||||||
|
|
||||||
|
// Check if source IP is permitted
|
||||||
|
let is_permitted = is_permitted_ip(src_ip);
|
||||||
|
|
||||||
|
if !is_permitted {
|
||||||
|
// Log the event
|
||||||
|
let action = if config.drop_packets == 1 { 1 } else { 0 };
|
||||||
|
|
||||||
|
let event = TrafficEvent {
|
||||||
|
src_ip,
|
||||||
|
dst_ip,
|
||||||
|
src_port,
|
||||||
|
dst_port,
|
||||||
|
protocol: protocol as u8,
|
||||||
|
packet_size: total_len,
|
||||||
|
action,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Send event to userspace
|
||||||
|
if let Some(mut entry) = EVENTS.reserve::<TrafficEvent>(0) {
|
||||||
|
unsafe {
|
||||||
|
*entry.as_mut_ptr() = event;
|
||||||
|
}
|
||||||
|
entry.submit(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
info!(
|
||||||
|
ctx,
|
||||||
|
"Non-permitted traffic: {}:{} -> {}:{} (proto: {}, size: {}, action: {})",
|
||||||
|
u32::from_be(src_ip),
|
||||||
|
src_port,
|
||||||
|
u32::from_be(dst_ip),
|
||||||
|
dst_port,
|
||||||
|
protocol as u8,
|
||||||
|
total_len,
|
||||||
|
if action == 1 { "DROP" } else { "ALLOW" }
|
||||||
|
);
|
||||||
|
|
||||||
|
// Drop packet if configured to do so
|
||||||
|
if config.drop_packets == 1 {
|
||||||
|
return Err(()); // This will cause XDP_DROP
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[panic_handler]
|
||||||
|
fn panic(_info: &core::panic::PanicInfo) -> ! {
|
||||||
|
unsafe { core::hint::unreachable_unchecked() }
|
||||||
|
}
|
Loading…
Reference in New Issue