1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
use super::GQLResponse;
use crate::{GQLError, PrismaResponse, RequestBody};
use futures::FutureExt;
use indexmap::IndexMap;
use query_core::{
    constants::custom_types,
    protocol::EngineProtocol,
    response_ir::{Item, ResponseData},
    schema::QuerySchemaRef,
    ArgumentValue, ArgumentValueObject, BatchDocument, BatchDocumentTransaction, CompactedDocument, Operation,
    QueryDocument, QueryExecutor, TxId,
};
use query_structure::{parse_datetime, stringify_datetime, PrismaValue};
use std::{collections::HashMap, fmt, panic::AssertUnwindSafe};

type ArgsToResult = (HashMap<String, ArgumentValue>, IndexMap<String, Item>);

pub struct RequestHandler<'a> {
    executor: &'a (dyn QueryExecutor + Send + Sync + 'a),
    query_schema: &'a QuerySchemaRef,
    engine_protocol: EngineProtocol,
}

impl<'a> fmt::Debug for RequestHandler<'a> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("RequestHandler").finish()
    }
}

impl<'a> RequestHandler<'a> {
    pub fn new(
        executor: &'a (dyn QueryExecutor + Send + Sync + 'a),
        query_schema: &'a QuerySchemaRef,
        engine_protocol: EngineProtocol,
    ) -> Self {
        Self {
            executor,
            query_schema,
            engine_protocol,
        }
    }

    pub async fn handle(&self, body: RequestBody, tx_id: Option<TxId>, trace_id: Option<String>) -> PrismaResponse {
        tracing::debug!("Incoming GraphQL query: {:?}", &body);

        match body.into_doc(self.query_schema) {
            Ok(QueryDocument::Single(query)) => self.handle_single(query, tx_id, trace_id).await,
            Ok(QueryDocument::Multi(batch)) => match batch.compact(self.query_schema) {
                BatchDocument::Multi(batch, transaction) => {
                    self.handle_batch(batch, transaction, tx_id, trace_id).await
                }
                BatchDocument::Compact(compacted) => self.handle_compacted(compacted, tx_id, trace_id).await,
            },

            Err(err) => PrismaResponse::Single(GQLError::from_handler_error(err).into()),
        }
    }

    async fn handle_single(&self, query: Operation, tx_id: Option<TxId>, trace_id: Option<String>) -> PrismaResponse {
        let gql_response = match AssertUnwindSafe(self.handle_request(query, tx_id, trace_id))
            .catch_unwind()
            .await
        {
            Ok(Ok(response)) => response.into(),
            Ok(Err(err)) => GQLError::from_core_error(err).into(),
            Err(err) => GQLError::from_panic_payload(err).into(),
        };

        PrismaResponse::Single(gql_response)
    }

    async fn handle_batch(
        &self,
        queries: Vec<Operation>,
        transaction: Option<BatchDocumentTransaction>,
        tx_id: Option<TxId>,
        trace_id: Option<String>,
    ) -> PrismaResponse {
        match AssertUnwindSafe(self.executor.execute_all(
            tx_id,
            queries,
            transaction,
            self.query_schema.clone(),
            trace_id,
            self.engine_protocol,
        ))
        .catch_unwind()
        .await
        {
            Ok(Ok(responses)) => {
                let gql_responses: Vec<GQLResponse> = responses
                    .into_iter()
                    .map(|response| match response {
                        Ok(data) => data.into(),
                        Err(err) => GQLError::from_core_error(err).into(),
                    })
                    .collect();

                PrismaResponse::Multi(gql_responses.into())
            }
            Ok(Err(err)) => PrismaResponse::Multi(GQLError::from_core_error(err).into()),
            Err(err) => PrismaResponse::Multi(GQLError::from_panic_payload(err).into()),
        }
    }

    async fn handle_compacted(
        &self,
        document: CompactedDocument,
        tx_id: Option<TxId>,
        trace_id: Option<String>,
    ) -> PrismaResponse {
        let plural_name = document.plural_name();
        let singular_name = document.single_name();
        let throw_on_empty = document.throw_on_empty();
        let keys: Vec<String> = document.keys;
        let arguments = document.arguments;
        let nested_selection = document.nested_selection;

        match AssertUnwindSafe(self.handle_request(document.operation, tx_id, trace_id))
            .catch_unwind()
            .await
        {
            Ok(Ok(response_data)) => {
                let mut gql_response: GQLResponse = response_data.into();

                // At this point, many findUnique queries were converted to a single findMany query and that query was run.
                // This means we have a list of results and we need to map each result back to their original findUnique query.
                // `args_to_results` is the data-structure that allows us to do that mapping.
                // It takes the findMany response and converts it to a map of arguments to result.
                // Let's take an example. Given the following batched queries:
                // [
                //    findUnique(where: { id: 1, name: "Bob" }) { id name age },
                //    findUnique(where: { id: 2, name: "Alice" }) { id name age }
                // ]
                // 1. This gets converted to: findMany(where: { OR: [{ id: 1, name: "Bob" }, { id: 2, name: "Alice" }] }) { id name age }
                // 2. Say we get the following result back: [{ id: 1, name: "Bob", age: 18 }, { id: 2, name: "Alice", age: 27 }]
                // 3. We know the inputted arguments are ["id", "name"]
                // 4. So we go over the result and build the following list:
                // [
                //  ({ id: 1, name: "Bob" },   { id: 1, name: "Bob", age: 18 }),
                //  ({ id: 2, name: "Alice" }, { id: 2, name: "Alice", age: 27 })
                // ]
                // 5. Now, given the original findUnique queries and that list, we can easily find back which arguments maps to which result
                // [
                //    findUnique(where: { id: 1, name: "Bob" }) { id name age } -> { id: 1, name: "Bob", age: 18 }
                //    findUnique(where: { id: 2, name: "Alice" }) { id name age } -> { id: 2, name: "Alice", age: 27 }
                // ]
                let args_to_results: Vec<ArgsToResult> = gql_response
                    .take_data(plural_name)
                    .unwrap()
                    .into_list()
                    .unwrap()
                    .index_by(keys.as_slice());

                let results: Vec<GQLResponse> = arguments
                    .into_iter()
                    .map(|args| {
                        let mut responses = GQLResponse::with_capacity(1);
                        // This is step 5 of the comment above.
                        // Copying here is mandatory due to some of the queries
                        // might be repeated with the same arguments in the original
                        // batch. We need to give the same answer for both of them.
                        match Self::find_original_result_from_args(&args_to_results, &args) {
                            Some(result) => {
                                // Filter out all the keys not selected in the
                                // original query.
                                let result: IndexMap<String, Item> = result
                                    .clone()
                                    .into_iter()
                                    .filter(|(k, _)| nested_selection.contains(k))
                                    .collect();

                                responses.insert_data(&singular_name, Item::Map(result));
                            }
                            None if throw_on_empty => responses.insert_error(GQLError::from_user_facing_error(
                                user_facing_errors::query_engine::RecordRequiredButNotFound {
                                    cause: "Expected a record, found none.".to_owned(),
                                }
                                .into(),
                            )),
                            None => responses.insert_data(&singular_name, Item::null()),
                        }

                        responses
                    })
                    .collect();

                PrismaResponse::Multi(results.into())
            }

            Ok(Err(err)) => PrismaResponse::Multi(GQLError::from_core_error(err).into()),

            // panicked
            Err(err) => PrismaResponse::Multi(GQLError::from_panic_payload(err).into()),
        }
    }

    async fn handle_request(
        &self,
        query_doc: Operation,
        tx_id: Option<TxId>,
        trace_id: Option<String>,
    ) -> query_core::Result<ResponseData> {
        self.executor
            .execute(
                tx_id,
                query_doc,
                self.query_schema.clone(),
                trace_id,
                self.engine_protocol,
            )
            .await
    }

    fn find_original_result_from_args<'b>(
        args_to_results: &'b [ArgsToResult],
        input_args: &'b HashMap<String, ArgumentValue>,
    ) -> Option<&'b IndexMap<String, Item>> {
        args_to_results
            .iter()
            .find(|(arg_from_result, _)| Self::compare_args(arg_from_result, input_args))
            .map(|(_, result)| result)
    }

    fn compare_args(left: &HashMap<String, ArgumentValue>, right: &HashMap<String, ArgumentValue>) -> bool {
        left.iter().all(|(key, left_value)| {
            right
                .get(key)
                .map_or(false, |right_value| Self::compare_values(left_value, right_value))
        })
    }

    /// Compares two PrismaValues with special comparisons rules needed because user-inputted values are coerced differently than response values.
    /// We need this when comparing user-inputted values with query response values in the context of compacted queries.
    /// Here are the cases covered:
    /// - DateTime/String: User-input: DateTime / Response: String
    /// - Int/BigInt: User-input: Int / Response: BigInt
    /// - (JSON protocol only) Custom types (eg: { "$type": "BigInt", value: "1" }): User-input: Scalar / Response: Object
    /// - (JSON protocol only) String/Enum: User-input: String / Response: Enum
    /// This should likely _not_ be used outside of this specific context.
    fn compare_values(left: &ArgumentValue, right: &ArgumentValue) -> bool {
        match (left, right) {
            (ArgumentValue::Scalar(PrismaValue::String(t1)), ArgumentValue::Scalar(PrismaValue::DateTime(t2)))
            | (ArgumentValue::Scalar(PrismaValue::DateTime(t2)), ArgumentValue::Scalar(PrismaValue::String(t1))) => {
                parse_datetime(t1)
                    .map(|t1| &t1 == t2)
                    .unwrap_or_else(|_| t1 == stringify_datetime(t2).as_str())
            }
            (ArgumentValue::Scalar(PrismaValue::Int(i1)), ArgumentValue::Scalar(PrismaValue::BigInt(i2)))
            | (ArgumentValue::Scalar(PrismaValue::BigInt(i2)), ArgumentValue::Scalar(PrismaValue::Int(i1))) => {
                *i1 == *i2
            }
            (ArgumentValue::Scalar(PrismaValue::Enum(s1)), ArgumentValue::Scalar(PrismaValue::String(s2)))
            | (ArgumentValue::Scalar(PrismaValue::String(s1)), ArgumentValue::Scalar(PrismaValue::Enum(s2))) => {
                *s1 == *s2
            }
            (ArgumentValue::Object(t1), t2) | (t2, ArgumentValue::Object(t1)) => match Self::unwrap_value(t1) {
                Some(t1) => Self::compare_values(t1, t2),
                None => left == right,
            },
            (left, right) => left == right,
        }
    }

    fn unwrap_value(obj: &ArgumentValueObject) -> Option<&ArgumentValue> {
        if obj.len() != 2 {
            return None;
        }

        if !obj.contains_key(custom_types::TYPE) || !obj.contains_key(custom_types::VALUE) {
            return None;
        }

        obj.get(custom_types::VALUE)
    }
}