1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
use either::Either;
use prisma_value::PrismaValue;
use psl::{datamodel_connector::constraint_names::ConstraintNames, parser_database::walkers};
use sql::postgres::PostgresSchemaExt;
use sql_schema_describer as sql;
use std::{borrow::Cow, fmt};

use super::IntrospectionPair;

pub(crate) type DefaultValuePair<'a> =
    IntrospectionPair<'a, Option<walkers::DefaultValueWalker<'a>>, sql::ColumnWalker<'a>>;

pub(crate) enum DefaultKind<'a> {
    Sequence(&'a sql::postgres::Sequence),
    DbGenerated(Option<&'a str>),
    Autoincrement,
    Uuid,
    Cuid,
    Nanoid(Option<u8>),
    Now,
    String(&'a str),
    StringList(Vec<&'a str>),
    EnumVariant(Cow<'a, str>),
    Constant(&'a dyn fmt::Display),
    ConstantList(Vec<&'a dyn fmt::Display>),
    Bytes(&'a [u8]),
    BytesList(Vec<&'a [u8]>),
}

impl<'a> DefaultValuePair<'a> {
    /// The default value, if defined either in the database or PSL.
    pub(crate) fn kind(self) -> Option<DefaultKind<'a>> {
        let sql_kind = self.next.default().map(|d| d.kind());
        let family = self.next.column_type_family();

        match (sql_kind, family) {
            (Some(sql::DefaultKind::Sequence(name)), _) if self.context.is_cockroach() => {
                let connector_data: &PostgresSchemaExt = self.context.sql_schema.downcast_connector_data();

                let sequence_idx = connector_data
                    .sequences
                    .binary_search_by_key(&name, |s| &s.name)
                    .unwrap();

                Some(DefaultKind::Sequence(&connector_data.sequences[sequence_idx]))
            }
            (_, sql::ColumnTypeFamily::Int | sql::ColumnTypeFamily::BigInt) if self.next.is_autoincrement() => {
                Some(DefaultKind::Autoincrement)
            }
            (Some(sql::DefaultKind::Sequence(_)), _) => Some(DefaultKind::Autoincrement),
            (Some(sql::DefaultKind::UniqueRowid), _) => Some(DefaultKind::Autoincrement),

            (Some(sql::DefaultKind::DbGenerated(default_string)), _) => {
                Some(DefaultKind::DbGenerated(default_string.as_deref()))
            }

            (Some(sql::DefaultKind::Now), sql::ColumnTypeFamily::DateTime) => Some(DefaultKind::Now),

            (Some(sql::DefaultKind::Value(PrismaValue::Null)), _) => Some(DefaultKind::Constant(&"null")),
            (Some(sql::DefaultKind::Value(PrismaValue::String(val))), _) => Some(DefaultKind::String(val)),
            (Some(sql::DefaultKind::Value(PrismaValue::Json(val))), _) => Some(DefaultKind::String(val)),

            (Some(sql::DefaultKind::Value(PrismaValue::Boolean(val))), _) => Some(DefaultKind::Constant(val)),
            (Some(sql::DefaultKind::Value(PrismaValue::Enum(variant))), sql::ColumnTypeFamily::Enum(enum_id)) => {
                let variant = self
                    .context
                    .sql_schema
                    .walk(*enum_id)
                    .variants()
                    .find(|v| v.name() == variant)
                    .unwrap();

                let variant_name = self.context.enum_variant_name(variant.id);

                if !variant_name.prisma_name().is_empty() {
                    Some(DefaultKind::EnumVariant(variant_name.prisma_name()))
                } else {
                    Some(DefaultKind::DbGenerated(variant_name.mapped_name()))
                }
            }
            (Some(sql::DefaultKind::Value(PrismaValue::Int(val))), _) => Some(DefaultKind::Constant(val)),
            (Some(sql::DefaultKind::Value(PrismaValue::Uuid(val))), _) => Some(DefaultKind::Constant(val)),
            (Some(sql::DefaultKind::Value(PrismaValue::DateTime(val))), _) => Some(DefaultKind::Constant(val)),
            (Some(sql::DefaultKind::Value(PrismaValue::Float(val))), _) => Some(DefaultKind::Constant(val)),
            (Some(sql::DefaultKind::Value(PrismaValue::BigInt(val))), _) => Some(DefaultKind::Constant(val)),

            (Some(sql::DefaultKind::Value(PrismaValue::Bytes(val))), _) => Some(DefaultKind::Bytes(val)),

            (Some(sql::DefaultKind::Value(PrismaValue::List(vals))), _) => match vals.first() {
                None => Some(DefaultKind::ConstantList(Vec::new())),
                Some(PrismaValue::String(_) | PrismaValue::Json(_)) => {
                    let vals = vals.iter().filter_map(|val| val.as_string()).collect();
                    Some(DefaultKind::StringList(vals))
                }
                Some(
                    PrismaValue::Boolean(_)
                    | PrismaValue::Enum(_)
                    | PrismaValue::Int(_)
                    | PrismaValue::Uuid(_)
                    | PrismaValue::DateTime(_)
                    | PrismaValue::Float(_)
                    | PrismaValue::BigInt(_),
                ) => {
                    let vals = vals.iter().map(|val| val as &'a dyn fmt::Display).collect();
                    Some(DefaultKind::ConstantList(vals))
                }
                Some(PrismaValue::Null) => {
                    let vals = vals.iter().map(|_| &"null" as &'a dyn fmt::Display).collect();
                    Some(DefaultKind::ConstantList(vals))
                }
                Some(PrismaValue::Bytes(_)) => {
                    let vals = vals.iter().filter_map(|val| val.as_bytes()).collect();
                    Some(DefaultKind::BytesList(vals))
                }
                _ => unreachable!(),
            },

            (None, sql::ColumnTypeFamily::String | sql::ColumnTypeFamily::Uuid) => match self.previous {
                Some(previous) if previous.is_cuid() => Some(DefaultKind::Cuid),
                Some(previous) if previous.is_uuid() => Some(DefaultKind::Uuid),
                Some(previous) if previous.is_nanoid() => {
                    let length = previous.value().as_function().and_then(|(_, args, _)| {
                        args.arguments
                            .first()
                            .map(|arg| arg.value.as_numeric_value().unwrap().0.parse::<u8>().unwrap())
                    });

                    Some(DefaultKind::Nanoid(length))
                }
                _ => None,
            },

            _ => None,
        }
    }

    /// The constraint name, if the database uses them for defaults
    /// and it is non-default.
    pub(crate) fn mapped_name(self) -> Option<&'a str> {
        match self.next.default() {
            Some(def) => def.constraint_name().filter(move |name| name != &self.default_name()),
            None => None,
        }
    }

    fn default_name(self) -> String {
        let container_name = match self.next.refine() {
            Either::Left(col) => col.table().name(),
            Either::Right(col) => col.view().name(),
        };

        ConstraintNames::default_name(container_name, self.next.name(), self.context.active_connector())
    }
}