smoltcp/socket/
icmp.rs

1use core::cmp;
2#[cfg(feature = "async")]
3use core::task::Waker;
4
5use crate::phy::ChecksumCapabilities;
6#[cfg(feature = "async")]
7use crate::socket::WakerRegistration;
8use crate::socket::{Context, PollAt};
9
10use crate::storage::Empty;
11use crate::wire::IcmpRepr;
12#[cfg(feature = "proto-ipv4")]
13use crate::wire::{Icmpv4Packet, Icmpv4Repr, Ipv4Repr};
14#[cfg(feature = "proto-ipv6")]
15use crate::wire::{Icmpv6Packet, Icmpv6Repr, Ipv6Repr};
16use crate::wire::{IpAddress, IpListenEndpoint, IpProtocol, IpRepr};
17use crate::wire::{UdpPacket, UdpRepr};
18
19/// Error returned by [`Socket::bind`]
20#[derive(Debug, PartialEq, Eq, Clone, Copy)]
21#[cfg_attr(feature = "defmt", derive(defmt::Format))]
22pub enum BindError {
23    InvalidState,
24    Unaddressable,
25}
26
27impl core::fmt::Display for BindError {
28    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
29        match self {
30            BindError::InvalidState => write!(f, "invalid state"),
31            BindError::Unaddressable => write!(f, "unaddressable"),
32        }
33    }
34}
35
36#[cfg(feature = "std")]
37impl std::error::Error for BindError {}
38
39/// Error returned by [`Socket::send`]
40#[derive(Debug, PartialEq, Eq, Clone, Copy)]
41#[cfg_attr(feature = "defmt", derive(defmt::Format))]
42pub enum SendError {
43    Unaddressable,
44    BufferFull,
45}
46
47impl core::fmt::Display for SendError {
48    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
49        match self {
50            SendError::Unaddressable => write!(f, "unaddressable"),
51            SendError::BufferFull => write!(f, "buffer full"),
52        }
53    }
54}
55
56#[cfg(feature = "std")]
57impl std::error::Error for SendError {}
58
59/// Error returned by [`Socket::recv`]
60#[derive(Debug, PartialEq, Eq, Clone, Copy)]
61#[cfg_attr(feature = "defmt", derive(defmt::Format))]
62pub enum RecvError {
63    Exhausted,
64    Truncated,
65}
66
67impl core::fmt::Display for RecvError {
68    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
69        match self {
70            RecvError::Exhausted => write!(f, "exhausted"),
71            RecvError::Truncated => write!(f, "truncated"),
72        }
73    }
74}
75
76#[cfg(feature = "std")]
77impl std::error::Error for RecvError {}
78
79/// Type of endpoint to bind the ICMP socket to. See [IcmpSocket::bind] for
80/// more details.
81///
82/// [IcmpSocket::bind]: struct.IcmpSocket.html#method.bind
83#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
84#[cfg_attr(feature = "defmt", derive(defmt::Format))]
85pub enum Endpoint {
86    #[default]
87    Unspecified,
88    Ident(u16),
89    Udp(IpListenEndpoint),
90}
91
92impl Endpoint {
93    pub fn is_specified(&self) -> bool {
94        match *self {
95            Endpoint::Ident(_) => true,
96            Endpoint::Udp(endpoint) => endpoint.port != 0,
97            Endpoint::Unspecified => false,
98        }
99    }
100}
101
102/// An ICMP packet metadata.
103pub type PacketMetadata = crate::storage::PacketMetadata<IpAddress>;
104
105/// An ICMP packet ring buffer.
106pub type PacketBuffer<'a> = crate::storage::PacketBuffer<'a, IpAddress>;
107
108/// A ICMP socket
109///
110/// An ICMP socket is bound to a specific [IcmpEndpoint] which may
111/// be a specific UDP port to listen for ICMP error messages related
112/// to the port or a specific ICMP identifier value. See [bind] for
113/// more details.
114///
115/// [IcmpEndpoint]: enum.IcmpEndpoint.html
116/// [bind]: #method.bind
117#[derive(Debug)]
118pub struct Socket<'a> {
119    rx_buffer: PacketBuffer<'a>,
120    tx_buffer: PacketBuffer<'a>,
121    /// The endpoint this socket is communicating with
122    endpoint: Endpoint,
123    /// The time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
124    hop_limit: Option<u8>,
125    #[cfg(feature = "async")]
126    rx_waker: WakerRegistration,
127    #[cfg(feature = "async")]
128    tx_waker: WakerRegistration,
129}
130
131impl<'a> Socket<'a> {
132    /// Create an ICMP socket with the given buffers.
133    pub fn new(rx_buffer: PacketBuffer<'a>, tx_buffer: PacketBuffer<'a>) -> Socket<'a> {
134        Socket {
135            rx_buffer,
136            tx_buffer,
137            endpoint: Default::default(),
138            hop_limit: None,
139            #[cfg(feature = "async")]
140            rx_waker: WakerRegistration::new(),
141            #[cfg(feature = "async")]
142            tx_waker: WakerRegistration::new(),
143        }
144    }
145
146    /// Register a waker for receive operations.
147    ///
148    /// The waker is woken on state changes that might affect the return value
149    /// of `recv` method calls, such as receiving data, or the socket closing.
150    ///
151    /// Notes:
152    ///
153    /// - Only one waker can be registered at a time. If another waker was previously registered,
154    ///   it is overwritten and will no longer be woken.
155    /// - The Waker is woken only once. Once woken, you must register it again to receive more wakes.
156    /// - "Spurious wakes" are allowed: a wake doesn't guarantee the result of `recv` has
157    ///   necessarily changed.
158    #[cfg(feature = "async")]
159    pub fn register_recv_waker(&mut self, waker: &Waker) {
160        self.rx_waker.register(waker)
161    }
162
163    /// Register a waker for send operations.
164    ///
165    /// The waker is woken on state changes that might affect the return value
166    /// of `send` method calls, such as space becoming available in the transmit
167    /// buffer, or the socket closing.
168    ///
169    /// Notes:
170    ///
171    /// - Only one waker can be registered at a time. If another waker was previously registered,
172    ///   it is overwritten and will no longer be woken.
173    /// - The Waker is woken only once. Once woken, you must register it again to receive more wakes.
174    /// - "Spurious wakes" are allowed: a wake doesn't guarantee the result of `send` has
175    ///   necessarily changed.
176    #[cfg(feature = "async")]
177    pub fn register_send_waker(&mut self, waker: &Waker) {
178        self.tx_waker.register(waker)
179    }
180
181    /// Return the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
182    ///
183    /// See also the [set_hop_limit](#method.set_hop_limit) method
184    pub fn hop_limit(&self) -> Option<u8> {
185        self.hop_limit
186    }
187
188    /// Set the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
189    ///
190    /// A socket without an explicitly set hop limit value uses the default [IANA recommended]
191    /// value (64).
192    ///
193    /// # Panics
194    ///
195    /// This function panics if a hop limit value of 0 is given. See [RFC 1122 § 3.2.1.7].
196    ///
197    /// [IANA recommended]: https://www.iana.org/assignments/ip-parameters/ip-parameters.xhtml
198    /// [RFC 1122 § 3.2.1.7]: https://tools.ietf.org/html/rfc1122#section-3.2.1.7
199    pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) {
200        // A host MUST NOT send a datagram with a hop limit value of 0
201        if let Some(0) = hop_limit {
202            panic!("the time-to-live value of a packet must not be zero")
203        }
204
205        self.hop_limit = hop_limit
206    }
207
208    /// Bind the socket to the given endpoint.
209    ///
210    /// This function returns `Err(Error::Illegal)` if the socket was open
211    /// (see [is_open](#method.is_open)), and `Err(Error::Unaddressable)`
212    /// if `endpoint` is unspecified (see [is_specified]).
213    ///
214    /// # Examples
215    ///
216    /// ## Bind to ICMP Error messages associated with a specific UDP port:
217    ///
218    /// To [recv] ICMP error messages that are associated with a specific local
219    /// UDP port, the socket may be bound to a given port using [IcmpEndpoint::Udp].
220    /// This may be useful for applications using UDP attempting to detect and/or
221    /// diagnose connection problems.
222    ///
223    /// ```
224    /// use smoltcp::wire::IpListenEndpoint;
225    /// use smoltcp::socket::icmp;
226    /// # let rx_buffer = icmp::PacketBuffer::new(vec![icmp::PacketMetadata::EMPTY], vec![0; 20]);
227    /// # let tx_buffer = icmp::PacketBuffer::new(vec![icmp::PacketMetadata::EMPTY], vec![0; 20]);
228    ///
229    /// let mut icmp_socket = // ...
230    /// # icmp::Socket::new(rx_buffer, tx_buffer);
231    ///
232    /// // Bind to ICMP error responses for UDP packets sent from port 53.
233    /// let endpoint = IpListenEndpoint::from(53);
234    /// icmp_socket.bind(icmp::Endpoint::Udp(endpoint)).unwrap();
235    /// ```
236    ///
237    /// ## Bind to a specific ICMP identifier:
238    ///
239    /// To [send] and [recv] ICMP packets that are not associated with a specific UDP
240    /// port, the socket may be bound to a specific ICMP identifier using
241    /// [IcmpEndpoint::Ident]. This is useful for sending and receiving Echo Request/Reply
242    /// messages.
243    ///
244    /// ```
245    /// use smoltcp::wire::IpListenEndpoint;
246    /// use smoltcp::socket::icmp;
247    /// # let rx_buffer = icmp::PacketBuffer::new(vec![icmp::PacketMetadata::EMPTY], vec![0; 20]);
248    /// # let tx_buffer = icmp::PacketBuffer::new(vec![icmp::PacketMetadata::EMPTY], vec![0; 20]);
249    ///
250    /// let mut icmp_socket = // ...
251    /// # icmp::Socket::new(rx_buffer, tx_buffer);
252    ///
253    /// // Bind to ICMP messages with the ICMP identifier 0x1234
254    /// icmp_socket.bind(icmp::Endpoint::Ident(0x1234)).unwrap();
255    /// ```
256    ///
257    /// [is_specified]: enum.IcmpEndpoint.html#method.is_specified
258    /// [IcmpEndpoint::Ident]: enum.IcmpEndpoint.html#variant.Ident
259    /// [IcmpEndpoint::Udp]: enum.IcmpEndpoint.html#variant.Udp
260    /// [send]: #method.send
261    /// [recv]: #method.recv
262    pub fn bind<T: Into<Endpoint>>(&mut self, endpoint: T) -> Result<(), BindError> {
263        let endpoint = endpoint.into();
264        if !endpoint.is_specified() {
265            return Err(BindError::Unaddressable);
266        }
267
268        if self.is_open() {
269            return Err(BindError::InvalidState);
270        }
271
272        self.endpoint = endpoint;
273
274        #[cfg(feature = "async")]
275        {
276            self.rx_waker.wake();
277            self.tx_waker.wake();
278        }
279
280        Ok(())
281    }
282
283    /// Check whether the transmit buffer is full.
284    #[inline]
285    pub fn can_send(&self) -> bool {
286        !self.tx_buffer.is_full()
287    }
288
289    /// Check whether the receive buffer is not empty.
290    #[inline]
291    pub fn can_recv(&self) -> bool {
292        !self.rx_buffer.is_empty()
293    }
294
295    /// Return the maximum number packets the socket can receive.
296    #[inline]
297    pub fn packet_recv_capacity(&self) -> usize {
298        self.rx_buffer.packet_capacity()
299    }
300
301    /// Return the maximum number packets the socket can transmit.
302    #[inline]
303    pub fn packet_send_capacity(&self) -> usize {
304        self.tx_buffer.packet_capacity()
305    }
306
307    /// Return the maximum number of bytes inside the recv buffer.
308    #[inline]
309    pub fn payload_recv_capacity(&self) -> usize {
310        self.rx_buffer.payload_capacity()
311    }
312
313    /// Return the maximum number of bytes inside the transmit buffer.
314    #[inline]
315    pub fn payload_send_capacity(&self) -> usize {
316        self.tx_buffer.payload_capacity()
317    }
318
319    /// Check whether the socket is open.
320    #[inline]
321    pub fn is_open(&self) -> bool {
322        self.endpoint != Endpoint::Unspecified
323    }
324
325    /// Enqueue a packet to be sent to a given remote address, and return a pointer
326    /// to its payload.
327    ///
328    /// This function returns `Err(Error::Exhausted)` if the transmit buffer is full,
329    /// `Err(Error::Truncated)` if the requested size is larger than the packet buffer
330    /// size, and `Err(Error::Unaddressable)` if the remote address is unspecified.
331    pub fn send(&mut self, size: usize, endpoint: IpAddress) -> Result<&mut [u8], SendError> {
332        if endpoint.is_unspecified() {
333            return Err(SendError::Unaddressable);
334        }
335
336        let packet_buf = self
337            .tx_buffer
338            .enqueue(size, endpoint)
339            .map_err(|_| SendError::BufferFull)?;
340
341        net_trace!("icmp:{}: buffer to send {} octets", endpoint, size);
342        Ok(packet_buf)
343    }
344
345    /// Enqueue a packet to be send to a given remote address and pass the buffer
346    /// to the provided closure. The closure then returns the size of the data written
347    /// into the buffer.
348    ///
349    /// Also see [send](#method.send).
350    pub fn send_with<F>(
351        &mut self,
352        max_size: usize,
353        endpoint: IpAddress,
354        f: F,
355    ) -> Result<usize, SendError>
356    where
357        F: FnOnce(&mut [u8]) -> usize,
358    {
359        if endpoint.is_unspecified() {
360            return Err(SendError::Unaddressable);
361        }
362
363        let size = self
364            .tx_buffer
365            .enqueue_with_infallible(max_size, endpoint, f)
366            .map_err(|_| SendError::BufferFull)?;
367
368        net_trace!("icmp:{}: buffer to send {} octets", endpoint, size);
369        Ok(size)
370    }
371
372    /// Enqueue a packet to be sent to a given remote address, and fill it from a slice.
373    ///
374    /// See also [send](#method.send).
375    pub fn send_slice(&mut self, data: &[u8], endpoint: IpAddress) -> Result<(), SendError> {
376        let packet_buf = self.send(data.len(), endpoint)?;
377        packet_buf.copy_from_slice(data);
378        Ok(())
379    }
380
381    /// Dequeue a packet received from a remote endpoint, and return the `IpAddress` as well
382    /// as a pointer to the payload.
383    ///
384    /// This function returns `Err(Error::Exhausted)` if the receive buffer is empty.
385    pub fn recv(&mut self) -> Result<(&[u8], IpAddress), RecvError> {
386        let (endpoint, packet_buf) = self.rx_buffer.dequeue().map_err(|_| RecvError::Exhausted)?;
387
388        net_trace!(
389            "icmp:{}: receive {} buffered octets",
390            endpoint,
391            packet_buf.len()
392        );
393        Ok((packet_buf, endpoint))
394    }
395
396    /// Dequeue a packet received from a remote endpoint, copy the payload into the given slice,
397    /// and return the amount of octets copied as well as the `IpAddress`
398    ///
399    /// **Note**: when the size of the provided buffer is smaller than the size of the payload,
400    /// the packet is dropped and a `RecvError::Truncated` error is returned.
401    ///
402    /// See also [recv](#method.recv).
403    pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<(usize, IpAddress), RecvError> {
404        let (buffer, endpoint) = self.recv()?;
405
406        if data.len() < buffer.len() {
407            return Err(RecvError::Truncated);
408        }
409
410        let length = cmp::min(data.len(), buffer.len());
411        data[..length].copy_from_slice(&buffer[..length]);
412        Ok((length, endpoint))
413    }
414
415    /// Return the amount of octets queued in the transmit buffer.
416    pub fn send_queue(&self) -> usize {
417        self.tx_buffer.payload_bytes_count()
418    }
419
420    /// Return the amount of octets queued in the receive buffer.
421    pub fn recv_queue(&self) -> usize {
422        self.rx_buffer.payload_bytes_count()
423    }
424
425    /// Fitler determining whether the socket accepts a given ICMPv4 packet.
426    /// Accepted packets are enqueued into the socket's receive buffer.
427    #[cfg(feature = "proto-ipv4")]
428    #[inline]
429    pub(crate) fn accepts_v4(
430        &self,
431        cx: &mut Context,
432        ip_repr: &Ipv4Repr,
433        icmp_repr: &Icmpv4Repr,
434    ) -> bool {
435        match (&self.endpoint, icmp_repr) {
436            // If we are bound to ICMP errors associated to a UDP port, only
437            // accept Destination Unreachable or Time Exceeded messages with
438            // the data containing a UDP packet send from the local port we
439            // are bound to.
440            (
441                &Endpoint::Udp(endpoint),
442                &Icmpv4Repr::DstUnreachable { data, header, .. }
443                | &Icmpv4Repr::TimeExceeded { data, header, .. },
444            ) if endpoint.addr.is_none() || endpoint.addr == Some(ip_repr.dst_addr.into()) => {
445                let packet = UdpPacket::new_unchecked(data);
446                match UdpRepr::parse(
447                    &packet,
448                    &header.src_addr.into(),
449                    &header.dst_addr.into(),
450                    &cx.checksum_caps(),
451                ) {
452                    Ok(repr) => endpoint.port == repr.src_port,
453                    Err(_) => false,
454                }
455            }
456            // If we are bound to a specific ICMP identifier value, only accept an
457            // Echo Request/Reply with the identifier field matching the endpoint
458            // port.
459            (&Endpoint::Ident(bound_ident), &Icmpv4Repr::EchoRequest { ident, .. })
460            | (&Endpoint::Ident(bound_ident), &Icmpv4Repr::EchoReply { ident, .. }) => {
461                ident == bound_ident
462            }
463            _ => false,
464        }
465    }
466
467    /// Fitler determining whether the socket accepts a given ICMPv6 packet.
468    /// Accepted packets are enqueued into the socket's receive buffer.
469    #[cfg(feature = "proto-ipv6")]
470    #[inline]
471    pub(crate) fn accepts_v6(
472        &self,
473        cx: &mut Context,
474        ip_repr: &Ipv6Repr,
475        icmp_repr: &Icmpv6Repr,
476    ) -> bool {
477        match (&self.endpoint, icmp_repr) {
478            // If we are bound to ICMP errors associated to a UDP port, only
479            // accept Destination Unreachable or Time Exceeded messages with
480            // the data containing a UDP packet send from the local port we
481            // are bound to.
482            (
483                &Endpoint::Udp(endpoint),
484                &Icmpv6Repr::DstUnreachable { data, header, .. }
485                | &Icmpv6Repr::TimeExceeded { data, header, .. },
486            ) if endpoint.addr.is_none() || endpoint.addr == Some(ip_repr.dst_addr.into()) => {
487                let packet = UdpPacket::new_unchecked(data);
488                match UdpRepr::parse(
489                    &packet,
490                    &header.src_addr.into(),
491                    &header.dst_addr.into(),
492                    &cx.checksum_caps(),
493                ) {
494                    Ok(repr) => endpoint.port == repr.src_port,
495                    Err(_) => false,
496                }
497            }
498            // If we are bound to a specific ICMP identifier value, only accept an
499            // Echo Request/Reply with the identifier field matching the endpoint
500            // port.
501            (
502                &Endpoint::Ident(bound_ident),
503                &Icmpv6Repr::EchoRequest { ident, .. } | &Icmpv6Repr::EchoReply { ident, .. },
504            ) => ident == bound_ident,
505            _ => false,
506        }
507    }
508
509    #[cfg(feature = "proto-ipv4")]
510    pub(crate) fn process_v4(
511        &mut self,
512        _cx: &mut Context,
513        ip_repr: &Ipv4Repr,
514        icmp_repr: &Icmpv4Repr,
515    ) {
516        net_trace!("icmp: receiving {} octets", icmp_repr.buffer_len());
517
518        match self
519            .rx_buffer
520            .enqueue(icmp_repr.buffer_len(), ip_repr.src_addr.into())
521        {
522            Ok(packet_buf) => {
523                icmp_repr.emit(
524                    &mut Icmpv4Packet::new_unchecked(packet_buf),
525                    &ChecksumCapabilities::default(),
526                );
527            }
528            Err(_) => net_trace!("icmp: buffer full, dropped incoming packet"),
529        }
530
531        #[cfg(feature = "async")]
532        self.rx_waker.wake();
533    }
534
535    #[cfg(feature = "proto-ipv6")]
536    pub(crate) fn process_v6(
537        &mut self,
538        _cx: &mut Context,
539        ip_repr: &Ipv6Repr,
540        icmp_repr: &Icmpv6Repr,
541    ) {
542        net_trace!("icmp: receiving {} octets", icmp_repr.buffer_len());
543
544        match self
545            .rx_buffer
546            .enqueue(icmp_repr.buffer_len(), ip_repr.src_addr.into())
547        {
548            Ok(packet_buf) => icmp_repr.emit(
549                &ip_repr.src_addr,
550                &ip_repr.dst_addr,
551                &mut Icmpv6Packet::new_unchecked(packet_buf),
552                &ChecksumCapabilities::default(),
553            ),
554            Err(_) => net_trace!("icmp: buffer full, dropped incoming packet"),
555        }
556
557        #[cfg(feature = "async")]
558        self.rx_waker.wake();
559    }
560
561    pub(crate) fn dispatch<F, E>(&mut self, cx: &mut Context, emit: F) -> Result<(), E>
562    where
563        F: FnOnce(&mut Context, (IpRepr, IcmpRepr)) -> Result<(), E>,
564    {
565        let hop_limit = self.hop_limit.unwrap_or(64);
566        let res = self.tx_buffer.dequeue_with(|remote_endpoint, packet_buf| {
567            net_trace!(
568                "icmp:{}: sending {} octets",
569                remote_endpoint,
570                packet_buf.len()
571            );
572            match *remote_endpoint {
573                #[cfg(feature = "proto-ipv4")]
574                IpAddress::Ipv4(dst_addr) => {
575                    let src_addr = match cx.get_source_address_ipv4(&dst_addr) {
576                        Some(addr) => addr,
577                        None => {
578                            net_trace!(
579                                "icmp:{}: not find suitable source address, dropping",
580                                remote_endpoint
581                            );
582                            return Ok(());
583                        }
584                    };
585                    let packet = Icmpv4Packet::new_unchecked(&*packet_buf);
586                    let repr = match Icmpv4Repr::parse(&packet, &ChecksumCapabilities::ignored()) {
587                        Ok(x) => x,
588                        Err(_) => {
589                            net_trace!(
590                                "icmp:{}: malformed packet in queue, dropping",
591                                remote_endpoint
592                            );
593                            return Ok(());
594                        }
595                    };
596                    let ip_repr = IpRepr::Ipv4(Ipv4Repr {
597                        src_addr,
598                        dst_addr,
599                        next_header: IpProtocol::Icmp,
600                        payload_len: repr.buffer_len(),
601                        hop_limit,
602                    });
603                    emit(cx, (ip_repr, IcmpRepr::Ipv4(repr)))
604                }
605                #[cfg(feature = "proto-ipv6")]
606                IpAddress::Ipv6(dst_addr) => {
607                    let src_addr = cx.get_source_address_ipv6(&dst_addr);
608
609                    let packet = Icmpv6Packet::new_unchecked(&*packet_buf);
610                    let repr = match Icmpv6Repr::parse(
611                        &src_addr,
612                        &dst_addr,
613                        &packet,
614                        &ChecksumCapabilities::ignored(),
615                    ) {
616                        Ok(x) => x,
617                        Err(_) => {
618                            net_trace!(
619                                "icmp:{}: malformed packet in queue, dropping",
620                                remote_endpoint
621                            );
622                            return Ok(());
623                        }
624                    };
625                    let ip_repr = IpRepr::Ipv6(Ipv6Repr {
626                        src_addr,
627                        dst_addr,
628                        next_header: IpProtocol::Icmpv6,
629                        payload_len: repr.buffer_len(),
630                        hop_limit,
631                    });
632                    emit(cx, (ip_repr, IcmpRepr::Ipv6(repr)))
633                }
634            }
635        });
636        match res {
637            Err(Empty) => Ok(()),
638            Ok(Err(e)) => Err(e),
639            Ok(Ok(())) => {
640                #[cfg(feature = "async")]
641                self.tx_waker.wake();
642                Ok(())
643            }
644        }
645    }
646
647    pub(crate) fn poll_at(&self, _cx: &mut Context) -> PollAt {
648        if self.tx_buffer.is_empty() {
649            PollAt::Ingress
650        } else {
651            PollAt::Now
652        }
653    }
654}
655
656#[cfg(test)]
657mod tests_common {
658    pub use super::*;
659    pub use crate::wire::IpAddress;
660
661    pub fn buffer(packets: usize) -> PacketBuffer<'static> {
662        PacketBuffer::new(vec![PacketMetadata::EMPTY; packets], vec![0; 66 * packets])
663    }
664
665    pub fn socket(
666        rx_buffer: PacketBuffer<'static>,
667        tx_buffer: PacketBuffer<'static>,
668    ) -> Socket<'static> {
669        Socket::new(rx_buffer, tx_buffer)
670    }
671
672    pub const LOCAL_PORT: u16 = 53;
673
674    pub static UDP_REPR: UdpRepr = UdpRepr {
675        src_port: 53,
676        dst_port: 9090,
677    };
678
679    pub static UDP_PAYLOAD: &[u8] = &[0xff; 10];
680}
681
682#[cfg(all(test, feature = "proto-ipv4"))]
683mod test_ipv4 {
684    use crate::phy::Medium;
685    use crate::tests::setup;
686    use rstest::*;
687
688    use super::tests_common::*;
689    use crate::wire::{Icmpv4DstUnreachable, IpEndpoint, Ipv4Address};
690
691    const REMOTE_IPV4: Ipv4Address = Ipv4Address::new(192, 168, 1, 2);
692    const LOCAL_IPV4: Ipv4Address = Ipv4Address::new(192, 168, 1, 1);
693    const LOCAL_END_V4: IpEndpoint = IpEndpoint {
694        addr: IpAddress::Ipv4(LOCAL_IPV4),
695        port: LOCAL_PORT,
696    };
697
698    static ECHOV4_REPR: Icmpv4Repr = Icmpv4Repr::EchoRequest {
699        ident: 0x1234,
700        seq_no: 0x5678,
701        data: &[0xff; 16],
702    };
703
704    static LOCAL_IPV4_REPR: IpRepr = IpRepr::Ipv4(Ipv4Repr {
705        src_addr: LOCAL_IPV4,
706        dst_addr: REMOTE_IPV4,
707        next_header: IpProtocol::Icmp,
708        payload_len: 24,
709        hop_limit: 0x40,
710    });
711
712    static REMOTE_IPV4_REPR: Ipv4Repr = Ipv4Repr {
713        src_addr: REMOTE_IPV4,
714        dst_addr: LOCAL_IPV4,
715        next_header: IpProtocol::Icmp,
716        payload_len: 24,
717        hop_limit: 0x40,
718    };
719
720    #[test]
721    fn test_send_unaddressable() {
722        let mut socket = socket(buffer(0), buffer(1));
723        assert_eq!(
724            socket.send_slice(b"abcdef", IpAddress::Ipv4(Ipv4Address::new(0, 0, 0, 0))),
725            Err(SendError::Unaddressable)
726        );
727        assert_eq!(socket.send_slice(b"abcdef", REMOTE_IPV4.into()), Ok(()));
728    }
729
730    #[rstest]
731    #[case::ethernet(Medium::Ethernet)]
732    #[cfg(feature = "medium-ethernet")]
733    fn test_send_dispatch(#[case] medium: Medium) {
734        let (mut iface, _, _) = setup(medium);
735        let cx = iface.context();
736
737        let mut socket = socket(buffer(0), buffer(1));
738        let checksum = ChecksumCapabilities::default();
739
740        assert_eq!(socket.dispatch(cx, |_, _| unreachable!()), Ok::<_, ()>(()));
741
742        // This buffer is too long
743        assert_eq!(
744            socket.send_slice(&[0xff; 67], REMOTE_IPV4.into()),
745            Err(SendError::BufferFull)
746        );
747        assert!(socket.can_send());
748
749        let mut bytes = [0xff; 24];
750        let mut packet = Icmpv4Packet::new_unchecked(&mut bytes);
751        ECHOV4_REPR.emit(&mut packet, &checksum);
752
753        assert_eq!(
754            socket.send_slice(&*packet.into_inner(), REMOTE_IPV4.into()),
755            Ok(())
756        );
757        assert_eq!(
758            socket.send_slice(b"123456", REMOTE_IPV4.into()),
759            Err(SendError::BufferFull)
760        );
761        assert!(!socket.can_send());
762
763        assert_eq!(
764            socket.dispatch(cx, |_, (ip_repr, icmp_repr)| {
765                assert_eq!(ip_repr, LOCAL_IPV4_REPR);
766                assert_eq!(icmp_repr, ECHOV4_REPR.into());
767                Err(())
768            }),
769            Err(())
770        );
771        // buffer is not taken off of the tx queue due to the error
772        assert!(!socket.can_send());
773
774        assert_eq!(
775            socket.dispatch(cx, |_, (ip_repr, icmp_repr)| {
776                assert_eq!(ip_repr, LOCAL_IPV4_REPR);
777                assert_eq!(icmp_repr, ECHOV4_REPR.into());
778                Ok::<_, ()>(())
779            }),
780            Ok(())
781        );
782        // buffer is taken off of the queue this time
783        assert!(socket.can_send());
784    }
785
786    #[rstest]
787    #[case::ethernet(Medium::Ethernet)]
788    #[cfg(feature = "medium-ethernet")]
789    fn test_set_hop_limit_v4(#[case] medium: Medium) {
790        let (mut iface, _, _) = setup(medium);
791        let cx = iface.context();
792
793        let mut s = socket(buffer(0), buffer(1));
794        let checksum = ChecksumCapabilities::default();
795
796        let mut bytes = [0xff; 24];
797        let mut packet = Icmpv4Packet::new_unchecked(&mut bytes);
798        ECHOV4_REPR.emit(&mut packet, &checksum);
799
800        s.set_hop_limit(Some(0x2a));
801
802        assert_eq!(
803            s.send_slice(&*packet.into_inner(), REMOTE_IPV4.into()),
804            Ok(())
805        );
806        assert_eq!(
807            s.dispatch(cx, |_, (ip_repr, _)| {
808                assert_eq!(
809                    ip_repr,
810                    IpRepr::Ipv4(Ipv4Repr {
811                        src_addr: LOCAL_IPV4,
812                        dst_addr: REMOTE_IPV4,
813                        next_header: IpProtocol::Icmp,
814                        payload_len: ECHOV4_REPR.buffer_len(),
815                        hop_limit: 0x2a,
816                    })
817                );
818                Ok::<_, ()>(())
819            }),
820            Ok(())
821        );
822    }
823
824    #[rstest]
825    #[case::ethernet(Medium::Ethernet)]
826    #[cfg(feature = "medium-ethernet")]
827    fn test_recv_process(#[case] medium: Medium) {
828        let (mut iface, _, _) = setup(medium);
829        let cx = iface.context();
830
831        let mut socket = socket(buffer(1), buffer(1));
832        assert_eq!(socket.bind(Endpoint::Ident(0x1234)), Ok(()));
833
834        assert!(!socket.can_recv());
835        assert_eq!(socket.recv(), Err(RecvError::Exhausted));
836
837        let checksum = ChecksumCapabilities::default();
838
839        let mut bytes = [0xff; 24];
840        let mut packet = Icmpv4Packet::new_unchecked(&mut bytes[..]);
841        ECHOV4_REPR.emit(&mut packet, &checksum);
842        let data = &*packet.into_inner();
843
844        assert!(socket.accepts_v4(cx, &REMOTE_IPV4_REPR, &ECHOV4_REPR));
845        socket.process_v4(cx, &REMOTE_IPV4_REPR, &ECHOV4_REPR);
846        assert!(socket.can_recv());
847
848        assert!(socket.accepts_v4(cx, &REMOTE_IPV4_REPR, &ECHOV4_REPR));
849        socket.process_v4(cx, &REMOTE_IPV4_REPR, &ECHOV4_REPR);
850
851        assert_eq!(socket.recv(), Ok((data, REMOTE_IPV4.into())));
852        assert!(!socket.can_recv());
853    }
854
855    #[rstest]
856    #[case::ethernet(Medium::Ethernet)]
857    #[cfg(feature = "medium-ethernet")]
858    fn test_accept_bad_id(#[case] medium: Medium) {
859        let (mut iface, _, _) = setup(medium);
860        let cx = iface.context();
861
862        let mut socket = socket(buffer(1), buffer(1));
863        assert_eq!(socket.bind(Endpoint::Ident(0x1234)), Ok(()));
864
865        let checksum = ChecksumCapabilities::default();
866        let mut bytes = [0xff; 20];
867        let mut packet = Icmpv4Packet::new_unchecked(&mut bytes);
868        let icmp_repr = Icmpv4Repr::EchoRequest {
869            ident: 0x4321,
870            seq_no: 0x5678,
871            data: &[0xff; 16],
872        };
873        icmp_repr.emit(&mut packet, &checksum);
874
875        // Ensure that a packet with an identifier that isn't the bound
876        // ID is not accepted
877        assert!(!socket.accepts_v4(cx, &REMOTE_IPV4_REPR, &icmp_repr));
878    }
879
880    #[rstest]
881    #[case::ethernet(Medium::Ethernet)]
882    #[cfg(feature = "medium-ethernet")]
883    fn test_accepts_udp(#[case] medium: Medium) {
884        let (mut iface, _, _) = setup(medium);
885        let cx = iface.context();
886
887        let mut socket = socket(buffer(1), buffer(1));
888        assert_eq!(socket.bind(Endpoint::Udp(LOCAL_END_V4.into())), Ok(()));
889
890        let checksum = ChecksumCapabilities::default();
891
892        let mut bytes = [0xff; 18];
893        let mut packet = UdpPacket::new_unchecked(&mut bytes);
894        UDP_REPR.emit(
895            &mut packet,
896            &REMOTE_IPV4.into(),
897            &LOCAL_IPV4.into(),
898            UDP_PAYLOAD.len(),
899            |buf| buf.copy_from_slice(UDP_PAYLOAD),
900            &checksum,
901        );
902
903        let data = &*packet.into_inner();
904
905        let icmp_repr = Icmpv4Repr::DstUnreachable {
906            reason: Icmpv4DstUnreachable::PortUnreachable,
907            header: Ipv4Repr {
908                src_addr: LOCAL_IPV4,
909                dst_addr: REMOTE_IPV4,
910                next_header: IpProtocol::Icmp,
911                payload_len: 12,
912                hop_limit: 0x40,
913            },
914            data,
915        };
916        let ip_repr = Ipv4Repr {
917            src_addr: REMOTE_IPV4,
918            dst_addr: LOCAL_IPV4,
919            next_header: IpProtocol::Icmp,
920            payload_len: icmp_repr.buffer_len(),
921            hop_limit: 0x40,
922        };
923
924        assert!(!socket.can_recv());
925
926        // Ensure we can accept ICMP error response to the bound
927        // UDP port
928        assert!(socket.accepts_v4(cx, &ip_repr, &icmp_repr));
929        socket.process_v4(cx, &ip_repr, &icmp_repr);
930        assert!(socket.can_recv());
931
932        let mut bytes = [0x00; 46];
933        let mut packet = Icmpv4Packet::new_unchecked(&mut bytes[..]);
934        icmp_repr.emit(&mut packet, &checksum);
935        assert_eq!(
936            socket.recv(),
937            Ok((&*packet.into_inner(), REMOTE_IPV4.into()))
938        );
939        assert!(!socket.can_recv());
940    }
941}
942
943#[cfg(all(test, feature = "proto-ipv6"))]
944mod test_ipv6 {
945    use crate::phy::Medium;
946    use crate::tests::setup;
947    use rstest::*;
948
949    use super::tests_common::*;
950
951    use crate::wire::{Icmpv6DstUnreachable, IpEndpoint, Ipv6Address};
952
953    const REMOTE_IPV6: Ipv6Address = Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 2);
954    const LOCAL_IPV6: Ipv6Address = Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1);
955    const LOCAL_END_V6: IpEndpoint = IpEndpoint {
956        addr: IpAddress::Ipv6(LOCAL_IPV6),
957        port: LOCAL_PORT,
958    };
959    static ECHOV6_REPR: Icmpv6Repr = Icmpv6Repr::EchoRequest {
960        ident: 0x1234,
961        seq_no: 0x5678,
962        data: &[0xff; 16],
963    };
964
965    static LOCAL_IPV6_REPR: Ipv6Repr = Ipv6Repr {
966        src_addr: LOCAL_IPV6,
967        dst_addr: REMOTE_IPV6,
968        next_header: IpProtocol::Icmpv6,
969        payload_len: 24,
970        hop_limit: 0x40,
971    };
972
973    static REMOTE_IPV6_REPR: Ipv6Repr = Ipv6Repr {
974        src_addr: REMOTE_IPV6,
975        dst_addr: LOCAL_IPV6,
976        next_header: IpProtocol::Icmpv6,
977        payload_len: 24,
978        hop_limit: 0x40,
979    };
980
981    #[test]
982    fn test_send_unaddressable() {
983        let mut socket = socket(buffer(0), buffer(1));
984        assert_eq!(
985            socket.send_slice(b"abcdef", IpAddress::Ipv6(Ipv6Address::UNSPECIFIED)),
986            Err(SendError::Unaddressable)
987        );
988        assert_eq!(socket.send_slice(b"abcdef", REMOTE_IPV6.into()), Ok(()));
989    }
990
991    #[rstest]
992    #[case::ethernet(Medium::Ethernet)]
993    #[cfg(feature = "medium-ethernet")]
994    fn test_send_dispatch(#[case] medium: Medium) {
995        let (mut iface, _, _) = setup(medium);
996        let cx = iface.context();
997
998        let mut socket = socket(buffer(0), buffer(1));
999        let checksum = ChecksumCapabilities::default();
1000
1001        assert_eq!(socket.dispatch(cx, |_, _| unreachable!()), Ok::<_, ()>(()));
1002
1003        // This buffer is too long
1004        assert_eq!(
1005            socket.send_slice(&[0xff; 67], REMOTE_IPV6.into()),
1006            Err(SendError::BufferFull)
1007        );
1008        assert!(socket.can_send());
1009
1010        let mut bytes = vec![0xff; 24];
1011        let mut packet = Icmpv6Packet::new_unchecked(&mut bytes);
1012        ECHOV6_REPR.emit(&LOCAL_IPV6, &REMOTE_IPV6, &mut packet, &checksum);
1013
1014        assert_eq!(
1015            socket.send_slice(&*packet.into_inner(), REMOTE_IPV6.into()),
1016            Ok(())
1017        );
1018        assert_eq!(
1019            socket.send_slice(b"123456", REMOTE_IPV6.into()),
1020            Err(SendError::BufferFull)
1021        );
1022        assert!(!socket.can_send());
1023
1024        assert_eq!(
1025            socket.dispatch(cx, |_, (ip_repr, icmp_repr)| {
1026                assert_eq!(ip_repr, LOCAL_IPV6_REPR.into());
1027                assert_eq!(icmp_repr, ECHOV6_REPR.into());
1028                Err(())
1029            }),
1030            Err(())
1031        );
1032        // buffer is not taken off of the tx queue due to the error
1033        assert!(!socket.can_send());
1034
1035        assert_eq!(
1036            socket.dispatch(cx, |_, (ip_repr, icmp_repr)| {
1037                assert_eq!(ip_repr, LOCAL_IPV6_REPR.into());
1038                assert_eq!(icmp_repr, ECHOV6_REPR.into());
1039                Ok::<_, ()>(())
1040            }),
1041            Ok(())
1042        );
1043        // buffer is taken off of the queue this time
1044        assert!(socket.can_send());
1045    }
1046
1047    #[rstest]
1048    #[case::ethernet(Medium::Ethernet)]
1049    #[cfg(feature = "medium-ethernet")]
1050    fn test_set_hop_limit(#[case] medium: Medium) {
1051        let (mut iface, _, _) = setup(medium);
1052        let cx = iface.context();
1053
1054        let mut s = socket(buffer(0), buffer(1));
1055        let checksum = ChecksumCapabilities::default();
1056
1057        let mut bytes = vec![0xff; 24];
1058        let mut packet = Icmpv6Packet::new_unchecked(&mut bytes);
1059        ECHOV6_REPR.emit(&LOCAL_IPV6, &REMOTE_IPV6, &mut packet, &checksum);
1060
1061        s.set_hop_limit(Some(0x2a));
1062
1063        assert_eq!(
1064            s.send_slice(&*packet.into_inner(), REMOTE_IPV6.into()),
1065            Ok(())
1066        );
1067        assert_eq!(
1068            s.dispatch(cx, |_, (ip_repr, _)| {
1069                assert_eq!(
1070                    ip_repr,
1071                    IpRepr::Ipv6(Ipv6Repr {
1072                        src_addr: LOCAL_IPV6,
1073                        dst_addr: REMOTE_IPV6,
1074                        next_header: IpProtocol::Icmpv6,
1075                        payload_len: ECHOV6_REPR.buffer_len(),
1076                        hop_limit: 0x2a,
1077                    })
1078                );
1079                Ok::<_, ()>(())
1080            }),
1081            Ok(())
1082        );
1083    }
1084
1085    #[rstest]
1086    #[case::ethernet(Medium::Ethernet)]
1087    #[cfg(feature = "medium-ethernet")]
1088    fn test_recv_process(#[case] medium: Medium) {
1089        let (mut iface, _, _) = setup(medium);
1090        let cx = iface.context();
1091
1092        let mut socket = socket(buffer(1), buffer(1));
1093        assert_eq!(socket.bind(Endpoint::Ident(0x1234)), Ok(()));
1094
1095        assert!(!socket.can_recv());
1096        assert_eq!(socket.recv(), Err(RecvError::Exhausted));
1097
1098        let checksum = ChecksumCapabilities::default();
1099
1100        let mut bytes = [0xff; 24];
1101        let mut packet = Icmpv6Packet::new_unchecked(&mut bytes[..]);
1102        ECHOV6_REPR.emit(&LOCAL_IPV6, &REMOTE_IPV6, &mut packet, &checksum);
1103        let data = &*packet.into_inner();
1104
1105        assert!(socket.accepts_v6(cx, &REMOTE_IPV6_REPR, &ECHOV6_REPR));
1106        socket.process_v6(cx, &REMOTE_IPV6_REPR, &ECHOV6_REPR);
1107        assert!(socket.can_recv());
1108
1109        assert!(socket.accepts_v6(cx, &REMOTE_IPV6_REPR, &ECHOV6_REPR));
1110        socket.process_v6(cx, &REMOTE_IPV6_REPR, &ECHOV6_REPR);
1111
1112        assert_eq!(socket.recv(), Ok((data, REMOTE_IPV6.into())));
1113        assert!(!socket.can_recv());
1114    }
1115
1116    #[rstest]
1117    #[case::ethernet(Medium::Ethernet)]
1118    #[cfg(feature = "medium-ethernet")]
1119    fn test_truncated_recv_slice(#[case] medium: Medium) {
1120        let (mut iface, _, _) = setup(medium);
1121        let cx = iface.context();
1122
1123        let mut socket = socket(buffer(1), buffer(1));
1124        assert_eq!(socket.bind(Endpoint::Ident(0x1234)), Ok(()));
1125
1126        let checksum = ChecksumCapabilities::default();
1127
1128        let mut bytes = [0xff; 24];
1129        let mut packet = Icmpv6Packet::new_unchecked(&mut bytes[..]);
1130        ECHOV6_REPR.emit(&LOCAL_IPV6, &REMOTE_IPV6, &mut packet, &checksum);
1131
1132        assert!(socket.accepts_v6(cx, &REMOTE_IPV6_REPR, &ECHOV6_REPR));
1133        socket.process_v6(cx, &REMOTE_IPV6_REPR, &ECHOV6_REPR);
1134        assert!(socket.can_recv());
1135
1136        assert!(socket.accepts_v6(cx, &REMOTE_IPV6_REPR, &ECHOV6_REPR));
1137        socket.process_v6(cx, &REMOTE_IPV6_REPR, &ECHOV6_REPR);
1138
1139        let mut buffer = [0u8; 1];
1140        assert_eq!(
1141            socket.recv_slice(&mut buffer[..]),
1142            Err(RecvError::Truncated)
1143        );
1144        assert!(!socket.can_recv());
1145    }
1146
1147    #[rstest]
1148    #[case::ethernet(Medium::Ethernet)]
1149    #[cfg(feature = "medium-ethernet")]
1150    fn test_accept_bad_id(#[case] medium: Medium) {
1151        let (mut iface, _, _) = setup(medium);
1152        let cx = iface.context();
1153
1154        let mut socket = socket(buffer(1), buffer(1));
1155        assert_eq!(socket.bind(Endpoint::Ident(0x1234)), Ok(()));
1156
1157        let checksum = ChecksumCapabilities::default();
1158        let mut bytes = [0xff; 20];
1159        let mut packet = Icmpv6Packet::new_unchecked(&mut bytes);
1160        let icmp_repr = Icmpv6Repr::EchoRequest {
1161            ident: 0x4321,
1162            seq_no: 0x5678,
1163            data: &[0xff; 16],
1164        };
1165        icmp_repr.emit(&LOCAL_IPV6, &REMOTE_IPV6, &mut packet, &checksum);
1166
1167        // Ensure that a packet with an identifier that isn't the bound
1168        // ID is not accepted
1169        assert!(!socket.accepts_v6(cx, &REMOTE_IPV6_REPR, &icmp_repr));
1170    }
1171
1172    #[rstest]
1173    #[case::ethernet(Medium::Ethernet)]
1174    #[cfg(feature = "medium-ethernet")]
1175    fn test_accepts_udp(#[case] medium: Medium) {
1176        let (mut iface, _, _) = setup(medium);
1177        let cx = iface.context();
1178
1179        let mut socket = socket(buffer(1), buffer(1));
1180        assert_eq!(socket.bind(Endpoint::Udp(LOCAL_END_V6.into())), Ok(()));
1181
1182        let checksum = ChecksumCapabilities::default();
1183
1184        let mut bytes = [0xff; 18];
1185        let mut packet = UdpPacket::new_unchecked(&mut bytes);
1186        UDP_REPR.emit(
1187            &mut packet,
1188            &REMOTE_IPV6.into(),
1189            &LOCAL_IPV6.into(),
1190            UDP_PAYLOAD.len(),
1191            |buf| buf.copy_from_slice(UDP_PAYLOAD),
1192            &checksum,
1193        );
1194
1195        let data = &*packet.into_inner();
1196
1197        let icmp_repr = Icmpv6Repr::DstUnreachable {
1198            reason: Icmpv6DstUnreachable::PortUnreachable,
1199            header: Ipv6Repr {
1200                src_addr: LOCAL_IPV6,
1201                dst_addr: REMOTE_IPV6,
1202                next_header: IpProtocol::Icmpv6,
1203                payload_len: 12,
1204                hop_limit: 0x40,
1205            },
1206            data,
1207        };
1208        let ip_repr = Ipv6Repr {
1209            src_addr: REMOTE_IPV6,
1210            dst_addr: LOCAL_IPV6,
1211            next_header: IpProtocol::Icmpv6,
1212            payload_len: icmp_repr.buffer_len(),
1213            hop_limit: 0x40,
1214        };
1215
1216        assert!(!socket.can_recv());
1217
1218        // Ensure we can accept ICMP error response to the bound
1219        // UDP port
1220        assert!(socket.accepts_v6(cx, &ip_repr, &icmp_repr));
1221        socket.process_v6(cx, &ip_repr, &icmp_repr);
1222        assert!(socket.can_recv());
1223
1224        let mut bytes = [0x00; 66];
1225        let mut packet = Icmpv6Packet::new_unchecked(&mut bytes[..]);
1226        icmp_repr.emit(&LOCAL_IPV6, &REMOTE_IPV6, &mut packet, &checksum);
1227        assert_eq!(
1228            socket.recv(),
1229            Ok((&*packet.into_inner(), REMOTE_IPV6.into()))
1230        );
1231        assert!(!socket.can_recv());
1232    }
1233}