diff --git a/PLAN.md b/PLAN.md index 5839166..80f1814 100644 --- a/PLAN.md +++ b/PLAN.md @@ -49,7 +49,7 @@ A Rust-flavored, C-targeting language - built pipeline-first. ### Control flow - [x] booleans and comparision operators -- [ ] `if` / `else` branching +- [x] `if` / `else` branching - [ ] `while` loops ### Types & memory diff --git a/e2e.sh b/e2e.sh index a508537..7f1295f 100755 --- a/e2e.sh +++ b/e2e.sh @@ -98,4 +98,7 @@ run_test "tests/return_neg_69.src" 187 # Test boolean operations. run_harness_test "tests/booleans.src" "tests/booleans.c" 0 +# Test if and else statements. +run_harness_test "tests/if-else.src" "tests/if-else.c" 0 + echo "All end-to-end tests passed!" \ No newline at end of file diff --git a/src/backend/cranelift.rs b/src/backend/cranelift.rs index 3534426..c7e1593 100644 --- a/src/backend/cranelift.rs +++ b/src/backend/cranelift.rs @@ -154,12 +154,53 @@ struct FunctionTranslator<'a> { impl<'a> FunctionTranslator<'a> { /// Translates a statement, recursively compiling its inner components. - fn translate_stmt(&mut self, stmt: &TypedStmt) { + /// Returns `true` if the statement resulted in a basic block terminator. + fn translate_stmt(&mut self, stmt: &TypedStmt) -> bool { match stmt { TypedStmt::Compound { inner } => { for s in inner { - self.translate_stmt(s); + if self.translate_stmt(s) { + return true; + } } + false + } + TypedStmt::If { + condition, + then, + elze, + } => { + let cond_val = self.translate_expr(condition); + + let then_block = self.builder.create_block(); + let else_block = self.builder.create_block(); + let merge_block = self.builder.create_block(); + + self.builder + .ins() + .brif(cond_val, then_block, &[], else_block, &[]); + + self.builder.switch_to_block(then_block); + self.builder.seal_block(then_block); + let then_terminated = self.translate_stmt(then); + if !then_terminated { + self.builder.ins().jump(merge_block, &[]); + } + + self.builder.switch_to_block(else_block); + self.builder.seal_block(else_block); + let else_terminated = elze + .as_ref() + .map(|stmt| self.translate_stmt(stmt)) + .unwrap_or(false); + if !else_terminated { + self.builder.ins().jump(merge_block, &[]); + } + + self.builder.switch_to_block(merge_block); + self.builder.seal_block(merge_block); + + then_terminated && else_terminated } TypedStmt::Return { value } => { if let Some(expr) = value { @@ -168,6 +209,7 @@ impl<'a> FunctionTranslator<'a> { } else { self.builder.ins().return_(&[]); } + true } } } diff --git a/src/frontend/ast.rs b/src/frontend/ast.rs index 78bb7cd..b85f3b2 100644 --- a/src/frontend/ast.rs +++ b/src/frontend/ast.rs @@ -56,8 +56,17 @@ pub struct Stmt { #[derive(Debug, PartialEq, Eq)] pub enum StmtKind { - Compound { inner: Vec }, - Return { value: Option }, + Compound { + inner: Vec, + }, + If { + condition: Expr, + then: Box, + elze: Option>, + }, + Return { + value: Option, + }, } #[derive(Debug, PartialEq, Eq)] diff --git a/src/frontend/lexer.rs b/src/frontend/lexer.rs index 22e0b0c..bc1b84c 100644 --- a/src/frontend/lexer.rs +++ b/src/frontend/lexer.rs @@ -64,6 +64,8 @@ impl<'src> Lexer<'src> { match &self.source[start..self.cursor] { "fn" => TokenKind::Fn, + "if" => TokenKind::If, + "else" => TokenKind::Else, "return" => TokenKind::Return, "i8" => TokenKind::I8, @@ -203,12 +205,16 @@ mod test { #[test] fn identifiers() { assert_eq!( - tokenize("HELLO _hello _0@"), + tokenize("HELLO _hello _0@ fn if else return"), vec![ Token::new(TokenKind::Identifier, "HELLO", Span::new(0, 5)), Token::new(TokenKind::Identifier, "_hello", Span::new(6, 12)), Token::new(TokenKind::Identifier, "_0", Span::new(13, 15)), Token::new(TokenKind::Invalid, "@", Span::new(15, 16)), + Token::new(TokenKind::Fn, "fn", Span::new(17, 19)), + Token::new(TokenKind::If, "if", Span::new(20, 22)), + Token::new(TokenKind::Else, "else", Span::new(23, 27)), + Token::new(TokenKind::Return, "return", Span::new(28, 34)), ] ) } diff --git a/src/frontend/parser.rs b/src/frontend/parser.rs index d449e5b..4d2f174 100644 --- a/src/frontend/parser.rs +++ b/src/frontend/parser.rs @@ -284,6 +284,7 @@ impl<'src> Parser<'src> { /// /// ```ebnf /// stmt = compound_stmt + /// | if_stmt /// | return_stmt ; /// ``` pub fn parse_stmt(&mut self) -> ParseResult { @@ -291,6 +292,7 @@ impl<'src> Parser<'src> { match peek_token.kind { TokenKind::LBrace => self.parse_compound_stmt(), + TokenKind::If => self.parse_if_stmt(), TokenKind::Return => self.parse_return_stmt(), _ => Err(ParseError::new( @@ -333,6 +335,46 @@ impl<'src> Parser<'src> { }) } + /// Parses an if statement. + /// + /// ```ebnf + /// if_stmt = "if" expr compound_stmt [ "else" ( if_stmt | compound_stmt ) ] ; + /// ``` + fn parse_if_stmt(&mut self) -> ParseResult { + let if_token = self.expect(TokenKind::If)?; + + let condition = self.parse_expr()?; + let consequence = self.parse_compound_stmt()?; + + let alternative = if self.is_peek(TokenKind::Else) { + self.advance(); + + Some(if self.is_peek(TokenKind::If) { + self.parse_if_stmt() + } else { + self.parse_compound_stmt() + }?) + } else { + None + }; + + let span = if_token.span.join( + alternative + .as_ref() + .map(|stmt| stmt.span) + .unwrap_or(consequence.span), + ); + + Ok(Stmt { + kind: StmtKind::If { + condition, + then: Box::new(consequence), + elze: alternative.map(Box::new), + }, + span, + }) + } + /// Parses a return statement. /// /// ```ebnf @@ -676,6 +718,32 @@ mod test { ); } + #[test] + fn if_stmt() { + assert_eq!( + parse("if true { return; }", Parser::parse_stmt), + Success(Stmt { + kind: StmtKind::If { + condition: Expr { + kind: ExprKind::Boolean { value: true }, + span: Span::new(3, 7) + }, + then: Box::new(Stmt { + kind: StmtKind::Compound { + inner: vec![Stmt { + kind: StmtKind::Return { value: None }, + span: Span::new(10, 17) + }] + }, + span: Span::new(8, 19) + }), + elze: None + }, + span: Span::new(0, 19) + }) + ) + } + #[test] fn compound_stmt() { assert_eq!( diff --git a/src/frontend/sema.rs b/src/frontend/sema.rs index 8abcba2..8f0993c 100644 --- a/src/frontend/sema.rs +++ b/src/frontend/sema.rs @@ -74,7 +74,7 @@ impl From<&TypeKind> for Ty { pub struct Sema { next_var: usize, subst: HashMap, - env: HashMap, + scopes: Vec>, errors: Vec, deferred_unary_neg: Vec<(Span, Ty, Ty, Option)>, deferred_binary: Vec<(Span, Ty)>, @@ -87,7 +87,7 @@ impl Sema { Self { next_var: 0, subst: HashMap::new(), - env: HashMap::new(), + scopes: Vec::new(), errors: Vec::new(), deferred_unary_neg: Vec::new(), deferred_binary: Vec::new(), @@ -100,6 +100,22 @@ impl Sema { (!self.errors.is_empty()).then_some(self.errors) } + fn enter_scope(&mut self) { + self.scopes.push(HashMap::new()); + } + + fn leave_scope(&mut self) { + self.scopes.pop(); + } + + fn bind(&mut self, name: &str, ty: Ty) { + self.scopes.last_mut().unwrap().insert(name.to_string(), ty); + } + + fn lookup(&self, name: &str) -> Option<&Ty> { + self.scopes.iter().rev().find_map(|scope| scope.get(name)) + } + /// Generates a fresh, unconstrained type variable. fn new_var(&mut self) -> Ty { let v = self.next_var; @@ -156,16 +172,20 @@ impl Sema { if self.occurs_check(v, &t) { return Err("recursive type".to_string()); } + self.subst.insert(v, t); + Ok(()) } (Ty::Function(p1, r1), Ty::Function(p2, r2)) => { if p1.len() != p2.len() { return Err("arity mismatch".to_string()); } + for (arg1, arg2) in p1.iter().zip(p2.iter()) { self.unify(arg1, arg2)?; } + self.unify(&r1, &r2) } (a, b) => Err(format!("type mismatch: expected {:?}, found {:?}", a, b)), @@ -175,6 +195,8 @@ impl Sema { /// Analyzes an entire module, collecting function signatures into the environment, /// performing type inference, and returning a fully evaluated [TypedModule]. pub fn analyze_module(&mut self, module: &Module) -> TypedModule { + self.enter_scope(); + for decl in &module.decls { match &decl.kind { DeclKind::Function { @@ -188,8 +210,8 @@ impl Sema { .as_ref() .map(|t| Ty::from(&t.kind)) .unwrap_or(Ty::Unit); - self.env - .insert(name.clone(), Ty::Function(param_tys, Box::new(ret_ty))); + + self.bind(name, Ty::Function(param_tys, Box::new(ret_ty))); } } } @@ -206,6 +228,8 @@ impl Sema { final_decls.push(self.apply_subst_decl(decl)); } + self.leave_scope(); + TypedModule { decls: final_decls } } @@ -220,23 +244,24 @@ impl Sema { body, .. } => { - let mut local_env = self.env.clone(); let mut typed_params = Vec::new(); + self.enter_scope(); + for param in params { let ty = Ty::from(¶m.ty.kind); - local_env.insert(param.name.clone(), ty.clone()); + self.bind(¶m.name, ty.clone()); typed_params.push((param.name.clone(), ty)); } - let global_env = std::mem::replace(&mut self.env, local_env); let expected_ret_ty = return_type .as_ref() .map(|t| Ty::from(&t.kind)) .unwrap_or(Ty::Unit); let typed_body = self.analyze_stmt(body, &expected_ret_ty); - self.env = global_env; + + self.leave_scope(); TypedDecl::Function { name: name.clone(), @@ -255,12 +280,38 @@ impl Sema { StmtKind::Compound { inner } => { let mut typed_inner = Vec::new(); + self.enter_scope(); + for s in inner { typed_inner.push(self.analyze_stmt(s, expected_ret_ty)); } + self.leave_scope(); + TypedStmt::Compound { inner: typed_inner } } + StmtKind::If { + condition, + then, + elze, + } => { + let typed_condition = self.analyze_expr(condition); + + if let Err(err) = self.unify(&typed_condition.ty, &Ty::Bool) { + self.errors.push(SemanticError::new(err, condition.span)); + } + + let typed_then = self.analyze_stmt(then, expected_ret_ty); + let typed_elze = elze + .as_ref() + .map(|stmt| self.analyze_stmt(stmt, expected_ret_ty)); + + TypedStmt::If { + condition: typed_condition, + then: Box::new(typed_then), + elze: typed_elze.map(Box::new), + } + } StmtKind::Return { value } => { if let Some(expr) = value { let typed_expr = self.analyze_expr(expr); @@ -288,7 +339,7 @@ impl Sema { fn analyze_expr(&mut self, expr: &Expr) -> TypedExpr { match &expr.kind { ExprKind::Identifier { name } => { - let ty = if let Some(ty) = self.env.get(name) { + let ty = if let Some(ty) = self.lookup(name) { ty.clone() } else { self.errors.push(SemanticError::new( @@ -439,6 +490,16 @@ impl Sema { .collect(), }, + TypedStmt::If { + condition, + then, + elze, + } => TypedStmt::If { + condition: self.apply_subst_expr(condition), + then: Box::new(self.apply_subst_stmt(*then)), + elze: elze.map(|s| Box::new(self.apply_subst_stmt(*s))), + }, + TypedStmt::Return { value } => TypedStmt::Return { value: value.map(|e| self.apply_subst_expr(e)), }, @@ -692,4 +753,16 @@ mod test { .contains("binary operators only work on integer types") })); } + + #[test] + fn valid_if() { + let src = "fn test() { if true { return; } }"; + assert!(analyze(src).is_ok()) + } + + #[test] + fn invalid_if() { + let src = "fn test() { if 12 {} }"; + assert!(analyze(src).is_err()); + } } diff --git a/src/frontend/token.rs b/src/frontend/token.rs index 3d43715..6d9309f 100644 --- a/src/frontend/token.rs +++ b/src/frontend/token.rs @@ -50,6 +50,8 @@ pub enum TokenKind { // Keywords Fn, + If, + Else, Return, // Types @@ -104,6 +106,8 @@ impl Display for TokenKind { TokenKind::IntegerLit => "an integer", TokenKind::BooleanLit => "a boolean", TokenKind::Fn => "`fn`", + TokenKind::If => "`if`", + TokenKind::Else => "`else`", TokenKind::Return => "`return`", TokenKind::I8 => "`i8`", TokenKind::I16 => "`i16`", diff --git a/src/frontend/typed_ast.rs b/src/frontend/typed_ast.rs index 16caba4..bcc2c14 100644 --- a/src/frontend/typed_ast.rs +++ b/src/frontend/typed_ast.rs @@ -18,8 +18,17 @@ pub enum TypedDecl { #[derive(Debug, PartialEq, Eq)] pub enum TypedStmt { - Compound { inner: Vec }, - Return { value: Option }, + Compound { + inner: Vec, + }, + If { + condition: TypedExpr, + then: Box, + elze: Option>, + }, + Return { + value: Option, + }, } #[derive(Debug, PartialEq, Eq)] diff --git a/tests/if-else.c b/tests/if-else.c new file mode 100644 index 0000000..c87d4a7 --- /dev/null +++ b/tests/if-else.c @@ -0,0 +1,9 @@ +extern int min(int a, int b); + +int main() { + if (min(12, 15)) { + return 0; + } else { + return -1; + } +} \ No newline at end of file diff --git a/tests/if-else.src b/tests/if-else.src new file mode 100644 index 0000000..49c8ac9 --- /dev/null +++ b/tests/if-else.src @@ -0,0 +1,7 @@ +fn min(a: i32, b: i32) -> i32 { + if a < b { + return a; + } else { + return b; + } +} \ No newline at end of file