use crate::{
ast::{self, WithName},
interner::StringId,
walkers::RelationFieldId,
DatamodelError, Diagnostics,
{context::Context, types::RelationField},
};
use enumflags2::bitflags;
use rustc_hash::FxHashMap as HashMap;
use std::{collections::BTreeSet, fmt};
pub(super) fn infer_relations(ctx: &mut Context<'_>) {
let mut relations = Relations::default();
for rf in ctx.types.iter_relation_fields() {
let evidence = relation_evidence(rf, ctx);
ingest_relation(evidence, &mut relations, ctx);
}
let _ = std::mem::replace(ctx.relations, relations);
}
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Ord, Eq)]
pub struct RelationId(u32);
impl RelationId {
const MAX: RelationId = RelationId(u32::MAX);
const MIN: RelationId = RelationId(u32::MIN);
}
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Ord, Eq)]
pub struct ManyToManyRelationId(pub(crate) RelationId);
#[derive(Debug, Default)]
pub(crate) struct Relations {
relations_storage: Vec<Relation>,
fields: HashMap<RelationFieldId, RelationId>,
forward: BTreeSet<(ast::ModelId, ast::ModelId, RelationId)>,
back: BTreeSet<(ast::ModelId, ast::ModelId, RelationId)>,
}
impl std::ops::Index<RelationId> for Relations {
type Output = Relation;
fn index(&self, index: RelationId) -> &Self::Output {
&self.relations_storage[index.0 as usize]
}
}
impl std::ops::Index<RelationFieldId> for Relations {
type Output = RelationId;
fn index(&self, index: RelationFieldId) -> &Self::Output {
&self.fields[&index]
}
}
impl Relations {
pub(crate) fn iter(&self) -> impl ExactSizeIterator<Item = RelationId> + Clone {
(0..self.relations_storage.len()).map(|idx| RelationId(idx as u32))
}
pub(crate) fn iter_relations(&self) -> impl Iterator<Item = (&Relation, RelationId)> + '_ {
self.relations_storage
.iter()
.enumerate()
.map(|(idx, rel)| (rel, RelationId(idx as u32)))
}
#[allow(clippy::wrong_self_convention)] pub(crate) fn from_model(&self, model_a_id: ast::ModelId) -> impl Iterator<Item = RelationId> + '_ {
self.forward
.range((model_a_id, ast::ModelId::ZERO, RelationId::MIN)..(model_a_id, ast::ModelId::MAX, RelationId::MAX))
.map(move |(_, _, relation_id)| *relation_id)
}
pub(crate) fn to_model(&self, model_a_id: ast::ModelId) -> impl Iterator<Item = RelationId> + '_ {
self.back
.range((model_a_id, ast::ModelId::ZERO, RelationId::MIN)..(model_a_id, ast::ModelId::MAX, RelationId::MAX))
.map(move |(_, _, relation_id)| *relation_id)
}
}
#[derive(PartialOrd, Ord, PartialEq, Eq, Debug)]
pub(super) enum OneToManyRelationFields {
Forward(RelationFieldId),
Back(RelationFieldId),
Both(RelationFieldId, RelationFieldId),
}
#[derive(PartialOrd, Ord, PartialEq, Eq, Debug)]
pub(super) enum OneToOneRelationFields {
Forward(RelationFieldId),
Both(RelationFieldId, RelationFieldId),
}
#[derive(PartialOrd, Ord, PartialEq, Eq, Debug)]
pub(super) enum RelationAttributes {
ImplicitManyToMany {
field_a: RelationFieldId,
field_b: RelationFieldId,
},
TwoWayEmbeddedManyToMany {
field_a: RelationFieldId,
field_b: RelationFieldId,
},
OneToOne(OneToOneRelationFields),
OneToMany(OneToManyRelationFields),
}
impl RelationAttributes {
pub(crate) fn fields(&self) -> (Option<RelationFieldId>, Option<RelationFieldId>) {
match self {
RelationAttributes::ImplicitManyToMany { field_a, field_b }
| RelationAttributes::TwoWayEmbeddedManyToMany { field_a, field_b }
| RelationAttributes::OneToOne(OneToOneRelationFields::Both(field_a, field_b))
| RelationAttributes::OneToMany(OneToManyRelationFields::Both(field_a, field_b)) => {
(Some(*field_a), Some(*field_b))
}
RelationAttributes::OneToMany(OneToManyRelationFields::Forward(field_a))
| RelationAttributes::OneToOne(OneToOneRelationFields::Forward(field_a)) => (Some(*field_a), None),
RelationAttributes::OneToMany(OneToManyRelationFields::Back(field_b)) => (None, Some(*field_b)),
}
}
}
#[derive(PartialOrd, Ord, PartialEq, Eq, Debug)]
pub(crate) struct Relation {
pub(super) relation_name: Option<StringId>,
pub(super) attributes: RelationAttributes,
pub(super) model_a: ast::ModelId,
pub(super) model_b: ast::ModelId,
}
impl Relation {
pub(crate) fn is_implicit_many_to_many(&self) -> bool {
matches!(self.attributes, RelationAttributes::ImplicitManyToMany { .. })
}
pub(crate) fn as_complete_fields(&self) -> Option<(RelationFieldId, RelationFieldId)> {
match &self.attributes {
RelationAttributes::ImplicitManyToMany { field_a, field_b } => Some((*field_a, *field_b)),
RelationAttributes::TwoWayEmbeddedManyToMany { field_a, field_b } => Some((*field_a, *field_b)),
RelationAttributes::OneToOne(OneToOneRelationFields::Both(field_a, field_b)) => Some((*field_a, *field_b)),
RelationAttributes::OneToMany(OneToManyRelationFields::Both(field_a, field_b)) => {
Some((*field_a, *field_b))
}
_ => None,
}
}
pub(crate) fn is_two_way_embedded_many_to_many(&self) -> bool {
matches!(self.attributes, RelationAttributes::TwoWayEmbeddedManyToMany { .. })
}
}
pub(super) struct RelationEvidence<'db> {
pub(super) ast_model: &'db ast::Model,
pub(super) model_id: ast::ModelId,
pub(super) ast_field: &'db ast::Field,
pub(super) field_id: RelationFieldId,
pub(super) is_self_relation: bool,
pub(super) is_two_way_embedded_many_to_many_relation: bool,
pub(super) relation_field: &'db RelationField,
pub(super) opposite_model: &'db ast::Model,
pub(super) opposite_relation_field: Option<(RelationFieldId, &'db ast::Field, &'db RelationField)>,
}
pub(super) fn relation_evidence<'db>(
(relation_field_id, relation_field): (RelationFieldId, &'db RelationField),
ctx: &'db Context<'db>,
) -> RelationEvidence<'db> {
let ast = ctx.ast;
let ast_model = &ast[relation_field.model_id];
let ast_field = &ast_model[relation_field.field_id];
let opposite_model = &ast[relation_field.referenced_model];
let is_self_relation = relation_field.model_id == relation_field.referenced_model;
let opposite_relation_field: Option<(RelationFieldId, &ast::Field, &'db RelationField)> = ctx
.types
.range_model_relation_fields(relation_field.referenced_model)
.filter(|(_, opposite_relation_field)| opposite_relation_field.referenced_model == relation_field.model_id)
.filter(|(_, opposite_relation_field)| {
!is_self_relation || opposite_relation_field.field_id != relation_field.field_id
})
.find(|(_, opposite_relation_field)| opposite_relation_field.name == relation_field.name)
.map(|(opp_field_id, opp_rf)| (opp_field_id, &ast[opp_rf.model_id][opp_rf.field_id], opp_rf));
let is_two_way_embedded_many_to_many_relation = match (relation_field, opposite_relation_field) {
(left, Some((_, _, right))) => left.fields.is_some() || right.fields.is_some(),
_ => false,
};
RelationEvidence {
ast_model,
model_id: relation_field.model_id,
ast_field,
field_id: relation_field_id,
relation_field,
opposite_model,
is_self_relation,
opposite_relation_field,
is_two_way_embedded_many_to_many_relation,
}
}
pub(super) fn ingest_relation<'db>(evidence: RelationEvidence<'db>, relations: &mut Relations, ctx: &Context<'db>) {
let relation_type = match (evidence.ast_field.arity, evidence.opposite_relation_field) {
(ast::FieldArity::List, Some((opp_field_id, opp_field, _))) if opp_field.arity.is_list() => {
if evidence.ast_model.name() > evidence.opposite_model.name() {
return;
}
if evidence.is_self_relation && evidence.ast_field.name() > opp_field.name() {
return;
}
if evidence.is_two_way_embedded_many_to_many_relation {
RelationAttributes::TwoWayEmbeddedManyToMany {
field_a: evidence.field_id,
field_b: opp_field_id,
}
} else {
RelationAttributes::ImplicitManyToMany {
field_a: evidence.field_id,
field_b: opp_field_id,
}
}
}
(ast::FieldArity::Required, Some((opp_field_id, opp_field, _))) if opp_field.arity.is_optional() => {
RelationAttributes::OneToOne(OneToOneRelationFields::Both(evidence.field_id, opp_field_id))
}
(ast::FieldArity::Required, Some((opp_field_id, opp_field, _))) if opp_field.arity.is_required() => {
if [evidence.ast_model.name(), evidence.ast_field.name()]
> [evidence.opposite_model.name(), opp_field.name()]
{
return;
}
RelationAttributes::OneToOne(OneToOneRelationFields::Both(evidence.field_id, opp_field_id))
}
(ast::FieldArity::Optional, Some((_, opp_field, _))) if opp_field.arity.is_required() => {
return;
}
(ast::FieldArity::Optional, Some((opp_field_id, opp_field, opp_field_attributes)))
if opp_field.arity.is_optional() =>
{
if evidence.relation_field.fields.is_some() {
RelationAttributes::OneToOne(OneToOneRelationFields::Both(evidence.field_id, opp_field_id))
} else if opp_field_attributes.fields.is_none() {
if [evidence.ast_model.name(), evidence.ast_field.name()]
> [evidence.opposite_model.name(), opp_field.name()]
{
return;
}
RelationAttributes::OneToOne(OneToOneRelationFields::Both(evidence.field_id, opp_field_id))
} else {
return;
}
}
(ast::FieldArity::List, Some(_)) => {
return;
}
(ast::FieldArity::List, None) => {
RelationAttributes::OneToMany(OneToManyRelationFields::Back(evidence.field_id))
}
(ast::FieldArity::Required | ast::FieldArity::Optional, Some((opp_field_id, _, _))) => {
RelationAttributes::OneToMany(OneToManyRelationFields::Both(evidence.field_id, opp_field_id))
}
(ast::FieldArity::Optional | ast::FieldArity::Required, None) => {
match &evidence.relation_field.fields {
Some(fields) => {
let fields_are_unique =
ctx.types.model_attributes[&evidence.model_id]
.ast_indexes
.iter()
.any(|(_, idx)| {
idx.is_unique() && idx.fields.len() == fields.len() && {
idx.fields
.iter()
.zip(fields.iter())
.all(|(idx_field, field)| matches!(idx_field.path.field_in_index(), either::Either::Left(id) if id == *field))
}
});
if fields_are_unique {
RelationAttributes::OneToOne(OneToOneRelationFields::Forward(evidence.field_id))
} else {
RelationAttributes::OneToMany(OneToManyRelationFields::Forward(evidence.field_id))
}
}
_ => RelationAttributes::OneToMany(OneToManyRelationFields::Forward(evidence.field_id)),
}
}
};
let relation = match relation_type {
RelationAttributes::OneToMany(OneToManyRelationFields::Back(_)) => Relation {
attributes: relation_type,
relation_name: evidence.relation_field.name,
model_a: evidence.relation_field.referenced_model,
model_b: evidence.model_id,
},
_ => Relation {
attributes: relation_type,
relation_name: evidence.relation_field.name,
model_a: evidence.model_id,
model_b: evidence.relation_field.referenced_model,
},
};
let relation_id = RelationId(relations.relations_storage.len() as u32);
relations.relations_storage.push(relation);
relations.fields.insert(evidence.field_id, relation_id);
if let Some((opposite_field_id, _, _)) = evidence.opposite_relation_field {
relations.fields.insert(opposite_field_id, relation_id);
}
relations
.forward
.insert((evidence.model_id, evidence.relation_field.referenced_model, relation_id));
relations
.back
.insert((evidence.relation_field.referenced_model, evidence.model_id, relation_id));
}
#[repr(u8)]
#[bitflags]
#[derive(Debug, Copy, PartialEq, Clone)]
pub enum ReferentialAction {
Cascade,
Restrict,
NoAction,
SetNull,
SetDefault,
}
impl ReferentialAction {
pub fn triggers_modification(self) -> bool {
!matches!(self, Self::NoAction | Self::Restrict)
}
pub fn as_str(self) -> &'static str {
match self {
ReferentialAction::Cascade => "Cascade",
ReferentialAction::Restrict => "Restrict",
ReferentialAction::NoAction => "NoAction",
ReferentialAction::SetNull => "SetNull",
ReferentialAction::SetDefault => "SetDefault",
}
}
pub fn documentation(&self) -> &'static str {
match self {
ReferentialAction::Cascade => "Delete the child records when the parent record is deleted.",
ReferentialAction::Restrict => "Prevent deleting a parent record as long as it is referenced.",
ReferentialAction::NoAction => "Prevent deleting a parent record as long as it is referenced.",
ReferentialAction::SetNull => "Set the referencing fields to NULL when the referenced record is deleted.",
ReferentialAction::SetDefault => {
"Set the referencing field's value to the default when the referenced record is deleted."
}
}
}
pub(crate) fn try_from_expression(
expr: &ast::Expression,
diagnostics: &mut Diagnostics,
) -> Option<ReferentialAction> {
match crate::coerce::constant(expr, diagnostics)? {
"Cascade" => Some(ReferentialAction::Cascade),
"Restrict" => Some(ReferentialAction::Restrict),
"NoAction" => Some(ReferentialAction::NoAction),
"SetNull" => Some(ReferentialAction::SetNull),
"SetDefault" => Some(ReferentialAction::SetDefault),
s => {
let message = format!("Invalid referential action: `{s}`");
diagnostics.push_error(DatamodelError::new_attribute_validation_error(
&message,
"@relation",
expr.span(),
));
None
}
}
}
}
impl AsRef<str> for ReferentialAction {
fn as_ref(&self) -> &'static str {
match self {
ReferentialAction::Cascade => "Cascade",
ReferentialAction::Restrict => "Restrict",
ReferentialAction::NoAction => "NoAction",
ReferentialAction::SetNull => "SetNull",
ReferentialAction::SetDefault => "SetDefault",
}
}
}
impl fmt::Display for ReferentialAction {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_ref())
}
}