use futures_util::FutureExt;
use mysql_common::{
io::ParseBuf,
named_params::parse_named_params,
packets::{ComStmtClose, StmtPacket},
};
use std::{borrow::Cow, sync::Arc};
use crate::{
conn::routines::{ExecRoutine, PrepareRoutine},
consts::CapabilityFlags,
error::*,
Column, Params,
};
use super::AsQuery;
pub enum ToStatementResult<'a> {
Immediate(Statement),
Mediate(crate::BoxFuture<'a, Statement>),
}
pub trait StatementLike: Send + Sync {
fn to_statement<'a>(self, conn: &'a mut crate::Conn) -> ToStatementResult<'a>
where
Self: 'a;
}
fn to_statement_move<'a, T: AsQuery + 'a>(
stmt: T,
conn: &'a mut crate::Conn,
) -> ToStatementResult<'a> {
let fut = async move {
let query = stmt.as_query();
let (named_params, raw_query) = parse_named_params(query.as_ref())?;
let inner_stmt = match conn.get_cached_stmt(&*raw_query) {
Some(inner_stmt) => inner_stmt,
None => conn.prepare_statement(raw_query).await?,
};
Ok(Statement::new(inner_stmt, named_params))
}
.boxed();
ToStatementResult::Mediate(fut)
}
impl StatementLike for Cow<'_, str> {
fn to_statement<'a>(self, conn: &'a mut crate::Conn) -> ToStatementResult<'a>
where
Self: 'a,
{
to_statement_move(self, conn)
}
}
impl StatementLike for &'_ str {
fn to_statement<'a>(self, conn: &'a mut crate::Conn) -> ToStatementResult<'a>
where
Self: 'a,
{
to_statement_move(self, conn)
}
}
impl StatementLike for String {
fn to_statement<'a>(self, conn: &'a mut crate::Conn) -> ToStatementResult<'a>
where
Self: 'a,
{
to_statement_move(self, conn)
}
}
impl StatementLike for Box<str> {
fn to_statement<'a>(self, conn: &'a mut crate::Conn) -> ToStatementResult<'a>
where
Self: 'a,
{
to_statement_move(self, conn)
}
}
impl StatementLike for Arc<str> {
fn to_statement<'a>(self, conn: &'a mut crate::Conn) -> ToStatementResult<'a>
where
Self: 'a,
{
to_statement_move(self, conn)
}
}
impl StatementLike for Cow<'_, [u8]> {
fn to_statement<'a>(self, conn: &'a mut crate::Conn) -> ToStatementResult<'a>
where
Self: 'a,
{
to_statement_move(self, conn)
}
}
impl StatementLike for &'_ [u8] {
fn to_statement<'a>(self, conn: &'a mut crate::Conn) -> ToStatementResult<'a>
where
Self: 'a,
{
to_statement_move(self, conn)
}
}
impl StatementLike for Vec<u8> {
fn to_statement<'a>(self, conn: &'a mut crate::Conn) -> ToStatementResult<'a>
where
Self: 'a,
{
to_statement_move(self, conn)
}
}
impl StatementLike for Box<[u8]> {
fn to_statement<'a>(self, conn: &'a mut crate::Conn) -> ToStatementResult<'a>
where
Self: 'a,
{
to_statement_move(self, conn)
}
}
impl StatementLike for Arc<[u8]> {
fn to_statement<'a>(self, conn: &'a mut crate::Conn) -> ToStatementResult<'a>
where
Self: 'a,
{
to_statement_move(self, conn)
}
}
impl StatementLike for Statement {
fn to_statement<'a>(self, _conn: &'a mut crate::Conn) -> ToStatementResult<'static>
where
Self: 'a,
{
ToStatementResult::Immediate(self.clone())
}
}
impl<T: StatementLike + Clone> StatementLike for &'_ T {
fn to_statement<'a>(self, conn: &'a mut crate::Conn) -> ToStatementResult<'a>
where
Self: 'a,
{
self.clone().to_statement(conn)
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct StmtInner {
pub(crate) raw_query: Arc<[u8]>,
columns: Option<Box<[Column]>>,
params: Option<Box<[Column]>>,
stmt_packet: StmtPacket,
connection_id: u32,
}
impl StmtInner {
pub(crate) fn from_payload(
pld: &[u8],
connection_id: u32,
raw_query: Arc<[u8]>,
) -> std::io::Result<Self> {
let stmt_packet = ParseBuf(pld).parse(())?;
Ok(Self {
raw_query,
columns: None,
params: None,
stmt_packet,
connection_id,
})
}
pub(crate) fn with_params(mut self, params: Vec<Column>) -> Self {
self.params = if params.is_empty() {
None
} else {
Some(params.into_boxed_slice())
};
self
}
pub(crate) fn with_columns(mut self, columns: Vec<Column>) -> Self {
self.columns = if columns.is_empty() {
None
} else {
Some(columns.into_boxed_slice())
};
self
}
pub(crate) fn columns(&self) -> &[Column] {
self.columns.as_ref().map(AsRef::as_ref).unwrap_or(&[])
}
pub(crate) fn params(&self) -> &[Column] {
self.params.as_ref().map(AsRef::as_ref).unwrap_or(&[])
}
pub(crate) fn id(&self) -> u32 {
self.stmt_packet.statement_id()
}
pub(crate) const fn connection_id(&self) -> u32 {
self.connection_id
}
pub(crate) fn num_params(&self) -> u16 {
self.stmt_packet.num_params()
}
pub(crate) fn num_columns(&self) -> u16 {
self.stmt_packet.num_columns()
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Statement {
pub(crate) inner: Arc<StmtInner>,
pub(crate) named_params: Option<Vec<Vec<u8>>>,
}
impl Statement {
pub(crate) fn new(inner: Arc<StmtInner>, named_params: Option<Vec<Vec<u8>>>) -> Self {
Self {
inner,
named_params,
}
}
pub fn columns(&self) -> &[Column] {
self.inner.columns()
}
pub fn params(&self) -> &[Column] {
self.inner.params()
}
pub fn id(&self) -> u32 {
self.inner.id()
}
pub fn connection_id(&self) -> u32 {
self.inner.connection_id()
}
pub fn num_params(&self) -> u16 {
self.inner.num_params()
}
pub fn num_columns(&self) -> u16 {
self.inner.num_columns()
}
}
impl crate::Conn {
pub(crate) async fn read_column_defs<U>(&mut self, num: U) -> Result<Vec<Column>>
where
U: Into<usize>,
{
let num = num.into();
debug_assert!(num > 0);
let packets = self.read_packets(num).await?;
let defs = packets
.into_iter()
.map(|x| ParseBuf(&*x).parse(()))
.collect::<std::result::Result<Vec<Column>, _>>()
.map_err(Error::from)?;
if !self
.capabilities()
.contains(CapabilityFlags::CLIENT_DEPRECATE_EOF)
{
self.read_packet().await?;
}
Ok(defs)
}
pub(crate) async fn get_statement<U>(&mut self, stmt_like: U) -> Result<Statement>
where
U: StatementLike,
{
match stmt_like.to_statement(self) {
ToStatementResult::Immediate(statement) => Ok(statement),
ToStatementResult::Mediate(statement) => statement.await,
}
}
async fn prepare_statement(&mut self, raw_query: Cow<'_, [u8]>) -> Result<Arc<StmtInner>> {
let inner_stmt = self.routine(PrepareRoutine::new(raw_query)).await?;
if let Some(old_stmt) = self.cache_stmt(&inner_stmt) {
self.close_statement(old_stmt.id()).await?;
}
Ok(inner_stmt)
}
pub(crate) async fn execute_statement<P>(
&mut self,
statement: &Statement,
params: P,
) -> Result<()>
where
P: Into<Params>,
{
self.routine(ExecRoutine::new(statement, params.into()))
.await?;
Ok(())
}
pub(crate) async fn close_statement(&mut self, id: u32) -> Result<()> {
self.stmt_cache_mut().remove(id);
self.write_command(&ComStmtClose::new(id)).await
}
}