diff --git a/e2e.sh b/e2e.sh index 7f1295f..fe57179 100755 --- a/e2e.sh +++ b/e2e.sh @@ -101,4 +101,7 @@ run_harness_test "tests/booleans.src" "tests/booleans.c" 0 # Test if and else statements. run_harness_test "tests/if-else.src" "tests/if-else.c" 0 +# Test variable declarations and scoping. +run_test "tests/let_stmt.src" 30 + echo "All end-to-end tests passed!" \ No newline at end of file diff --git a/src/frontend/ast.rs b/src/frontend/ast.rs index 218ad58..a794d90 100644 --- a/src/frontend/ast.rs +++ b/src/frontend/ast.rs @@ -101,6 +101,12 @@ pub enum StmtKind { Return { value: Option>, }, + Let { + name: String, + name_span: Span, + ty: P::ReturnType, + initializer: Option>, + }, } #[derive(Debug, PartialEq, Eq)] diff --git a/src/frontend/lexer.rs b/src/frontend/lexer.rs index bc1b84c..add1c71 100644 --- a/src/frontend/lexer.rs +++ b/src/frontend/lexer.rs @@ -67,6 +67,7 @@ impl<'src> Lexer<'src> { "if" => TokenKind::If, "else" => TokenKind::Else, "return" => TokenKind::Return, + "let" => TokenKind::Let, "i8" => TokenKind::I8, "i16" => TokenKind::I16, @@ -205,7 +206,7 @@ mod test { #[test] fn identifiers() { assert_eq!( - tokenize("HELLO _hello _0@ fn if else return"), + tokenize("HELLO _hello _0@ fn if else return let"), vec![ Token::new(TokenKind::Identifier, "HELLO", Span::new(0, 5)), Token::new(TokenKind::Identifier, "_hello", Span::new(6, 12)), @@ -215,6 +216,7 @@ mod test { Token::new(TokenKind::If, "if", Span::new(20, 22)), Token::new(TokenKind::Else, "else", Span::new(23, 27)), Token::new(TokenKind::Return, "return", Span::new(28, 34)), + Token::new(TokenKind::Let, "let", Span::new(35, 38)), ] ) } diff --git a/src/frontend/parser.rs b/src/frontend/parser.rs index 441cab9..77a1ef5 100644 --- a/src/frontend/parser.rs +++ b/src/frontend/parser.rs @@ -285,7 +285,9 @@ impl<'src> Parser<'src> { /// ```ebnf /// stmt = compound_stmt /// | if_stmt - /// | return_stmt ; + /// | return_stmt + /// | let_stmt + /// ; /// ``` pub fn parse_stmt(&mut self) -> ParseResult { let peek_token = self.peek_no_eof()?; @@ -294,6 +296,7 @@ impl<'src> Parser<'src> { TokenKind::LBrace => self.parse_compound_stmt(), TokenKind::If => self.parse_if_stmt(), TokenKind::Return => self.parse_return_stmt(), + TokenKind::Let => self.parse_let_stmt(), _ => Err(ParseError::new( format!("expected a statement but found {} instead", peek_token.kind), @@ -398,6 +401,47 @@ impl<'src> Parser<'src> { }) } + /// Parses a let statement. + /// + /// ```ebnf + /// let_stmt = "let" IDENTIFIER [ ":" type ] [ "=" expr ] ";" ; + /// ``` + fn parse_let_stmt(&mut self) -> ParseResult { + let let_token = self.expect(TokenKind::Let)?; + + let (name, name_span) = { + let ident_token = self.expect(TokenKind::Identifier)?; + (ident_token.text.to_string(), ident_token.span) + }; + + let ty = if self.is_peek(TokenKind::Colon) { + self.advance(); + Some(self.parse_type()?) + } else { + None + }; + + let initializer = if self.is_peek(TokenKind::Assign) { + self.advance(); + Some(self.parse_expr()?) + } else { + None + }; + + let semi_token = self.expect(TokenKind::Semicolon)?; + let span = let_token.span.join(semi_token.span); + + Ok(Stmt { + kind: StmtKind::Let { + name, + name_span, + ty, + initializer, + }, + span, + }) + } + // ====== Pratt Parsing Implementation ====== /// Parses an expression. @@ -795,6 +839,42 @@ mod test { ); } + #[test] + fn let_stmt() { + assert_eq!( + parse("let a: i32 = 5;", Parser::parse_stmt), + Success(Stmt { + kind: StmtKind::Let { + name: "a".to_string(), + name_span: Span::new(4, 5), + ty: Some(Type { + kind: TypeKind::I32, + span: Span::new(7, 10), + }), + initializer: Some(Expr { + kind: ExprKind::Integer { value: 5 }, + ty: (), + span: Span::new(13, 14), + }), + }, + span: Span::new(0, 15) + }) + ); + + assert_eq!( + parse("let a;", Parser::parse_stmt), + Success(Stmt { + kind: StmtKind::Let { + name: "a".to_string(), + name_span: Span::new(4, 5), + ty: None, + initializer: None, + }, + span: Span::new(0, 6) + }) + ); + } + #[test] fn function_decl() { assert_eq!( diff --git a/src/frontend/sema.rs b/src/frontend/sema.rs index 7cc784d..5a74c7d 100644 --- a/src/frontend/sema.rs +++ b/src/frontend/sema.rs @@ -348,6 +348,48 @@ impl Sema { } } } + StmtKind::Let { + name, + name_span, + ty: type_annotation, + initializer, + } => { + let explicit_ty = type_annotation.as_ref().map(|t| Ty::from(&t.kind)); + + let (inferred_ty, typed_initializer) = match initializer { + Some(expr) => { + let typed_expr = self.analyze_expr(expr); + if let Some(ref ext_ty) = explicit_ty { + if let Err(err) = self.unify(&typed_expr.ty, ext_ty) { + self.errors.push(SemanticError::new(err, expr.span)); + } + } + (typed_expr.ty.clone(), Some(typed_expr)) + } + None => { + if let Some(ref ext_ty) = explicit_ty { + (ext_ty.clone(), None) + } else { + self.errors + .push(SemanticError::new("type annotation needed", *name_span)); + (self.new_var(), None) + } + } + }; + + let final_ty = explicit_ty.unwrap_or(inferred_ty); + self.bind(name, final_ty.clone()); + + TypedStmt { + kind: TypedStmtKind::Let { + name: name.clone(), + name_span: *name_span, + ty: final_ty, + initializer: typed_initializer, + }, + span: stmt.span, + } + } } } @@ -532,6 +574,18 @@ impl Sema { TypedStmtKind::Return { value } => TypedStmtKind::Return { value: value.map(|e| self.apply_subst_expr(e)), }, + + TypedStmtKind::Let { + name, + name_span, + ty, + initializer, + } => TypedStmtKind::Let { + name, + name_span, + ty: self.apply_subst(&ty), + initializer: initializer.map(|e| self.apply_subst_expr(e)), + }, }; TypedStmt { kind, span } @@ -797,4 +851,16 @@ mod test { let src = "fn test() { if 12 {} }"; assert!(analyze(src).is_err()); } + + #[test] + fn valid_let() { + let src = "fn test() { let a = 5; let b: i32 = a; }"; + assert!(analyze(src).is_ok()); + } + + #[test] + fn let_missing_type() { + let src = "fn test() { let a; }"; + assert!(analyze(src).is_err()); + } } diff --git a/src/frontend/token.rs b/src/frontend/token.rs index 7594e9c..c3412ca 100644 --- a/src/frontend/token.rs +++ b/src/frontend/token.rs @@ -59,6 +59,7 @@ pub enum TokenKind { If, Else, Return, + Let, // Types I8, @@ -115,6 +116,7 @@ impl Display for TokenKind { TokenKind::If => "`if`", TokenKind::Else => "`else`", TokenKind::Return => "`return`", + TokenKind::Let => "`let`", TokenKind::I8 => "`i8`", TokenKind::I16 => "`i16`", TokenKind::I32 => "`i32`", diff --git a/src/middle/builder.rs b/src/middle/builder.rs index bc9063c..4c01cff 100644 --- a/src/middle/builder.rs +++ b/src/middle/builder.rs @@ -83,8 +83,8 @@ struct FuncBuilder { /// Counter for generating unique `BlockId`s. next_block_id: usize, - /// Mapping from user-defined variable names to their corresponding `LocalId`. - vars: HashMap, + /// Scoped mapping from user-defined variable names to their corresponding `LocalId`. + scopes: Vec>, } impl FuncBuilder { @@ -99,10 +99,18 @@ impl FuncBuilder { current_block: None, current_statements: Vec::new(), next_block_id: 0, - vars: HashMap::new(), + scopes: vec![HashMap::new()], } } + fn enter_scope(&mut self) { + self.scopes.push(HashMap::new()); + } + + fn leave_scope(&mut self) { + self.scopes.pop(); + } + /// Registers a new user-defined local variable and returns its `LocalId`. fn new_local(&mut self, name: String, ty: Ty) -> LocalId { let id = LocalId(self.locals.len()); @@ -112,7 +120,7 @@ impl FuncBuilder { mutable: false, name: Some(name.clone()), }); - self.vars.insert(name, id); + self.scopes.last_mut().unwrap().insert(name, id); id } @@ -128,6 +136,15 @@ impl FuncBuilder { id } + fn lookup(&self, name: &str) -> LocalId { + for scope in self.scopes.iter().rev() { + if let Some(id) = scope.get(name) { + return *id; + } + } + panic!("undeclared variable `{}` in MIR lowering", name); + } + /// Allocates a new, empty basic block and returns its `BlockId`. fn new_block(&mut self) -> BlockId { let id = BlockId(self.next_block_id); @@ -180,9 +197,11 @@ impl FuncBuilder { fn lower_stmt(&mut self, stmt: &TypedStmt) { match &stmt.kind { TypedStmtKind::Compound { inner } => { + self.enter_scope(); for s in inner { self.lower_stmt(s); } + self.leave_scope(); } TypedStmtKind::If { condition, @@ -237,6 +256,22 @@ impl FuncBuilder { span: stmt.span, }); } + TypedStmtKind::Let { + name, + name_span: _, + ty, + initializer, + } => { + let local_id = self.new_local(name.clone(), ty.clone()); + + if let Some(init_expr) = initializer { + let val_op = self.lower_expr(init_expr); + self.emit_stmt(Statement { + kind: StatementKind::Assign(local_id, Rvalue::Use(val_op)), + span: stmt.span, + }); + } + } } } @@ -244,10 +279,7 @@ impl FuncBuilder { fn lower_expr(&mut self, expr: &TypedExpr) -> Operand { match &expr.kind { TypedExprKind::Identifier { name } => { - let local = *self - .vars - .get(name) - .expect("undeclared variable in MIR lowering"); + let local = self.lookup(name); Operand::Copy(local) } TypedExprKind::Integer { value } => { diff --git a/tests/let_stmt.src b/tests/let_stmt.src new file mode 100644 index 0000000..606612f --- /dev/null +++ b/tests/let_stmt.src @@ -0,0 +1,11 @@ +fn main() -> i32 { + let a = 10; + let b: i32 = 20; + let c = a + b; + { + // Shadow 'a' and 'c' in a new scope + let a = 5; + let c = a + b; + } + return c; // Should return 30, not 25, because the inner 'c' is dropped +} \ No newline at end of file