pub(super) mod handshake;
use std::time::Duration;
use self::handshake::{Handshaker, HandshakerOptions};
use super::{
conn::{ConnectionGeneration, LoadBalancedGeneration, PendingConnection},
Connection,
PoolGeneration,
};
use crate::{
client::{
auth::Credential,
options::{ClientOptions, ServerAddress, TlsOptions},
},
error::{Error as MongoError, ErrorKind, Result},
hello::HelloReply,
runtime::{self, stream::DEFAULT_CONNECT_TIMEOUT, AsyncStream, TlsConfig},
sdam::HandshakePhase,
};
#[derive(Clone)]
pub(crate) struct ConnectionEstablisher {
handshaker: Handshaker,
tls_config: Option<TlsConfig>,
connect_timeout: Duration,
}
pub(crate) struct EstablisherOptions {
handshake_options: HandshakerOptions,
tls_options: Option<TlsOptions>,
connect_timeout: Option<Duration>,
}
impl EstablisherOptions {
pub(crate) fn from_client_options(opts: &ClientOptions) -> Self {
Self {
handshake_options: HandshakerOptions {
app_name: opts.app_name.clone(),
compressors: opts.compressors.clone(),
driver_info: opts.driver_info.clone(),
server_api: opts.server_api.clone(),
load_balanced: opts.load_balanced.unwrap_or(false),
},
tls_options: opts.tls_options(),
connect_timeout: opts.connect_timeout,
}
}
}
impl ConnectionEstablisher {
pub(crate) fn new(options: EstablisherOptions) -> Result<Self> {
let handshaker = Handshaker::new(options.handshake_options);
let tls_config = if let Some(tls_options) = options.tls_options {
Some(TlsConfig::new(tls_options)?)
} else {
None
};
let connect_timeout = match options.connect_timeout {
Some(d) if d.is_zero() => Duration::MAX,
Some(d) => d,
None => DEFAULT_CONNECT_TIMEOUT,
};
Ok(Self {
handshaker,
tls_config,
connect_timeout,
})
}
async fn make_stream(&self, address: ServerAddress) -> Result<AsyncStream> {
runtime::timeout(
self.connect_timeout,
AsyncStream::connect(address, self.tls_config.as_ref()),
)
.await?
}
pub(super) async fn establish_connection(
&self,
pending_connection: PendingConnection,
credential: Option<&Credential>,
) -> std::result::Result<Connection, EstablishError> {
let pool_gen = pending_connection.generation.clone();
let address = pending_connection.address.clone();
let stream = self
.make_stream(address)
.await
.map_err(|e| EstablishError::pre_hello(e, pool_gen.clone()))?;
let mut connection = Connection::new_pooled(pending_connection, stream);
let handshake_result = self.handshaker.handshake(&mut connection, credential).await;
match (&pool_gen, connection.service_id()) {
(PoolGeneration::Normal(_), _) => {}
(PoolGeneration::LoadBalanced(gen_map), Some(service_id)) => {
connection.generation = LoadBalancedGeneration {
generation: *gen_map.get(&service_id).unwrap_or(&0),
service_id,
}
.into();
}
(PoolGeneration::LoadBalanced(_), None) => {
if handshake_result.is_ok() {
return Err(EstablishError::post_hello(
ErrorKind::IncompatibleServer {
message: "Driver attempted to initialize in load balancing mode, but \
the server does not support this mode."
.to_string(),
}
.into(),
connection.generation,
));
}
}
}
handshake_result.map_err(|e| {
if connection.stream_description.is_none() {
EstablishError::pre_hello(e, pool_gen)
} else {
EstablishError::post_hello(e, connection.generation)
}
})?;
Ok(connection)
}
pub(crate) async fn establish_monitoring_connection(
&self,
address: ServerAddress,
id: u32,
) -> Result<(Connection, HelloReply)> {
let stream = self.make_stream(address.clone()).await?;
let mut connection = Connection::new_monitoring(address, stream, id);
let hello_reply = self.handshaker.handshake(&mut connection, None).await?;
Ok((connection, hello_reply))
}
}
#[derive(Debug, Clone)]
pub(crate) struct EstablishError {
pub(crate) cause: MongoError,
pub(crate) handshake_phase: HandshakePhase,
}
impl EstablishError {
fn pre_hello(cause: MongoError, generation: PoolGeneration) -> Self {
Self {
cause,
handshake_phase: HandshakePhase::PreHello { generation },
}
}
fn post_hello(cause: MongoError, generation: ConnectionGeneration) -> Self {
Self {
cause,
handshake_phase: HandshakePhase::PostHello { generation },
}
}
}