network/socket/
dns.rs

1use core::{
2    future::poll_fn,
3    task::{Context, Poll},
4};
5
6use alloc::{vec, vec::Vec};
7use embassy_futures::block_on;
8use smoltcp::{socket::dns, wire::DnsQueryType};
9
10use crate::{DnsQueryKind, IpAddress, Result, SocketContext};
11
12pub struct DnsSocket {
13    context: SocketContext,
14}
15
16impl DnsSocket {
17    pub fn new(context: SocketContext) -> Self {
18        Self { context }
19    }
20
21    fn poll_with_mutable<F, R>(&self, context: &mut Context<'_>, f: F) -> Poll<R>
22    where
23        F: FnOnce(&mut dns::Socket<'static>, &mut Context<'_>) -> Poll<R>,
24    {
25        self.context.poll_with_mutable(context, f)
26    }
27
28    pub async fn update_servers(&self) -> Result<()> {
29        self.context
30            .stack
31            .with_mutable(|s| {
32                let dns_servers = s.get_dns_servers().to_vec();
33                let socket = s.sockets.get_mut::<dns::Socket>(self.context.handle);
34                socket.update_servers(&dns_servers);
35            })
36            .await;
37
38        Ok(())
39    }
40
41    pub async fn resolve_for_kind(&self, host: &str, kind: DnsQueryType) -> Result<Vec<IpAddress>> {
42        if let Ok(host) = IpAddress::try_from(host) {
43            return Ok(vec![host]);
44        }
45
46        let query = self
47            .context
48            .stack
49            .with_mutable(|s| {
50                let socket = s.sockets.get_mut::<dns::Socket>(self.context.handle);
51
52                socket.start_query(s.interface.context(), host, kind)
53            })
54            .await?;
55
56        self.context.stack.wake_up();
57
58        poll_fn(|cx| {
59            self.poll_with_mutable(cx, |socket, cx| match socket.get_query_result(query) {
60                Err(dns::GetQueryResultError::Pending) => {
61                    socket.register_query_waker(query, cx.waker());
62                    Poll::Pending
63                }
64                Err(e) => Poll::Ready(Err(e.into())),
65                Ok(ip_addresses) => {
66                    let ip_addresses = ip_addresses
67                        .into_iter()
68                        .map(|a| IpAddress::from_smoltcp(&a))
69                        .collect();
70
71                    Poll::Ready(Ok(ip_addresses))
72                }
73            })
74        })
75        .await
76    }
77
78    pub async fn resolve(&self, host: &str, kind: DnsQueryKind) -> Result<Vec<IpAddress>> {
79        let mut results = Vec::new();
80
81        if kind.contains(DnsQueryKind::A) {
82            let mut a_results = self.resolve_for_kind(host, DnsQueryType::A).await?;
83            results.append(&mut a_results);
84        }
85
86        if kind.contains(DnsQueryKind::Aaaa) {
87            let mut aaaa_results = self.resolve_for_kind(host, DnsQueryType::Aaaa).await?;
88            results.append(&mut aaaa_results);
89        }
90
91        if kind.contains(DnsQueryKind::Cname) {
92            let mut cname_results = self.resolve_for_kind(host, DnsQueryType::Cname).await?;
93            results.append(&mut cname_results);
94        }
95
96        if kind.contains(DnsQueryKind::Ns) {
97            let mut ns_results = self.resolve_for_kind(host, DnsQueryType::Ns).await?;
98            results.append(&mut ns_results);
99        }
100
101        if kind.contains(DnsQueryKind::Soa) {
102            let mut soa_results = self.resolve_for_kind(host, DnsQueryType::Soa).await?;
103            results.append(&mut soa_results);
104        }
105
106        Ok(results)
107    }
108
109    pub async fn close(mut self) -> Result<()> {
110        if self.context.closed {
111            return Ok(());
112        }
113
114        self.context
115            .stack
116            .with_mutable(|s| {
117                let _ = s.remove_socket(self.context.handle);
118            })
119            .await;
120
121        self.context.closed = true;
122
123        Ok(())
124    }
125}
126
127impl Drop for DnsSocket {
128    fn drop(&mut self) {
129        if self.context.closed {
130            return;
131        }
132
133        log::warning!("DNS socket dropped without being closed. Forcing closure...");
134
135        block_on(self.context.stack.with_mutable(|s| {
136            let _ = s.remove_socket(self.context.handle);
137        }));
138    }
139}