#![deny(rust_2018_idioms, unsafe_code)]
#![allow(clippy::derive_partial_eq_without_eq)]
pub mod mssql;
pub mod mysql;
pub mod postgres;
pub mod sqlite;
pub mod walkers;
mod connector_data;
mod error;
mod getters;
mod ids;
mod parsers;
pub use self::{
error::{DescriberError, DescriberErrorKind, DescriberResult},
ids::*,
walkers::*,
};
pub use either::Either;
pub use prisma_value::PrismaValue;
use enumflags2::{BitFlag, BitFlags};
use once_cell::sync::Lazy;
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::{
any::Any,
fmt::{self, Debug},
};
#[async_trait::async_trait]
pub trait SqlSchemaDescriberBackend: Send + Sync {
async fn list_databases(&self) -> DescriberResult<Vec<String>>;
async fn get_metadata(&self, schema: &str) -> DescriberResult<SqlMetadata>;
async fn describe(&self, schemas: &[&str]) -> DescriberResult<SqlSchema>;
async fn version(&self) -> DescriberResult<Option<String>>;
}
pub struct SqlMetadata {
pub table_count: usize,
pub size_in_bytes: usize,
}
#[derive(Serialize, Deserialize, Debug, Default)]
pub struct SqlSchema {
namespaces: Vec<String>,
tables: Vec<Table>,
enums: Vec<Enum>,
enum_variants: Vec<EnumVariant>,
table_columns: Vec<(TableId, Column)>,
foreign_keys: Vec<ForeignKey>,
table_default_values: Vec<(TableColumnId, DefaultValue)>,
view_default_values: Vec<(ViewColumnId, DefaultValue)>,
foreign_key_columns: Vec<ForeignKeyColumn>,
indexes: Vec<Index>,
index_columns: Vec<IndexColumn>,
check_constraints: Vec<(TableId, String)>,
views: Vec<View>,
view_columns: Vec<(ViewId, Column)>,
procedures: Vec<Procedure>,
user_defined_types: Vec<UserDefinedType>,
connector_data: connector_data::ConnectorData,
}
impl SqlSchema {
#[track_caller]
pub fn downcast_connector_data<T: 'static>(&self) -> &T {
self.connector_data.data.as_ref().unwrap().downcast_ref().unwrap()
}
pub fn next_table_column_id(&self) -> TableColumnId {
TableColumnId(self.table_columns.len() as u32)
}
pub fn next_view_column_id(&self) -> ViewColumnId {
ViewColumnId(self.view_columns.len() as u32)
}
pub fn enum_used_in_tables(&self, id: EnumId) -> bool {
self.table_columns
.iter()
.any(|col| col.1.tpe.family == ColumnTypeFamily::Enum(id))
}
#[track_caller]
pub fn downcast_connector_data_mut<T: 'static>(&mut self) -> &mut T {
self.connector_data.data.as_mut().unwrap().downcast_mut().unwrap()
}
pub fn clear_namespaces(&mut self) {
self.namespaces.clear();
}
pub fn set_connector_data(&mut self, data: Box<dyn Any + Send + Sync>) {
self.connector_data.data = Some(data);
}
pub fn get_view(&self, name: &str) -> Option<&View> {
self.views.iter().find(|v| v.name == name)
}
pub fn find_enum(&self, name: &str, namespace: Option<&str>) -> Option<EnumId> {
let ns_id = namespace.and_then(|ns| self.get_namespace(ns));
self.enums
.iter()
.position(|e| e.name == name && ns_id.map(|id| id == e.namespace_id).unwrap_or(true))
.map(|i| EnumId(i as u32))
}
fn get_namespace(&self, name: &str) -> Option<NamespaceId> {
self.namespaces
.iter()
.position(|ns| ns == name)
.map(|i| NamespaceId(i as u32))
}
pub fn find_table(&self, name: &str, namespace: Option<&str>) -> Option<TableId> {
let ns_id = namespace.and_then(|ns| self.get_namespace(ns));
self.tables
.iter()
.position(|t| t.name == name && ns_id.map(|id| id == t.namespace_id).unwrap_or(true))
.map(|i| TableId(i as u32))
}
pub fn find_view(&self, name: &str, namespace: Option<&str>) -> Option<ViewId> {
let ns_id = namespace.and_then(|ns| self.get_namespace(ns));
self.views
.iter()
.position(|t| t.name == name && ns_id.map(|id| id == t.namespace_id).unwrap_or(true))
.map(|i| ViewId(i as u32))
}
pub fn get_procedure(&self, name: &str) -> Option<&Procedure> {
self.procedures.iter().find(|x| x.name == name)
}
pub fn get_user_defined_type(&self, name: &str) -> Option<&UserDefinedType> {
self.user_defined_types.iter().find(|x| x.name == name)
}
pub fn get_namespace_id(&self, name: &str) -> Option<NamespaceId> {
self.namespaces
.binary_search_by(|ns_name| ns_name.as_str().cmp(name))
.ok()
.map(|pos| NamespaceId(pos as u32))
}
pub fn indexes_count(&self) -> usize {
self.indexes.len()
}
pub fn make_fulltext_indexes_normal(&mut self) {
for idx in self.indexes.iter_mut() {
if matches!(idx.tpe, IndexType::Fulltext) {
idx.tpe = IndexType::Normal;
}
}
}
pub fn push_table_column(&mut self, table_id: TableId, column: Column) -> TableColumnId {
let id = TableColumnId(self.table_columns.len() as u32);
self.table_columns.push((table_id, column));
id
}
pub fn push_view_column(&mut self, view_id: ViewId, column: Column) -> ViewColumnId {
let id = ViewColumnId(self.view_columns.len() as u32);
self.view_columns.push((view_id, column));
id
}
pub fn push_enum(&mut self, namespace_id: NamespaceId, enum_name: String, description: Option<String>) -> EnumId {
let id = EnumId(self.enums.len() as u32);
self.enums.push(Enum {
namespace_id,
name: enum_name,
description,
});
id
}
pub fn push_enum_variant(&mut self, enum_id: EnumId, variant_name: String) -> EnumVariantId {
let id = EnumVariantId(self.enum_variants.len() as u32);
self.enum_variants.push(EnumVariant { enum_id, variant_name });
id
}
pub fn push_fulltext_index(&mut self, table_id: TableId, index_name: String) -> IndexId {
let id = IndexId(self.indexes.len() as u32);
self.indexes.push(Index {
table_id,
index_name,
tpe: IndexType::Fulltext,
});
id
}
pub fn push_index(&mut self, table_id: TableId, index_name: String) -> IndexId {
let id = IndexId(self.indexes.len() as u32);
self.indexes.push(Index {
table_id,
index_name,
tpe: IndexType::Normal,
});
id
}
pub fn push_table_default_value(&mut self, column_id: TableColumnId, value: DefaultValue) -> TableDefaultValueId {
let id = TableDefaultValueId(self.table_default_values.len() as u32);
self.table_default_values.push((column_id, value));
id
}
pub fn push_view_default_value(&mut self, column_id: ViewColumnId, value: DefaultValue) -> ViewDefaultValueId {
let id = ViewDefaultValueId(self.view_default_values.len() as u32);
self.view_default_values.push((column_id, value));
id
}
pub fn push_primary_key(&mut self, table_id: TableId, index_name: String) -> IndexId {
let id = IndexId(self.indexes.len() as u32);
self.indexes.push(Index {
table_id,
index_name,
tpe: IndexType::PrimaryKey,
});
id
}
pub fn push_unique_constraint(&mut self, table_id: TableId, index_name: String) -> IndexId {
let id = IndexId(self.indexes.len() as u32);
self.indexes.push(Index {
table_id,
index_name,
tpe: IndexType::Unique,
});
id
}
pub fn push_index_column(&mut self, column: IndexColumn) -> IndexColumnId {
let id = IndexColumnId(self.index_columns.len() as u32);
self.index_columns.push(column);
id
}
pub fn push_foreign_key(
&mut self,
constraint_name: Option<String>,
[constrained_table, referenced_table]: [TableId; 2],
[on_delete_action, on_update_action]: [ForeignKeyAction; 2],
) -> ForeignKeyId {
let id = ForeignKeyId(self.foreign_keys.len() as u32);
self.foreign_keys.push(ForeignKey {
constrained_table,
constraint_name,
referenced_table,
on_delete_action,
on_update_action,
});
id
}
pub fn push_foreign_key_column(
&mut self,
foreign_key_id: ForeignKeyId,
[constrained_column, referenced_column]: [TableColumnId; 2],
) {
self.foreign_key_columns.push(ForeignKeyColumn {
foreign_key_id,
constrained_column,
referenced_column,
});
}
pub fn push_namespace(&mut self, name: String) -> NamespaceId {
let id = NamespaceId(self.namespaces.len() as u32);
self.namespaces.push(name);
id
}
pub fn push_table(&mut self, name: String, namespace_id: NamespaceId, description: Option<String>) -> TableId {
let id = TableId(self.tables.len() as u32);
self.tables.push(Table {
namespace_id,
name,
properties: TableProperties::empty(),
description,
});
id
}
pub fn push_view(
&mut self,
name: String,
namespace_id: NamespaceId,
definition: Option<String>,
description: Option<String>,
) -> ViewId {
let id = ViewId(self.views.len() as u32);
self.views.push(View {
namespace_id,
name,
definition,
description,
});
id
}
pub fn push_table_with_properties(
&mut self,
name: String,
namespace_id: NamespaceId,
properties: BitFlags<TableProperties>,
description: Option<String>,
) -> TableId {
let id = TableId(self.tables.len() as u32);
self.tables.push(Table {
namespace_id,
name,
properties,
description,
});
id
}
pub fn namespaces_count(&self) -> usize {
self.namespaces.len()
}
pub fn namespace_walker<'a>(&'a self, name: &str) -> Option<NamespaceWalker<'a>> {
let namespace_idx = self.namespaces.iter().position(|ns| ns == name)?;
Some(self.walk(NamespaceId(namespace_idx as u32)))
}
pub fn tables_count(&self) -> usize {
self.tables.len()
}
pub fn views_count(&self) -> usize {
self.views.len()
}
pub fn table_walker<'a>(&'a self, name: &str) -> Option<TableWalker<'a>> {
let table_idx = self.tables.iter().position(|table| table.name == name)?;
Some(self.walk(TableId(table_idx as u32)))
}
pub fn table_walker_ns<'a>(&'a self, namespace: &str, name: &str) -> Option<TableWalker<'a>> {
let namespace_idx = self.namespace_walker(namespace)?.id;
let table_idx = self
.tables
.iter()
.position(|table| table.name == name && table.namespace_id == namespace_idx)?;
Some(self.walk(TableId(table_idx as u32)))
}
pub fn view_walker<'a>(&'a self, name: &str) -> Option<ViewWalker<'a>> {
let view_idx = self.views.iter().position(|view| view.name == name)?;
Some(self.walk(ViewId(view_idx as u32)))
}
pub fn view_walker_ns<'a>(&'a self, namespace: &str, name: &str) -> Option<ViewWalker<'a>> {
let namespace_idx = self.namespace_walker(namespace)?.id;
let view_idx = self
.views
.iter()
.position(|view| view.name == name && view.namespace_id == namespace_idx)?;
Some(self.walk(ViewId(view_idx as u32)))
}
pub fn table_walkers(&self) -> impl ExactSizeIterator<Item = TableWalker<'_>> {
(0..self.tables.len()).map(move |table_index| self.walk(TableId(table_index as u32)))
}
pub fn view_walkers(&self) -> impl ExactSizeIterator<Item = ViewWalker<'_>> {
(0..self.views.len()).map(move |view_index| self.walk(ViewId(view_index as u32)))
}
pub fn udt_walkers(&self) -> impl Iterator<Item = UserDefinedTypeWalker<'_>> {
(0..self.user_defined_types.len()).map(move |udt_index| self.walk(UdtId(udt_index as u32)))
}
pub fn enum_walkers(&self) -> impl ExactSizeIterator<Item = EnumWalker<'_>> {
(0..self.enums.len()).map(move |enum_index| self.walk(EnumId(enum_index as u32)))
}
pub fn walk_foreign_keys(&self) -> impl Iterator<Item = ForeignKeyWalker<'_>> {
(0..self.foreign_keys.len()).map(move |fk_idx| ForeignKeyWalker {
schema: self,
id: ForeignKeyId(fk_idx as u32),
})
}
pub fn walk<I>(&self, id: I) -> Walker<'_, I> {
Walker { id, schema: self }
}
pub fn walk_table_columns(&self) -> impl Iterator<Item = TableColumnWalker<'_>> {
(0..self.table_columns.len()).map(|idx| self.walk(TableColumnId(idx as u32)))
}
pub fn walk_view_columns(&self) -> impl Iterator<Item = ViewColumnWalker<'_>> {
(0..self.view_columns.len()).map(|idx| self.walk(ViewColumnId(idx as u32)))
}
pub fn walk_namespaces(&self) -> impl ExactSizeIterator<Item = NamespaceWalker<'_>> {
(0..self.namespaces.len()).map(|idx| self.walk(NamespaceId(idx as u32)))
}
pub fn is_empty(&self) -> bool {
self.tables.is_empty() && self.enums.is_empty()
}
}
#[enumflags2::bitflags]
#[repr(u8)]
#[derive(Clone, Copy, Debug)]
pub enum TableProperties {
IsPartition,
HasSubclass,
HasRowLevelSecurity,
}
#[derive(Serialize, Deserialize, PartialEq, Debug, Default)]
pub struct Table {
namespace_id: NamespaceId,
name: String,
properties: BitFlags<TableProperties>,
description: Option<String>,
}
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone, Copy)]
pub enum IndexType {
Unique,
Normal,
Fulltext,
PrimaryKey,
}
#[derive(Serialize, Deserialize, PartialEq, Debug, Copy, Clone)]
pub enum SQLSortOrder {
Asc,
Desc,
}
impl Default for SQLSortOrder {
fn default() -> Self {
Self::Asc
}
}
impl AsRef<str> for SQLSortOrder {
fn as_ref(&self) -> &str {
match self {
SQLSortOrder::Asc => "ASC",
SQLSortOrder::Desc => "DESC",
}
}
}
impl fmt::Display for SQLSortOrder {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_ref())
}
}
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
pub struct IndexColumn {
pub index_id: IndexId,
pub column_id: TableColumnId,
pub sort_order: Option<SQLSortOrder>,
pub length: Option<u32>,
}
#[derive(Serialize, Deserialize, PartialEq, Debug)]
struct Index {
table_id: TableId,
index_name: String,
tpe: IndexType,
}
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
pub struct Procedure {
namespace_id: NamespaceId,
pub name: String,
pub definition: Option<String>,
}
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
pub struct UserDefinedType {
namespace_id: NamespaceId,
pub name: String,
pub definition: Option<String>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct Column {
pub name: String,
pub tpe: ColumnType,
pub auto_increment: bool,
pub description: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ColumnType {
pub full_data_type: String,
pub family: ColumnTypeFamily,
pub arity: ColumnArity,
#[serde(skip)]
pub native_type: Option<psl::datamodel_connector::NativeTypeInstance>,
}
impl ColumnType {
pub fn pure(family: ColumnTypeFamily, arity: ColumnArity) -> Self {
ColumnType {
full_data_type: "".to_string(),
family,
arity,
native_type: None,
}
}
pub fn with_full_data_type(family: ColumnTypeFamily, arity: ColumnArity, full_data_type: String) -> Self {
ColumnType {
full_data_type,
family,
arity,
native_type: None,
}
}
}
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
pub enum ColumnTypeFamily {
Int,
BigInt,
Float,
Decimal,
Boolean,
String,
DateTime,
Binary,
Json,
Uuid,
Enum(EnumId),
Unsupported(String),
}
impl ColumnTypeFamily {
pub fn as_enum(&self) -> Option<EnumId> {
match self {
ColumnTypeFamily::Enum(id) => Some(*id),
_ => None,
}
}
pub fn is_bigint(&self) -> bool {
matches!(self, ColumnTypeFamily::BigInt)
}
pub fn is_boolean(&self) -> bool {
matches!(self, ColumnTypeFamily::Boolean)
}
pub fn is_datetime(&self) -> bool {
matches!(self, ColumnTypeFamily::DateTime)
}
pub fn is_enum(&self) -> bool {
matches!(self, ColumnTypeFamily::Enum(_))
}
pub fn is_int(&self) -> bool {
matches!(self, ColumnTypeFamily::Int)
}
pub fn is_json(&self) -> bool {
matches!(self, ColumnTypeFamily::Json)
}
pub fn is_string(&self) -> bool {
matches!(self, ColumnTypeFamily::String)
}
pub fn is_unsupported(&self) -> bool {
matches!(self, ColumnTypeFamily::Unsupported(_))
}
}
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone, Copy)]
pub enum ColumnArity {
Required,
Nullable,
List,
}
impl ColumnArity {
pub fn is_list(&self) -> bool {
matches!(self, ColumnArity::List)
}
pub fn is_nullable(&self) -> bool {
matches!(self, ColumnArity::Nullable)
}
pub fn is_required(&self) -> bool {
matches!(self, ColumnArity::Required)
}
}
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone, Copy)]
pub enum ForeignKeyAction {
NoAction,
Restrict,
Cascade,
SetNull,
SetDefault,
}
impl ForeignKeyAction {
pub fn is_cascade(&self) -> bool {
matches!(self, ForeignKeyAction::Cascade)
}
}
#[derive(Serialize, Deserialize, Debug)]
struct ForeignKey {
constrained_table: TableId,
referenced_table: TableId,
constraint_name: Option<String>,
on_delete_action: ForeignKeyAction,
on_update_action: ForeignKeyAction,
}
#[derive(Serialize, Deserialize, Debug)]
struct ForeignKeyColumn {
foreign_key_id: ForeignKeyId,
constrained_column: TableColumnId,
referenced_column: TableColumnId,
}
#[derive(Serialize, Deserialize, Debug)]
struct Enum {
namespace_id: NamespaceId,
name: String,
description: Option<String>,
}
#[derive(Serialize, Deserialize, Debug)]
struct EnumVariant {
enum_id: EnumId,
variant_name: String,
}
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
pub struct View {
namespace_id: NamespaceId,
pub name: String,
pub definition: Option<String>,
pub description: Option<String>,
}
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
pub struct DefaultValue {
kind: DefaultKind,
constraint_name: Option<String>,
}
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
pub enum DefaultKind {
Value(PrismaValue),
Now,
Sequence(String),
UniqueRowid,
DbGenerated(Option<String>),
}
impl DefaultValue {
pub fn db_generated(val: impl Into<String>) -> Self {
Self::new(DefaultKind::DbGenerated(Some(val.into())))
}
pub fn constraint_name(&self) -> Option<&str> {
self.constraint_name.as_deref()
}
pub fn now() -> Self {
Self::new(DefaultKind::Now)
}
pub fn value(val: impl Into<PrismaValue>) -> Self {
Self::new(DefaultKind::Value(val.into()))
}
pub fn sequence(val: impl ToString) -> Self {
Self::new(DefaultKind::Sequence(val.to_string()))
}
pub fn kind(&self) -> &DefaultKind {
&self.kind
}
pub fn new(kind: DefaultKind) -> Self {
Self {
kind,
constraint_name: None,
}
}
pub fn set_constraint_name(&mut self, name: impl ToString) {
self.constraint_name = Some(name.to_string())
}
pub(crate) fn as_value(&self) -> Option<&PrismaValue> {
match self.kind {
DefaultKind::Value(ref v) => Some(v),
_ => None,
}
}
#[cfg(test)]
pub(crate) fn as_sequence(&self) -> Option<&str> {
match self.kind {
DefaultKind::Sequence(ref name) => Some(name),
_ => None,
}
}
#[cfg(test)]
pub(crate) fn is_db_generated(&self) -> bool {
matches!(self.kind, DefaultKind::DbGenerated(_))
}
pub fn unique_rowid() -> Self {
Self::new(DefaultKind::UniqueRowid)
}
pub fn with_constraint_name(mut self, constraint_name: Option<String>) -> Self {
self.constraint_name = constraint_name;
self
}
pub fn is_empty_dbgenerated(&self) -> bool {
matches!(self.kind, DefaultKind::DbGenerated(None))
}
}
fn unquote_string(val: &str) -> String {
val.trim_start_matches('\'')
.trim_end_matches('\'')
.trim_start_matches('\\')
.trim_start_matches('"')
.trim_end_matches('"')
.trim_end_matches('\\')
.into()
}
#[derive(Debug)]
struct Precision {
character_maximum_length: Option<u32>,
numeric_precision: Option<u32>,
numeric_scale: Option<u32>,
time_precision: Option<u32>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn unquoting_works() {
let quoted_str = "'abc $$ def'".to_string();
assert_eq!(unquote_string("ed_str), "abc $$ def");
assert_eq!(unquote_string("heh "), "heh ");
}
}