1use core::{
2 future::poll_fn,
3 task::{Context, Poll},
4};
5
6use embassy_futures::block_on;
7use smoltcp::socket::udp;
8
9use crate::{Error, IpAddress, Port, Result, SocketContext, UdpMetadata};
10
11pub struct UdpSocket {
12 context: SocketContext,
13}
14
15impl UdpSocket {
16 pub(crate) fn new(context: SocketContext) -> Self {
17 Self { context }
18 }
19
20 pub async fn with<F, R>(&self, f: F) -> R
21 where
22 F: FnOnce(&udp::Socket<'static>) -> R,
23 {
24 self.context.with(f).await
25 }
26
27 pub async fn with_mutable<F, R>(&self, f: F) -> R
28 where
29 F: FnOnce(&mut udp::Socket<'static>) -> R,
30 {
31 self.context.with_mutable(f).await
32 }
33
34 pub fn poll_with<F, R>(&self, context: &mut Context<'_>, f: F) -> Poll<R>
35 where
36 F: FnOnce(&udp::Socket, &mut Context<'_>) -> Poll<R>,
37 {
38 self.context.poll_with(context, f)
39 }
40
41 pub fn poll_with_mutable<F, R>(&self, context: &mut Context<'_>, f: F) -> Poll<R>
42 where
43 F: FnOnce(&mut udp::Socket, &mut Context<'_>) -> Poll<R>,
44 {
45 self.context.poll_with_mutable(context, f)
46 }
47
48 pub async fn bind(&mut self, port: Port) -> Result<()> {
49 let port = port.into_inner();
50
51 self.with_mutable(|socket| socket.bind(port)).await?;
52
53 Ok(())
54 }
55
56 pub fn poll_send_to(
57 &mut self,
58 context: &mut Context<'_>,
59 buffer: &[u8],
60 metadata: &UdpMetadata,
61 ) -> Poll<Result<()>> {
62 log::debug!("UDP poll_send_to: sending {} bytes", buffer.len());
63 self.poll_with_mutable(context, |socket, cx| {
64 let capacity = socket.payload_send_capacity();
65 log::debug!("UDP send capacity: {}, needed: {}", capacity, buffer.len());
66
67 if capacity < buffer.len() {
68 log::warning!(
69 "UDP send buffer too small: capacity={}, needed={}",
70 capacity,
71 buffer.len()
72 );
73 return Poll::Ready(Err(Error::PacketTooLarge));
74 }
75
76 let metadata = metadata.to_smoltcp();
77
78 match socket.send_slice(buffer, metadata) {
79 Ok(()) => {
80 log::debug!("UDP send_slice succeeded");
81 Poll::Ready(Ok(()))
82 }
83 Err(udp::SendError::BufferFull) => {
84 log::information!("UDP send buffer full, registering waker");
85 socket.register_send_waker(cx.waker());
86 Poll::Pending
87 }
88 Err(udp::SendError::Unaddressable) => {
89 if socket.endpoint().port == 0 {
90 log::error!("UDP send failed: socket not bound");
91 Poll::Ready(Err(Error::SocketNotBound))
92 } else {
93 log::error!("UDP send failed: no route");
94 Poll::Ready(Err(Error::NoRoute))
95 }
96 }
97 }
98 })
99 }
100
101 pub fn poll_receive_from(
102 &self,
103 context: &mut Context<'_>,
104 buffer: &mut [u8],
105 ) -> Poll<Result<(usize, UdpMetadata)>> {
106 log::debug!("UDP poll_receive_from: buffer size={}", buffer.len());
107 self.poll_with_mutable(context, |socket, cx| {
108 log::debug!(
109 "UDP recv: checking socket for data (can_recv={}, buffered={})",
110 socket.can_recv(),
111 socket.recv_queue()
112 );
113 match socket.recv_slice(buffer) {
114 Ok((n, meta)) => {
115 log::information!("UDP received {} bytes", n);
116 Poll::Ready(Ok((n, UdpMetadata::from_smoltcp(&meta))))
117 }
118 Err(udp::RecvError::Truncated) => {
119 log::warning!("UDP receive truncated");
120 Poll::Ready(Err(Error::Truncated))
121 }
122 Err(udp::RecvError::Exhausted) => {
123 log::information!(
124 "UDP receive buffer exhausted (can_recv={}, buffered={})",
125 socket.can_recv(),
126 socket.recv_queue()
127 );
128 socket.register_recv_waker(cx.waker());
129 log::information!("UDP receive waker registered");
130 Poll::Pending
131 }
132 }
133 })
134 }
135
136 pub async fn write_to(&mut self, buffer: &[u8], metadata: &UdpMetadata) -> Result<()> {
137 log::debug!(
138 "UDP write_to: starting async send of {} bytes",
139 buffer.len()
140 );
141 let result = poll_fn(|cx| self.poll_send_to(cx, buffer, metadata)).await;
142 log::debug!("UDP write_to: completed with result {:?}", result.is_ok());
143 result
144 }
145
146 pub async fn read_from(&self, buffer: &mut [u8]) -> Result<(usize, UdpMetadata)> {
147 poll_fn(|cx| {
148 log::information!("Polling UDP read");
149 let r = self.poll_receive_from(cx, buffer);
150 log::information!("UDP read poll completed");
151 r
152 })
153 .await
154 }
155
156 pub fn flush(&mut self) -> impl Future<Output = ()> + '_ {
157 poll_fn(|cx| {
158 self.poll_with_mutable(cx, |socket, cx| {
159 if socket.can_send() {
160 Poll::Ready(())
161 } else {
162 socket.register_send_waker(cx.waker());
163 Poll::Pending
164 }
165 })
166 })
167 }
168
169 pub async fn close(mut self) {
170 self.context.closed = true;
171 self.with_mutable(|s| {
172 log::information!("Closing UDP socket : {:?}", s.endpoint());
173 udp::Socket::close(s);
174 log::information!("UDP socket closed");
175 })
176 .await;
177 }
178
179 pub async fn get_endpoint(&self) -> Result<(Option<IpAddress>, Port)> {
180 let endpoint = self.with(udp::Socket::endpoint).await;
181
182 let ip_address = endpoint.addr.as_ref().map(IpAddress::from_smoltcp);
183 let port = Port::from_inner(endpoint.port);
184
185 Ok((ip_address, port))
186 }
187
188 pub async fn get_packet_read_capacity(&self) -> usize {
189 self.with(udp::Socket::packet_recv_capacity).await
190 }
191
192 pub async fn get_packet_write_capacity(&self) -> usize {
193 self.with(udp::Socket::packet_send_capacity).await
194 }
195
196 pub async fn get_payload_read_capacity(&self) -> usize {
197 self.with(udp::Socket::payload_recv_capacity).await
198 }
199
200 pub async fn get_payload_write_capacity(&self) -> usize {
201 self.with(udp::Socket::payload_send_capacity).await
202 }
203
204 pub async fn set_hop_limit(&mut self, hop_limit: Option<u8>) {
205 self.with_mutable(|socket| socket.set_hop_limit(hop_limit))
206 .await
207 }
208}
209
210impl Drop for UdpSocket {
211 fn drop(&mut self) {
212 if self.context.closed {
213 return;
214 }
215 log::warning!("UDP socket dropped without being closed. Forcing closure.");
216 block_on(self.with_mutable(udp::Socket::close));
217 }
218}
219
220#[cfg(test)]
221mod tests {
222
223 extern crate std;
224
225 use crate::tests::initialize;
226
227 use super::*;
228 use smoltcp::phy::PacketMeta;
229
230 #[task::test]
231 async fn test_udp_bind() {
232 let network_manager = initialize().await;
233
234 let mut socket = network_manager
235 .new_udp_socket(1024, 1024, 10, 10, None)
236 .await
237 .expect("Failed to create UDP socket");
238
239 let port = Port::from_inner(10001);
240 let result = socket.bind(port).await;
241
242 assert!(result.is_ok(), "Failed to bind UDP socket");
243
244 let (_ip, bound_port) = socket.get_endpoint().await.expect("Failed to get endpoint");
245 assert_eq!(bound_port, port, "Port mismatch");
246
247 socket.close().await;
248 }
249
250 #[task::test]
251 async fn test_udp_send_receive() {
252 let network_manager = initialize().await;
253
254 let mut sender = network_manager
256 .new_udp_socket(1024, 1024, 10, 10, None)
257 .await
258 .expect("Failed to create sender socket");
259
260 let mut receiver = network_manager
262 .new_udp_socket(65535, 65535, 10, 10, None)
263 .await
264 .expect("Failed to create receiver socket");
265
266 let receiver_port = Port::from_inner(10003);
268 receiver
269 .bind(receiver_port)
270 .await
271 .expect("Failed to bind receiver");
272
273 let test_data = b"Hello, UDP!";
275
276 let remote_ip: IpAddress = [127, 0, 0, 1].into();
277
278 let metadata = UdpMetadata::new(remote_ip, receiver_port, None, PacketMeta::default());
279
280 sender
281 .bind(Port::from_inner(10002))
282 .await
283 .expect("Failed to bind sender");
284
285 log::information!("Sending data");
286
287 let send_result = sender.write_to(test_data, &metadata).await;
289 assert_eq!(send_result, Ok(()));
290
291 log::information!("Data sent, waiting to receive...");
292
293 let mut buffer = [0u8; 1024];
295 let receive_result = receiver.read_from(&mut buffer).await;
296
297 log::information!("Data received");
298
299 if let Ok((size, _recv_metadata)) = receive_result {
300 assert_eq!(size, test_data.len(), "Received data size mismatch");
301 assert_eq!(&buffer[..size], test_data, "Received data mismatch");
302 }
303
304 sender.close().await;
305 receiver.close().await;
306 }
307
308 #[task::test]
309 async fn test_udp_endpoint() {
310 let network_manager = initialize().await;
311
312 let mut socket = network_manager
313 .new_udp_socket(1024, 1024, 10, 10, None)
314 .await
315 .expect("Failed to create UDP socket");
316
317 let (_ip, port) = socket
319 .get_endpoint()
320 .await
321 .expect("Failed to get initial endpoint");
322 assert_eq!(port.into_inner(), 0, "Initial port should be 0");
323
324 let bind_port = Port::from_inner(10004);
326 socket.bind(bind_port).await.expect("Failed to bind");
327
328 let (_, bound_port) = socket
329 .get_endpoint()
330 .await
331 .expect("Failed to get bound endpoint");
332 assert_eq!(bound_port, bind_port, "Bound port mismatch");
333
334 socket.close().await;
335 }
336
337 #[task::test]
338 async fn test_udp_capacities() {
339 let network_manager = initialize().await;
340
341 let tx_buffer = 2048;
342 let rx_buffer = 1024;
343 let rx_meta = 15;
344 let tx_meta = 20;
345
346 let socket = network_manager
347 .new_udp_socket(tx_buffer, rx_buffer, rx_meta, tx_meta, None)
348 .await
349 .expect("Failed to create UDP socket");
350
351 let packet_read_cap = socket.get_packet_read_capacity().await;
352 let packet_write_cap = socket.get_packet_write_capacity().await;
353 let payload_read_cap = socket.get_payload_read_capacity().await;
354 let payload_write_cap = socket.get_payload_write_capacity().await;
355
356 assert_eq!(packet_read_cap, rx_meta, "Packet read capacity mismatch");
357 assert_eq!(packet_write_cap, tx_meta, "Packet write capacity mismatch");
358 assert_eq!(
359 payload_read_cap, rx_buffer,
360 "Payload read capacity mismatch"
361 );
362 assert_eq!(
363 payload_write_cap, tx_buffer,
364 "Payload write capacity mismatch"
365 );
366
367 socket.close().await;
368 }
369}