1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
use tokio_postgres::error::DbError;

use crate::{
    connector::postgres::error::PostgresError,
    error::{Error, ErrorKind, NativeErrorKind},
};

impl From<&DbError> for PostgresError {
    fn from(value: &DbError) -> Self {
        PostgresError {
            code: value.code().code().to_string(),
            severity: value.severity().to_string(),
            message: value.message().to_string(),
            detail: value.detail().map(ToString::to_string),
            column: value.column().map(ToString::to_string),
            hint: value.hint().map(ToString::to_string),
        }
    }
}

impl From<tokio_postgres::error::Error> for Error {
    fn from(e: tokio_postgres::error::Error) -> Error {
        if e.is_closed() {
            return Error::builder(ErrorKind::Native(NativeErrorKind::ConnectionClosed)).build();
        }

        if let Some(db_error) = e.as_db_error() {
            return PostgresError::from(db_error).into();
        }

        if let Some(tls_error) = try_extracting_tls_error(&e) {
            return tls_error;
        }

        // Same for IO errors.
        if let Some(io_error) = try_extracting_io_error(&e) {
            return io_error;
        }

        if let Some(uuid_error) = try_extracting_uuid_error(&e) {
            return uuid_error;
        }

        let reason = format!("{e}");
        let code = e.code().map(|c| c.code());

        match reason.as_str() {
            "error connecting to server: timed out" => {
                let mut builder = Error::builder(ErrorKind::Native(NativeErrorKind::ConnectTimeout));

                if let Some(code) = code {
                    builder.set_original_code(code);
                };

                builder.set_original_message(reason);
                builder.build()
            } // sigh...
            // https://github.com/sfackler/rust-postgres/blob/0c84ed9f8201f4e5b4803199a24afa2c9f3723b2/tokio-postgres/src/connect_tls.rs#L37
            "error performing TLS handshake: server does not support TLS" => {
                let mut builder = Error::builder(ErrorKind::Native(NativeErrorKind::TlsError {
                    message: reason.clone(),
                }));

                if let Some(code) = code {
                    builder.set_original_code(code);
                };

                builder.set_original_message(reason);
                builder.build()
            } // double sigh
            _ => {
                let code = code.map(|c| c.to_string());
                let mut builder = Error::builder(ErrorKind::QueryError(e.into()));

                if let Some(code) = code {
                    builder.set_original_code(code);
                };

                builder.set_original_message(reason);
                builder.build()
            }
        }
    }
}

fn try_extracting_uuid_error(err: &tokio_postgres::error::Error) -> Option<Error> {
    use std::error::Error as _;

    err.source()
        .and_then(|err| err.downcast_ref::<uuid::Error>())
        .map(|err| ErrorKind::UUIDError(format!("{err}")))
        .map(|kind| Error::builder(kind).build())
}

fn try_extracting_tls_error(err: &tokio_postgres::error::Error) -> Option<Error> {
    use std::error::Error;

    err.source()
        .and_then(|err| err.downcast_ref::<native_tls::Error>())
        .map(|err| err.into())
}

fn try_extracting_io_error(err: &tokio_postgres::error::Error) -> Option<Error> {
    use std::error::Error as _;

    err.source()
        .and_then(|err| err.downcast_ref::<std::io::Error>())
        .map(|err| {
            ErrorKind::Native(NativeErrorKind::ConnectionError(Box::new(std::io::Error::new(
                err.kind(),
                format!("{err}"),
            ))))
        })
        .map(|kind| Error::builder(kind).build())
}

impl From<native_tls::Error> for Error {
    fn from(e: native_tls::Error) -> Error {
        Error::from(&e)
    }
}

impl From<&native_tls::Error> for Error {
    fn from(e: &native_tls::Error) -> Error {
        let kind = ErrorKind::Native(NativeErrorKind::TlsError {
            message: format!("{e}"),
        });

        Error::builder(kind).build()
    }
}