mod conversion;
mod error;
pub(crate) use crate::connector::mssql::MssqlUrl;
use crate::connector::{timeout, IsolationLevel, Transaction, TransactionOptions};
use crate::{
ast::{Query, Value},
connector::{metrics, queryable::*, DefaultTransaction, ResultSet},
visitor::{self, Visitor},
};
use async_trait::async_trait;
use futures::lock::Mutex;
use std::{
convert::TryFrom,
future::Future,
sync::atomic::{AtomicBool, Ordering},
time::Duration,
};
use tiberius::*;
use tokio::net::TcpStream;
use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
#[cfg(feature = "expose-drivers")]
pub use tiberius;
static SQL_SERVER_DEFAULT_ISOLATION: IsolationLevel = IsolationLevel::ReadCommitted;
#[async_trait]
impl TransactionCapable for Mssql {
async fn start_transaction<'a>(
&'a self,
isolation: Option<IsolationLevel>,
) -> crate::Result<Box<dyn Transaction + 'a>> {
let isolation = isolation
.or(self.url.query_params.transaction_isolation_level)
.or(Some(SQL_SERVER_DEFAULT_ISOLATION));
let opts = TransactionOptions::new(isolation, self.requires_isolation_first());
Ok(Box::new(
DefaultTransaction::new(self, self.begin_statement(), opts).await?,
))
}
}
#[derive(Debug)]
pub struct Mssql {
client: Mutex<Client<Compat<TcpStream>>>,
url: MssqlUrl,
socket_timeout: Option<Duration>,
is_healthy: AtomicBool,
}
impl Mssql {
pub async fn new(url: MssqlUrl) -> crate::Result<Self> {
let config = Config::from_jdbc_string(&url.connection_string)?;
let tcp = TcpStream::connect_named(&config).await?;
let socket_timeout = url.socket_timeout();
let connecting = async {
match Client::connect(config, tcp.compat_write()).await {
Ok(client) => Ok(client),
Err(tiberius::error::Error::Routing { host, port }) => {
let mut config = Config::from_jdbc_string(&url.connection_string)?;
config.host(host);
config.port(port);
let tcp = TcpStream::connect_named(&config).await?;
Client::connect(config, tcp.compat_write()).await
}
Err(e) => Err(e),
}
};
let client = timeout::connect(url.connect_timeout(), connecting).await?;
let this = Self {
client: Mutex::new(client),
url,
socket_timeout,
is_healthy: AtomicBool::new(true),
};
if let Some(isolation) = this.url.transaction_isolation_level() {
this.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation}"))
.await?;
};
Ok(this)
}
#[cfg(feature = "expose-drivers")]
pub fn client(&self) -> &Mutex<Client<Compat<TcpStream>>> {
&self.client
}
async fn perform_io<F, T>(&self, fut: F) -> crate::Result<T>
where
F: Future<Output = std::result::Result<T, tiberius::error::Error>>,
{
match timeout::socket(self.socket_timeout, fut).await {
Err(e) if e.is_closed() => {
self.is_healthy.store(false, Ordering::SeqCst);
Err(e)
}
res => res,
}
}
}
#[async_trait]
impl Queryable for Mssql {
async fn query(&self, q: Query<'_>) -> crate::Result<ResultSet> {
let (sql, params) = visitor::Mssql::build(q)?;
self.query_raw(&sql, ¶ms[..]).await
}
async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<ResultSet> {
metrics::query("mssql.query_raw", sql, params, move || async move {
let mut client = self.client.lock().await;
let mut query = tiberius::Query::new(sql);
for param in params {
query.bind(param);
}
let mut results = self.perform_io(query.query(&mut client)).await?.into_results().await?;
match results.pop() {
Some(rows) => {
let mut columns_set = false;
let mut columns = Vec::new();
let mut result_rows = Vec::with_capacity(rows.len());
for row in rows.into_iter() {
if !columns_set {
columns = row.columns().iter().map(|c| c.name().to_string()).collect();
columns_set = true;
}
let mut values: Vec<Value<'_>> = Vec::with_capacity(row.len());
for val in row.into_iter() {
values.push(Value::try_from(val)?);
}
result_rows.push(values);
}
Ok(ResultSet::new(columns, result_rows))
}
None => Ok(ResultSet::new(Vec::new(), Vec::new())),
}
})
.await
}
async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<ResultSet> {
self.query_raw(sql, params).await
}
async fn execute(&self, q: Query<'_>) -> crate::Result<u64> {
let (sql, params) = visitor::Mssql::build(q)?;
self.execute_raw(&sql, ¶ms[..]).await
}
async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<u64> {
metrics::query("mssql.execute_raw", sql, params, move || async move {
let mut query = tiberius::Query::new(sql);
for param in params {
query.bind(param);
}
let mut client = self.client.lock().await;
let changes = self.perform_io(query.execute(&mut client)).await?.total();
Ok(changes)
})
.await
}
async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<u64> {
self.execute_raw(sql, params).await
}
async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> {
metrics::query("mssql.raw_cmd", cmd, &[], move || async move {
let mut client = self.client.lock().await;
self.perform_io(client.simple_query(cmd)).await?.into_results().await?;
Ok(())
})
.await
}
async fn version(&self) -> crate::Result<Option<String>> {
let query = r#"SELECT @@VERSION AS version"#;
let rows = self.query_raw(query, &[]).await?;
let version_string = rows
.first()
.and_then(|row| row.get("version").and_then(|version| version.to_string()));
Ok(version_string)
}
fn is_healthy(&self) -> bool {
self.is_healthy.load(Ordering::SeqCst)
}
async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> {
self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}"))
.await?;
Ok(())
}
fn begin_statement(&self) -> &'static str {
"BEGIN TRAN"
}
fn requires_isolation_first(&self) -> bool {
true
}
}