use std::{
collections::VecDeque,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use bson::{RawDocument, RawDocumentBuf};
use derivative::Derivative;
use futures_core::{future::BoxFuture, Future};
#[cfg(test)]
use tokio::sync::oneshot;
use crate::{
bson::{Bson, Document},
change_stream::event::ResumeToken,
client::{session::ClientSession, AsyncDropToken},
cmap::conn::PinnedConnectionHandle,
error::{Error, ErrorKind, Result},
operation::{self, GetMore},
options::ServerAddress,
results::GetMoreResult,
Client,
Namespace,
};
pub(super) enum AdvanceResult {
Advanced,
Exhausted,
Waiting,
}
#[derive(Derivative)]
#[derivative(Debug)]
pub(super) struct GenericCursor<'s, S> {
#[derivative(Debug = "ignore")]
provider: GetMoreProvider<'s, S>,
client: Client,
info: CursorInformation,
state: Option<CursorState>,
}
impl GenericCursor<'static, ImplicitClientSessionHandle> {
pub(super) fn with_implicit_session(
client: Client,
spec: CursorSpecification,
pinned_connection: PinnedConnection,
session: ImplicitClientSessionHandle,
) -> Self {
let exhausted = spec.id() == 0;
Self {
client,
provider: if exhausted {
GetMoreProvider::Done
} else {
GetMoreProvider::Idle(Box::new(session))
},
info: spec.info,
state: Some(CursorState {
buffer: CursorBuffer::new(spec.initial_buffer),
exhausted,
post_batch_resume_token: None,
pinned_connection,
}),
}
}
pub(super) fn take_implicit_session(&mut self) -> Option<ClientSession> {
self.provider.take_implicit_session()
}
}
impl<'s> GenericCursor<'s, ExplicitClientSessionHandle<'s>> {
pub(super) fn with_explicit_session(
state: CursorState,
client: Client,
info: CursorInformation,
session: ExplicitClientSessionHandle<'s>,
) -> Self {
Self {
provider: GetMoreProvider::Idle(Box::new(session)),
client,
info,
state: state.into(),
}
}
}
impl<'s, S: ClientSessionHandle<'s>> GenericCursor<'s, S> {
pub(super) fn current(&self) -> Option<&RawDocument> {
self.state().buffer.current()
}
#[cfg(test)]
pub(super) fn current_batch(&self) -> &VecDeque<RawDocumentBuf> {
self.state().buffer.as_ref()
}
fn state_mut(&mut self) -> &mut CursorState {
self.state.as_mut().unwrap()
}
pub(super) fn state(&self) -> &CursorState {
self.state.as_ref().unwrap()
}
pub(super) async fn advance(&mut self) -> Result<bool> {
loop {
match self.try_advance().await? {
AdvanceResult::Advanced => return Ok(true),
AdvanceResult::Exhausted => return Ok(false),
AdvanceResult::Waiting => continue,
}
}
}
pub(super) async fn try_advance(&mut self) -> Result<AdvanceResult> {
if self.state_mut().buffer.advance() {
return Ok(AdvanceResult::Advanced);
} else if self.is_exhausted() {
return Ok(AdvanceResult::Exhausted);
}
let client = self.client.clone();
let spec = self.info.clone();
let pin = self.state().pinned_connection.replicate();
let result = self.provider.execute(spec, client, pin).await;
self.handle_get_more_result(result)?;
match self.state_mut().buffer.advance() {
true => Ok(AdvanceResult::Advanced),
false => {
if self.is_exhausted() {
Ok(AdvanceResult::Exhausted)
} else {
Ok(AdvanceResult::Waiting)
}
}
}
}
pub(super) fn take_state(&mut self) -> CursorState {
self.state.take().unwrap()
}
pub(super) fn is_exhausted(&self) -> bool {
self.state().exhausted
}
pub(super) fn id(&self) -> i64 {
self.info.id
}
pub(super) fn namespace(&self) -> &Namespace {
&self.info.ns
}
pub(super) fn address(&self) -> &ServerAddress {
&self.info.address
}
pub(super) fn pinned_connection(&self) -> &PinnedConnection {
&self.state().pinned_connection
}
pub(super) fn post_batch_resume_token(&self) -> Option<&ResumeToken> {
self.state().post_batch_resume_token.as_ref()
}
fn mark_exhausted(&mut self) {
self.state_mut().exhausted = true;
self.state_mut().pinned_connection = PinnedConnection::Unpinned;
}
fn handle_get_more_result(&mut self, get_more_result: Result<GetMoreResult>) -> Result<()> {
match get_more_result {
Ok(get_more) => {
if get_more.exhausted {
self.mark_exhausted();
}
if get_more.id != 0 {
self.info.id = get_more.id
}
self.info.ns = get_more.ns;
self.state_mut().buffer = CursorBuffer::new(get_more.batch);
self.state_mut().post_batch_resume_token = get_more.post_batch_resume_token;
Ok(())
}
Err(e) => {
if matches!(*e.kind, ErrorKind::Command(ref e) if e.code == 43 || e.code == 237) {
self.mark_exhausted();
}
if e.is_network_error() {
self.state_mut().pinned_connection.invalidate();
}
Err(e)
}
}
}
}
pub(crate) trait CursorStream {
fn poll_next_in_batch(&mut self, cx: &mut Context<'_>) -> Poll<Result<BatchValue>>;
}
pub(crate) enum BatchValue {
Some { doc: RawDocumentBuf, is_last: bool },
Empty,
Exhausted,
}
impl<'s, S: ClientSessionHandle<'s>> CursorStream for GenericCursor<'s, S> {
fn poll_next_in_batch(&mut self, cx: &mut Context<'_>) -> Poll<Result<BatchValue>> {
if let Some(future) = self.provider.executing_future() {
match Pin::new(future).poll(cx) {
Poll::Ready(get_more_result_and_session) => {
let output = self.handle_get_more_result(get_more_result_and_session.result);
self.provider.clear_execution(
get_more_result_and_session.session,
self.state().exhausted,
);
output?;
}
Poll::Pending => return Poll::Pending,
}
}
match self.state_mut().buffer.next() {
Some(doc) => {
let is_last = self.state().buffer.is_empty();
Poll::Ready(Ok(BatchValue::Some { doc, is_last }))
}
None if !self.state().exhausted && !self.state().pinned_connection.is_invalid() => {
let info = self.info.clone();
let client = self.client.clone();
let state = self.state.as_mut().unwrap();
self.provider
.start_execution(info, client, state.pinned_connection.handle());
Poll::Ready(Ok(BatchValue::Empty))
}
None => Poll::Ready(Ok(BatchValue::Exhausted)),
}
}
}
pub(crate) fn stream_poll_next<S, V>(this: &mut S, cx: &mut Context<'_>) -> Poll<Option<Result<V>>>
where
S: CursorStream,
V: for<'a> serde::Deserialize<'a>,
{
loop {
match this.poll_next_in_batch(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(bv) => match bv? {
BatchValue::Some { doc, .. } => {
return Poll::Ready(Some(Ok(bson::from_slice(doc.as_bytes())?)))
}
BatchValue::Empty => continue,
BatchValue::Exhausted => return Poll::Ready(None),
},
}
}
}
pub(crate) struct NextInBatchFuture<'a, T>(&'a mut T);
impl<'a, T> NextInBatchFuture<'a, T>
where
T: CursorStream,
{
pub(crate) fn new(stream: &'a mut T) -> Self {
Self(stream)
}
}
impl<'a, C> Future for NextInBatchFuture<'a, C>
where
C: CursorStream,
{
type Output = Result<BatchValue>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.0.poll_next_in_batch(cx)
}
}
enum GetMoreProvider<'s, S> {
Executing(BoxFuture<'s, GetMoreResultAndSession<S>>),
Idle(Box<S>),
Done,
}
impl GetMoreProvider<'static, ImplicitClientSessionHandle> {
fn take_implicit_session(&mut self) -> Option<ClientSession> {
match self {
Self::Idle(session) => session.take_implicit_session(),
Self::Executing(..) | Self::Done => None,
}
}
}
impl<'s, S: ClientSessionHandle<'s>> GetMoreProvider<'s, S> {
fn executing_future(&mut self) -> Option<&mut BoxFuture<'s, GetMoreResultAndSession<S>>> {
if let Self::Executing(future) = self {
Some(future)
} else {
None
}
}
fn clear_execution(&mut self, session: S, exhausted: bool) {
if exhausted && session.is_implicit() {
*self = Self::Done
} else {
*self = Self::Idle(Box::new(session))
}
}
fn start_execution(
&mut self,
info: CursorInformation,
client: Client,
pinned_connection: Option<&PinnedConnectionHandle>,
) {
take_mut::take(self, |self_| {
if let Self::Idle(mut session) = self_ {
let pinned_connection = pinned_connection.map(|c| c.replicate());
let future = Box::pin(async move {
let get_more = GetMore::new(info, pinned_connection.as_ref());
let get_more_result = client
.execute_operation(get_more, session.borrow_mut())
.await;
GetMoreResultAndSession {
result: get_more_result,
session: *session,
}
});
Self::Executing(future)
} else {
self_
}
})
}
fn execute(
&mut self,
info: CursorInformation,
client: Client,
pinned_connection: PinnedConnection,
) -> BoxFuture<'_, Result<GetMoreResult>> {
match self {
Self::Idle(ref mut session) => Box::pin(async move {
let get_more = GetMore::new(info, pinned_connection.handle());
client
.execute_operation(get_more, session.borrow_mut())
.await
}),
Self::Executing(_fut) => Box::pin(async {
Err(Error::internal(
"streaming the cursor was cancelled while a request was in progress and must \
be continued before iterating manually",
))
}),
Self::Done => {
Box::pin(async { Err(Error::internal("cursor iterated after already exhausted")) })
}
}
}
}
struct GetMoreResultAndSession<S> {
result: Result<GetMoreResult>,
session: S,
}
#[derive(Debug, Clone)]
pub(crate) struct CursorSpecification {
pub(crate) info: CursorInformation,
pub(crate) initial_buffer: VecDeque<RawDocumentBuf>,
pub(crate) post_batch_resume_token: Option<ResumeToken>,
}
impl CursorSpecification {
pub(crate) fn new(
info: operation::CursorInfo,
address: ServerAddress,
batch_size: impl Into<Option<u32>>,
max_time: impl Into<Option<Duration>>,
comment: impl Into<Option<Bson>>,
) -> Self {
Self {
info: CursorInformation {
ns: info.ns,
id: info.id,
address,
batch_size: batch_size.into(),
max_time: max_time.into(),
comment: comment.into(),
},
initial_buffer: info.first_batch,
post_batch_resume_token: ResumeToken::from_raw(info.post_batch_resume_token),
}
}
pub(crate) fn id(&self) -> i64 {
self.info.id
}
#[cfg(test)]
pub(crate) fn address(&self) -> &ServerAddress {
&self.info.address
}
#[cfg(test)]
pub(crate) fn batch_size(&self) -> Option<u32> {
self.info.batch_size
}
#[cfg(test)]
pub(crate) fn max_time(&self) -> Option<Duration> {
self.info.max_time
}
}
#[derive(Clone, Debug)]
pub(crate) struct CursorInformation {
pub(crate) ns: Namespace,
pub(crate) address: ServerAddress,
pub(crate) id: i64,
pub(crate) batch_size: Option<u32>,
pub(crate) max_time: Option<Duration>,
pub(crate) comment: Option<Bson>,
}
#[derive(Debug)]
pub(crate) enum PinnedConnection {
Valid(PinnedConnectionHandle),
Invalid(PinnedConnectionHandle),
Unpinned,
}
impl PinnedConnection {
pub(super) fn new(handle: Option<PinnedConnectionHandle>) -> Self {
match handle {
Some(h) => Self::Valid(h),
None => Self::Unpinned,
}
}
pub(crate) fn replicate(&self) -> Self {
match self {
Self::Valid(h) => Self::Valid(h.replicate()),
Self::Invalid(h) => Self::Invalid(h.replicate()),
Self::Unpinned => Self::Unpinned,
}
}
pub(crate) fn handle(&self) -> Option<&PinnedConnectionHandle> {
match self {
Self::Valid(h) | Self::Invalid(h) => Some(h),
Self::Unpinned => None,
}
}
fn is_invalid(&self) -> bool {
matches!(self, Self::Invalid(_))
}
fn invalidate(&mut self) {
take_mut::take(self, |self_| {
if let Self::Valid(c) = self_ {
Self::Invalid(c)
} else {
self_
}
});
}
}
pub(super) fn kill_cursor(
client: Client,
drop_token: &mut AsyncDropToken,
ns: &Namespace,
cursor_id: i64,
pinned_conn: PinnedConnection,
drop_address: Option<ServerAddress>,
#[cfg(test)] kill_watcher: Option<oneshot::Sender<()>>,
) {
let coll = client
.database(ns.db.as_str())
.collection::<Document>(ns.coll.as_str());
drop_token.spawn(async move {
if !pinned_conn.is_invalid() {
let _ = coll
.kill_cursor(cursor_id, pinned_conn.handle(), drop_address)
.await;
#[cfg(test)]
if let Some(tx) = kill_watcher {
let _ = tx.send(());
}
}
});
}
#[derive(Debug)]
pub(crate) struct CursorState {
pub(crate) buffer: CursorBuffer,
pub(crate) exhausted: bool,
pub(crate) post_batch_resume_token: Option<ResumeToken>,
pub(crate) pinned_connection: PinnedConnection,
}
#[derive(Debug, Clone)]
pub(crate) struct CursorBuffer {
docs: VecDeque<RawDocumentBuf>,
fresh: bool,
}
impl CursorBuffer {
pub(crate) fn new(initial_buffer: VecDeque<RawDocumentBuf>) -> Self {
Self {
docs: initial_buffer,
fresh: true,
}
}
pub(crate) fn is_empty(&self) -> bool {
self.docs.is_empty()
}
pub(crate) fn next(&mut self) -> Option<RawDocumentBuf> {
self.fresh = false;
self.docs.pop_front()
}
pub(crate) fn advance(&mut self) -> bool {
if self.fresh {
self.fresh = false;
} else {
self.docs.pop_front();
}
!self.is_empty()
}
pub(crate) fn current(&self) -> Option<&RawDocument> {
self.docs.front().map(|d| d.as_ref())
}
}
impl AsRef<VecDeque<RawDocumentBuf>> for CursorBuffer {
fn as_ref(&self) -> &VecDeque<RawDocumentBuf> {
&self.docs
}
}
#[test]
fn test_buffer() {
use bson::rawdoc;
let queue: VecDeque<RawDocumentBuf> =
[rawdoc! { "x": 1 }, rawdoc! { "x": 2 }, rawdoc! { "x": 3 }].into();
let mut buffer = CursorBuffer::new(queue);
assert!(buffer.advance());
assert_eq!(buffer.current(), Some(rawdoc! { "x": 1 }.as_ref()));
assert!(buffer.advance());
assert_eq!(buffer.current(), Some(rawdoc! { "x": 2 }.as_ref()));
assert!(buffer.advance());
assert_eq!(buffer.current(), Some(rawdoc! { "x": 3 }.as_ref()));
assert!(!buffer.advance());
assert_eq!(buffer.current(), None);
}
pub(super) struct ImplicitClientSessionHandle(pub(super) Option<ClientSession>);
impl ImplicitClientSessionHandle {
fn take_implicit_session(&mut self) -> Option<ClientSession> {
self.0.take()
}
}
impl ClientSessionHandle<'_> for ImplicitClientSessionHandle {
fn is_implicit(&self) -> bool {
true
}
fn borrow_mut(&mut self) -> Option<&mut ClientSession> {
self.0.as_mut()
}
}
pub(super) struct ExplicitClientSessionHandle<'a>(pub(super) &'a mut ClientSession);
impl<'a> ClientSessionHandle<'a> for ExplicitClientSessionHandle<'a> {
fn is_implicit(&self) -> bool {
false
}
fn borrow_mut(&mut self) -> Option<&mut ClientSession> {
Some(self.0)
}
}
pub(super) trait ClientSessionHandle<'a>: Send + 'a {
fn is_implicit(&self) -> bool;
fn borrow_mut(&mut self) -> Option<&mut ClientSession>;
}