network/socket/
tcp.rs

1use core::{
2    future::poll_fn,
3    task::{Context, Poll},
4};
5
6use crate::{
7    Duration, Error, IpAddress, IpEndpoint, IpListenEndpoint, Port, Result, SocketContext,
8};
9use embassy_futures::block_on;
10use smoltcp::socket::tcp;
11
12#[repr(transparent)]
13pub struct TcpSocket {
14    context: SocketContext,
15}
16
17impl TcpSocket {
18    pub(crate) fn new(context: SocketContext) -> Self {
19        Self { context }
20    }
21    pub async fn with<F, R>(&self, f: F) -> R
22    where
23        F: FnOnce(&tcp::Socket<'static>) -> R,
24    {
25        self.context.with(f).await
26    }
27
28    pub async fn with_mutable<F, R>(&self, f: F) -> R
29    where
30        F: FnOnce(&mut tcp::Socket<'static>) -> R,
31    {
32        self.context.with_mutable(f).await
33    }
34
35    pub fn poll_with<F, R>(&self, context: &mut Context<'_>, f: F) -> Poll<R>
36    where
37        F: FnOnce(&tcp::Socket, &mut Context<'_>) -> Poll<R>,
38    {
39        self.context.poll_with(context, f)
40    }
41
42    pub fn poll_with_mutable<F, R>(&self, context: &mut Context<'_>, f: F) -> Poll<R>
43    where
44        F: FnOnce(&mut tcp::Socket, &mut Context<'_>) -> Poll<R>,
45    {
46        self.context.poll_with_mutable(context, f)
47    }
48
49    pub async fn set_timeout(&mut self, timeout: Option<Duration>) {
50        let timeout = timeout.map(Duration::into_smoltcp);
51
52        self.with_mutable(|socket| socket.set_timeout(timeout))
53            .await
54    }
55
56    pub async fn accept(
57        &mut self,
58        address: Option<impl Into<IpAddress>>,
59        port: impl Into<Port>,
60    ) -> Result<()> {
61        let endpoint = IpListenEndpoint::new(address.map(Into::into), port.into()).into_smoltcp();
62
63        self.with_mutable(|s| {
64            if s.state() == tcp::State::Closed {
65                s.listen(endpoint).map_err(|e| match e {
66                    tcp::ListenError::InvalidState => Error::InvalidState,
67                    tcp::ListenError::Unaddressable => Error::InvalidPort,
68                })
69            } else {
70                Ok(())
71            }
72        })
73        .await?;
74
75        poll_fn(|cx| {
76            self.poll_with_mutable(cx, |s, cx| match s.state() {
77                tcp::State::Listen | tcp::State::SynSent | tcp::State::SynReceived => {
78                    s.register_send_waker(cx.waker());
79                    Poll::Pending
80                }
81                _ => Poll::Ready(Ok(())),
82            })
83        })
84        .await
85    }
86
87    pub async fn connect(
88        &mut self,
89        address: impl Into<IpAddress>,
90        port: impl Into<Port>,
91    ) -> Result<()> {
92        let endpoint = IpEndpoint::new(address.into(), port.into()).into_smoltcp();
93
94        self.context
95            .stack
96            .with_mutable(|stack| {
97                let local_port = stack.get_next_port().into_inner();
98
99                let socket: &mut tcp::Socket<'static> = stack.sockets.get_mut(self.context.handle);
100
101                match socket.connect(stack.interface.context(), endpoint, local_port) {
102                    Ok(()) => Ok(()),
103                    Err(tcp::ConnectError::InvalidState) => Err(Error::InvalidState),
104                    Err(tcp::ConnectError::Unaddressable) => Err(Error::NoRoute),
105                }
106            })
107            .await?;
108
109        poll_fn(|cx| {
110            self.poll_with_mutable(cx, |socket, cx| match socket.state() {
111                tcp::State::Closed | tcp::State::TimeWait => {
112                    Poll::Ready(Err(Error::ConnectionReset))
113                }
114                tcp::State::Listen => unreachable!(),
115                tcp::State::SynSent | tcp::State::SynReceived => {
116                    socket.register_send_waker(cx.waker());
117                    Poll::Pending
118                }
119                _ => Poll::Ready(Ok(())),
120            })
121        })
122        .await
123    }
124
125    pub async fn read(&mut self, buffer: &mut [u8]) -> Result<usize> {
126        poll_fn(|cx| {
127            self.poll_with_mutable(cx, |s, cx| match s.recv_slice(buffer) {
128                Ok(0) if buffer.is_empty() => Poll::Ready(Ok(0)),
129                Ok(0) => {
130                    s.register_recv_waker(cx.waker());
131                    Poll::Pending
132                }
133                Ok(n) => Poll::Ready(Ok(n)),
134                Err(tcp::RecvError::Finished) => Poll::Ready(Ok(0)),
135                Err(tcp::RecvError::InvalidState) => Poll::Ready(Err(Error::ConnectionReset)),
136            })
137        })
138        .await
139    }
140
141    pub async fn write(&mut self, buffer: &[u8]) -> Result<usize> {
142        poll_fn(|cx| {
143            self.poll_with_mutable(cx, |s, cx| match s.send_slice(buffer) {
144                Ok(0) => {
145                    s.register_send_waker(cx.waker());
146                    Poll::Pending
147                }
148                Ok(n) => Poll::Ready(Ok(n)),
149                Err(tcp::SendError::InvalidState) => Poll::Ready(Err(Error::ConnectionReset)),
150            })
151        })
152        .await
153    }
154
155    pub async fn flush(&mut self) -> Result<()> {
156        poll_fn(|cx| {
157            self.poll_with_mutable(cx, |s, cx| {
158                let data_pending = (s.send_queue() > 0) && s.state() != tcp::State::Closed;
159                let fin_pending = matches!(
160                    s.state(),
161                    tcp::State::FinWait1 | tcp::State::Closing | tcp::State::LastAck
162                );
163                let rst_pending = s.state() == tcp::State::Closed && s.remote_endpoint().is_some();
164
165                if data_pending || fin_pending || rst_pending {
166                    s.register_send_waker(cx.waker());
167                    Poll::Pending
168                } else {
169                    Poll::Ready(Ok(()))
170                }
171            })
172        })
173        .await
174    }
175
176    pub async fn close(&mut self) {
177        self.context.closed = true;
178        self.with_mutable(tcp::Socket::close).await
179    }
180
181    pub async fn close_forced(&mut self) {
182        self.context.closed = true;
183        self.with_mutable(tcp::Socket::abort).await
184    }
185
186    pub async fn get_read_capacity(&self) -> usize {
187        self.with(tcp::Socket::recv_capacity).await
188    }
189
190    pub async fn get_write_capacity(&self) -> usize {
191        self.with(tcp::Socket::send_capacity).await
192    }
193
194    pub async fn get_write_queue_size(&self) -> usize {
195        self.with(tcp::Socket::send_queue).await
196    }
197
198    pub async fn get_read_queue_size(&self) -> usize {
199        self.with(tcp::Socket::recv_queue).await
200    }
201
202    pub async fn get_local_endpoint(&self) -> Result<Option<(IpAddress, Port)>> {
203        let endpoint = self.with(tcp::Socket::local_endpoint).await;
204
205        Ok(endpoint.map(|e| (IpAddress::from_smoltcp(&e.addr), Port::from_inner(e.port))))
206    }
207
208    pub async fn get_remote_endpoint(&self) -> Result<Option<(IpAddress, Port)>> {
209        let endpoint = self.with(tcp::Socket::remote_endpoint).await;
210
211        Ok(endpoint.map(|e| (IpAddress::from_smoltcp(&e.addr), Port::from_inner(e.port))))
212    }
213
214    pub async fn set_keep_alive(&mut self, keep_alive: Option<Duration>) {
215        let keep_alive = keep_alive.map(Duration::into_smoltcp);
216        self.with_mutable(|socket| socket.set_keep_alive(keep_alive))
217            .await
218    }
219
220    pub async fn set_hop_limit(&mut self, hop_limit: Option<u8>) {
221        self.with_mutable(|socket| socket.set_hop_limit(hop_limit))
222            .await
223    }
224
225    pub async fn can_read(&self) -> bool {
226        self.with(tcp::Socket::can_recv).await
227    }
228
229    pub async fn can_write(&self) -> bool {
230        self.with(tcp::Socket::can_send).await
231    }
232
233    pub async fn may_read(&self) -> bool {
234        self.with(tcp::Socket::may_recv).await
235    }
236
237    pub async fn may_write(&self) -> bool {
238        self.with(tcp::Socket::may_send).await
239    }
240}
241
242impl Drop for TcpSocket {
243    fn drop(&mut self) {
244        if self.context.closed {
245            return;
246        }
247        log::warning!("TCP socket dropped without being closed. Forcing closure.");
248        block_on(self.with_mutable(tcp::Socket::close));
249    }
250}
251
252#[cfg(test)]
253mod tests {
254
255    extern crate std;
256
257    use synchronization::{blocking_mutex::raw::CriticalSectionRawMutex, mutex::Mutex};
258
259    use crate::tests::initialize;
260
261    use super::*;
262
263    static TEST_MUTEX: Mutex<CriticalSectionRawMutex, ()> = Mutex::new(());
264
265    #[task::test]
266    async fn test_tcp_connect() {
267        let _lock = TEST_MUTEX.lock().await;
268        let network_manager = initialize().await;
269        let port = Port::from_inner(51001);
270        use synchronization::{Arc, blocking_mutex::raw::CriticalSectionRawMutex, signal::Signal};
271        let server_ready = Arc::new(Signal::<CriticalSectionRawMutex, ()>::new());
272        let connection_established = Arc::new(Signal::<CriticalSectionRawMutex, ()>::new());
273        let client_done = Arc::new(Signal::<CriticalSectionRawMutex, ()>::new());
274        let server_ready_clone = server_ready.clone();
275        let connection_established_clone = connection_established.clone();
276        let client_done_clone = client_done.clone();
277        let mut listener = network_manager
278            .new_tcp_socket(1024, 1024, None)
279            .await
280            .expect("Failed to create listener socket");
281        let task_manager = task::get_instance();
282        let current_task = task_manager.get_current_task_identifier().await;
283        let (listen_task, _) = task_manager
284            .spawn(current_task, "TCP Listen Task", None, move |_| async move {
285                server_ready_clone.signal(());
286                let accept_result = listener.accept(Some([127, 0, 0, 1]), port).await;
287                if accept_result.is_err() {
288                    return;
289                }
290                connection_established_clone.signal(());
291                client_done_clone.wait().await;
292                listener.close_forced().await;
293            })
294            .await
295            .unwrap();
296        for _ in 0..5 {
297            task::sleep(Duration::from_milliseconds(20)).await;
298        }
299        server_ready.wait().await;
300        task::sleep(Duration::from_milliseconds(200)).await;
301        let mut client = network_manager
302            .new_tcp_socket(1024, 1024, None)
303            .await
304            .expect("Failed to create client socket");
305        let connect_result = client.connect([127, 0, 0, 1], port).await;
306        if connect_result.is_err() {
307            return;
308        }
309        connection_established.wait().await;
310        let _endpoint = client.get_local_endpoint().await;
311        task::sleep(Duration::from_milliseconds(100)).await;
312        client.close_forced().await;
313        client_done.signal(());
314        listen_task.join().await;
315    }
316
317    #[task::test]
318    async fn test_tcp_send_receive() {
319        use synchronization::{Arc, blocking_mutex::raw::CriticalSectionRawMutex, signal::Signal};
320        let _lock = TEST_MUTEX.lock().await;
321        let network_manager = initialize().await;
322        let port = Port::from_inner(51002);
323        let mut listener = network_manager
324            .new_tcp_socket(2048, 2048, None)
325            .await
326            .expect("Failed to create listener");
327        let server_ready = Arc::new(Signal::<CriticalSectionRawMutex, ()>::new());
328        let server_ready_clone = server_ready.clone();
329        let task_manager = task::get_instance();
330        let current_task = task_manager.get_current_task_identifier().await;
331        let (_server_task, _) = task_manager
332            .spawn(current_task, "TCP Server Task", None, move |_| async move {
333                server_ready_clone.signal(());
334                let accept_result = listener.accept(Some([127, 0, 0, 1]), port).await;
335                if accept_result.is_err() {
336                    return;
337                }
338                let mut buffer = [0u8; 1024];
339                match listener.read(&mut buffer).await {
340                    Ok(size) => {
341                        assert_eq!(&buffer[..size], b"Hello, TCP!", "Received data mismatch");
342                    }
343                    Err(_) => {
344                        return;
345                    }
346                }
347                let response = b"Hello back!";
348                if let Err(_) = listener.write(response).await {
349                    return;
350                }
351                if let Err(_) = listener.flush().await {
352                    return;
353                }
354                task::sleep(Duration::from_milliseconds(100)).await;
355                listener.close_forced().await;
356            })
357            .await
358            .unwrap();
359        server_ready.wait().await;
360        let mut client = network_manager
361            .new_tcp_socket(2048, 2048, None)
362            .await
363            .expect("Failed to create client");
364        let connect_result = client.connect([127, 0, 0, 1], port).await;
365        if connect_result.is_err() {
366            return;
367        }
368        task::sleep(Duration::from_milliseconds(100)).await;
369        if let Err(_) = client.write(b"Hello, TCP!").await {
370            return;
371        }
372        if let Err(_) = client.flush().await {
373            return;
374        }
375        let mut response_buffer = [0u8; 1024];
376        match client.read(&mut response_buffer).await {
377            Ok(size) => {
378                assert_eq!(
379                    &response_buffer[..size],
380                    b"Hello back!",
381                    "Response data mismatch"
382                );
383            }
384            Err(_) => {
385                return;
386            }
387        }
388        client.close_forced().await;
389    }
390
391    #[task::test]
392    async fn test_tcp_endpoints() {
393        let _lock = TEST_MUTEX.lock().await;
394        let network_manager = initialize().await;
395        let mut socket = network_manager
396            .new_tcp_socket(1024, 1024, None)
397            .await
398            .expect("Failed to create socket");
399        let local = socket
400            .get_local_endpoint()
401            .await
402            .expect("Failed to get local endpoint");
403        let remote = socket
404            .get_remote_endpoint()
405            .await
406            .expect("Failed to get remote endpoint");
407        assert!(
408            local.is_none(),
409            "TCP endpoint | Local endpoint should be None before connection"
410        );
411        assert!(
412            remote.is_none(),
413            "Remote endpoint should be None before connection"
414        );
415        let port = Port::from_inner(51003);
416        let mut listener = network_manager
417            .new_tcp_socket(1024, 1024, None)
418            .await
419            .expect("Failed to create listener");
420        use synchronization::{Arc, blocking_mutex::raw::CriticalSectionRawMutex, signal::Signal};
421        let server_ready = Arc::new(Signal::<CriticalSectionRawMutex, ()>::new());
422        let connection_ready = Arc::new(Signal::<CriticalSectionRawMutex, ()>::new());
423        let endpoints_checked = Arc::new(Signal::<CriticalSectionRawMutex, ()>::new());
424        let server_ready_clone = server_ready.clone();
425        let connection_ready_clone = connection_ready.clone();
426        let endpoints_checked_clone = endpoints_checked.clone();
427        let task_manager = task::get_instance();
428        let current_task = task_manager.get_current_task_identifier().await;
429        let (listen_task, _) = task_manager
430            .spawn(
431                current_task,
432                "TCP Endpoint Listen",
433                None,
434                move |_| async move {
435                    server_ready_clone.signal(());
436                    let accept_result = listener.accept(Some([127, 0, 0, 1]), port).await;
437                    if accept_result.is_err() {
438                        return;
439                    }
440                    connection_ready_clone.signal(());
441                    endpoints_checked_clone.wait().await;
442                    task::sleep(Duration::from_milliseconds(100)).await;
443                    listener.close_forced().await;
444                },
445            )
446            .await
447            .unwrap();
448        for _ in 0..5 {
449            task::sleep(Duration::from_milliseconds(20)).await;
450        }
451        server_ready.wait().await;
452        task::sleep(Duration::from_milliseconds(200)).await;
453        let connect_result = socket.connect([127, 0, 0, 1], port).await;
454        if connect_result.is_err() {
455            return;
456        }
457        task::sleep(Duration::from_milliseconds(50)).await;
458        connection_ready.wait().await;
459        task::sleep(Duration::from_milliseconds(50)).await;
460        let local = socket
461            .get_local_endpoint()
462            .await
463            .expect("Failed to get local endpoint");
464        let remote = socket
465            .get_remote_endpoint()
466            .await
467            .expect("Failed to get remote endpoint");
468        assert!(
469            local.is_some(),
470            "Local endpoint should be set after connection"
471        );
472        assert!(
473            remote.is_some(),
474            "Remote endpoint should be set after connection"
475        );
476        if let Some((addr, p)) = remote {
477            assert_eq!(
478                addr,
479                IpAddress::from([127, 0, 0, 1]),
480                "Remote address mismatch"
481            );
482            assert_eq!(p, port, "Remote port mismatch");
483        }
484        endpoints_checked.signal(());
485        task::sleep(Duration::from_milliseconds(100)).await;
486        socket.close_forced().await;
487        listen_task.join().await;
488    }
489
490    #[task::test]
491    async fn test_tcp_capacities() {
492        let _lock = TEST_MUTEX.lock().await;
493        let network_manager = initialize().await;
494        let tx_buffer = 2048;
495        let rx_buffer = 1024;
496        let mut socket = network_manager
497            .new_tcp_socket(tx_buffer, rx_buffer, None)
498            .await
499            .expect("Failed to create socket");
500        let read_cap = socket.get_read_capacity().await;
501        let write_cap = socket.get_write_capacity().await;
502        let read_queue = socket.get_read_queue_size().await;
503        let write_queue = socket.get_write_queue_size().await;
504        assert_eq!(read_cap, rx_buffer, "Read capacity mismatch");
505        assert_eq!(write_cap, tx_buffer, "Write capacity mismatch");
506        assert_eq!(read_queue, 0, "Read queue should be empty initially");
507        assert_eq!(write_queue, 0, "Write queue should be empty initially");
508        socket.close_forced().await;
509    }
510
511    #[task::test]
512    async fn test_tcp_flush() {
513        let _lock = TEST_MUTEX.lock().await;
514        let network_manager = initialize().await;
515        let port = Port::from_inner(51004);
516        let mut listener = network_manager
517            .new_tcp_socket(1024, 1024, None)
518            .await
519            .expect("Failed to create listener");
520        let task_manager = task::get_instance();
521        let current_task = task_manager.get_current_task_identifier().await;
522        let (_server_task, _) = task_manager
523            .spawn(
524                current_task,
525                "TCP Flush Server",
526                None,
527                move |_| async move {
528                    let accept_result = listener.accept(Some([127, 0, 0, 1]), port).await;
529                    if accept_result.is_err() {
530                        return;
531                    }
532                    let mut buffer = [0u8; 1024];
533                    let _ = listener.read(&mut buffer).await;
534                    task::sleep(Duration::from_milliseconds(200)).await;
535                    listener.close_forced().await;
536                },
537            )
538            .await
539            .unwrap();
540        task::sleep(Duration::from_milliseconds(300)).await;
541        let mut client = network_manager
542            .new_tcp_socket(1024, 1024, None)
543            .await
544            .expect("Failed to create client");
545        let connect_result = client.connect([127, 0, 0, 1], port).await;
546        if connect_result.is_err() {
547            return;
548        }
549        task::sleep(Duration::from_milliseconds(100)).await;
550        let write_result = client.write(b"Test data").await;
551        if write_result.is_err() {
552            return;
553        }
554        if let Err(_) = client.flush().await {
555            return;
556        }
557        task::sleep(Duration::from_milliseconds(200)).await;
558        client.close_forced().await;
559    }
560}