1use core::{
2 future::poll_fn,
3 task::{Context, Poll},
4};
5
6use alloc::{vec, vec::Vec};
7use embassy_futures::block_on;
8use smoltcp::{socket::dns, wire::DnsQueryType};
9
10use crate::{DnsQueryKind, IpAddress, Result, SocketContext};
11
12pub struct DnsSocket {
13 context: SocketContext,
14}
15
16impl DnsSocket {
17 pub fn new(context: SocketContext) -> Self {
18 Self { context }
19 }
20
21 fn poll_with_mutable<F, R>(&self, context: &mut Context<'_>, f: F) -> Poll<R>
22 where
23 F: FnOnce(&mut dns::Socket<'static>, &mut Context<'_>) -> Poll<R>,
24 {
25 self.context.poll_with_mutable(context, f)
26 }
27
28 pub async fn update_servers(&self) -> Result<()> {
29 self.context
30 .stack
31 .with_mutable(|s| {
32 let dns_servers = s.get_dns_servers().to_vec();
33 let socket = s.sockets.get_mut::<dns::Socket>(self.context.handle);
34 socket.update_servers(&dns_servers);
35 })
36 .await;
37
38 Ok(())
39 }
40
41 pub async fn resolve_for_kind(&self, host: &str, kind: DnsQueryType) -> Result<Vec<IpAddress>> {
42 if let Ok(host) = IpAddress::try_from(host) {
43 return Ok(vec![host]);
44 }
45
46 let query = self
47 .context
48 .stack
49 .with_mutable(|s| {
50 let socket = s.sockets.get_mut::<dns::Socket>(self.context.handle);
51
52 socket.start_query(s.interface.context(), host, kind)
53 })
54 .await?;
55
56 self.context.stack.wake_up();
57
58 poll_fn(|cx| {
59 self.poll_with_mutable(cx, |socket, cx| match socket.get_query_result(query) {
60 Err(dns::GetQueryResultError::Pending) => {
61 socket.register_query_waker(query, cx.waker());
62 Poll::Pending
63 }
64 Err(e) => Poll::Ready(Err(e.into())),
65 Ok(ip_addresses) => {
66 let ip_addresses = ip_addresses
67 .into_iter()
68 .map(|a| IpAddress::from_smoltcp(&a))
69 .collect();
70
71 Poll::Ready(Ok(ip_addresses))
72 }
73 })
74 })
75 .await
76 }
77
78 pub async fn resolve(&self, host: &str, kind: DnsQueryKind) -> Result<Vec<IpAddress>> {
79 let mut results = Vec::new();
80
81 if kind.contains(DnsQueryKind::A) {
82 let mut a_results = self.resolve_for_kind(host, DnsQueryType::A).await?;
83 results.append(&mut a_results);
84 }
85
86 if kind.contains(DnsQueryKind::Aaaa) {
87 let mut aaaa_results = self.resolve_for_kind(host, DnsQueryType::Aaaa).await?;
88 results.append(&mut aaaa_results);
89 }
90
91 if kind.contains(DnsQueryKind::Cname) {
92 let mut cname_results = self.resolve_for_kind(host, DnsQueryType::Cname).await?;
93 results.append(&mut cname_results);
94 }
95
96 if kind.contains(DnsQueryKind::Ns) {
97 let mut ns_results = self.resolve_for_kind(host, DnsQueryType::Ns).await?;
98 results.append(&mut ns_results);
99 }
100
101 if kind.contains(DnsQueryKind::Soa) {
102 let mut soa_results = self.resolve_for_kind(host, DnsQueryType::Soa).await?;
103 results.append(&mut soa_results);
104 }
105
106 Ok(results)
107 }
108
109 pub async fn close(mut self) -> Result<()> {
110 if self.context.closed {
111 return Ok(());
112 }
113
114 self.context
115 .stack
116 .with_mutable(|s| {
117 let _ = s.remove_socket(self.context.handle);
118 })
119 .await;
120
121 self.context.closed = true;
122
123 Ok(())
124 }
125}
126
127impl Drop for DnsSocket {
128 fn drop(&mut self) {
129 if self.context.closed {
130 return;
131 }
132
133 log::warning!("DNS socket dropped without being closed. Forcing closure...");
134
135 block_on(self.context.stack.with_mutable(|s| {
136 let _ = s.remove_socket(self.context.handle);
137 }));
138 }
139}