feat: add support for if/else statements

This commit is contained in:
2026-04-21 18:20:15 +02:00
parent eb3663dfbb
commit 0c288c2247
11 changed files with 247 additions and 17 deletions
+1 -1
View File
@@ -49,7 +49,7 @@ A Rust-flavored, C-targeting language - built pipeline-first.
### Control flow
- [x] booleans and comparision operators
- [ ] `if` / `else` branching
- [x] `if` / `else` branching
- [ ] `while` loops
### Types & memory
+3
View File
@@ -98,4 +98,7 @@ run_test "tests/return_neg_69.src" 187
# Test boolean operations.
run_harness_test "tests/booleans.src" "tests/booleans.c" 0
# Test if and else statements.
run_harness_test "tests/if-else.src" "tests/if-else.c" 0
echo "All end-to-end tests passed!"
+44 -2
View File
@@ -154,13 +154,54 @@ struct FunctionTranslator<'a> {
impl<'a> FunctionTranslator<'a> {
/// Translates a statement, recursively compiling its inner components.
fn translate_stmt(&mut self, stmt: &TypedStmt) {
/// Returns `true` if the statement resulted in a basic block terminator.
fn translate_stmt(&mut self, stmt: &TypedStmt) -> bool {
match stmt {
TypedStmt::Compound { inner } => {
for s in inner {
self.translate_stmt(s);
if self.translate_stmt(s) {
return true;
}
}
false
}
TypedStmt::If {
condition,
then,
elze,
} => {
let cond_val = self.translate_expr(condition);
let then_block = self.builder.create_block();
let else_block = self.builder.create_block();
let merge_block = self.builder.create_block();
self.builder
.ins()
.brif(cond_val, then_block, &[], else_block, &[]);
self.builder.switch_to_block(then_block);
self.builder.seal_block(then_block);
let then_terminated = self.translate_stmt(then);
if !then_terminated {
self.builder.ins().jump(merge_block, &[]);
}
self.builder.switch_to_block(else_block);
self.builder.seal_block(else_block);
let else_terminated = elze
.as_ref()
.map(|stmt| self.translate_stmt(stmt))
.unwrap_or(false);
if !else_terminated {
self.builder.ins().jump(merge_block, &[]);
}
self.builder.switch_to_block(merge_block);
self.builder.seal_block(merge_block);
then_terminated && else_terminated
}
TypedStmt::Return { value } => {
if let Some(expr) = value {
let val = self.translate_expr(expr);
@@ -168,6 +209,7 @@ impl<'a> FunctionTranslator<'a> {
} else {
self.builder.ins().return_(&[]);
}
true
}
}
}
+11 -2
View File
@@ -56,8 +56,17 @@ pub struct Stmt {
#[derive(Debug, PartialEq, Eq)]
pub enum StmtKind {
Compound { inner: Vec<Stmt> },
Return { value: Option<Expr> },
Compound {
inner: Vec<Stmt>,
},
If {
condition: Expr,
then: Box<Stmt>,
elze: Option<Box<Stmt>>,
},
Return {
value: Option<Expr>,
},
}
#[derive(Debug, PartialEq, Eq)]
+7 -1
View File
@@ -64,6 +64,8 @@ impl<'src> Lexer<'src> {
match &self.source[start..self.cursor] {
"fn" => TokenKind::Fn,
"if" => TokenKind::If,
"else" => TokenKind::Else,
"return" => TokenKind::Return,
"i8" => TokenKind::I8,
@@ -203,12 +205,16 @@ mod test {
#[test]
fn identifiers() {
assert_eq!(
tokenize("HELLO _hello _0@"),
tokenize("HELLO _hello _0@ fn if else return"),
vec![
Token::new(TokenKind::Identifier, "HELLO", Span::new(0, 5)),
Token::new(TokenKind::Identifier, "_hello", Span::new(6, 12)),
Token::new(TokenKind::Identifier, "_0", Span::new(13, 15)),
Token::new(TokenKind::Invalid, "@", Span::new(15, 16)),
Token::new(TokenKind::Fn, "fn", Span::new(17, 19)),
Token::new(TokenKind::If, "if", Span::new(20, 22)),
Token::new(TokenKind::Else, "else", Span::new(23, 27)),
Token::new(TokenKind::Return, "return", Span::new(28, 34)),
]
)
}
+68
View File
@@ -284,6 +284,7 @@ impl<'src> Parser<'src> {
///
/// ```ebnf
/// stmt = compound_stmt
/// | if_stmt
/// | return_stmt ;
/// ```
pub fn parse_stmt(&mut self) -> ParseResult<Stmt> {
@@ -291,6 +292,7 @@ impl<'src> Parser<'src> {
match peek_token.kind {
TokenKind::LBrace => self.parse_compound_stmt(),
TokenKind::If => self.parse_if_stmt(),
TokenKind::Return => self.parse_return_stmt(),
_ => Err(ParseError::new(
@@ -333,6 +335,46 @@ impl<'src> Parser<'src> {
})
}
/// Parses an if statement.
///
/// ```ebnf
/// if_stmt = "if" expr compound_stmt [ "else" ( if_stmt | compound_stmt ) ] ;
/// ```
fn parse_if_stmt(&mut self) -> ParseResult<Stmt> {
let if_token = self.expect(TokenKind::If)?;
let condition = self.parse_expr()?;
let consequence = self.parse_compound_stmt()?;
let alternative = if self.is_peek(TokenKind::Else) {
self.advance();
Some(if self.is_peek(TokenKind::If) {
self.parse_if_stmt()
} else {
self.parse_compound_stmt()
}?)
} else {
None
};
let span = if_token.span.join(
alternative
.as_ref()
.map(|stmt| stmt.span)
.unwrap_or(consequence.span),
);
Ok(Stmt {
kind: StmtKind::If {
condition,
then: Box::new(consequence),
elze: alternative.map(Box::new),
},
span,
})
}
/// Parses a return statement.
///
/// ```ebnf
@@ -676,6 +718,32 @@ mod test {
);
}
#[test]
fn if_stmt() {
assert_eq!(
parse("if true { return; }", Parser::parse_stmt),
Success(Stmt {
kind: StmtKind::If {
condition: Expr {
kind: ExprKind::Boolean { value: true },
span: Span::new(3, 7)
},
then: Box::new(Stmt {
kind: StmtKind::Compound {
inner: vec![Stmt {
kind: StmtKind::Return { value: None },
span: Span::new(10, 17)
}]
},
span: Span::new(8, 19)
}),
elze: None
},
span: Span::new(0, 19)
})
)
}
#[test]
fn compound_stmt() {
assert_eq!(
+82 -9
View File
@@ -74,7 +74,7 @@ impl From<&TypeKind> for Ty {
pub struct Sema {
next_var: usize,
subst: HashMap<usize, Ty>,
env: HashMap<String, Ty>,
scopes: Vec<HashMap<String, Ty>>,
errors: Vec<SemanticError>,
deferred_unary_neg: Vec<(Span, Ty, Ty, Option<u64>)>,
deferred_binary: Vec<(Span, Ty)>,
@@ -87,7 +87,7 @@ impl Sema {
Self {
next_var: 0,
subst: HashMap::new(),
env: HashMap::new(),
scopes: Vec::new(),
errors: Vec::new(),
deferred_unary_neg: Vec::new(),
deferred_binary: Vec::new(),
@@ -100,6 +100,22 @@ impl Sema {
(!self.errors.is_empty()).then_some(self.errors)
}
fn enter_scope(&mut self) {
self.scopes.push(HashMap::new());
}
fn leave_scope(&mut self) {
self.scopes.pop();
}
fn bind(&mut self, name: &str, ty: Ty) {
self.scopes.last_mut().unwrap().insert(name.to_string(), ty);
}
fn lookup(&self, name: &str) -> Option<&Ty> {
self.scopes.iter().rev().find_map(|scope| scope.get(name))
}
/// Generates a fresh, unconstrained type variable.
fn new_var(&mut self) -> Ty {
let v = self.next_var;
@@ -156,16 +172,20 @@ impl Sema {
if self.occurs_check(v, &t) {
return Err("recursive type".to_string());
}
self.subst.insert(v, t);
Ok(())
}
(Ty::Function(p1, r1), Ty::Function(p2, r2)) => {
if p1.len() != p2.len() {
return Err("arity mismatch".to_string());
}
for (arg1, arg2) in p1.iter().zip(p2.iter()) {
self.unify(arg1, arg2)?;
}
self.unify(&r1, &r2)
}
(a, b) => Err(format!("type mismatch: expected {:?}, found {:?}", a, b)),
@@ -175,6 +195,8 @@ impl Sema {
/// Analyzes an entire module, collecting function signatures into the environment,
/// performing type inference, and returning a fully evaluated [TypedModule].
pub fn analyze_module(&mut self, module: &Module) -> TypedModule {
self.enter_scope();
for decl in &module.decls {
match &decl.kind {
DeclKind::Function {
@@ -188,8 +210,8 @@ impl Sema {
.as_ref()
.map(|t| Ty::from(&t.kind))
.unwrap_or(Ty::Unit);
self.env
.insert(name.clone(), Ty::Function(param_tys, Box::new(ret_ty)));
self.bind(name, Ty::Function(param_tys, Box::new(ret_ty)));
}
}
}
@@ -206,6 +228,8 @@ impl Sema {
final_decls.push(self.apply_subst_decl(decl));
}
self.leave_scope();
TypedModule { decls: final_decls }
}
@@ -220,23 +244,24 @@ impl Sema {
body,
..
} => {
let mut local_env = self.env.clone();
let mut typed_params = Vec::new();
self.enter_scope();
for param in params {
let ty = Ty::from(&param.ty.kind);
local_env.insert(param.name.clone(), ty.clone());
self.bind(&param.name, ty.clone());
typed_params.push((param.name.clone(), ty));
}
let global_env = std::mem::replace(&mut self.env, local_env);
let expected_ret_ty = return_type
.as_ref()
.map(|t| Ty::from(&t.kind))
.unwrap_or(Ty::Unit);
let typed_body = self.analyze_stmt(body, &expected_ret_ty);
self.env = global_env;
self.leave_scope();
TypedDecl::Function {
name: name.clone(),
@@ -255,12 +280,38 @@ impl Sema {
StmtKind::Compound { inner } => {
let mut typed_inner = Vec::new();
self.enter_scope();
for s in inner {
typed_inner.push(self.analyze_stmt(s, expected_ret_ty));
}
self.leave_scope();
TypedStmt::Compound { inner: typed_inner }
}
StmtKind::If {
condition,
then,
elze,
} => {
let typed_condition = self.analyze_expr(condition);
if let Err(err) = self.unify(&typed_condition.ty, &Ty::Bool) {
self.errors.push(SemanticError::new(err, condition.span));
}
let typed_then = self.analyze_stmt(then, expected_ret_ty);
let typed_elze = elze
.as_ref()
.map(|stmt| self.analyze_stmt(stmt, expected_ret_ty));
TypedStmt::If {
condition: typed_condition,
then: Box::new(typed_then),
elze: typed_elze.map(Box::new),
}
}
StmtKind::Return { value } => {
if let Some(expr) = value {
let typed_expr = self.analyze_expr(expr);
@@ -288,7 +339,7 @@ impl Sema {
fn analyze_expr(&mut self, expr: &Expr) -> TypedExpr {
match &expr.kind {
ExprKind::Identifier { name } => {
let ty = if let Some(ty) = self.env.get(name) {
let ty = if let Some(ty) = self.lookup(name) {
ty.clone()
} else {
self.errors.push(SemanticError::new(
@@ -439,6 +490,16 @@ impl Sema {
.collect(),
},
TypedStmt::If {
condition,
then,
elze,
} => TypedStmt::If {
condition: self.apply_subst_expr(condition),
then: Box::new(self.apply_subst_stmt(*then)),
elze: elze.map(|s| Box::new(self.apply_subst_stmt(*s))),
},
TypedStmt::Return { value } => TypedStmt::Return {
value: value.map(|e| self.apply_subst_expr(e)),
},
@@ -692,4 +753,16 @@ mod test {
.contains("binary operators only work on integer types")
}));
}
#[test]
fn valid_if() {
let src = "fn test() { if true { return; } }";
assert!(analyze(src).is_ok())
}
#[test]
fn invalid_if() {
let src = "fn test() { if 12 {} }";
assert!(analyze(src).is_err());
}
}
+4
View File
@@ -50,6 +50,8 @@ pub enum TokenKind {
// Keywords
Fn,
If,
Else,
Return,
// Types
@@ -104,6 +106,8 @@ impl Display for TokenKind {
TokenKind::IntegerLit => "an integer",
TokenKind::BooleanLit => "a boolean",
TokenKind::Fn => "`fn`",
TokenKind::If => "`if`",
TokenKind::Else => "`else`",
TokenKind::Return => "`return`",
TokenKind::I8 => "`i8`",
TokenKind::I16 => "`i16`",
+11 -2
View File
@@ -18,8 +18,17 @@ pub enum TypedDecl {
#[derive(Debug, PartialEq, Eq)]
pub enum TypedStmt {
Compound { inner: Vec<TypedStmt> },
Return { value: Option<TypedExpr> },
Compound {
inner: Vec<TypedStmt>,
},
If {
condition: TypedExpr,
then: Box<TypedStmt>,
elze: Option<Box<TypedStmt>>,
},
Return {
value: Option<TypedExpr>,
},
}
#[derive(Debug, PartialEq, Eq)]
+9
View File
@@ -0,0 +1,9 @@
extern int min(int a, int b);
int main() {
if (min(12, 15)) {
return 0;
} else {
return -1;
}
}
+7
View File
@@ -0,0 +1,7 @@
fn min(a: i32, b: i32) -> i32 {
if a < b {
return a;
} else {
return b;
}
}