network/socket/
udp.rs

1use core::{
2    future::poll_fn,
3    task::{Context, Poll},
4};
5
6use embassy_futures::block_on;
7use smoltcp::socket::udp;
8
9use crate::{Error, IpAddress, Port, Result, SocketContext, UdpMetadata};
10
11pub struct UdpSocket {
12    context: SocketContext,
13}
14
15impl UdpSocket {
16    pub(crate) fn new(context: SocketContext) -> Self {
17        Self { context }
18    }
19
20    pub async fn with<F, R>(&self, f: F) -> R
21    where
22        F: FnOnce(&udp::Socket<'static>) -> R,
23    {
24        self.context.with(f).await
25    }
26
27    pub async fn with_mutable<F, R>(&self, f: F) -> R
28    where
29        F: FnOnce(&mut udp::Socket<'static>) -> R,
30    {
31        self.context.with_mutable(f).await
32    }
33
34    pub fn poll_with<F, R>(&self, context: &mut Context<'_>, f: F) -> Poll<R>
35    where
36        F: FnOnce(&udp::Socket, &mut Context<'_>) -> Poll<R>,
37    {
38        self.context.poll_with(context, f)
39    }
40
41    pub fn poll_with_mutable<F, R>(&self, context: &mut Context<'_>, f: F) -> Poll<R>
42    where
43        F: FnOnce(&mut udp::Socket, &mut Context<'_>) -> Poll<R>,
44    {
45        self.context.poll_with_mutable(context, f)
46    }
47
48    pub async fn bind(&mut self, port: Port) -> Result<()> {
49        let port = port.into_inner();
50
51        self.with_mutable(|socket| socket.bind(port)).await?;
52
53        Ok(())
54    }
55
56    pub fn poll_send_to(
57        &mut self,
58        context: &mut Context<'_>,
59        buffer: &[u8],
60        metadata: &UdpMetadata,
61    ) -> Poll<Result<()>> {
62        log::debug!("UDP poll_send_to: sending {} bytes", buffer.len());
63        self.poll_with_mutable(context, |socket, cx| {
64            let capacity = socket.payload_send_capacity();
65            log::debug!("UDP send capacity: {}, needed: {}", capacity, buffer.len());
66
67            if capacity < buffer.len() {
68                log::warning!(
69                    "UDP send buffer too small: capacity={}, needed={}",
70                    capacity,
71                    buffer.len()
72                );
73                return Poll::Ready(Err(Error::PacketTooLarge));
74            }
75
76            let metadata = metadata.to_smoltcp();
77
78            match socket.send_slice(buffer, metadata) {
79                Ok(()) => {
80                    log::debug!("UDP send_slice succeeded");
81                    Poll::Ready(Ok(()))
82                }
83                Err(udp::SendError::BufferFull) => {
84                    log::information!("UDP send buffer full, registering waker");
85                    socket.register_send_waker(cx.waker());
86                    Poll::Pending
87                }
88                Err(udp::SendError::Unaddressable) => {
89                    if socket.endpoint().port == 0 {
90                        log::error!("UDP send failed: socket not bound");
91                        Poll::Ready(Err(Error::SocketNotBound))
92                    } else {
93                        log::error!("UDP send failed: no route");
94                        Poll::Ready(Err(Error::NoRoute))
95                    }
96                }
97            }
98        })
99    }
100
101    pub fn poll_receive_from(
102        &self,
103        context: &mut Context<'_>,
104        buffer: &mut [u8],
105    ) -> Poll<Result<(usize, UdpMetadata)>> {
106        log::debug!("UDP poll_receive_from: buffer size={}", buffer.len());
107        self.poll_with_mutable(context, |socket, cx| {
108            log::debug!(
109                "UDP recv: checking socket for data (can_recv={}, buffered={})",
110                socket.can_recv(),
111                socket.recv_queue()
112            );
113            match socket.recv_slice(buffer) {
114                Ok((n, meta)) => {
115                    log::information!("UDP received {} bytes", n);
116                    Poll::Ready(Ok((n, UdpMetadata::from_smoltcp(&meta))))
117                }
118                Err(udp::RecvError::Truncated) => {
119                    log::warning!("UDP receive truncated");
120                    Poll::Ready(Err(Error::Truncated))
121                }
122                Err(udp::RecvError::Exhausted) => {
123                    log::information!(
124                        "UDP receive buffer exhausted (can_recv={}, buffered={})",
125                        socket.can_recv(),
126                        socket.recv_queue()
127                    );
128                    socket.register_recv_waker(cx.waker());
129                    log::information!("UDP receive waker registered");
130                    Poll::Pending
131                }
132            }
133        })
134    }
135
136    pub async fn write_to(&mut self, buffer: &[u8], metadata: &UdpMetadata) -> Result<()> {
137        log::debug!(
138            "UDP write_to: starting async send of {} bytes",
139            buffer.len()
140        );
141        let result = poll_fn(|cx| self.poll_send_to(cx, buffer, metadata)).await;
142        log::debug!("UDP write_to: completed with result {:?}", result.is_ok());
143        result
144    }
145
146    pub async fn read_from(&self, buffer: &mut [u8]) -> Result<(usize, UdpMetadata)> {
147        poll_fn(|cx| {
148            log::information!("Polling UDP read");
149            let r = self.poll_receive_from(cx, buffer);
150            log::information!("UDP read poll completed");
151            r
152        })
153        .await
154    }
155
156    pub fn flush(&mut self) -> impl Future<Output = ()> + '_ {
157        poll_fn(|cx| {
158            self.poll_with_mutable(cx, |socket, cx| {
159                if socket.can_send() {
160                    Poll::Ready(())
161                } else {
162                    socket.register_send_waker(cx.waker());
163                    Poll::Pending
164                }
165            })
166        })
167    }
168
169    pub async fn close(mut self) {
170        self.context.closed = true;
171        self.with_mutable(|s| {
172            log::information!("Closing UDP socket : {:?}", s.endpoint());
173            udp::Socket::close(s);
174            log::information!("UDP socket closed");
175        })
176        .await;
177    }
178
179    pub async fn get_endpoint(&self) -> Result<(Option<IpAddress>, Port)> {
180        let endpoint = self.with(udp::Socket::endpoint).await;
181
182        let ip_address = endpoint.addr.as_ref().map(IpAddress::from_smoltcp);
183        let port = Port::from_inner(endpoint.port);
184
185        Ok((ip_address, port))
186    }
187
188    pub async fn get_packet_read_capacity(&self) -> usize {
189        self.with(udp::Socket::packet_recv_capacity).await
190    }
191
192    pub async fn get_packet_write_capacity(&self) -> usize {
193        self.with(udp::Socket::packet_send_capacity).await
194    }
195
196    pub async fn get_payload_read_capacity(&self) -> usize {
197        self.with(udp::Socket::payload_recv_capacity).await
198    }
199
200    pub async fn get_payload_write_capacity(&self) -> usize {
201        self.with(udp::Socket::payload_send_capacity).await
202    }
203
204    pub async fn set_hop_limit(&mut self, hop_limit: Option<u8>) {
205        self.with_mutable(|socket| socket.set_hop_limit(hop_limit))
206            .await
207    }
208}
209
210impl Drop for UdpSocket {
211    fn drop(&mut self) {
212        if self.context.closed {
213            return;
214        }
215        log::warning!("UDP socket dropped without being closed. Forcing closure.");
216        block_on(self.with_mutable(udp::Socket::close));
217    }
218}
219
220#[cfg(test)]
221mod tests {
222
223    extern crate std;
224
225    use crate::tests::initialize;
226
227    use super::*;
228    use smoltcp::phy::PacketMeta;
229
230    #[task::test]
231    async fn test_udp_bind() {
232        let network_manager = initialize().await;
233
234        let mut socket = network_manager
235            .new_udp_socket(1024, 1024, 10, 10, None)
236            .await
237            .expect("Failed to create UDP socket");
238
239        let port = Port::from_inner(10001);
240        let result = socket.bind(port).await;
241
242        assert!(result.is_ok(), "Failed to bind UDP socket");
243
244        let (_ip, bound_port) = socket.get_endpoint().await.expect("Failed to get endpoint");
245        assert_eq!(bound_port, port, "Port mismatch");
246
247        socket.close().await;
248    }
249
250    #[task::test]
251    async fn test_udp_send_receive() {
252        let network_manager = initialize().await;
253
254        // Create sender socket
255        let mut sender = network_manager
256            .new_udp_socket(1024, 1024, 10, 10, None)
257            .await
258            .expect("Failed to create sender socket");
259
260        // Create receiver socket
261        let mut receiver = network_manager
262            .new_udp_socket(65535, 65535, 10, 10, None)
263            .await
264            .expect("Failed to create receiver socket");
265
266        // Bind receiver to a specific port
267        let receiver_port = Port::from_inner(10003);
268        receiver
269            .bind(receiver_port)
270            .await
271            .expect("Failed to bind receiver");
272
273        // Prepare test data
274        let test_data = b"Hello, UDP!";
275
276        let remote_ip: IpAddress = [127, 0, 0, 1].into();
277
278        let metadata = UdpMetadata::new(remote_ip, receiver_port, None, PacketMeta::default());
279
280        sender
281            .bind(Port::from_inner(10002))
282            .await
283            .expect("Failed to bind sender");
284
285        log::information!("Sending data");
286
287        // Send data
288        let send_result = sender.write_to(test_data, &metadata).await;
289        assert_eq!(send_result, Ok(()));
290
291        log::information!("Data sent, waiting to receive...");
292
293        // Receive data
294        let mut buffer = [0u8; 1024];
295        let receive_result = receiver.read_from(&mut buffer).await;
296
297        log::information!("Data received");
298
299        if let Ok((size, _recv_metadata)) = receive_result {
300            assert_eq!(size, test_data.len(), "Received data size mismatch");
301            assert_eq!(&buffer[..size], test_data, "Received data mismatch");
302        }
303
304        sender.close().await;
305        receiver.close().await;
306    }
307
308    #[task::test]
309    async fn test_udp_endpoint() {
310        let network_manager = initialize().await;
311
312        let mut socket = network_manager
313            .new_udp_socket(1024, 1024, 10, 10, None)
314            .await
315            .expect("Failed to create UDP socket");
316
317        // Before binding, endpoint should have port 0
318        let (_ip, port) = socket
319            .get_endpoint()
320            .await
321            .expect("Failed to get initial endpoint");
322        assert_eq!(port.into_inner(), 0, "Initial port should be 0");
323
324        // After binding, endpoint should have the bound port
325        let bind_port = Port::from_inner(10004);
326        socket.bind(bind_port).await.expect("Failed to bind");
327
328        let (_, bound_port) = socket
329            .get_endpoint()
330            .await
331            .expect("Failed to get bound endpoint");
332        assert_eq!(bound_port, bind_port, "Bound port mismatch");
333
334        socket.close().await;
335    }
336
337    #[task::test]
338    async fn test_udp_capacities() {
339        let network_manager = initialize().await;
340
341        let tx_buffer = 2048;
342        let rx_buffer = 1024;
343        let rx_meta = 15;
344        let tx_meta = 20;
345
346        let socket = network_manager
347            .new_udp_socket(tx_buffer, rx_buffer, rx_meta, tx_meta, None)
348            .await
349            .expect("Failed to create UDP socket");
350
351        let packet_read_cap = socket.get_packet_read_capacity().await;
352        let packet_write_cap = socket.get_packet_write_capacity().await;
353        let payload_read_cap = socket.get_payload_read_capacity().await;
354        let payload_write_cap = socket.get_payload_write_capacity().await;
355
356        assert_eq!(packet_read_cap, rx_meta, "Packet read capacity mismatch");
357        assert_eq!(packet_write_cap, tx_meta, "Packet write capacity mismatch");
358        assert_eq!(
359            payload_read_cap, rx_buffer,
360            "Payload read capacity mismatch"
361        );
362        assert_eq!(
363            payload_write_cap, tx_buffer,
364            "Payload write capacity mismatch"
365        );
366
367        socket.close().await;
368    }
369}