1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#[cfg(test)]
mod test;

use std::time::Duration;

use super::{
    description::topology::TopologyType,
    monitor::DEFAULT_HEARTBEAT_FREQUENCY,
    TopologyUpdater,
    TopologyWatcher,
};
use crate::{
    error::{Error, Result},
    options::ClientOptions,
    runtime,
    srv::{LookupHosts, SrvResolver},
};

const MIN_RESCAN_SRV_INTERVAL: Duration = Duration::from_secs(60);

pub(crate) struct SrvPollingMonitor {
    initial_hostname: String,
    resolver: Option<SrvResolver>,
    topology_updater: TopologyUpdater,
    topology_watcher: TopologyWatcher,
    rescan_interval: Duration,
    client_options: ClientOptions,
}

impl SrvPollingMonitor {
    pub(crate) fn new(
        topology_updater: TopologyUpdater,
        topology_watcher: TopologyWatcher,
        mut client_options: ClientOptions,
    ) -> Option<Self> {
        let initial_info = match client_options.original_srv_info.take() {
            Some(info) => info,
            None => return None,
        };

        Some(Self {
            initial_hostname: initial_info.hostname,
            resolver: None,
            topology_updater,
            topology_watcher,
            rescan_interval: initial_info.min_ttl,
            client_options,
        })
    }

    /// Starts a monitoring task that periodically performs SRV record lookups to determine if the
    /// set of mongos in the cluster have changed. A weak reference is used to ensure that the
    /// monitoring task doesn't keep the topology alive after the client has been dropped.
    pub(super) fn start(
        topology: TopologyUpdater,
        topology_watcher: TopologyWatcher,
        client_options: ClientOptions,
    ) {
        if let Some(monitor) = Self::new(topology, topology_watcher, client_options) {
            runtime::execute(monitor.execute());
        }
    }

    fn rescan_interval(&self) -> Duration {
        std::cmp::max(self.rescan_interval, MIN_RESCAN_SRV_INTERVAL)
    }

    async fn execute(mut self) {
        fn should_poll(tt: TopologyType) -> bool {
            matches!(tt, TopologyType::Sharded | TopologyType::Unknown)
        }

        while self.topology_watcher.is_alive() {
            runtime::delay_for(self.rescan_interval()).await;

            if should_poll(self.topology_watcher.topology_type()) {
                let hosts = self.lookup_hosts().await;

                // verify we should still update before updating in case the topology changed
                // while the srv lookup was happening.
                if should_poll(self.topology_watcher.topology_type()) {
                    self.update_hosts(hosts).await;
                }
            }
        }
    }

    async fn update_hosts(&mut self, lookup: Result<LookupHosts>) {
        let lookup = match lookup {
            Ok(LookupHosts { hosts, .. }) if hosts.is_empty() => {
                self.no_valid_hosts(None);

                return;
            }
            Ok(lookup) => lookup,
            Err(err) => {
                self.no_valid_hosts(Some(err));

                return;
            }
        };

        self.rescan_interval = lookup.min_ttl;

        // TODO: RUST-230 Log error with host that was returned.
        self.topology_updater
            .sync_hosts(lookup.hosts.into_iter().collect())
            .await;
    }

    async fn lookup_hosts(&mut self) -> Result<LookupHosts> {
        #[cfg(test)]
        if let Some(mock) = self
            .client_options
            .test_options
            .as_ref()
            .and_then(|to| to.mock_lookup_hosts.as_ref())
        {
            return mock.clone();
        }
        let initial_hostname = self.initial_hostname.clone();
        let resolver = self.get_or_create_srv_resolver().await?;
        resolver
            .get_srv_hosts(initial_hostname.as_str(), crate::srv::DomainMismatch::Skip)
            .await
    }

    async fn get_or_create_srv_resolver(&mut self) -> Result<&SrvResolver> {
        if let Some(ref resolver) = self.resolver {
            return Ok(resolver);
        }

        let resolver =
            SrvResolver::new(self.client_options.resolver_config.clone().map(|c| c.inner)).await?;

        // Since the connection was not `Some` above, this will always insert the new connection and
        // return a reference to it.
        Ok(self.resolver.get_or_insert(resolver))
    }

    fn no_valid_hosts(&mut self, _error: Option<Error>) {
        // TODO RUST-230: Log error/lack of valid results.

        self.rescan_interval = self.heartbeat_freq();
    }

    fn heartbeat_freq(&self) -> Duration {
        self.client_options
            .heartbeat_freq
            .unwrap_or(DEFAULT_HEARTBEAT_FREQUENCY)
    }
}