use futures_util::FutureExt;
pub use mysql_common::named_params;
use mysql_common::{
constants::{DEFAULT_MAX_ALLOWED_PACKET, UTF8MB4_GENERAL_CI, UTF8_GENERAL_CI},
crypto,
io::ParseBuf,
packets::{
binlog_request::BinlogRequest, AuthPlugin, AuthSwitchRequest, CommonOkPacket, ErrPacket,
HandshakePacket, HandshakeResponse, OkPacket, OkPacketDeserializer, OldAuthSwitchRequest,
OldEofPacket, ResultSetTerminator, SslRequest,
},
proto::MySerialize,
row::Row,
};
use std::{
borrow::Cow,
fmt,
future::Future,
mem::{self, replace},
pin::Pin,
str::FromStr,
sync::Arc,
time::{Duration, Instant},
};
use crate::{
buffer_pool::PooledBuf,
conn::{pool::Pool, stmt_cache::StmtCache},
consts::{CapabilityFlags, Command, StatusFlags},
error::*,
io::Stream,
opts::Opts,
queryable::{
query_result::{QueryResult, ResultSetMeta},
transaction::TxStatus,
BinaryProtocol, Queryable, TextProtocol,
},
BinlogStream, InfileData, OptsBuilder,
};
use self::routines::Routine;
pub mod binlog_stream;
pub mod pool;
pub mod routines;
pub mod stmt_cache;
fn disconnect(mut conn: Conn) {
let disconnected = conn.inner.disconnected;
conn.inner.disconnected = true;
if !disconnected {
if std::thread::panicking() {
return;
}
if let Ok(handle) = tokio::runtime::Handle::try_current() {
handle.spawn(async move {
if let Ok(conn) = conn.cleanup_for_pool().await {
let _ = conn.disconnect().await;
}
});
}
}
}
#[derive(Debug, Clone)]
pub(crate) enum PendingResult {
Pending(ResultSetMeta),
Taken(Arc<ResultSetMeta>),
}
struct ConnInner {
stream: Option<Stream>,
id: u32,
is_mariadb: bool,
version: (u16, u16, u16),
socket: Option<String>,
capabilities: CapabilityFlags,
status: StatusFlags,
last_ok_packet: Option<OkPacket<'static>>,
last_err_packet: Option<mysql_common::packets::ServerError<'static>>,
pool: Option<Pool>,
pending_result: std::result::Result<Option<PendingResult>, ServerError>,
tx_status: TxStatus,
opts: Opts,
last_io: Instant,
wait_timeout: Duration,
stmt_cache: StmtCache,
nonce: Vec<u8>,
auth_plugin: AuthPlugin<'static>,
auth_switched: bool,
pub(crate) disconnected: bool,
infile_handler:
Option<Pin<Box<dyn Future<Output = crate::Result<InfileData>> + Send + Sync + 'static>>>,
}
impl fmt::Debug for ConnInner {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Conn")
.field("connection id", &self.id)
.field("server version", &self.version)
.field("pool", &self.pool)
.field("pending_result", &self.pending_result)
.field("tx_status", &self.tx_status)
.field("stream", &self.stream)
.field("options", &self.opts)
.finish()
}
}
impl ConnInner {
fn empty(opts: Opts) -> ConnInner {
ConnInner {
capabilities: opts.get_capabilities(),
status: StatusFlags::empty(),
last_ok_packet: None,
last_err_packet: None,
stream: None,
is_mariadb: false,
version: (0, 0, 0),
id: 0,
pending_result: Ok(None),
pool: None,
tx_status: TxStatus::None,
last_io: Instant::now(),
wait_timeout: Duration::from_secs(0),
stmt_cache: StmtCache::new(opts.stmt_cache_size()),
socket: opts.socket().map(Into::into),
opts,
nonce: Vec::default(),
auth_plugin: AuthPlugin::MysqlNativePassword,
auth_switched: false,
disconnected: false,
infile_handler: None,
}
}
fn stream_mut(&mut self) -> Result<&mut Stream> {
self.stream
.as_mut()
.ok_or_else(|| DriverError::ConnectionClosed.into())
}
}
#[derive(Debug)]
pub struct Conn {
inner: Box<ConnInner>,
}
impl Conn {
pub fn id(&self) -> u32 {
self.inner.id
}
pub fn last_insert_id(&self) -> Option<u64> {
self.inner
.last_ok_packet
.as_ref()
.and_then(|ok| ok.last_insert_id())
}
pub fn affected_rows(&self) -> u64 {
self.inner
.last_ok_packet
.as_ref()
.map(|ok| ok.affected_rows())
.unwrap_or_default()
}
pub fn info(&self) -> Cow<'_, str> {
self.inner
.last_ok_packet
.as_ref()
.and_then(|ok| ok.info_str())
.unwrap_or_else(|| "".into())
}
pub fn get_warnings(&self) -> u16 {
self.inner
.last_ok_packet
.as_ref()
.map(|ok| ok.warnings())
.unwrap_or_default()
}
pub fn last_ok_packet(&self) -> Option<&OkPacket<'static>> {
self.inner.last_ok_packet.as_ref()
}
pub(crate) fn stream_mut(&mut self) -> Result<&mut Stream> {
self.inner.stream_mut()
}
pub(crate) fn capabilities(&self) -> CapabilityFlags {
self.inner.capabilities
}
pub(crate) fn touch(&mut self) {
self.inner.last_io = Instant::now();
}
pub(crate) fn reset_seq_id(&mut self) {
if let Some(stream) = self.inner.stream.as_mut() {
stream.reset_seq_id();
}
}
pub(crate) fn sync_seq_id(&mut self) {
if let Some(stream) = self.inner.stream.as_mut() {
stream.sync_seq_id();
}
}
pub(crate) fn handle_ok(&mut self, ok_packet: OkPacket<'static>) {
self.inner.status = ok_packet.status_flags();
self.inner.last_err_packet = None;
self.inner.last_ok_packet = Some(ok_packet);
}
pub(crate) fn handle_err(&mut self, err_packet: ErrPacket<'_>) -> Result<()> {
match err_packet {
ErrPacket::Error(err) => {
self.inner.status = StatusFlags::empty();
self.inner.last_ok_packet = None;
self.inner.last_err_packet = Some(err.clone().into_owned());
Err(Error::from(err))
}
ErrPacket::Progress(_) => Ok(()),
}
}
pub(crate) fn get_tx_status(&self) -> TxStatus {
self.inner.tx_status
}
pub(crate) fn set_tx_status(&mut self, tx_status: TxStatus) {
self.inner.tx_status = tx_status;
}
pub(crate) fn use_pending_result(
&mut self,
) -> std::result::Result<Option<&PendingResult>, ServerError> {
if let Err(ref e) = self.inner.pending_result {
let e = e.clone();
self.inner.pending_result = Ok(None);
return Err(e);
} else {
Ok(self.inner.pending_result.as_ref().unwrap().as_ref())
}
}
pub(crate) fn get_pending_result(
&self,
) -> std::result::Result<Option<&PendingResult>, &ServerError> {
self.inner.pending_result.as_ref().map(|x| x.as_ref())
}
pub(crate) fn has_pending_result(&self) -> bool {
matches!(self.inner.pending_result, Err(_))
|| matches!(self.inner.pending_result, Ok(Some(_)))
}
pub(crate) fn set_pending_result(
&mut self,
meta: Option<ResultSetMeta>,
) -> std::result::Result<Option<PendingResult>, ServerError> {
replace(
&mut self.inner.pending_result,
Ok(meta.map(PendingResult::Pending)),
)
}
pub(crate) fn set_pending_result_error(
&mut self,
error: ServerError,
) -> std::result::Result<Option<PendingResult>, ServerError> {
replace(&mut self.inner.pending_result, Err(error))
}
pub(crate) fn take_pending_result(
&mut self,
) -> std::result::Result<Option<Arc<ResultSetMeta>>, ServerError> {
let mut output = None;
self.inner.pending_result = match replace(&mut self.inner.pending_result, Ok(None))? {
Some(PendingResult::Pending(x)) => {
let meta = Arc::new(x);
output = Some(meta.clone());
Ok(Some(PendingResult::Taken(meta)))
}
x => Ok(x),
};
Ok(output)
}
pub(crate) fn status(&self) -> StatusFlags {
self.inner.status
}
pub(crate) async fn routine<'a, F, T>(&mut self, mut f: F) -> crate::Result<T>
where
F: Routine<T> + 'a,
{
self.inner.disconnected = true;
let result = f.call(&mut *self).await;
match result {
result @ Ok(_) | result @ Err(crate::Error::Server(_)) => {
self.inner.disconnected = false;
result
}
Err(err) => {
if self.inner.stream.is_some() {
self.take_stream().close().await?;
}
Err(err)
}
}
}
pub fn server_version(&self) -> (u16, u16, u16) {
self.inner.version
}
pub fn opts(&self) -> &Opts {
&self.inner.opts
}
pub fn set_infile_handler<T>(&mut self, handler: T)
where
T: Future<Output = crate::Result<InfileData>>,
T: Send + Sync + 'static,
{
self.inner.infile_handler = Some(Box::pin(handler));
}
fn take_stream(&mut self) -> Stream {
self.inner.stream.take().unwrap()
}
pub async fn disconnect(mut self) -> Result<()> {
if !self.inner.disconnected {
self.inner.disconnected = true;
self.write_command_data(Command::COM_QUIT, &[]).await?;
let stream = self.take_stream();
stream.close().await?;
}
Ok(())
}
async fn close_conn(mut self) -> Result<()> {
self = self.cleanup_for_pool().await?;
self.disconnect().await
}
fn is_secure(&self) -> bool {
#[cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))]
if let Some(ref stream) = self.inner.stream {
stream.is_secure()
} else {
false
}
#[cfg(not(any(feature = "native-tls-tls", feature = "rustls-tls")))]
false
}
fn take(&mut self) -> Conn {
mem::replace(self, Conn::empty(Default::default()))
}
fn empty(opts: Opts) -> Self {
Self {
inner: Box::new(ConnInner::empty(opts)),
}
}
fn setup_stream(&mut self) -> Result<()> {
debug_assert!(self.inner.stream.is_some());
if let Some(stream) = self.inner.stream.as_mut() {
stream.set_tcp_nodelay(self.inner.opts.tcp_nodelay())?;
}
Ok(())
}
async fn handle_handshake(&mut self) -> Result<()> {
let packet = self.read_packet().await?;
let handshake = ParseBuf(&*packet).parse::<HandshakePacket>(())?;
self.inner.nonce = {
let mut nonce = Vec::from(handshake.scramble_1_ref());
nonce.extend_from_slice(handshake.scramble_2_ref().unwrap_or(&[][..]));
nonce.resize(20, 0);
nonce
};
self.inner.capabilities = handshake.capabilities() & self.inner.opts.get_capabilities();
self.inner.version = handshake
.maria_db_server_version_parsed()
.map(|version| {
self.inner.is_mariadb = true;
version
})
.or_else(|| handshake.server_version_parsed())
.unwrap_or((0, 0, 0));
self.inner.id = handshake.connection_id();
self.inner.status = handshake.status_flags();
self.inner.auth_plugin = match handshake.auth_plugin() {
Some(AuthPlugin::MysqlNativePassword | AuthPlugin::MysqlOldPassword) => {
AuthPlugin::MysqlNativePassword
}
Some(AuthPlugin::CachingSha2Password) => AuthPlugin::CachingSha2Password,
Some(AuthPlugin::Other(ref name)) => {
let name = String::from_utf8_lossy(name).into();
return Err(DriverError::UnknownAuthPlugin { name }.into());
}
None => AuthPlugin::MysqlNativePassword,
};
Ok(())
}
async fn switch_to_ssl_if_needed(&mut self) -> Result<()> {
if self
.inner
.opts
.get_capabilities()
.contains(CapabilityFlags::CLIENT_SSL)
{
if !self
.inner
.capabilities
.contains(CapabilityFlags::CLIENT_SSL)
{
return Err(DriverError::NoClientSslFlagFromServer.into());
}
let collation = if self.inner.version >= (5, 5, 3) {
UTF8MB4_GENERAL_CI
} else {
UTF8_GENERAL_CI
};
let ssl_request = SslRequest::new(
self.inner.capabilities,
DEFAULT_MAX_ALLOWED_PACKET as u32,
collation as u8,
);
self.write_struct(&ssl_request).await?;
let conn = self;
let ssl_opts = conn.opts().ssl_opts().cloned().expect("unreachable");
let domain = conn.opts().ip_or_hostname().into();
conn.stream_mut()?.make_secure(domain, ssl_opts).await?;
Ok(())
} else {
Ok(())
}
}
async fn do_handshake_response(&mut self) -> Result<()> {
let auth_data = self
.inner
.auth_plugin
.gen_data(self.inner.opts.pass(), &*self.inner.nonce);
let handshake_response = HandshakeResponse::new(
auth_data.as_deref(),
self.inner.version,
self.inner.opts.user().map(|x| x.as_bytes()),
self.inner.opts.db_name().map(|x| x.as_bytes()),
Some(self.inner.auth_plugin.borrow()),
self.capabilities(),
Default::default(), );
let mut buf = crate::BUFFER_POOL.get();
handshake_response.serialize(buf.as_mut());
self.write_packet(buf).await?;
Ok(())
}
async fn perform_auth_switch(
&mut self,
auth_switch_request: AuthSwitchRequest<'_>,
) -> Result<()> {
if !self.inner.auth_switched {
self.inner.auth_switched = true;
self.inner.nonce = auth_switch_request.plugin_data().to_vec();
if matches!(
auth_switch_request.auth_plugin(),
AuthPlugin::MysqlOldPassword
) {
if self.inner.opts.secure_auth() {
return Err(DriverError::MysqlOldPasswordDisabled.into());
}
}
self.inner.auth_plugin = auth_switch_request.auth_plugin().clone().into_owned();
let plugin_data = self
.inner
.auth_plugin
.gen_data(self.inner.opts.pass(), &*self.inner.nonce);
if let Some(plugin_data) = plugin_data {
self.write_struct(&plugin_data).await?;
} else {
self.write_packet(crate::BUFFER_POOL.get()).await?;
}
self.continue_auth().await?;
Ok(())
} else {
unreachable!("auth_switched flag should be checked by caller")
}
}
fn continue_auth(&mut self) -> Pin<Box<dyn Future<Output = Result<()>> + Send + '_>> {
Box::pin(async move {
match self.inner.auth_plugin {
AuthPlugin::MysqlNativePassword | AuthPlugin::MysqlOldPassword => {
self.continue_mysql_native_password_auth().await?;
Ok(())
}
AuthPlugin::CachingSha2Password => {
self.continue_caching_sha2_password_auth().await?;
Ok(())
}
AuthPlugin::Other(ref name) => Err(DriverError::UnknownAuthPlugin {
name: String::from_utf8_lossy(name.as_ref()).to_string(),
}
.into()),
}
})
}
fn switch_to_compression(&mut self) -> Result<()> {
if self
.capabilities()
.contains(CapabilityFlags::CLIENT_COMPRESS)
{
if let Some(compression) = self.inner.opts.compression() {
if let Some(stream) = self.inner.stream.as_mut() {
stream.compress(compression);
}
}
}
Ok(())
}
async fn continue_caching_sha2_password_auth(&mut self) -> Result<()> {
let packet = self.read_packet().await?;
match packet.get(0) {
Some(0x00) => {
Ok(())
}
Some(0x01) => match packet.get(1) {
Some(0x03) => {
self.drop_packet().await
}
Some(0x04) => {
let pass = self.inner.opts.pass().unwrap_or_default();
let mut pass = crate::BUFFER_POOL.get_with(pass.as_bytes());
pass.as_mut().push(0);
if self.is_secure() {
self.write_packet(pass).await?;
} else {
self.write_bytes(&[0x02][..]).await?;
let packet = self.read_packet().await?;
let key = &packet[1..];
for (i, byte) in pass.as_mut().iter_mut().enumerate() {
*byte ^= self.inner.nonce[i % self.inner.nonce.len()];
}
let encrypted_pass = crypto::encrypt(&*pass, key);
self.write_bytes(&*encrypted_pass).await?;
};
self.drop_packet().await?;
Ok(())
}
_ => Err(DriverError::UnexpectedPacket {
payload: packet.to_vec(),
}
.into()),
},
Some(0xfe) if !self.inner.auth_switched => {
let auth_switch_request = ParseBuf(&*packet).parse::<AuthSwitchRequest>(())?;
self.perform_auth_switch(auth_switch_request).await?;
Ok(())
}
_ => Err(DriverError::UnexpectedPacket {
payload: packet.to_vec(),
}
.into()),
}
}
async fn continue_mysql_native_password_auth(&mut self) -> Result<()> {
let packet = self.read_packet().await?;
match packet.get(0) {
Some(0x00) => Ok(()),
Some(0xfe) if !self.inner.auth_switched => {
let auth_switch = if packet.len() > 1 {
ParseBuf(&*packet).parse(())?
} else {
let _ = ParseBuf(&*packet).parse::<OldAuthSwitchRequest>(())?;
AuthSwitchRequest::new(
"mysql_old_password".as_bytes(),
self.inner.nonce.clone(),
)
};
self.perform_auth_switch(auth_switch).await
}
_ => Err(DriverError::UnexpectedPacket {
payload: packet.to_vec(),
}
.into()),
}
}
fn handle_packet(&mut self, packet: &PooledBuf) -> Result<bool> {
let ok_packet = if self.has_pending_result() {
if self
.capabilities()
.contains(CapabilityFlags::CLIENT_DEPRECATE_EOF)
{
ParseBuf(&*packet)
.parse::<OkPacketDeserializer<ResultSetTerminator>>(self.capabilities())
.map(|x| x.into_inner())
} else {
ParseBuf(&*packet)
.parse::<OkPacketDeserializer<OldEofPacket>>(self.capabilities())
.map(|x| x.into_inner())
}
} else {
ParseBuf(&*packet)
.parse::<OkPacketDeserializer<CommonOkPacket>>(self.capabilities())
.map(|x| x.into_inner())
};
if let Ok(ok_packet) = ok_packet {
self.handle_ok(ok_packet.into_owned());
} else {
let err_packet = ParseBuf(&*packet).parse::<ErrPacket>(self.capabilities());
if let Ok(err_packet) = err_packet {
self.handle_err(err_packet)?;
return Ok(true);
}
}
Ok(false)
}
pub(crate) async fn read_packet(&mut self) -> Result<PooledBuf> {
loop {
let packet = crate::io::ReadPacket::new(&mut *self)
.await
.map_err(|io_err| {
self.inner.stream.take();
self.inner.disconnected = true;
Error::from(io_err)
})?;
if self.handle_packet(&packet)? {
continue;
} else {
return Ok(packet);
}
}
}
pub(crate) async fn read_packets(&mut self, n: usize) -> Result<Vec<PooledBuf>> {
let mut packets = Vec::with_capacity(n);
for _ in 0..n {
packets.push(self.read_packet().await?);
}
Ok(packets)
}
pub(crate) async fn write_packet(&mut self, data: PooledBuf) -> Result<()> {
crate::io::WritePacket::new(&mut *self, data)
.await
.map_err(|io_err| {
self.inner.stream.take();
self.inner.disconnected = true;
From::from(io_err)
})
}
pub(crate) async fn write_bytes(&mut self, bytes: &[u8]) -> Result<()> {
let buf = crate::BUFFER_POOL.get_with(bytes);
self.write_packet(buf).await
}
pub(crate) async fn write_struct<T: MySerialize>(&mut self, x: &T) -> Result<()> {
let mut buf = crate::BUFFER_POOL.get();
x.serialize(buf.as_mut());
self.write_packet(buf).await
}
pub(crate) async fn write_command<T: MySerialize>(&mut self, cmd: &T) -> Result<()> {
self.clean_dirty().await?;
self.reset_seq_id();
self.write_struct(cmd).await
}
pub(crate) async fn write_command_raw(&mut self, body: PooledBuf) -> Result<()> {
debug_assert!(!body.is_empty());
self.clean_dirty().await?;
self.reset_seq_id();
self.write_packet(body).await
}
pub(crate) async fn write_command_data<T>(&mut self, cmd: Command, cmd_data: T) -> Result<()>
where
T: AsRef<[u8]>,
{
let cmd_data = cmd_data.as_ref();
let mut buf = crate::BUFFER_POOL.get();
let body = buf.as_mut();
body.push(cmd as u8);
body.extend_from_slice(cmd_data);
self.write_command_raw(buf).await
}
async fn drop_packet(&mut self) -> Result<()> {
self.read_packet().await?;
Ok(())
}
async fn run_init_commands(&mut self) -> Result<()> {
let mut init = self.inner.opts.init().to_vec();
while let Some(query) = init.pop() {
self.query_drop(query).await?;
}
Ok(())
}
pub fn new<T: Into<Opts>>(opts: T) -> crate::BoxFuture<'static, Conn> {
let opts = opts.into();
async move {
let mut conn = Conn::empty(opts.clone());
let stream = if let Some(_path) = opts.socket() {
#[cfg(unix)]
{
Stream::connect_socket(_path.to_owned()).await?
}
#[cfg(target_os = "windows")]
return Err(crate::DriverError::NamedPipesDisabled.into());
} else {
let keepalive = opts
.tcp_keepalive()
.map(|x| std::time::Duration::from_millis(x.into()));
Stream::connect_tcp(opts.hostport_or_url(), keepalive).await?
};
conn.inner.stream = Some(stream);
conn.setup_stream()?;
conn.handle_handshake().await?;
conn.switch_to_ssl_if_needed().await?;
conn.do_handshake_response().await?;
conn.continue_auth().await?;
conn.switch_to_compression()?;
conn.read_settings().await?;
conn.reconnect_via_socket_if_needed().await?;
conn.run_init_commands().await?;
Ok(conn)
}
.boxed()
}
pub async fn from_url<T: AsRef<str>>(url: T) -> Result<Conn> {
Conn::new(Opts::from_str(url.as_ref())?).await
}
async fn reconnect_via_socket_if_needed(&mut self) -> Result<()> {
if let Some(socket) = self.inner.socket.as_ref() {
let opts = self.inner.opts.clone();
if opts.socket().is_none() {
let opts = OptsBuilder::from_opts(opts).socket(Some(&**socket));
if let Ok(conn) = Conn::new(opts).await {
let old_conn = std::mem::replace(self, conn);
old_conn.close_conn().await?;
}
}
}
Ok(())
}
async fn read_settings(&mut self) -> Result<()> {
let read_socket = self.inner.opts.prefer_socket() && self.inner.socket.is_none();
let read_max_allowed_packet = self.opts().max_allowed_packet().is_none();
let read_wait_timeout = self.opts().wait_timeout().is_none();
let settings: Option<Row> = if read_socket || read_max_allowed_packet || read_wait_timeout {
self.query_internal("SELECT @@socket, @@max_allowed_packet, @@wait_timeout")
.await?
} else {
None
};
if read_socket {
self.inner.socket = settings.as_ref().map(|s| s.get("@@socket")).unwrap_or(None);
}
let max_allowed_packet = if read_max_allowed_packet {
settings
.as_ref()
.map(|s| s.get("@@max_allowed_packet"))
.unwrap()
} else {
self.opts().max_allowed_packet()
};
if let Some(stream) = self.inner.stream.as_mut() {
stream.set_max_allowed_packet(max_allowed_packet.unwrap_or(DEFAULT_MAX_ALLOWED_PACKET));
}
let wait_timeout = if read_wait_timeout {
settings.as_ref().map(|s| s.get("@@wait_timeout")).unwrap()
} else {
self.opts().wait_timeout()
};
self.inner.wait_timeout = Duration::from_secs(wait_timeout.unwrap_or(28800) as u64);
Ok(())
}
fn expired(&self) -> bool {
let ttl = self
.inner
.opts
.conn_ttl()
.unwrap_or(self.inner.wait_timeout);
!ttl.is_zero() && self.idling() > ttl
}
fn idling(&self) -> Duration {
self.inner.last_io.elapsed()
}
pub async fn reset(&mut self) -> Result<()> {
let pool = self.inner.pool.clone();
let supports_com_reset_connection = if self.inner.is_mariadb {
self.inner.version >= (10, 2, 4)
} else {
self.inner.version > (5, 7, 2)
};
if supports_com_reset_connection {
self.routine(routines::ResetRoutine).await?;
} else {
let opts = self.inner.opts.clone();
let old_conn = std::mem::replace(self, Conn::new(opts).await?);
old_conn.close_conn().await?;
};
self.inner.stmt_cache.clear();
self.inner.infile_handler = None;
self.inner.pool = pool;
Ok(())
}
async fn rollback_transaction(&mut self) -> Result<()> {
debug_assert_ne!(self.inner.tx_status, TxStatus::None);
self.inner.tx_status = TxStatus::None;
self.query_drop("ROLLBACK").await
}
pub(crate) fn more_results_exists(&self) -> bool {
self.status()
.contains(StatusFlags::SERVER_MORE_RESULTS_EXISTS)
}
pub(crate) async fn drop_result(&mut self) -> Result<()> {
let meta = match self.set_pending_result(None)? {
Some(PendingResult::Pending(meta)) => Some(meta),
Some(PendingResult::Taken(meta)) => {
Some(Arc::try_unwrap(meta).expect("Conn::drop_result call on a pending result that may still be droped by someone else"))
}
None => None,
};
let _ = self.set_pending_result(meta);
match self.use_pending_result() {
Ok(Some(PendingResult::Pending(ResultSetMeta::Text(_)))) => {
QueryResult::<'_, '_, TextProtocol>::new(self)
.drop_result()
.await
}
Ok(Some(PendingResult::Pending(ResultSetMeta::Binary(_)))) => {
QueryResult::<'_, '_, BinaryProtocol>::new(self)
.drop_result()
.await
}
Ok(None) => Ok(()),
Ok(Some(PendingResult::Taken(_))) | Err(_) => {
unreachable!("this case must be handled earlier in this function")
}
}
}
async fn cleanup_for_pool(mut self) -> Result<Self> {
loop {
let result = if self.has_pending_result() {
self.drop_result().await
} else if self.inner.tx_status != TxStatus::None {
self.rollback_transaction().await
} else {
break;
};
if let Err(err) = result {
if err.is_fatal() {
return Err(err);
}
}
}
Ok(self)
}
async fn register_as_slave(&mut self, server_id: u32) -> Result<()> {
use mysql_common::packets::ComRegisterSlave;
self.query_drop("SET @master_binlog_checksum='ALL'").await?;
self.write_command(&ComRegisterSlave::new(server_id))
.await?;
self.read_packet().await?;
Ok(())
}
async fn request_binlog(&mut self, request: BinlogRequest<'_>) -> Result<()> {
self.register_as_slave(request.server_id()).await?;
self.write_command(&request.as_cmd()).await?;
Ok(())
}
pub async fn get_binlog_stream(mut self, request: BinlogRequest<'_>) -> Result<BinlogStream> {
self.request_binlog(request).await?;
Ok(BinlogStream::new(self))
}
}
#[cfg(test)]
mod test {
use bytes::Bytes;
use futures_util::stream::{self, StreamExt};
use mysql_common::{binlog::events::EventData, constants::MAX_PAYLOAD_LEN};
use tokio::time::timeout;
use std::time::Duration;
use crate::{
from_row, params, prelude::*, test_misc::get_opts, BinlogDumpFlags, BinlogRequest, Conn,
Error, OptsBuilder, Pool, WhiteListFsHandler,
};
async fn gen_dummy_data() -> super::Result<()> {
let mut conn = Conn::new(get_opts()).await?;
"CREATE TABLE IF NOT EXISTS customers (customer_id int not null)"
.ignore(&mut conn)
.await?;
for i in 0_u8..100 {
"INSERT INTO customers(customer_id) VALUES (?)"
.with((i,))
.ignore(&mut conn)
.await?;
}
"DROP TABLE customers".ignore(&mut conn).await?;
Ok(())
}
async fn create_binlog_stream_conn(pool: Option<&Pool>) -> super::Result<(Conn, Vec<u8>, u64)> {
let mut conn = match pool {
None => Conn::new(get_opts()).await.unwrap(),
Some(pool) => pool.get_conn().await.unwrap(),
};
if let Ok(Some(gtid_mode)) = "SELECT @@GLOBAL.GTID_MODE"
.first::<String, _>(&mut conn)
.await
{
if !gtid_mode.starts_with("ON") {
panic!(
"GTID_MODE is disabled \
(enable using --gtid_mode=ON --enforce_gtid_consistency=ON)"
);
}
}
let row: crate::Row = "SHOW BINARY LOGS".first(&mut conn).await.unwrap().unwrap();
let filename = row.get(0).unwrap();
let position = row.get(1).unwrap();
gen_dummy_data().await.unwrap();
Ok((conn, filename, position))
}
#[tokio::test]
async fn should_read_binlog() -> super::Result<()> {
read_binlog_streams_and_close_their_connections(None, (12, 13, 14))
.await
.unwrap();
let pool = Pool::new(get_opts());
read_binlog_streams_and_close_their_connections(Some(&pool), (15, 16, 17))
.await
.unwrap();
timeout(Duration::from_secs(10), pool.disconnect())
.await
.unwrap()
.unwrap();
Ok(())
}
async fn read_binlog_streams_and_close_their_connections(
pool: Option<&Pool>,
binlog_server_ids: (u32, u32, u32),
) -> super::Result<()> {
let (conn, filename, pos) = create_binlog_stream_conn(pool).await.unwrap();
let is_mariadb = conn.inner.is_mariadb;
let mut binlog_stream = conn
.get_binlog_stream(
BinlogRequest::new(binlog_server_ids.0)
.with_filename(filename)
.with_pos(pos),
)
.await
.unwrap();
let mut events_num = 0;
while let Ok(Some(event)) = timeout(Duration::from_secs(10), binlog_stream.next()).await {
let event = event.unwrap();
events_num += 1;
event.header().event_type().unwrap();
match event.read_data()?.unwrap() {
EventData::RowsEvent(re) => {
let tme = binlog_stream.get_tme(re.table_id());
for row in re.rows(tme.unwrap()) {
row.unwrap();
}
}
_ => (),
}
}
assert!(events_num > 0);
timeout(Duration::from_secs(10), binlog_stream.close())
.await
.unwrap()
.unwrap();
if !is_mariadb {
let (conn, filename, pos) = create_binlog_stream_conn(pool).await.unwrap();
let mut binlog_stream = conn
.get_binlog_stream(
BinlogRequest::new(binlog_server_ids.1)
.with_use_gtid(true)
.with_filename(filename)
.with_pos(pos),
)
.await
.unwrap();
events_num = 0;
while let Ok(Some(event)) = timeout(Duration::from_secs(10), binlog_stream.next()).await
{
let event = event.unwrap();
events_num += 1;
event.header().event_type().unwrap();
match event.read_data()?.unwrap() {
EventData::RowsEvent(re) => {
let tme = binlog_stream.get_tme(re.table_id());
for row in re.rows(tme.unwrap()) {
row.unwrap();
}
}
_ => (),
}
}
assert!(events_num > 0);
timeout(Duration::from_secs(10), binlog_stream.close())
.await
.unwrap()
.unwrap();
}
let (conn, filename, pos) = create_binlog_stream_conn(pool).await.unwrap();
let mut binlog_stream = conn
.get_binlog_stream(
BinlogRequest::new(binlog_server_ids.2)
.with_filename(filename)
.with_pos(pos)
.with_flags(BinlogDumpFlags::BINLOG_DUMP_NON_BLOCK),
)
.await
.unwrap();
events_num = 0;
while let Some(event) = binlog_stream.next().await {
let event = event.unwrap();
events_num += 1;
event.header().event_type().unwrap();
event.read_data().unwrap();
}
assert!(events_num > 0);
timeout(Duration::from_secs(10), binlog_stream.close())
.await
.unwrap()
.unwrap();
Ok(())
}
#[test]
fn opts_should_satisfy_send_and_sync() {
struct A<T: Sync + Send>(T);
A(get_opts());
}
#[tokio::test]
async fn should_connect_without_database() -> super::Result<()> {
let mut conn: Conn = Conn::new(get_opts().db_name(None::<String>)).await?;
conn.ping().await?;
conn.disconnect().await?;
let mut conn: Conn = Conn::new(get_opts().db_name(Some(""))).await?;
conn.ping().await?;
conn.disconnect().await?;
Ok(())
}
#[tokio::test]
async fn should_clean_state_if_wrapper_is_dropeed() -> super::Result<()> {
let mut conn: Conn = Conn::new(get_opts()).await?;
conn.query_drop("CREATE TEMPORARY TABLE mysql.foo (id SERIAL)")
.await?;
conn.query_iter("SELECT 1").await?;
conn.ping().await?;
let mut tx = conn.start_transaction(Default::default()).await?;
tx.query_drop("INSERT INTO mysql.foo (id) VALUES (42)")
.await?;
tx.exec_iter("SELECT COUNT(*) FROM mysql.foo", ()).await?;
drop(tx);
conn.ping().await?;
let count: u8 = conn
.query_first("SELECT COUNT(*) FROM mysql.foo")
.await?
.unwrap_or_default();
assert_eq!(count, 0);
Ok(())
}
#[tokio::test]
async fn should_connect() -> super::Result<()> {
let mut conn: Conn = Conn::new(get_opts()).await?;
conn.ping().await?;
let plugins: Vec<String> = conn
.query_map("SHOW PLUGINS", |mut row: crate::Row| {
row.take("Name").unwrap()
})
.await?;
let variants = vec![
("caching_sha2_password", 2_u8, "non-empty"),
("caching_sha2_password", 2_u8, ""),
("mysql_native_password", 0_u8, "non-empty"),
("mysql_native_password", 0_u8, ""),
]
.into_iter()
.filter(|variant| plugins.iter().any(|p| p == variant.0));
for (plug, val, pass) in variants {
let _ = conn.query_drop("DROP USER 'test_user'@'%'").await;
let query = format!("CREATE USER 'test_user'@'%' IDENTIFIED WITH {}", plug);
conn.query_drop(query).await.unwrap();
if (8, 0, 11) <= conn.inner.version && conn.inner.version <= (9, 0, 0) {
conn.query_drop(format!("SET PASSWORD FOR 'test_user'@'%' = '{}'", pass))
.await
.unwrap();
} else {
conn.query_drop(format!("SET old_passwords = {}", val))
.await
.unwrap();
conn.query_drop(format!(
"SET PASSWORD FOR 'test_user'@'%' = PASSWORD('{}')",
pass
))
.await
.unwrap();
};
let opts = get_opts()
.user(Some("test_user"))
.pass(Some(pass))
.db_name(None::<String>);
let result = Conn::new(opts).await;
conn.query_drop("DROP USER 'test_user'@'%'").await.unwrap();
result?.disconnect().await?;
}
if crate::test_misc::test_compression() {
assert!(format!("{:?}", conn).contains("Compression"));
}
if crate::test_misc::test_ssl() {
assert!(format!("{:?}", conn).contains("Tls"));
}
conn.disconnect().await?;
Ok(())
}
#[test]
fn should_not_panic_if_dropped_without_tokio_runtime() {
let fut = Conn::new(get_opts());
let runtime = tokio::runtime::Runtime::new().unwrap();
runtime.block_on(async {
fut.await.unwrap();
});
}
#[tokio::test]
async fn should_execute_init_queries_on_new_connection() -> super::Result<()> {
let opts = OptsBuilder::from_opts(get_opts()).init(vec!["SET @a = 42", "SET @b = 'foo'"]);
let mut conn = Conn::new(opts).await?;
let result: Vec<(u8, String)> = conn.query("SELECT @a, @b").await?;
conn.disconnect().await?;
assert_eq!(result, vec![(42, "foo".into())]);
Ok(())
}
#[tokio::test]
async fn should_reset_the_connection() -> super::Result<()> {
let mut conn = Conn::new(get_opts()).await?;
conn.exec_drop("SELECT ?", (1_u8,)).await?;
conn.reset().await?;
conn.exec_drop("SELECT ?", (1_u8,)).await?;
conn.disconnect().await?;
Ok(())
}
#[tokio::test]
async fn should_not_cache_statements_if_stmt_cache_size_is_zero() -> super::Result<()> {
let opts = OptsBuilder::from_opts(get_opts()).stmt_cache_size(0);
let mut conn = Conn::new(opts).await?;
conn.exec_drop("DO ?", (1_u8,)).await?;
let stmt = conn.prep("DO 2").await?;
conn.exec_drop(&stmt, ()).await?;
conn.exec_drop(&stmt, ()).await?;
conn.close(stmt).await?;
conn.exec_drop("DO 3", ()).await?;
conn.exec_batch("DO 4", vec![(), ()]).await?;
conn.exec_first::<u8, _, _>("DO 5", ()).await?;
let row: Option<(crate::Value, usize)> = conn
.query_first("SHOW SESSION STATUS LIKE 'Com_stmt_close';")
.await?;
assert_eq!(row.unwrap().1, 1);
assert_eq!(conn.inner.stmt_cache.len(), 0);
conn.disconnect().await?;
Ok(())
}
#[tokio::test]
async fn should_hold_stmt_cache_size_bound() -> super::Result<()> {
let opts = OptsBuilder::from_opts(get_opts()).stmt_cache_size(3);
let mut conn = Conn::new(opts).await?;
conn.exec_drop("DO 1", ()).await?;
conn.exec_drop("DO 2", ()).await?;
conn.exec_drop("DO 3", ()).await?;
conn.exec_drop("DO 1", ()).await?;
conn.exec_drop("DO 4", ()).await?;
conn.exec_drop("DO 3", ()).await?;
conn.exec_drop("DO 5", ()).await?;
conn.exec_drop("DO 6", ()).await?;
let row_opt = conn
.query_first("SHOW SESSION STATUS LIKE 'Com_stmt_close';")
.await?;
let (_, count): (String, usize) = row_opt.unwrap();
assert_eq!(count, 3);
let order = conn
.stmt_cache_ref()
.iter()
.map(|item| item.1.query.0.as_ref())
.collect::<Vec<&[u8]>>();
assert_eq!(order, &[b"DO 6", b"DO 5", b"DO 3"]);
conn.disconnect().await?;
Ok(())
}
#[tokio::test]
async fn should_perform_queries() -> super::Result<()> {
let mut conn = Conn::new(get_opts()).await?;
for x in (MAX_PAYLOAD_LEN - 2)..=(MAX_PAYLOAD_LEN + 2) {
let long_string = ::std::iter::repeat('A').take(x).collect::<String>();
let result: Vec<(String, u8)> = conn
.query(format!(r"SELECT '{}', 231", long_string))
.await?;
assert_eq!((long_string, 231_u8), result[0]);
}
conn.disconnect().await?;
Ok(())
}
#[tokio::test]
async fn should_query_drop() -> super::Result<()> {
let mut conn = Conn::new(get_opts()).await?;
conn.query_drop("CREATE TEMPORARY TABLE tmp (id int DEFAULT 10, name text)")
.await?;
conn.query_drop("INSERT INTO tmp VALUES (1, 'foo')").await?;
let result: Option<u8> = conn.query_first("SELECT COUNT(*) FROM tmp").await?;
conn.disconnect().await?;
assert_eq!(result, Some(1_u8));
Ok(())
}
#[tokio::test]
async fn should_prepare_statement() -> super::Result<()> {
let mut conn = Conn::new(get_opts()).await?;
let stmt = conn.prep(r"SELECT ?").await?;
conn.close(stmt).await?;
conn.disconnect().await?;
let mut conn = Conn::new(get_opts()).await?;
let stmt = conn.prep(r"SELECT :foo").await?;
{
let query = String::from("SELECT ?, ?");
let stmt = conn.prep(&*query).await?;
conn.close(stmt).await?;
{
let mut conn = Conn::new(get_opts()).await?;
let stmt = conn.prep(&*query).await?;
conn.close(stmt).await?;
conn.disconnect().await?;
}
}
conn.close(stmt).await?;
conn.disconnect().await?;
Ok(())
}
#[tokio::test]
async fn should_execute_statement() -> super::Result<()> {
let long_string = ::std::iter::repeat('A')
.take(18 * 1024 * 1024)
.collect::<String>();
let mut conn = Conn::new(get_opts()).await?;
let stmt = conn.prep(r"SELECT ?").await?;
let result = conn.exec_iter(&stmt, (&long_string,)).await?;
let mut mapped = result
.map_and_drop(|row| from_row::<(String,)>(row))
.await?;
assert_eq!(mapped.len(), 1);
assert_eq!(mapped.pop(), Some((long_string,)));
let result = conn.exec_iter(&stmt, (42_u8,)).await?;
let collected = result.collect_and_drop::<(u8,)>().await?;
assert_eq!(collected, vec![(42u8,)]);
let result = conn.exec_iter(&stmt, (8_u8,)).await?;
let reduced = result
.reduce_and_drop(2, |mut acc, row| {
acc += from_row::<i32>(row);
acc
})
.await?;
conn.close(stmt).await?;
conn.disconnect().await?;
assert_eq!(reduced, 10);
let mut conn = Conn::new(get_opts()).await?;
let stmt = conn.prep(r"SELECT :foo, :bar, :foo, 3").await?;
let result = conn
.exec_iter(&stmt, params! { "foo" => "quux", "bar" => "baz" })
.await?;
let mut mapped = result
.map_and_drop(|row| from_row::<(String, String, String, u8)>(row))
.await?;
assert_eq!(mapped.len(), 1);
assert_eq!(
mapped.pop(),
Some(("quux".into(), "baz".into(), "quux".into(), 3))
);
let result = conn
.exec_iter(&stmt, params! { "foo" => 2, "bar" => 3 })
.await?;
let collected = result.collect_and_drop::<(u8, u8, u8, u8)>().await?;
assert_eq!(collected, vec![(2, 3, 2, 3)]);
let result = conn
.exec_iter(&stmt, params! { "foo" => 2, "bar" => 3 })
.await?;
let reduced = result
.reduce_and_drop(0, |acc, row| {
let (a, b, c, d): (u8, u8, u8, u8) = from_row(row);
acc + a + b + c + d
})
.await?;
conn.close(stmt).await?;
conn.disconnect().await?;
assert_eq!(reduced, 10);
Ok(())
}
#[tokio::test]
async fn should_prep_exec_statement() -> super::Result<()> {
let mut conn = Conn::new(get_opts()).await?;
let result = conn
.exec_iter(r"SELECT :a, :b, :a", params! { "a" => 2, "b" => 3 })
.await?;
let output = result
.map_and_drop(|row| {
let (a, b, c): (u8, u8, u8) = from_row(row);
a * b * c
})
.await?;
conn.disconnect().await?;
assert_eq!(output[0], 12u8);
Ok(())
}
#[tokio::test]
async fn should_first_exec_statement() -> super::Result<()> {
let mut conn = Conn::new(get_opts()).await?;
let output = conn
.exec_first(
r"SELECT :a UNION ALL SELECT :b",
params! { "a" => 2, "b" => 3 },
)
.await?;
conn.disconnect().await?;
assert_eq!(output, Some(2u8));
Ok(())
}
#[tokio::test]
async fn issue_107() -> super::Result<()> {
let mut conn = Conn::new(get_opts()).await?;
conn.query_drop(
r"CREATE TEMPORARY TABLE mysql.issue (
a BIGINT(20) UNSIGNED,
b VARBINARY(16),
c BINARY(32),
d BIGINT(20) UNSIGNED,
e BINARY(32)
)",
)
.await?;
conn.query_drop(
r"INSERT INTO mysql.issue VALUES (
0,
0xC066F966B0860000,
0x7939DA98E524C5F969FC2DE8D905FD9501EBC6F20001B0A9C941E0BE6D50CF44,
0,
''
), (
1,
'',
0x076311DF4D407B0854371BA13A5F3FB1A4555AC22B361375FD47B263F31822F2,
0,
''
)",
)
.await?;
let q = "SELECT b, c, d, e FROM mysql.issue";
let result = conn.query_iter(q).await?;
let loaded_structs = result
.map_and_drop(|row| crate::from_row::<(Vec<u8>, Vec<u8>, u64, Vec<u8>)>(row))
.await?;
conn.disconnect().await?;
assert_eq!(loaded_structs.len(), 2);
Ok(())
}
#[tokio::test]
async fn should_run_transactions() -> super::Result<()> {
let mut conn = Conn::new(get_opts()).await?;
conn.query_drop("CREATE TEMPORARY TABLE tmp (id INT, name TEXT)")
.await?;
let mut transaction = conn.start_transaction(Default::default()).await?;
transaction
.query_drop("INSERT INTO tmp VALUES (1, 'foo'), (2, 'bar')")
.await?;
assert_eq!(transaction.last_insert_id(), None);
assert_eq!(transaction.affected_rows(), 2);
assert_eq!(transaction.get_warnings(), 0);
assert_eq!(transaction.info(), "Records: 2 Duplicates: 0 Warnings: 0");
transaction.commit().await?;
let output_opt = conn.query_first("SELECT COUNT(*) FROM tmp").await?;
assert_eq!(output_opt, Some((2u8,)));
let mut transaction = conn.start_transaction(Default::default()).await?;
transaction
.query_drop("INSERT INTO tmp VALUES (3, 'baz'), (4, 'quux')")
.await?;
let output_opt = transaction
.exec_first("SELECT COUNT(*) FROM tmp", ())
.await?;
assert_eq!(output_opt, Some((4u8,)));
transaction.rollback().await?;
let output_opt = conn.query_first("SELECT COUNT(*) FROM tmp").await?;
assert_eq!(output_opt, Some((2u8,)));
let mut transaction = conn.start_transaction(Default::default()).await?;
transaction
.query_drop("INSERT INTO tmp VALUES (3, 'baz')")
.await?;
drop(transaction); let output_opt = conn.query_first("SELECT COUNT(*) FROM tmp").await?;
assert_eq!(output_opt, Some((2u8,)));
conn.disconnect().await?;
Ok(())
}
#[tokio::test]
async fn should_handle_multiresult_set_with_error() -> super::Result<()> {
const QUERY_FIRST: &str = "SELECT * FROM tmp; SELECT 1; SELECT 2;";
const QUERY_MIDDLE: &str = "SELECT 1; SELECT * FROM tmp; SELECT 2";
let mut conn = Conn::new(get_opts()).await.unwrap();
let result = QUERY_FIRST.run(&mut conn).await;
assert!(matches!(result, Err(Error::Server(_))));
let mut result = QUERY_MIDDLE.run(&mut conn).await.unwrap();
let result_set: Vec<u8> = result.collect().await.unwrap();
assert_eq!(result_set, vec![1]);
let result_set: super::Result<Vec<u8>> = result.collect().await;
assert!(matches!(result_set, Err(Error::Server(_))));
assert!(result.is_empty());
conn.ping().await?;
conn.disconnect().await?;
Ok(())
}
#[tokio::test]
async fn should_handle_binary_multiresult_set_with_error() -> super::Result<()> {
const PROC_DEF_FIRST: &str =
r#"CREATE PROCEDURE err_first() BEGIN SELECT * FROM tmp; SELECT 1; END"#;
const PROC_DEF_MIDDLE: &str =
r#"CREATE PROCEDURE err_middle() BEGIN SELECT 1; SELECT * FROM tmp; SELECT 2; END"#;
let mut conn = Conn::new(get_opts()).await.unwrap();
conn.query_drop("DROP PROCEDURE IF EXISTS err_first")
.await?;
conn.query_iter(PROC_DEF_FIRST).await?;
conn.query_drop("DROP PROCEDURE IF EXISTS err_middle")
.await?;
conn.query_iter(PROC_DEF_MIDDLE).await?;
let result = conn.query_iter("CALL err_first()").await;
assert!(matches!(result, Err(Error::Server(_))));
let mut result = conn.query_iter("CALL err_middle()").await?;
let result_set: Vec<u8> = result.collect().await.unwrap();
assert_eq!(result_set, vec![1]);
let result_set: super::Result<Vec<u8>> = result.collect().await;
assert!(matches!(result_set, Err(Error::Server(_))));
assert!(result.is_empty());
conn.ping().await?;
conn.disconnect().await?;
Ok(())
}
#[tokio::test]
async fn should_handle_multiresult_set_with_local_infile() -> super::Result<()> {
use std::fs::write;
let file_path = tempfile::Builder::new().tempfile_in("").unwrap();
let file_path = file_path.path();
let file_name = file_path.file_name().unwrap();
write(file_name, b"AAAAAA\nBBBBBB\nCCCCCC\n")?;
let opts = get_opts().local_infile_handler(Some(WhiteListFsHandler::new(&[file_name][..])));
let mut conn = Conn::new(opts).await.unwrap();
"CREATE TEMPORARY TABLE tmp (a TEXT)".run(&mut conn).await?;
let query = format!(
r#"SELECT * FROM tmp;
LOAD DATA LOCAL INFILE "{}" INTO TABLE tmp;
LOAD DATA LOCAL INFILE "{}" INTO TABLE tmp;
SELECT * FROM tmp"#,
file_name.to_str().unwrap(),
file_name.to_str().unwrap(),
);
let mut result = query.run(&mut conn).await?;
let result_set = result.collect::<String>().await?;
assert_eq!(result_set.len(), 0);
let mut no_local_infile = false;
for _ in 0..2 {
match result.collect::<String>().await {
Ok(result_set) => {
assert_eq!(result.affected_rows(), 3);
assert!(result_set.is_empty())
}
Err(Error::Server(ref err)) if err.code == 1148 => {
no_local_infile = true;
break;
}
Err(Error::Server(ref err)) if err.code == 3948 => {
no_local_infile = true;
break;
}
Err(err) => return Err(err),
}
}
if no_local_infile {
assert!(result.is_empty());
assert_eq!(result_set.len(), 0);
} else {
let result_set = result.collect::<String>().await?;
assert_eq!(result_set.len(), 6);
assert_eq!(result_set[0], "AAAAAA");
assert_eq!(result_set[1], "BBBBBB");
assert_eq!(result_set[2], "CCCCCC");
assert_eq!(result_set[3], "AAAAAA");
assert_eq!(result_set[4], "BBBBBB");
assert_eq!(result_set[5], "CCCCCC");
}
conn.ping().await?;
conn.disconnect().await?;
Ok(())
}
#[tokio::test]
async fn should_provide_multiresult_set_metadata() -> super::Result<()> {
let mut c = Conn::new(get_opts()).await?;
c.query_drop("CREATE TEMPORARY TABLE tmp (id INT, foo TEXT)")
.await?;
let mut result = c
.query_iter("SELECT 1; SELECT id, foo FROM tmp WHERE 1 = 2; DO 42; SELECT 2;")
.await?;
assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 1);
result.for_each(drop).await?;
assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 2);
result.for_each(drop).await?;
assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 0);
result.for_each(drop).await?;
assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 1);
c.disconnect().await?;
Ok(())
}
#[tokio::test]
async fn should_expose_query_result_metadata() -> super::Result<()> {
let pool = Pool::new(get_opts());
let mut c = pool.get_conn().await?;
c.query_drop(
r"
CREATE TEMPORARY TABLE `foo`
( `id` SERIAL
, `bar_id` varchar(36) NOT NULL
, `baz_id` varchar(36) NOT NULL
, `ctime` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP()
, PRIMARY KEY (`id`)
, KEY `bar_idx` (`bar_id`)
, KEY `baz_idx` (`baz_id`)
);",
)
.await?;
const QUERY: &str = "INSERT INTO foo (bar_id, baz_id) VALUES (?, ?)";
let params = ("qwerty", "data.employee_id");
let query_result = c.exec_iter(QUERY, params).await?;
assert_eq!(query_result.last_insert_id(), Some(1));
query_result.drop_result().await?;
c.exec_drop(QUERY, params).await?;
assert_eq!(c.last_insert_id(), Some(2));
let mut tx = c.start_transaction(Default::default()).await?;
tx.exec_drop(QUERY, params).await?;
assert_eq!(tx.last_insert_id(), Some(3));
Ok(())
}
#[tokio::test]
async fn should_handle_local_infile_locally() -> super::Result<()> {
let mut conn = Conn::new(get_opts()).await.unwrap();
conn.query_drop("CREATE TEMPORARY TABLE tmp (a TEXT);")
.await
.unwrap();
conn.set_infile_handler(async move {
Ok(
stream::iter([Bytes::from("AAAAAA\n"), Bytes::from("BBBBBB\nCCCCCC\n")])
.map(Ok)
.boxed(),
)
});
match conn
.query_drop(r#"LOAD DATA LOCAL INFILE "dummy" INTO TABLE tmp;"#)
.await
{
Ok(_) => (),
Err(super::Error::Server(ref err)) if err.code == 1148 => {
return Ok(());
}
Err(super::Error::Server(ref err)) if err.code == 3948 => {
return Ok(());
}
e @ Err(_) => e.unwrap(),
};
let result: Vec<String> = conn.query("SELECT * FROM tmp").await?;
assert_eq!(result.len(), 3);
assert_eq!(result[0], "AAAAAA");
assert_eq!(result[1], "BBBBBB");
assert_eq!(result[2], "CCCCCC");
Ok(())
}
#[tokio::test]
async fn should_handle_local_infile_globally() -> super::Result<()> {
use std::fs::write;
let file_path = tempfile::Builder::new().tempfile_in("").unwrap();
let file_path = file_path.path();
let file_name = file_path.file_name().unwrap();
write(file_name, b"AAAAAA\nBBBBBB\nCCCCCC\n")?;
let opts = get_opts().local_infile_handler(Some(WhiteListFsHandler::new(&[file_name][..])));
let mut conn = Conn::new(opts).await.unwrap();
conn.query_drop("CREATE TEMPORARY TABLE tmp (a TEXT);")
.await
.unwrap();
match conn
.query_drop(format!(
r#"LOAD DATA LOCAL INFILE "{}" INTO TABLE tmp;"#,
file_name.to_str().unwrap(),
))
.await
{
Ok(_) => (),
Err(super::Error::Server(ref err)) if err.code == 1148 => {
return Ok(());
}
Err(super::Error::Server(ref err)) if err.code == 3948 => {
return Ok(());
}
e @ Err(_) => e.unwrap(),
};
let result: Vec<String> = conn.query("SELECT * FROM tmp").await?;
assert_eq!(result.len(), 3);
assert_eq!(result[0], "AAAAAA");
assert_eq!(result[1], "BBBBBB");
assert_eq!(result[2], "CCCCCC");
Ok(())
}
#[cfg(feature = "nightly")]
mod bench {
use crate::{conn::Conn, queryable::Queryable, test_misc::get_opts};
#[bench]
fn simple_exec(bencher: &mut test::Bencher) {
let mut runtime = tokio::runtime::Runtime::new().unwrap();
let mut conn = runtime.block_on(Conn::new(get_opts())).unwrap();
bencher.iter(|| {
runtime.block_on(conn.query_drop("DO 1")).unwrap();
});
runtime.block_on(conn.disconnect()).unwrap();
}
#[bench]
fn select_large_string(bencher: &mut test::Bencher) {
let mut runtime = tokio::runtime::Runtime::new().unwrap();
let mut conn = runtime.block_on(Conn::new(get_opts())).unwrap();
bencher.iter(|| {
runtime
.block_on(conn.query_drop("SELECT REPEAT('A', 10000)"))
.unwrap();
});
runtime.block_on(conn.disconnect()).unwrap();
}
#[bench]
fn prepared_exec(bencher: &mut test::Bencher) {
let mut runtime = tokio::runtime::Runtime::new().unwrap();
let mut conn = runtime.block_on(Conn::new(get_opts())).unwrap();
let stmt = runtime.block_on(conn.prep("DO 1")).unwrap();
bencher.iter(|| {
runtime.block_on(conn.exec_drop(&stmt, ())).unwrap();
});
runtime.block_on(conn.close(stmt)).unwrap();
runtime.block_on(conn.disconnect()).unwrap();
}
#[bench]
fn prepare_and_exec(bencher: &mut test::Bencher) {
let mut runtime = tokio::runtime::Runtime::new().unwrap();
let mut conn = runtime.block_on(Conn::new(get_opts())).unwrap();
bencher.iter(|| {
runtime.block_on(conn.exec_drop("SELECT ?", (0,))).unwrap();
});
runtime.block_on(conn.disconnect()).unwrap();
}
}
}