1use core::{
2 future::poll_fn,
3 task::{Context, Poll},
4};
5
6use crate::{
7 Duration, Error, IpAddress, IpEndpoint, IpListenEndpoint, Port, Result, SocketContext,
8};
9use embassy_futures::block_on;
10use smoltcp::socket::tcp;
11
12#[repr(transparent)]
13pub struct TcpSocket {
14 context: SocketContext,
15}
16
17impl TcpSocket {
18 pub(crate) fn new(context: SocketContext) -> Self {
19 Self { context }
20 }
21 pub async fn with<F, R>(&self, f: F) -> R
22 where
23 F: FnOnce(&tcp::Socket<'static>) -> R,
24 {
25 self.context.with(f).await
26 }
27
28 pub async fn with_mutable<F, R>(&self, f: F) -> R
29 where
30 F: FnOnce(&mut tcp::Socket<'static>) -> R,
31 {
32 self.context.with_mutable(f).await
33 }
34
35 pub fn poll_with<F, R>(&self, context: &mut Context<'_>, f: F) -> Poll<R>
36 where
37 F: FnOnce(&tcp::Socket, &mut Context<'_>) -> Poll<R>,
38 {
39 self.context.poll_with(context, f)
40 }
41
42 pub fn poll_with_mutable<F, R>(&self, context: &mut Context<'_>, f: F) -> Poll<R>
43 where
44 F: FnOnce(&mut tcp::Socket, &mut Context<'_>) -> Poll<R>,
45 {
46 self.context.poll_with_mutable(context, f)
47 }
48
49 pub async fn set_timeout(&mut self, timeout: Option<Duration>) {
50 let timeout = timeout.map(Duration::into_smoltcp);
51
52 self.with_mutable(|socket| socket.set_timeout(timeout))
53 .await
54 }
55
56 pub async fn accept(
57 &mut self,
58 address: Option<impl Into<IpAddress>>,
59 port: impl Into<Port>,
60 ) -> Result<()> {
61 let endpoint = IpListenEndpoint::new(address.map(Into::into), port.into()).into_smoltcp();
62
63 self.with_mutable(|s| {
64 if s.state() == tcp::State::Closed {
65 s.listen(endpoint).map_err(|e| match e {
66 tcp::ListenError::InvalidState => Error::InvalidState,
67 tcp::ListenError::Unaddressable => Error::InvalidPort,
68 })
69 } else {
70 Ok(())
71 }
72 })
73 .await?;
74
75 poll_fn(|cx| {
76 self.poll_with_mutable(cx, |s, cx| match s.state() {
77 tcp::State::Listen | tcp::State::SynSent | tcp::State::SynReceived => {
78 s.register_send_waker(cx.waker());
79 Poll::Pending
80 }
81 _ => Poll::Ready(Ok(())),
82 })
83 })
84 .await
85 }
86
87 pub async fn connect(
88 &mut self,
89 address: impl Into<IpAddress>,
90 port: impl Into<Port>,
91 ) -> Result<()> {
92 let endpoint = IpEndpoint::new(address.into(), port.into()).into_smoltcp();
93
94 self.context
95 .stack
96 .with_mutable(|stack| {
97 let local_port = stack.get_next_port().into_inner();
98
99 let socket: &mut tcp::Socket<'static> = stack.sockets.get_mut(self.context.handle);
100
101 match socket.connect(stack.interface.context(), endpoint, local_port) {
102 Ok(()) => Ok(()),
103 Err(tcp::ConnectError::InvalidState) => Err(Error::InvalidState),
104 Err(tcp::ConnectError::Unaddressable) => Err(Error::NoRoute),
105 }
106 })
107 .await?;
108
109 poll_fn(|cx| {
110 self.poll_with_mutable(cx, |socket, cx| match socket.state() {
111 tcp::State::Closed | tcp::State::TimeWait => {
112 Poll::Ready(Err(Error::ConnectionReset))
113 }
114 tcp::State::Listen => unreachable!(),
115 tcp::State::SynSent | tcp::State::SynReceived => {
116 socket.register_send_waker(cx.waker());
117 Poll::Pending
118 }
119 _ => Poll::Ready(Ok(())),
120 })
121 })
122 .await
123 }
124
125 pub async fn read(&mut self, buffer: &mut [u8]) -> Result<usize> {
126 poll_fn(|cx| {
127 self.poll_with_mutable(cx, |s, cx| match s.recv_slice(buffer) {
128 Ok(0) if buffer.is_empty() => Poll::Ready(Ok(0)),
129 Ok(0) => {
130 s.register_recv_waker(cx.waker());
131 Poll::Pending
132 }
133 Ok(n) => Poll::Ready(Ok(n)),
134 Err(tcp::RecvError::Finished) => Poll::Ready(Ok(0)),
135 Err(tcp::RecvError::InvalidState) => Poll::Ready(Err(Error::ConnectionReset)),
136 })
137 })
138 .await
139 }
140
141 pub async fn write(&mut self, buffer: &[u8]) -> Result<usize> {
142 poll_fn(|cx| {
143 self.poll_with_mutable(cx, |s, cx| match s.send_slice(buffer) {
144 Ok(0) => {
145 s.register_send_waker(cx.waker());
146 Poll::Pending
147 }
148 Ok(n) => Poll::Ready(Ok(n)),
149 Err(tcp::SendError::InvalidState) => Poll::Ready(Err(Error::ConnectionReset)),
150 })
151 })
152 .await
153 }
154
155 pub async fn flush(&mut self) -> Result<()> {
156 poll_fn(|cx| {
157 self.poll_with_mutable(cx, |s, cx| {
158 let data_pending = (s.send_queue() > 0) && s.state() != tcp::State::Closed;
159 let fin_pending = matches!(
160 s.state(),
161 tcp::State::FinWait1 | tcp::State::Closing | tcp::State::LastAck
162 );
163 let rst_pending = s.state() == tcp::State::Closed && s.remote_endpoint().is_some();
164
165 if data_pending || fin_pending || rst_pending {
166 s.register_send_waker(cx.waker());
167 Poll::Pending
168 } else {
169 Poll::Ready(Ok(()))
170 }
171 })
172 })
173 .await
174 }
175
176 pub async fn close(&mut self) {
177 self.context.closed = true;
178 self.with_mutable(tcp::Socket::close).await
179 }
180
181 pub async fn close_forced(&mut self) {
182 self.context.closed = true;
183 self.with_mutable(tcp::Socket::abort).await
184 }
185
186 pub async fn get_read_capacity(&self) -> usize {
187 self.with(tcp::Socket::recv_capacity).await
188 }
189
190 pub async fn get_write_capacity(&self) -> usize {
191 self.with(tcp::Socket::send_capacity).await
192 }
193
194 pub async fn get_write_queue_size(&self) -> usize {
195 self.with(tcp::Socket::send_queue).await
196 }
197
198 pub async fn get_read_queue_size(&self) -> usize {
199 self.with(tcp::Socket::recv_queue).await
200 }
201
202 pub async fn get_local_endpoint(&self) -> Result<Option<(IpAddress, Port)>> {
203 let endpoint = self.with(tcp::Socket::local_endpoint).await;
204
205 Ok(endpoint.map(|e| (IpAddress::from_smoltcp(&e.addr), Port::from_inner(e.port))))
206 }
207
208 pub async fn get_remote_endpoint(&self) -> Result<Option<(IpAddress, Port)>> {
209 let endpoint = self.with(tcp::Socket::remote_endpoint).await;
210
211 Ok(endpoint.map(|e| (IpAddress::from_smoltcp(&e.addr), Port::from_inner(e.port))))
212 }
213
214 pub async fn set_keep_alive(&mut self, keep_alive: Option<Duration>) {
215 let keep_alive = keep_alive.map(Duration::into_smoltcp);
216 self.with_mutable(|socket| socket.set_keep_alive(keep_alive))
217 .await
218 }
219
220 pub async fn set_hop_limit(&mut self, hop_limit: Option<u8>) {
221 self.with_mutable(|socket| socket.set_hop_limit(hop_limit))
222 .await
223 }
224
225 pub async fn can_read(&self) -> bool {
226 self.with(tcp::Socket::can_recv).await
227 }
228
229 pub async fn can_write(&self) -> bool {
230 self.with(tcp::Socket::can_send).await
231 }
232
233 pub async fn may_read(&self) -> bool {
234 self.with(tcp::Socket::may_recv).await
235 }
236
237 pub async fn may_write(&self) -> bool {
238 self.with(tcp::Socket::may_send).await
239 }
240}
241
242impl Drop for TcpSocket {
243 fn drop(&mut self) {
244 if self.context.closed {
245 return;
246 }
247 log::warning!("TCP socket dropped without being closed. Forcing closure.");
248 block_on(self.with_mutable(tcp::Socket::close));
249 }
250}
251
252#[cfg(test)]
253mod tests {
254
255 extern crate std;
256
257 use synchronization::{blocking_mutex::raw::CriticalSectionRawMutex, mutex::Mutex};
258
259 use crate::tests::initialize;
260
261 use super::*;
262
263 static TEST_MUTEX: Mutex<CriticalSectionRawMutex, ()> = Mutex::new(());
264
265 #[task::test]
266 async fn test_tcp_connect() {
267 let _lock = TEST_MUTEX.lock().await;
268 let network_manager = initialize().await;
269 let port = Port::from_inner(51001);
270 use synchronization::{Arc, blocking_mutex::raw::CriticalSectionRawMutex, signal::Signal};
271 let server_ready = Arc::new(Signal::<CriticalSectionRawMutex, ()>::new());
272 let connection_established = Arc::new(Signal::<CriticalSectionRawMutex, ()>::new());
273 let client_done = Arc::new(Signal::<CriticalSectionRawMutex, ()>::new());
274 let server_ready_clone = server_ready.clone();
275 let connection_established_clone = connection_established.clone();
276 let client_done_clone = client_done.clone();
277 let mut listener = network_manager
278 .new_tcp_socket(1024, 1024, None)
279 .await
280 .expect("Failed to create listener socket");
281 let task_manager = task::get_instance();
282 let current_task = task_manager.get_current_task_identifier().await;
283 let (listen_task, _) = task_manager
284 .spawn(current_task, "TCP Listen Task", None, move |_| async move {
285 server_ready_clone.signal(());
286 let accept_result = listener.accept(Some([127, 0, 0, 1]), port).await;
287 if accept_result.is_err() {
288 return;
289 }
290 connection_established_clone.signal(());
291 client_done_clone.wait().await;
292 listener.close_forced().await;
293 })
294 .await
295 .unwrap();
296 for _ in 0..5 {
297 task::sleep(Duration::from_milliseconds(20)).await;
298 }
299 server_ready.wait().await;
300 task::sleep(Duration::from_milliseconds(200)).await;
301 let mut client = network_manager
302 .new_tcp_socket(1024, 1024, None)
303 .await
304 .expect("Failed to create client socket");
305 let connect_result = client.connect([127, 0, 0, 1], port).await;
306 if connect_result.is_err() {
307 return;
308 }
309 connection_established.wait().await;
310 let _endpoint = client.get_local_endpoint().await;
311 task::sleep(Duration::from_milliseconds(100)).await;
312 client.close_forced().await;
313 client_done.signal(());
314 listen_task.join().await;
315 }
316
317 #[task::test]
318 async fn test_tcp_send_receive() {
319 use synchronization::{Arc, blocking_mutex::raw::CriticalSectionRawMutex, signal::Signal};
320 let _lock = TEST_MUTEX.lock().await;
321 let network_manager = initialize().await;
322 let port = Port::from_inner(51002);
323 let mut listener = network_manager
324 .new_tcp_socket(2048, 2048, None)
325 .await
326 .expect("Failed to create listener");
327 let server_ready = Arc::new(Signal::<CriticalSectionRawMutex, ()>::new());
328 let server_ready_clone = server_ready.clone();
329 let task_manager = task::get_instance();
330 let current_task = task_manager.get_current_task_identifier().await;
331 let (_server_task, _) = task_manager
332 .spawn(current_task, "TCP Server Task", None, move |_| async move {
333 server_ready_clone.signal(());
334 let accept_result = listener.accept(Some([127, 0, 0, 1]), port).await;
335 if accept_result.is_err() {
336 return;
337 }
338 let mut buffer = [0u8; 1024];
339 match listener.read(&mut buffer).await {
340 Ok(size) => {
341 assert_eq!(&buffer[..size], b"Hello, TCP!", "Received data mismatch");
342 }
343 Err(_) => {
344 return;
345 }
346 }
347 let response = b"Hello back!";
348 if let Err(_) = listener.write(response).await {
349 return;
350 }
351 if let Err(_) = listener.flush().await {
352 return;
353 }
354 task::sleep(Duration::from_milliseconds(100)).await;
355 listener.close_forced().await;
356 })
357 .await
358 .unwrap();
359 server_ready.wait().await;
360 let mut client = network_manager
361 .new_tcp_socket(2048, 2048, None)
362 .await
363 .expect("Failed to create client");
364 let connect_result = client.connect([127, 0, 0, 1], port).await;
365 if connect_result.is_err() {
366 return;
367 }
368 task::sleep(Duration::from_milliseconds(100)).await;
369 if let Err(_) = client.write(b"Hello, TCP!").await {
370 return;
371 }
372 if let Err(_) = client.flush().await {
373 return;
374 }
375 let mut response_buffer = [0u8; 1024];
376 match client.read(&mut response_buffer).await {
377 Ok(size) => {
378 assert_eq!(
379 &response_buffer[..size],
380 b"Hello back!",
381 "Response data mismatch"
382 );
383 }
384 Err(_) => {
385 return;
386 }
387 }
388 client.close_forced().await;
389 }
390
391 #[task::test]
392 async fn test_tcp_endpoints() {
393 let _lock = TEST_MUTEX.lock().await;
394 let network_manager = initialize().await;
395 let mut socket = network_manager
396 .new_tcp_socket(1024, 1024, None)
397 .await
398 .expect("Failed to create socket");
399 let local = socket
400 .get_local_endpoint()
401 .await
402 .expect("Failed to get local endpoint");
403 let remote = socket
404 .get_remote_endpoint()
405 .await
406 .expect("Failed to get remote endpoint");
407 assert!(
408 local.is_none(),
409 "TCP endpoint | Local endpoint should be None before connection"
410 );
411 assert!(
412 remote.is_none(),
413 "Remote endpoint should be None before connection"
414 );
415 let port = Port::from_inner(51003);
416 let mut listener = network_manager
417 .new_tcp_socket(1024, 1024, None)
418 .await
419 .expect("Failed to create listener");
420 use synchronization::{Arc, blocking_mutex::raw::CriticalSectionRawMutex, signal::Signal};
421 let server_ready = Arc::new(Signal::<CriticalSectionRawMutex, ()>::new());
422 let connection_ready = Arc::new(Signal::<CriticalSectionRawMutex, ()>::new());
423 let endpoints_checked = Arc::new(Signal::<CriticalSectionRawMutex, ()>::new());
424 let server_ready_clone = server_ready.clone();
425 let connection_ready_clone = connection_ready.clone();
426 let endpoints_checked_clone = endpoints_checked.clone();
427 let task_manager = task::get_instance();
428 let current_task = task_manager.get_current_task_identifier().await;
429 let (listen_task, _) = task_manager
430 .spawn(
431 current_task,
432 "TCP Endpoint Listen",
433 None,
434 move |_| async move {
435 server_ready_clone.signal(());
436 let accept_result = listener.accept(Some([127, 0, 0, 1]), port).await;
437 if accept_result.is_err() {
438 return;
439 }
440 connection_ready_clone.signal(());
441 endpoints_checked_clone.wait().await;
442 task::sleep(Duration::from_milliseconds(100)).await;
443 listener.close_forced().await;
444 },
445 )
446 .await
447 .unwrap();
448 for _ in 0..5 {
449 task::sleep(Duration::from_milliseconds(20)).await;
450 }
451 server_ready.wait().await;
452 task::sleep(Duration::from_milliseconds(200)).await;
453 let connect_result = socket.connect([127, 0, 0, 1], port).await;
454 if connect_result.is_err() {
455 return;
456 }
457 task::sleep(Duration::from_milliseconds(50)).await;
458 connection_ready.wait().await;
459 task::sleep(Duration::from_milliseconds(50)).await;
460 let local = socket
461 .get_local_endpoint()
462 .await
463 .expect("Failed to get local endpoint");
464 let remote = socket
465 .get_remote_endpoint()
466 .await
467 .expect("Failed to get remote endpoint");
468 assert!(
469 local.is_some(),
470 "Local endpoint should be set after connection"
471 );
472 assert!(
473 remote.is_some(),
474 "Remote endpoint should be set after connection"
475 );
476 if let Some((addr, p)) = remote {
477 assert_eq!(
478 addr,
479 IpAddress::from([127, 0, 0, 1]),
480 "Remote address mismatch"
481 );
482 assert_eq!(p, port, "Remote port mismatch");
483 }
484 endpoints_checked.signal(());
485 task::sleep(Duration::from_milliseconds(100)).await;
486 socket.close_forced().await;
487 listen_task.join().await;
488 }
489
490 #[task::test]
491 async fn test_tcp_capacities() {
492 let _lock = TEST_MUTEX.lock().await;
493 let network_manager = initialize().await;
494 let tx_buffer = 2048;
495 let rx_buffer = 1024;
496 let mut socket = network_manager
497 .new_tcp_socket(tx_buffer, rx_buffer, None)
498 .await
499 .expect("Failed to create socket");
500 let read_cap = socket.get_read_capacity().await;
501 let write_cap = socket.get_write_capacity().await;
502 let read_queue = socket.get_read_queue_size().await;
503 let write_queue = socket.get_write_queue_size().await;
504 assert_eq!(read_cap, rx_buffer, "Read capacity mismatch");
505 assert_eq!(write_cap, tx_buffer, "Write capacity mismatch");
506 assert_eq!(read_queue, 0, "Read queue should be empty initially");
507 assert_eq!(write_queue, 0, "Write queue should be empty initially");
508 socket.close_forced().await;
509 }
510
511 #[task::test]
512 async fn test_tcp_flush() {
513 let _lock = TEST_MUTEX.lock().await;
514 let network_manager = initialize().await;
515 let port = Port::from_inner(51004);
516 let mut listener = network_manager
517 .new_tcp_socket(1024, 1024, None)
518 .await
519 .expect("Failed to create listener");
520 let task_manager = task::get_instance();
521 let current_task = task_manager.get_current_task_identifier().await;
522 let (_server_task, _) = task_manager
523 .spawn(
524 current_task,
525 "TCP Flush Server",
526 None,
527 move |_| async move {
528 let accept_result = listener.accept(Some([127, 0, 0, 1]), port).await;
529 if accept_result.is_err() {
530 return;
531 }
532 let mut buffer = [0u8; 1024];
533 let _ = listener.read(&mut buffer).await;
534 task::sleep(Duration::from_milliseconds(200)).await;
535 listener.close_forced().await;
536 },
537 )
538 .await
539 .unwrap();
540 task::sleep(Duration::from_milliseconds(300)).await;
541 let mut client = network_manager
542 .new_tcp_socket(1024, 1024, None)
543 .await
544 .expect("Failed to create client");
545 let connect_result = client.connect([127, 0, 0, 1], port).await;
546 if connect_result.is_err() {
547 return;
548 }
549 task::sleep(Duration::from_milliseconds(100)).await;
550 let write_result = client.write(b"Test data").await;
551 if write_result.is_err() {
552 return;
553 }
554 if let Err(_) = client.flush().await {
555 return;
556 }
557 task::sleep(Duration::from_milliseconds(200)).await;
558 client.close_forced().await;
559 }
560}