feat: add support for if/else statements

This commit is contained in:
2026-04-21 18:20:15 +02:00
parent eb3663dfbb
commit 0c288c2247
11 changed files with 247 additions and 17 deletions
+82 -9
View File
@@ -74,7 +74,7 @@ impl From<&TypeKind> for Ty {
pub struct Sema {
next_var: usize,
subst: HashMap<usize, Ty>,
env: HashMap<String, Ty>,
scopes: Vec<HashMap<String, Ty>>,
errors: Vec<SemanticError>,
deferred_unary_neg: Vec<(Span, Ty, Ty, Option<u64>)>,
deferred_binary: Vec<(Span, Ty)>,
@@ -87,7 +87,7 @@ impl Sema {
Self {
next_var: 0,
subst: HashMap::new(),
env: HashMap::new(),
scopes: Vec::new(),
errors: Vec::new(),
deferred_unary_neg: Vec::new(),
deferred_binary: Vec::new(),
@@ -100,6 +100,22 @@ impl Sema {
(!self.errors.is_empty()).then_some(self.errors)
}
fn enter_scope(&mut self) {
self.scopes.push(HashMap::new());
}
fn leave_scope(&mut self) {
self.scopes.pop();
}
fn bind(&mut self, name: &str, ty: Ty) {
self.scopes.last_mut().unwrap().insert(name.to_string(), ty);
}
fn lookup(&self, name: &str) -> Option<&Ty> {
self.scopes.iter().rev().find_map(|scope| scope.get(name))
}
/// Generates a fresh, unconstrained type variable.
fn new_var(&mut self) -> Ty {
let v = self.next_var;
@@ -156,16 +172,20 @@ impl Sema {
if self.occurs_check(v, &t) {
return Err("recursive type".to_string());
}
self.subst.insert(v, t);
Ok(())
}
(Ty::Function(p1, r1), Ty::Function(p2, r2)) => {
if p1.len() != p2.len() {
return Err("arity mismatch".to_string());
}
for (arg1, arg2) in p1.iter().zip(p2.iter()) {
self.unify(arg1, arg2)?;
}
self.unify(&r1, &r2)
}
(a, b) => Err(format!("type mismatch: expected {:?}, found {:?}", a, b)),
@@ -175,6 +195,8 @@ impl Sema {
/// Analyzes an entire module, collecting function signatures into the environment,
/// performing type inference, and returning a fully evaluated [TypedModule].
pub fn analyze_module(&mut self, module: &Module) -> TypedModule {
self.enter_scope();
for decl in &module.decls {
match &decl.kind {
DeclKind::Function {
@@ -188,8 +210,8 @@ impl Sema {
.as_ref()
.map(|t| Ty::from(&t.kind))
.unwrap_or(Ty::Unit);
self.env
.insert(name.clone(), Ty::Function(param_tys, Box::new(ret_ty)));
self.bind(name, Ty::Function(param_tys, Box::new(ret_ty)));
}
}
}
@@ -206,6 +228,8 @@ impl Sema {
final_decls.push(self.apply_subst_decl(decl));
}
self.leave_scope();
TypedModule { decls: final_decls }
}
@@ -220,23 +244,24 @@ impl Sema {
body,
..
} => {
let mut local_env = self.env.clone();
let mut typed_params = Vec::new();
self.enter_scope();
for param in params {
let ty = Ty::from(&param.ty.kind);
local_env.insert(param.name.clone(), ty.clone());
self.bind(&param.name, ty.clone());
typed_params.push((param.name.clone(), ty));
}
let global_env = std::mem::replace(&mut self.env, local_env);
let expected_ret_ty = return_type
.as_ref()
.map(|t| Ty::from(&t.kind))
.unwrap_or(Ty::Unit);
let typed_body = self.analyze_stmt(body, &expected_ret_ty);
self.env = global_env;
self.leave_scope();
TypedDecl::Function {
name: name.clone(),
@@ -255,12 +280,38 @@ impl Sema {
StmtKind::Compound { inner } => {
let mut typed_inner = Vec::new();
self.enter_scope();
for s in inner {
typed_inner.push(self.analyze_stmt(s, expected_ret_ty));
}
self.leave_scope();
TypedStmt::Compound { inner: typed_inner }
}
StmtKind::If {
condition,
then,
elze,
} => {
let typed_condition = self.analyze_expr(condition);
if let Err(err) = self.unify(&typed_condition.ty, &Ty::Bool) {
self.errors.push(SemanticError::new(err, condition.span));
}
let typed_then = self.analyze_stmt(then, expected_ret_ty);
let typed_elze = elze
.as_ref()
.map(|stmt| self.analyze_stmt(stmt, expected_ret_ty));
TypedStmt::If {
condition: typed_condition,
then: Box::new(typed_then),
elze: typed_elze.map(Box::new),
}
}
StmtKind::Return { value } => {
if let Some(expr) = value {
let typed_expr = self.analyze_expr(expr);
@@ -288,7 +339,7 @@ impl Sema {
fn analyze_expr(&mut self, expr: &Expr) -> TypedExpr {
match &expr.kind {
ExprKind::Identifier { name } => {
let ty = if let Some(ty) = self.env.get(name) {
let ty = if let Some(ty) = self.lookup(name) {
ty.clone()
} else {
self.errors.push(SemanticError::new(
@@ -439,6 +490,16 @@ impl Sema {
.collect(),
},
TypedStmt::If {
condition,
then,
elze,
} => TypedStmt::If {
condition: self.apply_subst_expr(condition),
then: Box::new(self.apply_subst_stmt(*then)),
elze: elze.map(|s| Box::new(self.apply_subst_stmt(*s))),
},
TypedStmt::Return { value } => TypedStmt::Return {
value: value.map(|e| self.apply_subst_expr(e)),
},
@@ -692,4 +753,16 @@ mod test {
.contains("binary operators only work on integer types")
}));
}
#[test]
fn valid_if() {
let src = "fn test() { if true { return; } }";
assert!(analyze(src).is_ok())
}
#[test]
fn invalid_if() {
let src = "fn test() { if 12 {} }";
assert!(analyze(src).is_err());
}
}