use bson::{RawDocument, RawDocumentBuf};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use super::wire::Message;
use crate::{
bson::{rawdoc, Document},
bson_util::extend_raw_document_buf,
client::{options::ServerApi, ClusterTime, HELLO_COMMAND_NAMES, REDACTED_COMMANDS},
error::{Error, ErrorKind, Result},
hello::{HelloCommandResponse, HelloReply},
operation::{CommandErrorBody, CommandResponse},
options::{ReadConcern, ReadConcernInternal, ReadConcernLevel, ServerAddress},
selection_criteria::ReadPreference,
ClientSession,
};
#[derive(Debug)]
pub(crate) struct RawCommand {
pub(crate) name: String,
pub(crate) target_db: String,
pub(crate) exhaust_allowed: bool,
pub(crate) bytes: Vec<u8>,
}
impl RawCommand {
pub(crate) fn should_compress(&self) -> bool {
let name = self.name.to_lowercase();
!REDACTED_COMMANDS.contains(name.as_str()) && !HELLO_COMMAND_NAMES.contains(name.as_str())
}
}
#[serde_with::skip_serializing_none]
#[derive(Clone, Debug, Serialize, Default)]
#[serde(rename_all = "camelCase")]
pub(crate) struct Command<T = Document> {
#[serde(skip)]
pub(crate) name: String,
#[serde(skip)]
pub(crate) exhaust_allowed: bool,
#[serde(flatten)]
pub(crate) body: T,
#[serde(rename = "$db")]
pub(crate) target_db: String,
lsid: Option<Document>,
#[serde(rename = "$clusterTime")]
cluster_time: Option<ClusterTime>,
#[serde(flatten)]
server_api: Option<ServerApi>,
#[serde(rename = "$readPreference")]
read_preference: Option<ReadPreference>,
txn_number: Option<i64>,
start_transaction: Option<bool>,
autocommit: Option<bool>,
read_concern: Option<ReadConcernInternal>,
recovery_token: Option<Document>,
}
impl<T> Command<T> {
pub(crate) fn new(name: String, target_db: String, body: T) -> Self {
Self {
name,
target_db,
exhaust_allowed: false,
body,
lsid: None,
cluster_time: None,
server_api: None,
read_preference: None,
txn_number: None,
start_transaction: None,
autocommit: None,
read_concern: None,
recovery_token: None,
}
}
pub(crate) fn new_read(
name: String,
target_db: String,
read_concern: Option<ReadConcern>,
body: T,
) -> Self {
Self {
name,
target_db,
exhaust_allowed: false,
body,
lsid: None,
cluster_time: None,
server_api: None,
read_preference: None,
txn_number: None,
start_transaction: None,
autocommit: None,
read_concern: read_concern.map(Into::into),
recovery_token: None,
}
}
pub(crate) fn set_session(&mut self, session: &ClientSession) {
self.lsid = Some(session.id().clone())
}
pub(crate) fn set_cluster_time(&mut self, cluster_time: &ClusterTime) {
self.cluster_time = Some(cluster_time.clone());
}
pub(crate) fn set_recovery_token(&mut self, recovery_token: &Document) {
self.recovery_token = Some(recovery_token.clone());
}
pub(crate) fn set_txn_number(&mut self, txn_number: i64) {
self.txn_number = Some(txn_number);
}
pub(crate) fn set_server_api(&mut self, server_api: &ServerApi) {
self.server_api = Some(server_api.clone());
}
pub(crate) fn set_read_preference(&mut self, read_preference: ReadPreference) {
self.read_preference = Some(read_preference);
}
pub(crate) fn set_start_transaction(&mut self) {
self.start_transaction = Some(true);
}
pub(crate) fn set_autocommit(&mut self) {
self.autocommit = Some(false);
}
pub(crate) fn set_read_concern_level(&mut self, level: ReadConcernLevel) {
let inner = self.read_concern.get_or_insert(ReadConcernInternal {
level: None,
at_cluster_time: None,
after_cluster_time: None,
});
inner.level = Some(level);
}
pub(crate) fn set_snapshot_read_concern(&mut self, session: &ClientSession) {
let inner = self.read_concern.get_or_insert(ReadConcernInternal {
level: Some(ReadConcernLevel::Snapshot),
at_cluster_time: None,
after_cluster_time: None,
});
inner.at_cluster_time = session.snapshot_time;
}
pub(crate) fn set_after_cluster_time(&mut self, session: &ClientSession) {
if let Some(operation_time) = session.operation_time {
let inner = self.read_concern.get_or_insert(ReadConcernInternal {
level: None,
at_cluster_time: None,
after_cluster_time: None,
});
inner.after_cluster_time = Some(operation_time);
}
}
}
impl Command<RawDocumentBuf> {
pub(crate) fn into_bson_bytes(mut self) -> Result<Vec<u8>> {
let mut command = self.body;
self.body = rawdoc! {};
let rest_of_command = bson::to_raw_document_buf(&self)?;
extend_raw_document_buf(&mut command, rest_of_command)?;
Ok(command.into_bytes())
}
}
#[derive(Debug, Clone)]
pub(crate) struct RawCommandResponse {
pub(crate) source: ServerAddress,
raw: RawDocumentBuf,
}
impl RawCommandResponse {
#[cfg(test)]
pub(crate) fn with_document_and_address(source: ServerAddress, doc: Document) -> Result<Self> {
let mut raw = Vec::new();
doc.to_writer(&mut raw)?;
Ok(Self {
source,
raw: RawDocumentBuf::from_bytes(raw)?,
})
}
#[cfg(test)]
pub(crate) fn with_document(doc: Document) -> Result<Self> {
Self::with_document_and_address(
ServerAddress::Tcp {
host: "localhost".to_string(),
port: None,
},
doc,
)
}
pub(crate) fn new(source: ServerAddress, message: Message) -> Result<Self> {
let raw = message.single_document_response()?;
Ok(Self::new_raw(source, RawDocumentBuf::from_bytes(raw)?))
}
pub(crate) fn new_raw(source: ServerAddress, raw: RawDocumentBuf) -> Self {
Self { source, raw }
}
pub(crate) fn body<'a, T: Deserialize<'a>>(&'a self) -> Result<T> {
bson::from_slice(self.raw.as_bytes()).map_err(|e| {
Error::from(ErrorKind::InvalidResponse {
message: format!("{}", e),
})
})
}
pub(crate) fn body_utf8_lossy<'a, T: Deserialize<'a>>(&'a self) -> Result<T> {
bson::from_slice_utf8_lossy(self.raw.as_bytes()).map_err(|e| {
Error::from(ErrorKind::InvalidResponse {
message: format!("{}", e),
})
})
}
pub(crate) fn raw_body(&self) -> &RawDocument {
&self.raw
}
pub(crate) fn as_bytes(&self) -> &[u8] {
self.raw.as_bytes()
}
pub(crate) fn auth_response_body<T: DeserializeOwned>(
&self,
mechanism_name: &str,
) -> Result<T> {
self.body()
.map_err(|_| Error::invalid_authentication_response(mechanism_name))
}
pub(crate) fn into_hello_reply(self) -> Result<HelloReply> {
match self.body::<CommandResponse<HelloCommandResponse>>() {
Ok(response) if response.is_success() => {
let server_address = self.source_address().clone();
let cluster_time = response.cluster_time().cloned();
Ok(HelloReply {
server_address,
command_response: response.body,
cluster_time,
raw_command_response: self.into_raw_document_buf(),
})
}
_ => match self.body::<CommandResponse<CommandErrorBody>>() {
Ok(command_error_body) => Err(Error::new(
ErrorKind::Command(command_error_body.body.command_error),
command_error_body.body.error_labels,
)),
Err(_) => Err(ErrorKind::InvalidResponse {
message: "invalid server response".into(),
}
.into()),
},
}
}
pub(crate) fn source_address(&self) -> &ServerAddress {
&self.source
}
pub(crate) fn into_raw_document_buf(self) -> RawDocumentBuf {
self.raw
}
}