1use core::{
2 future::poll_fn,
3 task::{Context, Poll},
4};
5
6use crate::{
7 Duration, Error, IpAddress, IpEndpoint, IpListenEndpoint, Port, Result, SocketContext,
8};
9use embassy_futures::block_on;
10use smoltcp::socket::tcp;
11
12#[repr(transparent)]
13pub struct TcpSocket {
14 context: SocketContext,
15}
16
17impl TcpSocket {
18 pub(crate) fn new(context: SocketContext) -> Self {
19 Self { context }
20 }
21 pub async fn with<F, R>(&self, f: F) -> R
22 where
23 F: FnOnce(&tcp::Socket<'static>) -> R,
24 {
25 self.context.with(f).await
26 }
27
28 pub async fn with_mutable<F, R>(&self, f: F) -> R
29 where
30 F: FnOnce(&mut tcp::Socket<'static>) -> R,
31 {
32 self.context.with_mutable(f).await
33 }
34
35 pub fn poll_with<F, R>(&self, context: &mut Context<'_>, f: F) -> Poll<R>
36 where
37 F: FnOnce(&tcp::Socket, &mut Context<'_>) -> Poll<R>,
38 {
39 self.context.poll_with(context, f)
40 }
41
42 pub fn poll_with_mutable<F, R>(&self, context: &mut Context<'_>, f: F) -> Poll<R>
43 where
44 F: FnOnce(&mut tcp::Socket, &mut Context<'_>) -> Poll<R>,
45 {
46 self.context.poll_with_mutable(context, f)
47 }
48
49 pub async fn set_timeout(&mut self, timeout: Option<Duration>) {
50 let timeout = timeout.map(Duration::into_smoltcp);
51
52 self.with_mutable(|socket| socket.set_timeout(timeout))
53 .await
54 }
55
56 pub async fn accept(
57 &mut self,
58 address: Option<impl Into<IpAddress>>,
59 port: impl Into<Port>,
60 ) -> Result<()> {
61 let endpoint = IpListenEndpoint::new(address.map(Into::into), port.into()).into_smoltcp();
62
63 self.with_mutable(|s| {
64 if s.state() == tcp::State::Closed {
65 s.listen(endpoint).map_err(|e| match e {
66 tcp::ListenError::InvalidState => Error::InvalidState,
67 tcp::ListenError::Unaddressable => Error::InvalidPort,
68 })
69 } else {
70 Ok(())
71 }
72 })
73 .await?;
74
75 poll_fn(|cx| {
76 self.poll_with_mutable(cx, |s, cx| match s.state() {
77 tcp::State::Listen | tcp::State::SynSent | tcp::State::SynReceived => {
78 s.register_send_waker(cx.waker());
79 Poll::Pending
80 }
81 _ => Poll::Ready(Ok(())),
82 })
83 })
84 .await
85 }
86
87 pub async fn connect(
88 &mut self,
89 address: impl Into<IpAddress>,
90 port: impl Into<Port>,
91 ) -> Result<()> {
92 let endpoint = IpEndpoint::new(address.into(), port.into()).into_smoltcp();
93
94 self.context
95 .stack
96 .with_mutable(|stack| {
97 let local_port = stack.get_next_port().into_inner();
98
99 let socket: &mut tcp::Socket<'static> = stack.sockets.get_mut(self.context.handle);
100
101 match socket.connect(stack.interface.context(), endpoint, local_port) {
102 Ok(()) => Ok(()),
103 Err(tcp::ConnectError::InvalidState) => Err(Error::InvalidState),
104 Err(tcp::ConnectError::Unaddressable) => Err(Error::NoRoute),
105 }
106 })
107 .await?;
108
109 poll_fn(|cx| {
110 self.poll_with_mutable(cx, |socket, cx| match socket.state() {
111 tcp::State::Closed | tcp::State::TimeWait => {
112 Poll::Ready(Err(Error::ConnectionReset))
113 }
114 tcp::State::Listen => unreachable!(),
115 tcp::State::SynSent | tcp::State::SynReceived => {
116 socket.register_send_waker(cx.waker());
117 Poll::Pending
118 }
119 _ => Poll::Ready(Ok(())),
120 })
121 })
122 .await
123 }
124
125 pub async fn read(&mut self, buffer: &mut [u8]) -> Result<usize> {
126 poll_fn(|cx| {
127 self.poll_with_mutable(cx, |s, cx| match s.recv_slice(buffer) {
128 Ok(0) if buffer.is_empty() => Poll::Ready(Ok(0)),
129 Ok(0) if !s.may_recv() => Poll::Ready(Ok(0)),
130 Ok(0) => {
131 s.register_recv_waker(cx.waker());
132 Poll::Pending
133 }
134 Ok(n) => Poll::Ready(Ok(n)),
135 Err(tcp::RecvError::Finished) => Poll::Ready(Ok(0)),
136 Err(tcp::RecvError::InvalidState) => Poll::Ready(Err(Error::ConnectionReset)),
137 })
138 })
139 .await
140 }
141
142 pub async fn write(&mut self, buffer: &[u8]) -> Result<usize> {
143 poll_fn(|cx| {
144 self.poll_with_mutable(cx, |s, cx| match s.send_slice(buffer) {
145 Ok(0) => {
146 s.register_send_waker(cx.waker());
147 Poll::Pending
148 }
149 Ok(n) => Poll::Ready(Ok(n)),
150 Err(tcp::SendError::InvalidState) => Poll::Ready(Err(Error::ConnectionReset)),
151 })
152 })
153 .await
154 }
155
156 pub async fn flush(&mut self) -> Result<()> {
157 poll_fn(|cx| {
158 self.poll_with_mutable(cx, |s, cx| {
159 let data_pending = (s.send_queue() > 0) && s.state() != tcp::State::Closed;
160 let fin_pending = matches!(
161 s.state(),
162 tcp::State::FinWait1 | tcp::State::Closing | tcp::State::LastAck
163 );
164 let rst_pending = s.state() == tcp::State::Closed && s.remote_endpoint().is_some();
165
166 if data_pending || fin_pending || rst_pending {
167 s.register_send_waker(cx.waker());
168 Poll::Pending
169 } else {
170 Poll::Ready(Ok(()))
171 }
172 })
173 })
174 .await
175 }
176
177 pub async fn close(&mut self) {
178 self.context.closed = true;
179 self.with_mutable(tcp::Socket::close).await
180 }
181
182 pub async fn close_forced(&mut self) {
183 self.context.closed = true;
184 self.with_mutable(tcp::Socket::abort).await
185 }
186
187 pub async fn get_read_capacity(&self) -> usize {
188 self.with(tcp::Socket::recv_capacity).await
189 }
190
191 pub async fn get_write_capacity(&self) -> usize {
192 self.with(tcp::Socket::send_capacity).await
193 }
194
195 pub async fn get_write_queue_size(&self) -> usize {
196 self.with(tcp::Socket::send_queue).await
197 }
198
199 pub async fn get_read_queue_size(&self) -> usize {
200 self.with(tcp::Socket::recv_queue).await
201 }
202
203 pub async fn get_local_endpoint(&self) -> Result<Option<(IpAddress, Port)>> {
204 let endpoint = self.with(tcp::Socket::local_endpoint).await;
205
206 Ok(endpoint.map(|e| (IpAddress::from_smoltcp(&e.addr), Port::from_inner(e.port))))
207 }
208
209 pub async fn get_remote_endpoint(&self) -> Result<Option<(IpAddress, Port)>> {
210 let endpoint = self.with(tcp::Socket::remote_endpoint).await;
211
212 Ok(endpoint.map(|e| (IpAddress::from_smoltcp(&e.addr), Port::from_inner(e.port))))
213 }
214
215 pub async fn set_keep_alive(&mut self, keep_alive: Option<Duration>) {
216 let keep_alive = keep_alive.map(Duration::into_smoltcp);
217 self.with_mutable(|socket| socket.set_keep_alive(keep_alive))
218 .await
219 }
220
221 pub async fn set_hop_limit(&mut self, hop_limit: Option<u8>) {
222 self.with_mutable(|socket| socket.set_hop_limit(hop_limit))
223 .await
224 }
225
226 pub async fn can_read(&self) -> bool {
227 self.with(tcp::Socket::can_recv).await
228 }
229
230 pub async fn can_write(&self) -> bool {
231 self.with(tcp::Socket::can_send).await
232 }
233
234 pub async fn may_read(&self) -> bool {
235 self.with(tcp::Socket::may_recv).await
236 }
237
238 pub async fn may_write(&self) -> bool {
239 self.with(tcp::Socket::may_send).await
240 }
241}
242
243impl Drop for TcpSocket {
244 fn drop(&mut self) {
245 if self.context.closed {
246 return;
247 }
248 log::warning!("TCP socket dropped without being closed. Forcing closure.");
249 block_on(self.with_mutable(tcp::Socket::close));
250 }
251}
252
253#[cfg(test)]
254mod tests {
255
256 extern crate std;
257
258 use synchronization::{blocking_mutex::raw::CriticalSectionRawMutex, mutex::Mutex};
259
260 use crate::tests::initialize;
261
262 use super::*;
263
264 static TEST_MUTEX: Mutex<CriticalSectionRawMutex, ()> = Mutex::new(());
265
266 #[task::test]
267 async fn test_tcp_connect() {
268 let _lock = TEST_MUTEX.lock().await;
269 let network_manager = initialize().await;
270 let port = Port::from_inner(51001);
271 use synchronization::{Arc, blocking_mutex::raw::CriticalSectionRawMutex, signal::Signal};
272 let server_ready = Arc::new(Signal::<CriticalSectionRawMutex, ()>::new());
273 let connection_established = Arc::new(Signal::<CriticalSectionRawMutex, ()>::new());
274 let client_done = Arc::new(Signal::<CriticalSectionRawMutex, ()>::new());
275 let server_ready_clone = server_ready.clone();
276 let connection_established_clone = connection_established.clone();
277 let client_done_clone = client_done.clone();
278 let mut listener = network_manager
279 .new_tcp_socket(1024, 1024, None)
280 .await
281 .expect("Failed to create listener socket");
282 let task_manager = task::get_instance();
283 let current_task = task_manager.get_current_task_identifier().await;
284 let (listen_task, _) = task_manager
285 .spawn(current_task, "TCP Listen Task", None, move |_| async move {
286 server_ready_clone.signal(());
287 let accept_result = listener.accept(Some([127, 0, 0, 1]), port).await;
288 if accept_result.is_err() {
289 return;
290 }
291 connection_established_clone.signal(());
292 client_done_clone.wait().await;
293 listener.close_forced().await;
294 })
295 .await
296 .unwrap();
297 for _ in 0..5 {
298 task::sleep(Duration::from_milliseconds(20)).await;
299 }
300 server_ready.wait().await;
301 task::sleep(Duration::from_milliseconds(200)).await;
302 let mut client = network_manager
303 .new_tcp_socket(1024, 1024, None)
304 .await
305 .expect("Failed to create client socket");
306 let connect_result = client.connect([127, 0, 0, 1], port).await;
307 if connect_result.is_err() {
308 return;
309 }
310 connection_established.wait().await;
311 let _endpoint = client.get_local_endpoint().await;
312 task::sleep(Duration::from_milliseconds(100)).await;
313 client.close_forced().await;
314 client_done.signal(());
315 listen_task.join().await;
316 }
317
318 #[task::test]
319 async fn test_tcp_send_receive() {
320 use synchronization::{Arc, blocking_mutex::raw::CriticalSectionRawMutex, signal::Signal};
321 let _lock = TEST_MUTEX.lock().await;
322 let network_manager = initialize().await;
323 let port = Port::from_inner(51002);
324 let mut listener = network_manager
325 .new_tcp_socket(2048, 2048, None)
326 .await
327 .expect("Failed to create listener");
328 let server_ready = Arc::new(Signal::<CriticalSectionRawMutex, ()>::new());
329 let server_ready_clone = server_ready.clone();
330 let task_manager = task::get_instance();
331 let current_task = task_manager.get_current_task_identifier().await;
332 let (_server_task, _) = task_manager
333 .spawn(current_task, "TCP Server Task", None, move |_| async move {
334 server_ready_clone.signal(());
335 let accept_result = listener.accept(Some([127, 0, 0, 1]), port).await;
336 if accept_result.is_err() {
337 return;
338 }
339 let mut buffer = [0u8; 1024];
340 match listener.read(&mut buffer).await {
341 Ok(size) => {
342 assert_eq!(&buffer[..size], b"Hello, TCP!", "Received data mismatch");
343 }
344 Err(_) => {
345 return;
346 }
347 }
348 let response = b"Hello back!";
349 if let Err(_) = listener.write(response).await {
350 return;
351 }
352 if let Err(_) = listener.flush().await {
353 return;
354 }
355 task::sleep(Duration::from_milliseconds(100)).await;
356 listener.close_forced().await;
357 })
358 .await
359 .unwrap();
360 server_ready.wait().await;
361 let mut client = network_manager
362 .new_tcp_socket(2048, 2048, None)
363 .await
364 .expect("Failed to create client");
365 let connect_result = client.connect([127, 0, 0, 1], port).await;
366 if connect_result.is_err() {
367 return;
368 }
369 task::sleep(Duration::from_milliseconds(100)).await;
370 if let Err(_) = client.write(b"Hello, TCP!").await {
371 return;
372 }
373 if let Err(_) = client.flush().await {
374 return;
375 }
376 let mut response_buffer = [0u8; 1024];
377 match client.read(&mut response_buffer).await {
378 Ok(size) => {
379 assert_eq!(
380 &response_buffer[..size],
381 b"Hello back!",
382 "Response data mismatch"
383 );
384 }
385 Err(_) => {
386 return;
387 }
388 }
389 client.close_forced().await;
390 }
391
392 #[task::test]
393 async fn test_tcp_endpoints() {
394 let _lock = TEST_MUTEX.lock().await;
395 let network_manager = initialize().await;
396 let mut socket = network_manager
397 .new_tcp_socket(1024, 1024, None)
398 .await
399 .expect("Failed to create socket");
400 let local = socket
401 .get_local_endpoint()
402 .await
403 .expect("Failed to get local endpoint");
404 let remote = socket
405 .get_remote_endpoint()
406 .await
407 .expect("Failed to get remote endpoint");
408 assert!(
409 local.is_none(),
410 "TCP endpoint | Local endpoint should be None before connection"
411 );
412 assert!(
413 remote.is_none(),
414 "Remote endpoint should be None before connection"
415 );
416 let port = Port::from_inner(51003);
417 let mut listener = network_manager
418 .new_tcp_socket(1024, 1024, None)
419 .await
420 .expect("Failed to create listener");
421 use synchronization::{Arc, blocking_mutex::raw::CriticalSectionRawMutex, signal::Signal};
422 let server_ready = Arc::new(Signal::<CriticalSectionRawMutex, ()>::new());
423 let connection_ready = Arc::new(Signal::<CriticalSectionRawMutex, ()>::new());
424 let endpoints_checked = Arc::new(Signal::<CriticalSectionRawMutex, ()>::new());
425 let server_ready_clone = server_ready.clone();
426 let connection_ready_clone = connection_ready.clone();
427 let endpoints_checked_clone = endpoints_checked.clone();
428 let task_manager = task::get_instance();
429 let current_task = task_manager.get_current_task_identifier().await;
430 let (listen_task, _) = task_manager
431 .spawn(
432 current_task,
433 "TCP Endpoint Listen",
434 None,
435 move |_| async move {
436 server_ready_clone.signal(());
437 let accept_result = listener.accept(Some([127, 0, 0, 1]), port).await;
438 if accept_result.is_err() {
439 return;
440 }
441 connection_ready_clone.signal(());
442 endpoints_checked_clone.wait().await;
443 task::sleep(Duration::from_milliseconds(100)).await;
444 listener.close_forced().await;
445 },
446 )
447 .await
448 .unwrap();
449 for _ in 0..5 {
450 task::sleep(Duration::from_milliseconds(20)).await;
451 }
452 server_ready.wait().await;
453 task::sleep(Duration::from_milliseconds(200)).await;
454 let connect_result = socket.connect([127, 0, 0, 1], port).await;
455 if connect_result.is_err() {
456 return;
457 }
458 task::sleep(Duration::from_milliseconds(50)).await;
459 connection_ready.wait().await;
460 task::sleep(Duration::from_milliseconds(50)).await;
461 let local = socket
462 .get_local_endpoint()
463 .await
464 .expect("Failed to get local endpoint");
465 let remote = socket
466 .get_remote_endpoint()
467 .await
468 .expect("Failed to get remote endpoint");
469 assert!(
470 local.is_some(),
471 "Local endpoint should be set after connection"
472 );
473 assert!(
474 remote.is_some(),
475 "Remote endpoint should be set after connection"
476 );
477 if let Some((addr, p)) = remote {
478 assert_eq!(
479 addr,
480 IpAddress::from([127, 0, 0, 1]),
481 "Remote address mismatch"
482 );
483 assert_eq!(p, port, "Remote port mismatch");
484 }
485 endpoints_checked.signal(());
486 task::sleep(Duration::from_milliseconds(100)).await;
487 socket.close_forced().await;
488 listen_task.join().await;
489 }
490
491 #[task::test]
492 async fn test_tcp_capacities() {
493 let _lock = TEST_MUTEX.lock().await;
494 let network_manager = initialize().await;
495 let tx_buffer = 2048;
496 let rx_buffer = 1024;
497 let mut socket = network_manager
498 .new_tcp_socket(tx_buffer, rx_buffer, None)
499 .await
500 .expect("Failed to create socket");
501 let read_cap = socket.get_read_capacity().await;
502 let write_cap = socket.get_write_capacity().await;
503 let read_queue = socket.get_read_queue_size().await;
504 let write_queue = socket.get_write_queue_size().await;
505 assert_eq!(read_cap, rx_buffer, "Read capacity mismatch");
506 assert_eq!(write_cap, tx_buffer, "Write capacity mismatch");
507 assert_eq!(read_queue, 0, "Read queue should be empty initially");
508 assert_eq!(write_queue, 0, "Write queue should be empty initially");
509 socket.close_forced().await;
510 }
511
512 #[task::test]
513 async fn test_tcp_flush() {
514 let _lock = TEST_MUTEX.lock().await;
515 let network_manager = initialize().await;
516 let port = Port::from_inner(51004);
517 let mut listener = network_manager
518 .new_tcp_socket(1024, 1024, None)
519 .await
520 .expect("Failed to create listener");
521 let task_manager = task::get_instance();
522 let current_task = task_manager.get_current_task_identifier().await;
523 let (_server_task, _) = task_manager
524 .spawn(
525 current_task,
526 "TCP Flush Server",
527 None,
528 move |_| async move {
529 let accept_result = listener.accept(Some([127, 0, 0, 1]), port).await;
530 if accept_result.is_err() {
531 return;
532 }
533 let mut buffer = [0u8; 1024];
534 let _ = listener.read(&mut buffer).await;
535 task::sleep(Duration::from_milliseconds(200)).await;
536 listener.close_forced().await;
537 },
538 )
539 .await
540 .unwrap();
541 task::sleep(Duration::from_milliseconds(300)).await;
542 let mut client = network_manager
543 .new_tcp_socket(1024, 1024, None)
544 .await
545 .expect("Failed to create client");
546 let connect_result = client.connect([127, 0, 0, 1], port).await;
547 if connect_result.is_err() {
548 return;
549 }
550 task::sleep(Duration::from_milliseconds(100)).await;
551 let write_result = client.write(b"Test data").await;
552 if write_result.is_err() {
553 return;
554 }
555 if let Err(_) = client.flush().await {
556 return;
557 }
558 task::sleep(Duration::from_milliseconds(200)).await;
559 client.close_forced().await;
560 }
561}