pub(crate) mod server_selection;
#[cfg(test)]
pub(crate) mod test;
use std::{
collections::{HashMap, HashSet},
time::Duration,
};
use serde::{Deserialize, Serialize};
use crate::{
bson::oid::ObjectId,
client::ClusterTime,
cmap::Command,
error::{Error, Result},
options::{ClientOptions, ServerAddress},
sdam::{
description::server::{ServerDescription, ServerType},
DEFAULT_HEARTBEAT_FREQUENCY,
},
selection_criteria::{ReadPreference, SelectionCriteria},
};
use self::server_selection::IDLE_WRITE_PERIOD;
#[derive(Debug, Clone, Copy, Eq, PartialEq, Deserialize, Serialize, derive_more::Display)]
#[non_exhaustive]
pub enum TopologyType {
Single,
ReplicaSetNoPrimary,
ReplicaSetWithPrimary,
Sharded,
LoadBalanced,
Unknown,
}
#[cfg(test)]
impl TopologyType {
fn as_str(&self) -> &'static str {
match self {
Self::Single => "Single",
Self::ReplicaSetNoPrimary => "ReplicaSetNoPrimary",
Self::ReplicaSetWithPrimary => "ReplicaSetWithPrimary",
Self::Sharded => "Sharded",
Self::LoadBalanced => "LoadBalanced",
Self::Unknown => "Unknown",
}
}
}
impl Default for TopologyType {
fn default() -> Self {
TopologyType::Unknown
}
}
#[derive(Debug, Clone, Serialize)]
#[non_exhaustive]
pub(crate) struct TopologyDescription {
#[serde(skip)]
pub(crate) single_seed: bool,
pub(crate) topology_type: TopologyType,
pub(crate) set_name: Option<String>,
pub(crate) max_set_version: Option<i32>,
pub(crate) max_election_id: Option<ObjectId>,
pub(crate) compatibility_error: Option<String>,
pub(crate) logical_session_timeout: Option<Duration>,
#[serde(skip)]
pub(crate) transaction_support_status: TransactionSupportStatus,
#[serde(skip)]
pub(crate) cluster_time: Option<ClusterTime>,
#[serde(skip)]
pub(crate) local_threshold: Option<Duration>,
#[serde(skip)]
pub(crate) heartbeat_freq: Option<Duration>,
pub(crate) servers: HashMap<ServerAddress, ServerDescription>,
pub(crate) srv_max_hosts: Option<u32>,
}
impl PartialEq for TopologyDescription {
fn eq(&self, other: &Self) -> bool {
self.compatibility_error == other.compatibility_error
&& self.servers == other.servers
&& self.topology_type == other.topology_type
}
}
impl Default for TopologyDescription {
fn default() -> Self {
Self {
single_seed: false,
topology_type: TopologyType::Unknown,
set_name: Default::default(),
max_set_version: Default::default(),
max_election_id: Default::default(),
compatibility_error: Default::default(),
logical_session_timeout: None,
transaction_support_status: TransactionSupportStatus::Undetermined,
cluster_time: Default::default(),
local_threshold: Default::default(),
heartbeat_freq: Default::default(),
servers: Default::default(),
srv_max_hosts: Default::default(),
}
}
}
impl TopologyDescription {
pub(crate) fn initialize(&mut self, options: &ClientOptions) {
debug_assert!(
self.servers.is_empty() && self.topology_type == TopologyType::Unknown,
"new TopologyDescriptions should start empty"
);
self.topology_type = if let Some(true) = options.direct_connection {
TopologyType::Single
} else if options.repl_set_name.is_some() {
TopologyType::ReplicaSetNoPrimary
} else if options.load_balanced.unwrap_or(false) {
TopologyType::LoadBalanced
} else {
TopologyType::Unknown
};
self.transaction_support_status = if self.topology_type == TopologyType::LoadBalanced {
TransactionSupportStatus::Supported
} else {
TransactionSupportStatus::Undetermined
};
for address in options.hosts.iter() {
let description = ServerDescription::new(address.clone());
self.servers.insert(address.to_owned(), description);
}
self.single_seed = self.servers.len() == 1;
self.set_name = options.repl_set_name.clone();
self.local_threshold = options.local_threshold;
self.heartbeat_freq = options.heartbeat_freq;
self.srv_max_hosts = options.srv_max_hosts;
}
pub(crate) fn topology_type(&self) -> TopologyType {
self.topology_type
}
pub(crate) fn server_addresses(&self) -> impl Iterator<Item = &ServerAddress> {
self.servers.keys()
}
pub(crate) fn cluster_time(&self) -> Option<&ClusterTime> {
self.cluster_time.as_ref()
}
pub(crate) fn get_server_description(
&self,
address: &ServerAddress,
) -> Option<&ServerDescription> {
self.servers.get(address)
}
pub(crate) fn update_command_with_read_pref<T>(
&self,
address: &ServerAddress,
command: &mut Command<T>,
criteria: Option<&SelectionCriteria>,
) {
let server_type = self
.get_server_description(address)
.map(|sd| sd.server_type)
.unwrap_or(ServerType::Unknown);
match (self.topology_type, server_type) {
(TopologyType::Sharded, ServerType::Mongos)
| (TopologyType::Single, ServerType::Mongos)
| (TopologyType::LoadBalanced, _) => {
self.update_command_read_pref_for_mongos(command, criteria)
}
(TopologyType::Single, ServerType::Standalone) => {}
(TopologyType::Single, _) => {
let specified_read_pref = criteria
.and_then(SelectionCriteria::as_read_pref)
.map(Clone::clone);
let resolved_read_pref = match specified_read_pref {
Some(ReadPreference::Primary) | None => ReadPreference::PrimaryPreferred {
options: Default::default(),
},
Some(other) => other,
};
if resolved_read_pref != ReadPreference::Primary {
command.set_read_preference(resolved_read_pref)
}
}
_ => {
let read_pref = match criteria {
Some(SelectionCriteria::ReadPreference(rp)) => rp.clone(),
Some(SelectionCriteria::Predicate(_)) => ReadPreference::PrimaryPreferred {
options: Default::default(),
},
None => ReadPreference::Primary,
};
if read_pref != ReadPreference::Primary {
command.set_read_preference(read_pref)
}
}
}
}
fn update_command_read_pref_for_mongos<T>(
&self,
command: &mut Command<T>,
criteria: Option<&SelectionCriteria>,
) {
let read_preference = match criteria {
Some(SelectionCriteria::ReadPreference(rp)) => rp,
_ => return,
};
match read_preference {
ReadPreference::Secondary { .. }
| ReadPreference::PrimaryPreferred { .. }
| ReadPreference::Nearest { .. }
| ReadPreference::SecondaryPreferred { .. } => {
command.set_read_preference(read_preference.clone())
}
_ => {}
}
}
fn heartbeat_frequency(&self) -> Duration {
self.heartbeat_freq.unwrap_or(DEFAULT_HEARTBEAT_FREQUENCY)
}
fn check_compatibility(&mut self) {
self.compatibility_error = None;
for server in self.servers.values() {
let error_message = server.compatibility_error_message();
if error_message.is_some() {
self.compatibility_error = error_message;
return;
}
}
}
pub(crate) fn compatibility_error(&self) -> Option<&String> {
self.compatibility_error.as_ref()
}
fn update_logical_session_timeout(&mut self, server_description: &ServerDescription) {
if !server_description.server_type.is_data_bearing() {
return;
}
match server_description.logical_session_timeout().ok().flatten() {
Some(new_timeout) => match self.logical_session_timeout {
Some(current_timeout) => {
self.logical_session_timeout =
Some(std::cmp::min(current_timeout, new_timeout));
}
None => {
let min_timeout = self
.servers
.values()
.filter(|s| s.server_type.is_data_bearing())
.map(|s| s.logical_session_timeout().ok().flatten())
.min()
.flatten();
self.logical_session_timeout = min_timeout;
}
},
None => self.logical_session_timeout = None,
}
}
fn update_transaction_support_status(&mut self, server_description: &ServerDescription) {
if self.logical_session_timeout.is_none() {
self.transaction_support_status = TransactionSupportStatus::Unsupported;
}
if let Ok(Some(max_wire_version)) = server_description.max_wire_version() {
self.transaction_support_status = if max_wire_version < 7
|| (max_wire_version < 8 && self.topology_type == TopologyType::Sharded)
{
TransactionSupportStatus::Unsupported
} else {
TransactionSupportStatus::Supported
}
}
}
pub(crate) fn advance_cluster_time(&mut self, cluster_time: &ClusterTime) {
if self.cluster_time.as_ref() >= Some(cluster_time) {
return;
}
self.cluster_time = Some(cluster_time.clone());
}
pub(crate) fn diff<'a>(
&'a self,
other: &'a TopologyDescription,
) -> Option<TopologyDescriptionDiff> {
if self == other {
return None;
}
let addresses: HashSet<&ServerAddress> = self.server_addresses().collect();
let other_addresses: HashSet<&ServerAddress> = other.server_addresses().collect();
let changed_servers = self
.servers
.iter()
.filter_map(|(address, description)| match other.servers.get(address) {
Some(other_description) if description != other_description => {
Some((address, (description, other_description)))
}
_ => None,
});
Some(TopologyDescriptionDiff {
removed_addresses: addresses.difference(&other_addresses).cloned().collect(),
added_addresses: other_addresses.difference(&addresses).cloned().collect(),
changed_servers: changed_servers.collect(),
})
}
pub(crate) fn sync_hosts(&mut self, hosts: HashSet<ServerAddress>) {
self.servers.retain(|host, _| hosts.contains(host));
let mut new = vec![];
for host in hosts {
if !self.servers.contains_key(&host) {
new.push((host.clone(), ServerDescription::new(host)));
}
}
if let Some(max) = self.srv_max_hosts {
let max = max as usize;
if max > 0 && max < self.servers.len() + new.len() {
new = choose_n(&new, max.saturating_sub(self.servers.len()))
.cloned()
.collect();
}
}
self.servers.extend(new);
}
pub(crate) fn transaction_support_status(&self) -> TransactionSupportStatus {
self.transaction_support_status
}
pub(crate) fn update(&mut self, mut server_description: ServerDescription) -> Result<()> {
match self.servers.get(&server_description.address) {
None => return Ok(()),
Some(existing_sd) => {
if let Some(existing_tv) = existing_sd.topology_version() {
if let Some(new_tv) = server_description.topology_version() {
if existing_tv.process_id == new_tv.process_id
&& new_tv.counter < existing_tv.counter
{
return Ok(());
}
}
}
}
}
if let Some(expected_name) = &self.set_name {
if server_description.is_available() {
let got_name = server_description.set_name();
if self.topology_type() == TopologyType::Single
&& !matches!(
got_name.as_ref().map(|opt| opt.as_ref()),
Ok(Some(name)) if name == expected_name
)
{
let got_display = match got_name {
Ok(Some(s)) => format!("{:?}", s),
Ok(None) => "<none>".to_string(),
Err(s) => format!("<error: {}>", s),
};
server_description = ServerDescription::new_from_error(
server_description.address,
Error::invalid_argument(format!(
"Connection string replicaSet name {:?} does not match actual name {}",
expected_name, got_display,
)),
);
}
}
}
self.servers.insert(
server_description.address.clone(),
server_description.clone(),
);
if let TopologyType::LoadBalanced = self.topology_type {
return Ok(());
}
self.update_logical_session_timeout(&server_description);
self.update_transaction_support_status(&server_description);
if let Some(ref cluster_time) = server_description.cluster_time().ok().flatten() {
self.advance_cluster_time(cluster_time);
}
match self.topology_type {
TopologyType::Single | TopologyType::LoadBalanced => {}
TopologyType::Unknown => self.update_unknown_topology(server_description)?,
TopologyType::Sharded => self.update_sharded_topology(server_description),
TopologyType::ReplicaSetNoPrimary => {
self.update_replica_set_no_primary_topology(server_description)?
}
TopologyType::ReplicaSetWithPrimary => {
self.update_replica_set_with_primary_topology(server_description)?;
}
}
self.check_compatibility();
Ok(())
}
fn update_unknown_topology(&mut self, server_description: ServerDescription) -> Result<()> {
match server_description.server_type {
ServerType::Unknown | ServerType::RsGhost => {}
ServerType::Standalone => {
self.update_unknown_with_standalone_server(server_description)
}
ServerType::Mongos => self.topology_type = TopologyType::Sharded,
ServerType::RsPrimary => {
self.topology_type = TopologyType::ReplicaSetWithPrimary;
self.update_rs_from_primary_server(server_description)?;
}
ServerType::RsSecondary | ServerType::RsArbiter | ServerType::RsOther => {
self.topology_type = TopologyType::ReplicaSetNoPrimary;
self.update_rs_without_primary_server(server_description)?;
}
ServerType::LoadBalancer => {
return Err(Error::internal("cannot transition to a load balancer"))
}
}
Ok(())
}
fn update_sharded_topology(&mut self, server_description: ServerDescription) {
match server_description.server_type {
ServerType::Unknown | ServerType::Mongos => {}
_ => {
self.servers.remove(&server_description.address);
}
}
}
fn update_replica_set_no_primary_topology(
&mut self,
server_description: ServerDescription,
) -> Result<()> {
match server_description.server_type {
ServerType::Unknown | ServerType::RsGhost => {}
ServerType::Standalone | ServerType::Mongos => {
self.servers.remove(&server_description.address);
}
ServerType::RsPrimary => {
self.topology_type = TopologyType::ReplicaSetWithPrimary;
self.update_rs_from_primary_server(server_description)?
}
ServerType::RsSecondary | ServerType::RsArbiter | ServerType::RsOther => {
self.update_rs_without_primary_server(server_description)?;
}
ServerType::LoadBalancer => {
return Err(Error::internal("cannot transition to a load balancer"))
}
}
Ok(())
}
fn update_replica_set_with_primary_topology(
&mut self,
server_description: ServerDescription,
) -> Result<()> {
match server_description.server_type {
ServerType::Unknown | ServerType::RsGhost => {
self.record_primary_state();
}
ServerType::Standalone | ServerType::Mongos => {
self.servers.remove(&server_description.address);
self.record_primary_state();
}
ServerType::RsPrimary => self.update_rs_from_primary_server(server_description)?,
ServerType::RsSecondary | ServerType::RsArbiter | ServerType::RsOther => {
self.update_rs_with_primary_from_member(server_description)?;
}
ServerType::LoadBalancer => {
return Err(Error::internal("cannot transition to a load balancer"));
}
}
Ok(())
}
fn update_unknown_with_standalone_server(&mut self, server_description: ServerDescription) {
if self.single_seed {
self.topology_type = TopologyType::Single;
} else {
self.servers.remove(&server_description.address);
}
}
fn update_rs_without_primary_server(
&mut self,
server_description: ServerDescription,
) -> Result<()> {
if self.set_name.is_none() {
self.set_name = server_description.set_name()?;
} else if self.set_name != server_description.set_name()? {
self.servers.remove(&server_description.address);
return Ok(());
}
self.add_new_servers(server_description.known_hosts()?)?;
if server_description.invalid_me()? {
self.servers.remove(&server_description.address);
}
Ok(())
}
fn update_rs_with_primary_from_member(
&mut self,
server_description: ServerDescription,
) -> Result<()> {
if self.set_name != server_description.set_name()? {
self.servers.remove(&server_description.address);
self.record_primary_state();
return Ok(());
}
if server_description.invalid_me()? {
self.servers.remove(&server_description.address);
self.record_primary_state();
return Ok(());
}
Ok(())
}
fn update_rs_from_primary_server(
&mut self,
server_description: ServerDescription,
) -> Result<()> {
if self.set_name.is_none() {
self.set_name = server_description.set_name()?;
} else if self.set_name != server_description.set_name()? {
self.servers.remove(&server_description.address);
self.record_primary_state();
return Ok(());
}
if let Some(server_set_version) = server_description.set_version()? {
if let Some(server_election_id) = server_description.election_id()? {
if let Some(topology_max_set_version) = self.max_set_version {
if let Some(ref topology_max_election_id) = self.max_election_id {
if topology_max_set_version > server_set_version
|| (topology_max_set_version == server_set_version
&& *topology_max_election_id > server_election_id)
{
self.servers.insert(
server_description.address.clone(),
ServerDescription::new(server_description.address),
);
self.record_primary_state();
return Ok(());
}
}
}
self.max_election_id = Some(server_election_id);
}
}
if let Some(server_set_version) = server_description.set_version()? {
if self
.max_set_version
.as_ref()
.map(|topology_max_set_version| server_set_version > *topology_max_set_version)
.unwrap_or(true)
{
self.max_set_version = Some(server_set_version);
}
}
let addresses: Vec<_> = self.servers.keys().cloned().collect();
for address in addresses.clone() {
if address == server_description.address {
continue;
}
if let ServerType::RsPrimary = self.servers.get(&address).unwrap().server_type {
self.servers
.insert(address.clone(), ServerDescription::new(address));
}
}
self.add_new_servers(server_description.known_hosts()?)?;
let known_hosts: HashSet<_> = server_description.known_hosts()?.collect();
for address in addresses {
if !known_hosts.contains(&address.to_string()) {
self.servers.remove(&address);
}
}
self.record_primary_state();
Ok(())
}
fn record_primary_state(&mut self) {
self.topology_type = if self
.servers
.values()
.any(|server| server.server_type == ServerType::RsPrimary)
{
TopologyType::ReplicaSetWithPrimary
} else {
TopologyType::ReplicaSetNoPrimary
};
}
fn add_new_servers<'a>(&mut self, servers: impl Iterator<Item = &'a String>) -> Result<()> {
let servers: Result<Vec<_>> = servers.map(ServerAddress::parse).collect();
self.add_new_servers_from_addresses(servers?.iter());
Ok(())
}
fn add_new_servers_from_addresses<'a>(
&mut self,
servers: impl Iterator<Item = &'a ServerAddress>,
) {
for server in servers {
if !self.servers.contains_key(server) {
self.servers
.insert(server.clone(), ServerDescription::new(server.clone()));
}
}
}
}
pub(crate) fn choose_n<T>(values: &[T], n: usize) -> impl Iterator<Item = &T> {
use rand::{prelude::SliceRandom, SeedableRng};
values.choose_multiple(&mut rand::rngs::SmallRng::from_entropy(), n)
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub(crate) enum TransactionSupportStatus {
Undetermined,
Unsupported,
Supported,
}
impl Default for TransactionSupportStatus {
fn default() -> Self {
Self::Undetermined
}
}
#[derive(Debug)]
pub(crate) struct TopologyDescriptionDiff<'a> {
pub(crate) removed_addresses: HashSet<&'a ServerAddress>,
pub(crate) added_addresses: HashSet<&'a ServerAddress>,
pub(crate) changed_servers:
HashMap<&'a ServerAddress, (&'a ServerDescription, &'a ServerDescription)>,
}
pub(crate) fn verify_max_staleness(
max_staleness: Duration,
heartbeat_frequency: Duration,
) -> crate::error::Result<()> {
let smallest_max_staleness = std::cmp::max(
Duration::from_secs(90),
heartbeat_frequency
.checked_add(IDLE_WRITE_PERIOD)
.unwrap_or(Duration::MAX),
);
if max_staleness < smallest_max_staleness {
return Err(Error::invalid_argument(format!(
"invalid max_staleness value: must be at least {} seconds",
smallest_max_staleness.as_secs()
)));
}
Ok(())
}