feat: add typed ast and Hindley-Milner style semantic analysis

This commit is contained in:
2026-04-20 22:33:41 +02:00
parent 27d033135c
commit c3ee0d6e67
5 changed files with 699 additions and 3 deletions
+2
View File
@@ -1,4 +1,6 @@
pub mod ast;
pub mod lexer;
pub mod parser;
pub mod sema;
pub mod token;
pub mod typed_ast;
+631
View File
@@ -0,0 +1,631 @@
use std::collections::HashMap;
use crate::frontend::ast::*;
use crate::frontend::token::Span;
use crate::frontend::typed_ast::*;
/// A structured error produced during semantic analysis, carrying a human-readable
/// message and the [Span] of the offending AST node for precise diagnostics.
#[derive(Debug, PartialEq, Eq)]
pub struct SemanticError {
/// Human-readable description of the semantic error.
pub message: String,
/// Source location of the offending node.
pub span: Span,
}
impl SemanticError {
/// Creates a new [SemanticError] with the given message and source span.
pub fn new(message: impl ToString, span: Span) -> Self {
Self {
message: message.to_string(),
span,
}
}
}
/// An internal representation of types used during semantic analysis, including
/// concrete types, unconstrained type variables, and function types.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Ty {
I8,
I16,
I32,
I64,
U8,
U16,
U32,
U64,
Bool,
Unit,
Var(usize),
Function(Vec<Ty>, Box<Ty>),
}
impl Ty {
/// Returns `true` if the type is a fundamental integer type.
pub fn is_integer(&self) -> bool {
matches!(
self,
Ty::I8 | Ty::I16 | Ty::I32 | Ty::I64 | Ty::U8 | Ty::U16 | Ty::U32 | Ty::U64
)
}
}
impl From<&TypeKind> for Ty {
fn from(kind: &TypeKind) -> Self {
match kind {
TypeKind::I8 => Ty::I8,
TypeKind::I16 => Ty::I16,
TypeKind::I32 => Ty::I32,
TypeKind::I64 => Ty::I64,
TypeKind::U8 => Ty::U8,
TypeKind::U16 => Ty::U16,
TypeKind::U32 => Ty::U32,
TypeKind::U64 => Ty::U64,
TypeKind::Bool => Ty::Bool,
}
}
}
/// The semantic analyzer, responsible for type inference, name resolution, and type checking.
///
/// Uses a Hindley-Milner style algorithm with unification to transform an untyped AST into a typed AST.
pub struct Sema {
next_var: usize,
subst: HashMap<usize, Ty>,
env: HashMap<String, Ty>,
errors: Vec<SemanticError>,
deferred_unary_neg: Vec<(Span, Ty, Ty, Option<u64>)>,
deferred_binary: Vec<(Span, Ty)>,
deferred_literals: Vec<(Span, Ty)>,
}
impl Sema {
/// Creates a new, empty [Sema] instance.
pub fn new() -> Self {
Self {
next_var: 0,
subst: HashMap::new(),
env: HashMap::new(),
errors: Vec::new(),
deferred_unary_neg: Vec::new(),
deferred_binary: Vec::new(),
deferred_literals: Vec::new(),
}
}
/// Consumes the analyzer and returns the accumulated semantic errors, if any.
pub fn errors(self) -> Option<Vec<SemanticError>> {
(!self.errors.is_empty()).then_some(self.errors)
}
/// Generates a fresh, unconstrained type variable.
fn new_var(&mut self) -> Ty {
let v = self.next_var;
self.next_var += 1;
Ty::Var(v)
}
/// Recursively applies the current type substitutions to resolve a type variable to its inferred type.
pub fn apply_subst(&self, ty: &Ty) -> Ty {
match ty {
Ty::Var(v) => {
if let Some(t) = self.subst.get(v) {
self.apply_subst(t)
} else {
Ty::Var(*v)
}
}
Ty::Function(params, ret) => {
let params = params.iter().map(|p| self.apply_subst(p)).collect();
let ret = Box::new(self.apply_subst(ret));
Ty::Function(params, ret)
}
t => t.clone(),
}
}
/// Performs an occurs check to prevent infinite recursive types during unification.
fn occurs_check(&self, v: usize, ty: &Ty) -> bool {
let ty = self.apply_subst(ty);
match ty {
Ty::Var(v2) => v == v2,
Ty::Function(params, ret) => {
params.iter().any(|p| self.occurs_check(v, p)) || self.occurs_check(v, &ret)
}
_ => false,
}
}
/// Attempts to unify two types. If they are compatible, it records the necessary
/// substitutions. Returns an error message string if unification fails.
fn unify(&mut self, t1: &Ty, t2: &Ty) -> Result<(), String> {
let t1 = self.apply_subst(t1);
let t2 = self.apply_subst(t2);
if t1 == t2 {
return Ok(());
}
match (t1, t2) {
(Ty::Var(v), t) | (t, Ty::Var(v)) => {
if self.occurs_check(v, &t) {
return Err("recursive type".to_string());
}
self.subst.insert(v, t);
Ok(())
}
(Ty::Function(p1, r1), Ty::Function(p2, r2)) => {
if p1.len() != p2.len() {
return Err("arity mismatch".to_string());
}
for (arg1, arg2) in p1.iter().zip(p2.iter()) {
self.unify(arg1, arg2)?;
}
self.unify(&r1, &r2)
}
(a, b) => Err(format!("type mismatch: expected {:?}, found {:?}", a, b)),
}
}
/// Analyzes an entire module, collecting function signatures into the environment,
/// performing type inference, and returning a fully evaluated [TypedModule].
pub fn analyze_module(&mut self, module: &Module) -> TypedModule {
for decl in &module.decls {
match &decl.kind {
DeclKind::Function {
name,
params,
return_type,
..
} => {
let param_tys: Vec<Ty> = params.iter().map(|p| Ty::from(&p.ty.kind)).collect();
let ret_ty = return_type
.as_ref()
.map(|t| Ty::from(&t.kind))
.unwrap_or(Ty::Unit);
self.env
.insert(name.clone(), Ty::Function(param_tys, Box::new(ret_ty)));
}
}
}
let mut decls = Vec::new();
for decl in &module.decls {
decls.push(self.analyze_decl(decl));
}
self.finish();
let mut final_decls = Vec::new();
for decl in decls {
final_decls.push(self.apply_subst_decl(decl));
}
TypedModule { decls: final_decls }
}
/// Analyzes a single declaration, tracking variable environments for parameters,
/// and returning a typed declaration.
fn analyze_decl(&mut self, decl: &Decl) -> TypedDecl {
match &decl.kind {
DeclKind::Function {
name,
params,
return_type,
body,
..
} => {
let mut local_env = self.env.clone();
let mut typed_params = Vec::new();
for param in params {
let ty = Ty::from(&param.ty.kind);
local_env.insert(param.name.clone(), ty.clone());
typed_params.push((param.name.clone(), ty));
}
let global_env = std::mem::replace(&mut self.env, local_env);
let expected_ret_ty = return_type
.as_ref()
.map(|t| Ty::from(&t.kind))
.unwrap_or(Ty::Unit);
let typed_body = self.analyze_stmt(body, &expected_ret_ty);
self.env = global_env;
TypedDecl::Function {
name: name.clone(),
params: typed_params,
return_type: expected_ret_ty,
body: typed_body,
}
}
}
}
/// Analyzes a statement, recursively type-checking its inner expressions and
/// validating return statements against the `expected_ret_ty`.
fn analyze_stmt(&mut self, stmt: &Stmt, expected_ret_ty: &Ty) -> TypedStmt {
match &stmt.kind {
StmtKind::Compound { inner } => {
let mut typed_inner = Vec::new();
for s in inner {
typed_inner.push(self.analyze_stmt(s, expected_ret_ty));
}
TypedStmt::Compound { inner: typed_inner }
}
StmtKind::Return { value } => {
if let Some(expr) = value {
let typed_expr = self.analyze_expr(expr);
if let Err(err) = self.unify(&typed_expr.ty, expected_ret_ty) {
self.errors.push(SemanticError::new(err, expr.span));
}
TypedStmt::Return {
value: Some(typed_expr),
}
} else {
if let Err(err) = self.unify(&Ty::Unit, expected_ret_ty) {
self.errors.push(SemanticError::new(err, stmt.span));
}
TypedStmt::Return { value: None }
}
}
}
}
/// Analyzes an expression, generating type constraints, registering deferred checks,
/// and returning a typed expression with potentially unconstrained type variables.
fn analyze_expr(&mut self, expr: &Expr) -> TypedExpr {
match &expr.kind {
ExprKind::Identifier { name } => {
let ty = if let Some(ty) = self.env.get(name) {
ty.clone()
} else {
self.errors.push(SemanticError::new(
format!("undeclared identifier `{}`", name),
expr.span,
));
self.new_var()
};
TypedExpr {
kind: TypedExprKind::Identifier { name: name.clone() },
ty,
}
}
ExprKind::Integer { value } => {
let ty = self.new_var();
self.deferred_literals.push((expr.span, ty.clone()));
TypedExpr {
kind: TypedExprKind::Integer { value: *value },
ty,
}
}
ExprKind::Boolean { value } => TypedExpr {
kind: TypedExprKind::Boolean { value: *value },
ty: Ty::Bool,
},
ExprKind::Unary {
op: UnaryOp::Neg,
expr: inner_expr,
} => {
let typed_inner = self.analyze_expr(inner_expr);
let result_ty = self.new_var();
let known_value = match &inner_expr.kind {
ExprKind::Integer { value } => Some(*value),
_ => None,
};
self.deferred_unary_neg.push((
expr.span,
typed_inner.ty.clone(),
result_ty.clone(),
known_value,
));
TypedExpr {
kind: TypedExprKind::Unary {
op: UnaryOp::Neg,
expr: Box::new(typed_inner),
},
ty: result_ty,
}
}
ExprKind::Binary { op, lhs, rhs } => {
let typed_lhs = self.analyze_expr(lhs);
let typed_rhs = self.analyze_expr(rhs);
if let Err(e) = self.unify(&typed_lhs.ty, &typed_rhs.ty) {
self.errors.push(SemanticError::new(e, expr.span));
}
let result_ty = typed_lhs.ty.clone();
self.deferred_binary.push((expr.span, result_ty.clone()));
TypedExpr {
kind: TypedExprKind::Binary {
op: *op,
lhs: Box::new(typed_lhs),
rhs: Box::new(typed_rhs),
},
ty: result_ty,
}
}
}
}
/// Recursively applies the final resolved type substitutions to a typed declaration.
fn apply_subst_decl(&self, decl: TypedDecl) -> TypedDecl {
match decl {
TypedDecl::Function {
name,
params,
return_type,
body,
} => {
let params = params
.into_iter()
.map(|(n, ty)| (n, self.apply_subst(&ty)))
.collect();
TypedDecl::Function {
name,
params,
return_type: self.apply_subst(&return_type),
body: self.apply_subst_stmt(body),
}
}
}
}
/// Recursively applies the final resolved type substitutions to a typed statement.
fn apply_subst_stmt(&self, stmt: TypedStmt) -> TypedStmt {
match stmt {
TypedStmt::Compound { inner } => TypedStmt::Compound {
inner: inner
.into_iter()
.map(|s| self.apply_subst_stmt(s))
.collect(),
},
TypedStmt::Return { value } => TypedStmt::Return {
value: value.map(|e| self.apply_subst_expr(e)),
},
}
}
/// Recursively applies the final resolved type substitutions to a typed expression.
fn apply_subst_expr(&self, expr: TypedExpr) -> TypedExpr {
let ty = self.apply_subst(&expr.ty);
let kind = match expr.kind {
TypedExprKind::Identifier { name } => TypedExprKind::Identifier { name },
TypedExprKind::Integer { value } => TypedExprKind::Integer { value },
TypedExprKind::Boolean { value } => TypedExprKind::Boolean { value },
TypedExprKind::Unary { op, expr } => TypedExprKind::Unary {
op,
expr: Box::new(self.apply_subst_expr(*expr)),
},
TypedExprKind::Binary { op, lhs, rhs } => TypedExprKind::Binary {
op,
lhs: Box::new(self.apply_subst_expr(*lhs)),
rhs: Box::new(self.apply_subst_expr(*rhs)),
},
};
TypedExpr { kind, ty }
}
/// Resolves all deferred type constraints accumulated during analysis, such as
/// integer literal sizing, unary negation promotion rules, and binary operator limits.
fn finish(&mut self) {
for (span, inner_ty, result_ty, known_value) in std::mem::take(&mut self.deferred_unary_neg)
{
let inner_resolved = self.apply_subst(&inner_ty);
let result_resolved = self.apply_subst(&result_ty);
let inner_final = if let Ty::Var(_) = inner_resolved {
if result_resolved.is_integer() {
let _ = self.unify(&inner_resolved, &result_resolved);
result_resolved.clone()
} else {
let default = Ty::I32;
let _ = self.unify(&inner_resolved, &default);
default
}
} else {
inner_resolved
};
// Determine the result type of the unary minus operation.
// Signed integers remain the same type. Unsigned integers are promoted to the
// next largest signed integer type to ensure they can safely represent the negative value.
// If the value is known at compile time, we can avoid promoting to a larger size
// as long as the exact value fits within the signed equivalent of the same size.
let promoted_ty = match inner_final {
Ty::I8 | Ty::I16 | Ty::I32 | Ty::I64 => Ok(inner_final),
Ty::U8 => Ok(if known_value.is_some_and(|v| v <= i8::MAX as u64) {
Ty::I8
} else {
Ty::I16
}),
Ty::U16 => Ok(if known_value.is_some_and(|v| v <= i16::MAX as u64) {
Ty::I16
} else {
Ty::I32
}),
Ty::U32 => Ok(if known_value.is_some_and(|v| v <= i32::MAX as u64) {
Ty::I32
} else {
Ty::I64
}),
Ty::U64 => {
if known_value.is_some_and(|v| v <= i64::MAX as u64) {
Ok(Ty::I64)
} else if known_value.is_some() {
Err("value too large to be promoted to signed 64-bit integer")
} else {
Err("cannot promote u64 to a larger signed type")
}
}
_ => Err("unary minus only works on integer types"),
};
match promoted_ty {
Ok(promoted) => {
if let Err(e) = self.unify(&promoted, &result_resolved) {
self.errors.push(SemanticError::new(e, span));
}
}
Err(err) => {
self.errors.push(SemanticError::new(err, span));
}
}
}
for (span, ty) in std::mem::take(&mut self.deferred_binary) {
let resolved = self.apply_subst(&ty);
let final_ty = if let Ty::Var(_) = resolved {
let default = Ty::I32;
let _ = self.unify(&resolved, &default);
default
} else {
resolved
};
if !final_ty.is_integer() {
self.errors.push(SemanticError::new(
"binary operators only work on integer types",
span,
));
}
}
for (span, ty) in std::mem::take(&mut self.deferred_literals) {
let resolved = self.apply_subst(&ty);
let final_ty = if let Ty::Var(_) = resolved {
let default = Ty::I32;
let _ = self.unify(&resolved, &default);
default
} else {
resolved
};
if !final_ty.is_integer() {
self.errors.push(SemanticError::new(
"expected integer type for literal",
span,
));
}
}
}
}
#[cfg(test)]
mod test {
use crate::frontend::{
parser::Parser,
sema::{Sema, SemanticError},
typed_ast::TypedModule,
};
fn analyze(source: &str) -> Result<TypedModule, Vec<SemanticError>> {
let mut parser = Parser::new(source);
let module = parser.parse_module();
if let Some(errors) = parser.errors() {
panic!("Parse errors during semantic analysis test: {:?}", errors);
}
let mut sema = Sema::new();
let typed_module = sema.analyze_module(&module);
if let Some(errors) = sema.errors() {
Err(errors)
} else {
Ok(typed_module)
}
}
#[test]
fn valid_function() {
let src = "fn add(a: i32, b: i32) -> i32 { return a + b; }";
assert!(analyze(src).is_ok());
}
#[test]
fn type_mismatch() {
let src = "fn bad(a: i32) -> bool { return a; }";
let errors = analyze(src).unwrap_err();
assert_eq!(errors.len(), 1);
assert!(errors[0].message.contains("type mismatch"));
}
#[test]
fn undeclared_identifier() {
let src = "fn oops() -> i32 { return x; }";
let errors = analyze(src).unwrap_err();
assert_eq!(errors.len(), 1);
assert!(errors[0].message.contains("undeclared identifier `x`"));
}
#[test]
fn binary_op_non_integer() {
let src = "fn bad_bin() -> bool { return true + false; }";
let errors = analyze(src).unwrap_err();
assert!(errors.iter().any(|e| {
e.message
.contains("binary operators only work on integer types")
}));
}
#[test]
fn unary_neg_promotion_u64_unknown() {
// We cannot promote an unknown u64 parameter since its bounds
// are unconstrained and might overflow signed i64.
let src = "fn test(x: u64) -> i64 { return -x; }";
let errors = analyze(src).unwrap_err();
assert!(errors.iter().any(|e| {
e.message
.contains("cannot promote u64 to a larger signed type")
}));
}
#[test]
fn unary_neg_invalid_type() {
let src = "fn test(x: bool) -> bool { return -x; }";
let errors = analyze(src).unwrap_err();
assert!(errors.iter().any(|e| {
e.message
.contains("unary minus only works on integer types")
}));
}
}
+51
View File
@@ -0,0 +1,51 @@
use crate::frontend::ast::{BinaryOp, UnaryOp};
use crate::frontend::sema::Ty;
#[derive(Debug, PartialEq, Eq)]
pub struct TypedModule {
pub decls: Vec<TypedDecl>,
}
#[derive(Debug, PartialEq, Eq)]
pub enum TypedDecl {
Function {
name: String,
params: Vec<(String, Ty)>,
return_type: Ty,
body: TypedStmt,
},
}
#[derive(Debug, PartialEq, Eq)]
pub enum TypedStmt {
Compound { inner: Vec<TypedStmt> },
Return { value: Option<TypedExpr> },
}
#[derive(Debug, PartialEq, Eq)]
pub struct TypedExpr {
pub kind: TypedExprKind,
pub ty: Ty,
}
#[derive(Debug, PartialEq, Eq)]
pub enum TypedExprKind {
Identifier {
name: String,
},
Integer {
value: u64,
},
Boolean {
value: bool,
},
Unary {
op: UnaryOp,
expr: Box<TypedExpr>,
},
Binary {
op: BinaryOp,
lhs: Box<TypedExpr>,
rhs: Box<TypedExpr>,
},
}
+13 -1
View File
@@ -1,6 +1,7 @@
use std::{env::args, fs::read_to_string, process::exit};
use crate::frontend::parser::Parser;
use crate::frontend::sema::Sema;
pub mod frontend;
@@ -26,5 +27,16 @@ fn main() {
exit(1);
}
println!("{:#?}", module);
let mut sema = Sema::new();
let typed_module = sema.analyze_module(&module);
if let Some(errors) = sema.errors() {
for error in errors {
eprintln!("{:?}", error);
}
exit(1);
}
println!("{:#?}", typed_module);
}