use crate::{
    cursor_condition, filter::FilterBuilder, model_extensions::*, nested_aggregations, ordering::OrderByBuilder,
    sql_trace::SqlTraceComment, Context,
};
use connector_interface::AggregationSelection;
use itertools::Itertools;
use quaint::ast::*;
use query_structure::*;
use tracing::Span;
pub(crate) trait SelectDefinition {
    fn into_select<'a>(
        self,
        _: &Model,
        virtual_selections: impl IntoIterator<Item = &'a VirtualSelection>,
        ctx: &Context<'_>,
    ) -> (Select<'static>, Vec<Expression<'static>>);
}
impl SelectDefinition for Filter {
    fn into_select<'a>(
        self,
        model: &Model,
        virtual_selections: impl IntoIterator<Item = &'a VirtualSelection>,
        ctx: &Context<'_>,
    ) -> (Select<'static>, Vec<Expression<'static>>) {
        let args = QueryArguments::from((model.clone(), self));
        args.into_select(model, virtual_selections, ctx)
    }
}
impl SelectDefinition for &Filter {
    fn into_select<'a>(
        self,
        model: &Model,
        virtual_selections: impl IntoIterator<Item = &'a VirtualSelection>,
        ctx: &Context<'_>,
    ) -> (Select<'static>, Vec<Expression<'static>>) {
        self.clone().into_select(model, virtual_selections, ctx)
    }
}
impl SelectDefinition for Select<'static> {
    fn into_select<'a>(
        self,
        _: &Model,
        _: impl IntoIterator<Item = &'a VirtualSelection>,
        _ctx: &Context<'_>,
    ) -> (Select<'static>, Vec<Expression<'static>>) {
        (self, vec![])
    }
}
impl SelectDefinition for QueryArguments {
    fn into_select<'a>(
        self,
        model: &Model,
        virtual_selections: impl IntoIterator<Item = &'a VirtualSelection>,
        ctx: &Context<'_>,
    ) -> (Select<'static>, Vec<Expression<'static>>) {
        let order_by_definitions = OrderByBuilder::default().build(&self, ctx);
        let cursor_condition = cursor_condition::build(&self, model, &order_by_definitions, ctx);
        let aggregation_joins = nested_aggregations::build(virtual_selections, ctx);
        let limit = if self.ignore_take { None } else { self.take_abs() };
        let skip = if self.ignore_skip { 0 } else { self.skip.unwrap_or(0) };
        let (filter, filter_joins) = self
            .filter
            .map(|f| FilterBuilder::with_top_level_joins().visit_filter(f, ctx))
            .unwrap_or((ConditionTree::NoCondition, None));
        let conditions = match (filter, cursor_condition) {
            (ConditionTree::NoCondition, cursor) => cursor,
            (filter, ConditionTree::NoCondition) => filter,
            (filter, cursor) => ConditionTree::and(filter, cursor),
        };
        let joined_table = order_by_definitions
            .iter()
            .flat_map(|j| &j.joins)
            .fold(model.as_table(ctx), |acc, join| acc.join(join.clone().data));
        let joined_table = aggregation_joins
            .joins
            .into_iter()
            .fold(joined_table, |acc, join| acc.join(join.data));
        let joined_table = if let Some(filter_joins) = filter_joins {
            filter_joins
                .into_iter()
                .fold(joined_table, |acc, join| acc.join(join.data))
        } else {
            joined_table
        };
        let select_ast = Select::from_table(joined_table)
            .so_that(conditions)
            .offset(skip as usize)
            .append_trace(&Span::current())
            .add_trace_id(ctx.trace_id);
        let select_ast = order_by_definitions
            .iter()
            .fold(select_ast, |acc, o| acc.order_by(o.order_definition.clone()));
        let select_ast = if let Some(distinct) = self.distinct {
            let distinct_fields = ModelProjection::from(distinct)
                .as_columns(ctx)
                .map(Expression::from)
                .collect_vec();
            select_ast.distinct_on(distinct_fields)
        } else {
            select_ast
        };
        match limit {
            Some(limit) => (select_ast.limit(limit as usize), aggregation_joins.columns),
            None => (select_ast, aggregation_joins.columns),
        }
    }
}
pub(crate) fn get_records<'a, T>(
    model: &Model,
    columns: impl Iterator<Item = Column<'static>>,
    virtual_selections: impl IntoIterator<Item = &'a VirtualSelection>,
    query: T,
    ctx: &Context<'_>,
) -> Select<'static>
where
    T: SelectDefinition,
{
    let (select, additional_selection_set) = query.into_select(model, virtual_selections, ctx);
    let select = columns.fold(select, |acc, col| acc.column(col));
    let select = select.append_trace(&Span::current()).add_trace_id(ctx.trace_id);
    additional_selection_set
        .into_iter()
        .fold(select, |acc, col| acc.value(col))
}
pub(crate) fn aggregate(
    model: &Model,
    selections: &[AggregationSelection],
    args: QueryArguments,
    ctx: &Context<'_>,
) -> Select<'static> {
    let columns = extract_columns(model, selections, ctx);
    let sub_query = get_records(model, columns.into_iter(), &[], args, ctx);
    let sub_table = Table::from(sub_query).alias("sub");
    selections.iter().fold(
        Select::from_table(sub_table)
            .append_trace(&Span::current())
            .add_trace_id(ctx.trace_id),
        |select, next_op| match next_op {
            AggregationSelection::Field(field) => select.column(
                Column::from(field.db_name().to_owned())
                    .set_is_enum(field.type_identifier().is_enum())
                    .set_is_selected(true),
            ),
            AggregationSelection::Count { all, fields } => {
                let select = fields.iter().fold(select, |select, next_field| {
                    select.value(count(Column::from(next_field.db_name().to_owned())))
                });
                if *all {
                    select.value(count(asterisk()))
                } else {
                    select
                }
            }
            AggregationSelection::Average(fields) => fields.iter().fold(select, |select, next_field| {
                select.value(avg(Column::from(next_field.db_name().to_owned())))
            }),
            AggregationSelection::Sum(fields) => fields.iter().fold(select, |select, next_field| {
                select.value(sum(Column::from(next_field.db_name().to_owned())))
            }),
            AggregationSelection::Min(fields) => fields.iter().fold(select, |select, next_field| {
                select.value(min(Column::from(next_field.db_name().to_owned())
                    .set_is_enum(next_field.type_identifier().is_enum())
                    .set_is_selected(true)))
            }),
            AggregationSelection::Max(fields) => fields.iter().fold(select, |select, next_field| {
                select.value(max(Column::from(next_field.db_name().to_owned())
                    .set_is_enum(next_field.type_identifier().is_enum())
                    .set_is_selected(true)))
            }),
        },
    )
}
pub(crate) fn group_by_aggregate(
    model: &Model,
    args: QueryArguments,
    selections: &[AggregationSelection],
    group_by: Vec<ScalarFieldRef>,
    having: Option<Filter>,
    ctx: &Context<'_>,
) -> Select<'static> {
    let (base_query, _) = args.into_select(model, &[], ctx);
    let select_query = selections.iter().fold(base_query, |select, next_op| match next_op {
        AggregationSelection::Field(field) => select.column(field.as_column(ctx).set_is_selected(true)),
        AggregationSelection::Count { all, fields } => {
            let select = fields.iter().fold(select, |select, next_field| {
                select.value(count(next_field.as_column(ctx)))
            });
            if *all {
                select.value(count(asterisk()))
            } else {
                select
            }
        }
        AggregationSelection::Average(fields) => fields.iter().fold(select, |select, next_field| {
            select.value(avg(next_field.as_column(ctx)))
        }),
        AggregationSelection::Sum(fields) => fields.iter().fold(select, |select, next_field| {
            select.value(sum(next_field.as_column(ctx)))
        }),
        AggregationSelection::Min(fields) => fields.iter().fold(select, |select, next_field| {
            select.value(min(next_field.as_column(ctx).set_is_selected(true)))
        }),
        AggregationSelection::Max(fields) => fields.iter().fold(select, |select, next_field| {
            select.value(max(next_field.as_column(ctx).set_is_selected(true)))
        }),
    });
    let grouped = group_by.into_iter().fold(
        select_query.append_trace(&Span::current()).add_trace_id(ctx.trace_id),
        |query, field| query.group_by(field.as_column(ctx)),
    );
    match having {
        Some(filter) => {
            let cond = FilterBuilder::without_top_level_joins().visit_filter(filter, ctx);
            grouped.having(cond)
        }
        None => grouped,
    }
}
fn extract_columns(model: &Model, selections: &[AggregationSelection], ctx: &Context<'_>) -> Vec<Column<'static>> {
    let fields: Vec<_> = selections
        .iter()
        .flat_map(|selection| match selection {
            AggregationSelection::Field(field) => vec![field.clone()],
            AggregationSelection::Count { all: _, fields } => {
                if fields.is_empty() {
                    model
                        .primary_identifier()
                        .as_scalar_fields()
                        .expect("Primary identifier has non-scalar fields.")
                } else {
                    fields.clone()
                }
            }
            AggregationSelection::Average(fields) => fields.clone(),
            AggregationSelection::Sum(fields) => fields.clone(),
            AggregationSelection::Min(fields) => fields.clone(),
            AggregationSelection::Max(fields) => fields.clone(),
        })
        .unique_by(|field| field.db_name().to_owned())
        .collect();
    fields.as_columns(ctx).collect()
}