feat: add constant folding and dead code elimination passes
This commit is contained in:
@@ -64,11 +64,7 @@ impl CraneliftBackend {
|
||||
.optimize(self.module.isa(), &mut ctrl_plane)
|
||||
.unwrap();
|
||||
|
||||
ir_output.push_str(&format!(
|
||||
"; Function: {}\n{}",
|
||||
func.name,
|
||||
self.ctx.func.to_string()
|
||||
));
|
||||
ir_output.push_str(&format!("; Function: {}\n{}", func.name, self.ctx.func));
|
||||
ir_output.push('\n');
|
||||
|
||||
let func_id = self
|
||||
|
||||
+9
-1
@@ -5,6 +5,8 @@ use clap::Parser as ClapParser;
|
||||
use crate::frontend::parser::Parser;
|
||||
use crate::frontend::sema::Sema;
|
||||
use crate::middle::builder::MirBuilder;
|
||||
use crate::middle::dce::eliminate_dead_code;
|
||||
use crate::middle::fold::fold_constants;
|
||||
|
||||
pub mod backend;
|
||||
pub mod frontend;
|
||||
@@ -57,7 +59,13 @@ fn main() {
|
||||
exit(1);
|
||||
}
|
||||
|
||||
let mir_module = MirBuilder::build(&typed_module);
|
||||
let mut mir_module = MirBuilder::build(&typed_module);
|
||||
fold_constants(&mut mir_module);
|
||||
let warnings = eliminate_dead_code(&mut mir_module);
|
||||
|
||||
for warning in warnings {
|
||||
eprintln!("Warning: {} at {:?}", warning.message, warning.span);
|
||||
}
|
||||
|
||||
let backend = CraneliftBackend::new();
|
||||
let (ir, obj_bytes) = backend.compile_module(&mir_module);
|
||||
|
||||
@@ -0,0 +1,197 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
use crate::frontend::token::Span;
|
||||
use crate::middle::mir::*;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub struct MirWarning {
|
||||
pub message: String,
|
||||
pub span: Span,
|
||||
}
|
||||
|
||||
/// Eliminates unreachable basic blocks from the Control Flow Graph.
|
||||
pub fn eliminate_dead_code(module: &mut MirModule) -> Vec<MirWarning> {
|
||||
let mut warnings = Vec::new();
|
||||
for func in &mut module.functions {
|
||||
optimize_function(func, &mut warnings);
|
||||
}
|
||||
warnings.sort_by_key(|w| w.span.start);
|
||||
warnings.dedup_by_key(|w| w.span.start);
|
||||
warnings
|
||||
}
|
||||
|
||||
fn optimize_function(func: &mut MirFunction, warnings: &mut Vec<MirWarning>) {
|
||||
if func.blocks.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut visited = HashSet::new();
|
||||
// The first block generated is always the entry block.
|
||||
let mut worklist = vec![func.blocks[0].id];
|
||||
|
||||
let block_map: HashMap<BlockId, &BasicBlock> = func.blocks.iter().map(|b| (b.id, b)).collect();
|
||||
|
||||
while let Some(id) = worklist.pop() {
|
||||
if !visited.insert(id) {
|
||||
continue; // Already visited
|
||||
}
|
||||
|
||||
if let Some(block) = block_map.get(&id) {
|
||||
match &block.terminator.kind {
|
||||
TerminatorKind::Goto { target } => worklist.push(*target),
|
||||
TerminatorKind::CondBranch {
|
||||
cond,
|
||||
target_true,
|
||||
target_false,
|
||||
} => match cond {
|
||||
Operand::Constant(ConstantValue::Boolean(true)) => {
|
||||
worklist.push(*target_true);
|
||||
}
|
||||
Operand::Constant(ConstantValue::Boolean(false)) => {
|
||||
worklist.push(*target_false);
|
||||
}
|
||||
_ => {
|
||||
worklist.push(*target_true);
|
||||
worklist.push(*target_false);
|
||||
}
|
||||
},
|
||||
TerminatorKind::Return { .. } | TerminatorKind::Unreachable => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Emit warnings for the dead code we are about to remove
|
||||
for block in &func.blocks {
|
||||
if !visited.contains(&block.id) {
|
||||
let is_implicit_unreachable = block.statements.is_empty()
|
||||
&& matches!(block.terminator.kind, TerminatorKind::Unreachable);
|
||||
|
||||
// Do not warn about implicit unreachables (e.g. missing returns handled by builder)
|
||||
if !is_implicit_unreachable {
|
||||
let span = block
|
||||
.statements
|
||||
.first()
|
||||
.map(|s| s.span)
|
||||
.unwrap_or(block.terminator.span);
|
||||
warnings.push(MirWarning {
|
||||
message: "unreachable code".to_string(),
|
||||
span,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Retain only the blocks that were successfully reached during our traversal
|
||||
func.blocks.retain(|b| visited.contains(&b.id));
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use crate::frontend::sema::Ty;
|
||||
use crate::frontend::token::Span;
|
||||
|
||||
#[test]
|
||||
fn test_eliminate_dead_blocks() {
|
||||
let mut module = MirModule {
|
||||
functions: vec![MirFunction {
|
||||
name: "test_func".to_string(),
|
||||
params: vec![],
|
||||
return_type: Ty::Unit,
|
||||
locals: vec![],
|
||||
blocks: vec![
|
||||
BasicBlock {
|
||||
id: BlockId(0), // Entry block
|
||||
statements: vec![],
|
||||
terminator: Terminator {
|
||||
kind: TerminatorKind::Goto { target: BlockId(1) },
|
||||
span: Span::new(0, 0),
|
||||
},
|
||||
},
|
||||
BasicBlock {
|
||||
id: BlockId(1), // Reachable
|
||||
statements: vec![],
|
||||
terminator: Terminator {
|
||||
kind: TerminatorKind::Return { value: None },
|
||||
span: Span::new(0, 0),
|
||||
},
|
||||
},
|
||||
BasicBlock {
|
||||
id: BlockId(2), // Unreachable
|
||||
statements: vec![],
|
||||
terminator: Terminator {
|
||||
kind: TerminatorKind::Goto { target: BlockId(1) },
|
||||
span: Span::new(0, 0),
|
||||
},
|
||||
},
|
||||
],
|
||||
}],
|
||||
};
|
||||
|
||||
let warnings = eliminate_dead_code(&mut module);
|
||||
assert_eq!(warnings.len(), 1);
|
||||
|
||||
let blocks = &module.functions[0].blocks;
|
||||
assert_eq!(blocks.len(), 2, "Expected exactly 2 reachable blocks");
|
||||
assert!(blocks.iter().any(|b| b.id == BlockId(0)));
|
||||
assert!(blocks.iter().any(|b| b.id == BlockId(1)));
|
||||
assert!(
|
||||
!blocks.iter().any(|b| b.id == BlockId(2)),
|
||||
"Block 2 should have been eliminated"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_eliminate_dead_cond_branch() {
|
||||
let mut module = MirModule {
|
||||
functions: vec![MirFunction {
|
||||
name: "test_cond_func".to_string(),
|
||||
params: vec![],
|
||||
return_type: Ty::Unit,
|
||||
locals: vec![],
|
||||
blocks: vec![
|
||||
BasicBlock {
|
||||
id: BlockId(0), // Entry block
|
||||
statements: vec![],
|
||||
terminator: Terminator {
|
||||
kind: TerminatorKind::CondBranch {
|
||||
cond: Operand::Constant(ConstantValue::Boolean(true)),
|
||||
target_true: BlockId(1),
|
||||
target_false: BlockId(2),
|
||||
},
|
||||
span: Span::new(0, 0),
|
||||
},
|
||||
},
|
||||
BasicBlock {
|
||||
id: BlockId(1), // Reachable (true branch)
|
||||
statements: vec![],
|
||||
terminator: Terminator {
|
||||
kind: TerminatorKind::Return { value: None },
|
||||
span: Span::new(0, 0),
|
||||
},
|
||||
},
|
||||
BasicBlock {
|
||||
id: BlockId(2), // Unreachable (false branch)
|
||||
statements: vec![],
|
||||
terminator: Terminator {
|
||||
kind: TerminatorKind::Return { value: None },
|
||||
span: Span::new(0, 0),
|
||||
},
|
||||
},
|
||||
],
|
||||
}],
|
||||
};
|
||||
|
||||
let warnings = eliminate_dead_code(&mut module);
|
||||
assert_eq!(warnings.len(), 1);
|
||||
|
||||
let blocks = &module.functions[0].blocks;
|
||||
assert_eq!(blocks.len(), 2, "Expected exactly 2 reachable blocks");
|
||||
assert!(blocks.iter().any(|b| b.id == BlockId(0)));
|
||||
assert!(blocks.iter().any(|b| b.id == BlockId(1)));
|
||||
assert!(
|
||||
!blocks.iter().any(|b| b.id == BlockId(2)),
|
||||
"Block 2 should have been eliminated"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,278 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::frontend::ast::{BinaryOp, UnaryOp};
|
||||
use crate::frontend::sema::Ty;
|
||||
use crate::middle::mir::*;
|
||||
|
||||
/// Folds constant expressions and simplifies control flow branches where the condition is known.
|
||||
pub fn fold_constants(module: &mut MirModule) {
|
||||
for func in &mut module.functions {
|
||||
optimize_function(func);
|
||||
}
|
||||
}
|
||||
|
||||
fn optimize_function(func: &mut MirFunction) {
|
||||
for block in &mut func.blocks {
|
||||
// Block-local constant tracking. This is safe even without strict SSA
|
||||
// because compiler-generated temporaries are evaluated exactly once
|
||||
// prior to being used within their basic block.
|
||||
let mut known_constants = HashMap::new();
|
||||
|
||||
for stmt in &mut block.statements {
|
||||
let StatementKind::Assign(local, rvalue) = &mut stmt.kind;
|
||||
|
||||
// Propagate any known constants downwards into the rvalue
|
||||
propagate_rvalue(rvalue, &known_constants);
|
||||
|
||||
// Attempt to compute the rvalue
|
||||
if let Some(constant) = evaluate_rvalue(rvalue) {
|
||||
// Replace the complex instruction with a simple constant use
|
||||
*rvalue = Rvalue::Use(Operand::Constant(constant.clone()));
|
||||
known_constants.insert(*local, constant);
|
||||
}
|
||||
}
|
||||
|
||||
// Propagate constants into the terminator
|
||||
match &mut block.terminator.kind {
|
||||
TerminatorKind::CondBranch {
|
||||
cond,
|
||||
target_true,
|
||||
target_false,
|
||||
} => {
|
||||
propagate_operand(cond, &known_constants);
|
||||
|
||||
// If the condition is statically known, fold the branch into a simple Goto!
|
||||
if let Operand::Constant(ConstantValue::Boolean(val)) = cond {
|
||||
let target = if *val { *target_true } else { *target_false };
|
||||
|
||||
block.terminator.kind = TerminatorKind::Goto { target };
|
||||
}
|
||||
}
|
||||
TerminatorKind::Return { value: Some(val) } => {
|
||||
propagate_operand(val, &known_constants);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn propagate_operand(operand: &mut Operand, known_constants: &HashMap<LocalId, ConstantValue>) {
|
||||
if let Operand::Copy(local) = operand
|
||||
&& let Some(constant) = known_constants.get(local)
|
||||
{
|
||||
*operand = Operand::Constant(constant.clone());
|
||||
}
|
||||
}
|
||||
|
||||
fn propagate_rvalue(rvalue: &mut Rvalue, known_constants: &HashMap<LocalId, ConstantValue>) {
|
||||
match rvalue {
|
||||
Rvalue::Use(op) => propagate_operand(op, known_constants),
|
||||
Rvalue::UnaryOp(_, op) => propagate_operand(op, known_constants),
|
||||
Rvalue::BinaryOp(_, lhs, rhs) => {
|
||||
propagate_operand(lhs, known_constants);
|
||||
propagate_operand(rhs, known_constants);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn evaluate_rvalue(rvalue: &Rvalue) -> Option<ConstantValue> {
|
||||
match rvalue {
|
||||
Rvalue::Use(Operand::Constant(c)) => Some(c.clone()),
|
||||
Rvalue::UnaryOp(op, Operand::Constant(c)) => evaluate_unary(*op, c),
|
||||
Rvalue::BinaryOp(op, Operand::Constant(l), Operand::Constant(r)) => {
|
||||
evaluate_binary(*op, l, r)
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn evaluate_unary(op: UnaryOp, val: &ConstantValue) -> Option<ConstantValue> {
|
||||
match (op, val) {
|
||||
(UnaryOp::Neg, ConstantValue::Integer(v, ty)) => {
|
||||
Some(ConstantValue::Integer(v.wrapping_neg(), ty.clone()))
|
||||
}
|
||||
(UnaryOp::Not, ConstantValue::Boolean(b)) => Some(ConstantValue::Boolean(!b)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn evaluate_binary(
|
||||
op: BinaryOp,
|
||||
lhs: &ConstantValue,
|
||||
rhs: &ConstantValue,
|
||||
) -> Option<ConstantValue> {
|
||||
match (lhs, rhs) {
|
||||
(ConstantValue::Integer(l, ty), ConstantValue::Integer(r, _)) => {
|
||||
let is_signed = matches!(ty, Ty::I8 | Ty::I16 | Ty::I32 | Ty::I64);
|
||||
|
||||
match op {
|
||||
BinaryOp::Add => Some(ConstantValue::Integer(l.wrapping_add(*r), ty.clone())),
|
||||
BinaryOp::Sub => Some(ConstantValue::Integer(l.wrapping_sub(*r), ty.clone())),
|
||||
BinaryOp::Mul => Some(ConstantValue::Integer(l.wrapping_mul(*r), ty.clone())),
|
||||
|
||||
// Avoid dividing by 0 during compile-time constant evaluation
|
||||
BinaryOp::Div if *r != 0 => {
|
||||
if is_signed {
|
||||
// `overflowing_div` safely sidesteps the `i64::MIN / -1` panic
|
||||
let (result, _) = (*l as i64).overflowing_div(*r as i64);
|
||||
Some(ConstantValue::Integer(result as u64, ty.clone()))
|
||||
} else {
|
||||
Some(ConstantValue::Integer(l.wrapping_div(*r), ty.clone()))
|
||||
}
|
||||
}
|
||||
BinaryOp::Rem if *r != 0 => {
|
||||
if is_signed {
|
||||
let (result, _) = (*l as i64).overflowing_rem(*r as i64);
|
||||
Some(ConstantValue::Integer(result as u64, ty.clone()))
|
||||
} else {
|
||||
Some(ConstantValue::Integer(l.wrapping_rem(*r), ty.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
BinaryOp::Eq => Some(ConstantValue::Boolean(l == r)),
|
||||
BinaryOp::Neq => Some(ConstantValue::Boolean(l != r)),
|
||||
|
||||
BinaryOp::Lt => Some(ConstantValue::Boolean(if is_signed {
|
||||
(*l as i64) < (*r as i64)
|
||||
} else {
|
||||
l < r
|
||||
})),
|
||||
BinaryOp::Le => Some(ConstantValue::Boolean(if is_signed {
|
||||
(*l as i64) <= (*r as i64)
|
||||
} else {
|
||||
l <= r
|
||||
})),
|
||||
BinaryOp::Gt => Some(ConstantValue::Boolean(if is_signed {
|
||||
(*l as i64) > (*r as i64)
|
||||
} else {
|
||||
l > r
|
||||
})),
|
||||
BinaryOp::Ge => Some(ConstantValue::Boolean(if is_signed {
|
||||
(*l as i64) >= (*r as i64)
|
||||
} else {
|
||||
l >= r
|
||||
})),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
(ConstantValue::Boolean(l), ConstantValue::Boolean(r)) => match op {
|
||||
BinaryOp::Eq => Some(ConstantValue::Boolean(l == r)),
|
||||
BinaryOp::Neq => Some(ConstantValue::Boolean(l != r)),
|
||||
_ => None,
|
||||
},
|
||||
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use crate::frontend::token::Span;
|
||||
|
||||
#[test]
|
||||
fn test_fold_arithmetic() {
|
||||
let func = MirFunction {
|
||||
name: "test".to_string(),
|
||||
params: vec![],
|
||||
return_type: Ty::I32,
|
||||
locals: vec![],
|
||||
blocks: vec![BasicBlock {
|
||||
id: BlockId(0),
|
||||
statements: vec![
|
||||
Statement {
|
||||
kind: StatementKind::Assign(
|
||||
LocalId(0),
|
||||
Rvalue::Use(Operand::Constant(ConstantValue::Integer(5, Ty::I32))),
|
||||
),
|
||||
span: Span::new(0, 0),
|
||||
},
|
||||
Statement {
|
||||
kind: StatementKind::Assign(
|
||||
LocalId(1),
|
||||
Rvalue::Use(Operand::Constant(ConstantValue::Integer(10, Ty::I32))),
|
||||
),
|
||||
span: Span::new(0, 0),
|
||||
},
|
||||
Statement {
|
||||
kind: StatementKind::Assign(
|
||||
LocalId(2),
|
||||
Rvalue::BinaryOp(
|
||||
BinaryOp::Add,
|
||||
Operand::Copy(LocalId(0)),
|
||||
Operand::Copy(LocalId(1)),
|
||||
),
|
||||
),
|
||||
span: Span::new(0, 0),
|
||||
},
|
||||
],
|
||||
terminator: Terminator {
|
||||
kind: TerminatorKind::Return {
|
||||
value: Some(Operand::Copy(LocalId(2))),
|
||||
},
|
||||
span: Span::new(0, 0),
|
||||
},
|
||||
}],
|
||||
};
|
||||
|
||||
let mut module = MirModule {
|
||||
functions: vec![func],
|
||||
};
|
||||
|
||||
fold_constants(&mut module);
|
||||
|
||||
let block = &module.functions[0].blocks[0];
|
||||
|
||||
// The third statement (LocalId(2) = ...) should be folded to 15
|
||||
let StatementKind::Assign(_, rvalue) = &block.statements[2].kind;
|
||||
assert!(matches!(
|
||||
rvalue,
|
||||
Rvalue::Use(Operand::Constant(ConstantValue::Integer(15, Ty::I32)))
|
||||
));
|
||||
|
||||
// The return terminator should be updated to return the constant 15 directly
|
||||
assert!(matches!(
|
||||
&block.terminator.kind,
|
||||
TerminatorKind::Return {
|
||||
value: Some(Operand::Constant(ConstantValue::Integer(15, Ty::I32)))
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fold_cond_branch() {
|
||||
let func = MirFunction {
|
||||
name: "test".to_string(),
|
||||
params: vec![],
|
||||
return_type: Ty::Unit,
|
||||
locals: vec![],
|
||||
blocks: vec![BasicBlock {
|
||||
id: BlockId(0),
|
||||
statements: vec![],
|
||||
terminator: Terminator {
|
||||
kind: TerminatorKind::CondBranch {
|
||||
cond: Operand::Constant(ConstantValue::Boolean(true)),
|
||||
target_true: BlockId(1),
|
||||
target_false: BlockId(2),
|
||||
},
|
||||
span: Span::new(0, 0),
|
||||
},
|
||||
}],
|
||||
};
|
||||
|
||||
let mut module = MirModule {
|
||||
functions: vec![func],
|
||||
};
|
||||
|
||||
fold_constants(&mut module);
|
||||
|
||||
let block = &module.functions[0].blocks[0];
|
||||
|
||||
// The condition branch should be folded into an unconditional Goto
|
||||
assert!(matches!(
|
||||
block.terminator.kind,
|
||||
TerminatorKind::Goto { target: BlockId(1) }
|
||||
));
|
||||
}
|
||||
}
|
||||
@@ -1,2 +1,4 @@
|
||||
pub mod builder;
|
||||
pub mod dce;
|
||||
pub mod fold;
|
||||
pub mod mir;
|
||||
|
||||
Reference in New Issue
Block a user