network/socket/
icmp.rs

1use crate::{Duration, Error, IpAddress, Port, Result, SocketContext};
2use alloc::vec;
3use core::{
4    future::poll_fn,
5    task::{Context, Poll},
6};
7use smoltcp::{socket::icmp, wire};
8
9#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord, Clone)]
10pub enum IcmpEndpoint {
11    #[default]
12    Unspecified,
13    Identifier(u16),
14    Udp((Option<IpAddress>, Port)),
15}
16
17impl IcmpEndpoint {
18    pub const fn into_smoltcp(&self) -> icmp::Endpoint {
19        match self {
20            IcmpEndpoint::Unspecified => icmp::Endpoint::Unspecified,
21            IcmpEndpoint::Identifier(id) => icmp::Endpoint::Ident(*id),
22            IcmpEndpoint::Udp((addr_opt, port)) => {
23                let addr = match addr_opt {
24                    Some(a) => Some(a.into_smoltcp()),
25                    None => None,
26                };
27
28                let endpoint = wire::IpListenEndpoint {
29                    addr,
30                    port: port.into_inner(),
31                };
32
33                icmp::Endpoint::Udp(endpoint)
34            }
35        }
36    }
37}
38
39pub struct IcmpSocket {
40    context: SocketContext,
41}
42
43impl IcmpSocket {
44    pub(crate) fn new(context: SocketContext) -> Self {
45        Self { context }
46    }
47
48    pub async fn with<F, R>(&self, f: F) -> R
49    where
50        F: FnOnce(&icmp::Socket<'static>) -> R,
51    {
52        self.context.with(f).await
53    }
54
55    pub async fn with_mutable<F, R>(&self, f: F) -> R
56    where
57        F: FnOnce(&mut icmp::Socket<'static>) -> R,
58    {
59        self.context.with_mutable(f).await
60    }
61
62    pub fn poll_with<F, R>(&self, context: &mut Context<'_>, f: F) -> Poll<R>
63    where
64        F: FnOnce(&icmp::Socket, &mut Context<'_>) -> Poll<R>,
65    {
66        self.context.poll_with(context, f)
67    }
68
69    pub fn poll_with_mutable<F, R>(&self, context: &mut Context<'_>, f: F) -> Poll<R>
70    where
71        F: FnOnce(&mut icmp::Socket, &mut Context<'_>) -> Poll<R>,
72    {
73        self.context.poll_with_mutable(context, f)
74    }
75
76    fn poll_send_to(
77        &self,
78        context: &mut Context<'_>,
79        buffer: &[u8],
80        remote_endpoint: &IpAddress,
81    ) -> Poll<Result<()>> {
82        self.poll_with_mutable(context, |socket, context| {
83            let send_capacity_too_small = socket.payload_send_capacity() < buffer.len();
84            if send_capacity_too_small {
85                return Poll::Ready(Err(Error::PacketTooLarge));
86            }
87
88            let remote_endpoint = remote_endpoint.into_smoltcp();
89
90            match socket.send_slice(buffer, remote_endpoint) {
91                Ok(()) => Poll::Ready(Ok(())),
92                Err(icmp::SendError::BufferFull) => {
93                    socket.register_send_waker(context.waker());
94                    Poll::Pending
95                }
96                Err(icmp::SendError::Unaddressable) => {
97                    if socket.is_open() {
98                        Poll::Ready(Err(Error::NoRoute))
99                    } else {
100                        Poll::Ready(Err(Error::SocketNotBound))
101                    }
102                }
103            }
104        })
105    }
106
107    fn poll_receive_from(
108        &self,
109        context: &mut Context<'_>,
110        buffer: &mut [u8],
111    ) -> Poll<Result<(usize, IpAddress)>> {
112        self.poll_with_mutable(context, |socket, context| match socket.recv_slice(buffer) {
113            Ok((size, remote_endpoint)) => {
114                let remote_endpoint = IpAddress::from_smoltcp(&remote_endpoint);
115                Poll::Ready(Ok((size, remote_endpoint)))
116            }
117            Err(icmp::RecvError::Truncated) => Poll::Ready(Err(Error::Truncated)),
118            Err(icmp::RecvError::Exhausted) => {
119                socket.register_recv_waker(context.waker());
120
121                self.context.stack.wake_runner();
122                Poll::Pending
123            }
124        })
125    }
126
127    pub async fn bind(&self, endpoint: IcmpEndpoint) -> Result<()> {
128        let endpoint = endpoint.into_smoltcp();
129
130        self.with_mutable(|socket: &mut icmp::Socket| socket.bind(endpoint))
131            .await?;
132        Ok(())
133    }
134
135    pub async fn can_write(&self) -> bool {
136        self.with(icmp::Socket::can_send).await
137    }
138
139    pub async fn can_read(&self) -> bool {
140        self.with(icmp::Socket::can_recv).await
141    }
142
143    pub async fn write_to(&self, buffer: &[u8], endpoint: impl Into<IpAddress>) -> Result<()> {
144        let address: IpAddress = endpoint.into();
145
146        poll_fn(|context| self.poll_send_to(context, buffer, &address)).await
147    }
148
149    pub async fn read_from(&self, buffer: &mut [u8]) -> Result<(usize, IpAddress)> {
150        poll_fn(|context| self.poll_receive_from(context, buffer)).await
151    }
152
153    pub async fn read_from_with_timeout(
154        &self,
155        buffer: &mut [u8],
156        timeout: impl Into<Duration>,
157    ) -> Result<(usize, IpAddress)> {
158        use embassy_futures::select::{Either, select};
159
160        let receive = poll_fn(|context| self.poll_receive_from(context, buffer));
161        let sleep = task::sleep(timeout.into());
162
163        match select(receive, sleep).await {
164            Either::First(result) => result,
165            Either::Second(_) => Err(Error::TimedOut),
166        }
167    }
168
169    /// Sends an ICMP echo request (ping) to the specified remote address and waits for a reply.
170    /// Returns the round-trip time if successful.
171    ///
172    /// # Errors
173    ///
174    /// Returns an error if the ping request fails or times out.
175    pub async fn ping(
176        &self,
177        remote_address: &IpAddress,
178        sequence_number: u16,
179        identifier: u16,
180        timeout: Duration,
181        payload_size: usize,
182    ) -> Result<Duration> {
183        use wire::{Icmpv4Packet, Icmpv4Repr, Icmpv6Packet, Icmpv6Repr};
184
185        let mut echo_payload = vec![0u8; payload_size];
186        let start_time = crate::get_smoltcp_time();
187
188        let timestamp_millis = start_time.total_millis() as u64;
189        echo_payload[0..8].copy_from_slice(&timestamp_millis.to_be_bytes());
190
191        let mut stack_lock = self.context.stack.lock().await;
192
193        let src_addr_v6 = if let IpAddress::IPv6(v6_addr) = remote_address {
194            Some(
195                stack_lock
196                    .interface
197                    .get_source_address_ipv6(&v6_addr.into_smoltcp()),
198            )
199        } else {
200            None
201        };
202
203        let socket = stack_lock
204            .sockets
205            .get_mut::<icmp::Socket>(self.context.handle);
206
207        let remote_endpoint = remote_address.into_smoltcp();
208
209        let checksum_caps = smoltcp::phy::ChecksumCapabilities::default();
210
211        match remote_address {
212            IpAddress::IPv4(_) => {
213                let icmp_repr = Icmpv4Repr::EchoRequest {
214                    ident: identifier,
215                    seq_no: sequence_number,
216                    data: &echo_payload,
217                };
218
219                let icmp_payload = socket
220                    .send(icmp_repr.buffer_len(), remote_endpoint)
221                    .map_err(|e| match e {
222                        icmp::SendError::BufferFull => Error::ResourceBusy,
223                        icmp::SendError::Unaddressable => Error::NoRoute,
224                    })?;
225
226                let mut icmp_packet = Icmpv4Packet::new_unchecked(icmp_payload);
227                icmp_repr.emit(&mut icmp_packet, &checksum_caps);
228            }
229            IpAddress::IPv6(v6_addr) => {
230                let icmp_repr = Icmpv6Repr::EchoRequest {
231                    ident: identifier,
232                    seq_no: sequence_number,
233                    data: &echo_payload,
234                };
235
236                let icmp_payload = socket
237                    .send(icmp_repr.buffer_len(), remote_endpoint)
238                    .map_err(|e| match e {
239                        icmp::SendError::BufferFull => Error::ResourceBusy,
240                        icmp::SendError::Unaddressable => Error::NoRoute,
241                    })?;
242
243                let src_addr = src_addr_v6.unwrap();
244                let mut icmp_packet = Icmpv6Packet::new_unchecked(icmp_payload);
245                icmp_repr.emit(
246                    &src_addr,
247                    &v6_addr.into_smoltcp(),
248                    &mut icmp_packet,
249                    &checksum_caps,
250                );
251            }
252        }
253
254        drop(stack_lock);
255
256        self.context.stack.wake_runner();
257
258        let timeout_end = start_time + timeout.into_smoltcp();
259
260        loop {
261            let now = crate::get_smoltcp_time();
262            if now >= timeout_end {
263                return Err(Error::TimedOut);
264            }
265
266            let mut recv_buffer = [0u8; 256];
267            let result = self.read_from_with_timeout(&mut recv_buffer, timeout).await;
268
269            match result {
270                Ok((size, addr)) if addr == *remote_address => {
271                    // Parse the received packet
272                    let is_valid_reply = match remote_address {
273                        IpAddress::IPv4(_) => {
274                            if let Ok(packet) = Icmpv4Packet::new_checked(&recv_buffer[..size]) {
275                                if let Ok(repr) = Icmpv4Repr::parse(&packet, &checksum_caps) {
276                                    matches!(
277                                        repr,
278                                        Icmpv4Repr::EchoReply {
279                                            ident: id,
280                                            seq_no,
281                                            ..
282                                        } if id == identifier && seq_no == sequence_number
283                                    )
284                                } else {
285                                    false
286                                }
287                            } else {
288                                false
289                            }
290                        }
291                        IpAddress::IPv6(v6_addr) => {
292                            if let Ok(packet) = Icmpv6Packet::new_checked(&recv_buffer[..size]) {
293                                let src_addr = self
294                                    .context
295                                    .stack
296                                    .with_mutable(|s| s.get_source_ip_v6_address(*v6_addr))
297                                    .await;
298
299                                if let Ok(repr) = Icmpv6Repr::parse(
300                                    &v6_addr.into_smoltcp(),
301                                    &src_addr.into_smoltcp(),
302                                    &packet,
303                                    &checksum_caps,
304                                ) {
305                                    matches!(
306                                        repr,
307                                        Icmpv6Repr::EchoReply {
308                                            ident: id,
309                                            seq_no,
310                                            ..
311                                        } if id == identifier && seq_no == sequence_number
312                                    )
313                                } else {
314                                    false
315                                }
316                            } else {
317                                false
318                            }
319                        }
320                    };
321
322                    if is_valid_reply {
323                        let end_time = crate::get_smoltcp_time();
324                        let rtt = end_time - start_time;
325                        return Ok(Duration::from_milliseconds(rtt.total_millis() as u64));
326                    }
327                }
328                Ok(_) => {
329                    continue;
330                }
331                Err(e) => return Err(e),
332            }
333        }
334    }
335
336    pub async fn flush(&self) -> () {
337        poll_fn(|context| {
338            self.poll_with_mutable(context, |socket, context| {
339                if socket.send_queue() == 0 {
340                    Poll::Ready(())
341                } else {
342                    socket.register_send_waker(context.waker());
343                    Poll::Pending
344                }
345            })
346        })
347        .await
348    }
349
350    pub async fn is_open(&self) -> bool {
351        self.with(icmp::Socket::is_open).await
352    }
353
354    pub async fn get_packet_read_capacity(&self) -> usize {
355        self.with(icmp::Socket::packet_recv_capacity).await
356    }
357
358    pub async fn get_packet_write_capacity(&self) -> usize {
359        self.with(icmp::Socket::packet_send_capacity).await
360    }
361
362    pub async fn get_payload_read_capacity(&self) -> usize {
363        self.with(icmp::Socket::payload_recv_capacity).await
364    }
365
366    pub async fn get_payload_write_capacity(&self) -> usize {
367        self.with(icmp::Socket::payload_send_capacity).await
368    }
369
370    pub async fn get_hop_limit(&self) -> Option<u8> {
371        self.with(icmp::Socket::hop_limit).await
372    }
373
374    pub async fn set_hop_limit(&self, hop_limit: Option<u8>) -> () {
375        self.with_mutable(|socket: &mut icmp::Socket| {
376            socket.set_hop_limit(hop_limit);
377        })
378        .await
379    }
380}