From 6f7e2eb4998722d5c42a705300736a1dd6d4244e Mon Sep 17 00:00:00 2001 From: Nick Krichevsky Date: Wed, 22 May 2024 14:10:29 -0400 Subject: [PATCH] Add block scoping of variables --- build.rs | 1 + src/eval.rs | 23 ++++-- src/eval/environment.rs | 173 +++++++++++++++++++++++++++++++++++++++- src/parse.rs | 16 ++++ 4 files changed, 202 insertions(+), 11 deletions(-) diff --git a/build.rs b/build.rs index 4b235ff..5e50bb0 100644 --- a/build.rs +++ b/build.rs @@ -59,6 +59,7 @@ fn do_ast_codegen(mut output: W) { let statement_types = BTreeMap::from([ ("Expression", "expression: Expr"), + ("Block", "statements: Vec"), ("Print", "expression: Expr"), ("Var", "name: Token, initializer: Option"), ]); diff --git a/src/eval.rs b/src/eval.rs index dc4203d..72d2c39 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -1,8 +1,4 @@ -use std::{ - fmt::{self, Display, Formatter}, - rc::Rc, - string::ParseError, -}; +use std::fmt::{self, Display, Formatter}; use crate::{ ast::{Expr, ExprVisitor, LiteralValue, Stmt, StmtVisitor}, @@ -174,7 +170,7 @@ impl ExprVisitor> for InterpreterRunner<'_> fn visit_variable(&mut self, name: &Token) -> Result { self.interpreter .env - .get(name.lexeme()) + .get_value(name.lexeme()) .ok_or_else(|| ScriptError { message: format!("Undefined variable: {}", name.lexeme()), location: String::new(), @@ -200,10 +196,23 @@ impl StmtVisitor> for InterpreterRunner<'_> { .as_ref() .map_or(Ok(EvaluatedValue::Nil), |expr| self.visit_expr(expr))?; - self.interpreter.env.set(name.lexeme(), initialized_value); + self.interpreter + .env + .set_value(name.lexeme(), initialized_value); Ok(()) } + + fn visit_block(&mut self, statements: &Vec) -> Result<(), ScriptError> { + self.interpreter.env.enter_scope(); + let result = statements + .iter() + .try_for_each(|statement| self.visit_stmt(statement)); + + self.interpreter.env.exit_scope(); + + result + } } fn convert_arithmetic_operands( diff --git a/src/eval/environment.rs b/src/eval/environment.rs index 60f3c4c..a5639cc 100644 --- a/src/eval/environment.rs +++ b/src/eval/environment.rs @@ -1,22 +1,187 @@ -use std::collections::HashMap; +use std::collections::{BTreeMap, HashMap}; use super::value::EvaluatedValue; -#[derive(Debug, Default)] +#[derive(Debug)] pub struct Environment { + scopes: BTreeMap, +} + +#[derive(Debug)] +struct Scope { + parent: Option, values: HashMap, } +impl Default for Environment { + fn default() -> Self { + let mut scopes = BTreeMap::new(); + scopes.insert(Self::ROOT_KEY, Scope::new_global_scope()); + Self { scopes } + } +} + impl Environment { + const ROOT_KEY: u32 = 0; + pub fn new() -> Self { Self::default() } - pub fn get(&self, name: &str) -> Option<&EvaluatedValue> { + #[must_use] + pub fn root_scope(&self) -> &Scope { + self.scopes + .get(&Self::ROOT_KEY) + .expect("no root environment defined") + } + + #[must_use] + pub fn root_scope_mut(&mut self) -> &mut Scope { + self.scopes + .get_mut(&Self::ROOT_KEY) + .expect("no root environment defined") + } + + pub fn enter_scope(&mut self) { + let parent_key = self.current_scope_key(); + let child_key = self.next_scope_key(); + + let env = Scope::new_child_scope(parent_key); + let insert_result = self.scopes.insert(child_key, env); + assert!( + insert_result.is_none(), + "collision when inserting environment {child_key}" + ); + } + + pub fn exit_scope(&mut self) { + let to_remove_key = self.current_scope_key(); + assert!(to_remove_key != Self::ROOT_KEY, "cannot leave global scope"); + + self.scopes.remove(&to_remove_key); + } + + pub fn get_value(&self, name: &str) -> Option<&EvaluatedValue> { + let mut cursor = Some(self.current_scope()); + while let Some(scope) = cursor { + let value = scope.get(name); + if value.is_some() { + return value; + } + + cursor = scope.parent.and_then(|key| self.scopes.get(&key)); + } + + None + } + + pub fn set_value(&mut self, name: &str, value: EvaluatedValue) { + self.current_scope_mut().set(name, value); + } + + fn current_scope(&self) -> &Scope { + // this can't fail, given we know where these scope ids come from + self.scopes.get(&self.current_scope_key()).unwrap() + } + + fn current_scope_mut(&mut self) -> &mut Scope { + // this can't fail, given we know where these scope ids come from + self.scopes.get_mut(&self.current_scope_key()).unwrap() + } + + fn current_scope_key(&self) -> u32 { + self.scopes + .last_key_value() + .map(|(&k, _v)| k) + // this should *NEVER* happen + .expect("no root scope defined") + } + + fn next_scope_key(&self) -> u32 { + let parent = self.current_scope_key(); + + parent.checked_add(1).expect("too many nested scopes") + } +} + +impl Scope { + fn new_global_scope() -> Self { + Self { + values: HashMap::new(), + parent: None, + } + } + + fn new_child_scope(parent: u32) -> Self { + Self { + values: HashMap::new(), + parent: Some(parent), + } + } + + fn get(&self, name: &str) -> Option<&EvaluatedValue> { self.values.get(name) } - pub fn set>(&mut self, name: S, value: EvaluatedValue) { + fn set>(&mut self, name: S, value: EvaluatedValue) { self.values.insert(name.into(), value); } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_can_get_value_from_root_scope() { + let mut global = Environment::new(); + global.set_value("foo", EvaluatedValue::Number(42_f64)); + + assert_eq!( + Some(&EvaluatedValue::Number(42_f64)), + global.get_value(r"foo") + ); + } + + #[test] + fn test_can_get_value_from_parent_scope() { + let mut global = Environment::new(); + global.set_value("foo", EvaluatedValue::Number(42_f64)); + + // Make two child scopes, just so they exist + global.enter_scope(); + global.enter_scope(); + + assert_eq!( + Some(&EvaluatedValue::Number(42_f64)), + global.get_value(r"foo") + ); + } + + #[test] + fn test_insert_value_into_non_root_scope() { + let mut global = Environment::new(); + + global.enter_scope(); + global.set_value("foo", EvaluatedValue::Number(42_f64)); + + global.enter_scope(); + + assert_eq!( + Some(&EvaluatedValue::Number(42_f64)), + global.get_value(r"foo") + ); + } + + #[test] + fn test_leaving_scope_removes_value() { + let mut global = Environment::new(); + + global.enter_scope(); + global.set_value("foo", EvaluatedValue::Number(42_f64)); + + global.exit_scope(); + + assert_eq!(None, global.get_value(r"foo")); + } +} diff --git a/src/parse.rs b/src/parse.rs index b9d3696..b4b0b3d 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -1,5 +1,7 @@ use std::iter::Peekable; +use itertools::Itertools; + use crate::{ ast::{Expr, LiteralValue, Stmt}, lex::{Token, TokenKind}, @@ -120,6 +122,8 @@ fn parse_var_declaration>( fn parse_statement>(iter: &mut Peekable) -> Result { if match_token_kind!(iter, TokenKind::Print).is_ok() { parse_statement_containing_expression(iter, |expression| Stmt::Print { expression }) + } else if match_token_kind!(iter, TokenKind::LeftBrace).is_ok() { + parse_block_statement(iter) } else { parse_statement_containing_expression(iter, |expression| Stmt::Expression { expression }) } @@ -138,6 +142,18 @@ fn parse_statement_containing_expression, F: Fn(Expr) }) } +fn parse_block_statement>( + iter: &mut Peekable, +) -> Result { + let mut statements = vec![]; + while match_token_kind!(iter, TokenKind::RightBrace).is_err() { + let statement = parse_declaration(iter)?; + statements.push(statement); + } + + Ok(Stmt::Block { statements }) +} + fn parse_expression>(iter: &mut Peekable) -> Result { parse_equality(iter) }