1use crate::{Duration, Error, IpAddress, Port, Result, SocketContext};
2use alloc::vec;
3use core::{
4 future::poll_fn,
5 task::{Context, Poll},
6};
7use smoltcp::{socket::icmp, wire};
8
9#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord, Clone)]
10pub enum IcmpEndpoint {
11 #[default]
12 Unspecified,
13 Identifier(u16),
14 Udp((Option<IpAddress>, Port)),
15}
16
17impl IcmpEndpoint {
18 pub const fn into_smoltcp(&self) -> icmp::Endpoint {
19 match self {
20 IcmpEndpoint::Unspecified => icmp::Endpoint::Unspecified,
21 IcmpEndpoint::Identifier(id) => icmp::Endpoint::Ident(*id),
22 IcmpEndpoint::Udp((addr_opt, port)) => {
23 let addr = match addr_opt {
24 Some(a) => Some(a.into_smoltcp()),
25 None => None,
26 };
27
28 let endpoint = wire::IpListenEndpoint {
29 addr,
30 port: port.into_inner(),
31 };
32
33 icmp::Endpoint::Udp(endpoint)
34 }
35 }
36 }
37}
38
39pub struct IcmpSocket {
40 context: SocketContext,
41}
42
43impl IcmpSocket {
44 pub(crate) fn new(context: SocketContext) -> Self {
45 Self { context }
46 }
47
48 pub async fn with<F, R>(&self, f: F) -> R
49 where
50 F: FnOnce(&icmp::Socket<'static>) -> R,
51 {
52 self.context.with(f).await
53 }
54
55 pub async fn with_mutable<F, R>(&self, f: F) -> R
56 where
57 F: FnOnce(&mut icmp::Socket<'static>) -> R,
58 {
59 self.context.with_mutable(f).await
60 }
61
62 pub fn poll_with<F, R>(&self, context: &mut Context<'_>, f: F) -> Poll<R>
63 where
64 F: FnOnce(&icmp::Socket, &mut Context<'_>) -> Poll<R>,
65 {
66 self.context.poll_with(context, f)
67 }
68
69 pub fn poll_with_mutable<F, R>(&self, context: &mut Context<'_>, f: F) -> Poll<R>
70 where
71 F: FnOnce(&mut icmp::Socket, &mut Context<'_>) -> Poll<R>,
72 {
73 self.context.poll_with_mutable(context, f)
74 }
75
76 fn poll_send_to(
77 &self,
78 context: &mut Context<'_>,
79 buffer: &[u8],
80 remote_endpoint: &IpAddress,
81 ) -> Poll<Result<()>> {
82 self.poll_with_mutable(context, |socket, context| {
83 let send_capacity_too_small = socket.payload_send_capacity() < buffer.len();
84 if send_capacity_too_small {
85 return Poll::Ready(Err(Error::PacketTooLarge));
86 }
87
88 let remote_endpoint = remote_endpoint.into_smoltcp();
89
90 match socket.send_slice(buffer, remote_endpoint) {
91 Ok(()) => Poll::Ready(Ok(())),
92 Err(icmp::SendError::BufferFull) => {
93 socket.register_send_waker(context.waker());
94 Poll::Pending
95 }
96 Err(icmp::SendError::Unaddressable) => {
97 if socket.is_open() {
98 Poll::Ready(Err(Error::NoRoute))
99 } else {
100 Poll::Ready(Err(Error::SocketNotBound))
101 }
102 }
103 }
104 })
105 }
106
107 fn poll_receive_from(
108 &self,
109 context: &mut Context<'_>,
110 buffer: &mut [u8],
111 ) -> Poll<Result<(usize, IpAddress)>> {
112 self.poll_with_mutable(context, |socket, context| match socket.recv_slice(buffer) {
113 Ok((size, remote_endpoint)) => {
114 let remote_endpoint = IpAddress::from_smoltcp(&remote_endpoint);
115 Poll::Ready(Ok((size, remote_endpoint)))
116 }
117 Err(icmp::RecvError::Truncated) => Poll::Ready(Err(Error::Truncated)),
118 Err(icmp::RecvError::Exhausted) => {
119 socket.register_recv_waker(context.waker());
120
121 self.context.stack.wake_runner();
122 Poll::Pending
123 }
124 })
125 }
126
127 pub async fn bind(&self, endpoint: IcmpEndpoint) -> Result<()> {
128 let endpoint = endpoint.into_smoltcp();
129
130 self.with_mutable(|socket: &mut icmp::Socket| socket.bind(endpoint))
131 .await?;
132 Ok(())
133 }
134
135 pub async fn can_write(&self) -> bool {
136 self.with(icmp::Socket::can_send).await
137 }
138
139 pub async fn can_read(&self) -> bool {
140 self.with(icmp::Socket::can_recv).await
141 }
142
143 pub async fn write_to(&self, buffer: &[u8], endpoint: impl Into<IpAddress>) -> Result<()> {
144 let address: IpAddress = endpoint.into();
145
146 poll_fn(|context| self.poll_send_to(context, buffer, &address)).await
147 }
148
149 pub async fn read_from(&self, buffer: &mut [u8]) -> Result<(usize, IpAddress)> {
150 poll_fn(|context| self.poll_receive_from(context, buffer)).await
151 }
152
153 pub async fn read_from_with_timeout(
154 &self,
155 buffer: &mut [u8],
156 timeout: impl Into<Duration>,
157 ) -> Result<(usize, IpAddress)> {
158 use embassy_futures::select::{Either, select};
159
160 let receive = poll_fn(|context| self.poll_receive_from(context, buffer));
161 let sleep = task::sleep(timeout.into());
162
163 match select(receive, sleep).await {
164 Either::First(result) => result,
165 Either::Second(_) => Err(Error::TimedOut),
166 }
167 }
168
169 pub async fn ping(
176 &self,
177 remote_address: &IpAddress,
178 sequence_number: u16,
179 identifier: u16,
180 timeout: Duration,
181 payload_size: usize,
182 ) -> Result<Duration> {
183 use wire::{Icmpv4Packet, Icmpv4Repr, Icmpv6Packet, Icmpv6Repr};
184
185 let mut echo_payload = vec![0u8; payload_size];
186 let start_time = crate::get_smoltcp_time();
187
188 let timestamp_millis = start_time.total_millis() as u64;
189 echo_payload[0..8].copy_from_slice(×tamp_millis.to_be_bytes());
190
191 let mut stack_lock = self.context.stack.lock().await;
192
193 let src_addr_v6 = if let IpAddress::IPv6(v6_addr) = remote_address {
194 Some(
195 stack_lock
196 .interface
197 .get_source_address_ipv6(&v6_addr.into_smoltcp()),
198 )
199 } else {
200 None
201 };
202
203 let socket = stack_lock
204 .sockets
205 .get_mut::<icmp::Socket>(self.context.handle);
206
207 let remote_endpoint = remote_address.into_smoltcp();
208
209 let checksum_caps = smoltcp::phy::ChecksumCapabilities::default();
210
211 match remote_address {
212 IpAddress::IPv4(_) => {
213 let icmp_repr = Icmpv4Repr::EchoRequest {
214 ident: identifier,
215 seq_no: sequence_number,
216 data: &echo_payload,
217 };
218
219 let icmp_payload = socket
220 .send(icmp_repr.buffer_len(), remote_endpoint)
221 .map_err(|e| match e {
222 icmp::SendError::BufferFull => Error::ResourceBusy,
223 icmp::SendError::Unaddressable => Error::NoRoute,
224 })?;
225
226 let mut icmp_packet = Icmpv4Packet::new_unchecked(icmp_payload);
227 icmp_repr.emit(&mut icmp_packet, &checksum_caps);
228 }
229 IpAddress::IPv6(v6_addr) => {
230 let icmp_repr = Icmpv6Repr::EchoRequest {
231 ident: identifier,
232 seq_no: sequence_number,
233 data: &echo_payload,
234 };
235
236 let icmp_payload = socket
237 .send(icmp_repr.buffer_len(), remote_endpoint)
238 .map_err(|e| match e {
239 icmp::SendError::BufferFull => Error::ResourceBusy,
240 icmp::SendError::Unaddressable => Error::NoRoute,
241 })?;
242
243 let src_addr = src_addr_v6.unwrap();
244 let mut icmp_packet = Icmpv6Packet::new_unchecked(icmp_payload);
245 icmp_repr.emit(
246 &src_addr,
247 &v6_addr.into_smoltcp(),
248 &mut icmp_packet,
249 &checksum_caps,
250 );
251 }
252 }
253
254 drop(stack_lock);
255
256 self.context.stack.wake_runner();
257
258 let timeout_end = start_time + timeout.into_smoltcp();
259
260 loop {
261 let now = crate::get_smoltcp_time();
262 if now >= timeout_end {
263 return Err(Error::TimedOut);
264 }
265
266 let mut recv_buffer = [0u8; 256];
267 let result = self.read_from_with_timeout(&mut recv_buffer, timeout).await;
268
269 match result {
270 Ok((size, addr)) if addr == *remote_address => {
271 let is_valid_reply = match remote_address {
273 IpAddress::IPv4(_) => {
274 if let Ok(packet) = Icmpv4Packet::new_checked(&recv_buffer[..size]) {
275 if let Ok(repr) = Icmpv4Repr::parse(&packet, &checksum_caps) {
276 matches!(
277 repr,
278 Icmpv4Repr::EchoReply {
279 ident: id,
280 seq_no,
281 ..
282 } if id == identifier && seq_no == sequence_number
283 )
284 } else {
285 false
286 }
287 } else {
288 false
289 }
290 }
291 IpAddress::IPv6(v6_addr) => {
292 if let Ok(packet) = Icmpv6Packet::new_checked(&recv_buffer[..size]) {
293 let src_addr = self
294 .context
295 .stack
296 .with_mutable(|s| s.get_source_ip_v6_address(*v6_addr))
297 .await;
298
299 if let Ok(repr) = Icmpv6Repr::parse(
300 &v6_addr.into_smoltcp(),
301 &src_addr.into_smoltcp(),
302 &packet,
303 &checksum_caps,
304 ) {
305 matches!(
306 repr,
307 Icmpv6Repr::EchoReply {
308 ident: id,
309 seq_no,
310 ..
311 } if id == identifier && seq_no == sequence_number
312 )
313 } else {
314 false
315 }
316 } else {
317 false
318 }
319 }
320 };
321
322 if is_valid_reply {
323 let end_time = crate::get_smoltcp_time();
324 let rtt = end_time - start_time;
325 return Ok(Duration::from_milliseconds(rtt.total_millis() as u64));
326 }
327 }
328 Ok(_) => {
329 continue;
330 }
331 Err(e) => return Err(e),
332 }
333 }
334 }
335
336 pub async fn flush(&self) -> () {
337 poll_fn(|context| {
338 self.poll_with_mutable(context, |socket, context| {
339 if socket.send_queue() == 0 {
340 Poll::Ready(())
341 } else {
342 socket.register_send_waker(context.waker());
343 Poll::Pending
344 }
345 })
346 })
347 .await
348 }
349
350 pub async fn is_open(&self) -> bool {
351 self.with(icmp::Socket::is_open).await
352 }
353
354 pub async fn get_packet_read_capacity(&self) -> usize {
355 self.with(icmp::Socket::packet_recv_capacity).await
356 }
357
358 pub async fn get_packet_write_capacity(&self) -> usize {
359 self.with(icmp::Socket::packet_send_capacity).await
360 }
361
362 pub async fn get_payload_read_capacity(&self) -> usize {
363 self.with(icmp::Socket::payload_recv_capacity).await
364 }
365
366 pub async fn get_payload_write_capacity(&self) -> usize {
367 self.with(icmp::Socket::payload_send_capacity).await
368 }
369
370 pub async fn get_hop_limit(&self) -> Option<u8> {
371 self.with(icmp::Socket::hop_limit).await
372 }
373
374 pub async fn set_hop_limit(&self, hop_limit: Option<u8>) -> () {
375 self.with_mutable(|socket: &mut icmp::Socket| {
376 socket.set_hop_limit(hop_limit);
377 })
378 .await
379 }
380}