smoltcp/socket/
udp.rs

1use core::cmp::min;
2#[cfg(feature = "async")]
3use core::task::Waker;
4
5use crate::iface::Context;
6use crate::phy::PacketMeta;
7use crate::socket::PollAt;
8#[cfg(feature = "async")]
9use crate::socket::WakerRegistration;
10use crate::storage::Empty;
11use crate::wire::{IpAddress, IpEndpoint, IpListenEndpoint, IpProtocol, IpRepr, UdpRepr};
12
13/// Metadata for a sent or received UDP packet.
14#[cfg_attr(feature = "defmt", derive(defmt::Format))]
15#[derive(Debug, PartialEq, Eq, Clone, Copy)]
16pub struct UdpMetadata {
17    /// The IP endpoint from which an incoming datagram was received, or to which an outgoing
18    /// datagram will be sent.
19    pub endpoint: IpEndpoint,
20    /// The IP address to which an incoming datagram was sent, or from which an outgoing datagram
21    /// will be sent. Incoming datagrams always have this set. On outgoing datagrams, if it is not
22    /// set, and the socket is not bound to a single address anyway, a suitable address will be
23    /// determined using the algorithms of RFC 6724 (candidate source address selection) or some
24    /// heuristic (for IPv4).
25    pub local_address: Option<IpAddress>,
26    pub meta: PacketMeta,
27}
28
29impl<T: Into<IpEndpoint>> From<T> for UdpMetadata {
30    fn from(value: T) -> Self {
31        Self {
32            endpoint: value.into(),
33            local_address: None,
34            meta: PacketMeta::default(),
35        }
36    }
37}
38
39impl core::fmt::Display for UdpMetadata {
40    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
41        #[cfg(feature = "packetmeta-id")]
42        return write!(f, "{}, PacketID: {:?}", self.endpoint, self.meta);
43
44        #[cfg(not(feature = "packetmeta-id"))]
45        write!(f, "{}", self.endpoint)
46    }
47}
48
49/// A UDP packet metadata.
50pub type PacketMetadata = crate::storage::PacketMetadata<UdpMetadata>;
51
52/// A UDP packet ring buffer.
53pub type PacketBuffer<'a> = crate::storage::PacketBuffer<'a, UdpMetadata>;
54
55/// Error returned by [`Socket::bind`]
56#[derive(Debug, PartialEq, Eq, Clone, Copy)]
57#[cfg_attr(feature = "defmt", derive(defmt::Format))]
58pub enum BindError {
59    InvalidState,
60    Unaddressable,
61}
62
63impl core::fmt::Display for BindError {
64    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
65        match self {
66            BindError::InvalidState => write!(f, "invalid state"),
67            BindError::Unaddressable => write!(f, "unaddressable"),
68        }
69    }
70}
71
72#[cfg(feature = "std")]
73impl std::error::Error for BindError {}
74
75/// Error returned by [`Socket::send`]
76#[derive(Debug, PartialEq, Eq, Clone, Copy)]
77#[cfg_attr(feature = "defmt", derive(defmt::Format))]
78pub enum SendError {
79    Unaddressable,
80    BufferFull,
81}
82
83impl core::fmt::Display for SendError {
84    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
85        match self {
86            SendError::Unaddressable => write!(f, "unaddressable"),
87            SendError::BufferFull => write!(f, "buffer full"),
88        }
89    }
90}
91
92#[cfg(feature = "std")]
93impl std::error::Error for SendError {}
94
95/// Error returned by [`Socket::recv`]
96#[derive(Debug, PartialEq, Eq, Clone, Copy)]
97#[cfg_attr(feature = "defmt", derive(defmt::Format))]
98pub enum RecvError {
99    Exhausted,
100    Truncated,
101}
102
103impl core::fmt::Display for RecvError {
104    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
105        match self {
106            RecvError::Exhausted => write!(f, "exhausted"),
107            RecvError::Truncated => write!(f, "truncated"),
108        }
109    }
110}
111
112#[cfg(feature = "std")]
113impl std::error::Error for RecvError {}
114
115/// A User Datagram Protocol socket.
116///
117/// A UDP socket is bound to a specific endpoint, and owns transmit and receive
118/// packet buffers.
119#[derive(Debug)]
120pub struct Socket<'a> {
121    endpoint: IpListenEndpoint,
122    rx_buffer: PacketBuffer<'a>,
123    tx_buffer: PacketBuffer<'a>,
124    /// The time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
125    hop_limit: Option<u8>,
126    #[cfg(feature = "async")]
127    rx_waker: WakerRegistration,
128    #[cfg(feature = "async")]
129    tx_waker: WakerRegistration,
130}
131
132impl<'a> Socket<'a> {
133    /// Create an UDP socket with the given buffers.
134    pub fn new(rx_buffer: PacketBuffer<'a>, tx_buffer: PacketBuffer<'a>) -> Socket<'a> {
135        Socket {
136            endpoint: IpListenEndpoint::default(),
137            rx_buffer,
138            tx_buffer,
139            hop_limit: None,
140            #[cfg(feature = "async")]
141            rx_waker: WakerRegistration::new(),
142            #[cfg(feature = "async")]
143            tx_waker: WakerRegistration::new(),
144        }
145    }
146
147    /// Register a waker for receive operations.
148    ///
149    /// The waker is woken on state changes that might affect the return value
150    /// of `recv` method calls, such as receiving data, or the socket closing.
151    ///
152    /// Notes:
153    ///
154    /// - Only one waker can be registered at a time. If another waker was previously registered,
155    ///   it is overwritten and will no longer be woken.
156    /// - The Waker is woken only once. Once woken, you must register it again to receive more wakes.
157    /// - "Spurious wakes" are allowed: a wake doesn't guarantee the result of `recv` has
158    ///   necessarily changed.
159    #[cfg(feature = "async")]
160    pub fn register_recv_waker(&mut self, waker: &Waker) {
161        self.rx_waker.register(waker)
162    }
163
164    /// Register a waker for send operations.
165    ///
166    /// The waker is woken on state changes that might affect the return value
167    /// of `send` method calls, such as space becoming available in the transmit
168    /// buffer, or the socket closing.
169    ///
170    /// Notes:
171    ///
172    /// - Only one waker can be registered at a time. If another waker was previously registered,
173    ///   it is overwritten and will no longer be woken.
174    /// - The Waker is woken only once. Once woken, you must register it again to receive more wakes.
175    /// - "Spurious wakes" are allowed: a wake doesn't guarantee the result of `send` has
176    ///   necessarily changed.
177    #[cfg(feature = "async")]
178    pub fn register_send_waker(&mut self, waker: &Waker) {
179        self.tx_waker.register(waker)
180    }
181
182    /// Return the bound endpoint.
183    #[inline]
184    pub fn endpoint(&self) -> IpListenEndpoint {
185        self.endpoint
186    }
187
188    /// Return the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
189    ///
190    /// See also the [set_hop_limit](#method.set_hop_limit) method
191    pub fn hop_limit(&self) -> Option<u8> {
192        self.hop_limit
193    }
194
195    /// Set the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
196    ///
197    /// A socket without an explicitly set hop limit value uses the default [IANA recommended]
198    /// value (64).
199    ///
200    /// # Panics
201    ///
202    /// This function panics if a hop limit value of 0 is given. See [RFC 1122 § 3.2.1.7].
203    ///
204    /// [IANA recommended]: https://www.iana.org/assignments/ip-parameters/ip-parameters.xhtml
205    /// [RFC 1122 § 3.2.1.7]: https://tools.ietf.org/html/rfc1122#section-3.2.1.7
206    pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) {
207        // A host MUST NOT send a datagram with a hop limit value of 0
208        if let Some(0) = hop_limit {
209            panic!("the time-to-live value of a packet must not be zero")
210        }
211
212        self.hop_limit = hop_limit
213    }
214
215    /// Bind the socket to the given endpoint.
216    ///
217    /// This function returns `Err(Error::Illegal)` if the socket was open
218    /// (see [is_open](#method.is_open)), and `Err(Error::Unaddressable)`
219    /// if the port in the given endpoint is zero.
220    pub fn bind<T: Into<IpListenEndpoint>>(&mut self, endpoint: T) -> Result<(), BindError> {
221        let endpoint = endpoint.into();
222        if endpoint.port == 0 {
223            return Err(BindError::Unaddressable);
224        }
225
226        if self.is_open() {
227            return Err(BindError::InvalidState);
228        }
229
230        self.endpoint = endpoint;
231
232        #[cfg(feature = "async")]
233        {
234            self.rx_waker.wake();
235            self.tx_waker.wake();
236        }
237
238        Ok(())
239    }
240
241    /// Close the socket.
242    pub fn close(&mut self) {
243        // Clear the bound endpoint of the socket.
244        self.endpoint = IpListenEndpoint::default();
245
246        // Reset the RX and TX buffers of the socket.
247        self.tx_buffer.reset();
248        self.rx_buffer.reset();
249
250        #[cfg(feature = "async")]
251        {
252            self.rx_waker.wake();
253            self.tx_waker.wake();
254        }
255    }
256
257    /// Check whether the socket is open.
258    #[inline]
259    pub fn is_open(&self) -> bool {
260        self.endpoint.port != 0
261    }
262
263    /// Check whether the transmit buffer is full.
264    #[inline]
265    pub fn can_send(&self) -> bool {
266        !self.tx_buffer.is_full()
267    }
268
269    /// Check whether the receive buffer is not empty.
270    #[inline]
271    pub fn can_recv(&self) -> bool {
272        !self.rx_buffer.is_empty()
273    }
274
275    /// Return the maximum number packets the socket can receive.
276    #[inline]
277    pub fn packet_recv_capacity(&self) -> usize {
278        self.rx_buffer.packet_capacity()
279    }
280
281    /// Return the maximum number packets the socket can transmit.
282    #[inline]
283    pub fn packet_send_capacity(&self) -> usize {
284        self.tx_buffer.packet_capacity()
285    }
286
287    /// Return the maximum number of bytes inside the recv buffer.
288    #[inline]
289    pub fn payload_recv_capacity(&self) -> usize {
290        self.rx_buffer.payload_capacity()
291    }
292
293    /// Return the maximum number of bytes inside the transmit buffer.
294    #[inline]
295    pub fn payload_send_capacity(&self) -> usize {
296        self.tx_buffer.payload_capacity()
297    }
298
299    /// Enqueue a packet to be sent to a given remote endpoint, and return a pointer
300    /// to its payload.
301    ///
302    /// This function returns `Err(Error::Exhausted)` if the transmit buffer is full,
303    /// `Err(Error::Unaddressable)` if local or remote port, or remote address are unspecified,
304    /// and `Err(Error::Truncated)` if there is not enough transmit buffer capacity
305    /// to ever send this packet.
306    pub fn send(
307        &mut self,
308        size: usize,
309        meta: impl Into<UdpMetadata>,
310    ) -> Result<&mut [u8], SendError> {
311        let meta = meta.into();
312        if self.endpoint.port == 0 {
313            return Err(SendError::Unaddressable);
314        }
315        if meta.endpoint.addr.is_unspecified() {
316            return Err(SendError::Unaddressable);
317        }
318        if meta.endpoint.port == 0 {
319            return Err(SendError::Unaddressable);
320        }
321
322        let payload_buf = self
323            .tx_buffer
324            .enqueue(size, meta)
325            .map_err(|_| SendError::BufferFull)?;
326
327        net_trace!(
328            "udp:{}:{}: buffer to send {} octets",
329            self.endpoint,
330            meta.endpoint,
331            size
332        );
333        Ok(payload_buf)
334    }
335
336    /// Enqueue a packet to be send to a given remote endpoint and pass the buffer
337    /// to the provided closure. The closure then returns the size of the data written
338    /// into the buffer.
339    ///
340    /// Also see [send](#method.send).
341    pub fn send_with<F>(
342        &mut self,
343        max_size: usize,
344        meta: impl Into<UdpMetadata>,
345        f: F,
346    ) -> Result<usize, SendError>
347    where
348        F: FnOnce(&mut [u8]) -> usize,
349    {
350        let meta = meta.into();
351        if self.endpoint.port == 0 {
352            return Err(SendError::Unaddressable);
353        }
354        if meta.endpoint.addr.is_unspecified() {
355            return Err(SendError::Unaddressable);
356        }
357        if meta.endpoint.port == 0 {
358            return Err(SendError::Unaddressable);
359        }
360
361        let size = self
362            .tx_buffer
363            .enqueue_with_infallible(max_size, meta, f)
364            .map_err(|_| SendError::BufferFull)?;
365
366        net_trace!(
367            "udp:{}:{}: buffer to send {} octets",
368            self.endpoint,
369            meta.endpoint,
370            size
371        );
372        Ok(size)
373    }
374
375    /// Enqueue a packet to be sent to a given remote endpoint, and fill it from a slice.
376    ///
377    /// See also [send](#method.send).
378    pub fn send_slice(
379        &mut self,
380        data: &[u8],
381        meta: impl Into<UdpMetadata>,
382    ) -> Result<(), SendError> {
383        self.send(data.len(), meta)?.copy_from_slice(data);
384        Ok(())
385    }
386
387    /// Dequeue a packet received from a remote endpoint, and return the endpoint as well
388    /// as a pointer to the payload.
389    ///
390    /// This function returns `Err(Error::Exhausted)` if the receive buffer is empty.
391    pub fn recv(&mut self) -> Result<(&[u8], UdpMetadata), RecvError> {
392        let (remote_endpoint, payload_buf) =
393            self.rx_buffer.dequeue().map_err(|_| RecvError::Exhausted)?;
394
395        net_trace!(
396            "udp:{}:{}: receive {} buffered octets",
397            self.endpoint,
398            remote_endpoint.endpoint,
399            payload_buf.len()
400        );
401        Ok((payload_buf, remote_endpoint))
402    }
403
404    /// Dequeue a packet received from a remote endpoint, copy the payload into the given slice,
405    /// and return the amount of octets copied as well as the endpoint.
406    ///
407    /// **Note**: when the size of the provided buffer is smaller than the size of the payload,
408    /// the packet is dropped and a `RecvError::Truncated` error is returned.
409    ///
410    /// See also [recv](#method.recv).
411    pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<(usize, UdpMetadata), RecvError> {
412        let (buffer, endpoint) = self.recv().map_err(|_| RecvError::Exhausted)?;
413
414        if data.len() < buffer.len() {
415            return Err(RecvError::Truncated);
416        }
417
418        let length = min(data.len(), buffer.len());
419        data[..length].copy_from_slice(&buffer[..length]);
420        Ok((length, endpoint))
421    }
422
423    /// Peek at a packet received from a remote endpoint, and return the endpoint as well
424    /// as a pointer to the payload without removing the packet from the receive buffer.
425    /// This function otherwise behaves identically to [recv](#method.recv).
426    ///
427    /// It returns `Err(Error::Exhausted)` if the receive buffer is empty.
428    pub fn peek(&mut self) -> Result<(&[u8], &UdpMetadata), RecvError> {
429        let endpoint = self.endpoint;
430        self.rx_buffer.peek().map_err(|_| RecvError::Exhausted).map(
431            |(remote_endpoint, payload_buf)| {
432                net_trace!(
433                    "udp:{}:{}: peek {} buffered octets",
434                    endpoint,
435                    remote_endpoint.endpoint,
436                    payload_buf.len()
437                );
438                (payload_buf, remote_endpoint)
439            },
440        )
441    }
442
443    /// Peek at a packet received from a remote endpoint, copy the payload into the given slice,
444    /// and return the amount of octets copied as well as the endpoint without removing the
445    /// packet from the receive buffer.
446    /// This function otherwise behaves identically to [recv_slice](#method.recv_slice).
447    ///
448    /// **Note**: when the size of the provided buffer is smaller than the size of the payload,
449    /// no data is copied into the provided buffer and a `RecvError::Truncated` error is returned.
450    ///
451    /// See also [peek](#method.peek).
452    pub fn peek_slice(&mut self, data: &mut [u8]) -> Result<(usize, &UdpMetadata), RecvError> {
453        let (buffer, endpoint) = self.peek()?;
454
455        if data.len() < buffer.len() {
456            return Err(RecvError::Truncated);
457        }
458
459        let length = min(data.len(), buffer.len());
460        data[..length].copy_from_slice(&buffer[..length]);
461        Ok((length, endpoint))
462    }
463
464    /// Return the amount of octets queued in the transmit buffer.
465    ///
466    /// Note that the Berkeley sockets interface does not have an equivalent of this API.
467    pub fn send_queue(&self) -> usize {
468        self.tx_buffer.payload_bytes_count()
469    }
470
471    /// Return the amount of octets queued in the receive buffer. This value can be larger than
472    /// the slice read by the next `recv` or `peek` call because it includes all queued octets,
473    /// and not only the octets that may be returned as a contiguous slice.
474    ///
475    /// Note that the Berkeley sockets interface does not have an equivalent of this API.
476    pub fn recv_queue(&self) -> usize {
477        self.rx_buffer.payload_bytes_count()
478    }
479
480    pub(crate) fn accepts(&self, cx: &mut Context, ip_repr: &IpRepr, repr: &UdpRepr) -> bool {
481        if self.endpoint.port != repr.dst_port {
482            return false;
483        }
484        if self.endpoint.addr.is_some()
485            && self.endpoint.addr != Some(ip_repr.dst_addr())
486            && !cx.is_broadcast(&ip_repr.dst_addr())
487            && !ip_repr.dst_addr().is_multicast()
488        {
489            return false;
490        }
491
492        true
493    }
494
495    pub(crate) fn process(
496        &mut self,
497        cx: &mut Context,
498        meta: PacketMeta,
499        ip_repr: &IpRepr,
500        repr: &UdpRepr,
501        payload: &[u8],
502    ) {
503        debug_assert!(self.accepts(cx, ip_repr, repr));
504
505        let size = payload.len();
506
507        let remote_endpoint = IpEndpoint {
508            addr: ip_repr.src_addr(),
509            port: repr.src_port,
510        };
511
512        net_trace!(
513            "udp:{}:{}: receiving {} octets",
514            self.endpoint,
515            remote_endpoint,
516            size
517        );
518
519        let metadata = UdpMetadata {
520            endpoint: remote_endpoint,
521            local_address: Some(ip_repr.dst_addr()),
522            meta,
523        };
524
525        match self.rx_buffer.enqueue(size, metadata) {
526            Ok(buf) => buf.copy_from_slice(payload),
527            Err(_) => net_trace!(
528                "udp:{}:{}: buffer full, dropped incoming packet",
529                self.endpoint,
530                remote_endpoint
531            ),
532        }
533
534        #[cfg(feature = "async")]
535        self.rx_waker.wake();
536    }
537
538    pub(crate) fn dispatch<F, E>(&mut self, cx: &mut Context, emit: F) -> Result<(), E>
539    where
540        F: FnOnce(&mut Context, PacketMeta, (IpRepr, UdpRepr, &[u8])) -> Result<(), E>,
541    {
542        let endpoint = self.endpoint;
543        let hop_limit = self.hop_limit.unwrap_or(64);
544
545        let res = self.tx_buffer.dequeue_with(|packet_meta, payload_buf| {
546            let src_addr = if let Some(s) = packet_meta.local_address {
547                s
548            } else {
549                match endpoint.addr {
550                    Some(addr) => addr,
551                    None => match cx.get_source_address(&packet_meta.endpoint.addr) {
552                        Some(addr) => addr,
553                        None => {
554                            net_trace!(
555                                "udp:{}:{}: cannot find suitable source address, dropping.",
556                                endpoint,
557                                packet_meta.endpoint
558                            );
559                            return Ok(());
560                        }
561                    },
562                }
563            };
564
565            net_trace!(
566                "udp:{}:{}: sending {} octets",
567                endpoint,
568                packet_meta.endpoint,
569                payload_buf.len()
570            );
571
572            let repr = UdpRepr {
573                src_port: endpoint.port,
574                dst_port: packet_meta.endpoint.port,
575            };
576            let ip_repr = IpRepr::new(
577                src_addr,
578                packet_meta.endpoint.addr,
579                IpProtocol::Udp,
580                repr.header_len() + payload_buf.len(),
581                hop_limit,
582            );
583
584            emit(cx, packet_meta.meta, (ip_repr, repr, payload_buf))
585        });
586        match res {
587            Err(Empty) => Ok(()),
588            Ok(Err(e)) => Err(e),
589            Ok(Ok(())) => {
590                #[cfg(feature = "async")]
591                self.tx_waker.wake();
592                Ok(())
593            }
594        }
595    }
596
597    pub(crate) fn poll_at(&self, _cx: &mut Context) -> PollAt {
598        if self.tx_buffer.is_empty() {
599            PollAt::Ingress
600        } else {
601            PollAt::Now
602        }
603    }
604}
605
606#[cfg(test)]
607mod test {
608    use super::*;
609    use crate::wire::{IpRepr, UdpRepr};
610
611    use crate::phy::Medium;
612    use crate::tests::setup;
613    use rstest::*;
614
615    fn buffer(packets: usize) -> PacketBuffer<'static> {
616        PacketBuffer::new(
617            (0..packets)
618                .map(|_| PacketMetadata::EMPTY)
619                .collect::<Vec<_>>(),
620            vec![0; 16 * packets],
621        )
622    }
623
624    fn socket(
625        rx_buffer: PacketBuffer<'static>,
626        tx_buffer: PacketBuffer<'static>,
627    ) -> Socket<'static> {
628        Socket::new(rx_buffer, tx_buffer)
629    }
630
631    const LOCAL_PORT: u16 = 53;
632    const REMOTE_PORT: u16 = 49500;
633
634    cfg_if::cfg_if! {
635        if #[cfg(feature = "proto-ipv4")] {
636            use crate::wire::Ipv4Address as IpvXAddress;
637            use crate::wire::Ipv4Repr as IpvXRepr;
638            use IpRepr::Ipv4 as IpReprIpvX;
639
640            const LOCAL_ADDR: IpvXAddress = IpvXAddress::new(192, 168, 1, 1);
641            const REMOTE_ADDR: IpvXAddress = IpvXAddress::new(192, 168, 1, 2);
642            const OTHER_ADDR: IpvXAddress = IpvXAddress::new(192, 168, 1, 3);
643
644            const LOCAL_END: IpEndpoint = IpEndpoint {
645                addr: IpAddress::Ipv4(LOCAL_ADDR),
646                port: LOCAL_PORT,
647            };
648            const REMOTE_END: IpEndpoint = IpEndpoint {
649                addr: IpAddress::Ipv4(REMOTE_ADDR),
650                port: REMOTE_PORT,
651            };
652        } else {
653            use crate::wire::Ipv6Address as IpvXAddress;
654            use crate::wire::Ipv6Repr as IpvXRepr;
655            use IpRepr::Ipv6 as IpReprIpvX;
656
657            const LOCAL_ADDR: IpvXAddress = IpvXAddress::new(0xfe80, 0, 0, 0, 0, 0, 0, 1);
658            const REMOTE_ADDR: IpvXAddress = IpvXAddress::new(0xfe80, 0, 0, 0, 0, 0, 0, 2);
659            const OTHER_ADDR: IpvXAddress = IpvXAddress::new(0xfe80, 0, 0, 0, 0, 0, 0, 3);
660
661            const LOCAL_END: IpEndpoint = IpEndpoint {
662                addr: IpAddress::Ipv6(LOCAL_ADDR),
663                port: LOCAL_PORT,
664            };
665            const REMOTE_END: IpEndpoint = IpEndpoint {
666                addr: IpAddress::Ipv6(REMOTE_ADDR),
667                port: REMOTE_PORT,
668            };
669        }
670    }
671
672    fn remote_metadata_with_local() -> UdpMetadata {
673        // Would be great as a const once we have const `.into()`.
674        UdpMetadata {
675            local_address: Some(LOCAL_ADDR.into()),
676            ..REMOTE_END.into()
677        }
678    }
679
680    pub const LOCAL_IP_REPR: IpRepr = IpReprIpvX(IpvXRepr {
681        src_addr: LOCAL_ADDR,
682        dst_addr: REMOTE_ADDR,
683        next_header: IpProtocol::Udp,
684        payload_len: 8 + 6,
685        hop_limit: 64,
686    });
687
688    pub const REMOTE_IP_REPR: IpRepr = IpReprIpvX(IpvXRepr {
689        src_addr: REMOTE_ADDR,
690        dst_addr: LOCAL_ADDR,
691        next_header: IpProtocol::Udp,
692        payload_len: 8 + 6,
693        hop_limit: 64,
694    });
695
696    pub const BAD_IP_REPR: IpRepr = IpReprIpvX(IpvXRepr {
697        src_addr: REMOTE_ADDR,
698        dst_addr: OTHER_ADDR,
699        next_header: IpProtocol::Udp,
700        payload_len: 8 + 6,
701        hop_limit: 64,
702    });
703
704    const LOCAL_UDP_REPR: UdpRepr = UdpRepr {
705        src_port: LOCAL_PORT,
706        dst_port: REMOTE_PORT,
707    };
708
709    const REMOTE_UDP_REPR: UdpRepr = UdpRepr {
710        src_port: REMOTE_PORT,
711        dst_port: LOCAL_PORT,
712    };
713
714    const PAYLOAD: &[u8] = b"abcdef";
715
716    #[test]
717    fn test_bind_unaddressable() {
718        let mut socket = socket(buffer(0), buffer(0));
719        assert_eq!(socket.bind(0), Err(BindError::Unaddressable));
720    }
721
722    #[test]
723    fn test_bind_twice() {
724        let mut socket = socket(buffer(0), buffer(0));
725        assert_eq!(socket.bind(1), Ok(()));
726        assert_eq!(socket.bind(2), Err(BindError::InvalidState));
727    }
728
729    #[test]
730    #[should_panic(expected = "the time-to-live value of a packet must not be zero")]
731    fn test_set_hop_limit_zero() {
732        let mut s = socket(buffer(0), buffer(1));
733        s.set_hop_limit(Some(0));
734    }
735
736    #[test]
737    fn test_send_unaddressable() {
738        let mut socket = socket(buffer(0), buffer(1));
739
740        assert_eq!(
741            socket.send_slice(b"abcdef", REMOTE_END),
742            Err(SendError::Unaddressable)
743        );
744        assert_eq!(socket.bind(LOCAL_PORT), Ok(()));
745        assert_eq!(
746            socket.send_slice(
747                b"abcdef",
748                IpEndpoint {
749                    addr: IpvXAddress::UNSPECIFIED.into(),
750                    ..REMOTE_END
751                }
752            ),
753            Err(SendError::Unaddressable)
754        );
755        assert_eq!(
756            socket.send_slice(
757                b"abcdef",
758                IpEndpoint {
759                    port: 0,
760                    ..REMOTE_END
761                }
762            ),
763            Err(SendError::Unaddressable)
764        );
765        assert_eq!(socket.send_slice(b"abcdef", REMOTE_END), Ok(()));
766    }
767
768    #[test]
769    fn test_send_with_source() {
770        let mut socket = socket(buffer(0), buffer(1));
771
772        assert_eq!(socket.bind(LOCAL_PORT), Ok(()));
773        assert_eq!(
774            socket.send_slice(b"abcdef", remote_metadata_with_local()),
775            Ok(())
776        );
777    }
778
779    #[rstest]
780    #[case::ip(Medium::Ip)]
781    #[cfg(feature = "medium-ip")]
782    #[case::ethernet(Medium::Ethernet)]
783    #[cfg(feature = "medium-ethernet")]
784    #[case::ieee802154(Medium::Ieee802154)]
785    #[cfg(feature = "medium-ieee802154")]
786    fn test_send_dispatch(#[case] medium: Medium) {
787        let (mut iface, _, _) = setup(medium);
788        let cx = iface.context();
789        let mut socket = socket(buffer(0), buffer(1));
790
791        assert_eq!(socket.bind(LOCAL_END), Ok(()));
792
793        assert!(socket.can_send());
794        assert_eq!(
795            socket.dispatch(cx, |_, _, _| unreachable!()),
796            Ok::<_, ()>(())
797        );
798
799        assert_eq!(socket.send_slice(b"abcdef", REMOTE_END), Ok(()));
800        assert_eq!(
801            socket.send_slice(b"123456", REMOTE_END),
802            Err(SendError::BufferFull)
803        );
804        assert!(!socket.can_send());
805
806        assert_eq!(
807            socket.dispatch(cx, |_, _, (ip_repr, udp_repr, payload)| {
808                assert_eq!(ip_repr, LOCAL_IP_REPR);
809                assert_eq!(udp_repr, LOCAL_UDP_REPR);
810                assert_eq!(payload, PAYLOAD);
811                Err(())
812            }),
813            Err(())
814        );
815        assert!(!socket.can_send());
816
817        assert_eq!(
818            socket.dispatch(cx, |_, _, (ip_repr, udp_repr, payload)| {
819                assert_eq!(ip_repr, LOCAL_IP_REPR);
820                assert_eq!(udp_repr, LOCAL_UDP_REPR);
821                assert_eq!(payload, PAYLOAD);
822                Ok::<_, ()>(())
823            }),
824            Ok(())
825        );
826        assert!(socket.can_send());
827    }
828
829    #[rstest]
830    #[case::ip(Medium::Ip)]
831    #[cfg(feature = "medium-ip")]
832    #[case::ethernet(Medium::Ethernet)]
833    #[cfg(feature = "medium-ethernet")]
834    #[case::ieee802154(Medium::Ieee802154)]
835    #[cfg(feature = "medium-ieee802154")]
836    fn test_recv_process(#[case] medium: Medium) {
837        let (mut iface, _, _) = setup(medium);
838        let cx = iface.context();
839
840        let mut socket = socket(buffer(1), buffer(0));
841
842        assert_eq!(socket.bind(LOCAL_PORT), Ok(()));
843
844        assert!(!socket.can_recv());
845        assert_eq!(socket.recv(), Err(RecvError::Exhausted));
846
847        assert!(socket.accepts(cx, &REMOTE_IP_REPR, &REMOTE_UDP_REPR));
848        socket.process(
849            cx,
850            PacketMeta::default(),
851            &REMOTE_IP_REPR,
852            &REMOTE_UDP_REPR,
853            PAYLOAD,
854        );
855        assert!(socket.can_recv());
856
857        assert!(socket.accepts(cx, &REMOTE_IP_REPR, &REMOTE_UDP_REPR));
858        socket.process(
859            cx,
860            PacketMeta::default(),
861            &REMOTE_IP_REPR,
862            &REMOTE_UDP_REPR,
863            PAYLOAD,
864        );
865
866        assert_eq!(
867            socket.recv(),
868            Ok((&b"abcdef"[..], remote_metadata_with_local()))
869        );
870        assert!(!socket.can_recv());
871    }
872
873    #[rstest]
874    #[case::ip(Medium::Ip)]
875    #[cfg(feature = "medium-ip")]
876    #[case::ethernet(Medium::Ethernet)]
877    #[cfg(feature = "medium-ethernet")]
878    #[case::ieee802154(Medium::Ieee802154)]
879    #[cfg(feature = "medium-ieee802154")]
880    fn test_peek_process(#[case] medium: Medium) {
881        let (mut iface, _, _) = setup(medium);
882        let cx = iface.context();
883
884        let mut socket = socket(buffer(1), buffer(0));
885
886        assert_eq!(socket.bind(LOCAL_PORT), Ok(()));
887
888        assert_eq!(socket.peek(), Err(RecvError::Exhausted));
889
890        socket.process(
891            cx,
892            PacketMeta::default(),
893            &REMOTE_IP_REPR,
894            &REMOTE_UDP_REPR,
895            PAYLOAD,
896        );
897        assert_eq!(
898            socket.peek(),
899            Ok((&b"abcdef"[..], &remote_metadata_with_local(),))
900        );
901        assert_eq!(
902            socket.recv(),
903            Ok((&b"abcdef"[..], remote_metadata_with_local(),))
904        );
905        assert_eq!(socket.peek(), Err(RecvError::Exhausted));
906    }
907
908    #[rstest]
909    #[case::ip(Medium::Ip)]
910    #[cfg(feature = "medium-ip")]
911    #[case::ethernet(Medium::Ethernet)]
912    #[cfg(feature = "medium-ethernet")]
913    #[case::ieee802154(Medium::Ieee802154)]
914    #[cfg(feature = "medium-ieee802154")]
915    fn test_recv_truncated_slice(#[case] medium: Medium) {
916        let (mut iface, _, _) = setup(medium);
917        let cx = iface.context();
918
919        let mut socket = socket(buffer(1), buffer(0));
920
921        assert_eq!(socket.bind(LOCAL_PORT), Ok(()));
922
923        assert!(socket.accepts(cx, &REMOTE_IP_REPR, &REMOTE_UDP_REPR));
924        socket.process(
925            cx,
926            PacketMeta::default(),
927            &REMOTE_IP_REPR,
928            &REMOTE_UDP_REPR,
929            PAYLOAD,
930        );
931
932        let mut slice = [0; 4];
933        assert_eq!(socket.recv_slice(&mut slice[..]), Err(RecvError::Truncated));
934    }
935
936    #[rstest]
937    #[case::ip(Medium::Ip)]
938    #[cfg(feature = "medium-ip")]
939    #[case::ethernet(Medium::Ethernet)]
940    #[cfg(feature = "medium-ethernet")]
941    #[case::ieee802154(Medium::Ieee802154)]
942    #[cfg(feature = "medium-ieee802154")]
943    fn test_peek_truncated_slice(#[case] medium: Medium) {
944        let (mut iface, _, _) = setup(medium);
945        let cx = iface.context();
946
947        let mut socket = socket(buffer(1), buffer(0));
948
949        assert_eq!(socket.bind(LOCAL_PORT), Ok(()));
950
951        socket.process(
952            cx,
953            PacketMeta::default(),
954            &REMOTE_IP_REPR,
955            &REMOTE_UDP_REPR,
956            PAYLOAD,
957        );
958
959        let mut slice = [0; 4];
960        assert_eq!(socket.peek_slice(&mut slice[..]), Err(RecvError::Truncated));
961        assert_eq!(socket.recv_slice(&mut slice[..]), Err(RecvError::Truncated));
962        assert_eq!(socket.peek_slice(&mut slice[..]), Err(RecvError::Exhausted));
963    }
964
965    #[rstest]
966    #[case::ip(Medium::Ip)]
967    #[cfg(feature = "medium-ip")]
968    #[case::ethernet(Medium::Ethernet)]
969    #[cfg(feature = "medium-ethernet")]
970    #[case::ieee802154(Medium::Ieee802154)]
971    #[cfg(feature = "medium-ieee802154")]
972    fn test_set_hop_limit(#[case] medium: Medium) {
973        let (mut iface, _, _) = setup(medium);
974        let cx = iface.context();
975
976        let mut s = socket(buffer(0), buffer(1));
977
978        assert_eq!(s.bind(LOCAL_END), Ok(()));
979
980        s.set_hop_limit(Some(0x2a));
981        assert_eq!(s.send_slice(b"abcdef", REMOTE_END), Ok(()));
982        assert_eq!(
983            s.dispatch(cx, |_, _, (ip_repr, _, _)| {
984                assert_eq!(
985                    ip_repr,
986                    IpReprIpvX(IpvXRepr {
987                        src_addr: LOCAL_ADDR,
988                        dst_addr: REMOTE_ADDR,
989                        next_header: IpProtocol::Udp,
990                        payload_len: 8 + 6,
991                        hop_limit: 0x2a,
992                    })
993                );
994                Ok::<_, ()>(())
995            }),
996            Ok(())
997        );
998    }
999
1000    #[rstest]
1001    #[case::ip(Medium::Ip)]
1002    #[cfg(feature = "medium-ip")]
1003    #[case::ethernet(Medium::Ethernet)]
1004    #[cfg(feature = "medium-ethernet")]
1005    #[case::ieee802154(Medium::Ieee802154)]
1006    #[cfg(feature = "medium-ieee802154")]
1007    fn test_doesnt_accept_wrong_port(#[case] medium: Medium) {
1008        let (mut iface, _, _) = setup(medium);
1009        let cx = iface.context();
1010
1011        let mut socket = socket(buffer(1), buffer(0));
1012
1013        assert_eq!(socket.bind(LOCAL_PORT), Ok(()));
1014
1015        let mut udp_repr = REMOTE_UDP_REPR;
1016        assert!(socket.accepts(cx, &REMOTE_IP_REPR, &udp_repr));
1017        udp_repr.dst_port += 1;
1018        assert!(!socket.accepts(cx, &REMOTE_IP_REPR, &udp_repr));
1019    }
1020
1021    #[rstest]
1022    #[case::ip(Medium::Ip)]
1023    #[cfg(feature = "medium-ip")]
1024    #[case::ethernet(Medium::Ethernet)]
1025    #[cfg(feature = "medium-ethernet")]
1026    #[case::ieee802154(Medium::Ieee802154)]
1027    #[cfg(feature = "medium-ieee802154")]
1028    fn test_doesnt_accept_wrong_ip(#[case] medium: Medium) {
1029        let (mut iface, _, _) = setup(medium);
1030        let cx = iface.context();
1031
1032        let mut port_bound_socket = socket(buffer(1), buffer(0));
1033        assert_eq!(port_bound_socket.bind(LOCAL_PORT), Ok(()));
1034        assert!(port_bound_socket.accepts(cx, &BAD_IP_REPR, &REMOTE_UDP_REPR));
1035
1036        let mut ip_bound_socket = socket(buffer(1), buffer(0));
1037        assert_eq!(ip_bound_socket.bind(LOCAL_END), Ok(()));
1038        assert!(!ip_bound_socket.accepts(cx, &BAD_IP_REPR, &REMOTE_UDP_REPR));
1039    }
1040
1041    #[test]
1042    fn test_send_large_packet() {
1043        // buffer(4) creates a payload buffer of size 16*4
1044        let mut socket = socket(buffer(0), buffer(4));
1045        assert_eq!(socket.bind(LOCAL_END), Ok(()));
1046
1047        let too_large = b"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdefx";
1048        assert_eq!(
1049            socket.send_slice(too_large, REMOTE_END),
1050            Err(SendError::BufferFull)
1051        );
1052        assert_eq!(socket.send_slice(&too_large[..16 * 4], REMOTE_END), Ok(()));
1053    }
1054
1055    #[rstest]
1056    #[case::ip(Medium::Ip)]
1057    #[cfg(feature = "medium-ip")]
1058    #[case::ethernet(Medium::Ethernet)]
1059    #[cfg(feature = "medium-ethernet")]
1060    #[case::ieee802154(Medium::Ieee802154)]
1061    #[cfg(feature = "medium-ieee802154")]
1062    fn test_process_empty_payload(#[case] medium: Medium) {
1063        let meta = Box::leak(Box::new([PacketMetadata::EMPTY]));
1064        let recv_buffer = PacketBuffer::new(&mut meta[..], vec![]);
1065        let mut socket = socket(recv_buffer, buffer(0));
1066
1067        let (mut iface, _, _) = setup(medium);
1068        let cx = iface.context();
1069
1070        assert_eq!(socket.bind(LOCAL_PORT), Ok(()));
1071
1072        let repr = UdpRepr {
1073            src_port: REMOTE_PORT,
1074            dst_port: LOCAL_PORT,
1075        };
1076        socket.process(cx, PacketMeta::default(), &REMOTE_IP_REPR, &repr, &[]);
1077        assert_eq!(socket.recv(), Ok((&[][..], remote_metadata_with_local())));
1078    }
1079
1080    #[test]
1081    fn test_closing() {
1082        let meta = Box::leak(Box::new([PacketMetadata::EMPTY]));
1083        let recv_buffer = PacketBuffer::new(&mut meta[..], vec![]);
1084        let mut socket = socket(recv_buffer, buffer(0));
1085        assert_eq!(socket.bind(LOCAL_PORT), Ok(()));
1086
1087        assert!(socket.is_open());
1088        socket.close();
1089        assert!(!socket.is_open());
1090    }
1091}