diff --git a/traffic-monitor/src/config.rs b/traffic-monitor/src/config.rs new file mode 100644 index 00000000..3f7688b2 --- /dev/null +++ b/traffic-monitor/src/config.rs @@ -0,0 +1,86 @@ +use anyhow::{Context, Result}; +use serde::{Deserialize, Serialize}; +use std::{fs, path::Path}; + +#[derive(Debug, Serialize, Deserialize)] +pub struct TrafficMonitorConfig { + pub permitted_cidrs: Vec, +} + +impl TrafficMonitorConfig { + pub fn load>(path: P) -> Result { + let content = fs::read_to_string(path.as_ref()) + .with_context(|| format!("Failed to read config file: {:?}", path.as_ref()))?; + + serde_json::from_str(&content) + .with_context(|| "Failed to parse config file as JSON") + } + + pub fn save>(&self, path: P) -> Result<()> { + let content = serde_json::to_string_pretty(self) + .context("Failed to serialize config to JSON")?; + + fs::write(path.as_ref(), content) + .with_context(|| format!("Failed to write config file: {:?}", path.as_ref()))?; + + Ok(()) + } +} + +impl Default for TrafficMonitorConfig { + fn default() -> Self { + Self { + permitted_cidrs: vec![ + "127.0.0.0/8".to_string(), // Localhost + "10.0.0.0/8".to_string(), // Private network + "172.16.0.0/12".to_string(), // Private network + "192.168.0.0/16".to_string(), // Private network + ], + } + } +} + +// Re-export for convenience +pub use TrafficMonitorConfig as Config; + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::NamedTempFile; + + #[test] + fn test_config_serialization() { + let config = TrafficMonitorConfig { + permitted_cidrs: vec![ + "192.168.1.0/24".to_string(), + "10.0.0.0/8".to_string(), + ], + }; + + let json = serde_json::to_string(&config).unwrap(); + let deserialized: TrafficMonitorConfig = serde_json::from_str(&json).unwrap(); + + assert_eq!(config.permitted_cidrs, deserialized.permitted_cidrs); + } + + #[test] + fn test_config_file_operations() { + let config = TrafficMonitorConfig::default(); + let temp_file = NamedTempFile::new().unwrap(); + + // Save config + config.save(temp_file.path()).unwrap(); + + // Load config + let loaded_config = TrafficMonitorConfig::load(temp_file.path()).unwrap(); + + assert_eq!(config.permitted_cidrs, loaded_config.permitted_cidrs); + } + + #[test] + fn test_default_config() { + let config = TrafficMonitorConfig::default(); + assert!(!config.permitted_cidrs.is_empty()); + assert!(config.permitted_cidrs.contains(&"127.0.0.0/8".to_string())); + } +} \ No newline at end of file diff --git a/traffic-monitor/src/event_handler.rs b/traffic-monitor/src/event_handler.rs new file mode 100644 index 00000000..36499736 --- /dev/null +++ b/traffic-monitor/src/event_handler.rs @@ -0,0 +1,212 @@ +use log::{info, warn}; +use crate::TrafficEvent; +use std::{ + collections::HashMap, + net::Ipv4Addr, + time::{Duration, Instant}, +}; + +// Mirror of the eBPF TrafficEvent structure +#[repr(C)] +pub struct TrafficEvent { + pub src_ip: u32, + pub dst_ip: u32, + pub src_port: u16, + pub dst_port: u16, + pub protocol: u8, + pub packet_size: u16, + pub action: u8, // 0 = allowed, 1 = dropped +} + +#[derive(Debug, Clone)] +pub struct EventStats { + pub count: u64, + pub total_bytes: u64, + pub first_seen: Instant, + pub last_seen: Instant, + pub protocols: HashMap, +} + +impl EventStats { + fn new() -> Self { + Self { + count: 0, + total_bytes: 0, + first_seen: Instant::now(), + last_seen: Instant::now(), + protocols: HashMap::new(), + } + } + + fn update(&mut self, event: &TrafficEvent) { + self.count += 1; + self.total_bytes += event.packet_size as u64; + self.last_seen = Instant::now(); + *self.protocols.entry(event.protocol).or_insert(0) += 1; + } +} + +pub struct EventHandler { + stats: HashMap, // keyed by source IP + last_summary: Instant, + summary_interval: Duration, +} + +impl EventHandler { + pub fn new() -> Self { + Self { + stats: HashMap::new(), + last_summary: Instant::now(), + summary_interval: Duration::from_secs(60), // Print summary every minute + } + } + + pub fn handle_event(&mut self, event: TrafficEvent) { + let src_ip = Ipv4Addr::from(u32::from_be(event.src_ip)); + let dst_ip = Ipv4Addr::from(u32::from_be(event.dst_ip)); + + // Update statistics + let stats = self.stats.entry(event.src_ip).or_insert_with(EventStats::new); + stats.update(&event); + + // Log the event + let protocol_name = protocol_to_string(event.protocol); + let action_str = if event.action == 1 { "DROPPED" } else { "LOGGED" }; + + if event.src_port != 0 && event.dst_port != 0 { + info!( + "[{}] Non-permitted traffic: {}:{} -> {}:{} (proto: {}, size: {} bytes)", + action_str, src_ip, event.src_port, dst_ip, event.dst_port, + protocol_name, event.packet_size + ); + } else { + info!( + "[{}] Non-permitted traffic: {} -> {} (proto: {}, size: {} bytes)", + action_str, src_ip, dst_ip, protocol_name, event.packet_size + ); + } + + // Print periodic summary + if self.last_summary.elapsed() >= self.summary_interval { + self.print_summary(); + self.last_summary = Instant::now(); + } + } + + fn print_summary(&self) { + if self.stats.is_empty() { + info!("No non-permitted traffic detected in the last minute"); + return; + } + + info!("=== Traffic Summary (last {} seconds) ===", self.summary_interval.as_secs()); + + let mut sorted_ips: Vec<_> = self.stats.iter().collect(); + sorted_ips.sort_by(|a, b| b.1.count.cmp(&a.1.count)); + + for (ip, stats) in sorted_ips.iter().take(10) { // Top 10 most active IPs + let src_ip = Ipv4Addr::from(u32::from_be(**ip)); + let duration = stats.last_seen.duration_since(stats.first_seen); + + info!( + " {}: {} packets, {} bytes, duration: {:.1}s", + src_ip, stats.count, stats.total_bytes, duration.as_secs_f64() + ); + + // Show protocol breakdown + for (proto, count) in &stats.protocols { + info!(" {}: {} packets", protocol_to_string(*proto), count); + } + } + + let total_ips = self.stats.len(); + let total_packets: u64 = self.stats.values().map(|s| s.count).sum(); + let total_bytes: u64 = self.stats.values().map(|s| s.total_bytes).sum(); + + info!( + "Total: {} unique IPs, {} packets, {} bytes", + total_ips, total_packets, total_bytes + ); + info!("=== End Summary ==="); + } + + pub fn get_stats(&self) -> &HashMap { + &self.stats + } + + pub fn clear_stats(&mut self) { + self.stats.clear(); + } +} + +impl Default for EventHandler { + fn default() -> Self { + Self::new() + } +} + +fn protocol_to_string(protocol: u8) -> &'static str { + match protocol { + 1 => "ICMP", + 6 => "TCP", + 17 => "UDP", + 47 => "GRE", + 50 => "ESP", + 51 => "AH", + 58 => "ICMPv6", + 132 => "SCTP", + _ => "Unknown", + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_event_stats_update() { + let mut stats = EventStats::new(); + let event = TrafficEvent { + src_ip: 0x0100007f, // 127.0.0.1 in network byte order + dst_ip: 0x0200a8c0, // 192.168.0.2 in network byte order + src_port: 12345, + dst_port: 80, + protocol: 6, // TCP + packet_size: 1500, + action: 0, + }; + + stats.update(&event); + + assert_eq!(stats.count, 1); + assert_eq!(stats.total_bytes, 1500); + assert_eq!(stats.protocols.get(&6), Some(&1)); + } + + #[test] + fn test_protocol_names() { + assert_eq!(protocol_to_string(1), "ICMP"); + assert_eq!(protocol_to_string(6), "TCP"); + assert_eq!(protocol_to_string(17), "UDP"); + assert_eq!(protocol_to_string(255), "Unknown"); + } + + #[test] + fn test_event_handler_basic() { + let mut handler = EventHandler::new(); + let event = TrafficEvent { + src_ip: 0x0100007f, + dst_ip: 0x0200a8c0, + src_port: 12345, + dst_port: 80, + protocol: 6, + packet_size: 1500, + action: 0, + }; + + handler.handle_event(event); + + assert_eq!(handler.stats.len(), 1); + assert!(handler.stats.contains_key(&0x0100007f)); + } +} \ No newline at end of file diff --git a/traffic-monitor/src/ip_utils.rs b/traffic-monitor/src/ip_utils.rs new file mode 100644 index 00000000..895f2d39 --- /dev/null +++ b/traffic-monitor/src/ip_utils.rs @@ -0,0 +1,160 @@ +use anyhow::{anyhow, Result}; +use std::net::Ipv4Addr; + +/// Parse a CIDR notation string into network address and prefix length +pub fn parse_cidr(cidr: &str) -> Result<(Ipv4Addr, u8)> { + let parts: Vec<&str> = cidr.split('/').collect(); + + if parts.len() != 2 { + return Err(anyhow!("Invalid CIDR format: {}", cidr)); + } + + let ip: Ipv4Addr = parts[0].parse() + .map_err(|_| anyhow!("Invalid IP address: {}", parts[0]))?; + + let prefix_len: u8 = parts[1].parse() + .map_err(|_| anyhow!("Invalid prefix length: {}", parts[1]))?; + + if prefix_len > 32 { + return Err(anyhow!("Prefix length must be 0-32, got: {}", prefix_len)); + } + + // Calculate the network address by applying the subnet mask + let ip_u32 = u32::from(ip); + let mask = if prefix_len == 0 { + 0 + } else { + !((1u32 << (32 - prefix_len)) - 1) + }; + let network_u32 = ip_u32 & mask; + let network_ip = Ipv4Addr::from(network_u32); + + Ok((network_ip, prefix_len)) +} + +/// Check if an IP address is within a CIDR range +pub fn ip_in_cidr(ip: Ipv4Addr, network: Ipv4Addr, prefix_len: u8) -> bool { + if prefix_len == 0 { + return true; // 0.0.0.0/0 matches everything + } + if prefix_len > 32 { + return false; + } + + let ip_u32 = u32::from(ip); + let network_u32 = u32::from(network); + + let mask = if prefix_len == 32 { + 0xFFFFFFFF + } else { + !((1u32 << (32 - prefix_len)) - 1) + }; + + (ip_u32 & mask) == (network_u32 & mask) +} + +/// Convert IP address to human-readable string with additional info +pub fn format_ip_info(ip: Ipv4Addr) -> String { + let octets = ip.octets(); + let class = match octets[0] { + 1..=126 => "Class A", + 128..=191 => "Class B", + 192..=223 => "Class C", + 224..=239 => "Class D (Multicast)", + 240..=255 => "Class E (Reserved)", + _ => "Invalid", + }; + + let special = if ip.is_private() { + " (Private)" + } else if ip.is_loopback() { + " (Loopback)" + } else if ip.is_multicast() { + " (Multicast)" + } else if ip.is_broadcast() { + " (Broadcast)" + } else { + " (Public)" + }; + + format!("{} [{}{}]", ip, class, special) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_cidr_valid() { + let (network, prefix) = parse_cidr("192.168.1.0/24").unwrap(); + assert_eq!(network, Ipv4Addr::new(192, 168, 1, 0)); + assert_eq!(prefix, 24); + } + + #[test] + fn test_parse_cidr_network_calculation() { + // Input IP is not the network address, should be calculated + let (network, prefix) = parse_cidr("192.168.1.100/24").unwrap(); + assert_eq!(network, Ipv4Addr::new(192, 168, 1, 0)); + assert_eq!(prefix, 24); + } + + #[test] + fn test_parse_cidr_invalid_format() { + assert!(parse_cidr("192.168.1.0").is_err()); + assert!(parse_cidr("192.168.1.0/24/extra").is_err()); + } + + #[test] + fn test_parse_cidr_invalid_ip() { + assert!(parse_cidr("256.1.1.1/24").is_err()); + assert!(parse_cidr("invalid.ip/24").is_err()); + } + + #[test] + fn test_parse_cidr_invalid_prefix() { + assert!(parse_cidr("192.168.1.0/33").is_err()); + assert!(parse_cidr("192.168.1.0/abc").is_err()); + } + + #[test] + fn test_ip_in_cidr() { + let network = Ipv4Addr::new(192, 168, 1, 0); + + // Test IPs within the range + assert!(ip_in_cidr(Ipv4Addr::new(192, 168, 1, 1), network, 24)); + assert!(ip_in_cidr(Ipv4Addr::new(192, 168, 1, 100), network, 24)); + assert!(ip_in_cidr(Ipv4Addr::new(192, 168, 1, 255), network, 24)); + + // Test IPs outside the range + assert!(!ip_in_cidr(Ipv4Addr::new(192, 168, 2, 1), network, 24)); + assert!(!ip_in_cidr(Ipv4Addr::new(10, 0, 0, 1), network, 24)); + } + + #[test] + fn test_ip_in_cidr_edge_cases() { + let network = Ipv4Addr::new(0, 0, 0, 0); + + // /0 should match everything + assert!(ip_in_cidr(Ipv4Addr::new(1, 2, 3, 4), network, 0)); + assert!(ip_in_cidr(Ipv4Addr::new(255, 255, 255, 255), network, 0)); + + // /32 should match only exact IP + let exact_ip = Ipv4Addr::new(192, 168, 1, 1); + assert!(ip_in_cidr(exact_ip, exact_ip, 32)); + assert!(!ip_in_cidr(Ipv4Addr::new(192, 168, 1, 2), exact_ip, 32)); + } + + #[test] + fn test_format_ip_info() { + let info = format_ip_info(Ipv4Addr::new(192, 168, 1, 1)); + assert!(info.contains("192.168.1.1")); + assert!(info.contains("Class C")); + assert!(info.contains("Private")); + + let info = format_ip_info(Ipv4Addr::new(127, 0, 0, 1)); + assert!(info.contains("127.0.0.1")); + assert!(info.contains("Class A")); + assert!(info.contains("Loopback")); + } +} \ No newline at end of file diff --git a/traffic-monitor/src/lib.rs b/traffic-monitor/src/lib.rs new file mode 100644 index 00000000..de1f6b47 --- /dev/null +++ b/traffic-monitor/src/lib.rs @@ -0,0 +1,7 @@ +pub mod config; +pub mod event_handler; +pub mod ip_utils; + +pub use config::TrafficMonitorConfig; +pub use event_handler::{EventHandler, TrafficEvent}; +pub use ip_utils::{format_ip_info, ip_in_cidr, parse_cidr}; \ No newline at end of file diff --git a/traffic-monitor/src/traffic_monitor.bpf.rs b/traffic-monitor/src/traffic_monitor.bpf.rs new file mode 100644 index 00000000..89afebc0 --- /dev/null +++ b/traffic-monitor/src/traffic_monitor.bpf.rs @@ -0,0 +1,214 @@ +#![no_std] +#![no_main] + +use aya_ebpf::{ + bindings::{xdp_action, BPF_F_NO_PREALLOC}, + macros::{map, xdp}, + maps::{HashMap, RingBuf}, + programs::XdpContext, +}; +use aya_log_ebpf::info; +use core::mem; +use network_types::{ + eth::{EthHdr, EtherType}, + ip::{IpProto, Ipv4Hdr, Ipv6Hdr}, + tcp::TcpHdr, + udp::UdpHdr, +}; + +// Maximum number of CIDR ranges we can store +const MAX_CIDR_RANGES: u32 = 256; + +// Configuration passed from userspace +#[repr(C)] +#[derive(Clone, Copy)] +pub struct Config { + pub drop_packets: u8, // 0 = log only, 1 = log and drop +} + +// CIDR range for IPv4 +#[repr(C)] +#[derive(Clone, Copy)] +pub struct CidrRange { + pub network: u32, // Network address in network byte order + pub prefix_len: u8, // Prefix length (0-32) +} + +// Traffic event to send to userspace +#[repr(C)] +pub struct TrafficEvent { + pub src_ip: u32, + pub dst_ip: u32, + pub src_port: u16, + pub dst_port: u16, + pub protocol: u8, + pub packet_size: u16, + pub action: u8, // 0 = allowed, 1 = dropped +} + +// Maps +#[map] +static CONFIG: HashMap = HashMap::with_max_entries(1, BPF_F_NO_PREALLOC); + +#[map] +static PERMITTED_CIDRS: HashMap = HashMap::with_max_entries(MAX_CIDR_RANGES, 0); + +#[map] +static EVENTS: RingBuf = RingBuf::with_byte_size(256 * 1024, 0); + +#[inline(always)] +fn is_ip_in_cidr(ip: u32, cidr: &CidrRange) -> bool { + if cidr.prefix_len == 0 { + return true; // 0.0.0.0/0 matches everything + } + if cidr.prefix_len > 32 { + return false; + } + + let mask = if cidr.prefix_len == 32 { + 0xFFFFFFFF + } else { + !((1u32 << (32 - cidr.prefix_len)) - 1) + }; + + (ip & mask) == (cidr.network & mask) +} + +#[inline(always)] +fn is_permitted_ip(ip: u32) -> bool { + // Check against all CIDR ranges + for i in 0..MAX_CIDR_RANGES { + if let Some(cidr) = unsafe { PERMITTED_CIDRS.get(&i) } { + if is_ip_in_cidr(ip, cidr) { + return true; + } + } + } + false +} + +#[inline(always)] +fn ptr_at(ctx: &XdpContext, offset: usize) -> Result<*const T, ()> { + let start = ctx.data(); + let end = ctx.data_end(); + let len = mem::size_of::(); + + if start + offset + len > end { + return Err(()); + } + + Ok((start + offset) as *const T) +} + +#[xdp] +pub fn traffic_monitor(ctx: XdpContext) -> u32 { + match try_traffic_monitor(ctx) { + Ok(ret) => ret, + Err(_) => xdp_action::XDP_PASS, + } +} + +fn try_traffic_monitor(ctx: XdpContext) -> Result { + // Get configuration + let config = unsafe { CONFIG.get(&0) }.unwrap_or(&Config { drop_packets: 0 }); + + // Parse Ethernet header + let ethhdr: *const EthHdr = ptr_at(&ctx, 0)?; + let eth_proto = unsafe { (*ethhdr).ether_type }; + + match eth_proto { + EtherType::Ipv4 => { + handle_ipv4(&ctx, config)?; + } + EtherType::Ipv6 => { + // For now, we'll pass IPv6 traffic through + // Could extend to support IPv6 CIDR ranges later + return Ok(xdp_action::XDP_PASS); + } + _ => { + // Non-IP traffic, pass through + return Ok(xdp_action::XDP_PASS); + } + } + + Ok(xdp_action::XDP_PASS) +} + +fn handle_ipv4(ctx: &XdpContext, config: &Config) -> Result<(), ()> { + let ipv4hdr: *const Ipv4Hdr = ptr_at(ctx, EthHdr::LEN)?; + let src_ip = unsafe { (*ipv4hdr).src_addr }; + let dst_ip = unsafe { (*ipv4hdr).dst_addr }; + let protocol = unsafe { (*ipv4hdr).proto }; + let total_len = unsafe { u16::from_be((*ipv4hdr).tot_len) }; + + // Calculate IP header length + let ip_hdr_len = (unsafe { (*ipv4hdr).version_ihl() } & 0x0F) as usize * 4; + + let (src_port, dst_port) = match protocol { + IpProto::Tcp => { + let tcphdr: *const TcpHdr = ptr_at(ctx, EthHdr::LEN + ip_hdr_len)?; + ( + unsafe { u16::from_be((*tcphdr).source) }, + unsafe { u16::from_be((*tcphdr).dest) }, + ) + } + IpProto::Udp => { + let udphdr: *const UdpHdr = ptr_at(ctx, EthHdr::LEN + ip_hdr_len)?; + ( + unsafe { u16::from_be((*udphdr).source) }, + unsafe { u16::from_be((*udphdr).dest) }, + ) + } + _ => (0, 0), // Other protocols don't have ports + }; + + // Check if source IP is permitted + let is_permitted = is_permitted_ip(src_ip); + + if !is_permitted { + // Log the event + let action = if config.drop_packets == 1 { 1 } else { 0 }; + + let event = TrafficEvent { + src_ip, + dst_ip, + src_port, + dst_port, + protocol: protocol as u8, + packet_size: total_len, + action, + }; + + // Send event to userspace + if let Some(mut entry) = EVENTS.reserve::(0) { + unsafe { + *entry.as_mut_ptr() = event; + } + entry.submit(0); + } + + info!( + ctx, + "Non-permitted traffic: {}:{} -> {}:{} (proto: {}, size: {}, action: {})", + u32::from_be(src_ip), + src_port, + u32::from_be(dst_ip), + dst_port, + protocol as u8, + total_len, + if action == 1 { "DROP" } else { "ALLOW" } + ); + + // Drop packet if configured to do so + if config.drop_packets == 1 { + return Err(()); // This will cause XDP_DROP + } + } + + Ok(()) +} + +#[panic_handler] +fn panic(_info: &core::panic::PanicInfo) -> ! { + unsafe { core::hint::unreachable_unchecked() } +} \ No newline at end of file