From c3ee0d6e67ead20919e3001a079505830fe9bfe6 Mon Sep 17 00:00:00 2001 From: Jooris Hadeler Date: Mon, 20 Apr 2026 22:33:41 +0200 Subject: [PATCH] feat: add typed ast and Hindley-Milner style semantic analysis --- examples/simple.src | 4 +- src/frontend/mod.rs | 2 + src/frontend/sema.rs | 631 ++++++++++++++++++++++++++++++++++++++ src/frontend/typed_ast.rs | 51 +++ src/main.rs | 14 +- 5 files changed, 699 insertions(+), 3 deletions(-) create mode 100644 src/frontend/sema.rs create mode 100644 src/frontend/typed_ast.rs diff --git a/examples/simple.src b/examples/simple.src index 98f5b50..380035c 100644 --- a/examples/simple.src +++ b/examples/simple.src @@ -1,3 +1,3 @@ -fn main() -> i32 { - return 0; +fn main() -> i8 { + return -1; } \ No newline at end of file diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index 1b8d4c4..63591ad 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -1,4 +1,6 @@ pub mod ast; pub mod lexer; pub mod parser; +pub mod sema; pub mod token; +pub mod typed_ast; diff --git a/src/frontend/sema.rs b/src/frontend/sema.rs new file mode 100644 index 0000000..3a8680e --- /dev/null +++ b/src/frontend/sema.rs @@ -0,0 +1,631 @@ +use std::collections::HashMap; + +use crate::frontend::ast::*; +use crate::frontend::token::Span; +use crate::frontend::typed_ast::*; + +/// A structured error produced during semantic analysis, carrying a human-readable +/// message and the [Span] of the offending AST node for precise diagnostics. +#[derive(Debug, PartialEq, Eq)] +pub struct SemanticError { + /// Human-readable description of the semantic error. + pub message: String, + /// Source location of the offending node. + pub span: Span, +} + +impl SemanticError { + /// Creates a new [SemanticError] with the given message and source span. + pub fn new(message: impl ToString, span: Span) -> Self { + Self { + message: message.to_string(), + span, + } + } +} + +/// An internal representation of types used during semantic analysis, including +/// concrete types, unconstrained type variables, and function types. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Ty { + I8, + I16, + I32, + I64, + U8, + U16, + U32, + U64, + Bool, + Unit, + Var(usize), + Function(Vec, Box), +} + +impl Ty { + /// Returns `true` if the type is a fundamental integer type. + pub fn is_integer(&self) -> bool { + matches!( + self, + Ty::I8 | Ty::I16 | Ty::I32 | Ty::I64 | Ty::U8 | Ty::U16 | Ty::U32 | Ty::U64 + ) + } +} + +impl From<&TypeKind> for Ty { + fn from(kind: &TypeKind) -> Self { + match kind { + TypeKind::I8 => Ty::I8, + TypeKind::I16 => Ty::I16, + TypeKind::I32 => Ty::I32, + TypeKind::I64 => Ty::I64, + TypeKind::U8 => Ty::U8, + TypeKind::U16 => Ty::U16, + TypeKind::U32 => Ty::U32, + TypeKind::U64 => Ty::U64, + TypeKind::Bool => Ty::Bool, + } + } +} + +/// The semantic analyzer, responsible for type inference, name resolution, and type checking. +/// +/// Uses a Hindley-Milner style algorithm with unification to transform an untyped AST into a typed AST. +pub struct Sema { + next_var: usize, + subst: HashMap, + env: HashMap, + errors: Vec, + deferred_unary_neg: Vec<(Span, Ty, Ty, Option)>, + deferred_binary: Vec<(Span, Ty)>, + deferred_literals: Vec<(Span, Ty)>, +} + +impl Sema { + /// Creates a new, empty [Sema] instance. + pub fn new() -> Self { + Self { + next_var: 0, + subst: HashMap::new(), + env: HashMap::new(), + errors: Vec::new(), + deferred_unary_neg: Vec::new(), + deferred_binary: Vec::new(), + deferred_literals: Vec::new(), + } + } + + /// Consumes the analyzer and returns the accumulated semantic errors, if any. + pub fn errors(self) -> Option> { + (!self.errors.is_empty()).then_some(self.errors) + } + + /// Generates a fresh, unconstrained type variable. + fn new_var(&mut self) -> Ty { + let v = self.next_var; + self.next_var += 1; + Ty::Var(v) + } + + /// Recursively applies the current type substitutions to resolve a type variable to its inferred type. + pub fn apply_subst(&self, ty: &Ty) -> Ty { + match ty { + Ty::Var(v) => { + if let Some(t) = self.subst.get(v) { + self.apply_subst(t) + } else { + Ty::Var(*v) + } + } + + Ty::Function(params, ret) => { + let params = params.iter().map(|p| self.apply_subst(p)).collect(); + let ret = Box::new(self.apply_subst(ret)); + Ty::Function(params, ret) + } + + t => t.clone(), + } + } + + /// Performs an occurs check to prevent infinite recursive types during unification. + fn occurs_check(&self, v: usize, ty: &Ty) -> bool { + let ty = self.apply_subst(ty); + + match ty { + Ty::Var(v2) => v == v2, + Ty::Function(params, ret) => { + params.iter().any(|p| self.occurs_check(v, p)) || self.occurs_check(v, &ret) + } + _ => false, + } + } + + /// Attempts to unify two types. If they are compatible, it records the necessary + /// substitutions. Returns an error message string if unification fails. + fn unify(&mut self, t1: &Ty, t2: &Ty) -> Result<(), String> { + let t1 = self.apply_subst(t1); + let t2 = self.apply_subst(t2); + + if t1 == t2 { + return Ok(()); + } + + match (t1, t2) { + (Ty::Var(v), t) | (t, Ty::Var(v)) => { + if self.occurs_check(v, &t) { + return Err("recursive type".to_string()); + } + self.subst.insert(v, t); + Ok(()) + } + (Ty::Function(p1, r1), Ty::Function(p2, r2)) => { + if p1.len() != p2.len() { + return Err("arity mismatch".to_string()); + } + for (arg1, arg2) in p1.iter().zip(p2.iter()) { + self.unify(arg1, arg2)?; + } + self.unify(&r1, &r2) + } + (a, b) => Err(format!("type mismatch: expected {:?}, found {:?}", a, b)), + } + } + + /// Analyzes an entire module, collecting function signatures into the environment, + /// performing type inference, and returning a fully evaluated [TypedModule]. + pub fn analyze_module(&mut self, module: &Module) -> TypedModule { + for decl in &module.decls { + match &decl.kind { + DeclKind::Function { + name, + params, + return_type, + .. + } => { + let param_tys: Vec = params.iter().map(|p| Ty::from(&p.ty.kind)).collect(); + let ret_ty = return_type + .as_ref() + .map(|t| Ty::from(&t.kind)) + .unwrap_or(Ty::Unit); + self.env + .insert(name.clone(), Ty::Function(param_tys, Box::new(ret_ty))); + } + } + } + + let mut decls = Vec::new(); + for decl in &module.decls { + decls.push(self.analyze_decl(decl)); + } + + self.finish(); + + let mut final_decls = Vec::new(); + for decl in decls { + final_decls.push(self.apply_subst_decl(decl)); + } + + TypedModule { decls: final_decls } + } + + /// Analyzes a single declaration, tracking variable environments for parameters, + /// and returning a typed declaration. + fn analyze_decl(&mut self, decl: &Decl) -> TypedDecl { + match &decl.kind { + DeclKind::Function { + name, + params, + return_type, + body, + .. + } => { + let mut local_env = self.env.clone(); + let mut typed_params = Vec::new(); + + for param in params { + let ty = Ty::from(¶m.ty.kind); + local_env.insert(param.name.clone(), ty.clone()); + typed_params.push((param.name.clone(), ty)); + } + + let global_env = std::mem::replace(&mut self.env, local_env); + let expected_ret_ty = return_type + .as_ref() + .map(|t| Ty::from(&t.kind)) + .unwrap_or(Ty::Unit); + + let typed_body = self.analyze_stmt(body, &expected_ret_ty); + self.env = global_env; + + TypedDecl::Function { + name: name.clone(), + params: typed_params, + return_type: expected_ret_ty, + body: typed_body, + } + } + } + } + + /// Analyzes a statement, recursively type-checking its inner expressions and + /// validating return statements against the `expected_ret_ty`. + fn analyze_stmt(&mut self, stmt: &Stmt, expected_ret_ty: &Ty) -> TypedStmt { + match &stmt.kind { + StmtKind::Compound { inner } => { + let mut typed_inner = Vec::new(); + + for s in inner { + typed_inner.push(self.analyze_stmt(s, expected_ret_ty)); + } + + TypedStmt::Compound { inner: typed_inner } + } + StmtKind::Return { value } => { + if let Some(expr) = value { + let typed_expr = self.analyze_expr(expr); + + if let Err(err) = self.unify(&typed_expr.ty, expected_ret_ty) { + self.errors.push(SemanticError::new(err, expr.span)); + } + + TypedStmt::Return { + value: Some(typed_expr), + } + } else { + if let Err(err) = self.unify(&Ty::Unit, expected_ret_ty) { + self.errors.push(SemanticError::new(err, stmt.span)); + } + + TypedStmt::Return { value: None } + } + } + } + } + + /// Analyzes an expression, generating type constraints, registering deferred checks, + /// and returning a typed expression with potentially unconstrained type variables. + fn analyze_expr(&mut self, expr: &Expr) -> TypedExpr { + match &expr.kind { + ExprKind::Identifier { name } => { + let ty = if let Some(ty) = self.env.get(name) { + ty.clone() + } else { + self.errors.push(SemanticError::new( + format!("undeclared identifier `{}`", name), + expr.span, + )); + + self.new_var() + }; + + TypedExpr { + kind: TypedExprKind::Identifier { name: name.clone() }, + ty, + } + } + + ExprKind::Integer { value } => { + let ty = self.new_var(); + self.deferred_literals.push((expr.span, ty.clone())); + + TypedExpr { + kind: TypedExprKind::Integer { value: *value }, + ty, + } + } + + ExprKind::Boolean { value } => TypedExpr { + kind: TypedExprKind::Boolean { value: *value }, + ty: Ty::Bool, + }, + + ExprKind::Unary { + op: UnaryOp::Neg, + expr: inner_expr, + } => { + let typed_inner = self.analyze_expr(inner_expr); + let result_ty = self.new_var(); + + let known_value = match &inner_expr.kind { + ExprKind::Integer { value } => Some(*value), + _ => None, + }; + + self.deferred_unary_neg.push(( + expr.span, + typed_inner.ty.clone(), + result_ty.clone(), + known_value, + )); + + TypedExpr { + kind: TypedExprKind::Unary { + op: UnaryOp::Neg, + expr: Box::new(typed_inner), + }, + ty: result_ty, + } + } + + ExprKind::Binary { op, lhs, rhs } => { + let typed_lhs = self.analyze_expr(lhs); + let typed_rhs = self.analyze_expr(rhs); + + if let Err(e) = self.unify(&typed_lhs.ty, &typed_rhs.ty) { + self.errors.push(SemanticError::new(e, expr.span)); + } + + let result_ty = typed_lhs.ty.clone(); + self.deferred_binary.push((expr.span, result_ty.clone())); + TypedExpr { + kind: TypedExprKind::Binary { + op: *op, + lhs: Box::new(typed_lhs), + rhs: Box::new(typed_rhs), + }, + ty: result_ty, + } + } + } + } + + /// Recursively applies the final resolved type substitutions to a typed declaration. + fn apply_subst_decl(&self, decl: TypedDecl) -> TypedDecl { + match decl { + TypedDecl::Function { + name, + params, + return_type, + body, + } => { + let params = params + .into_iter() + .map(|(n, ty)| (n, self.apply_subst(&ty))) + .collect(); + + TypedDecl::Function { + name, + params, + return_type: self.apply_subst(&return_type), + body: self.apply_subst_stmt(body), + } + } + } + } + + /// Recursively applies the final resolved type substitutions to a typed statement. + fn apply_subst_stmt(&self, stmt: TypedStmt) -> TypedStmt { + match stmt { + TypedStmt::Compound { inner } => TypedStmt::Compound { + inner: inner + .into_iter() + .map(|s| self.apply_subst_stmt(s)) + .collect(), + }, + + TypedStmt::Return { value } => TypedStmt::Return { + value: value.map(|e| self.apply_subst_expr(e)), + }, + } + } + + /// Recursively applies the final resolved type substitutions to a typed expression. + fn apply_subst_expr(&self, expr: TypedExpr) -> TypedExpr { + let ty = self.apply_subst(&expr.ty); + let kind = match expr.kind { + TypedExprKind::Identifier { name } => TypedExprKind::Identifier { name }, + TypedExprKind::Integer { value } => TypedExprKind::Integer { value }, + TypedExprKind::Boolean { value } => TypedExprKind::Boolean { value }, + + TypedExprKind::Unary { op, expr } => TypedExprKind::Unary { + op, + expr: Box::new(self.apply_subst_expr(*expr)), + }, + + TypedExprKind::Binary { op, lhs, rhs } => TypedExprKind::Binary { + op, + lhs: Box::new(self.apply_subst_expr(*lhs)), + rhs: Box::new(self.apply_subst_expr(*rhs)), + }, + }; + + TypedExpr { kind, ty } + } + + /// Resolves all deferred type constraints accumulated during analysis, such as + /// integer literal sizing, unary negation promotion rules, and binary operator limits. + fn finish(&mut self) { + for (span, inner_ty, result_ty, known_value) in std::mem::take(&mut self.deferred_unary_neg) + { + let inner_resolved = self.apply_subst(&inner_ty); + let result_resolved = self.apply_subst(&result_ty); + + let inner_final = if let Ty::Var(_) = inner_resolved { + if result_resolved.is_integer() { + let _ = self.unify(&inner_resolved, &result_resolved); + result_resolved.clone() + } else { + let default = Ty::I32; + let _ = self.unify(&inner_resolved, &default); + default + } + } else { + inner_resolved + }; + + // Determine the result type of the unary minus operation. + // Signed integers remain the same type. Unsigned integers are promoted to the + // next largest signed integer type to ensure they can safely represent the negative value. + // If the value is known at compile time, we can avoid promoting to a larger size + // as long as the exact value fits within the signed equivalent of the same size. + let promoted_ty = match inner_final { + Ty::I8 | Ty::I16 | Ty::I32 | Ty::I64 => Ok(inner_final), + + Ty::U8 => Ok(if known_value.is_some_and(|v| v <= i8::MAX as u64) { + Ty::I8 + } else { + Ty::I16 + }), + + Ty::U16 => Ok(if known_value.is_some_and(|v| v <= i16::MAX as u64) { + Ty::I16 + } else { + Ty::I32 + }), + + Ty::U32 => Ok(if known_value.is_some_and(|v| v <= i32::MAX as u64) { + Ty::I32 + } else { + Ty::I64 + }), + + Ty::U64 => { + if known_value.is_some_and(|v| v <= i64::MAX as u64) { + Ok(Ty::I64) + } else if known_value.is_some() { + Err("value too large to be promoted to signed 64-bit integer") + } else { + Err("cannot promote u64 to a larger signed type") + } + } + _ => Err("unary minus only works on integer types"), + }; + + match promoted_ty { + Ok(promoted) => { + if let Err(e) = self.unify(&promoted, &result_resolved) { + self.errors.push(SemanticError::new(e, span)); + } + } + Err(err) => { + self.errors.push(SemanticError::new(err, span)); + } + } + } + + for (span, ty) in std::mem::take(&mut self.deferred_binary) { + let resolved = self.apply_subst(&ty); + + let final_ty = if let Ty::Var(_) = resolved { + let default = Ty::I32; + let _ = self.unify(&resolved, &default); + default + } else { + resolved + }; + + if !final_ty.is_integer() { + self.errors.push(SemanticError::new( + "binary operators only work on integer types", + span, + )); + } + } + + for (span, ty) in std::mem::take(&mut self.deferred_literals) { + let resolved = self.apply_subst(&ty); + + let final_ty = if let Ty::Var(_) = resolved { + let default = Ty::I32; + let _ = self.unify(&resolved, &default); + default + } else { + resolved + }; + + if !final_ty.is_integer() { + self.errors.push(SemanticError::new( + "expected integer type for literal", + span, + )); + } + } + } +} + +#[cfg(test)] +mod test { + use crate::frontend::{ + parser::Parser, + sema::{Sema, SemanticError}, + typed_ast::TypedModule, + }; + + fn analyze(source: &str) -> Result> { + let mut parser = Parser::new(source); + let module = parser.parse_module(); + + if let Some(errors) = parser.errors() { + panic!("Parse errors during semantic analysis test: {:?}", errors); + } + + let mut sema = Sema::new(); + let typed_module = sema.analyze_module(&module); + + if let Some(errors) = sema.errors() { + Err(errors) + } else { + Ok(typed_module) + } + } + + #[test] + fn valid_function() { + let src = "fn add(a: i32, b: i32) -> i32 { return a + b; }"; + assert!(analyze(src).is_ok()); + } + + #[test] + fn type_mismatch() { + let src = "fn bad(a: i32) -> bool { return a; }"; + let errors = analyze(src).unwrap_err(); + + assert_eq!(errors.len(), 1); + assert!(errors[0].message.contains("type mismatch")); + } + + #[test] + fn undeclared_identifier() { + let src = "fn oops() -> i32 { return x; }"; + let errors = analyze(src).unwrap_err(); + + assert_eq!(errors.len(), 1); + assert!(errors[0].message.contains("undeclared identifier `x`")); + } + + #[test] + fn binary_op_non_integer() { + let src = "fn bad_bin() -> bool { return true + false; }"; + let errors = analyze(src).unwrap_err(); + + assert!(errors.iter().any(|e| { + e.message + .contains("binary operators only work on integer types") + })); + } + + #[test] + fn unary_neg_promotion_u64_unknown() { + // We cannot promote an unknown u64 parameter since its bounds + // are unconstrained and might overflow signed i64. + let src = "fn test(x: u64) -> i64 { return -x; }"; + let errors = analyze(src).unwrap_err(); + + assert!(errors.iter().any(|e| { + e.message + .contains("cannot promote u64 to a larger signed type") + })); + } + + #[test] + fn unary_neg_invalid_type() { + let src = "fn test(x: bool) -> bool { return -x; }"; + let errors = analyze(src).unwrap_err(); + + assert!(errors.iter().any(|e| { + e.message + .contains("unary minus only works on integer types") + })); + } +} diff --git a/src/frontend/typed_ast.rs b/src/frontend/typed_ast.rs new file mode 100644 index 0000000..16caba4 --- /dev/null +++ b/src/frontend/typed_ast.rs @@ -0,0 +1,51 @@ +use crate::frontend::ast::{BinaryOp, UnaryOp}; +use crate::frontend::sema::Ty; + +#[derive(Debug, PartialEq, Eq)] +pub struct TypedModule { + pub decls: Vec, +} + +#[derive(Debug, PartialEq, Eq)] +pub enum TypedDecl { + Function { + name: String, + params: Vec<(String, Ty)>, + return_type: Ty, + body: TypedStmt, + }, +} + +#[derive(Debug, PartialEq, Eq)] +pub enum TypedStmt { + Compound { inner: Vec }, + Return { value: Option }, +} + +#[derive(Debug, PartialEq, Eq)] +pub struct TypedExpr { + pub kind: TypedExprKind, + pub ty: Ty, +} + +#[derive(Debug, PartialEq, Eq)] +pub enum TypedExprKind { + Identifier { + name: String, + }, + Integer { + value: u64, + }, + Boolean { + value: bool, + }, + Unary { + op: UnaryOp, + expr: Box, + }, + Binary { + op: BinaryOp, + lhs: Box, + rhs: Box, + }, +} diff --git a/src/main.rs b/src/main.rs index 3cd3c92..26f9f95 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,7 @@ use std::{env::args, fs::read_to_string, process::exit}; use crate::frontend::parser::Parser; +use crate::frontend::sema::Sema; pub mod frontend; @@ -26,5 +27,16 @@ fn main() { exit(1); } - println!("{:#?}", module); + let mut sema = Sema::new(); + let typed_module = sema.analyze_module(&module); + + if let Some(errors) = sema.errors() { + for error in errors { + eprintln!("{:?}", error); + } + + exit(1); + } + + println!("{:#?}", typed_module); }