// This whole thing could maybe be a set of macros, but I'm playing along with the book for now use std::{ collections::BTreeMap, fs::File, io::{self, Write}, iter, }; pub fn main() { let mut ast_file = File::create("./src/ast.rs").expect("could not open src/ast.rs for writing"); do_ast_codegen(&mut ast_file); } fn do_ast_codegen(mut output: W) { // Autogenerated file, clippy can leave us alone writeln!(output, "#![allow(clippy::all)]").unwrap(); writeln!( output, "// This file is auto-generated. Do not modify it directly, but rather modify build.rs.\n" ) .unwrap(); define_imports(&mut output, &["crate::lex::Token"]).expect("failed to define imports"); writeln!(output).unwrap(); define_literals( &mut output, "LiteralValue", &BTreeMap::from([ ("Number", Some("f64")), ("String", Some("String")), ("True", None), ("False", None), ("Nil", None), ]), ) .expect("failed to generate literals"); writeln!(output).unwrap(); let expr_types = BTreeMap::from([ ("Assign", "name: Token, value: Box"), ( "Binary", "left: Box, operator: Token, right: Box", ), ("Grouping", "expr: Box"), ("Unary", "expr: Box, operator: Token"), ("Literal", "value: LiteralValue"), ("Variable", "name: Token"), ]); define_ast(&mut output, "Expr", &expr_types).expect("failed to generate ast values"); writeln!(output).unwrap(); define_visitor_trait(&mut output, "ExprVisitor", "Expr", &expr_types) .expect("failed to generate visitor trait"); writeln!(output).unwrap(); let statement_types = BTreeMap::from([ ("Expression", "expression: Expr"), ("Block", "statements: Vec"), ("Print", "expression: Expr"), ("Var", "name: Token, initializer: Option"), ]); define_ast(&mut output, "Stmt", &statement_types).expect("failed to generate ast values"); writeln!(output).unwrap(); define_visitor_trait(&mut output, "StmtVisitor", "Stmt", &statement_types) .expect("failed to generate visitor trait"); } fn define_imports(mut w: W, paths: &[&str]) -> Result<(), io::Error> { for path in paths { writeln!(w, "use {path};")?; } Ok(()) } fn define_literals( mut w: W, name: &str, types: &BTreeMap<&str, Option<&str>>, ) -> Result<(), io::Error> { writeln!(w, "#[derive(Debug, Clone)]")?; writeln!(w, "pub enum {name} {{")?; for (key, maybe_value_type) in types { if let Some(value_type) = maybe_value_type { writeln!(w, " {key}({value_type}),")?; } else { writeln!(w, " {key},")?; } } writeln!(w, "}}")?; Ok(()) } fn define_ast( mut w: W, base_name: &str, types: &BTreeMap<&str, &str>, ) -> Result<(), io::Error> { writeln!(w, "#[derive(Debug, Clone)]")?; writeln!(w, "pub enum {base_name} {{")?; for (key, value_types) in types { define_ast_variant(&mut w, key, value_types)?; } writeln!(w, "}}")?; Ok(()) } fn define_ast_variant(mut w: W, name: &str, value_types: &str) -> Result<(), io::Error> { writeln!(w, " {name} {{")?; for field in value_types.split(", ") { writeln!(w, " {field},")?; } writeln!(w, " }},")?; Ok(()) } fn define_visitor_trait( mut w: W, name: &str, base_name: &str, types: &BTreeMap<&str, &str>, ) -> Result<(), io::Error> { writeln!(w, "pub trait {name} {{")?; define_main_visitor_trait_method(&mut w, base_name, "T", types)?; writeln!(w)?; for (key, value_types) in types { define_visitor_trait_method(&mut w, key, value_types)?; } writeln!(w, "}}")?; Ok(()) } fn define_visitor_trait_method( mut w: W, type_name: &str, value_types: &str, ) -> Result<(), io::Error> { let snake_key = camel_to_snake(type_name); write!(w, " fn visit_{snake_key}(&mut self, ")?; let fields = value_types.split(", ").collect::>(); for (i, field) in fields.iter().enumerate() { let mut field_components = field.split(": "); let field_name = field_components.next().unwrap(); let field_raw_type = field_components.next().unwrap(); // Filthy hack for box types, but we don't use any other non-referential type right now let stripped_field_type = field_raw_type .strip_prefix("Box<") .and_then(|val| val.strip_suffix('>')) .unwrap_or(field_raw_type); let arg_type = format!("&{stripped_field_type}"); if i == fields.len() - 1 { write!(w, "{field_name}: {arg_type}")?; } else { write!(w, "{field_name}: {arg_type}, ")?; } } writeln!(w, ") -> T;")?; Ok(()) } fn define_main_visitor_trait_method( mut w: W, base_name: &str, result_name: &str, types: &BTreeMap<&str, &str>, ) -> Result<(), io::Error> { let snake_name = camel_to_snake(base_name); writeln!( w, " fn visit_{snake_name}(&mut self, {snake_name}: &{base_name}) -> {result_name} {{", )?; writeln!(w, " match {snake_name} {{")?; for (key, value_types) in types { define_main_visitor_match_arm(&mut w, base_name, key, value_types)?; } writeln!(w, " }}")?; writeln!(w, " }}")?; Ok(()) } fn define_main_visitor_match_arm( mut w: W, base_name: &str, type_name: &str, value_types: &str, ) -> Result<(), io::Error> { let fields = value_types.split(", ").collect::>(); write!(w, " {base_name}::{type_name} {{ ")?; write_comma_separated_field_names(&mut w, &fields)?; write!(w, " }} => ")?; write!(w, "self.visit_{}(", camel_to_snake(type_name))?; write_comma_separated_field_names_with_ref(&mut w, &fields)?; writeln!(w, "),")?; Ok(()) } fn write_comma_separated_field_names(mut w: W, fields: &[&str]) -> Result<(), io::Error> { let num_fields = fields.len(); for (i, field) in fields.iter().enumerate() { let field_name = field.split(": ").next().unwrap(); if i == num_fields - 1 { write!(w, "{field_name}")?; } else { write!(w, "{field_name}, ")?; } } Ok(()) } fn write_comma_separated_field_names_with_ref( mut w: W, fields: &[&str], ) -> Result<(), io::Error> { let num_fields = fields.len(); for (i, field) in fields.iter().enumerate() { let mut field_components = field.split(": "); let field_name = field_components.next().unwrap(); let field_type = field_components.next().unwrap(); // Filthy box hack again let reference_name = if field_type.starts_with("Box") { format!("{field_name}.as_ref()") } else { format!("&{field_name}") }; if i == num_fields - 1 { write!(w, "{reference_name}")?; } else { write!(w, "{reference_name}, ")?; } } Ok(()) } fn camel_to_snake(s: &str) -> String { if s.is_empty() { return String::new(); } let mut s_iter = s.chars(); iter::once(s_iter.next().unwrap().to_ascii_lowercase()) .chain(s_iter) .fold(String::new(), |mut snake, c| { if c.is_ascii_uppercase() { snake.push(c.to_ascii_lowercase()); snake.push('_'); } else { snake.push(c); } snake }) }