feat: add constant folding and dead code elimination passes

This commit is contained in:
2026-04-21 22:28:30 +02:00
parent 3e0b5c5b00
commit 0162d5b845
5 changed files with 487 additions and 6 deletions
+1 -5
View File
@@ -64,11 +64,7 @@ impl CraneliftBackend {
.optimize(self.module.isa(), &mut ctrl_plane)
.unwrap();
ir_output.push_str(&format!(
"; Function: {}\n{}",
func.name,
self.ctx.func.to_string()
));
ir_output.push_str(&format!("; Function: {}\n{}", func.name, self.ctx.func));
ir_output.push('\n');
let func_id = self
+9 -1
View File
@@ -5,6 +5,8 @@ use clap::Parser as ClapParser;
use crate::frontend::parser::Parser;
use crate::frontend::sema::Sema;
use crate::middle::builder::MirBuilder;
use crate::middle::dce::eliminate_dead_code;
use crate::middle::fold::fold_constants;
pub mod backend;
pub mod frontend;
@@ -57,7 +59,13 @@ fn main() {
exit(1);
}
let mir_module = MirBuilder::build(&typed_module);
let mut mir_module = MirBuilder::build(&typed_module);
fold_constants(&mut mir_module);
let warnings = eliminate_dead_code(&mut mir_module);
for warning in warnings {
eprintln!("Warning: {} at {:?}", warning.message, warning.span);
}
let backend = CraneliftBackend::new();
let (ir, obj_bytes) = backend.compile_module(&mir_module);
+197
View File
@@ -0,0 +1,197 @@
use std::collections::{HashMap, HashSet};
use crate::frontend::token::Span;
use crate::middle::mir::*;
#[derive(Debug, PartialEq, Eq)]
pub struct MirWarning {
pub message: String,
pub span: Span,
}
/// Eliminates unreachable basic blocks from the Control Flow Graph.
pub fn eliminate_dead_code(module: &mut MirModule) -> Vec<MirWarning> {
let mut warnings = Vec::new();
for func in &mut module.functions {
optimize_function(func, &mut warnings);
}
warnings.sort_by_key(|w| w.span.start);
warnings.dedup_by_key(|w| w.span.start);
warnings
}
fn optimize_function(func: &mut MirFunction, warnings: &mut Vec<MirWarning>) {
if func.blocks.is_empty() {
return;
}
let mut visited = HashSet::new();
// The first block generated is always the entry block.
let mut worklist = vec![func.blocks[0].id];
let block_map: HashMap<BlockId, &BasicBlock> = func.blocks.iter().map(|b| (b.id, b)).collect();
while let Some(id) = worklist.pop() {
if !visited.insert(id) {
continue; // Already visited
}
if let Some(block) = block_map.get(&id) {
match &block.terminator.kind {
TerminatorKind::Goto { target } => worklist.push(*target),
TerminatorKind::CondBranch {
cond,
target_true,
target_false,
} => match cond {
Operand::Constant(ConstantValue::Boolean(true)) => {
worklist.push(*target_true);
}
Operand::Constant(ConstantValue::Boolean(false)) => {
worklist.push(*target_false);
}
_ => {
worklist.push(*target_true);
worklist.push(*target_false);
}
},
TerminatorKind::Return { .. } | TerminatorKind::Unreachable => {}
}
}
}
// Emit warnings for the dead code we are about to remove
for block in &func.blocks {
if !visited.contains(&block.id) {
let is_implicit_unreachable = block.statements.is_empty()
&& matches!(block.terminator.kind, TerminatorKind::Unreachable);
// Do not warn about implicit unreachables (e.g. missing returns handled by builder)
if !is_implicit_unreachable {
let span = block
.statements
.first()
.map(|s| s.span)
.unwrap_or(block.terminator.span);
warnings.push(MirWarning {
message: "unreachable code".to_string(),
span,
});
}
}
}
// Retain only the blocks that were successfully reached during our traversal
func.blocks.retain(|b| visited.contains(&b.id));
}
#[cfg(test)]
mod test {
use super::*;
use crate::frontend::sema::Ty;
use crate::frontend::token::Span;
#[test]
fn test_eliminate_dead_blocks() {
let mut module = MirModule {
functions: vec![MirFunction {
name: "test_func".to_string(),
params: vec![],
return_type: Ty::Unit,
locals: vec![],
blocks: vec![
BasicBlock {
id: BlockId(0), // Entry block
statements: vec![],
terminator: Terminator {
kind: TerminatorKind::Goto { target: BlockId(1) },
span: Span::new(0, 0),
},
},
BasicBlock {
id: BlockId(1), // Reachable
statements: vec![],
terminator: Terminator {
kind: TerminatorKind::Return { value: None },
span: Span::new(0, 0),
},
},
BasicBlock {
id: BlockId(2), // Unreachable
statements: vec![],
terminator: Terminator {
kind: TerminatorKind::Goto { target: BlockId(1) },
span: Span::new(0, 0),
},
},
],
}],
};
let warnings = eliminate_dead_code(&mut module);
assert_eq!(warnings.len(), 1);
let blocks = &module.functions[0].blocks;
assert_eq!(blocks.len(), 2, "Expected exactly 2 reachable blocks");
assert!(blocks.iter().any(|b| b.id == BlockId(0)));
assert!(blocks.iter().any(|b| b.id == BlockId(1)));
assert!(
!blocks.iter().any(|b| b.id == BlockId(2)),
"Block 2 should have been eliminated"
);
}
#[test]
fn test_eliminate_dead_cond_branch() {
let mut module = MirModule {
functions: vec![MirFunction {
name: "test_cond_func".to_string(),
params: vec![],
return_type: Ty::Unit,
locals: vec![],
blocks: vec![
BasicBlock {
id: BlockId(0), // Entry block
statements: vec![],
terminator: Terminator {
kind: TerminatorKind::CondBranch {
cond: Operand::Constant(ConstantValue::Boolean(true)),
target_true: BlockId(1),
target_false: BlockId(2),
},
span: Span::new(0, 0),
},
},
BasicBlock {
id: BlockId(1), // Reachable (true branch)
statements: vec![],
terminator: Terminator {
kind: TerminatorKind::Return { value: None },
span: Span::new(0, 0),
},
},
BasicBlock {
id: BlockId(2), // Unreachable (false branch)
statements: vec![],
terminator: Terminator {
kind: TerminatorKind::Return { value: None },
span: Span::new(0, 0),
},
},
],
}],
};
let warnings = eliminate_dead_code(&mut module);
assert_eq!(warnings.len(), 1);
let blocks = &module.functions[0].blocks;
assert_eq!(blocks.len(), 2, "Expected exactly 2 reachable blocks");
assert!(blocks.iter().any(|b| b.id == BlockId(0)));
assert!(blocks.iter().any(|b| b.id == BlockId(1)));
assert!(
!blocks.iter().any(|b| b.id == BlockId(2)),
"Block 2 should have been eliminated"
);
}
}
+278
View File
@@ -0,0 +1,278 @@
use std::collections::HashMap;
use crate::frontend::ast::{BinaryOp, UnaryOp};
use crate::frontend::sema::Ty;
use crate::middle::mir::*;
/// Folds constant expressions and simplifies control flow branches where the condition is known.
pub fn fold_constants(module: &mut MirModule) {
for func in &mut module.functions {
optimize_function(func);
}
}
fn optimize_function(func: &mut MirFunction) {
for block in &mut func.blocks {
// Block-local constant tracking. This is safe even without strict SSA
// because compiler-generated temporaries are evaluated exactly once
// prior to being used within their basic block.
let mut known_constants = HashMap::new();
for stmt in &mut block.statements {
let StatementKind::Assign(local, rvalue) = &mut stmt.kind;
// Propagate any known constants downwards into the rvalue
propagate_rvalue(rvalue, &known_constants);
// Attempt to compute the rvalue
if let Some(constant) = evaluate_rvalue(rvalue) {
// Replace the complex instruction with a simple constant use
*rvalue = Rvalue::Use(Operand::Constant(constant.clone()));
known_constants.insert(*local, constant);
}
}
// Propagate constants into the terminator
match &mut block.terminator.kind {
TerminatorKind::CondBranch {
cond,
target_true,
target_false,
} => {
propagate_operand(cond, &known_constants);
// If the condition is statically known, fold the branch into a simple Goto!
if let Operand::Constant(ConstantValue::Boolean(val)) = cond {
let target = if *val { *target_true } else { *target_false };
block.terminator.kind = TerminatorKind::Goto { target };
}
}
TerminatorKind::Return { value: Some(val) } => {
propagate_operand(val, &known_constants);
}
_ => {}
}
}
}
fn propagate_operand(operand: &mut Operand, known_constants: &HashMap<LocalId, ConstantValue>) {
if let Operand::Copy(local) = operand
&& let Some(constant) = known_constants.get(local)
{
*operand = Operand::Constant(constant.clone());
}
}
fn propagate_rvalue(rvalue: &mut Rvalue, known_constants: &HashMap<LocalId, ConstantValue>) {
match rvalue {
Rvalue::Use(op) => propagate_operand(op, known_constants),
Rvalue::UnaryOp(_, op) => propagate_operand(op, known_constants),
Rvalue::BinaryOp(_, lhs, rhs) => {
propagate_operand(lhs, known_constants);
propagate_operand(rhs, known_constants);
}
}
}
fn evaluate_rvalue(rvalue: &Rvalue) -> Option<ConstantValue> {
match rvalue {
Rvalue::Use(Operand::Constant(c)) => Some(c.clone()),
Rvalue::UnaryOp(op, Operand::Constant(c)) => evaluate_unary(*op, c),
Rvalue::BinaryOp(op, Operand::Constant(l), Operand::Constant(r)) => {
evaluate_binary(*op, l, r)
}
_ => None,
}
}
fn evaluate_unary(op: UnaryOp, val: &ConstantValue) -> Option<ConstantValue> {
match (op, val) {
(UnaryOp::Neg, ConstantValue::Integer(v, ty)) => {
Some(ConstantValue::Integer(v.wrapping_neg(), ty.clone()))
}
(UnaryOp::Not, ConstantValue::Boolean(b)) => Some(ConstantValue::Boolean(!b)),
_ => None,
}
}
fn evaluate_binary(
op: BinaryOp,
lhs: &ConstantValue,
rhs: &ConstantValue,
) -> Option<ConstantValue> {
match (lhs, rhs) {
(ConstantValue::Integer(l, ty), ConstantValue::Integer(r, _)) => {
let is_signed = matches!(ty, Ty::I8 | Ty::I16 | Ty::I32 | Ty::I64);
match op {
BinaryOp::Add => Some(ConstantValue::Integer(l.wrapping_add(*r), ty.clone())),
BinaryOp::Sub => Some(ConstantValue::Integer(l.wrapping_sub(*r), ty.clone())),
BinaryOp::Mul => Some(ConstantValue::Integer(l.wrapping_mul(*r), ty.clone())),
// Avoid dividing by 0 during compile-time constant evaluation
BinaryOp::Div if *r != 0 => {
if is_signed {
// `overflowing_div` safely sidesteps the `i64::MIN / -1` panic
let (result, _) = (*l as i64).overflowing_div(*r as i64);
Some(ConstantValue::Integer(result as u64, ty.clone()))
} else {
Some(ConstantValue::Integer(l.wrapping_div(*r), ty.clone()))
}
}
BinaryOp::Rem if *r != 0 => {
if is_signed {
let (result, _) = (*l as i64).overflowing_rem(*r as i64);
Some(ConstantValue::Integer(result as u64, ty.clone()))
} else {
Some(ConstantValue::Integer(l.wrapping_rem(*r), ty.clone()))
}
}
BinaryOp::Eq => Some(ConstantValue::Boolean(l == r)),
BinaryOp::Neq => Some(ConstantValue::Boolean(l != r)),
BinaryOp::Lt => Some(ConstantValue::Boolean(if is_signed {
(*l as i64) < (*r as i64)
} else {
l < r
})),
BinaryOp::Le => Some(ConstantValue::Boolean(if is_signed {
(*l as i64) <= (*r as i64)
} else {
l <= r
})),
BinaryOp::Gt => Some(ConstantValue::Boolean(if is_signed {
(*l as i64) > (*r as i64)
} else {
l > r
})),
BinaryOp::Ge => Some(ConstantValue::Boolean(if is_signed {
(*l as i64) >= (*r as i64)
} else {
l >= r
})),
_ => None,
}
}
(ConstantValue::Boolean(l), ConstantValue::Boolean(r)) => match op {
BinaryOp::Eq => Some(ConstantValue::Boolean(l == r)),
BinaryOp::Neq => Some(ConstantValue::Boolean(l != r)),
_ => None,
},
_ => None,
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::frontend::token::Span;
#[test]
fn test_fold_arithmetic() {
let func = MirFunction {
name: "test".to_string(),
params: vec![],
return_type: Ty::I32,
locals: vec![],
blocks: vec![BasicBlock {
id: BlockId(0),
statements: vec![
Statement {
kind: StatementKind::Assign(
LocalId(0),
Rvalue::Use(Operand::Constant(ConstantValue::Integer(5, Ty::I32))),
),
span: Span::new(0, 0),
},
Statement {
kind: StatementKind::Assign(
LocalId(1),
Rvalue::Use(Operand::Constant(ConstantValue::Integer(10, Ty::I32))),
),
span: Span::new(0, 0),
},
Statement {
kind: StatementKind::Assign(
LocalId(2),
Rvalue::BinaryOp(
BinaryOp::Add,
Operand::Copy(LocalId(0)),
Operand::Copy(LocalId(1)),
),
),
span: Span::new(0, 0),
},
],
terminator: Terminator {
kind: TerminatorKind::Return {
value: Some(Operand::Copy(LocalId(2))),
},
span: Span::new(0, 0),
},
}],
};
let mut module = MirModule {
functions: vec![func],
};
fold_constants(&mut module);
let block = &module.functions[0].blocks[0];
// The third statement (LocalId(2) = ...) should be folded to 15
let StatementKind::Assign(_, rvalue) = &block.statements[2].kind;
assert!(matches!(
rvalue,
Rvalue::Use(Operand::Constant(ConstantValue::Integer(15, Ty::I32)))
));
// The return terminator should be updated to return the constant 15 directly
assert!(matches!(
&block.terminator.kind,
TerminatorKind::Return {
value: Some(Operand::Constant(ConstantValue::Integer(15, Ty::I32)))
}
));
}
#[test]
fn test_fold_cond_branch() {
let func = MirFunction {
name: "test".to_string(),
params: vec![],
return_type: Ty::Unit,
locals: vec![],
blocks: vec![BasicBlock {
id: BlockId(0),
statements: vec![],
terminator: Terminator {
kind: TerminatorKind::CondBranch {
cond: Operand::Constant(ConstantValue::Boolean(true)),
target_true: BlockId(1),
target_false: BlockId(2),
},
span: Span::new(0, 0),
},
}],
};
let mut module = MirModule {
functions: vec![func],
};
fold_constants(&mut module);
let block = &module.functions[0].blocks[0];
// The condition branch should be folded into an unconditional Goto
assert!(matches!(
block.terminator.kind,
TerminatorKind::Goto { target: BlockId(1) }
));
}
}
+2
View File
@@ -1,2 +1,4 @@
pub mod builder;
pub mod dce;
pub mod fold;
pub mod mir;