use crate::{
parser::{PrismaDatamodelParser, Rule},
renderer::{LineWriteable, Renderer, TableFormat},
};
use pest::Parser;
use std::iter::Peekable;
type Pair<'a> = pest::iterators::Pair<'a, Rule>;
pub fn reformat(input: &str, indent_width: usize) -> Option<String> {
let mut ast = PrismaDatamodelParser::parse(Rule::schema, input).ok()?;
let mut renderer = Renderer::new(indent_width);
renderer.stream.reserve(input.len() / 2);
reformat_top(&mut renderer, ast.next().unwrap());
if !renderer.stream.ends_with('\n') {
renderer.stream.push('\n');
}
Some(renderer.stream)
}
fn reformat_top(target: &mut Renderer, pair: Pair<'_>) {
let mut pairs = pair.into_inner().peekable();
eat_empty_lines(&mut pairs);
while let Some(current) = pairs.next() {
match current.as_rule() {
Rule::model_declaration | Rule::enum_declaration | Rule::config_block => {
reformat_block_element(current, target)
}
Rule::comment_block => {
let mut table = Default::default();
reformat_comment_block(current, &mut table);
table.render(target);
}
Rule::empty_lines => {
match pairs.peek().map(|b| b.as_rule()) {
None | Some(Rule::EOI) => (), _ => target.end_line(),
}
}
Rule::CATCH_ALL | Rule::BLOCK_LEVEL_CATCH_ALL | Rule::arbitrary_block | Rule::type_alias => {
target.write(current.as_str());
}
Rule::EOI => {}
_ => unreachable(¤t),
}
}
}
fn reformat_key_value(pair: Pair<'_>, table: &mut TableFormat) {
table.start_new_line();
for current in pair.into_inner() {
match current.as_rule() {
Rule::identifier => table.column_locked_writer_for(0).write(current.as_str()),
Rule::expression => {
let mut writer = table.column_locked_writer_for(1);
writer.write("= ");
reformat_expression(current, &mut writer);
}
Rule::trailing_comment => table.append_suffix_to_current_row(current.as_str()),
_ => unreachable(¤t),
}
}
}
fn reformat_block_element(pair: Pair<'_>, renderer: &mut Renderer) {
let mut pairs = pair.into_inner().peekable();
let block_type = pairs.next().unwrap().as_str();
loop {
let current = match pairs.next() {
Some(current) => current,
None => return,
};
match current.as_rule() {
Rule::BLOCK_OPEN => {
eat_empty_lines(&mut pairs);
}
Rule::BLOCK_CLOSE => {}
Rule::model_contents | Rule::config_contents | Rule::enum_contents => {
reformat_block_contents(&mut current.into_inner().peekable(), renderer)
}
Rule::identifier => {
let block_name = current.as_str();
renderer.write(block_type);
renderer.write(" ");
renderer.write(block_name);
renderer.write(" {");
renderer.end_line();
renderer.indent_up();
}
_ => unreachable(¤t),
}
}
}
fn reformat_block_contents<'a>(
pairs: &mut Peekable<impl Iterator<Item = pest::iterators::Pair<'a, Rule>>>,
renderer: &mut Renderer,
) {
let mut attributes: Vec<(Option<Pair<'_>>, Pair<'_>)> = Vec::new(); let mut table = TableFormat::default();
let mut pending_block_comment = None; eat_empty_lines(pairs);
loop {
let ate_empty_lines = eat_empty_lines(pairs);
if ate_empty_lines {
match pairs.peek().map(|pair| pair.as_rule()) {
None | Some(Rule::block_attribute) | Some(Rule::comment_block) => {
}
Some(_) => {
table.render(renderer);
table = TableFormat::default();
table.start_new_line();
}
}
}
let current = match pairs.next() {
Some(current) => current,
None => {
table.render(renderer);
table = Default::default();
if !attributes.is_empty() {
table.start_new_line();
}
sort_attributes(&mut attributes[..]);
for (comment, pair) in attributes.drain(..) {
if let Some(comment) = comment {
reformat_comment_block(comment, &mut table);
}
reformat_block_attribute(pair, &mut table);
}
table.render(renderer);
renderer.indent_down();
renderer.write("}");
renderer.end_line();
return;
}
};
match current.as_rule() {
Rule::comment_block => {
if pairs.peek().map(|pair| pair.as_rule()) == Some(Rule::block_attribute) {
pending_block_comment = Some(current.clone()); } else {
if ate_empty_lines {
table.render(renderer);
table = Default::default();
table.start_new_line();
}
reformat_comment_block(current, &mut table);
}
}
Rule::field_declaration => reformat_field(current, &mut table),
Rule::key_value => reformat_key_value(current, &mut table),
Rule::enum_value_declaration => reformat_enum_entry(current, &mut table),
Rule::block_attribute => attributes.push((pending_block_comment.take(), current)),
Rule::CATCH_ALL | Rule::BLOCK_LEVEL_CATCH_ALL => {
table.interleave(current.as_str().trim_end_matches('\n'));
}
_ => unreachable(¤t),
}
}
}
fn reformat_block_attribute(pair: Pair<'_>, table: &mut TableFormat) {
debug_assert!(pair.as_rule() == Rule::block_attribute);
table.start_new_line();
for current in pair.into_inner() {
match current.as_rule() {
Rule::path => {
let mut writer = table.column_locked_writer_for(0);
writer.write("@@");
writer.write(current.as_str());
}
Rule::arguments_list => reformat_arguments_list(current, &mut table.column_locked_writer_for(0)),
Rule::trailing_comment => table.append_suffix_to_current_row(current.as_str()),
_ => unreachable(¤t),
}
}
}
fn reformat_enum_entry(pair: Pair<'_>, table: &mut TableFormat) {
for current in pair.into_inner() {
match current.as_rule() {
Rule::identifier => {
table.start_new_line();
table.column_locked_writer_for(0).write(current.as_str())
}
Rule::field_attribute => {
let mut writer = table.column_locked_writer_for(1);
writer.write("@");
reformat_function_call(current, &mut writer)
}
Rule::trailing_comment => table.append_suffix_to_current_row(current.as_str()),
Rule::comment_block => reformat_comment_block(current, table),
_ => unreachable(¤t),
}
}
}
fn sort_attributes(attributes: &mut [(Option<Pair<'_>>, Pair<'_>)]) {
attributes.sort_by(|(_, a), (_, b)| {
let sort_index_a = get_sort_index_of_attribute(a.clone());
let sort_index_b = get_sort_index_of_attribute(b.clone());
sort_index_a.cmp(&sort_index_b)
});
}
fn reformat_field(pair: Pair<'_>, table: &mut TableFormat) {
let mut attributes = Vec::new();
for current in pair.into_inner() {
match current.as_rule() {
Rule::identifier => {
table.start_new_line();
table
.column_locked_writer_for(FIELD_NAME_COLUMN)
.write(current.as_str());
}
Rule::field_type => {
let mut writer = table.column_locked_writer_for(FIELD_TYPE_COLUMN);
reformat_field_type(current, &mut writer);
}
Rule::LEGACY_COLON => {}
Rule::trailing_comment => table.append_suffix_to_current_row(current.as_str()),
Rule::field_attribute => {
attributes.push((None, current));
}
_ => unreachable(¤t),
}
}
let mut attributes_writer = table.column_locked_writer_for(FIELD_ATTRIBUTES_COLUMN);
sort_attributes(&mut attributes[..]);
let mut attributes = attributes.into_iter().peekable();
while let Some((_, attribute)) = attributes.next() {
attributes_writer.write("@");
reformat_function_call(attribute, &mut attributes_writer);
if attributes.peek().is_some() {
attributes_writer.write(" ");
}
}
}
fn reformat_field_type(pair: Pair<'_>, target: &mut dyn LineWriteable) {
assert!(pair.as_rule() == Rule::field_type);
for current in pair.into_inner() {
match current.as_rule() {
Rule::optional_type => {
target.write(get_identifier(current));
target.write("?");
}
Rule::base_type | Rule::legacy_required_type => {
target.write(get_identifier(current));
}
Rule::list_type | Rule::legacy_list_type => {
target.write(get_identifier(current));
target.write("[]");
}
Rule::unsupported_optional_list_type => {
target.write(get_identifier(current));
target.write("[]?");
}
_ => unreachable(¤t),
}
}
}
fn get_identifier(pair: Pair<'_>) -> &str {
let ident_token = match pair.as_rule() {
Rule::base_type => pair.as_str(),
Rule::list_type
| Rule::legacy_list_type
| Rule::legacy_required_type
| Rule::optional_type
| Rule::unsupported_optional_list_type => {
let ident_token = pair.into_inner().next().unwrap();
assert!(ident_token.as_rule() == Rule::base_type);
ident_token.as_str()
}
_ => unreachable(&pair),
};
ident_token
}
fn reformat_arguments_list(pair: Pair<'_>, target: &mut dyn LineWriteable) {
debug_assert_eq!(pair.as_rule(), Rule::arguments_list);
target.write("(");
for (idx, current) in pair.into_inner().enumerate() {
let first_arg = idx == 0;
match current.as_rule() {
Rule::named_argument => {
if !first_arg {
target.write(", ");
}
reformat_attribute_arg(current, target);
}
Rule::expression => {
if !first_arg {
target.write(", ");
}
reformat_expression(current, target);
}
Rule::empty_argument => {
if !first_arg {
target.write(", ");
}
reformat_attribute_arg(current, target);
}
Rule::trailing_comma => (), _ => unreachable(¤t),
};
}
target.write(")");
}
fn reformat_attribute_arg(pair: Pair<'_>, target: &mut dyn LineWriteable) {
for current in pair.into_inner() {
match current.as_rule() {
Rule::identifier => {
target.write(current.as_str());
target.write(": ");
}
Rule::expression => reformat_expression(current, target),
Rule::trailing_comma => (), _ => unreachable(¤t),
};
}
}
fn reformat_expression(pair: Pair<'_>, target: &mut dyn LineWriteable) {
for current in pair.into_inner() {
match current.as_rule() {
Rule::numeric_literal => target.write(current.as_str()),
Rule::string_literal => target.write(current.as_str()),
Rule::path => target.write(current.as_str()),
Rule::function_call => reformat_function_call(current, target),
Rule::array_expression => reformat_array_expression(current, target),
_ => unreachable(¤t),
}
}
}
fn reformat_array_expression(pair: Pair<'_>, target: &mut dyn LineWriteable) {
target.write("[");
let mut expr_count = 0;
for current in pair.into_inner() {
match current.as_rule() {
Rule::expression => {
if expr_count > 0 {
target.write(", ");
}
reformat_expression(current, target);
expr_count += 1;
}
_ => unreachable(¤t),
}
}
target.write("]");
}
fn reformat_function_call(pair: Pair<'_>, target: &mut dyn LineWriteable) {
for current in pair.into_inner() {
match current.as_rule() {
Rule::path => target.write(current.as_str()),
Rule::arguments_list => reformat_arguments_list(current, target),
_ => unreachable(¤t),
}
}
}
#[track_caller]
fn unreachable(pair: &Pair<'_>) -> ! {
unreachable!("Encountered impossible declaration during formatting: {pair:?}")
}
fn reformat_comment_block(pair: Pair<'_>, table: &mut TableFormat) {
assert!(pair.as_rule() == Rule::comment_block);
for current in pair.into_inner() {
match current.as_rule() {
Rule::comment | Rule::doc_comment => {
table.start_new_line();
let prefix = if current.as_rule() == Rule::doc_comment {
"///"
} else {
"//"
};
table.append_suffix_to_current_row(prefix);
for inner in current.into_inner() {
match inner.as_rule() {
Rule::doc_content => table.append_suffix_to_current_row(inner.as_str()),
_ => unreachable!(),
}
}
}
_ => unreachable!(),
}
}
}
fn eat_empty_lines<'a>(pairs: &mut Peekable<impl Iterator<Item = Pair<'a>>>) -> bool {
match pairs.peek().map(|p| p.as_rule()) {
Some(Rule::empty_lines) => {
pairs.next(); true
}
_ => false,
}
}
const FIELD_NAME_COLUMN: usize = 0;
const FIELD_TYPE_COLUMN: usize = 1;
const FIELD_ATTRIBUTES_COLUMN: usize = 2;
fn get_sort_index_of_attribute(attribute: Pair<'_>) -> usize {
let path = attribute.into_inner().next().unwrap();
debug_assert_eq!(path.as_rule(), Rule::path);
let path = path.as_str();
let correct_order: &[&str] = &[
"id",
"unique",
"default",
"updatedAt",
"index",
"fulltext",
"map",
"relation",
"ignore",
];
let pos = correct_order.iter().position(|p| path == *p);
pos.unwrap_or(usize::MAX)
}