use super::{
check::Check, database_inspection_results::DatabaseInspectionResults,
unexecutable_step_check::UnexecutableStepCheck, warning_check::SqlMigrationWarningCheck,
};
use crate::flavour::SqlFlavour;
use schema_connector::{
ConnectorError, ConnectorResult, DestructiveChangeDiagnostics, MigrationWarning, UnexecutableMigration,
};
use std::time::Duration;
use tokio::time::{error::Elapsed, timeout};
const DESTRUCTIVE_TIMEOUT_DURATION: Duration = Duration::from_secs(60);
#[derive(Debug)]
pub(crate) struct DestructiveCheckPlan {
warnings: Vec<(SqlMigrationWarningCheck, usize)>,
unexecutable_migrations: Vec<(UnexecutableStepCheck, usize)>,
}
impl DestructiveCheckPlan {
pub(super) fn new() -> Self {
DestructiveCheckPlan {
warnings: Vec::new(),
unexecutable_migrations: Vec::new(),
}
}
pub(super) fn push_warning(&mut self, warning: SqlMigrationWarningCheck, step_index: usize) {
self.warnings.push((warning, step_index))
}
pub(super) fn push_unexecutable(&mut self, unexecutable_migration: UnexecutableStepCheck, step_index: usize) {
self.unexecutable_migrations.push((unexecutable_migration, step_index))
}
#[tracing::instrument(skip(flavour), level = "debug")]
pub(super) async fn execute(
&self,
flavour: &mut (dyn SqlFlavour + Send + Sync),
) -> ConnectorResult<DestructiveChangeDiagnostics> {
let mut results = DatabaseInspectionResults::default();
let inspection = async {
for (unexecutable, _idx) in &self.unexecutable_migrations {
self.inspect_for_check(unexecutable, flavour, &mut results).await?;
}
for (warning, _idx) in &self.warnings {
self.inspect_for_check(warning, flavour, &mut results).await?;
}
Ok::<(), ConnectorError>(())
};
match timeout(DESTRUCTIVE_TIMEOUT_DURATION, inspection).await {
Ok(Ok(())) | Err(Elapsed { .. }) => (),
Ok(Err(err)) => return Err(err),
};
let mut diagnostics = DestructiveChangeDiagnostics::new();
for (unexecutable, step_index) in &self.unexecutable_migrations {
if let Some(message) = unexecutable.evaluate(&results) {
diagnostics.unexecutable_migrations.push(UnexecutableMigration {
description: message,
step_index: *step_index,
})
}
}
for (warning, step_index) in &self.warnings {
if let Some(message) = warning.evaluate(&results) {
diagnostics.warnings.push(MigrationWarning {
description: message,
step_index: *step_index,
})
}
}
Ok(diagnostics)
}
pub(super) async fn inspect_for_check(
&self,
check: &(dyn Check + Send + Sync + 'static),
flavour: &mut (dyn SqlFlavour + Send + Sync),
results: &mut DatabaseInspectionResults,
) -> ConnectorResult<()> {
if let Some(table) = check.needed_table_row_count() {
if results.get_row_count(&table).is_none() {
let count = flavour.count_rows_in_table(&table).await?;
results.set_row_count(table.to_owned(), count)
}
}
if let Some(column) = check.needed_column_value_count() {
if let (_, None) = results.get_row_and_non_null_value_count(&column) {
let count = flavour.count_values_in_column(&column).await?;
results.set_value_count(column, count);
}
}
Ok(())
}
pub(super) fn pure_check(&self) -> DestructiveChangeDiagnostics {
let results = DatabaseInspectionResults::default();
let mut diagnostics = DestructiveChangeDiagnostics::new();
for (unexecutable, step_index) in &self.unexecutable_migrations {
if let Some(message) = unexecutable.evaluate(&results) {
diagnostics.unexecutable_migrations.push(UnexecutableMigration {
description: message,
step_index: *step_index,
})
}
}
for (warning, step_index) in &self.warnings {
if let Some(message) = warning.evaluate(&results) {
diagnostics.warnings.push(MigrationWarning {
description: message,
step_index: *step_index,
})
}
}
diagnostics
}
}