1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
// Copyright (c) 2016 Anatoly Ikorsky
//
// Licensed under the Apache License, Version 2.0
// <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0> or the MIT
// license <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. All files in the project carrying such notice may not be copied,
// modified, or distributed except according to those terms.

use std::{
    fmt,
    future::Future,
    pin::Pin,
    task::{Context, Poll},
};

use futures_core::ready;
#[cfg(feature = "tracing")]
use {
    std::sync::Arc,
    tracing::{debug_span, Span},
};

use crate::{
    conn::{
        pool::{Pool, QueueId},
        Conn,
    },
    error::*,
};

/// States of the GetConn future.
pub(crate) enum GetConnInner {
    New,
    Done,
    // TODO: one day this should be an existential
    Connecting(crate::BoxFuture<'static, Conn>),
    /// This future will check, that idling connection is alive.
    Checking(crate::BoxFuture<'static, Conn>),
}

impl fmt::Debug for GetConnInner {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            GetConnInner::New => f.debug_tuple("GetConnInner::New").finish(),
            GetConnInner::Done => f.debug_tuple("GetConnInner::Done").finish(),
            GetConnInner::Connecting(_) => f
                .debug_tuple("GetConnInner::Connecting")
                .field(&"<future>")
                .finish(),
            GetConnInner::Checking(_) => f
                .debug_tuple("GetConnInner::Checking")
                .field(&"<future>")
                .finish(),
        }
    }
}

impl GetConnInner {
    /// Take the value of the inner connection, resetting it to `New`.
    pub fn take(&mut self) -> GetConnInner {
        std::mem::replace(self, GetConnInner::New)
    }
}

/// This future will take connection from a pool and resolve to [`Conn`].
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct GetConn {
    pub(crate) queue_id: Option<QueueId>,
    pub(crate) pool: Option<Pool>,
    pub(crate) inner: GetConnInner,
    #[cfg(feature = "tracing")]
    span: Arc<Span>,
}

impl GetConn {
    pub(crate) fn new(pool: &Pool) -> GetConn {
        GetConn {
            queue_id: None,
            pool: Some(pool.clone()),
            inner: GetConnInner::New,
            #[cfg(feature = "tracing")]
            span: Arc::new(debug_span!("mysql_async::get_conn")),
        }
    }

    fn pool_mut(&mut self) -> &mut Pool {
        self.pool
            .as_mut()
            .expect("GetConn::poll polled after returning Async::Ready")
    }

    fn pool_take(&mut self) -> Pool {
        self.pool
            .take()
            .expect("GetConn::poll polled after returning Async::Ready")
    }
}

// this manual implementation of Future may seem stupid, but we sort
// of need it to get the dropping behavior we want.
impl Future for GetConn {
    type Output = Result<Conn>;

    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        #[cfg(feature = "tracing")]
        let span = self.span.clone();
        #[cfg(feature = "tracing")]
        let _span_guard = span.enter();
        loop {
            match self.inner {
                GetConnInner::New => {
                    let queued = self.queue_id.is_some();
                    let queue_id = *self.queue_id.get_or_insert_with(QueueId::next);
                    let next =
                        ready!(Pin::new(self.pool_mut()).poll_new_conn(cx, queued, queue_id))?;
                    match next {
                        GetConnInner::Connecting(conn_fut) => {
                            self.inner = GetConnInner::Connecting(conn_fut);
                        }
                        GetConnInner::Checking(conn_fut) => {
                            self.inner = GetConnInner::Checking(conn_fut);
                        }
                        GetConnInner::Done => unreachable!(
                            "Pool::poll_new_conn never gives out already-consumed GetConns"
                        ),
                        GetConnInner::New => {
                            unreachable!("Pool::poll_new_conn never gives out GetConnInner::New")
                        }
                    }
                }
                GetConnInner::Done => {
                    unreachable!("GetConn::poll polled after returning Async::Ready");
                }
                GetConnInner::Connecting(ref mut f) => {
                    let result = ready!(Pin::new(f).poll(cx));
                    let pool = self.pool_take();

                    self.inner = GetConnInner::Done;

                    return match result {
                        Ok(mut c) => {
                            c.inner.pool = Some(pool);
                            Poll::Ready(Ok(c))
                        }
                        Err(e) => {
                            pool.cancel_connection();
                            Poll::Ready(Err(e))
                        }
                    };
                }
                GetConnInner::Checking(ref mut f) => {
                    let result = ready!(Pin::new(f).poll(cx));
                    match result {
                        Ok(mut checked_conn) => {
                            self.inner = GetConnInner::Done;

                            let pool = self.pool_take();
                            checked_conn.inner.pool = Some(pool);
                            return Poll::Ready(Ok(checked_conn));
                        }
                        Err(_) => {
                            // Idling connection is broken. We'll drop it and try again.
                            self.inner = GetConnInner::New;

                            let pool = self.pool_mut();
                            pool.cancel_connection();
                            continue;
                        }
                    }
                }
            }
        }
    }
}

impl Drop for GetConn {
    fn drop(&mut self) {
        // We drop a connection before it can be resolved, a.k.a. cancelling it.
        // Make sure we maintain the necessary invariants towards the pool.
        if let Some(pool) = self.pool.take() {
            // Remove the waker from the pool's waitlist in case this task was
            // woken by another waker, like from tokio::time::timeout.
            if let Some(queue_id) = self.queue_id {
                pool.unqueue(queue_id);
            }
            if let GetConnInner::Connecting(..) = self.inner.take() {
                pool.cancel_connection();
            }
        }
    }
}