Skip to main content

network/socket/
tcp.rs

1use core::{
2    future::poll_fn,
3    task::{Context, Poll},
4};
5
6use crate::{
7    Duration, Error, IpAddress, IpEndpoint, IpListenEndpoint, Port, Result, SocketContext,
8};
9use embassy_futures::block_on;
10use smoltcp::socket::tcp;
11
12#[repr(transparent)]
13pub struct TcpSocket {
14    context: SocketContext,
15}
16
17impl TcpSocket {
18    pub(crate) fn new(context: SocketContext) -> Self {
19        Self { context }
20    }
21    pub async fn with<F, R>(&self, f: F) -> R
22    where
23        F: FnOnce(&tcp::Socket<'static>) -> R,
24    {
25        self.context.with(f).await
26    }
27
28    pub async fn with_mutable<F, R>(&self, f: F) -> R
29    where
30        F: FnOnce(&mut tcp::Socket<'static>) -> R,
31    {
32        self.context.with_mutable(f).await
33    }
34
35    pub fn poll_with<F, R>(&self, context: &mut Context<'_>, f: F) -> Poll<R>
36    where
37        F: FnOnce(&tcp::Socket, &mut Context<'_>) -> Poll<R>,
38    {
39        self.context.poll_with(context, f)
40    }
41
42    pub fn poll_with_mutable<F, R>(&self, context: &mut Context<'_>, f: F) -> Poll<R>
43    where
44        F: FnOnce(&mut tcp::Socket, &mut Context<'_>) -> Poll<R>,
45    {
46        self.context.poll_with_mutable(context, f)
47    }
48
49    pub async fn set_timeout(&mut self, timeout: Option<Duration>) {
50        let timeout = timeout.map(Duration::into_smoltcp);
51
52        self.with_mutable(|socket| socket.set_timeout(timeout))
53            .await
54    }
55
56    pub async fn accept(
57        &mut self,
58        address: Option<impl Into<IpAddress>>,
59        port: impl Into<Port>,
60    ) -> Result<()> {
61        let endpoint = IpListenEndpoint::new(address.map(Into::into), port.into()).into_smoltcp();
62
63        self.with_mutable(|s| {
64            if s.state() == tcp::State::Closed {
65                s.listen(endpoint).map_err(|e| match e {
66                    tcp::ListenError::InvalidState => Error::InvalidState,
67                    tcp::ListenError::Unaddressable => Error::InvalidPort,
68                })
69            } else {
70                Ok(())
71            }
72        })
73        .await?;
74
75        poll_fn(|cx| {
76            self.poll_with_mutable(cx, |s, cx| match s.state() {
77                tcp::State::Listen | tcp::State::SynSent | tcp::State::SynReceived => {
78                    s.register_send_waker(cx.waker());
79                    Poll::Pending
80                }
81                _ => Poll::Ready(Ok(())),
82            })
83        })
84        .await
85    }
86
87    pub async fn connect(
88        &mut self,
89        address: impl Into<IpAddress>,
90        port: impl Into<Port>,
91    ) -> Result<()> {
92        let endpoint = IpEndpoint::new(address.into(), port.into()).into_smoltcp();
93
94        self.context
95            .stack
96            .with_mutable(|stack| {
97                let local_port = stack.get_next_port().into_inner();
98
99                let socket: &mut tcp::Socket<'static> = stack.sockets.get_mut(self.context.handle);
100
101                match socket.connect(stack.interface.context(), endpoint, local_port) {
102                    Ok(()) => Ok(()),
103                    Err(tcp::ConnectError::InvalidState) => Err(Error::InvalidState),
104                    Err(tcp::ConnectError::Unaddressable) => Err(Error::NoRoute),
105                }
106            })
107            .await?;
108
109        poll_fn(|cx| {
110            self.poll_with_mutable(cx, |socket, cx| match socket.state() {
111                tcp::State::Closed | tcp::State::TimeWait => {
112                    Poll::Ready(Err(Error::ConnectionReset))
113                }
114                tcp::State::Listen => unreachable!(),
115                tcp::State::SynSent | tcp::State::SynReceived => {
116                    socket.register_send_waker(cx.waker());
117                    Poll::Pending
118                }
119                _ => Poll::Ready(Ok(())),
120            })
121        })
122        .await
123    }
124
125    pub async fn read(&mut self, buffer: &mut [u8]) -> Result<usize> {
126        poll_fn(|cx| {
127            self.poll_with_mutable(cx, |s, cx| match s.recv_slice(buffer) {
128                Ok(0) if buffer.is_empty() => Poll::Ready(Ok(0)),
129                Ok(0) if !s.may_recv() => Poll::Ready(Ok(0)),
130                Ok(0) => {
131                    s.register_recv_waker(cx.waker());
132                    Poll::Pending
133                }
134                Ok(n) => Poll::Ready(Ok(n)),
135                Err(tcp::RecvError::Finished) => Poll::Ready(Ok(0)),
136                Err(tcp::RecvError::InvalidState) => Poll::Ready(Err(Error::ConnectionReset)),
137            })
138        })
139        .await
140    }
141
142    pub async fn write(&mut self, buffer: &[u8]) -> Result<usize> {
143        poll_fn(|cx| {
144            self.poll_with_mutable(cx, |s, cx| match s.send_slice(buffer) {
145                Ok(0) => {
146                    s.register_send_waker(cx.waker());
147                    Poll::Pending
148                }
149                Ok(n) => Poll::Ready(Ok(n)),
150                Err(tcp::SendError::InvalidState) => Poll::Ready(Err(Error::ConnectionReset)),
151            })
152        })
153        .await
154    }
155
156    pub async fn flush(&mut self) -> Result<()> {
157        poll_fn(|cx| {
158            self.poll_with_mutable(cx, |s, cx| {
159                let data_pending = (s.send_queue() > 0) && s.state() != tcp::State::Closed;
160                let fin_pending = matches!(
161                    s.state(),
162                    tcp::State::FinWait1 | tcp::State::Closing | tcp::State::LastAck
163                );
164                let rst_pending = s.state() == tcp::State::Closed && s.remote_endpoint().is_some();
165
166                if data_pending || fin_pending || rst_pending {
167                    s.register_send_waker(cx.waker());
168                    Poll::Pending
169                } else {
170                    Poll::Ready(Ok(()))
171                }
172            })
173        })
174        .await
175    }
176
177    pub async fn close(&mut self) {
178        self.context.closed = true;
179        self.with_mutable(tcp::Socket::close).await
180    }
181
182    pub async fn close_forced(&mut self) {
183        self.context.closed = true;
184        self.with_mutable(tcp::Socket::abort).await
185    }
186
187    pub async fn get_read_capacity(&self) -> usize {
188        self.with(tcp::Socket::recv_capacity).await
189    }
190
191    pub async fn get_write_capacity(&self) -> usize {
192        self.with(tcp::Socket::send_capacity).await
193    }
194
195    pub async fn get_write_queue_size(&self) -> usize {
196        self.with(tcp::Socket::send_queue).await
197    }
198
199    pub async fn get_read_queue_size(&self) -> usize {
200        self.with(tcp::Socket::recv_queue).await
201    }
202
203    pub async fn get_local_endpoint(&self) -> Result<Option<(IpAddress, Port)>> {
204        let endpoint = self.with(tcp::Socket::local_endpoint).await;
205
206        Ok(endpoint.map(|e| (IpAddress::from_smoltcp(&e.addr), Port::from_inner(e.port))))
207    }
208
209    pub async fn get_remote_endpoint(&self) -> Result<Option<(IpAddress, Port)>> {
210        let endpoint = self.with(tcp::Socket::remote_endpoint).await;
211
212        Ok(endpoint.map(|e| (IpAddress::from_smoltcp(&e.addr), Port::from_inner(e.port))))
213    }
214
215    pub async fn set_keep_alive(&mut self, keep_alive: Option<Duration>) {
216        let keep_alive = keep_alive.map(Duration::into_smoltcp);
217        self.with_mutable(|socket| socket.set_keep_alive(keep_alive))
218            .await
219    }
220
221    pub async fn set_hop_limit(&mut self, hop_limit: Option<u8>) {
222        self.with_mutable(|socket| socket.set_hop_limit(hop_limit))
223            .await
224    }
225
226    pub async fn can_read(&self) -> bool {
227        self.with(tcp::Socket::can_recv).await
228    }
229
230    pub async fn can_write(&self) -> bool {
231        self.with(tcp::Socket::can_send).await
232    }
233
234    pub async fn may_read(&self) -> bool {
235        self.with(tcp::Socket::may_recv).await
236    }
237
238    pub async fn may_write(&self) -> bool {
239        self.with(tcp::Socket::may_send).await
240    }
241}
242
243impl Drop for TcpSocket {
244    fn drop(&mut self) {
245        if self.context.closed {
246            return;
247        }
248        log::warning!("TCP socket dropped without being closed. Forcing closure.");
249        block_on(self.with_mutable(tcp::Socket::close));
250    }
251}
252
253#[cfg(test)]
254mod tests {
255
256    extern crate std;
257
258    use synchronization::{blocking_mutex::raw::CriticalSectionRawMutex, mutex::Mutex};
259
260    use crate::tests::initialize;
261
262    use super::*;
263
264    static TEST_MUTEX: Mutex<CriticalSectionRawMutex, ()> = Mutex::new(());
265
266    #[task::test]
267    async fn test_tcp_connect() {
268        let _lock = TEST_MUTEX.lock().await;
269        let network_manager = initialize().await;
270        let port = Port::from_inner(51001);
271        use synchronization::{Arc, blocking_mutex::raw::CriticalSectionRawMutex, signal::Signal};
272        let server_ready = Arc::new(Signal::<CriticalSectionRawMutex, ()>::new());
273        let connection_established = Arc::new(Signal::<CriticalSectionRawMutex, ()>::new());
274        let client_done = Arc::new(Signal::<CriticalSectionRawMutex, ()>::new());
275        let server_ready_clone = server_ready.clone();
276        let connection_established_clone = connection_established.clone();
277        let client_done_clone = client_done.clone();
278        let mut listener = network_manager
279            .new_tcp_socket(1024, 1024, None)
280            .await
281            .expect("Failed to create listener socket");
282        let task_manager = task::get_instance();
283        let current_task = task_manager.get_current_task_identifier().await;
284        let (listen_task, _) = task_manager
285            .spawn(current_task, "TCP Listen Task", None, move |_| async move {
286                server_ready_clone.signal(());
287                let accept_result = listener.accept(Some([127, 0, 0, 1]), port).await;
288                if accept_result.is_err() {
289                    return;
290                }
291                connection_established_clone.signal(());
292                client_done_clone.wait().await;
293                listener.close_forced().await;
294            })
295            .await
296            .unwrap();
297        for _ in 0..5 {
298            task::sleep(Duration::from_milliseconds(20)).await;
299        }
300        server_ready.wait().await;
301        task::sleep(Duration::from_milliseconds(200)).await;
302        let mut client = network_manager
303            .new_tcp_socket(1024, 1024, None)
304            .await
305            .expect("Failed to create client socket");
306        let connect_result = client.connect([127, 0, 0, 1], port).await;
307        if connect_result.is_err() {
308            return;
309        }
310        connection_established.wait().await;
311        let _endpoint = client.get_local_endpoint().await;
312        task::sleep(Duration::from_milliseconds(100)).await;
313        client.close_forced().await;
314        client_done.signal(());
315        listen_task.join().await;
316    }
317
318    #[task::test]
319    async fn test_tcp_send_receive() {
320        use synchronization::{Arc, blocking_mutex::raw::CriticalSectionRawMutex, signal::Signal};
321        let _lock = TEST_MUTEX.lock().await;
322        let network_manager = initialize().await;
323        let port = Port::from_inner(51002);
324        let mut listener = network_manager
325            .new_tcp_socket(2048, 2048, None)
326            .await
327            .expect("Failed to create listener");
328        let server_ready = Arc::new(Signal::<CriticalSectionRawMutex, ()>::new());
329        let server_ready_clone = server_ready.clone();
330        let task_manager = task::get_instance();
331        let current_task = task_manager.get_current_task_identifier().await;
332        let (_server_task, _) = task_manager
333            .spawn(current_task, "TCP Server Task", None, move |_| async move {
334                server_ready_clone.signal(());
335                let accept_result = listener.accept(Some([127, 0, 0, 1]), port).await;
336                if accept_result.is_err() {
337                    return;
338                }
339                let mut buffer = [0u8; 1024];
340                match listener.read(&mut buffer).await {
341                    Ok(size) => {
342                        assert_eq!(&buffer[..size], b"Hello, TCP!", "Received data mismatch");
343                    }
344                    Err(_) => {
345                        return;
346                    }
347                }
348                let response = b"Hello back!";
349                if let Err(_) = listener.write(response).await {
350                    return;
351                }
352                if let Err(_) = listener.flush().await {
353                    return;
354                }
355                task::sleep(Duration::from_milliseconds(100)).await;
356                listener.close_forced().await;
357            })
358            .await
359            .unwrap();
360        server_ready.wait().await;
361        let mut client = network_manager
362            .new_tcp_socket(2048, 2048, None)
363            .await
364            .expect("Failed to create client");
365        let connect_result = client.connect([127, 0, 0, 1], port).await;
366        if connect_result.is_err() {
367            return;
368        }
369        task::sleep(Duration::from_milliseconds(100)).await;
370        if let Err(_) = client.write(b"Hello, TCP!").await {
371            return;
372        }
373        if let Err(_) = client.flush().await {
374            return;
375        }
376        let mut response_buffer = [0u8; 1024];
377        match client.read(&mut response_buffer).await {
378            Ok(size) => {
379                assert_eq!(
380                    &response_buffer[..size],
381                    b"Hello back!",
382                    "Response data mismatch"
383                );
384            }
385            Err(_) => {
386                return;
387            }
388        }
389        client.close_forced().await;
390    }
391
392    #[task::test]
393    async fn test_tcp_endpoints() {
394        let _lock = TEST_MUTEX.lock().await;
395        let network_manager = initialize().await;
396        let mut socket = network_manager
397            .new_tcp_socket(1024, 1024, None)
398            .await
399            .expect("Failed to create socket");
400        let local = socket
401            .get_local_endpoint()
402            .await
403            .expect("Failed to get local endpoint");
404        let remote = socket
405            .get_remote_endpoint()
406            .await
407            .expect("Failed to get remote endpoint");
408        assert!(
409            local.is_none(),
410            "TCP endpoint | Local endpoint should be None before connection"
411        );
412        assert!(
413            remote.is_none(),
414            "Remote endpoint should be None before connection"
415        );
416        let port = Port::from_inner(51003);
417        let mut listener = network_manager
418            .new_tcp_socket(1024, 1024, None)
419            .await
420            .expect("Failed to create listener");
421        use synchronization::{Arc, blocking_mutex::raw::CriticalSectionRawMutex, signal::Signal};
422        let server_ready = Arc::new(Signal::<CriticalSectionRawMutex, ()>::new());
423        let connection_ready = Arc::new(Signal::<CriticalSectionRawMutex, ()>::new());
424        let endpoints_checked = Arc::new(Signal::<CriticalSectionRawMutex, ()>::new());
425        let server_ready_clone = server_ready.clone();
426        let connection_ready_clone = connection_ready.clone();
427        let endpoints_checked_clone = endpoints_checked.clone();
428        let task_manager = task::get_instance();
429        let current_task = task_manager.get_current_task_identifier().await;
430        let (listen_task, _) = task_manager
431            .spawn(
432                current_task,
433                "TCP Endpoint Listen",
434                None,
435                move |_| async move {
436                    server_ready_clone.signal(());
437                    let accept_result = listener.accept(Some([127, 0, 0, 1]), port).await;
438                    if accept_result.is_err() {
439                        return;
440                    }
441                    connection_ready_clone.signal(());
442                    endpoints_checked_clone.wait().await;
443                    task::sleep(Duration::from_milliseconds(100)).await;
444                    listener.close_forced().await;
445                },
446            )
447            .await
448            .unwrap();
449        for _ in 0..5 {
450            task::sleep(Duration::from_milliseconds(20)).await;
451        }
452        server_ready.wait().await;
453        task::sleep(Duration::from_milliseconds(200)).await;
454        let connect_result = socket.connect([127, 0, 0, 1], port).await;
455        if connect_result.is_err() {
456            return;
457        }
458        task::sleep(Duration::from_milliseconds(50)).await;
459        connection_ready.wait().await;
460        task::sleep(Duration::from_milliseconds(50)).await;
461        let local = socket
462            .get_local_endpoint()
463            .await
464            .expect("Failed to get local endpoint");
465        let remote = socket
466            .get_remote_endpoint()
467            .await
468            .expect("Failed to get remote endpoint");
469        assert!(
470            local.is_some(),
471            "Local endpoint should be set after connection"
472        );
473        assert!(
474            remote.is_some(),
475            "Remote endpoint should be set after connection"
476        );
477        if let Some((addr, p)) = remote {
478            assert_eq!(
479                addr,
480                IpAddress::from([127, 0, 0, 1]),
481                "Remote address mismatch"
482            );
483            assert_eq!(p, port, "Remote port mismatch");
484        }
485        endpoints_checked.signal(());
486        task::sleep(Duration::from_milliseconds(100)).await;
487        socket.close_forced().await;
488        listen_task.join().await;
489    }
490
491    #[task::test]
492    async fn test_tcp_capacities() {
493        let _lock = TEST_MUTEX.lock().await;
494        let network_manager = initialize().await;
495        let tx_buffer = 2048;
496        let rx_buffer = 1024;
497        let mut socket = network_manager
498            .new_tcp_socket(tx_buffer, rx_buffer, None)
499            .await
500            .expect("Failed to create socket");
501        let read_cap = socket.get_read_capacity().await;
502        let write_cap = socket.get_write_capacity().await;
503        let read_queue = socket.get_read_queue_size().await;
504        let write_queue = socket.get_write_queue_size().await;
505        assert_eq!(read_cap, rx_buffer, "Read capacity mismatch");
506        assert_eq!(write_cap, tx_buffer, "Write capacity mismatch");
507        assert_eq!(read_queue, 0, "Read queue should be empty initially");
508        assert_eq!(write_queue, 0, "Write queue should be empty initially");
509        socket.close_forced().await;
510    }
511
512    #[task::test]
513    async fn test_tcp_flush() {
514        let _lock = TEST_MUTEX.lock().await;
515        let network_manager = initialize().await;
516        let port = Port::from_inner(51004);
517        let mut listener = network_manager
518            .new_tcp_socket(1024, 1024, None)
519            .await
520            .expect("Failed to create listener");
521        let task_manager = task::get_instance();
522        let current_task = task_manager.get_current_task_identifier().await;
523        let (_server_task, _) = task_manager
524            .spawn(
525                current_task,
526                "TCP Flush Server",
527                None,
528                move |_| async move {
529                    let accept_result = listener.accept(Some([127, 0, 0, 1]), port).await;
530                    if accept_result.is_err() {
531                        return;
532                    }
533                    let mut buffer = [0u8; 1024];
534                    let _ = listener.read(&mut buffer).await;
535                    task::sleep(Duration::from_milliseconds(200)).await;
536                    listener.close_forced().await;
537                },
538            )
539            .await
540            .unwrap();
541        task::sleep(Duration::from_milliseconds(300)).await;
542        let mut client = network_manager
543            .new_tcp_socket(1024, 1024, None)
544            .await
545            .expect("Failed to create client");
546        let connect_result = client.connect([127, 0, 0, 1], port).await;
547        if connect_result.is_err() {
548            return;
549        }
550        task::sleep(Duration::from_milliseconds(100)).await;
551        let write_result = client.write(b"Test data").await;
552        if write_result.is_err() {
553            return;
554        }
555        if let Err(_) = client.flush().await {
556            return;
557        }
558        task::sleep(Duration::from_milliseconds(200)).await;
559        client.close_forced().await;
560    }
561}