#[cfg(feature = "compression")]
use crate::codec::compression::{
CompressionEncoding, EnabledCompressionEncodings, SingleMessageCompressionOverride,
};
use crate::{
body::BoxBody,
codec::{encode_server, Codec, Streaming},
server::{ClientStreamingService, ServerStreamingService, StreamingService, UnaryService},
Code, Request, Status,
};
use futures_core::TryStream;
use futures_util::{future, stream, TryStreamExt};
use http_body::Body;
use std::fmt;
macro_rules! t {
($result:expr) => {
match $result {
Ok(value) => value,
Err(status) => return status.to_http(),
}
};
}
pub struct Grpc<T> {
codec: T,
#[cfg(feature = "compression")]
accept_compression_encodings: EnabledCompressionEncodings,
#[cfg(feature = "compression")]
send_compression_encodings: EnabledCompressionEncodings,
}
impl<T> Grpc<T>
where
T: Codec,
{
pub fn new(codec: T) -> Self {
Self {
codec,
#[cfg(feature = "compression")]
accept_compression_encodings: EnabledCompressionEncodings::default(),
#[cfg(feature = "compression")]
send_compression_encodings: EnabledCompressionEncodings::default(),
}
}
#[cfg(feature = "compression")]
#[cfg_attr(docsrs, doc(cfg(feature = "compression")))]
pub fn accept_gzip(mut self) -> Self {
self.accept_compression_encodings.enable_gzip();
self
}
#[doc(hidden)]
#[cfg(not(feature = "compression"))]
pub fn accept_gzip(self) -> Self {
panic!("`accept_gzip` called on a server but the `compression` feature is not enabled on tonic");
}
#[cfg(feature = "compression")]
#[cfg_attr(docsrs, doc(cfg(feature = "compression")))]
pub fn send_gzip(mut self) -> Self {
self.send_compression_encodings.enable_gzip();
self
}
#[doc(hidden)]
#[cfg(not(feature = "compression"))]
pub fn send_gzip(self) -> Self {
panic!(
"`send_gzip` called on a server but the `compression` feature is not enabled on tonic"
);
}
#[cfg(feature = "compression")]
#[doc(hidden)]
pub fn apply_compression_config(
self,
accept_encodings: EnabledCompressionEncodings,
send_encodings: EnabledCompressionEncodings,
) -> Self {
let mut this = self;
let EnabledCompressionEncodings { gzip: accept_gzip } = accept_encodings;
if accept_gzip {
this = this.accept_gzip();
}
let EnabledCompressionEncodings { gzip: send_gzip } = send_encodings;
if send_gzip {
this = this.send_gzip();
}
this
}
#[cfg(not(feature = "compression"))]
#[doc(hidden)]
#[allow(unused_variables)]
pub fn apply_compression_config(self, accept_encodings: (), send_encodings: ()) -> Self {
self
}
pub async fn unary<S, B>(
&mut self,
mut service: S,
req: http::Request<B>,
) -> http::Response<BoxBody>
where
S: UnaryService<T::Decode, Response = T::Encode>,
B: Body + Send + 'static,
B::Error: Into<crate::Error> + Send,
{
#[cfg(feature = "compression")]
let accept_encoding = CompressionEncoding::from_accept_encoding_header(
req.headers(),
self.send_compression_encodings,
);
let request = match self.map_request_unary(req).await {
Ok(r) => r,
Err(status) => {
return self
.map_response::<stream::Once<future::Ready<Result<T::Encode, Status>>>>(
Err(status),
#[cfg(feature = "compression")]
accept_encoding,
#[cfg(feature = "compression")]
SingleMessageCompressionOverride::default(),
);
}
};
let response = service
.call(request)
.await
.map(|r| r.map(|m| stream::once(future::ok(m))));
#[cfg(feature = "compression")]
let compression_override = compression_override_from_response(&response);
self.map_response(
response,
#[cfg(feature = "compression")]
accept_encoding,
#[cfg(feature = "compression")]
compression_override,
)
}
pub async fn server_streaming<S, B>(
&mut self,
mut service: S,
req: http::Request<B>,
) -> http::Response<BoxBody>
where
S: ServerStreamingService<T::Decode, Response = T::Encode>,
S::ResponseStream: Send + 'static,
B: Body + Send + 'static,
B::Error: Into<crate::Error> + Send,
{
#[cfg(feature = "compression")]
let accept_encoding = CompressionEncoding::from_accept_encoding_header(
req.headers(),
self.send_compression_encodings,
);
let request = match self.map_request_unary(req).await {
Ok(r) => r,
Err(status) => {
return self.map_response::<S::ResponseStream>(
Err(status),
#[cfg(feature = "compression")]
accept_encoding,
#[cfg(feature = "compression")]
SingleMessageCompressionOverride::default(),
);
}
};
let response = service.call(request).await;
self.map_response(
response,
#[cfg(feature = "compression")]
accept_encoding,
#[cfg(feature = "compression")]
SingleMessageCompressionOverride::default(),
)
}
pub async fn client_streaming<S, B>(
&mut self,
mut service: S,
req: http::Request<B>,
) -> http::Response<BoxBody>
where
S: ClientStreamingService<T::Decode, Response = T::Encode>,
B: Body + Send + 'static,
B::Error: Into<crate::Error> + Send + 'static,
{
#[cfg(feature = "compression")]
let accept_encoding = CompressionEncoding::from_accept_encoding_header(
req.headers(),
self.send_compression_encodings,
);
let request = t!(self.map_request_streaming(req));
let response = service
.call(request)
.await
.map(|r| r.map(|m| stream::once(future::ok(m))));
#[cfg(feature = "compression")]
let compression_override = compression_override_from_response(&response);
self.map_response(
response,
#[cfg(feature = "compression")]
accept_encoding,
#[cfg(feature = "compression")]
compression_override,
)
}
pub async fn streaming<S, B>(
&mut self,
mut service: S,
req: http::Request<B>,
) -> http::Response<BoxBody>
where
S: StreamingService<T::Decode, Response = T::Encode> + Send,
S::ResponseStream: Send + 'static,
B: Body + Send + 'static,
B::Error: Into<crate::Error> + Send,
{
#[cfg(feature = "compression")]
let accept_encoding = CompressionEncoding::from_accept_encoding_header(
req.headers(),
self.send_compression_encodings,
);
let request = t!(self.map_request_streaming(req));
let response = service.call(request).await;
self.map_response(
response,
#[cfg(feature = "compression")]
accept_encoding,
#[cfg(feature = "compression")]
SingleMessageCompressionOverride::default(),
)
}
async fn map_request_unary<B>(
&mut self,
request: http::Request<B>,
) -> Result<Request<T::Decode>, Status>
where
B: Body + Send + 'static,
B::Error: Into<crate::Error> + Send,
{
#[cfg(feature = "compression")]
let request_compression_encoding = self.request_encoding_if_supported(&request)?;
let (parts, body) = request.into_parts();
#[cfg(feature = "compression")]
let stream =
Streaming::new_request(self.codec.decoder(), body, request_compression_encoding);
#[cfg(not(feature = "compression"))]
let stream = Streaming::new_request(self.codec.decoder(), body);
futures_util::pin_mut!(stream);
let message = stream
.try_next()
.await?
.ok_or_else(|| Status::new(Code::Internal, "Missing request message."))?;
let mut req = Request::from_http_parts(parts, message);
if let Some(trailers) = stream.trailers().await? {
req.metadata_mut().merge(trailers);
}
Ok(req)
}
fn map_request_streaming<B>(
&mut self,
request: http::Request<B>,
) -> Result<Request<Streaming<T::Decode>>, Status>
where
B: Body + Send + 'static,
B::Error: Into<crate::Error> + Send,
{
#[cfg(feature = "compression")]
let encoding = self.request_encoding_if_supported(&request)?;
#[cfg(feature = "compression")]
let request =
request.map(|body| Streaming::new_request(self.codec.decoder(), body, encoding));
#[cfg(not(feature = "compression"))]
let request = request.map(|body| Streaming::new_request(self.codec.decoder(), body));
Ok(Request::from_http(request))
}
fn map_response<B>(
&mut self,
response: Result<crate::Response<B>, Status>,
#[cfg(feature = "compression")] accept_encoding: Option<CompressionEncoding>,
#[cfg(feature = "compression")] compression_override: SingleMessageCompressionOverride,
) -> http::Response<BoxBody>
where
B: TryStream<Ok = T::Encode, Error = Status> + Send + 'static,
{
let response = match response {
Ok(r) => r,
Err(status) => return status.to_http(),
};
let (mut parts, body) = response.into_http().into_parts();
parts.headers.insert(
http::header::CONTENT_TYPE,
http::header::HeaderValue::from_static("application/grpc"),
);
#[cfg(feature = "compression")]
if let Some(encoding) = accept_encoding {
parts.headers.insert(
crate::codec::compression::ENCODING_HEADER,
encoding.into_header_value(),
);
}
let body = encode_server(
self.codec.encoder(),
body.into_stream(),
#[cfg(feature = "compression")]
accept_encoding,
#[cfg(feature = "compression")]
compression_override,
);
http::Response::from_parts(parts, BoxBody::new(body))
}
#[cfg(feature = "compression")]
fn request_encoding_if_supported<B>(
&self,
request: &http::Request<B>,
) -> Result<Option<CompressionEncoding>, Status> {
CompressionEncoding::from_encoding_header(
request.headers(),
self.accept_compression_encodings,
)
}
}
impl<T: fmt::Debug> fmt::Debug for Grpc<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut f = f.debug_struct("Grpc");
f.field("codec", &self.codec);
#[cfg(feature = "compression")]
f.field(
"accept_compression_encodings",
&self.accept_compression_encodings,
);
#[cfg(feature = "compression")]
f.field(
"send_compression_encodings",
&self.send_compression_encodings,
);
f.finish()
}
}
#[cfg(feature = "compression")]
fn compression_override_from_response<B, E>(
res: &Result<crate::Response<B>, E>,
) -> SingleMessageCompressionOverride {
res.as_ref()
.ok()
.and_then(|response| {
response
.extensions()
.get::<SingleMessageCompressionOverride>()
.copied()
})
.unwrap_or_default()
}