From 0162d5b845f85792e915e1f25c973f46fd9bc2ea Mon Sep 17 00:00:00 2001 From: Jooris Hadeler Date: Tue, 21 Apr 2026 22:28:30 +0200 Subject: [PATCH] feat: add constant folding and dead code elimination passes --- src/backend/cranelift.rs | 6 +- src/main.rs | 10 +- src/middle/dce.rs | 197 +++++++++++++++++++++++++++ src/middle/fold.rs | 278 +++++++++++++++++++++++++++++++++++++++ src/middle/mod.rs | 2 + 5 files changed, 487 insertions(+), 6 deletions(-) create mode 100644 src/middle/dce.rs create mode 100644 src/middle/fold.rs diff --git a/src/backend/cranelift.rs b/src/backend/cranelift.rs index beb2ab1..3406757 100644 --- a/src/backend/cranelift.rs +++ b/src/backend/cranelift.rs @@ -64,11 +64,7 @@ impl CraneliftBackend { .optimize(self.module.isa(), &mut ctrl_plane) .unwrap(); - ir_output.push_str(&format!( - "; Function: {}\n{}", - func.name, - self.ctx.func.to_string() - )); + ir_output.push_str(&format!("; Function: {}\n{}", func.name, self.ctx.func)); ir_output.push('\n'); let func_id = self diff --git a/src/main.rs b/src/main.rs index ece3b51..7c5c490 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,6 +5,8 @@ use clap::Parser as ClapParser; use crate::frontend::parser::Parser; use crate::frontend::sema::Sema; use crate::middle::builder::MirBuilder; +use crate::middle::dce::eliminate_dead_code; +use crate::middle::fold::fold_constants; pub mod backend; pub mod frontend; @@ -57,7 +59,13 @@ fn main() { exit(1); } - let mir_module = MirBuilder::build(&typed_module); + let mut mir_module = MirBuilder::build(&typed_module); + fold_constants(&mut mir_module); + let warnings = eliminate_dead_code(&mut mir_module); + + for warning in warnings { + eprintln!("Warning: {} at {:?}", warning.message, warning.span); + } let backend = CraneliftBackend::new(); let (ir, obj_bytes) = backend.compile_module(&mir_module); diff --git a/src/middle/dce.rs b/src/middle/dce.rs new file mode 100644 index 0000000..82c9ad4 --- /dev/null +++ b/src/middle/dce.rs @@ -0,0 +1,197 @@ +use std::collections::{HashMap, HashSet}; + +use crate::frontend::token::Span; +use crate::middle::mir::*; + +#[derive(Debug, PartialEq, Eq)] +pub struct MirWarning { + pub message: String, + pub span: Span, +} + +/// Eliminates unreachable basic blocks from the Control Flow Graph. +pub fn eliminate_dead_code(module: &mut MirModule) -> Vec { + let mut warnings = Vec::new(); + for func in &mut module.functions { + optimize_function(func, &mut warnings); + } + warnings.sort_by_key(|w| w.span.start); + warnings.dedup_by_key(|w| w.span.start); + warnings +} + +fn optimize_function(func: &mut MirFunction, warnings: &mut Vec) { + if func.blocks.is_empty() { + return; + } + + let mut visited = HashSet::new(); + // The first block generated is always the entry block. + let mut worklist = vec![func.blocks[0].id]; + + let block_map: HashMap = func.blocks.iter().map(|b| (b.id, b)).collect(); + + while let Some(id) = worklist.pop() { + if !visited.insert(id) { + continue; // Already visited + } + + if let Some(block) = block_map.get(&id) { + match &block.terminator.kind { + TerminatorKind::Goto { target } => worklist.push(*target), + TerminatorKind::CondBranch { + cond, + target_true, + target_false, + } => match cond { + Operand::Constant(ConstantValue::Boolean(true)) => { + worklist.push(*target_true); + } + Operand::Constant(ConstantValue::Boolean(false)) => { + worklist.push(*target_false); + } + _ => { + worklist.push(*target_true); + worklist.push(*target_false); + } + }, + TerminatorKind::Return { .. } | TerminatorKind::Unreachable => {} + } + } + } + + // Emit warnings for the dead code we are about to remove + for block in &func.blocks { + if !visited.contains(&block.id) { + let is_implicit_unreachable = block.statements.is_empty() + && matches!(block.terminator.kind, TerminatorKind::Unreachable); + + // Do not warn about implicit unreachables (e.g. missing returns handled by builder) + if !is_implicit_unreachable { + let span = block + .statements + .first() + .map(|s| s.span) + .unwrap_or(block.terminator.span); + warnings.push(MirWarning { + message: "unreachable code".to_string(), + span, + }); + } + } + } + + // Retain only the blocks that were successfully reached during our traversal + func.blocks.retain(|b| visited.contains(&b.id)); +} + +#[cfg(test)] +mod test { + use super::*; + use crate::frontend::sema::Ty; + use crate::frontend::token::Span; + + #[test] + fn test_eliminate_dead_blocks() { + let mut module = MirModule { + functions: vec![MirFunction { + name: "test_func".to_string(), + params: vec![], + return_type: Ty::Unit, + locals: vec![], + blocks: vec![ + BasicBlock { + id: BlockId(0), // Entry block + statements: vec![], + terminator: Terminator { + kind: TerminatorKind::Goto { target: BlockId(1) }, + span: Span::new(0, 0), + }, + }, + BasicBlock { + id: BlockId(1), // Reachable + statements: vec![], + terminator: Terminator { + kind: TerminatorKind::Return { value: None }, + span: Span::new(0, 0), + }, + }, + BasicBlock { + id: BlockId(2), // Unreachable + statements: vec![], + terminator: Terminator { + kind: TerminatorKind::Goto { target: BlockId(1) }, + span: Span::new(0, 0), + }, + }, + ], + }], + }; + + let warnings = eliminate_dead_code(&mut module); + assert_eq!(warnings.len(), 1); + + let blocks = &module.functions[0].blocks; + assert_eq!(blocks.len(), 2, "Expected exactly 2 reachable blocks"); + assert!(blocks.iter().any(|b| b.id == BlockId(0))); + assert!(blocks.iter().any(|b| b.id == BlockId(1))); + assert!( + !blocks.iter().any(|b| b.id == BlockId(2)), + "Block 2 should have been eliminated" + ); + } + + #[test] + fn test_eliminate_dead_cond_branch() { + let mut module = MirModule { + functions: vec![MirFunction { + name: "test_cond_func".to_string(), + params: vec![], + return_type: Ty::Unit, + locals: vec![], + blocks: vec![ + BasicBlock { + id: BlockId(0), // Entry block + statements: vec![], + terminator: Terminator { + kind: TerminatorKind::CondBranch { + cond: Operand::Constant(ConstantValue::Boolean(true)), + target_true: BlockId(1), + target_false: BlockId(2), + }, + span: Span::new(0, 0), + }, + }, + BasicBlock { + id: BlockId(1), // Reachable (true branch) + statements: vec![], + terminator: Terminator { + kind: TerminatorKind::Return { value: None }, + span: Span::new(0, 0), + }, + }, + BasicBlock { + id: BlockId(2), // Unreachable (false branch) + statements: vec![], + terminator: Terminator { + kind: TerminatorKind::Return { value: None }, + span: Span::new(0, 0), + }, + }, + ], + }], + }; + + let warnings = eliminate_dead_code(&mut module); + assert_eq!(warnings.len(), 1); + + let blocks = &module.functions[0].blocks; + assert_eq!(blocks.len(), 2, "Expected exactly 2 reachable blocks"); + assert!(blocks.iter().any(|b| b.id == BlockId(0))); + assert!(blocks.iter().any(|b| b.id == BlockId(1))); + assert!( + !blocks.iter().any(|b| b.id == BlockId(2)), + "Block 2 should have been eliminated" + ); + } +} diff --git a/src/middle/fold.rs b/src/middle/fold.rs new file mode 100644 index 0000000..dc1a7f9 --- /dev/null +++ b/src/middle/fold.rs @@ -0,0 +1,278 @@ +use std::collections::HashMap; + +use crate::frontend::ast::{BinaryOp, UnaryOp}; +use crate::frontend::sema::Ty; +use crate::middle::mir::*; + +/// Folds constant expressions and simplifies control flow branches where the condition is known. +pub fn fold_constants(module: &mut MirModule) { + for func in &mut module.functions { + optimize_function(func); + } +} + +fn optimize_function(func: &mut MirFunction) { + for block in &mut func.blocks { + // Block-local constant tracking. This is safe even without strict SSA + // because compiler-generated temporaries are evaluated exactly once + // prior to being used within their basic block. + let mut known_constants = HashMap::new(); + + for stmt in &mut block.statements { + let StatementKind::Assign(local, rvalue) = &mut stmt.kind; + + // Propagate any known constants downwards into the rvalue + propagate_rvalue(rvalue, &known_constants); + + // Attempt to compute the rvalue + if let Some(constant) = evaluate_rvalue(rvalue) { + // Replace the complex instruction with a simple constant use + *rvalue = Rvalue::Use(Operand::Constant(constant.clone())); + known_constants.insert(*local, constant); + } + } + + // Propagate constants into the terminator + match &mut block.terminator.kind { + TerminatorKind::CondBranch { + cond, + target_true, + target_false, + } => { + propagate_operand(cond, &known_constants); + + // If the condition is statically known, fold the branch into a simple Goto! + if let Operand::Constant(ConstantValue::Boolean(val)) = cond { + let target = if *val { *target_true } else { *target_false }; + + block.terminator.kind = TerminatorKind::Goto { target }; + } + } + TerminatorKind::Return { value: Some(val) } => { + propagate_operand(val, &known_constants); + } + _ => {} + } + } +} + +fn propagate_operand(operand: &mut Operand, known_constants: &HashMap) { + if let Operand::Copy(local) = operand + && let Some(constant) = known_constants.get(local) + { + *operand = Operand::Constant(constant.clone()); + } +} + +fn propagate_rvalue(rvalue: &mut Rvalue, known_constants: &HashMap) { + match rvalue { + Rvalue::Use(op) => propagate_operand(op, known_constants), + Rvalue::UnaryOp(_, op) => propagate_operand(op, known_constants), + Rvalue::BinaryOp(_, lhs, rhs) => { + propagate_operand(lhs, known_constants); + propagate_operand(rhs, known_constants); + } + } +} + +fn evaluate_rvalue(rvalue: &Rvalue) -> Option { + match rvalue { + Rvalue::Use(Operand::Constant(c)) => Some(c.clone()), + Rvalue::UnaryOp(op, Operand::Constant(c)) => evaluate_unary(*op, c), + Rvalue::BinaryOp(op, Operand::Constant(l), Operand::Constant(r)) => { + evaluate_binary(*op, l, r) + } + _ => None, + } +} + +fn evaluate_unary(op: UnaryOp, val: &ConstantValue) -> Option { + match (op, val) { + (UnaryOp::Neg, ConstantValue::Integer(v, ty)) => { + Some(ConstantValue::Integer(v.wrapping_neg(), ty.clone())) + } + (UnaryOp::Not, ConstantValue::Boolean(b)) => Some(ConstantValue::Boolean(!b)), + _ => None, + } +} + +fn evaluate_binary( + op: BinaryOp, + lhs: &ConstantValue, + rhs: &ConstantValue, +) -> Option { + match (lhs, rhs) { + (ConstantValue::Integer(l, ty), ConstantValue::Integer(r, _)) => { + let is_signed = matches!(ty, Ty::I8 | Ty::I16 | Ty::I32 | Ty::I64); + + match op { + BinaryOp::Add => Some(ConstantValue::Integer(l.wrapping_add(*r), ty.clone())), + BinaryOp::Sub => Some(ConstantValue::Integer(l.wrapping_sub(*r), ty.clone())), + BinaryOp::Mul => Some(ConstantValue::Integer(l.wrapping_mul(*r), ty.clone())), + + // Avoid dividing by 0 during compile-time constant evaluation + BinaryOp::Div if *r != 0 => { + if is_signed { + // `overflowing_div` safely sidesteps the `i64::MIN / -1` panic + let (result, _) = (*l as i64).overflowing_div(*r as i64); + Some(ConstantValue::Integer(result as u64, ty.clone())) + } else { + Some(ConstantValue::Integer(l.wrapping_div(*r), ty.clone())) + } + } + BinaryOp::Rem if *r != 0 => { + if is_signed { + let (result, _) = (*l as i64).overflowing_rem(*r as i64); + Some(ConstantValue::Integer(result as u64, ty.clone())) + } else { + Some(ConstantValue::Integer(l.wrapping_rem(*r), ty.clone())) + } + } + + BinaryOp::Eq => Some(ConstantValue::Boolean(l == r)), + BinaryOp::Neq => Some(ConstantValue::Boolean(l != r)), + + BinaryOp::Lt => Some(ConstantValue::Boolean(if is_signed { + (*l as i64) < (*r as i64) + } else { + l < r + })), + BinaryOp::Le => Some(ConstantValue::Boolean(if is_signed { + (*l as i64) <= (*r as i64) + } else { + l <= r + })), + BinaryOp::Gt => Some(ConstantValue::Boolean(if is_signed { + (*l as i64) > (*r as i64) + } else { + l > r + })), + BinaryOp::Ge => Some(ConstantValue::Boolean(if is_signed { + (*l as i64) >= (*r as i64) + } else { + l >= r + })), + _ => None, + } + } + + (ConstantValue::Boolean(l), ConstantValue::Boolean(r)) => match op { + BinaryOp::Eq => Some(ConstantValue::Boolean(l == r)), + BinaryOp::Neq => Some(ConstantValue::Boolean(l != r)), + _ => None, + }, + + _ => None, + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::frontend::token::Span; + + #[test] + fn test_fold_arithmetic() { + let func = MirFunction { + name: "test".to_string(), + params: vec![], + return_type: Ty::I32, + locals: vec![], + blocks: vec![BasicBlock { + id: BlockId(0), + statements: vec![ + Statement { + kind: StatementKind::Assign( + LocalId(0), + Rvalue::Use(Operand::Constant(ConstantValue::Integer(5, Ty::I32))), + ), + span: Span::new(0, 0), + }, + Statement { + kind: StatementKind::Assign( + LocalId(1), + Rvalue::Use(Operand::Constant(ConstantValue::Integer(10, Ty::I32))), + ), + span: Span::new(0, 0), + }, + Statement { + kind: StatementKind::Assign( + LocalId(2), + Rvalue::BinaryOp( + BinaryOp::Add, + Operand::Copy(LocalId(0)), + Operand::Copy(LocalId(1)), + ), + ), + span: Span::new(0, 0), + }, + ], + terminator: Terminator { + kind: TerminatorKind::Return { + value: Some(Operand::Copy(LocalId(2))), + }, + span: Span::new(0, 0), + }, + }], + }; + + let mut module = MirModule { + functions: vec![func], + }; + + fold_constants(&mut module); + + let block = &module.functions[0].blocks[0]; + + // The third statement (LocalId(2) = ...) should be folded to 15 + let StatementKind::Assign(_, rvalue) = &block.statements[2].kind; + assert!(matches!( + rvalue, + Rvalue::Use(Operand::Constant(ConstantValue::Integer(15, Ty::I32))) + )); + + // The return terminator should be updated to return the constant 15 directly + assert!(matches!( + &block.terminator.kind, + TerminatorKind::Return { + value: Some(Operand::Constant(ConstantValue::Integer(15, Ty::I32))) + } + )); + } + + #[test] + fn test_fold_cond_branch() { + let func = MirFunction { + name: "test".to_string(), + params: vec![], + return_type: Ty::Unit, + locals: vec![], + blocks: vec![BasicBlock { + id: BlockId(0), + statements: vec![], + terminator: Terminator { + kind: TerminatorKind::CondBranch { + cond: Operand::Constant(ConstantValue::Boolean(true)), + target_true: BlockId(1), + target_false: BlockId(2), + }, + span: Span::new(0, 0), + }, + }], + }; + + let mut module = MirModule { + functions: vec![func], + }; + + fold_constants(&mut module); + + let block = &module.functions[0].blocks[0]; + + // The condition branch should be folded into an unconditional Goto + assert!(matches!( + block.terminator.kind, + TerminatorKind::Goto { target: BlockId(1) } + )); + } +} diff --git a/src/middle/mod.rs b/src/middle/mod.rs index e7f29ad..bd53e44 100644 --- a/src/middle/mod.rs +++ b/src/middle/mod.rs @@ -1,2 +1,4 @@ pub mod builder; +pub mod dce; +pub mod fold; pub mod mir;