From 3e0b5c5b009e9f55d6465ae6484985d36b4cd8b6 Mon Sep 17 00:00:00 2001 From: Jooris Hadeler Date: Tue, 21 Apr 2026 22:08:15 +0200 Subject: [PATCH] feat: unify AST structures, introduce MIR and update codegen --- src/backend/cranelift.rs | 288 ++++++++++++++++++++------------------ src/frontend/ast.rs | 81 ++++++++--- src/frontend/mod.rs | 1 - src/frontend/parser.rs | 29 ++++ src/frontend/sema.rs | 146 ++++++++----------- src/frontend/typed_ast.rs | 60 -------- src/main.rs | 6 +- src/middle/builder.rs | 282 +++++++++++++++++++++++++++++++++++++ src/middle/mir.rs | 98 +++++++++++++ src/middle/mod.rs | 2 + 10 files changed, 686 insertions(+), 307 deletions(-) delete mode 100644 src/frontend/typed_ast.rs create mode 100644 src/middle/builder.rs create mode 100644 src/middle/mir.rs create mode 100644 src/middle/mod.rs diff --git a/src/backend/cranelift.rs b/src/backend/cranelift.rs index c7e1593..beb2ab1 100644 --- a/src/backend/cranelift.rs +++ b/src/backend/cranelift.rs @@ -13,8 +13,8 @@ use cranelift_object::{ObjectBuilder, ObjectModule}; use crate::frontend::{ ast::{BinaryOp, UnaryOp}, sema::Ty, - typed_ast::*, }; +use crate::middle::mir::*; /// The backend responsible for lowering a `TypedModule` into Cranelift IR and /// generating native machine code object files. @@ -49,63 +49,53 @@ impl CraneliftBackend { } } - /// Compiles a fully typed AST module into native object code. + /// Compiles a MIR module into native object code. /// /// Returns a tuple containing the generated Cranelift IR (as a human-readable string) and the assembled object file bytes. - pub fn compile_module(mut self, module: &TypedModule) -> (String, Vec) { + pub fn compile_module(mut self, module: &MirModule) -> (String, Vec) { let mut ir_output = String::new(); - for decl in &module.decls { - match decl { - TypedDecl::Function { - name, - params, - return_type, - body, - } => { - self.compile_function(params, return_type, body); + for func in &module.functions { + self.compile_function(func); - // Run Cranelift's optimization passes before emitting the text IR - let mut ctrl_plane = ControlPlane::default(); - self.ctx - .optimize(self.module.isa(), &mut ctrl_plane) - .unwrap(); + // Run Cranelift's optimization passes before emitting the text IR + let mut ctrl_plane = ControlPlane::default(); + self.ctx + .optimize(self.module.isa(), &mut ctrl_plane) + .unwrap(); - ir_output.push_str(&format!( - "; Function: {}\n{}", - name, - self.ctx.func.to_string() - )); - ir_output.push('\n'); + ir_output.push_str(&format!( + "; Function: {}\n{}", + func.name, + self.ctx.func.to_string() + )); + ir_output.push('\n'); - let func_id = self - .module - .declare_function(name, Linkage::Export, &self.ctx.func.signature) - .unwrap(); + let func_id = self + .module + .declare_function(&func.name, Linkage::Export, &self.ctx.func.signature) + .unwrap(); - self.module.define_function(func_id, &mut self.ctx).unwrap(); - self.module.clear_context(&mut self.ctx); - } - } + self.module.define_function(func_id, &mut self.ctx).unwrap(); + self.module.clear_context(&mut self.ctx); } let obj_bytes = self.module.finish().emit().unwrap(); (ir_output, obj_bytes) } - /// Lowers a single function declaration into Cranelift IR. - /// - /// This sets up the function signature, ABI parameters, entry block, and declares the parameters as local variables. - fn compile_function(&mut self, params: &[(String, Ty)], return_type: &Ty, body: &TypedStmt) { + /// Lowers a single MIR function into Cranelift IR. + fn compile_function(&mut self, func: &MirFunction) { let mut sig = self.module.make_signature(); - for (_, ty) in params { - sig.params.push(AbiParam::new(Self::lower_type(ty))); + for param_id in &func.params { + let param_ty = &func.locals[param_id.0].ty; + sig.params.push(AbiParam::new(Self::lower_type(param_ty))); } - if return_type != &Ty::Unit { + if func.return_type != Ty::Unit { sig.returns - .push(AbiParam::new(Self::lower_type(return_type))); + .push(AbiParam::new(Self::lower_type(&func.return_type))); } self.ctx.func.signature = sig; @@ -113,23 +103,55 @@ impl CraneliftBackend { let mut builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context); - let entry_block = builder.create_block(); - builder.append_block_params_for_function_params(entry_block); - builder.switch_to_block(entry_block); - builder.seal_block(entry_block); - - let mut vars = HashMap::new(); - - for (i, (param_name, ty)) in params.iter().enumerate() { - let var = builder.declare_var(Self::lower_type(ty)); - let val = builder.block_params(entry_block)[i]; - builder.def_var(var, val); - vars.insert(param_name.clone(), var); + let mut block_map = HashMap::new(); + for block in &func.blocks { + block_map.insert(block.id, builder.create_block()); } - let mut trans = FunctionTranslator { builder, vars }; + let mut var_map = HashMap::new(); + for local in &func.locals { + let var = builder.declare_var(Self::lower_type(&local.ty)); + var_map.insert(local.id, var); + } + + let mut trans = FunctionTranslator { + builder, + var_map, + block_map, + locals: &func.locals, + }; + + if let Some(first_block) = func.blocks.first() { + let entry_block = trans.block_map[&first_block.id]; + trans + .builder + .append_block_params_for_function_params(entry_block); + } + + for (i, block) in func.blocks.iter().enumerate() { + let cl_block = trans.block_map[&block.id]; + trans.builder.switch_to_block(cl_block); + + // Retrieve function arguments explicitly if this is the entry block + if i == 0 { + for (j, param_id) in func.params.iter().enumerate() { + let val = trans.builder.block_params(cl_block)[j]; + trans.builder.def_var(trans.var_map[param_id], val); + } + } + + for stmt in &block.statements { + trans.translate_stmt(stmt); + } + + trans.translate_terminator(&block.terminator); + } + + // Seal all blocks now that all branches are translated and predecessors are known + for cl_block in trans.block_map.values() { + trans.builder.seal_block(*cl_block); + } - trans.translate_stmt(body); trans.builder.finalize(); } @@ -146,109 +168,107 @@ impl CraneliftBackend { } } -/// A visitor that traverses typed statements and expressions, emitting Cranelift IR instructions into the current function builder. +/// A visitor that traverses MIR basic blocks and instructions, emitting Cranelift IR instructions +/// into the current function builder. struct FunctionTranslator<'a> { builder: FunctionBuilder<'a>, - vars: HashMap, + var_map: HashMap, + block_map: HashMap, + locals: &'a [LocalDecl], } impl<'a> FunctionTranslator<'a> { - /// Translates a statement, recursively compiling its inner components. - /// Returns `true` if the statement resulted in a basic block terminator. - fn translate_stmt(&mut self, stmt: &TypedStmt) -> bool { - match stmt { - TypedStmt::Compound { inner } => { - for s in inner { - if self.translate_stmt(s) { - return true; - } - } - false - } - TypedStmt::If { - condition, - then, - elze, - } => { - let cond_val = self.translate_expr(condition); - - let then_block = self.builder.create_block(); - let else_block = self.builder.create_block(); - let merge_block = self.builder.create_block(); - - self.builder - .ins() - .brif(cond_val, then_block, &[], else_block, &[]); - - self.builder.switch_to_block(then_block); - self.builder.seal_block(then_block); - let then_terminated = self.translate_stmt(then); - if !then_terminated { - self.builder.ins().jump(merge_block, &[]); - } - - self.builder.switch_to_block(else_block); - self.builder.seal_block(else_block); - let else_terminated = elze - .as_ref() - .map(|stmt| self.translate_stmt(stmt)) - .unwrap_or(false); - if !else_terminated { - self.builder.ins().jump(merge_block, &[]); - } - - self.builder.switch_to_block(merge_block); - self.builder.seal_block(merge_block); - - then_terminated && else_terminated - } - TypedStmt::Return { value } => { - if let Some(expr) = value { - let val = self.translate_expr(expr); - self.builder.ins().return_(&[val]); - } else { - self.builder.ins().return_(&[]); - } - true + fn translate_stmt(&mut self, stmt: &Statement) { + match &stmt.kind { + StatementKind::Assign(local_id, rvalue) => { + let val = self.translate_rvalue(rvalue); + let var = self.var_map[local_id]; + self.builder.def_var(var, val); } } } - /// Translates an expression into a Cranelift IR value. - /// Emits appropriate computation instructions based on operators and operand types. - fn translate_expr(&mut self, expr: &TypedExpr) -> ir::Value { - match &expr.kind { - TypedExprKind::Identifier { name } => { - let var = self.vars.get(name).expect("Undeclared variable"); - self.builder.use_var(*var) + fn translate_terminator(&mut self, term: &Terminator) { + match &term.kind { + TerminatorKind::Goto { target } => { + self.builder.ins().jump(self.block_map[target], &[]); } - TypedExprKind::Integer { value } => { - let ty = CraneliftBackend::lower_type(&expr.ty); - self.builder.ins().iconst(ty, *value as i64) + TerminatorKind::CondBranch { + cond, + target_true, + target_false, + } => { + let cond_val = self.translate_operand(cond); + self.builder.ins().brif( + cond_val, + self.block_map[target_true], + &[], + self.block_map[target_false], + &[], + ); } - TypedExprKind::Boolean { value } => { - let ty = CraneliftBackend::lower_type(&expr.ty); - self.builder.ins().iconst(ty, if *value { 1 } else { 0 }) + TerminatorKind::Return { value } => { + if let Some(op) = value { + let val = self.translate_operand(op); + self.builder.ins().return_(&[val]); + } else { + self.builder.ins().return_(&[]); + } } - TypedExprKind::Unary { op, expr: inner } => { - let inner_val = self.translate_expr(inner); + TerminatorKind::Unreachable => { + self.builder.ins().trap(ir::TrapCode::user(5).unwrap()); + } + } + } + + fn get_operand_type(&self, op: &Operand) -> Ty { + match op { + Operand::Copy(local_id) => self.locals[local_id.0].ty.clone(), + Operand::Constant(ConstantValue::Integer(_, ty)) => ty.clone(), + Operand::Constant(ConstantValue::Boolean(_)) => Ty::Bool, + } + } + + fn translate_operand(&mut self, op: &Operand) -> ir::Value { + match op { + Operand::Copy(local_id) => { + let var = self.var_map[local_id]; + self.builder.use_var(var) + } + Operand::Constant(ConstantValue::Integer(val, ty)) => { + let cl_ty = CraneliftBackend::lower_type(ty); + self.builder.ins().iconst(cl_ty, *val as i64) + } + Operand::Constant(ConstantValue::Boolean(val)) => self + .builder + .ins() + .iconst(types::I8, if *val { 1 } else { 0 }), + } + } + + fn translate_rvalue(&mut self, rvalue: &Rvalue) -> ir::Value { + match rvalue { + Rvalue::Use(op) => self.translate_operand(op), + Rvalue::UnaryOp(op, inner) => { + let inner_val = self.translate_operand(inner); match op { UnaryOp::Neg => self.builder.ins().ineg(inner_val), UnaryOp::Not => { - // `!x` is equivalent to `x == 0` for booleans (0 or 1). - let ty = CraneliftBackend::lower_type(&inner.ty); - let zero = self.builder.ins().iconst(ty, 0); + let ty = self.get_operand_type(inner); + let cl_ty = CraneliftBackend::lower_type(&ty); + let zero = self.builder.ins().iconst(cl_ty, 0); self.builder .ins() .icmp(ir::condcodes::IntCC::Equal, inner_val, zero) } } } - TypedExprKind::Binary { op, lhs, rhs } => { - let lhs_val = self.translate_expr(lhs); - let rhs_val = self.translate_expr(rhs); + Rvalue::BinaryOp(op, lhs, rhs) => { + let lhs_val = self.translate_operand(lhs); + let rhs_val = self.translate_operand(rhs); - let is_signed = matches!(lhs.ty, Ty::I8 | Ty::I16 | Ty::I32 | Ty::I64); + let ty = self.get_operand_type(lhs); + let is_signed = matches!(ty, Ty::I8 | Ty::I16 | Ty::I32 | Ty::I64); match op { BinaryOp::Add => self.builder.ins().iadd(lhs_val, rhs_val), diff --git a/src/frontend/ast.rs b/src/frontend/ast.rs index b85f3b2..218ad58 100644 --- a/src/frontend/ast.rs +++ b/src/frontend/ast.rs @@ -1,24 +1,58 @@ +use crate::frontend::sema::Ty; use crate::frontend::token::Span; +use std::fmt::Debug; -#[derive(Debug, PartialEq, Eq)] -pub struct Module { - pub decls: Vec, +pub trait Phase: Debug + PartialEq + Eq { + type ReturnType: Debug + PartialEq + Eq; + type ParamType: Debug + PartialEq + Eq; + type ExprType: Debug + PartialEq + Eq; } #[derive(Debug, PartialEq, Eq)] -pub struct Decl { - pub kind: DeclKind, +pub struct Untyped; + +impl Phase for Untyped { + type ReturnType = Option; + type ParamType = FunctionParam; + type ExprType = (); +} + +#[derive(Debug, PartialEq, Eq)] +pub struct Typed; + +impl Phase for Typed { + type ReturnType = Ty; + type ParamType = (String, Ty); + type ExprType = Ty; +} + +pub type TypedModule = Module; +pub type TypedDecl = Decl; +pub type TypedDeclKind = DeclKind; +pub type TypedStmt = Stmt; +pub type TypedStmtKind = StmtKind; +pub type TypedExpr = Expr; +pub type TypedExprKind = ExprKind; + +#[derive(Debug, PartialEq, Eq)] +pub struct Module { + pub decls: Vec>, +} + +#[derive(Debug, PartialEq, Eq)] +pub struct Decl { + pub kind: DeclKind

, pub span: Span, } #[derive(Debug, PartialEq, Eq)] -pub enum DeclKind { +pub enum DeclKind { Function { name: String, name_span: Span, - params: Vec, - return_type: Option, - body: Stmt, + params: Vec, + return_type: P::ReturnType, + body: Stmt

, }, } @@ -49,34 +83,35 @@ pub enum TypeKind { } #[derive(Debug, PartialEq, Eq)] -pub struct Stmt { - pub kind: StmtKind, +pub struct Stmt { + pub kind: StmtKind

, pub span: Span, } #[derive(Debug, PartialEq, Eq)] -pub enum StmtKind { +pub enum StmtKind { Compound { - inner: Vec, + inner: Vec>, }, If { - condition: Expr, - then: Box, - elze: Option>, + condition: Expr

, + then: Box>, + elze: Option>>, }, Return { - value: Option, + value: Option>, }, } #[derive(Debug, PartialEq, Eq)] -pub struct Expr { - pub kind: ExprKind, +pub struct Expr { + pub kind: ExprKind

, + pub ty: P::ExprType, pub span: Span, } #[derive(Debug, PartialEq, Eq)] -pub enum ExprKind { +pub enum ExprKind { Identifier { name: String, }, @@ -88,12 +123,12 @@ pub enum ExprKind { }, Unary { op: UnaryOp, - expr: Box, + expr: Box>, }, Binary { op: BinaryOp, - lhs: Box, - rhs: Box, + lhs: Box>, + rhs: Box>, }, } diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index 63591ad..d246a9a 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -3,4 +3,3 @@ pub mod lexer; pub mod parser; pub mod sema; pub mod token; -pub mod typed_ast; diff --git a/src/frontend/parser.rs b/src/frontend/parser.rs index 4d2f174..441cab9 100644 --- a/src/frontend/parser.rs +++ b/src/frontend/parser.rs @@ -431,6 +431,7 @@ impl<'src> Parser<'src> { lhs: Box::new(lhs), rhs: Box::new(rhs), }, + ty: (), span, }; } @@ -451,6 +452,7 @@ impl<'src> Parser<'src> { kind: ExprKind::Identifier { name: token.text.to_string(), }, + ty: (), span: token.span, }) } @@ -472,6 +474,7 @@ impl<'src> Parser<'src> { Ok(Expr { kind: ExprKind::Integer { value }, + ty: (), span: token.span, }) } @@ -483,6 +486,7 @@ impl<'src> Parser<'src> { kind: ExprKind::Boolean { value: token.text == "true", }, + ty: (), span: token.span, }) } @@ -494,6 +498,7 @@ impl<'src> Parser<'src> { Ok(Expr { kind: expr.kind, + ty: (), span: lparen.span.join(rparen.span), }) } @@ -503,6 +508,7 @@ impl<'src> Parser<'src> { let rhs = self.parse_expr_bp(r_bp)?; Ok(Expr { + ty: (), span: op_token.span.join(rhs.span), kind: ExprKind::Unary { op, @@ -598,6 +604,7 @@ mod test { parse("0xBEEF;", Parser::parse_expr), Success(Expr { kind: ExprKind::Integer { value: 0xBEEF }, + ty: (), span: Span::new(0, 6) }) ); @@ -606,6 +613,7 @@ mod test { parse("0o777;", Parser::parse_expr), Success(Expr { kind: ExprKind::Integer { value: 0o777 }, + ty: (), span: Span::new(0, 5) }) ); @@ -614,6 +622,7 @@ mod test { parse("0b1001;", Parser::parse_expr), Success(Expr { kind: ExprKind::Integer { value: 0b1001 }, + ty: (), span: Span::new(0, 6) }) ); @@ -622,6 +631,7 @@ mod test { parse("1337;", Parser::parse_expr), Success(Expr { kind: ExprKind::Integer { value: 1337 }, + ty: (), span: Span::new(0, 4) }) ); @@ -633,6 +643,7 @@ mod test { parse("true;", Parser::parse_expr), Success(Expr { kind: ExprKind::Boolean { value: true }, + ty: (), span: Span::new(0, 4) }) ); @@ -641,6 +652,7 @@ mod test { parse("false;", Parser::parse_expr), Success(Expr { kind: ExprKind::Boolean { value: false }, + ty: (), span: Span::new(0, 5) }) ); @@ -655,9 +667,11 @@ mod test { op: UnaryOp::Neg, expr: Box::new(Expr { kind: ExprKind::Integer { value: 5 }, + ty: (), span: Span::new(1, 2) }) }, + ty: (), span: Span::new(0, 2) }) ); @@ -672,6 +686,7 @@ mod test { op: BinaryOp::Add, lhs: Box::new(Expr { kind: ExprKind::Integer { value: 12 }, + ty: (), span: Span::new(0, 2) }), rhs: Box::new(Expr { @@ -679,16 +694,20 @@ mod test { op: BinaryOp::Mul, lhs: Box::new(Expr { kind: ExprKind::Integer { value: 3 }, + ty: (), span: Span::new(5, 6) }), rhs: Box::new(Expr { kind: ExprKind::Integer { value: 6 }, + ty: (), span: Span::new(9, 10) }) }, + ty: (), span: Span::new(5, 10) }) }, + ty: (), span: Span::new(0, 10) }) ); @@ -710,6 +729,7 @@ mod test { kind: StmtKind::Return { value: Some(Expr { kind: ExprKind::Integer { value: 0 }, + ty: (), span: Span::new(7, 8) }) }, @@ -726,6 +746,7 @@ mod test { kind: StmtKind::If { condition: Expr { kind: ExprKind::Boolean { value: true }, + ty: (), span: Span::new(3, 7) }, then: Box::new(Stmt { @@ -818,15 +839,18 @@ mod test { kind: ExprKind::Identifier { name: "a".to_string() }, + ty: (), span: Span::new(39, 40) }), rhs: Box::new(Expr { kind: ExprKind::Identifier { name: "b".to_string() }, + ty: (), span: Span::new(43, 44) }) }, + ty: (), span: Span::new(39, 44) }) }, @@ -852,13 +876,16 @@ mod test { kind: ExprKind::Identifier { name: "a".to_string() }, + ty: (), span: Span::new(0, 1) }), rhs: Box::new(Expr { kind: ExprKind::Integer { value: 5 }, + ty: (), span: Span::new(5, 6) }) }, + ty: (), span: Span::new(0, 6) }) ); @@ -873,9 +900,11 @@ mod test { op: UnaryOp::Not, expr: Box::new(Expr { kind: ExprKind::Boolean { value: true }, + ty: (), span: Span::new(1, 5) }) }, + ty: (), span: Span::new(0, 5) }) ); diff --git a/src/frontend/sema.rs b/src/frontend/sema.rs index f20eef8..7cc784d 100644 --- a/src/frontend/sema.rs +++ b/src/frontend/sema.rs @@ -2,7 +2,6 @@ use std::collections::HashMap; use crate::frontend::ast::*; use crate::frontend::token::Span; -use crate::frontend::typed_ast::*; /// A structured error produced during semantic analysis, carrying a human-readable /// message and the [Span] of the offending AST node for precise diagnostics. @@ -79,7 +78,6 @@ pub struct Sema { deferred_unary_neg: Vec<(Span, Ty, Ty, Option)>, deferred_binary: Vec<(Span, Ty)>, deferred_literals: Vec<(Span, Ty)>, - is_reachable: bool, } impl Sema { @@ -93,7 +91,6 @@ impl Sema { deferred_unary_neg: Vec::new(), deferred_binary: Vec::new(), deferred_literals: Vec::new(), - is_reachable: true, } } @@ -245,10 +242,10 @@ impl Sema { match &decl.kind { DeclKind::Function { name, + name_span, params, return_type, body, - .. } => { let mut typed_params = Vec::new(); @@ -265,24 +262,19 @@ impl Sema { .map(|t| Ty::from(&t.kind)) .unwrap_or(Ty::Unit); - self.is_reachable = true; - let typed_body = self.analyze_stmt(body, &expected_ret_ty); - if expected_ret_ty != Ty::Unit && self.is_reachable { - self.errors.push(SemanticError::new( - "not all control paths return a value", - decl.span, - )); - } - self.leave_scope(); - TypedDecl::Function { - name: name.clone(), - params: typed_params, - return_type: expected_ret_ty, - body: typed_body, + TypedDecl { + kind: TypedDeclKind::Function { + name: name.clone(), + name_span: *name_span, + params: typed_params, + return_type: expected_ret_ty, + body: typed_body, + }, + span: decl.span, } } } @@ -294,22 +286,19 @@ impl Sema { match &stmt.kind { StmtKind::Compound { inner } => { let mut typed_inner = Vec::new(); - let mut reported_unreachable = false; self.enter_scope(); for s in inner { - if !self.is_reachable && !reported_unreachable { - self.errors - .push(SemanticError::new("unreachable statement", s.span)); - reported_unreachable = true; - } typed_inner.push(self.analyze_stmt(s, expected_ret_ty)); } self.leave_scope(); - TypedStmt::Compound { inner: typed_inner } + TypedStmt { + kind: TypedStmtKind::Compound { inner: typed_inner }, + span: stmt.span, + } } StmtKind::If { condition, @@ -322,29 +311,16 @@ impl Sema { self.errors.push(SemanticError::new(err, condition.span)); } - let initial_reachable = self.is_reachable; - - self.is_reachable = initial_reachable; let typed_then = self.analyze_stmt(then, expected_ret_ty); - let reachable_after_then = self.is_reachable; + let typed_elze = elze.as_ref().map(|e| self.analyze_stmt(e, expected_ret_ty)); - let typed_elze = elze.as_ref().map(|e| { - self.is_reachable = initial_reachable; - self.analyze_stmt(e, expected_ret_ty) - }); - - let reachable_after_else = if elze.is_some() { - self.is_reachable - } else { - initial_reachable - }; - - self.is_reachable = reachable_after_then || reachable_after_else; - - TypedStmt::If { - condition: typed_condition, - then: Box::new(typed_then), - elze: typed_elze.map(Box::new), + TypedStmt { + kind: TypedStmtKind::If { + condition: typed_condition, + then: Box::new(typed_then), + elze: typed_elze.map(Box::new), + }, + span: stmt.span, } } StmtKind::Return { value } => { @@ -355,19 +331,21 @@ impl Sema { self.errors.push(SemanticError::new(err, expr.span)); } - self.is_reachable = false; - - TypedStmt::Return { - value: Some(typed_expr), + TypedStmt { + kind: TypedStmtKind::Return { + value: Some(typed_expr), + }, + span: stmt.span, } } else { if let Err(err) = self.unify(&Ty::Unit, expected_ret_ty) { self.errors.push(SemanticError::new(err, stmt.span)); } - self.is_reachable = false; - - TypedStmt::Return { value: None } + TypedStmt { + kind: TypedStmtKind::Return { value: None }, + span: stmt.span, + } } } } @@ -392,6 +370,7 @@ impl Sema { TypedExpr { kind: TypedExprKind::Identifier { name: name.clone() }, ty, + span: expr.span, } } @@ -402,12 +381,14 @@ impl Sema { TypedExpr { kind: TypedExprKind::Integer { value: *value }, ty, + span: expr.span, } } ExprKind::Boolean { value } => TypedExpr { kind: TypedExprKind::Boolean { value: *value }, ty: Ty::Bool, + span: expr.span, }, ExprKind::Unary { @@ -435,6 +416,7 @@ impl Sema { expr: Box::new(typed_inner), }, ty: result_ty, + span: expr.span, } } @@ -454,6 +436,7 @@ impl Sema { expr: Box::new(typed_inner), }, ty: Ty::Bool, + span: expr.span, } } @@ -490,6 +473,7 @@ impl Sema { rhs: Box::new(typed_rhs), }, ty: result_ty, + span: expr.span, } } } @@ -497,9 +481,11 @@ impl Sema { /// Recursively applies the final resolved type substitutions to a typed declaration. fn apply_subst_decl(&self, decl: TypedDecl) -> TypedDecl { - match decl { - TypedDecl::Function { + let span = decl.span; + let kind = match decl.kind { + TypedDeclKind::Function { name, + name_span, params, return_type, body, @@ -509,45 +495,52 @@ impl Sema { .map(|(n, ty)| (n, self.apply_subst(&ty))) .collect(); - TypedDecl::Function { + TypedDeclKind::Function { name, + name_span, params, return_type: self.apply_subst(&return_type), body: self.apply_subst_stmt(body), } } - } + }; + + TypedDecl { kind, span } } /// Recursively applies the final resolved type substitutions to a typed statement. fn apply_subst_stmt(&self, stmt: TypedStmt) -> TypedStmt { - match stmt { - TypedStmt::Compound { inner } => TypedStmt::Compound { + let span = stmt.span; + let kind = match stmt.kind { + TypedStmtKind::Compound { inner } => TypedStmtKind::Compound { inner: inner .into_iter() .map(|s| self.apply_subst_stmt(s)) .collect(), }, - TypedStmt::If { + TypedStmtKind::If { condition, then, elze, - } => TypedStmt::If { + } => TypedStmtKind::If { condition: self.apply_subst_expr(condition), then: Box::new(self.apply_subst_stmt(*then)), elze: elze.map(|s| Box::new(self.apply_subst_stmt(*s))), }, - TypedStmt::Return { value } => TypedStmt::Return { + TypedStmtKind::Return { value } => TypedStmtKind::Return { value: value.map(|e| self.apply_subst_expr(e)), }, - } + }; + + TypedStmt { kind, span } } /// Recursively applies the final resolved type substitutions to a typed expression. fn apply_subst_expr(&self, expr: TypedExpr) -> TypedExpr { let ty = self.apply_subst(&expr.ty); + let span = expr.span; let kind = match expr.kind { TypedExprKind::Identifier { name } => TypedExprKind::Identifier { name }, TypedExprKind::Integer { value } => TypedExprKind::Integer { value }, @@ -565,7 +558,7 @@ impl Sema { }, }; - TypedExpr { kind, ty } + TypedExpr { kind, ty, span } } /// Resolves all deferred type constraints accumulated during analysis, such as @@ -682,9 +675,9 @@ impl Sema { #[cfg(test)] mod test { use crate::frontend::{ + ast::TypedModule, parser::Parser, sema::{Sema, SemanticError}, - typed_ast::TypedModule, }; fn analyze(source: &str) -> Result> { @@ -804,27 +797,4 @@ mod test { let src = "fn test() { if 12 {} }"; assert!(analyze(src).is_err()); } - - #[test] - fn not_all_paths_return() { - let src = "fn test(a: i32) -> i32 { if a < 5 { return 5; } else { } }"; - assert!(analyze(src).is_err()); - - let src = "fn test() -> i32 { }"; - assert!(analyze(src).is_err()); - - let src = "fn test(a: i32) -> i32 { if a < 5 { return 5; } return 10; }"; - assert!(analyze(src).is_ok()); - } - - #[test] - fn unreachable_code() { - let src = "fn test() -> i32 { return 5; return 10; }"; - let errors = analyze(src).unwrap_err(); - assert!( - errors - .iter() - .any(|e| e.message.contains("unreachable statement")) - ); - } } diff --git a/src/frontend/typed_ast.rs b/src/frontend/typed_ast.rs deleted file mode 100644 index bcc2c14..0000000 --- a/src/frontend/typed_ast.rs +++ /dev/null @@ -1,60 +0,0 @@ -use crate::frontend::ast::{BinaryOp, UnaryOp}; -use crate::frontend::sema::Ty; - -#[derive(Debug, PartialEq, Eq)] -pub struct TypedModule { - pub decls: Vec, -} - -#[derive(Debug, PartialEq, Eq)] -pub enum TypedDecl { - Function { - name: String, - params: Vec<(String, Ty)>, - return_type: Ty, - body: TypedStmt, - }, -} - -#[derive(Debug, PartialEq, Eq)] -pub enum TypedStmt { - Compound { - inner: Vec, - }, - If { - condition: TypedExpr, - then: Box, - elze: Option>, - }, - Return { - value: Option, - }, -} - -#[derive(Debug, PartialEq, Eq)] -pub struct TypedExpr { - pub kind: TypedExprKind, - pub ty: Ty, -} - -#[derive(Debug, PartialEq, Eq)] -pub enum TypedExprKind { - Identifier { - name: String, - }, - Integer { - value: u64, - }, - Boolean { - value: bool, - }, - Unary { - op: UnaryOp, - expr: Box, - }, - Binary { - op: BinaryOp, - lhs: Box, - rhs: Box, - }, -} diff --git a/src/main.rs b/src/main.rs index 3475d9e..ece3b51 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,9 +4,11 @@ use clap::Parser as ClapParser; use crate::frontend::parser::Parser; use crate::frontend::sema::Sema; +use crate::middle::builder::MirBuilder; pub mod backend; pub mod frontend; +pub mod middle; use crate::backend::cranelift::CraneliftBackend; @@ -55,8 +57,10 @@ fn main() { exit(1); } + let mir_module = MirBuilder::build(&typed_module); + let backend = CraneliftBackend::new(); - let (ir, obj_bytes) = backend.compile_module(&typed_module); + let (ir, obj_bytes) = backend.compile_module(&mir_module); if cli.emit_ir { println!("{}", ir); diff --git a/src/middle/builder.rs b/src/middle/builder.rs new file mode 100644 index 0000000..bc9063c --- /dev/null +++ b/src/middle/builder.rs @@ -0,0 +1,282 @@ +use std::collections::HashMap; + +use crate::frontend::ast::*; +use crate::frontend::sema::Ty; +use crate::middle::mir::*; + +/// Lowers a fully-typed AST into a linear Control Flow Graph (MIR). +pub struct MirBuilder; + +impl MirBuilder { + /// Builds a `MirModule` from a `TypedModule`. + pub fn build(module: &TypedModule) -> MirModule { + let mut functions = Vec::new(); + + for decl in &module.decls { + match &decl.kind { + TypedDeclKind::Function { + name, + name_span: _, + params, + return_type, + body, + } => { + let mut builder = FuncBuilder::new(name.clone(), return_type.clone()); + + // Register parameters as local variables + for (param_name, ty) in params { + let local_id = builder.new_local(param_name.clone(), ty.clone()); + builder.params.push(local_id); + } + + let entry = builder.new_block(); + builder.switch_to_block(entry); + + builder.lower_stmt(body); + + // If the final block was never terminated (e.g. missing an explicit return), + // we insert an implicit return for Unit, or an Unreachable trap otherwise. + if builder.current_block.is_some() { + let span = body.span; + if *return_type == Ty::Unit { + builder.terminate(Terminator { + kind: TerminatorKind::Return { value: None }, + span, + }); + } else { + builder.terminate(Terminator { + kind: TerminatorKind::Unreachable, + span, + }); + } + } + + functions.push(builder.finish()); + } + } + } + + MirModule { functions } + } +} + +/// A helper struct that manages the state required to construct a single `MirFunction`. +/// +/// It keeps track of variable declarations, basic blocks, and handles the flattening +/// of nested AST nodes into a sequential Control Flow Graph. +struct FuncBuilder { + /// The name of the function being built. + name: String, + /// The expected return type of the function. + return_type: Ty, + /// `LocalId`s corresponding to the function's parameters. + params: Vec, + /// All local variables and compiler-generated temporaries used in the function. + locals: Vec, + + /// The basic blocks that make up the function's control flow graph. + blocks: Vec, + /// The block currently being populated with statements. If `None`, the builder is in "dead code" territory. + current_block: Option, + /// Statements buffered for the current block before a terminator is reached. + current_statements: Vec, + /// Counter for generating unique `BlockId`s. + next_block_id: usize, + + /// Mapping from user-defined variable names to their corresponding `LocalId`. + vars: HashMap, +} + +impl FuncBuilder { + /// Creates a new `FuncBuilder` for a function with the given name and return type. + fn new(name: String, return_type: Ty) -> Self { + Self { + name, + return_type, + params: Vec::new(), + locals: Vec::new(), + blocks: Vec::new(), + current_block: None, + current_statements: Vec::new(), + next_block_id: 0, + vars: HashMap::new(), + } + } + + /// Registers a new user-defined local variable and returns its `LocalId`. + fn new_local(&mut self, name: String, ty: Ty) -> LocalId { + let id = LocalId(self.locals.len()); + self.locals.push(LocalDecl { + id, + ty, + mutable: false, + name: Some(name.clone()), + }); + self.vars.insert(name, id); + id + } + + /// Creates a new compiler-generated temporary variable and returns its `LocalId`. + fn new_temp(&mut self, ty: Ty) -> LocalId { + let id = LocalId(self.locals.len()); + self.locals.push(LocalDecl { + id, + ty, + mutable: false, + name: None, + }); + id + } + + /// Allocates a new, empty basic block and returns its `BlockId`. + fn new_block(&mut self) -> BlockId { + let id = BlockId(self.next_block_id); + self.next_block_id += 1; + id + } + + /// Sets the given block as the active insertion point for new statements. + /// Note: The previous block should be terminated before switching. + fn switch_to_block(&mut self, id: BlockId) { + assert!(self.current_statements.is_empty()); + self.current_block = Some(id); + } + + /// Appends a statement to the current basic block. + /// If the current block has already been terminated, the statement is ignored (dead code elimination). + fn emit_stmt(&mut self, stmt: Statement) { + if self.current_block.is_some() { + self.current_statements.push(stmt); + } + } + + /// Terminates the current basic block, sealing it and finalizing its instructions. + /// Subsequent statements will be ignored until `switch_to_block` is called. + fn terminate(&mut self, terminator: Terminator) { + if let Some(id) = self.current_block.take() { + self.blocks.push(BasicBlock { + id, + statements: std::mem::take(&mut self.current_statements), + terminator, + }); + } + } + + /// Finalizes the function construction, returning the complete `MirFunction`. + fn finish(mut self) -> MirFunction { + // Ensure basic blocks are strictly ordered by their numerical IDs + self.blocks.sort_by_key(|b| b.id.0); + + MirFunction { + name: self.name, + params: self.params, + return_type: self.return_type, + locals: self.locals, + blocks: self.blocks, + } + } + + /// Recursively lowers a typed statement into MIR instructions and basic blocks. + fn lower_stmt(&mut self, stmt: &TypedStmt) { + match &stmt.kind { + TypedStmtKind::Compound { inner } => { + for s in inner { + self.lower_stmt(s); + } + } + TypedStmtKind::If { + condition, + then, + elze, + } => { + let cond_op = self.lower_expr(condition); + + let then_block = self.new_block(); + let merge_block = self.new_block(); + let else_block = if elze.is_some() { + self.new_block() + } else { + merge_block + }; + + self.terminate(Terminator { + kind: TerminatorKind::CondBranch { + cond: cond_op, + target_true: then_block, + target_false: else_block, + }, + span: stmt.span, + }); + + self.switch_to_block(then_block); + self.lower_stmt(then); + self.terminate(Terminator { + kind: TerminatorKind::Goto { + target: merge_block, + }, + span: then.span, + }); + + if let Some(e) = elze { + self.switch_to_block(else_block); + self.lower_stmt(e); + self.terminate(Terminator { + kind: TerminatorKind::Goto { + target: merge_block, + }, + span: e.span, + }); + } + + self.switch_to_block(merge_block); + } + TypedStmtKind::Return { value } => { + let val_op = value.as_ref().map(|v| self.lower_expr(v)); + self.terminate(Terminator { + kind: TerminatorKind::Return { value: val_op }, + span: stmt.span, + }); + } + } + } + + /// Recursively lowers a typed expression into MIR instructions, returning an `Operand` representing its result. + fn lower_expr(&mut self, expr: &TypedExpr) -> Operand { + match &expr.kind { + TypedExprKind::Identifier { name } => { + let local = *self + .vars + .get(name) + .expect("undeclared variable in MIR lowering"); + Operand::Copy(local) + } + TypedExprKind::Integer { value } => { + Operand::Constant(ConstantValue::Integer(*value, expr.ty.clone())) + } + TypedExprKind::Boolean { value } => Operand::Constant(ConstantValue::Boolean(*value)), + TypedExprKind::Unary { op, expr: inner } => { + let inner_op = self.lower_expr(inner); + let temp = self.new_temp(expr.ty.clone()); + + self.emit_stmt(Statement { + kind: StatementKind::Assign(temp, Rvalue::UnaryOp(*op, inner_op)), + span: expr.span, + }); + + Operand::Copy(temp) + } + TypedExprKind::Binary { op, lhs, rhs } => { + let lhs_op = self.lower_expr(lhs); + let rhs_op = self.lower_expr(rhs); + let temp = self.new_temp(expr.ty.clone()); + + self.emit_stmt(Statement { + kind: StatementKind::Assign(temp, Rvalue::BinaryOp(*op, lhs_op, rhs_op)), + span: expr.span, + }); + + Operand::Copy(temp) + } + } + } +} diff --git a/src/middle/mir.rs b/src/middle/mir.rs new file mode 100644 index 0000000..3c3939f --- /dev/null +++ b/src/middle/mir.rs @@ -0,0 +1,98 @@ +use crate::frontend::ast::{BinaryOp, UnaryOp}; +use crate::frontend::sema::Ty; +use crate::frontend::token::Span; + +/// A strongly-typed reference to a basic block. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct BlockId(pub usize); + +/// A strongly-typed reference to a local variable or temporary. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct LocalId(pub usize); + +#[derive(Debug)] +pub struct MirModule { + pub functions: Vec, +} + +#[derive(Debug)] +pub struct MirFunction { + pub name: String, + pub params: Vec, + pub return_type: Ty, + pub locals: Vec, + pub blocks: Vec, +} + +/// A declaration of a local variable or compiler-generated temporary. +#[derive(Debug)] +pub struct LocalDecl { + pub id: LocalId, + pub ty: Ty, + pub mutable: bool, + /// Contains the name if it is a user-declared variable, or `None` if it's a compiler temporary. + pub name: Option, +} + +/// A sequential list of non-branching statements followed by a single terminator. +#[derive(Debug)] +pub struct BasicBlock { + pub id: BlockId, + pub statements: Vec, + pub terminator: Terminator, +} + +#[derive(Debug)] +pub struct Statement { + pub kind: StatementKind, + pub span: Span, +} + +#[derive(Debug)] +pub enum StatementKind { + /// Assigns the result of an Rvalue to a local variable or temporary. + Assign(LocalId, Rvalue), +} + +/// Operations that produce a value. +#[derive(Debug)] +pub enum Rvalue { + Use(Operand), + UnaryOp(UnaryOp, Operand), + BinaryOp(BinaryOp, Operand, Operand), +} + +/// An atomic value used as inputs to instructions. +#[derive(Debug, Clone)] +pub enum Operand { + Copy(LocalId), + Constant(ConstantValue), +} + +#[derive(Debug, Clone)] +pub enum ConstantValue { + Integer(u64, Ty), + Boolean(bool), +} + +#[derive(Debug)] +pub struct Terminator { + pub kind: TerminatorKind, + pub span: Span, +} + +#[derive(Debug)] +pub enum TerminatorKind { + Goto { + target: BlockId, + }, + CondBranch { + cond: Operand, + target_true: BlockId, + target_false: BlockId, + }, + Return { + value: Option, + }, + Unreachable, +} diff --git a/src/middle/mod.rs b/src/middle/mod.rs new file mode 100644 index 0000000..e7f29ad --- /dev/null +++ b/src/middle/mod.rs @@ -0,0 +1,2 @@ +pub mod builder; +pub mod mir;