From 4251f7cd233ab3b287fa03ecb75351276f15a701 Mon Sep 17 00:00:00 2001
From: Michal Rostecki <vadorovsky@protonmail.com>
Date: Tue, 4 Jun 2024 11:54:45 +0200
Subject: [PATCH] aya-log: Allow logging `core::net::Ipv4Addr` and
 `core::net::Ipv6Addr`

IP address types are available in `core`, so they can be used also in
eBPF programs. This change adds support of these types in aya-log.

* Add implementation of `WriteTuBuf` to these types.
* Support these types in `Ipv4Formatter` and `Ipv6Formatter`.
* Support them with `DisplayHint::Ip`.
---
 aya-log-common/src/lib.rs              |  32 +++++-
 aya-log/src/lib.rs                     | 136 +++++++++++++++++++++++++
 test/integration-ebpf/src/log.rs       |  10 +-
 test/integration-test/src/tests/log.rs |  20 +++-
 4 files changed, 195 insertions(+), 3 deletions(-)

diff --git a/aya-log-common/src/lib.rs b/aya-log-common/src/lib.rs
index 24f41610..e28c782c 100644
--- a/aya-log-common/src/lib.rs
+++ b/aya-log-common/src/lib.rs
@@ -1,6 +1,9 @@
 #![no_std]
 
-use core::num::{NonZeroUsize, TryFromIntError};
+use core::{
+    net::{IpAddr, Ipv4Addr, Ipv6Addr},
+    num::{NonZeroUsize, TryFromIntError},
+};
 
 use num_enum::IntoPrimitive;
 
@@ -75,6 +78,9 @@ impl_formatter_for_types!(
 );
 
 pub trait IpFormatter {}
+impl IpFormatter for IpAddr {}
+impl IpFormatter for Ipv4Addr {}
+impl IpFormatter for Ipv6Addr {}
 impl IpFormatter for u32 {}
 impl IpFormatter for [u8; 16] {}
 impl IpFormatter for [u16; 8] {}
@@ -118,6 +124,9 @@ pub enum Argument {
     F32,
     F64,
 
+    Ipv4Addr,
+    Ipv6Addr,
+
     /// `[u8; 6]` array which represents a MAC address.
     ArrU8Len6,
     /// `[u8; 16]` array which represents an IPv6 address.
@@ -203,6 +212,27 @@ impl_write_to_buf!(usize, Argument::Usize);
 impl_write_to_buf!(f32, Argument::F32);
 impl_write_to_buf!(f64, Argument::F64);
 
+impl WriteToBuf for IpAddr {
+    fn write(self, buf: &mut [u8]) -> Option<NonZeroUsize> {
+        match self {
+            IpAddr::V4(ipv4_addr) => write(Argument::Ipv4Addr.into(), &ipv4_addr.octets(), buf),
+            IpAddr::V6(ipv6_addr) => write(Argument::Ipv6Addr.into(), &ipv6_addr.octets(), buf),
+        }
+    }
+}
+
+impl WriteToBuf for Ipv4Addr {
+    fn write(self, buf: &mut [u8]) -> Option<NonZeroUsize> {
+        write(Argument::Ipv4Addr.into(), &self.octets(), buf)
+    }
+}
+
+impl WriteToBuf for Ipv6Addr {
+    fn write(self, buf: &mut [u8]) -> Option<NonZeroUsize> {
+        write(Argument::Ipv6Addr.into(), &self.octets(), buf)
+    }
+}
+
 impl WriteToBuf for [u8; 16] {
     // This need not be inlined because the return value is Option<N> where N is 16, which is a
     // compile-time constant.
diff --git a/aya-log/src/lib.rs b/aya-log/src/lib.rs
index 8c89c115..98e0d87b 100644
--- a/aya-log/src/lib.rs
+++ b/aya-log/src/lib.rs
@@ -310,6 +310,34 @@ impl Format for u32 {
     }
 }
 
+impl Format for Ipv4Addr {
+    fn format(&self, last_hint: Option<DisplayHintWrapper>) -> Result<String, ()> {
+        match last_hint.map(|DisplayHintWrapper(dh)| dh) {
+            Some(DisplayHint::Default) => Ok(Ipv4Formatter::format(*self)),
+            Some(DisplayHint::LowerHex) => Err(()),
+            Some(DisplayHint::UpperHex) => Err(()),
+            Some(DisplayHint::Ip) => Ok(Ipv4Formatter::format(*self)),
+            Some(DisplayHint::LowerMac) => Err(()),
+            Some(DisplayHint::UpperMac) => Err(()),
+            _ => Ok(Ipv4Formatter::format(*self)),
+        }
+    }
+}
+
+impl Format for Ipv6Addr {
+    fn format(&self, last_hint: Option<DisplayHintWrapper>) -> Result<String, ()> {
+        match last_hint.map(|DisplayHintWrapper(dh)| dh) {
+            Some(DisplayHint::Default) => Ok(Ipv6Formatter::format(*self)),
+            Some(DisplayHint::LowerHex) => Err(()),
+            Some(DisplayHint::UpperHex) => Err(()),
+            Some(DisplayHint::Ip) => Ok(Ipv6Formatter::format(*self)),
+            Some(DisplayHint::LowerMac) => Err(()),
+            Some(DisplayHint::UpperMac) => Err(()),
+            _ => Ok(Ipv6Formatter::format(*self)),
+        }
+    }
+}
+
 impl Format for [u8; 6] {
     fn format(&self, last_hint: Option<DisplayHintWrapper>) -> Result<String, ()> {
         match last_hint.map(|DisplayHintWrapper(dh)| dh) {
@@ -548,6 +576,16 @@ fn log_buf(mut buf: &[u8], logger: &dyn Log) -> Result<(), ()> {
                         .format(last_hint.take())?,
                 );
             }
+            Argument::Ipv4Addr => {
+                let value: [u8; 4] = value.try_into().map_err(|_| ())?;
+                let value = Ipv4Addr::from(value);
+                full_log_msg.push_str(&value.format(last_hint.take())?)
+            }
+            Argument::Ipv6Addr => {
+                let value: [u8; 16] = value.try_into().map_err(|_| ())?;
+                let value = Ipv6Addr::from(value);
+                full_log_msg.push_str(&value.format(last_hint.take())?)
+            }
             Argument::ArrU8Len6 => {
                 let value: [u8; 6] = value.try_into().map_err(|_| ())?;
                 full_log_msg.push_str(&value.format(last_hint.take())?);
@@ -615,6 +653,8 @@ fn try_read<T: Pod>(mut buf: &[u8]) -> Result<(T, &[u8], &[u8]), ()> {
 
 #[cfg(test)]
 mod test {
+    use std::net::IpAddr;
+
     use aya_log_common::{write_record_header, WriteToBuf};
     use log::{logger, Level};
 
@@ -794,6 +834,52 @@ mod test {
         testing_logger::setup();
         let (mut len, mut input) = new_log(3).unwrap();
 
+        len += "ipv4: ".write(&mut input[len..]).unwrap().get();
+        len += DisplayHint::Ip.write(&mut input[len..]).unwrap().get();
+        Ipv4Addr::new(10, 0, 0, 1)
+            .write(&mut input[len..])
+            .unwrap()
+            .get();
+
+        _ = len;
+
+        let logger = logger();
+        let () = log_buf(&input, logger).unwrap();
+        testing_logger::validate(|captured_logs| {
+            assert_eq!(captured_logs.len(), 1);
+            assert_eq!(captured_logs[0].body, "ipv4: 10.0.0.1");
+            assert_eq!(captured_logs[0].level, Level::Info);
+        });
+    }
+
+    #[test]
+    fn test_display_hint_ip_ipv4() {
+        testing_logger::setup();
+        let (mut len, mut input) = new_log(3).unwrap();
+
+        len += "ipv4: ".write(&mut input[len..]).unwrap().get();
+        len += DisplayHint::Ip.write(&mut input[len..]).unwrap().get();
+        IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))
+            .write(&mut input[len..])
+            .unwrap()
+            .get();
+
+        _ = len;
+
+        let logger = logger();
+        let () = log_buf(&input, logger).unwrap();
+        testing_logger::validate(|captured_logs| {
+            assert_eq!(captured_logs.len(), 1);
+            assert_eq!(captured_logs[0].body, "ipv4: 10.0.0.1");
+            assert_eq!(captured_logs[0].level, Level::Info);
+        });
+    }
+
+    #[test]
+    fn test_display_hint_ipv4_u32() {
+        testing_logger::setup();
+        let (mut len, mut input) = new_log(3).unwrap();
+
         len += "ipv4: ".write(&mut input[len..]).unwrap().get();
         len += DisplayHint::Ip.write(&mut input[len..]).unwrap().get();
         // 10.0.0.1 as u32
@@ -810,6 +896,56 @@ mod test {
         });
     }
 
+    #[test]
+    fn test_display_hint_ipv6() {
+        testing_logger::setup();
+        let (mut len, mut input) = new_log(3).unwrap();
+
+        len += "ipv6: ".write(&mut input[len..]).unwrap().get();
+        len += DisplayHint::Ip.write(&mut input[len..]).unwrap().get();
+        Ipv6Addr::new(
+            0x2001, 0x0db8, 0x0000, 0x0000, 0x0000, 0x0000, 0x0001, 0x0001,
+        )
+        .write(&mut input[len..])
+        .unwrap()
+        .get();
+
+        _ = len;
+
+        let logger = logger();
+        let () = log_buf(&input, logger).unwrap();
+        testing_logger::validate(|captured_logs| {
+            assert_eq!(captured_logs.len(), 1);
+            assert_eq!(captured_logs[0].body, "ipv6: 2001:db8::1:1");
+            assert_eq!(captured_logs[0].level, Level::Info);
+        });
+    }
+
+    #[test]
+    fn test_display_hint_ip_ipv6() {
+        testing_logger::setup();
+        let (mut len, mut input) = new_log(3).unwrap();
+
+        len += "ipv6: ".write(&mut input[len..]).unwrap().get();
+        len += DisplayHint::Ip.write(&mut input[len..]).unwrap().get();
+        IpAddr::V6(Ipv6Addr::new(
+            0x2001, 0x0db8, 0x0000, 0x0000, 0x0000, 0x0000, 0x0001, 0x0001,
+        ))
+        .write(&mut input[len..])
+        .unwrap()
+        .get();
+
+        _ = len;
+
+        let logger = logger();
+        let () = log_buf(&input, logger).unwrap();
+        testing_logger::validate(|captured_logs| {
+            assert_eq!(captured_logs.len(), 1);
+            assert_eq!(captured_logs[0].body, "ipv6: 2001:db8::1:1");
+            assert_eq!(captured_logs[0].level, Level::Info);
+        });
+    }
+
     #[test]
     fn test_display_hint_ipv6_arr_u8_len_16() {
         testing_logger::setup();
diff --git a/test/integration-ebpf/src/log.rs b/test/integration-ebpf/src/log.rs
index b4f71341..cec1840b 100644
--- a/test/integration-ebpf/src/log.rs
+++ b/test/integration-ebpf/src/log.rs
@@ -1,6 +1,8 @@
 #![no_std]
 #![no_main]
 
+use core::net::{IpAddr, Ipv4Addr, Ipv6Addr};
+
 use aya_ebpf::{macros::uprobe, programs::ProbeContext};
 use aya_log_ebpf::{debug, error, info, trace, warn};
 
@@ -15,11 +17,17 @@ pub fn test_log(ctx: ProbeContext) {
         "wao",
         "wao".as_bytes()
     );
+    let ipv4 = Ipv4Addr::new(10, 0, 0, 1);
+    let ipv6 = Ipv6Addr::new(8193, 3512, 0, 0, 0, 0, 0, 1);
+    info!(&ctx, "ip structs: ipv4: {:i}, ipv6: {:i}", ipv4, ipv6); // 2001:db8::1
+    let ipv4 = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
+    let ipv6 = IpAddr::V6(Ipv6Addr::new(8193, 3512, 0, 0, 0, 0, 0, 1));
+    info!(&ctx, "ip enums: ipv4: {:i}, ipv6: {:i}", ipv4, ipv6);
     let ipv4 = 167772161u32; // 10.0.0.1
     let ipv6 = [
         32u8, 1u8, 13u8, 184u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 1u8,
     ]; // 2001:db8::1
-    info!(&ctx, "ipv4: {:i}, ipv6: {:i}", ipv4, ipv6);
+    info!(&ctx, "ip as primitives: ipv4: {:i}, ipv6: {:i}", ipv4, ipv6);
     let mac = [4u8, 32u8, 6u8, 9u8, 0u8, 64u8];
     trace!(&ctx, "mac lc: {:mac}, mac uc: {:MAC}", mac, mac);
     let hex = 0x2f;
diff --git a/test/integration-test/src/tests/log.rs b/test/integration-test/src/tests/log.rs
index 56f5edec..e039eff0 100644
--- a/test/integration-test/src/tests/log.rs
+++ b/test/integration-test/src/tests/log.rs
@@ -106,7 +106,25 @@ async fn log() {
     assert_eq!(
         records.next(),
         Some(&CapturedLog {
-            body: "ipv4: 10.0.0.1, ipv6: 2001:db8::1".into(),
+            body: "ip structs: ipv4: 10.0.0.1, ipv6: 2001:db8::1".into(),
+            level: Level::Info,
+            target: "log".into(),
+        })
+    );
+
+    assert_eq!(
+        records.next(),
+        Some(&CapturedLog {
+            body: "ip enums: ipv4: 10.0.0.1, ipv6: 2001:db8::1".into(),
+            level: Level::Info,
+            target: "log".into(),
+        })
+    );
+
+    assert_eq!(
+        records.next(),
+        Some(&CapturedLog {
+            body: "ip as primitives: ipv4: 10.0.0.1, ipv6: 2001:db8::1".into(),
             level: Level::Info,
             target: "log".into(),
         })