#![warn(rust_2018_idioms, clippy::all, missing_docs)]
use std::future::Future;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, BufReader, ReadBuf};
use tokio_postgres::tls;
#[cfg(feature = "runtime")]
use tokio_postgres::tls::MakeTlsConnect;
use tokio_postgres::tls::{ChannelBinding, TlsConnect};
#[cfg(test)]
mod test;
#[cfg(feature = "runtime")]
#[derive(Clone)]
pub struct MakeTlsConnector(native_tls::TlsConnector);
#[cfg(feature = "runtime")]
impl MakeTlsConnector {
pub fn new(connector: native_tls::TlsConnector) -> MakeTlsConnector {
MakeTlsConnector(connector)
}
}
#[cfg(feature = "runtime")]
impl<S> MakeTlsConnect<S> for MakeTlsConnector
where
S: AsyncRead + AsyncWrite + Unpin + 'static + Send,
{
type Stream = TlsStream<S>;
type TlsConnect = TlsConnector;
type Error = native_tls::Error;
fn make_tls_connect(&mut self, domain: &str) -> Result<TlsConnector, native_tls::Error> {
Ok(TlsConnector::new(self.0.clone(), domain))
}
}
pub struct TlsConnector {
connector: tokio_native_tls::TlsConnector,
domain: String,
}
impl TlsConnector {
pub fn new(connector: native_tls::TlsConnector, domain: &str) -> TlsConnector {
TlsConnector {
connector: tokio_native_tls::TlsConnector::from(connector),
domain: domain.to_string(),
}
}
}
impl<S> TlsConnect<S> for TlsConnector
where
S: AsyncRead + AsyncWrite + Unpin + 'static + Send,
{
type Stream = TlsStream<S>;
type Error = native_tls::Error;
#[allow(clippy::type_complexity)]
type Future = Pin<Box<dyn Future<Output = Result<TlsStream<S>, native_tls::Error>> + Send>>;
fn connect(self, stream: S) -> Self::Future {
let stream = BufReader::with_capacity(8192, stream);
let future = async move {
let stream = self.connector.connect(&self.domain, stream).await?;
Ok(TlsStream(stream))
};
Box::pin(future)
}
}
pub struct TlsStream<S>(tokio_native_tls::TlsStream<BufReader<S>>);
impl<S> AsyncRead for TlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
}
impl<S> AsyncWrite for TlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.0).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_shutdown(cx)
}
}
impl<S> tls::TlsStream for TlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn channel_binding(&self) -> ChannelBinding {
match self.0.get_ref().tls_server_end_point().ok().flatten() {
Some(buf) => ChannelBinding::tls_server_end_point(buf),
None => ChannelBinding::none(),
}
}
}