Skip to main content

smoltcp/socket/
raw.rs

1use core::cmp::min;
2#[cfg(feature = "async")]
3use core::task::Waker;
4
5use crate::iface::Context;
6use crate::socket::PollAt;
7#[cfg(feature = "async")]
8use crate::socket::WakerRegistration;
9
10use crate::storage::Empty;
11use crate::wire::{IpProtocol, IpRepr, IpVersion};
12#[cfg(feature = "proto-ipv4")]
13use crate::wire::{Ipv4Packet, Ipv4Repr};
14#[cfg(feature = "proto-ipv6")]
15use crate::wire::{Ipv6Packet, Ipv6Repr};
16
17/// Error returned by [`Socket::bind`]
18#[derive(Debug, PartialEq, Eq, Clone, Copy)]
19#[cfg_attr(feature = "defmt", derive(defmt::Format))]
20pub enum BindError {
21    InvalidState,
22    Unaddressable,
23}
24
25impl core::fmt::Display for BindError {
26    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
27        match self {
28            BindError::InvalidState => write!(f, "invalid state"),
29            BindError::Unaddressable => write!(f, "unaddressable"),
30        }
31    }
32}
33
34#[cfg(feature = "std")]
35impl std::error::Error for BindError {}
36
37/// Error returned by [`Socket::send`]
38#[derive(Debug, PartialEq, Eq, Clone, Copy)]
39#[cfg_attr(feature = "defmt", derive(defmt::Format))]
40pub enum SendError {
41    BufferFull,
42}
43
44impl core::fmt::Display for SendError {
45    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
46        match self {
47            SendError::BufferFull => write!(f, "buffer full"),
48        }
49    }
50}
51
52#[cfg(feature = "std")]
53impl std::error::Error for SendError {}
54
55/// Error returned by [`Socket::recv`]
56#[derive(Debug, PartialEq, Eq, Clone, Copy)]
57#[cfg_attr(feature = "defmt", derive(defmt::Format))]
58pub enum RecvError {
59    Exhausted,
60    Truncated,
61}
62
63impl core::fmt::Display for RecvError {
64    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
65        match self {
66            RecvError::Exhausted => write!(f, "exhausted"),
67            RecvError::Truncated => write!(f, "truncated"),
68        }
69    }
70}
71
72#[cfg(feature = "std")]
73impl std::error::Error for RecvError {}
74
75/// A UDP packet metadata.
76pub type PacketMetadata = crate::storage::PacketMetadata<()>;
77
78/// A UDP packet ring buffer.
79pub type PacketBuffer<'a> = crate::storage::PacketBuffer<'a, ()>;
80
81/// A raw IP socket.
82///
83/// A raw socket may be bound to a specific IP protocol, and owns
84/// transmit and receive packet buffers.
85#[derive(Debug)]
86pub struct Socket<'a> {
87    ip_version: Option<IpVersion>,
88    ip_protocol: Option<IpProtocol>,
89    rx_buffer: PacketBuffer<'a>,
90    tx_buffer: PacketBuffer<'a>,
91    #[cfg(feature = "async")]
92    rx_waker: WakerRegistration,
93    #[cfg(feature = "async")]
94    tx_waker: WakerRegistration,
95}
96
97impl<'a> Socket<'a> {
98    /// Create a raw IP socket bound to the given IP version and datagram protocol,
99    /// with the given buffers.
100    pub fn new(
101        ip_version: Option<IpVersion>,
102        ip_protocol: Option<IpProtocol>,
103        rx_buffer: PacketBuffer<'a>,
104        tx_buffer: PacketBuffer<'a>,
105    ) -> Socket<'a> {
106        Socket {
107            ip_version,
108            ip_protocol,
109            rx_buffer,
110            tx_buffer,
111            #[cfg(feature = "async")]
112            rx_waker: WakerRegistration::new(),
113            #[cfg(feature = "async")]
114            tx_waker: WakerRegistration::new(),
115        }
116    }
117
118    /// Register a waker for receive operations.
119    ///
120    /// The waker is woken on state changes that might affect the return value
121    /// of `recv` method calls, such as receiving data, or the socket closing.
122    ///
123    /// Notes:
124    ///
125    /// - Only one waker can be registered at a time. If another waker was previously registered,
126    ///   it is overwritten and will no longer be woken.
127    /// - The Waker is woken only once. Once woken, you must register it again to receive more wakes.
128    /// - "Spurious wakes" are allowed: a wake doesn't guarantee the result of `recv` has
129    ///   necessarily changed.
130    #[cfg(feature = "async")]
131    pub fn register_recv_waker(&mut self, waker: &Waker) {
132        self.rx_waker.register(waker)
133    }
134
135    /// Register a waker for send operations.
136    ///
137    /// The waker is woken on state changes that might affect the return value
138    /// of `send` method calls, such as space becoming available in the transmit
139    /// buffer, or the socket closing.
140    ///
141    /// Notes:
142    ///
143    /// - Only one waker can be registered at a time. If another waker was previously registered,
144    ///   it is overwritten and will no longer be woken.
145    /// - The Waker is woken only once. Once woken, you must register it again to receive more wakes.
146    /// - "Spurious wakes" are allowed: a wake doesn't guarantee the result of `send` has
147    ///   necessarily changed.
148    #[cfg(feature = "async")]
149    pub fn register_send_waker(&mut self, waker: &Waker) {
150        self.tx_waker.register(waker)
151    }
152
153    /// Return the IP version the socket is bound to.
154    #[inline]
155    pub fn ip_version(&self) -> Option<IpVersion> {
156        self.ip_version
157    }
158
159    /// Return the IP protocol the socket is bound to.
160    #[inline]
161    pub fn ip_protocol(&self) -> Option<IpProtocol> {
162        self.ip_protocol
163    }
164
165    /// Check whether the transmit buffer is full.
166    #[inline]
167    pub fn can_send(&self) -> bool {
168        !self.tx_buffer.is_full()
169    }
170
171    /// Check whether the receive buffer is not empty.
172    #[inline]
173    pub fn can_recv(&self) -> bool {
174        !self.rx_buffer.is_empty()
175    }
176
177    /// Return the maximum number packets the socket can receive.
178    #[inline]
179    pub fn packet_recv_capacity(&self) -> usize {
180        self.rx_buffer.packet_capacity()
181    }
182
183    /// Return the maximum number packets the socket can transmit.
184    #[inline]
185    pub fn packet_send_capacity(&self) -> usize {
186        self.tx_buffer.packet_capacity()
187    }
188
189    /// Return the maximum number of bytes inside the recv buffer.
190    #[inline]
191    pub fn payload_recv_capacity(&self) -> usize {
192        self.rx_buffer.payload_capacity()
193    }
194
195    /// Return the maximum number of bytes inside the transmit buffer.
196    #[inline]
197    pub fn payload_send_capacity(&self) -> usize {
198        self.tx_buffer.payload_capacity()
199    }
200
201    /// Enqueue a packet to send, and return a pointer to its payload.
202    ///
203    /// This function returns `Err(Error::Exhausted)` if the transmit buffer is full,
204    /// and `Err(Error::Truncated)` if there is not enough transmit buffer capacity
205    /// to ever send this packet.
206    ///
207    /// If the buffer is filled in a way that does not match the socket's
208    /// IP version or protocol, the packet will be silently dropped.
209    ///
210    /// **Note:** The IP header is parsed and re-serialized, and may not match
211    /// the header actually transmitted bit for bit.
212    pub fn send(&mut self, size: usize) -> Result<&mut [u8], SendError> {
213        let packet_buf = self
214            .tx_buffer
215            .enqueue(size, ())
216            .map_err(|_| SendError::BufferFull)?;
217
218        net_trace!(
219            "raw:{:?}:{:?}: buffer to send {} octets",
220            self.ip_version,
221            self.ip_protocol,
222            packet_buf.len()
223        );
224        Ok(packet_buf)
225    }
226
227    /// Enqueue a packet to be send and pass the buffer to the provided closure.
228    /// The closure then returns the size of the data written into the buffer.
229    ///
230    /// Also see [send](#method.send).
231    pub fn send_with<F>(&mut self, max_size: usize, f: F) -> Result<usize, SendError>
232    where
233        F: FnOnce(&mut [u8]) -> usize,
234    {
235        let size = self
236            .tx_buffer
237            .enqueue_with_infallible(max_size, (), f)
238            .map_err(|_| SendError::BufferFull)?;
239
240        net_trace!(
241            "raw:{:?}:{:?}: buffer to send {} octets",
242            self.ip_version,
243            self.ip_protocol,
244            size
245        );
246
247        Ok(size)
248    }
249
250    /// Enqueue a packet to send, and fill it from a slice.
251    ///
252    /// See also [send](#method.send).
253    pub fn send_slice(&mut self, data: &[u8]) -> Result<(), SendError> {
254        self.send(data.len())?.copy_from_slice(data);
255        Ok(())
256    }
257
258    /// Dequeue a packet, and return a pointer to the payload.
259    ///
260    /// This function returns `Err(Error::Exhausted)` if the receive buffer is empty.
261    ///
262    /// **Note:** The IP header is parsed and re-serialized, and may not match
263    /// the header actually received bit for bit.
264    pub fn recv(&mut self) -> Result<&[u8], RecvError> {
265        let ((), packet_buf) = self.rx_buffer.dequeue().map_err(|_| RecvError::Exhausted)?;
266
267        net_trace!(
268            "raw:{:?}:{:?}: receive {} buffered octets",
269            self.ip_version,
270            self.ip_protocol,
271            packet_buf.len()
272        );
273        Ok(packet_buf)
274    }
275
276    /// Dequeue a packet, and copy the payload into the given slice.
277    ///
278    /// **Note**: when the size of the provided buffer is smaller than the size of the payload,
279    /// the packet is dropped and a `RecvError::Truncated` error is returned.
280    ///
281    /// See also [recv](#method.recv).
282    pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<usize, RecvError> {
283        let buffer = self.recv()?;
284        if data.len() < buffer.len() {
285            return Err(RecvError::Truncated);
286        }
287
288        let length = min(data.len(), buffer.len());
289        data[..length].copy_from_slice(&buffer[..length]);
290        Ok(length)
291    }
292
293    /// Peek at a packet in the receive buffer and return a pointer to the
294    /// payload without removing the packet from the receive buffer.
295    /// This function otherwise behaves identically to [recv](#method.recv).
296    ///
297    /// It returns `Err(Error::Exhausted)` if the receive buffer is empty.
298    pub fn peek(&mut self) -> Result<&[u8], RecvError> {
299        let ((), packet_buf) = self.rx_buffer.peek().map_err(|_| RecvError::Exhausted)?;
300
301        net_trace!(
302            "raw:{:?}:{:?}: receive {} buffered octets",
303            self.ip_version,
304            self.ip_protocol,
305            packet_buf.len()
306        );
307
308        Ok(packet_buf)
309    }
310
311    /// Peek at a packet in the receive buffer, copy the payload into the given slice,
312    /// and return the amount of octets copied without removing the packet from the receive buffer.
313    /// This function otherwise behaves identically to [recv_slice](#method.recv_slice).
314    ///
315    /// **Note**: when the size of the provided buffer is smaller than the size of the payload,
316    /// no data is copied into the provided buffer and a `RecvError::Truncated` error is returned.
317    ///
318    /// See also [peek](#method.peek).
319    pub fn peek_slice(&mut self, data: &mut [u8]) -> Result<usize, RecvError> {
320        let buffer = self.peek()?;
321        if data.len() < buffer.len() {
322            return Err(RecvError::Truncated);
323        }
324
325        let length = min(data.len(), buffer.len());
326        data[..length].copy_from_slice(&buffer[..length]);
327        Ok(length)
328    }
329
330    /// Return the amount of octets queued in the transmit buffer.
331    pub fn send_queue(&self) -> usize {
332        self.tx_buffer.payload_bytes_count()
333    }
334
335    /// Return the amount of octets queued in the receive buffer.
336    pub fn recv_queue(&self) -> usize {
337        self.rx_buffer.payload_bytes_count()
338    }
339
340    pub(crate) fn accepts(&self, ip_repr: &IpRepr) -> bool {
341        if self
342            .ip_version
343            .is_some_and(|version| version != ip_repr.version())
344        {
345            return false;
346        }
347
348        if self
349            .ip_protocol
350            .is_some_and(|next_header| next_header != ip_repr.next_header())
351        {
352            return false;
353        }
354
355        true
356    }
357
358    pub(crate) fn process(&mut self, cx: &mut Context, ip_repr: &IpRepr, payload: &[u8]) {
359        debug_assert!(self.accepts(ip_repr));
360
361        let header_len = ip_repr.header_len();
362        let total_len = header_len + payload.len();
363
364        net_trace!(
365            "raw:{:?}:{:?}: receiving {} octets",
366            self.ip_version,
367            self.ip_protocol,
368            total_len
369        );
370
371        match self.rx_buffer.enqueue(total_len, ()) {
372            Ok(buf) => {
373                ip_repr.emit(&mut buf[..header_len], &cx.checksum_caps());
374                buf[header_len..].copy_from_slice(payload);
375            }
376            Err(_) => net_trace!(
377                "raw:{:?}:{:?}: buffer full, dropped incoming packet",
378                self.ip_version,
379                self.ip_protocol
380            ),
381        }
382
383        #[cfg(feature = "async")]
384        self.rx_waker.wake();
385    }
386
387    pub(crate) fn dispatch<F, E>(&mut self, cx: &mut Context, emit: F) -> Result<(), E>
388    where
389        F: FnOnce(&mut Context, (IpRepr, &[u8])) -> Result<(), E>,
390    {
391        let ip_protocol = self.ip_protocol;
392        let ip_version = self.ip_version;
393        let _checksum_caps = &cx.checksum_caps();
394        let res = self.tx_buffer.dequeue_with(|&mut (), buffer| {
395            match IpVersion::of_packet(buffer) {
396                #[cfg(feature = "proto-ipv4")]
397                Ok(IpVersion::Ipv4) => {
398                    let mut packet = match Ipv4Packet::new_checked(buffer) {
399                        Ok(x) => x,
400                        Err(_) => {
401                            net_trace!("raw: malformed ipv6 packet in queue, dropping.");
402                            return Ok(());
403                        }
404                    };
405                    if ip_protocol.is_some_and(|next_header| next_header != packet.next_header()) {
406                        net_trace!("raw: sent packet with wrong ip protocol, dropping.");
407                        return Ok(());
408                    }
409                    if _checksum_caps.ipv4.tx() {
410                        packet.fill_checksum();
411                    } else {
412                        // make sure we get a consistently zeroed checksum,
413                        // since implementations might rely on it
414                        packet.set_checksum(0);
415                    }
416
417                    let packet = Ipv4Packet::new_unchecked(&*packet.into_inner());
418                    let ipv4_repr = match Ipv4Repr::parse(&packet, _checksum_caps) {
419                        Ok(x) => x,
420                        Err(_) => {
421                            net_trace!("raw: malformed ipv4 packet in queue, dropping.");
422                            return Ok(());
423                        }
424                    };
425                    net_trace!("raw:{:?}:{:?}: sending", ip_version, ip_protocol);
426                    emit(cx, (IpRepr::Ipv4(ipv4_repr), packet.payload()))
427                }
428                #[cfg(feature = "proto-ipv6")]
429                Ok(IpVersion::Ipv6) => {
430                    let packet = match Ipv6Packet::new_checked(buffer) {
431                        Ok(x) => x,
432                        Err(_) => {
433                            net_trace!("raw: malformed ipv6 packet in queue, dropping.");
434                            return Ok(());
435                        }
436                    };
437                    if ip_protocol.is_some_and(|next_header| next_header != packet.next_header()) {
438                        net_trace!("raw: sent ipv6 packet with wrong ip protocol, dropping.");
439                        return Ok(());
440                    }
441                    let packet = Ipv6Packet::new_unchecked(&*packet.into_inner());
442                    let ipv6_repr = match Ipv6Repr::parse(&packet) {
443                        Ok(x) => x,
444                        Err(_) => {
445                            net_trace!("raw: malformed ipv6 packet in queue, dropping.");
446                            return Ok(());
447                        }
448                    };
449
450                    net_trace!("raw:{:?}:{:?}: sending", ip_version, ip_protocol);
451                    emit(cx, (IpRepr::Ipv6(ipv6_repr), packet.payload()))
452                }
453                Err(_) => {
454                    net_trace!("raw: sent packet with invalid IP version, dropping.");
455                    Ok(())
456                }
457            }
458        });
459        match res {
460            Err(Empty) => Ok(()),
461            Ok(Err(e)) => Err(e),
462            Ok(Ok(())) => {
463                #[cfg(feature = "async")]
464                self.tx_waker.wake();
465                Ok(())
466            }
467        }
468    }
469
470    pub(crate) fn poll_at(&self, _cx: &mut Context) -> PollAt {
471        if self.tx_buffer.is_empty() {
472            PollAt::Ingress
473        } else {
474            PollAt::Now
475        }
476    }
477}
478
479#[cfg(test)]
480mod test {
481    use crate::phy::Medium;
482    use crate::tests::setup;
483    use rstest::*;
484
485    use super::*;
486    use crate::wire::IpRepr;
487    #[cfg(feature = "proto-ipv4")]
488    use crate::wire::{Ipv4Address, Ipv4Repr};
489    #[cfg(feature = "proto-ipv6")]
490    use crate::wire::{Ipv6Address, Ipv6Repr};
491
492    fn buffer(packets: usize) -> PacketBuffer<'static> {
493        PacketBuffer::new(vec![PacketMetadata::EMPTY; packets], vec![0; 48 * packets])
494    }
495
496    #[cfg(feature = "proto-ipv4")]
497    mod ipv4_locals {
498        use super::*;
499
500        pub fn socket(
501            rx_buffer: PacketBuffer<'static>,
502            tx_buffer: PacketBuffer<'static>,
503        ) -> Socket<'static> {
504            Socket::new(
505                Some(IpVersion::Ipv4),
506                Some(IpProtocol::Unknown(IP_PROTO)),
507                rx_buffer,
508                tx_buffer,
509            )
510        }
511
512        pub const IP_PROTO: u8 = 63;
513        pub const HEADER_REPR: IpRepr = IpRepr::Ipv4(Ipv4Repr {
514            src_addr: Ipv4Address::new(10, 0, 0, 1),
515            dst_addr: Ipv4Address::new(10, 0, 0, 2),
516            next_header: IpProtocol::Unknown(IP_PROTO),
517            payload_len: 4,
518            hop_limit: 64,
519        });
520        pub const PACKET_BYTES: [u8; 24] = [
521            0x45, 0x00, 0x00, 0x18, 0x00, 0x00, 0x40, 0x00, 0x40, 0x3f, 0x00, 0x00, 0x0a, 0x00,
522            0x00, 0x01, 0x0a, 0x00, 0x00, 0x02, 0xaa, 0x00, 0x00, 0xff,
523        ];
524        pub const PACKET_PAYLOAD: [u8; 4] = [0xaa, 0x00, 0x00, 0xff];
525    }
526
527    #[cfg(feature = "proto-ipv6")]
528    mod ipv6_locals {
529        use super::*;
530
531        pub fn socket(
532            rx_buffer: PacketBuffer<'static>,
533            tx_buffer: PacketBuffer<'static>,
534        ) -> Socket<'static> {
535            Socket::new(
536                Some(IpVersion::Ipv6),
537                Some(IpProtocol::Unknown(IP_PROTO)),
538                rx_buffer,
539                tx_buffer,
540            )
541        }
542
543        pub const IP_PROTO: u8 = 63;
544        pub const HEADER_REPR: IpRepr = IpRepr::Ipv6(Ipv6Repr {
545            src_addr: Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1),
546            dst_addr: Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 2),
547            next_header: IpProtocol::Unknown(IP_PROTO),
548            payload_len: 4,
549            hop_limit: 64,
550        });
551
552        pub const PACKET_BYTES: [u8; 44] = [
553            0x60, 0x00, 0x00, 0x00, 0x00, 0x04, 0x3f, 0x40, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00,
554            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00,
555            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xaa, 0x00,
556            0x00, 0xff,
557        ];
558
559        pub const PACKET_PAYLOAD: [u8; 4] = [0xaa, 0x00, 0x00, 0xff];
560    }
561
562    macro_rules! reusable_ip_specific_tests {
563        ($module:ident, $socket:path, $hdr:path, $packet:path, $payload:path) => {
564            mod $module {
565                use super::*;
566
567                #[test]
568                fn test_send_truncated() {
569                    let mut socket = $socket(buffer(0), buffer(1));
570                    assert_eq!(socket.send_slice(&[0; 56][..]), Err(SendError::BufferFull));
571                }
572
573                #[rstest]
574                #[case::ip(Medium::Ip)]
575                #[cfg(feature = "medium-ip")]
576                #[case::ethernet(Medium::Ethernet)]
577                #[cfg(feature = "medium-ethernet")]
578                #[case::ieee802154(Medium::Ieee802154)]
579                #[cfg(feature = "medium-ieee802154")]
580                fn test_send_dispatch(#[case] medium: Medium) {
581                    let (mut iface, _, _) = setup(medium);
582                    let mut cx = iface.context();
583                    let mut socket = $socket(buffer(0), buffer(1));
584
585                    assert!(socket.can_send());
586                    assert_eq!(
587                        socket.dispatch(&mut cx, |_, _| unreachable!()),
588                        Ok::<_, ()>(())
589                    );
590
591                    assert_eq!(socket.send_slice(&$packet[..]), Ok(()));
592                    assert_eq!(socket.send_slice(b""), Err(SendError::BufferFull));
593                    assert!(!socket.can_send());
594
595                    assert_eq!(
596                        socket.dispatch(&mut cx, |_, (ip_repr, ip_payload)| {
597                            assert_eq!(ip_repr, $hdr);
598                            assert_eq!(ip_payload, &$payload);
599                            Err(())
600                        }),
601                        Err(())
602                    );
603                    assert!(!socket.can_send());
604
605                    assert_eq!(
606                        socket.dispatch(&mut cx, |_, (ip_repr, ip_payload)| {
607                            assert_eq!(ip_repr, $hdr);
608                            assert_eq!(ip_payload, &$payload);
609                            Ok::<_, ()>(())
610                        }),
611                        Ok(())
612                    );
613                    assert!(socket.can_send());
614                }
615
616                #[rstest]
617                #[case::ip(Medium::Ip)]
618                #[cfg(feature = "medium-ip")]
619                #[case::ethernet(Medium::Ethernet)]
620                #[cfg(feature = "medium-ethernet")]
621                #[case::ieee802154(Medium::Ieee802154)]
622                #[cfg(feature = "medium-ieee802154")]
623                fn test_recv_truncated_slice(#[case] medium: Medium) {
624                    let (mut iface, _, _) = setup(medium);
625                    let mut cx = iface.context();
626                    let mut socket = $socket(buffer(1), buffer(0));
627
628                    assert!(socket.accepts(&$hdr));
629                    socket.process(&mut cx, &$hdr, &$payload);
630
631                    let mut slice = [0; 4];
632                    assert_eq!(socket.recv_slice(&mut slice[..]), Err(RecvError::Truncated));
633                }
634
635                #[rstest]
636                #[case::ip(Medium::Ip)]
637                #[cfg(feature = "medium-ip")]
638                #[case::ethernet(Medium::Ethernet)]
639                #[cfg(feature = "medium-ethernet")]
640                #[case::ieee802154(Medium::Ieee802154)]
641                #[cfg(feature = "medium-ieee802154")]
642                fn test_recv_truncated_packet(#[case] medium: Medium) {
643                    let (mut iface, _, _) = setup(medium);
644                    let mut cx = iface.context();
645                    let mut socket = $socket(buffer(1), buffer(0));
646
647                    let mut buffer = vec![0; 128];
648                    buffer[..$packet.len()].copy_from_slice(&$packet[..]);
649
650                    assert!(socket.accepts(&$hdr));
651                    socket.process(&mut cx, &$hdr, &buffer);
652                }
653
654                #[rstest]
655                #[case::ip(Medium::Ip)]
656                #[cfg(feature = "medium-ip")]
657                #[case::ethernet(Medium::Ethernet)]
658                #[cfg(feature = "medium-ethernet")]
659                #[case::ieee802154(Medium::Ieee802154)]
660                #[cfg(feature = "medium-ieee802154")]
661                fn test_peek_truncated_slice(#[case] medium: Medium) {
662                    let (mut iface, _, _) = setup(medium);
663                    let mut cx = iface.context();
664                    let mut socket = $socket(buffer(1), buffer(0));
665
666                    assert!(socket.accepts(&$hdr));
667                    socket.process(&mut cx, &$hdr, &$payload);
668
669                    let mut slice = [0; 4];
670                    assert_eq!(socket.peek_slice(&mut slice[..]), Err(RecvError::Truncated));
671                    assert_eq!(socket.recv_slice(&mut slice[..]), Err(RecvError::Truncated));
672                    assert_eq!(socket.peek_slice(&mut slice[..]), Err(RecvError::Exhausted));
673                }
674            }
675        };
676    }
677
678    #[cfg(feature = "proto-ipv4")]
679    reusable_ip_specific_tests!(
680        ipv4,
681        ipv4_locals::socket,
682        ipv4_locals::HEADER_REPR,
683        ipv4_locals::PACKET_BYTES,
684        ipv4_locals::PACKET_PAYLOAD
685    );
686
687    #[cfg(feature = "proto-ipv6")]
688    reusable_ip_specific_tests!(
689        ipv6,
690        ipv6_locals::socket,
691        ipv6_locals::HEADER_REPR,
692        ipv6_locals::PACKET_BYTES,
693        ipv6_locals::PACKET_PAYLOAD
694    );
695
696    #[rstest]
697    #[case::ip(Medium::Ip)]
698    #[case::ethernet(Medium::Ethernet)]
699    #[cfg(feature = "medium-ethernet")]
700    #[case::ieee802154(Medium::Ieee802154)]
701    #[cfg(feature = "medium-ieee802154")]
702    fn test_send_illegal(#[case] medium: Medium) {
703        #[cfg(feature = "proto-ipv4")]
704        {
705            let (mut iface, _, _) = setup(medium);
706            let cx = iface.context();
707            let mut socket = ipv4_locals::socket(buffer(0), buffer(2));
708
709            let mut wrong_version = ipv4_locals::PACKET_BYTES;
710            Ipv4Packet::new_unchecked(&mut wrong_version).set_version(6);
711
712            assert_eq!(socket.send_slice(&wrong_version[..]), Ok(()));
713            assert_eq!(socket.dispatch(cx, |_, _| unreachable!()), Ok::<_, ()>(()));
714
715            let mut wrong_protocol = ipv4_locals::PACKET_BYTES;
716            Ipv4Packet::new_unchecked(&mut wrong_protocol).set_next_header(IpProtocol::Tcp);
717
718            assert_eq!(socket.send_slice(&wrong_protocol[..]), Ok(()));
719            assert_eq!(socket.dispatch(cx, |_, _| unreachable!()), Ok::<_, ()>(()));
720        }
721        #[cfg(feature = "proto-ipv6")]
722        {
723            let (mut iface, _, _) = setup(medium);
724            let cx = iface.context();
725            let mut socket = ipv6_locals::socket(buffer(0), buffer(2));
726
727            let mut wrong_version = ipv6_locals::PACKET_BYTES;
728            Ipv6Packet::new_unchecked(&mut wrong_version[..]).set_version(4);
729
730            assert_eq!(socket.send_slice(&wrong_version[..]), Ok(()));
731            assert_eq!(socket.dispatch(cx, |_, _| unreachable!()), Ok::<_, ()>(()));
732
733            let mut wrong_protocol = ipv6_locals::PACKET_BYTES;
734            Ipv6Packet::new_unchecked(&mut wrong_protocol[..]).set_next_header(IpProtocol::Tcp);
735
736            assert_eq!(socket.send_slice(&wrong_protocol[..]), Ok(()));
737            assert_eq!(socket.dispatch(cx, |_, _| unreachable!()), Ok::<_, ()>(()));
738        }
739    }
740
741    #[rstest]
742    #[case::ip(Medium::Ip)]
743    #[cfg(feature = "medium-ip")]
744    #[case::ethernet(Medium::Ethernet)]
745    #[cfg(feature = "medium-ethernet")]
746    #[case::ieee802154(Medium::Ieee802154)]
747    #[cfg(feature = "medium-ieee802154")]
748    fn test_recv_process(#[case] medium: Medium) {
749        #[cfg(feature = "proto-ipv4")]
750        {
751            let (mut iface, _, _) = setup(medium);
752            let cx = iface.context();
753            let mut socket = ipv4_locals::socket(buffer(1), buffer(0));
754            assert!(!socket.can_recv());
755
756            let mut cksumd_packet = ipv4_locals::PACKET_BYTES;
757            Ipv4Packet::new_unchecked(&mut cksumd_packet).fill_checksum();
758
759            assert_eq!(socket.recv(), Err(RecvError::Exhausted));
760            assert!(socket.accepts(&ipv4_locals::HEADER_REPR));
761            socket.process(cx, &ipv4_locals::HEADER_REPR, &ipv4_locals::PACKET_PAYLOAD);
762            assert!(socket.can_recv());
763
764            assert!(socket.accepts(&ipv4_locals::HEADER_REPR));
765            socket.process(cx, &ipv4_locals::HEADER_REPR, &ipv4_locals::PACKET_PAYLOAD);
766            assert_eq!(socket.recv(), Ok(&cksumd_packet[..]));
767            assert!(!socket.can_recv());
768        }
769        #[cfg(feature = "proto-ipv6")]
770        {
771            let (mut iface, _, _) = setup(medium);
772            let cx = iface.context();
773            let mut socket = ipv6_locals::socket(buffer(1), buffer(0));
774            assert!(!socket.can_recv());
775
776            assert_eq!(socket.recv(), Err(RecvError::Exhausted));
777            assert!(socket.accepts(&ipv6_locals::HEADER_REPR));
778            socket.process(cx, &ipv6_locals::HEADER_REPR, &ipv6_locals::PACKET_PAYLOAD);
779            assert!(socket.can_recv());
780
781            assert!(socket.accepts(&ipv6_locals::HEADER_REPR));
782            socket.process(cx, &ipv6_locals::HEADER_REPR, &ipv6_locals::PACKET_PAYLOAD);
783            assert_eq!(socket.recv(), Ok(&ipv6_locals::PACKET_BYTES[..]));
784            assert!(!socket.can_recv());
785        }
786    }
787
788    #[rstest]
789    #[case::ip(Medium::Ip)]
790    #[case::ethernet(Medium::Ethernet)]
791    #[cfg(feature = "medium-ethernet")]
792    #[case::ieee802154(Medium::Ieee802154)]
793    #[cfg(feature = "medium-ieee802154")]
794    fn test_peek_process(#[case] medium: Medium) {
795        #[cfg(feature = "proto-ipv4")]
796        {
797            let (mut iface, _, _) = setup(medium);
798            let cx = iface.context();
799            let mut socket = ipv4_locals::socket(buffer(1), buffer(0));
800
801            let mut cksumd_packet = ipv4_locals::PACKET_BYTES;
802            Ipv4Packet::new_unchecked(&mut cksumd_packet).fill_checksum();
803
804            assert_eq!(socket.peek(), Err(RecvError::Exhausted));
805            assert!(socket.accepts(&ipv4_locals::HEADER_REPR));
806            socket.process(cx, &ipv4_locals::HEADER_REPR, &ipv4_locals::PACKET_PAYLOAD);
807
808            assert!(socket.accepts(&ipv4_locals::HEADER_REPR));
809            socket.process(cx, &ipv4_locals::HEADER_REPR, &ipv4_locals::PACKET_PAYLOAD);
810            assert_eq!(socket.peek(), Ok(&cksumd_packet[..]));
811            assert_eq!(socket.recv(), Ok(&cksumd_packet[..]));
812            assert_eq!(socket.peek(), Err(RecvError::Exhausted));
813        }
814        #[cfg(feature = "proto-ipv6")]
815        {
816            let (mut iface, _, _) = setup(medium);
817            let cx = iface.context();
818            let mut socket = ipv6_locals::socket(buffer(1), buffer(0));
819
820            assert_eq!(socket.peek(), Err(RecvError::Exhausted));
821            assert!(socket.accepts(&ipv6_locals::HEADER_REPR));
822            socket.process(cx, &ipv6_locals::HEADER_REPR, &ipv6_locals::PACKET_PAYLOAD);
823
824            assert!(socket.accepts(&ipv6_locals::HEADER_REPR));
825            socket.process(cx, &ipv6_locals::HEADER_REPR, &ipv6_locals::PACKET_PAYLOAD);
826            assert_eq!(socket.peek(), Ok(&ipv6_locals::PACKET_BYTES[..]));
827            assert_eq!(socket.recv(), Ok(&ipv6_locals::PACKET_BYTES[..]));
828            assert_eq!(socket.peek(), Err(RecvError::Exhausted));
829        }
830    }
831
832    #[test]
833    fn test_doesnt_accept_wrong_proto() {
834        #[cfg(feature = "proto-ipv4")]
835        {
836            let socket = Socket::new(
837                Some(IpVersion::Ipv4),
838                Some(IpProtocol::Unknown(ipv4_locals::IP_PROTO + 1)),
839                buffer(1),
840                buffer(1),
841            );
842            assert!(!socket.accepts(&ipv4_locals::HEADER_REPR));
843            #[cfg(feature = "proto-ipv6")]
844            assert!(!socket.accepts(&ipv6_locals::HEADER_REPR));
845        }
846        #[cfg(feature = "proto-ipv6")]
847        {
848            let socket = Socket::new(
849                Some(IpVersion::Ipv6),
850                Some(IpProtocol::Unknown(ipv6_locals::IP_PROTO + 1)),
851                buffer(1),
852                buffer(1),
853            );
854            assert!(!socket.accepts(&ipv6_locals::HEADER_REPR));
855            #[cfg(feature = "proto-ipv4")]
856            assert!(!socket.accepts(&ipv4_locals::HEADER_REPR));
857        }
858    }
859
860    fn check_dispatch(socket: &mut Socket<'_>, cx: &mut Context) {
861        // Check dispatch returns Ok(()) and calls the emit closure
862        let mut emitted = false;
863        assert_eq!(
864            socket.dispatch(cx, |_, _| {
865                emitted = true;
866                Ok(())
867            }),
868            Ok::<_, ()>(())
869        );
870        assert!(emitted);
871    }
872
873    #[rstest]
874    #[case::ip(Medium::Ip)]
875    #[case::ethernet(Medium::Ethernet)]
876    #[cfg(feature = "medium-ethernet")]
877    #[case::ieee802154(Medium::Ieee802154)]
878    #[cfg(feature = "medium-ieee802154")]
879    fn test_unfiltered_sends_all(#[case] medium: Medium) {
880        // Test a single unfiltered socket can send packets with different IP versions and next
881        // headers
882        let mut socket = Socket::new(None, None, buffer(0), buffer(2));
883        #[cfg(feature = "proto-ipv4")]
884        {
885            let (mut iface, _, _) = setup(medium);
886            let cx = iface.context();
887
888            let mut udp_packet = ipv4_locals::PACKET_BYTES;
889            Ipv4Packet::new_unchecked(&mut udp_packet).set_next_header(IpProtocol::Udp);
890
891            assert_eq!(socket.send_slice(&udp_packet), Ok(()));
892            check_dispatch(&mut socket, cx);
893
894            let mut tcp_packet = ipv4_locals::PACKET_BYTES;
895            Ipv4Packet::new_unchecked(&mut tcp_packet).set_next_header(IpProtocol::Tcp);
896
897            assert_eq!(socket.send_slice(&tcp_packet[..]), Ok(()));
898            check_dispatch(&mut socket, cx);
899        }
900        #[cfg(feature = "proto-ipv6")]
901        {
902            let (mut iface, _, _) = setup(medium);
903            let cx = iface.context();
904
905            let mut udp_packet = ipv6_locals::PACKET_BYTES;
906            Ipv6Packet::new_unchecked(&mut udp_packet).set_next_header(IpProtocol::Udp);
907
908            assert_eq!(socket.send_slice(&ipv6_locals::PACKET_BYTES), Ok(()));
909            check_dispatch(&mut socket, cx);
910
911            let mut tcp_packet = ipv6_locals::PACKET_BYTES;
912            Ipv6Packet::new_unchecked(&mut tcp_packet).set_next_header(IpProtocol::Tcp);
913
914            assert_eq!(socket.send_slice(&tcp_packet[..]), Ok(()));
915            check_dispatch(&mut socket, cx);
916        }
917    }
918
919    #[rstest]
920    #[case::proto(IpProtocol::Icmp)]
921    #[case::proto(IpProtocol::Tcp)]
922    #[case::proto(IpProtocol::Udp)]
923    fn test_unfiltered_accepts_all(#[case] proto: IpProtocol) {
924        // Test an unfiltered socket can accept packets with different IP versions and next headers
925        let socket = Socket::new(None, None, buffer(0), buffer(0));
926        #[cfg(feature = "proto-ipv4")]
927        {
928            let header_repr = IpRepr::Ipv4(Ipv4Repr {
929                src_addr: Ipv4Address::new(10, 0, 0, 1),
930                dst_addr: Ipv4Address::new(10, 0, 0, 2),
931                next_header: proto,
932                payload_len: 4,
933                hop_limit: 64,
934            });
935            assert!(socket.accepts(&header_repr));
936        }
937        #[cfg(feature = "proto-ipv6")]
938        {
939            let header_repr = IpRepr::Ipv6(Ipv6Repr {
940                src_addr: Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1),
941                dst_addr: Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 2),
942                next_header: proto,
943                payload_len: 4,
944                hop_limit: 64,
945            });
946            assert!(socket.accepts(&header_repr));
947        }
948    }
949}