use std::collections::HashMap; use crate::frontend::ast::*; use crate::frontend::token::Span; use crate::frontend::typed_ast::*; /// A structured error produced during semantic analysis, carrying a human-readable /// message and the [Span] of the offending AST node for precise diagnostics. #[derive(Debug, PartialEq, Eq)] pub struct SemanticError { /// Human-readable description of the semantic error. pub message: String, /// Source location of the offending node. pub span: Span, } impl SemanticError { /// Creates a new [SemanticError] with the given message and source span. pub fn new(message: impl ToString, span: Span) -> Self { Self { message: message.to_string(), span, } } } /// An internal representation of types used during semantic analysis, including /// concrete types, unconstrained type variables, and function types. #[derive(Debug, Clone, PartialEq, Eq)] pub enum Ty { I8, I16, I32, I64, U8, U16, U32, U64, Bool, Unit, Var(usize), Function(Vec, Box), } impl Ty { /// Returns `true` if the type is a fundamental integer type. pub fn is_integer(&self) -> bool { matches!( self, Ty::I8 | Ty::I16 | Ty::I32 | Ty::I64 | Ty::U8 | Ty::U16 | Ty::U32 | Ty::U64 ) } } impl From<&TypeKind> for Ty { fn from(kind: &TypeKind) -> Self { match kind { TypeKind::I8 => Ty::I8, TypeKind::I16 => Ty::I16, TypeKind::I32 => Ty::I32, TypeKind::I64 => Ty::I64, TypeKind::U8 => Ty::U8, TypeKind::U16 => Ty::U16, TypeKind::U32 => Ty::U32, TypeKind::U64 => Ty::U64, TypeKind::Bool => Ty::Bool, } } } /// The semantic analyzer, responsible for type inference, name resolution, and type checking. /// /// Uses a Hindley-Milner style algorithm with unification to transform an untyped AST into a typed AST. pub struct Sema { next_var: usize, subst: HashMap, scopes: Vec>, errors: Vec, deferred_unary_neg: Vec<(Span, Ty, Ty, Option)>, deferred_binary: Vec<(Span, Ty)>, deferred_literals: Vec<(Span, Ty)>, is_reachable: bool, } impl Sema { /// Creates a new, empty [Sema] instance. pub fn new() -> Self { Self { next_var: 0, subst: HashMap::new(), scopes: Vec::new(), errors: Vec::new(), deferred_unary_neg: Vec::new(), deferred_binary: Vec::new(), deferred_literals: Vec::new(), is_reachable: true, } } /// Consumes the analyzer and returns the accumulated semantic errors, if any. pub fn errors(self) -> Option> { (!self.errors.is_empty()).then_some(self.errors) } /// Pushes a new, empty scope onto the environment stack. fn enter_scope(&mut self) { self.scopes.push(HashMap::new()); } /// Pops the current scope from the environment stack. fn leave_scope(&mut self) { self.scopes.pop(); } /// Binds a name to a type in the current innermost scope. fn bind(&mut self, name: &str, ty: Ty) { self.scopes.last_mut().unwrap().insert(name.to_string(), ty); } /// Looks up a name in the environment, searching from the innermost scope outwards. fn lookup(&self, name: &str) -> Option<&Ty> { self.scopes.iter().rev().find_map(|scope| scope.get(name)) } /// Generates a fresh, unconstrained type variable. fn new_var(&mut self) -> Ty { let v = self.next_var; self.next_var += 1; Ty::Var(v) } /// Recursively applies the current type substitutions to resolve a type variable to its inferred type. pub fn apply_subst(&self, ty: &Ty) -> Ty { match ty { Ty::Var(v) => { if let Some(t) = self.subst.get(v) { self.apply_subst(t) } else { Ty::Var(*v) } } Ty::Function(params, ret) => { let params = params.iter().map(|p| self.apply_subst(p)).collect(); let ret = Box::new(self.apply_subst(ret)); Ty::Function(params, ret) } t => t.clone(), } } /// Performs an occurs check to prevent infinite recursive types during unification. fn occurs_check(&self, v: usize, ty: &Ty) -> bool { let ty = self.apply_subst(ty); match ty { Ty::Var(v2) => v == v2, Ty::Function(params, ret) => { params.iter().any(|p| self.occurs_check(v, p)) || self.occurs_check(v, &ret) } _ => false, } } /// Attempts to unify two types. If they are compatible, it records the necessary /// substitutions. Returns an error message string if unification fails. fn unify(&mut self, t1: &Ty, t2: &Ty) -> Result<(), String> { let t1 = self.apply_subst(t1); let t2 = self.apply_subst(t2); if t1 == t2 { return Ok(()); } match (t1, t2) { (Ty::Var(v), t) | (t, Ty::Var(v)) => { if self.occurs_check(v, &t) { return Err("recursive type".to_string()); } self.subst.insert(v, t); Ok(()) } (Ty::Function(p1, r1), Ty::Function(p2, r2)) => { if p1.len() != p2.len() { return Err("arity mismatch".to_string()); } for (arg1, arg2) in p1.iter().zip(p2.iter()) { self.unify(arg1, arg2)?; } self.unify(&r1, &r2) } (a, b) => Err(format!("type mismatch: expected {:?}, found {:?}", a, b)), } } /// Analyzes an entire module, collecting function signatures into the environment, /// performing type inference, and returning a fully evaluated [TypedModule]. pub fn analyze_module(&mut self, module: &Module) -> TypedModule { self.enter_scope(); for decl in &module.decls { match &decl.kind { DeclKind::Function { name, params, return_type, .. } => { let param_tys: Vec = params.iter().map(|p| Ty::from(&p.ty.kind)).collect(); let ret_ty = return_type .as_ref() .map(|t| Ty::from(&t.kind)) .unwrap_or(Ty::Unit); self.bind(name, Ty::Function(param_tys, Box::new(ret_ty))); } } } let mut decls = Vec::new(); for decl in &module.decls { decls.push(self.analyze_decl(decl)); } self.finish(); let mut final_decls = Vec::new(); for decl in decls { final_decls.push(self.apply_subst_decl(decl)); } self.leave_scope(); TypedModule { decls: final_decls } } /// Analyzes a single declaration, tracking variable environments for parameters, /// and returning a typed declaration. fn analyze_decl(&mut self, decl: &Decl) -> TypedDecl { match &decl.kind { DeclKind::Function { name, params, return_type, body, .. } => { let mut typed_params = Vec::new(); self.enter_scope(); for param in params { let ty = Ty::from(¶m.ty.kind); self.bind(¶m.name, ty.clone()); typed_params.push((param.name.clone(), ty)); } let expected_ret_ty = return_type .as_ref() .map(|t| Ty::from(&t.kind)) .unwrap_or(Ty::Unit); self.is_reachable = true; let typed_body = self.analyze_stmt(body, &expected_ret_ty); if expected_ret_ty != Ty::Unit && self.is_reachable { self.errors.push(SemanticError::new( "not all control paths return a value", decl.span, )); } self.leave_scope(); TypedDecl::Function { name: name.clone(), params: typed_params, return_type: expected_ret_ty, body: typed_body, } } } } /// Analyzes a statement, recursively type-checking its inner expressions and /// validating return statements against the `expected_ret_ty`. fn analyze_stmt(&mut self, stmt: &Stmt, expected_ret_ty: &Ty) -> TypedStmt { match &stmt.kind { StmtKind::Compound { inner } => { let mut typed_inner = Vec::new(); let mut reported_unreachable = false; self.enter_scope(); for s in inner { if !self.is_reachable && !reported_unreachable { self.errors .push(SemanticError::new("unreachable statement", s.span)); reported_unreachable = true; } typed_inner.push(self.analyze_stmt(s, expected_ret_ty)); } self.leave_scope(); TypedStmt::Compound { inner: typed_inner } } StmtKind::If { condition, then, elze, } => { let typed_condition = self.analyze_expr(condition); if let Err(err) = self.unify(&typed_condition.ty, &Ty::Bool) { self.errors.push(SemanticError::new(err, condition.span)); } let initial_reachable = self.is_reachable; self.is_reachable = initial_reachable; let typed_then = self.analyze_stmt(then, expected_ret_ty); let reachable_after_then = self.is_reachable; let typed_elze = elze.as_ref().map(|e| { self.is_reachable = initial_reachable; self.analyze_stmt(e, expected_ret_ty) }); let reachable_after_else = if elze.is_some() { self.is_reachable } else { initial_reachable }; self.is_reachable = reachable_after_then || reachable_after_else; TypedStmt::If { condition: typed_condition, then: Box::new(typed_then), elze: typed_elze.map(Box::new), } } StmtKind::Return { value } => { if let Some(expr) = value { let typed_expr = self.analyze_expr(expr); if let Err(err) = self.unify(&typed_expr.ty, expected_ret_ty) { self.errors.push(SemanticError::new(err, expr.span)); } self.is_reachable = false; TypedStmt::Return { value: Some(typed_expr), } } else { if let Err(err) = self.unify(&Ty::Unit, expected_ret_ty) { self.errors.push(SemanticError::new(err, stmt.span)); } self.is_reachable = false; TypedStmt::Return { value: None } } } } } /// Analyzes an expression, generating type constraints, registering deferred checks, /// and returning a typed expression with potentially unconstrained type variables. fn analyze_expr(&mut self, expr: &Expr) -> TypedExpr { match &expr.kind { ExprKind::Identifier { name } => { let ty = if let Some(ty) = self.lookup(name) { ty.clone() } else { self.errors.push(SemanticError::new( format!("undeclared identifier `{}`", name), expr.span, )); self.new_var() }; TypedExpr { kind: TypedExprKind::Identifier { name: name.clone() }, ty, } } ExprKind::Integer { value } => { let ty = self.new_var(); self.deferred_literals.push((expr.span, ty.clone())); TypedExpr { kind: TypedExprKind::Integer { value: *value }, ty, } } ExprKind::Boolean { value } => TypedExpr { kind: TypedExprKind::Boolean { value: *value }, ty: Ty::Bool, }, ExprKind::Unary { op: UnaryOp::Neg, expr: inner_expr, } => { let typed_inner = self.analyze_expr(inner_expr); let result_ty = self.new_var(); let known_value = match &inner_expr.kind { ExprKind::Integer { value } => Some(*value), _ => None, }; self.deferred_unary_neg.push(( expr.span, typed_inner.ty.clone(), result_ty.clone(), known_value, )); TypedExpr { kind: TypedExprKind::Unary { op: UnaryOp::Neg, expr: Box::new(typed_inner), }, ty: result_ty, } } ExprKind::Unary { op: UnaryOp::Not, expr, } => { let typed_inner = self.analyze_expr(expr); if let Err(e) = self.unify(&Ty::Bool, &typed_inner.ty) { self.errors.push(SemanticError::new(e, expr.span)); } TypedExpr { kind: TypedExprKind::Unary { op: UnaryOp::Not, expr: Box::new(typed_inner), }, ty: Ty::Bool, } } ExprKind::Binary { op, lhs, rhs } => { let typed_lhs = self.analyze_expr(lhs); let typed_rhs = self.analyze_expr(rhs); if let Err(e) = self.unify(&typed_lhs.ty, &typed_rhs.ty) { self.errors.push(SemanticError::new(e, expr.span)); } let is_comparison = matches!( op, BinaryOp::Eq | BinaryOp::Neq | BinaryOp::Lt | BinaryOp::Le | BinaryOp::Gt | BinaryOp::Ge ); let result_ty = if is_comparison { Ty::Bool } else { typed_lhs.ty.clone() }; self.deferred_binary.push((expr.span, typed_lhs.ty.clone())); TypedExpr { kind: TypedExprKind::Binary { op: *op, lhs: Box::new(typed_lhs), rhs: Box::new(typed_rhs), }, ty: result_ty, } } } } /// Recursively applies the final resolved type substitutions to a typed declaration. fn apply_subst_decl(&self, decl: TypedDecl) -> TypedDecl { match decl { TypedDecl::Function { name, params, return_type, body, } => { let params = params .into_iter() .map(|(n, ty)| (n, self.apply_subst(&ty))) .collect(); TypedDecl::Function { name, params, return_type: self.apply_subst(&return_type), body: self.apply_subst_stmt(body), } } } } /// Recursively applies the final resolved type substitutions to a typed statement. fn apply_subst_stmt(&self, stmt: TypedStmt) -> TypedStmt { match stmt { TypedStmt::Compound { inner } => TypedStmt::Compound { inner: inner .into_iter() .map(|s| self.apply_subst_stmt(s)) .collect(), }, TypedStmt::If { condition, then, elze, } => TypedStmt::If { condition: self.apply_subst_expr(condition), then: Box::new(self.apply_subst_stmt(*then)), elze: elze.map(|s| Box::new(self.apply_subst_stmt(*s))), }, TypedStmt::Return { value } => TypedStmt::Return { value: value.map(|e| self.apply_subst_expr(e)), }, } } /// Recursively applies the final resolved type substitutions to a typed expression. fn apply_subst_expr(&self, expr: TypedExpr) -> TypedExpr { let ty = self.apply_subst(&expr.ty); let kind = match expr.kind { TypedExprKind::Identifier { name } => TypedExprKind::Identifier { name }, TypedExprKind::Integer { value } => TypedExprKind::Integer { value }, TypedExprKind::Boolean { value } => TypedExprKind::Boolean { value }, TypedExprKind::Unary { op, expr } => TypedExprKind::Unary { op, expr: Box::new(self.apply_subst_expr(*expr)), }, TypedExprKind::Binary { op, lhs, rhs } => TypedExprKind::Binary { op, lhs: Box::new(self.apply_subst_expr(*lhs)), rhs: Box::new(self.apply_subst_expr(*rhs)), }, }; TypedExpr { kind, ty } } /// Resolves all deferred type constraints accumulated during analysis, such as /// integer literal sizing, unary negation promotion rules, and binary operator limits. fn finish(&mut self) { for (span, inner_ty, result_ty, known_value) in std::mem::take(&mut self.deferred_unary_neg) { let inner_resolved = self.apply_subst(&inner_ty); let result_resolved = self.apply_subst(&result_ty); let inner_final = if let Ty::Var(_) = inner_resolved { if result_resolved.is_integer() { let _ = self.unify(&inner_resolved, &result_resolved); result_resolved.clone() } else { let default = Ty::I32; let _ = self.unify(&inner_resolved, &default); default } } else { inner_resolved }; // Determine the result type of the unary minus operation. // Signed integers remain the same type. Unsigned integers are promoted to the // next largest signed integer type to ensure they can safely represent the negative value. // If the value is known at compile time, we can avoid promoting to a larger size // as long as the exact value fits within the signed equivalent of the same size. let promoted_ty = match inner_final { Ty::I8 | Ty::I16 | Ty::I32 | Ty::I64 => Ok(inner_final), Ty::U8 => Ok(if known_value.is_some_and(|v| v <= i8::MAX as u64) { Ty::I8 } else { Ty::I16 }), Ty::U16 => Ok(if known_value.is_some_and(|v| v <= i16::MAX as u64) { Ty::I16 } else { Ty::I32 }), Ty::U32 => Ok(if known_value.is_some_and(|v| v <= i32::MAX as u64) { Ty::I32 } else { Ty::I64 }), Ty::U64 => { if known_value.is_some_and(|v| v <= i64::MAX as u64) { Ok(Ty::I64) } else if known_value.is_some() { Err("value too large to be promoted to signed 64-bit integer") } else { Err("cannot promote u64 to a larger signed type") } } _ => Err("unary minus only works on integer types"), }; match promoted_ty { Ok(promoted) => { if let Err(e) = self.unify(&promoted, &result_resolved) { self.errors.push(SemanticError::new(e, span)); } } Err(err) => { self.errors.push(SemanticError::new(err, span)); } } } for (span, ty) in std::mem::take(&mut self.deferred_binary) { let resolved = self.apply_subst(&ty); let final_ty = if let Ty::Var(_) = resolved { let default = Ty::I32; let _ = self.unify(&resolved, &default); default } else { resolved }; if !final_ty.is_integer() { self.errors.push(SemanticError::new( "binary operators only work on integer types", span, )); } } for (span, ty) in std::mem::take(&mut self.deferred_literals) { let resolved = self.apply_subst(&ty); let final_ty = if let Ty::Var(_) = resolved { let default = Ty::I32; let _ = self.unify(&resolved, &default); default } else { resolved }; if !final_ty.is_integer() { self.errors.push(SemanticError::new( "expected integer type for literal", span, )); } } } } #[cfg(test)] mod test { use crate::frontend::{ parser::Parser, sema::{Sema, SemanticError}, typed_ast::TypedModule, }; fn analyze(source: &str) -> Result> { let mut parser = Parser::new(source); let module = parser.parse_module(); if let Some(errors) = parser.errors() { panic!("Parse errors during semantic analysis test: {:?}", errors); } let mut sema = Sema::new(); let typed_module = sema.analyze_module(&module); if let Some(errors) = sema.errors() { Err(errors) } else { Ok(typed_module) } } #[test] fn valid_function() { let src = "fn add(a: i32, b: i32) -> i32 { return a + b; }"; assert!(analyze(src).is_ok()); } #[test] fn type_mismatch() { let src = "fn bad(a: i32) -> bool { return a; }"; let errors = analyze(src).unwrap_err(); assert_eq!(errors.len(), 1); assert!(errors[0].message.contains("type mismatch")); } #[test] fn undeclared_identifier() { let src = "fn oops() -> i32 { return x; }"; let errors = analyze(src).unwrap_err(); assert_eq!(errors.len(), 1); assert!(errors[0].message.contains("undeclared identifier `x`")); } #[test] fn binary_op_non_integer() { let src = "fn bad_bin() -> bool { return true + false; }"; let errors = analyze(src).unwrap_err(); assert!(errors.iter().any(|e| { e.message .contains("binary operators only work on integer types") })); } #[test] fn unary_neg_promotion_u64_unknown() { // We cannot promote an unknown u64 parameter since its bounds // are unconstrained and might overflow signed i64. let src = "fn test(x: u64) -> i64 { return -x; }"; let errors = analyze(src).unwrap_err(); assert!(errors.iter().any(|e| { e.message .contains("cannot promote u64 to a larger signed type") })); } #[test] fn unary_neg_invalid_type() { let src = "fn test(x: bool) -> bool { return -x; }"; let errors = analyze(src).unwrap_err(); assert!(errors.iter().any(|e| { e.message .contains("unary minus only works on integer types") })); } #[test] fn valid_logical_not() { let src = "fn test(a: bool) -> bool { return !a; }"; assert!(analyze(src).is_ok()); } #[test] fn invalid_logical_not() { let src = "fn test(a: i32) -> bool { return !a; }"; let errors = analyze(src).unwrap_err(); assert!(errors.iter().any(|e| e.message.contains("type mismatch"))); } #[test] fn valid_comparison() { let src = "fn test(a: i32, b: i32) -> bool { return a <= b; }"; assert!(analyze(src).is_ok()); } #[test] fn invalid_comparison() { let src = "fn test(a: bool, b: bool) -> bool { return a == b; }"; let errors = analyze(src).unwrap_err(); assert!(errors.iter().any(|e| { e.message .contains("binary operators only work on integer types") })); } #[test] fn valid_if() { let src = "fn test() { if true { return; } }"; assert!(analyze(src).is_ok()) } #[test] fn invalid_if() { let src = "fn test() { if 12 {} }"; assert!(analyze(src).is_err()); } #[test] fn not_all_paths_return() { let src = "fn test(a: i32) -> i32 { if a < 5 { return 5; } else { } }"; assert!(analyze(src).is_err()); let src = "fn test() -> i32 { }"; assert!(analyze(src).is_err()); let src = "fn test(a: i32) -> i32 { if a < 5 { return 5; } return 10; }"; assert!(analyze(src).is_ok()); } #[test] fn unreachable_code() { let src = "fn test() -> i32 { return 5; return 10; }"; let errors = analyze(src).unwrap_err(); assert!( errors .iter() .any(|e| e.message.contains("unreachable statement")) ); } }