From 70625afcafed7c0dd8e1c30a969e3e00afb054b0 Mon Sep 17 00:00:00 2001 From: Nick Krichevsky Date: Wed, 15 May 2024 11:20:17 -0400 Subject: [PATCH] Add AST generation code and test visitor --- Cargo.toml | 5 +- bin/astgen.rs | 252 ++++++++++++++++++++++++++++++++++++++++++++++++++ src/ast.rs | 50 ++++++++++ src/lex.rs | 7 ++ src/lib.rs | 90 ++++++++++++++++++ 5 files changed, 403 insertions(+), 1 deletion(-) create mode 100644 bin/astgen.rs create mode 100644 src/ast.rs diff --git a/Cargo.toml b/Cargo.toml index a7adcfa..65dd956 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,8 +2,11 @@ name = "jlox-rust" version = "0.1.0" edition = "2021" +default-run = "jlox-rust" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[[bin]] +name = "astgen" +path = "bin/astgen.rs" [dependencies] anyhow = "1.0.83" diff --git a/bin/astgen.rs b/bin/astgen.rs new file mode 100644 index 0000000..910948a --- /dev/null +++ b/bin/astgen.rs @@ -0,0 +1,252 @@ +// This whole thing could maybe be a set of macros, but I'm playing along with the book for now + +use std::{ + collections::BTreeMap, + io::{self, Write}, + iter, +}; + +fn main() { + let mut stdout = io::stdout(); + // Autogenerated file, clippy can leave us alone + writeln!(stdout, "#![allow(clippy::all)]").unwrap(); + + writeln!( + stdout, + "// This file is auto-generated. Do not modify it directly, but rather use bin/astgen.rs.\n" + ) + .unwrap(); + + define_imports(&mut stdout, &["crate::lex::Token"]).expect("failed to define imports"); + writeln!(stdout).unwrap(); + + define_literals( + &mut stdout, + "LiteralValue", + &BTreeMap::from([ + ("Number", Some("f64")), + ("String", Some("String")), + ("True", None), + ("False", None), + ("Nil", None), + ]), + ) + .expect("failed to generate literals"); + writeln!(stdout).unwrap(); + + let types = BTreeMap::from([ + ( + "Binary", + "left: Box, operator: Token, right: Box", + ), + ("Grouping", "expr: Box"), + ("Unary", "expr: Box, operator: Token"), + ("Literal", "value: LiteralValue"), + ]); + + define_ast(&mut stdout, "Expr", &types).expect("failed to generate ast values"); + writeln!(stdout).unwrap(); + define_visitor_trait(&mut stdout, "Visitor", "Expr", &types) + .expect("failed to generate visitor trait"); +} + +fn define_imports(mut w: W, paths: &[&str]) -> Result<(), io::Error> { + for path in paths { + writeln!(w, "use {path};")?; + } + + Ok(()) +} + +fn define_literals( + mut w: W, + name: &str, + types: &BTreeMap<&str, Option<&str>>, +) -> Result<(), io::Error> { + writeln!(w, "pub enum {name} {{")?; + for (key, maybe_value_type) in types { + if let Some(value_type) = maybe_value_type { + writeln!(w, " {key}({value_type}),")?; + } else { + writeln!(w, " {key},")?; + } + } + writeln!(w, "}}")?; + + Ok(()) +} + +fn define_ast( + mut w: W, + base_name: &str, + types: &BTreeMap<&str, &str>, +) -> Result<(), io::Error> { + writeln!(w, "pub enum {base_name} {{")?; + for (key, value_types) in types { + define_ast_variant(&mut w, key, value_types)?; + } + writeln!(w, "}}")?; + + Ok(()) +} + +fn define_ast_variant(mut w: W, name: &str, value_types: &str) -> Result<(), io::Error> { + writeln!(w, " {name} {{")?; + for field in value_types.split(", ") { + writeln!(w, " {field},")?; + } + writeln!(w, " }},")?; + + Ok(()) +} + +fn define_visitor_trait( + mut w: W, + name: &str, + base_name: &str, + types: &BTreeMap<&str, &str>, +) -> Result<(), io::Error> { + writeln!(w, "pub trait {name} {{")?; + define_main_visitor_trait_method(&mut w, base_name, "T", types)?; + writeln!(w)?; + + for (key, value_types) in types { + define_visitor_trait_method(&mut w, key, value_types)?; + } + + writeln!(w, "}}")?; + Ok(()) +} + +fn define_visitor_trait_method( + mut w: W, + type_name: &str, + value_types: &str, +) -> Result<(), io::Error> { + let snake_key = camel_to_snake(type_name); + write!(w, " fn visit_{snake_key}(&mut self, ")?; + let fields = value_types.split(", ").collect::>(); + for (i, field) in fields.iter().enumerate() { + let mut field_components = field.split(": "); + + let field_name = field_components.next().unwrap(); + let field_raw_type = field_components.next().unwrap(); + // Filthy hack for box types, but we don't use any other non-referential type right now + let stripped_field_type = field_raw_type + .strip_prefix("Box<") + .and_then(|val| val.strip_suffix('>')) + .unwrap_or(field_raw_type); + let arg_type = format!("&{stripped_field_type}"); + + if i == fields.len() - 1 { + write!(w, "{field_name}: {arg_type}")?; + } else { + write!(w, "{field_name}: {arg_type}, ")?; + } + } + + writeln!(w, ") -> T;")?; + + Ok(()) +} + +fn define_main_visitor_trait_method( + mut w: W, + base_name: &str, + result_name: &str, + types: &BTreeMap<&str, &str>, +) -> Result<(), io::Error> { + let snake_name = camel_to_snake(base_name); + writeln!( + w, + " fn visit_{snake_name}(&mut self, {snake_name}: &{base_name}) -> {result_name} {{", + )?; + + writeln!(w, " match {snake_name} {{")?; + for (key, value_types) in types { + define_main_visitor_match_arm(&mut w, base_name, key, value_types)?; + } + + writeln!(w, " }}")?; + writeln!(w, " }}")?; + + Ok(()) +} + +fn define_main_visitor_match_arm( + mut w: W, + base_name: &str, + type_name: &str, + value_types: &str, +) -> Result<(), io::Error> { + let fields = value_types.split(", ").collect::>(); + write!(w, " {base_name}::{type_name} {{ ")?; + write_comma_separated_field_names(&mut w, &fields)?; + write!(w, " }} => ")?; + + write!(w, "self.visit_{}(", camel_to_snake(type_name))?; + write_comma_separated_field_names_with_ref(&mut w, &fields)?; + writeln!(w, "),")?; + + Ok(()) +} + +fn write_comma_separated_field_names(mut w: W, fields: &[&str]) -> Result<(), io::Error> { + let num_fields = fields.len(); + for (i, field) in fields.iter().enumerate() { + let field_name = field.split(": ").next().unwrap(); + if i == num_fields - 1 { + write!(w, "{field_name}")?; + } else { + write!(w, "{field_name}, ")?; + } + } + + Ok(()) +} + +fn write_comma_separated_field_names_with_ref( + mut w: W, + fields: &[&str], +) -> Result<(), io::Error> { + let num_fields = fields.len(); + for (i, field) in fields.iter().enumerate() { + let mut field_components = field.split(": "); + let field_name = field_components.next().unwrap(); + let field_type = field_components.next().unwrap(); + // Filthy box hack again + let reference_name = if field_type.starts_with("Box") { + format!("{field_name}.as_ref()") + } else { + format!("&{field_name}") + }; + + if i == num_fields - 1 { + write!(w, "{reference_name}")?; + } else { + write!(w, "{reference_name}, ")?; + } + } + + Ok(()) +} + +fn camel_to_snake(s: &str) -> String { + if s.is_empty() { + return String::new(); + } + let mut s_iter = s.chars(); + + iter::once(s_iter.next().unwrap().to_ascii_lowercase()) + .chain(s_iter) + .fold(String::new(), |mut snake, c| { + if c.is_ascii_uppercase() { + snake.push(c.to_ascii_lowercase()); + snake.push('_'); + } else { + snake.push(c); + } + + snake + }) +} diff --git a/src/ast.rs b/src/ast.rs new file mode 100644 index 0000000..e37ba90 --- /dev/null +++ b/src/ast.rs @@ -0,0 +1,50 @@ +#![allow(clippy::all)] +// This file is auto-generated. Do not modify it directly, but rather use bin/astgen.rs. + +use crate::lex::Token; + +pub enum LiteralValue { + False, + Nil, + Number(f64), + String(String), + True, +} + +pub enum Expr { + Binary { + left: Box, + operator: Token, + right: Box, + }, + Grouping { + expr: Box, + }, + Literal { + value: LiteralValue, + }, + Unary { + expr: Box, + operator: Token, + }, +} + +pub trait Visitor { + fn visit_expr(&mut self, expr: &Expr) -> T { + match expr { + Expr::Binary { + left, + operator, + right, + } => self.visit_binary(left.as_ref(), &operator, right.as_ref()), + Expr::Grouping { expr } => self.visit_grouping(expr.as_ref()), + Expr::Literal { value } => self.visit_literal(&value), + Expr::Unary { expr, operator } => self.visit_unary(expr.as_ref(), &operator), + } + } + + fn visit_binary(&mut self, left: &Expr, operator: &Token, right: &Expr) -> T; + fn visit_grouping(&mut self, expr: &Expr) -> T; + fn visit_literal(&mut self, value: &LiteralValue) -> T; + fn visit_unary(&mut self, expr: &Expr, operator: &Token) -> T; +} diff --git a/src/lex.rs b/src/lex.rs index 9bc197b..fb7f4e9 100644 --- a/src/lex.rs +++ b/src/lex.rs @@ -62,6 +62,13 @@ pub struct Token { line: usize, } +#[cfg(test)] +impl Token { + pub fn new(kind: TokenKind, lexeme: String, line: usize) -> Self { + Self { kind, lexeme, line } + } +} + #[derive(Debug, Clone)] struct Consumed { token: Option, diff --git a/src/lib.rs b/src/lib.rs index 46398bd..b44183c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,6 +2,7 @@ use std::fmt::{self, Display, Formatter}; +mod ast; mod lex; #[derive(thiserror::Error, Debug, Clone)] @@ -46,3 +47,92 @@ pub fn run(script: &str) -> Result<(), ScriptErrors> { Ok(()) } + +#[cfg(test)] +mod tests { + use self::{ + ast::{Expr, Visitor}, + lex::{Token, TokenKind}, + }; + + use super::*; + + struct ASTPrinter; + impl ASTPrinter { + // meh + #[allow(clippy::format_collect)] + fn parenthesize(&mut self, name: &str, exprs: &[&Expr]) -> String { + format!( + "({name}{})", + exprs + .iter() + .map(|expr| format!(" {}", self.visit_expr(expr))) + .collect::() + ) + } + } + + impl Visitor for ASTPrinter { + fn visit_binary( + &mut self, + left: &ast::Expr, + operator: &lex::Token, + right: &ast::Expr, + ) -> String { + self.parenthesize(operator.lexeme(), &[left, right]) + } + + fn visit_grouping(&mut self, expr: &Expr) -> String { + self.parenthesize("group", &[expr]) + } + + fn visit_unary(&mut self, expr: &Expr, operator: &lex::Token) -> String { + self.parenthesize(operator.lexeme(), &[expr]) + } + + fn visit_literal(&mut self, value: &ast::LiteralValue) -> String { + match value { + ast::LiteralValue::Nil => "nil".to_string(), + ast::LiteralValue::False => "false".to_string(), + ast::LiteralValue::True => "true".to_string(), + ast::LiteralValue::Number(n) => n.to_string(), + ast::LiteralValue::String(s) => format!("\"{}\"", s.clone()), + } + } + } + + #[test] + fn test_simple_add() { + let result = ASTPrinter.visit_expr(&Expr::Binary { + left: Box::new(Expr::Literal { + value: ast::LiteralValue::Number(123_f64), + }), + operator: Token::new(TokenKind::Plus, "+".to_string(), 1), + right: Box::new(Expr::Literal { + value: ast::LiteralValue::Number(456_f64), + }), + }); + + assert_eq!("(+ 123 456)", result); + } + + #[test] + fn test_complicated_arithmetic() { + let result = ASTPrinter.visit_expr(&Expr::Binary { + left: Box::new(Expr::Unary { + expr: Box::new(Expr::Literal { + value: ast::LiteralValue::Number(123_f64), + }), + operator: Token::new(TokenKind::Minus, "-".to_string(), 1), + }), + operator: Token::new(TokenKind::Plus, "*".to_string(), 1), + right: Box::new(Expr::Grouping { + expr: Box::new(Expr::Literal { + value: ast::LiteralValue::Number(456.789_f64), + }), + }), + }); + + assert_eq!("(* (- 123) (group 456.789))", result); + } +}