Files
compiler-old/src/frontend/sema.rs
T

937 lines
30 KiB
Rust

use std::collections::HashMap;
use crate::frontend::ast::*;
use crate::frontend::token::Span;
/// A structured error produced during semantic analysis, carrying a human-readable
/// message and the [Span] of the offending AST node for precise diagnostics.
#[derive(Debug, PartialEq, Eq)]
pub struct SemanticError {
/// Human-readable description of the semantic error.
pub message: String,
/// Source location of the offending node.
pub span: Span,
}
impl SemanticError {
/// Creates a new [SemanticError] with the given message and source span.
pub fn new(message: impl ToString, span: Span) -> Self {
Self {
message: message.to_string(),
span,
}
}
}
/// An internal representation of types used during semantic analysis, including
/// concrete types, unconstrained type variables, and function types.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Ty {
I8,
I16,
I32,
I64,
U8,
U16,
U32,
U64,
Bool,
Unit,
Var(usize),
Function(Vec<Ty>, Box<Ty>),
}
impl Ty {
/// Returns `true` if the type is a fundamental integer type.
pub fn is_integer(&self) -> bool {
matches!(
self,
Ty::I8 | Ty::I16 | Ty::I32 | Ty::I64 | Ty::U8 | Ty::U16 | Ty::U32 | Ty::U64
)
}
}
impl From<&TypeKind> for Ty {
fn from(kind: &TypeKind) -> Self {
match kind {
TypeKind::I8 => Ty::I8,
TypeKind::I16 => Ty::I16,
TypeKind::I32 => Ty::I32,
TypeKind::I64 => Ty::I64,
TypeKind::U8 => Ty::U8,
TypeKind::U16 => Ty::U16,
TypeKind::U32 => Ty::U32,
TypeKind::U64 => Ty::U64,
TypeKind::Bool => Ty::Bool,
}
}
}
/// The semantic analyzer, responsible for type inference, name resolution, and type checking.
///
/// Uses a Hindley-Milner style algorithm with unification to transform an untyped AST into a typed AST.
pub struct Sema {
next_var: usize,
subst: HashMap<usize, Ty>,
scopes: Vec<HashMap<String, Ty>>,
errors: Vec<SemanticError>,
deferred_unary_neg: Vec<(Span, Ty, Ty, Option<u64>)>,
deferred_binary: Vec<(Span, Ty)>,
deferred_literals: Vec<(Span, Ty)>,
}
impl Sema {
/// Creates a new, empty [Sema] instance.
pub fn new() -> Self {
Self {
next_var: 0,
subst: HashMap::new(),
scopes: Vec::new(),
errors: Vec::new(),
deferred_unary_neg: Vec::new(),
deferred_binary: Vec::new(),
deferred_literals: Vec::new(),
}
}
/// Consumes the analyzer and returns the accumulated semantic errors, if any.
pub fn errors(self) -> Option<Vec<SemanticError>> {
(!self.errors.is_empty()).then_some(self.errors)
}
/// Pushes a new, empty scope onto the environment stack.
fn enter_scope(&mut self) {
self.scopes.push(HashMap::new());
}
/// Pops the current scope from the environment stack.
fn leave_scope(&mut self) {
self.scopes.pop();
}
/// Binds a name to a type in the current innermost scope.
fn bind(&mut self, name: &str, ty: Ty) {
self.scopes.last_mut().unwrap().insert(name.to_string(), ty);
}
/// Looks up a name in the environment, searching from the innermost scope outwards.
fn lookup(&self, name: &str) -> Option<&Ty> {
self.scopes.iter().rev().find_map(|scope| scope.get(name))
}
/// Generates a fresh, unconstrained type variable.
fn new_var(&mut self) -> Ty {
let v = self.next_var;
self.next_var += 1;
Ty::Var(v)
}
/// Recursively applies the current type substitutions to resolve a type variable to its inferred type.
pub fn apply_subst(&self, ty: &Ty) -> Ty {
match ty {
Ty::Var(v) => {
if let Some(t) = self.subst.get(v) {
self.apply_subst(t)
} else {
Ty::Var(*v)
}
}
Ty::Function(params, ret) => {
let params = params.iter().map(|p| self.apply_subst(p)).collect();
let ret = Box::new(self.apply_subst(ret));
Ty::Function(params, ret)
}
t => t.clone(),
}
}
/// Performs an occurs check to prevent infinite recursive types during unification.
fn occurs_check(&self, v: usize, ty: &Ty) -> bool {
let ty = self.apply_subst(ty);
match ty {
Ty::Var(v2) => v == v2,
Ty::Function(params, ret) => {
params.iter().any(|p| self.occurs_check(v, p)) || self.occurs_check(v, &ret)
}
_ => false,
}
}
/// Attempts to unify two types. If they are compatible, it records the necessary
/// substitutions. Returns an error message string if unification fails.
fn unify(&mut self, t1: &Ty, t2: &Ty) -> Result<(), String> {
let t1 = self.apply_subst(t1);
let t2 = self.apply_subst(t2);
if t1 == t2 {
return Ok(());
}
match (t1, t2) {
(Ty::Var(v), t) | (t, Ty::Var(v)) => {
if self.occurs_check(v, &t) {
return Err("recursive type".to_string());
}
self.subst.insert(v, t);
Ok(())
}
(Ty::Function(p1, r1), Ty::Function(p2, r2)) => {
if p1.len() != p2.len() {
return Err("arity mismatch".to_string());
}
for (arg1, arg2) in p1.iter().zip(p2.iter()) {
self.unify(arg1, arg2)?;
}
self.unify(&r1, &r2)
}
(a, b) => Err(format!("type mismatch: expected {:?}, found {:?}", a, b)),
}
}
/// Analyzes an entire module, collecting function signatures into the environment,
/// performing type inference, and returning a fully evaluated [TypedModule].
pub fn analyze_module(&mut self, module: &Module) -> TypedModule {
self.enter_scope();
for decl in &module.decls {
match &decl.kind {
DeclKind::Function {
name,
params,
return_type,
..
} => {
let param_tys: Vec<Ty> = params.iter().map(|p| Ty::from(&p.ty.kind)).collect();
let ret_ty = return_type
.as_ref()
.map(|t| Ty::from(&t.kind))
.unwrap_or(Ty::Unit);
self.bind(name, Ty::Function(param_tys, Box::new(ret_ty)));
}
}
}
let mut decls = Vec::new();
for decl in &module.decls {
decls.push(self.analyze_decl(decl));
}
self.finish();
let mut final_decls = Vec::new();
for decl in decls {
final_decls.push(self.apply_subst_decl(decl));
}
self.leave_scope();
TypedModule { decls: final_decls }
}
/// Analyzes a single declaration, tracking variable environments for parameters,
/// and returning a typed declaration.
fn analyze_decl(&mut self, decl: &Decl) -> TypedDecl {
match &decl.kind {
DeclKind::Function {
name,
name_span,
params,
return_type,
body,
} => {
let mut typed_params = Vec::new();
self.enter_scope();
for param in params {
let ty = Ty::from(&param.ty.kind);
self.bind(&param.name, ty.clone());
typed_params.push((param.name.clone(), ty));
}
let expected_ret_ty = return_type
.as_ref()
.map(|t| Ty::from(&t.kind))
.unwrap_or(Ty::Unit);
let typed_body = self.analyze_stmt(body, &expected_ret_ty);
self.leave_scope();
TypedDecl {
kind: TypedDeclKind::Function {
name: name.clone(),
name_span: *name_span,
params: typed_params,
return_type: expected_ret_ty,
body: typed_body,
},
span: decl.span,
}
}
}
}
/// Analyzes a statement, recursively type-checking its inner expressions and
/// validating return statements against the `expected_ret_ty`.
fn analyze_stmt(&mut self, stmt: &Stmt, expected_ret_ty: &Ty) -> TypedStmt {
match &stmt.kind {
StmtKind::Compound { inner } => {
let mut typed_inner = Vec::new();
self.enter_scope();
for s in inner {
typed_inner.push(self.analyze_stmt(s, expected_ret_ty));
}
self.leave_scope();
TypedStmt {
kind: TypedStmtKind::Compound { inner: typed_inner },
span: stmt.span,
}
}
StmtKind::If {
condition,
then,
elze,
} => {
let typed_condition = self.analyze_expr(condition);
if let Err(err) = self.unify(&typed_condition.ty, &Ty::Bool) {
self.errors.push(SemanticError::new(err, condition.span));
}
let typed_then = self.analyze_stmt(then, expected_ret_ty);
let typed_elze = elze.as_ref().map(|e| self.analyze_stmt(e, expected_ret_ty));
TypedStmt {
kind: TypedStmtKind::If {
condition: typed_condition,
then: Box::new(typed_then),
elze: typed_elze.map(Box::new),
},
span: stmt.span,
}
}
StmtKind::Return { value } => {
if let Some(expr) = value {
let typed_expr = self.analyze_expr(expr);
if let Err(err) = self.unify(&typed_expr.ty, expected_ret_ty) {
self.errors.push(SemanticError::new(err, expr.span));
}
TypedStmt {
kind: TypedStmtKind::Return {
value: Some(typed_expr),
},
span: stmt.span,
}
} else {
if let Err(err) = self.unify(&Ty::Unit, expected_ret_ty) {
self.errors.push(SemanticError::new(err, stmt.span));
}
TypedStmt {
kind: TypedStmtKind::Return { value: None },
span: stmt.span,
}
}
}
StmtKind::Let {
name,
name_span,
ty: type_annotation,
initializer,
} => {
let explicit_ty = type_annotation.as_ref().map(|t| Ty::from(&t.kind));
let (inferred_ty, typed_initializer) = match initializer {
Some(expr) => {
let typed_expr = self.analyze_expr(expr);
if let Some(ref ext_ty) = explicit_ty {
if let Err(err) = self.unify(&typed_expr.ty, ext_ty) {
self.errors.push(SemanticError::new(err, expr.span));
}
}
(typed_expr.ty.clone(), Some(typed_expr))
}
None => {
if let Some(ref ext_ty) = explicit_ty {
(ext_ty.clone(), None)
} else {
self.errors
.push(SemanticError::new("type annotation needed", *name_span));
(self.new_var(), None)
}
}
};
let final_ty = explicit_ty.unwrap_or(inferred_ty);
self.bind(name, final_ty.clone());
TypedStmt {
kind: TypedStmtKind::Let {
name: name.clone(),
name_span: *name_span,
ty: final_ty,
initializer: typed_initializer,
},
span: stmt.span,
}
}
StmtKind::Expression { expr } => {
let typed_expr = self.analyze_expr(expr);
TypedStmt {
kind: TypedStmtKind::Expression { expr: typed_expr },
span: stmt.span,
}
}
}
}
/// Analyzes an expression, generating type constraints, registering deferred checks,
/// and returning a typed expression with potentially unconstrained type variables.
fn analyze_expr(&mut self, expr: &Expr) -> TypedExpr {
match &expr.kind {
ExprKind::Identifier { name } => {
let ty = if let Some(ty) = self.lookup(name) {
ty.clone()
} else {
self.errors.push(SemanticError::new(
format!("undeclared identifier `{}`", name),
expr.span,
));
self.new_var()
};
TypedExpr {
kind: TypedExprKind::Identifier { name: name.clone() },
ty,
span: expr.span,
}
}
ExprKind::Integer { value } => {
let ty = self.new_var();
self.deferred_literals.push((expr.span, ty.clone()));
TypedExpr {
kind: TypedExprKind::Integer { value: *value },
ty,
span: expr.span,
}
}
ExprKind::Boolean { value } => TypedExpr {
kind: TypedExprKind::Boolean { value: *value },
ty: Ty::Bool,
span: expr.span,
},
ExprKind::Unary {
op: UnaryOp::Neg,
expr: inner_expr,
} => {
let typed_inner = self.analyze_expr(inner_expr);
let result_ty = self.new_var();
let known_value = match &inner_expr.kind {
ExprKind::Integer { value } => Some(*value),
_ => None,
};
self.deferred_unary_neg.push((
expr.span,
typed_inner.ty.clone(),
result_ty.clone(),
known_value,
));
TypedExpr {
kind: TypedExprKind::Unary {
op: UnaryOp::Neg,
expr: Box::new(typed_inner),
},
ty: result_ty,
span: expr.span,
}
}
ExprKind::Unary {
op: UnaryOp::Not,
expr,
} => {
let typed_inner = self.analyze_expr(expr);
if let Err(e) = self.unify(&Ty::Bool, &typed_inner.ty) {
self.errors.push(SemanticError::new(e, expr.span));
}
TypedExpr {
kind: TypedExprKind::Unary {
op: UnaryOp::Not,
expr: Box::new(typed_inner),
},
ty: Ty::Bool,
span: expr.span,
}
}
ExprKind::Binary { op, lhs, rhs } => {
let typed_lhs = self.analyze_expr(lhs);
let typed_rhs = self.analyze_expr(rhs);
if let Err(e) = self.unify(&typed_lhs.ty, &typed_rhs.ty) {
self.errors.push(SemanticError::new(e, expr.span));
}
let is_comparison = matches!(
op,
BinaryOp::Eq
| BinaryOp::Neq
| BinaryOp::Lt
| BinaryOp::Le
| BinaryOp::Gt
| BinaryOp::Ge
);
let result_ty = if is_comparison {
Ty::Bool
} else {
typed_lhs.ty.clone()
};
self.deferred_binary.push((expr.span, typed_lhs.ty.clone()));
TypedExpr {
kind: TypedExprKind::Binary {
op: *op,
lhs: Box::new(typed_lhs),
rhs: Box::new(typed_rhs),
},
ty: result_ty,
span: expr.span,
}
}
ExprKind::Assign { lval, rval } => {
let typed_rval = self.analyze_expr(rval);
let typed_lval = self.analyze_expr(lval);
match &typed_lval.kind {
TypedExprKind::Identifier { .. } => {}
_ => self.errors.push(SemanticError::new(
"invalid left-hand side of assignment",
lval.span,
)),
}
if let Err(e) = self.unify(&typed_lval.ty, &typed_rval.ty) {
self.errors.push(SemanticError::new(e, expr.span));
}
TypedExpr {
ty: typed_rval.ty.clone(),
kind: TypedExprKind::Assign {
lval: Box::new(typed_lval),
rval: Box::new(typed_rval),
},
span: expr.span,
}
}
}
}
/// Recursively applies the final resolved type substitutions to a typed declaration.
fn apply_subst_decl(&self, decl: TypedDecl) -> TypedDecl {
let span = decl.span;
let kind = match decl.kind {
TypedDeclKind::Function {
name,
name_span,
params,
return_type,
body,
} => {
let params = params
.into_iter()
.map(|(n, ty)| (n, self.apply_subst(&ty)))
.collect();
TypedDeclKind::Function {
name,
name_span,
params,
return_type: self.apply_subst(&return_type),
body: self.apply_subst_stmt(body),
}
}
};
TypedDecl { kind, span }
}
/// Recursively applies the final resolved type substitutions to a typed statement.
fn apply_subst_stmt(&self, stmt: TypedStmt) -> TypedStmt {
let span = stmt.span;
let kind = match stmt.kind {
TypedStmtKind::Compound { inner } => TypedStmtKind::Compound {
inner: inner
.into_iter()
.map(|s| self.apply_subst_stmt(s))
.collect(),
},
TypedStmtKind::If {
condition,
then,
elze,
} => TypedStmtKind::If {
condition: self.apply_subst_expr(condition),
then: Box::new(self.apply_subst_stmt(*then)),
elze: elze.map(|s| Box::new(self.apply_subst_stmt(*s))),
},
TypedStmtKind::Return { value } => TypedStmtKind::Return {
value: value.map(|e| self.apply_subst_expr(e)),
},
TypedStmtKind::Let {
name,
name_span,
ty,
initializer,
} => TypedStmtKind::Let {
name,
name_span,
ty: self.apply_subst(&ty),
initializer: initializer.map(|e| self.apply_subst_expr(e)),
},
TypedStmtKind::Expression { expr } => TypedStmtKind::Expression {
expr: self.apply_subst_expr(expr),
},
};
TypedStmt { kind, span }
}
/// Recursively applies the final resolved type substitutions to a typed expression.
fn apply_subst_expr(&self, expr: TypedExpr) -> TypedExpr {
let ty = self.apply_subst(&expr.ty);
let span = expr.span;
let kind = match expr.kind {
TypedExprKind::Identifier { name } => TypedExprKind::Identifier { name },
TypedExprKind::Integer { value } => TypedExprKind::Integer { value },
TypedExprKind::Boolean { value } => TypedExprKind::Boolean { value },
TypedExprKind::Unary { op, expr } => TypedExprKind::Unary {
op,
expr: Box::new(self.apply_subst_expr(*expr)),
},
TypedExprKind::Binary { op, lhs, rhs } => TypedExprKind::Binary {
op,
lhs: Box::new(self.apply_subst_expr(*lhs)),
rhs: Box::new(self.apply_subst_expr(*rhs)),
},
TypedExprKind::Assign { lval, rval } => TypedExprKind::Assign {
lval: Box::new(self.apply_subst_expr(*lval)),
rval: Box::new(self.apply_subst_expr(*rval)),
},
};
TypedExpr { kind, ty, span }
}
/// Resolves all deferred type constraints accumulated during analysis, such as
/// integer literal sizing, unary negation promotion rules, and binary operator limits.
fn finish(&mut self) {
for (span, inner_ty, result_ty, known_value) in std::mem::take(&mut self.deferred_unary_neg)
{
let inner_resolved = self.apply_subst(&inner_ty);
let result_resolved = self.apply_subst(&result_ty);
let inner_final = if let Ty::Var(_) = inner_resolved {
if result_resolved.is_integer() {
let _ = self.unify(&inner_resolved, &result_resolved);
result_resolved.clone()
} else {
let default = Ty::I32;
let _ = self.unify(&inner_resolved, &default);
default
}
} else {
inner_resolved
};
// Determine the result type of the unary minus operation.
// Signed integers remain the same type. Unsigned integers are promoted to the
// next largest signed integer type to ensure they can safely represent the negative value.
// If the value is known at compile time, we can avoid promoting to a larger size
// as long as the exact value fits within the signed equivalent of the same size.
let promoted_ty = match inner_final {
Ty::I8 | Ty::I16 | Ty::I32 | Ty::I64 => Ok(inner_final),
Ty::U8 => Ok(if known_value.is_some_and(|v| v <= i8::MAX as u64) {
Ty::I8
} else {
Ty::I16
}),
Ty::U16 => Ok(if known_value.is_some_and(|v| v <= i16::MAX as u64) {
Ty::I16
} else {
Ty::I32
}),
Ty::U32 => Ok(if known_value.is_some_and(|v| v <= i32::MAX as u64) {
Ty::I32
} else {
Ty::I64
}),
Ty::U64 => {
if known_value.is_some_and(|v| v <= i64::MAX as u64) {
Ok(Ty::I64)
} else if known_value.is_some() {
Err("value too large to be promoted to signed 64-bit integer")
} else {
Err("cannot promote u64 to a larger signed type")
}
}
_ => Err("unary minus only works on integer types"),
};
match promoted_ty {
Ok(promoted) => {
if let Err(e) = self.unify(&promoted, &result_resolved) {
self.errors.push(SemanticError::new(e, span));
}
}
Err(err) => {
self.errors.push(SemanticError::new(err, span));
}
}
}
for (span, ty) in std::mem::take(&mut self.deferred_binary) {
let resolved = self.apply_subst(&ty);
let final_ty = if let Ty::Var(_) = resolved {
let default = Ty::I32;
let _ = self.unify(&resolved, &default);
default
} else {
resolved
};
if !final_ty.is_integer() {
self.errors.push(SemanticError::new(
"binary operators only work on integer types",
span,
));
}
}
for (span, ty) in std::mem::take(&mut self.deferred_literals) {
let resolved = self.apply_subst(&ty);
let final_ty = if let Ty::Var(_) = resolved {
let default = Ty::I32;
let _ = self.unify(&resolved, &default);
default
} else {
resolved
};
if !final_ty.is_integer() {
self.errors.push(SemanticError::new(
"expected integer type for literal",
span,
));
}
}
}
}
#[cfg(test)]
mod test {
use crate::frontend::{
ast::TypedModule,
parser::Parser,
sema::{Sema, SemanticError},
};
fn analyze(source: &str) -> Result<TypedModule, Vec<SemanticError>> {
let mut parser = Parser::new(source);
let module = parser.parse_module();
if let Some(errors) = parser.errors() {
panic!("Parse errors during semantic analysis test: {:?}", errors);
}
let mut sema = Sema::new();
let typed_module = sema.analyze_module(&module);
if let Some(errors) = sema.errors() {
Err(errors)
} else {
Ok(typed_module)
}
}
#[test]
fn valid_function() {
let src = "fn add(a: i32, b: i32) -> i32 { return a + b; }";
assert!(analyze(src).is_ok());
}
#[test]
fn type_mismatch() {
let src = "fn bad(a: i32) -> bool { return a; }";
let errors = analyze(src).unwrap_err();
assert_eq!(errors.len(), 1);
assert!(errors[0].message.contains("type mismatch"));
}
#[test]
fn undeclared_identifier() {
let src = "fn oops() -> i32 { return x; }";
let errors = analyze(src).unwrap_err();
assert_eq!(errors.len(), 1);
assert!(errors[0].message.contains("undeclared identifier `x`"));
}
#[test]
fn binary_op_non_integer() {
let src = "fn bad_bin() -> bool { return true + false; }";
let errors = analyze(src).unwrap_err();
assert!(errors.iter().any(|e| {
e.message
.contains("binary operators only work on integer types")
}));
}
#[test]
fn unary_neg_promotion_u64_unknown() {
// We cannot promote an unknown u64 parameter since its bounds
// are unconstrained and might overflow signed i64.
let src = "fn test(x: u64) -> i64 { return -x; }";
let errors = analyze(src).unwrap_err();
assert!(errors.iter().any(|e| {
e.message
.contains("cannot promote u64 to a larger signed type")
}));
}
#[test]
fn unary_neg_invalid_type() {
let src = "fn test(x: bool) -> bool { return -x; }";
let errors = analyze(src).unwrap_err();
assert!(errors.iter().any(|e| {
e.message
.contains("unary minus only works on integer types")
}));
}
#[test]
fn valid_logical_not() {
let src = "fn test(a: bool) -> bool { return !a; }";
assert!(analyze(src).is_ok());
}
#[test]
fn invalid_logical_not() {
let src = "fn test(a: i32) -> bool { return !a; }";
let errors = analyze(src).unwrap_err();
assert!(errors.iter().any(|e| e.message.contains("type mismatch")));
}
#[test]
fn valid_comparison() {
let src = "fn test(a: i32, b: i32) -> bool { return a <= b; }";
assert!(analyze(src).is_ok());
}
#[test]
fn invalid_comparison() {
let src = "fn test(a: bool, b: bool) -> bool { return a == b; }";
let errors = analyze(src).unwrap_err();
assert!(errors.iter().any(|e| {
e.message
.contains("binary operators only work on integer types")
}));
}
#[test]
fn valid_if() {
let src = "fn test() { if true { return; } }";
assert!(analyze(src).is_ok())
}
#[test]
fn invalid_if() {
let src = "fn test() { if 12 {} }";
assert!(analyze(src).is_err());
}
#[test]
fn valid_let() {
let src = "fn test() { let a = 5; let b: i32 = a; }";
assert!(analyze(src).is_ok());
}
#[test]
fn let_missing_type() {
let src = "fn test() { let a; }";
assert!(analyze(src).is_err());
}
#[test]
fn valid_expression_stmt() {
let src = "fn test() { 5 + 5; }";
assert!(analyze(src).is_ok());
}
#[test]
fn valid_assignment() {
let src = "fn test() { let a = 5; a = 10; }";
assert!(analyze(src).is_ok());
}
#[test]
fn invalid_assignment_type() {
let src = "fn test() { let a: i32 = 5; a = true; }";
let errors = analyze(src).unwrap_err();
assert!(errors.iter().any(|e| e.message.contains("type mismatch")));
}
#[test]
fn invalid_lvalue() {
let src = "fn test() { 5 = 10; }";
let errors = analyze(src).unwrap_err();
assert!(
errors
.iter()
.any(|e| e.message.contains("invalid left-hand side of assignment"))
);
}
}