use std::{
net::SocketAddr,
ops::DerefMut,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use crate::{
error::{ErrorKind, Result},
options::ServerAddress,
runtime,
};
use super::{tls::AsyncTlsStream, TlsConfig};
pub(crate) const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
const KEEPALIVE_TIME: Duration = Duration::from_secs(120);
#[allow(clippy::large_enum_variant)]
#[derive(Debug)]
pub(crate) enum AsyncStream {
Null,
Tcp(AsyncTcpStream),
Tls(AsyncTlsStream),
#[cfg(unix)]
Unix(unix::AsyncUnixStream),
}
impl AsyncStream {
pub(crate) async fn connect(
address: ServerAddress,
tls_cfg: Option<&TlsConfig>,
) -> Result<Self> {
match &address {
ServerAddress::Tcp { host, .. } => {
let inner = AsyncTcpStream::connect(&address).await?;
match tls_cfg {
Some(cfg) => Ok(AsyncStream::Tls(
AsyncTlsStream::connect(host, inner, cfg).await?,
)),
None => Ok(AsyncStream::Tcp(inner)),
}
}
#[cfg(unix)]
ServerAddress::Unix { .. } => Ok(AsyncStream::Unix(
unix::AsyncUnixStream::connect(&address).await?,
)),
}
}
}
#[cfg(unix)]
mod unix {
use std::{
ops::DerefMut,
path::Path,
pin::Pin,
task::{Context, Poll},
};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use crate::{client::options::ServerAddress, error::Result};
#[derive(Debug)]
pub(crate) enum AsyncUnixStream {
#[cfg(feature = "tokio-runtime")]
Tokio(tokio::net::UnixStream),
#[cfg(feature = "async-std-runtime")]
AsyncStd(async_std::os::unix::net::UnixStream),
}
#[cfg(feature = "tokio-runtime")]
impl From<tokio::net::UnixStream> for AsyncUnixStream {
fn from(stream: tokio::net::UnixStream) -> Self {
Self::Tokio(stream)
}
}
#[cfg(feature = "async-std-runtime")]
impl From<async_std::os::unix::net::UnixStream> for AsyncUnixStream {
fn from(stream: async_std::os::unix::net::UnixStream) -> Self {
Self::AsyncStd(stream)
}
}
impl AsyncUnixStream {
#[cfg(feature = "tokio-runtime")]
async fn try_connect(address: &Path) -> Result<Self> {
use tokio::net::UnixStream;
let stream = UnixStream::connect(address).await?;
Ok(stream.into())
}
#[cfg(feature = "async-std-runtime")]
async fn try_connect(address: &Path) -> Result<Self> {
use async_std::os::unix::net::UnixStream;
let stream = UnixStream::connect(address).await?;
Ok(stream.into())
}
pub(crate) async fn connect(address: &ServerAddress) -> Result<Self> {
debug_assert!(
matches!(address, ServerAddress::Unix { .. }),
"address must be unix"
);
match address {
ServerAddress::Unix { ref path } => Self::try_connect(path.as_path()).await,
_ => unreachable!(),
}
}
}
impl AsyncRead for AsyncUnixStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf,
) -> Poll<tokio::io::Result<()>> {
match self.deref_mut() {
#[cfg(feature = "tokio-runtime")]
Self::Tokio(ref mut inner) => Pin::new(inner).poll_read(cx, buf),
#[cfg(feature = "async-std-runtime")]
Self::AsyncStd(ref mut inner) => {
use tokio_util::compat::FuturesAsyncReadCompatExt;
Pin::new(&mut inner.compat()).poll_read(cx, buf)
}
}
}
}
impl AsyncWrite for AsyncUnixStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<tokio::io::Result<usize>> {
match self.deref_mut() {
#[cfg(feature = "tokio-runtime")]
Self::Tokio(ref mut inner) => Pin::new(inner).poll_write(cx, buf),
#[cfg(feature = "async-std-runtime")]
Self::AsyncStd(ref mut inner) => {
use tokio_util::compat::FuturesAsyncReadCompatExt;
Pin::new(&mut inner.compat()).poll_write(cx, buf)
}
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<tokio::io::Result<()>> {
match self.deref_mut() {
#[cfg(feature = "tokio-runtime")]
Self::Tokio(ref mut inner) => Pin::new(inner).poll_flush(cx),
#[cfg(feature = "async-std-runtime")]
Self::AsyncStd(ref mut inner) => {
use tokio_util::compat::FuturesAsyncReadCompatExt;
Pin::new(&mut inner.compat()).poll_flush(cx)
}
}
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context,
) -> Poll<tokio::io::Result<()>> {
match self.deref_mut() {
#[cfg(feature = "tokio-runtime")]
Self::Tokio(ref mut inner) => Pin::new(inner).poll_shutdown(cx),
#[cfg(feature = "async-std-runtime")]
Self::AsyncStd(ref mut inner) => {
use tokio_util::compat::FuturesAsyncReadCompatExt;
Pin::new(&mut inner.compat()).poll_shutdown(cx)
}
}
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[futures_io::IoSlice<'_>],
) -> Poll<std::result::Result<usize, std::io::Error>> {
match self.get_mut() {
#[cfg(feature = "tokio-runtime")]
Self::Tokio(ref mut inner) => Pin::new(inner).poll_write_vectored(cx, bufs),
#[cfg(feature = "async-std-runtime")]
Self::AsyncStd(ref mut inner) => {
use tokio_util::compat::FuturesAsyncReadCompatExt;
Pin::new(&mut inner.compat()).poll_write_vectored(cx, bufs)
}
}
}
fn is_write_vectored(&self) -> bool {
match self {
#[cfg(feature = "tokio-runtime")]
Self::Tokio(ref inner) => inner.is_write_vectored(),
#[cfg(feature = "async-std-runtime")]
Self::AsyncStd(_) => false,
}
}
}
}
#[derive(Debug)]
pub(crate) enum AsyncTcpStream {
#[cfg(feature = "tokio-runtime")]
Tokio(tokio::net::TcpStream),
#[cfg(feature = "async-std-runtime")]
AsyncStd(async_std::net::TcpStream),
}
#[cfg(feature = "tokio-runtime")]
impl From<tokio::net::TcpStream> for AsyncTcpStream {
fn from(stream: tokio::net::TcpStream) -> Self {
Self::Tokio(stream)
}
}
#[cfg(feature = "async-std-runtime")]
impl From<async_std::net::TcpStream> for AsyncTcpStream {
fn from(stream: async_std::net::TcpStream) -> Self {
Self::AsyncStd(stream)
}
}
impl AsyncTcpStream {
#[cfg(feature = "tokio-runtime")]
async fn try_connect(address: &SocketAddr) -> Result<Self> {
use tokio::net::TcpStream;
let stream = TcpStream::connect(address).await?;
stream.set_nodelay(true)?;
let socket = socket2::Socket::from(stream.into_std()?);
let conf = socket2::TcpKeepalive::new().with_time(KEEPALIVE_TIME);
socket.set_tcp_keepalive(&conf)?;
let std_stream = std::net::TcpStream::from(socket);
let stream = TcpStream::from_std(std_stream)?;
Ok(stream.into())
}
#[cfg(feature = "async-std-runtime")]
async fn try_connect(address: &SocketAddr) -> Result<Self> {
use async_std::net::TcpStream;
let stream = TcpStream::connect(address).await?;
stream.set_nodelay(true)?;
let std_stream: std::net::TcpStream = stream.try_into()?;
let socket = socket2::Socket::from(std_stream);
let conf = socket2::TcpKeepalive::new().with_time(KEEPALIVE_TIME);
socket.set_tcp_keepalive(&conf)?;
let std_stream = std::net::TcpStream::from(socket);
let stream = TcpStream::from(std_stream);
Ok(stream.into())
}
pub(crate) async fn connect(address: &ServerAddress) -> Result<Self> {
let mut socket_addrs: Vec<_> = runtime::resolve_address(address).await?.collect();
if socket_addrs.is_empty() {
return Err(ErrorKind::DnsResolve {
message: format!("No DNS results for domain {}", address),
}
.into());
}
socket_addrs.sort_by_key(|addr| if addr.is_ipv4() { 0 } else { 1 });
let mut connect_error = None;
for address in &socket_addrs {
connect_error = match Self::try_connect(address).await {
Ok(stream) => return Ok(stream),
Err(err) => Some(err),
};
}
Err(connect_error.unwrap_or_else(|| {
ErrorKind::Internal {
message: "connecting to all DNS results failed but no error reported".to_string(),
}
.into()
}))
}
}
impl tokio::io::AsyncRead for AsyncStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match self.deref_mut() {
Self::Null => Poll::Ready(Ok(())),
Self::Tcp(ref mut inner) => tokio::io::AsyncRead::poll_read(Pin::new(inner), cx, buf),
Self::Tls(ref mut inner) => tokio::io::AsyncRead::poll_read(Pin::new(inner), cx, buf),
#[cfg(unix)]
Self::Unix(ref mut inner) => tokio::io::AsyncRead::poll_read(Pin::new(inner), cx, buf),
}
}
}
impl AsyncWrite for AsyncStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
match self.deref_mut() {
Self::Null => Poll::Ready(Ok(0)),
Self::Tcp(ref mut inner) => AsyncWrite::poll_write(Pin::new(inner), cx, buf),
Self::Tls(ref mut inner) => Pin::new(inner).poll_write(cx, buf),
#[cfg(unix)]
Self::Unix(ref mut inner) => AsyncWrite::poll_write(Pin::new(inner), cx, buf),
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.deref_mut() {
Self::Null => Poll::Ready(Ok(())),
Self::Tcp(ref mut inner) => AsyncWrite::poll_flush(Pin::new(inner), cx),
Self::Tls(ref mut inner) => Pin::new(inner).poll_flush(cx),
#[cfg(unix)]
Self::Unix(ref mut inner) => AsyncWrite::poll_flush(Pin::new(inner), cx),
}
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.deref_mut() {
Self::Null => Poll::Ready(Ok(())),
Self::Tcp(ref mut inner) => Pin::new(inner).poll_shutdown(cx),
Self::Tls(ref mut inner) => Pin::new(inner).poll_shutdown(cx),
#[cfg(unix)]
Self::Unix(ref mut inner) => Pin::new(inner).poll_shutdown(cx),
}
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[futures_io::IoSlice<'_>],
) -> Poll<std::result::Result<usize, std::io::Error>> {
match self.get_mut() {
Self::Null => Poll::Ready(Ok(0)),
Self::Tcp(ref mut inner) => Pin::new(inner).poll_write_vectored(cx, bufs),
Self::Tls(ref mut inner) => Pin::new(inner).poll_write_vectored(cx, bufs),
#[cfg(unix)]
Self::Unix(ref mut inner) => Pin::new(inner).poll_write_vectored(cx, bufs),
}
}
fn is_write_vectored(&self) -> bool {
match self {
Self::Null => false,
Self::Tcp(ref inner) => inner.is_write_vectored(),
Self::Tls(ref inner) => inner.is_write_vectored(),
#[cfg(unix)]
Self::Unix(ref inner) => inner.is_write_vectored(),
}
}
}
impl AsyncRead for AsyncTcpStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf,
) -> Poll<tokio::io::Result<()>> {
match self.deref_mut() {
#[cfg(feature = "tokio-runtime")]
Self::Tokio(ref mut inner) => Pin::new(inner).poll_read(cx, buf),
#[cfg(feature = "async-std-runtime")]
Self::AsyncStd(ref mut inner) => {
use tokio_util::compat::FuturesAsyncReadCompatExt;
Pin::new(&mut inner.compat()).poll_read(cx, buf)
}
}
}
}
impl AsyncWrite for AsyncTcpStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<tokio::io::Result<usize>> {
match self.deref_mut() {
#[cfg(feature = "tokio-runtime")]
Self::Tokio(ref mut inner) => Pin::new(inner).poll_write(cx, buf),
#[cfg(feature = "async-std-runtime")]
Self::AsyncStd(ref mut inner) => {
use tokio_util::compat::FuturesAsyncReadCompatExt;
Pin::new(&mut inner.compat()).poll_write(cx, buf)
}
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<tokio::io::Result<()>> {
match self.deref_mut() {
#[cfg(feature = "tokio-runtime")]
Self::Tokio(ref mut inner) => Pin::new(inner).poll_flush(cx),
#[cfg(feature = "async-std-runtime")]
Self::AsyncStd(ref mut inner) => {
use tokio_util::compat::FuturesAsyncReadCompatExt;
Pin::new(&mut inner.compat()).poll_flush(cx)
}
}
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<tokio::io::Result<()>> {
match self.deref_mut() {
#[cfg(feature = "tokio-runtime")]
Self::Tokio(ref mut inner) => Pin::new(inner).poll_shutdown(cx),
#[cfg(feature = "async-std-runtime")]
Self::AsyncStd(ref mut inner) => {
use tokio_util::compat::FuturesAsyncReadCompatExt;
Pin::new(&mut inner.compat()).poll_shutdown(cx)
}
}
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[futures_io::IoSlice<'_>],
) -> Poll<std::result::Result<usize, std::io::Error>> {
match self.get_mut() {
#[cfg(feature = "tokio-runtime")]
Self::Tokio(ref mut inner) => Pin::new(inner).poll_write_vectored(cx, bufs),
#[cfg(feature = "async-std-runtime")]
Self::AsyncStd(ref mut inner) => {
use tokio_util::compat::FuturesAsyncReadCompatExt;
Pin::new(&mut inner.compat()).poll_write_vectored(cx, bufs)
}
}
}
fn is_write_vectored(&self) -> bool {
match self {
#[cfg(feature = "tokio-runtime")]
Self::Tokio(ref inner) => inner.is_write_vectored(),
#[cfg(feature = "async-std-runtime")]
Self::AsyncStd(_) => false,
}
}
}