feat: add constant folding and dead code elimination passes
This commit is contained in:
@@ -0,0 +1,278 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::frontend::ast::{BinaryOp, UnaryOp};
|
||||
use crate::frontend::sema::Ty;
|
||||
use crate::middle::mir::*;
|
||||
|
||||
/// Folds constant expressions and simplifies control flow branches where the condition is known.
|
||||
pub fn fold_constants(module: &mut MirModule) {
|
||||
for func in &mut module.functions {
|
||||
optimize_function(func);
|
||||
}
|
||||
}
|
||||
|
||||
fn optimize_function(func: &mut MirFunction) {
|
||||
for block in &mut func.blocks {
|
||||
// Block-local constant tracking. This is safe even without strict SSA
|
||||
// because compiler-generated temporaries are evaluated exactly once
|
||||
// prior to being used within their basic block.
|
||||
let mut known_constants = HashMap::new();
|
||||
|
||||
for stmt in &mut block.statements {
|
||||
let StatementKind::Assign(local, rvalue) = &mut stmt.kind;
|
||||
|
||||
// Propagate any known constants downwards into the rvalue
|
||||
propagate_rvalue(rvalue, &known_constants);
|
||||
|
||||
// Attempt to compute the rvalue
|
||||
if let Some(constant) = evaluate_rvalue(rvalue) {
|
||||
// Replace the complex instruction with a simple constant use
|
||||
*rvalue = Rvalue::Use(Operand::Constant(constant.clone()));
|
||||
known_constants.insert(*local, constant);
|
||||
}
|
||||
}
|
||||
|
||||
// Propagate constants into the terminator
|
||||
match &mut block.terminator.kind {
|
||||
TerminatorKind::CondBranch {
|
||||
cond,
|
||||
target_true,
|
||||
target_false,
|
||||
} => {
|
||||
propagate_operand(cond, &known_constants);
|
||||
|
||||
// If the condition is statically known, fold the branch into a simple Goto!
|
||||
if let Operand::Constant(ConstantValue::Boolean(val)) = cond {
|
||||
let target = if *val { *target_true } else { *target_false };
|
||||
|
||||
block.terminator.kind = TerminatorKind::Goto { target };
|
||||
}
|
||||
}
|
||||
TerminatorKind::Return { value: Some(val) } => {
|
||||
propagate_operand(val, &known_constants);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn propagate_operand(operand: &mut Operand, known_constants: &HashMap<LocalId, ConstantValue>) {
|
||||
if let Operand::Copy(local) = operand
|
||||
&& let Some(constant) = known_constants.get(local)
|
||||
{
|
||||
*operand = Operand::Constant(constant.clone());
|
||||
}
|
||||
}
|
||||
|
||||
fn propagate_rvalue(rvalue: &mut Rvalue, known_constants: &HashMap<LocalId, ConstantValue>) {
|
||||
match rvalue {
|
||||
Rvalue::Use(op) => propagate_operand(op, known_constants),
|
||||
Rvalue::UnaryOp(_, op) => propagate_operand(op, known_constants),
|
||||
Rvalue::BinaryOp(_, lhs, rhs) => {
|
||||
propagate_operand(lhs, known_constants);
|
||||
propagate_operand(rhs, known_constants);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn evaluate_rvalue(rvalue: &Rvalue) -> Option<ConstantValue> {
|
||||
match rvalue {
|
||||
Rvalue::Use(Operand::Constant(c)) => Some(c.clone()),
|
||||
Rvalue::UnaryOp(op, Operand::Constant(c)) => evaluate_unary(*op, c),
|
||||
Rvalue::BinaryOp(op, Operand::Constant(l), Operand::Constant(r)) => {
|
||||
evaluate_binary(*op, l, r)
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn evaluate_unary(op: UnaryOp, val: &ConstantValue) -> Option<ConstantValue> {
|
||||
match (op, val) {
|
||||
(UnaryOp::Neg, ConstantValue::Integer(v, ty)) => {
|
||||
Some(ConstantValue::Integer(v.wrapping_neg(), ty.clone()))
|
||||
}
|
||||
(UnaryOp::Not, ConstantValue::Boolean(b)) => Some(ConstantValue::Boolean(!b)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn evaluate_binary(
|
||||
op: BinaryOp,
|
||||
lhs: &ConstantValue,
|
||||
rhs: &ConstantValue,
|
||||
) -> Option<ConstantValue> {
|
||||
match (lhs, rhs) {
|
||||
(ConstantValue::Integer(l, ty), ConstantValue::Integer(r, _)) => {
|
||||
let is_signed = matches!(ty, Ty::I8 | Ty::I16 | Ty::I32 | Ty::I64);
|
||||
|
||||
match op {
|
||||
BinaryOp::Add => Some(ConstantValue::Integer(l.wrapping_add(*r), ty.clone())),
|
||||
BinaryOp::Sub => Some(ConstantValue::Integer(l.wrapping_sub(*r), ty.clone())),
|
||||
BinaryOp::Mul => Some(ConstantValue::Integer(l.wrapping_mul(*r), ty.clone())),
|
||||
|
||||
// Avoid dividing by 0 during compile-time constant evaluation
|
||||
BinaryOp::Div if *r != 0 => {
|
||||
if is_signed {
|
||||
// `overflowing_div` safely sidesteps the `i64::MIN / -1` panic
|
||||
let (result, _) = (*l as i64).overflowing_div(*r as i64);
|
||||
Some(ConstantValue::Integer(result as u64, ty.clone()))
|
||||
} else {
|
||||
Some(ConstantValue::Integer(l.wrapping_div(*r), ty.clone()))
|
||||
}
|
||||
}
|
||||
BinaryOp::Rem if *r != 0 => {
|
||||
if is_signed {
|
||||
let (result, _) = (*l as i64).overflowing_rem(*r as i64);
|
||||
Some(ConstantValue::Integer(result as u64, ty.clone()))
|
||||
} else {
|
||||
Some(ConstantValue::Integer(l.wrapping_rem(*r), ty.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
BinaryOp::Eq => Some(ConstantValue::Boolean(l == r)),
|
||||
BinaryOp::Neq => Some(ConstantValue::Boolean(l != r)),
|
||||
|
||||
BinaryOp::Lt => Some(ConstantValue::Boolean(if is_signed {
|
||||
(*l as i64) < (*r as i64)
|
||||
} else {
|
||||
l < r
|
||||
})),
|
||||
BinaryOp::Le => Some(ConstantValue::Boolean(if is_signed {
|
||||
(*l as i64) <= (*r as i64)
|
||||
} else {
|
||||
l <= r
|
||||
})),
|
||||
BinaryOp::Gt => Some(ConstantValue::Boolean(if is_signed {
|
||||
(*l as i64) > (*r as i64)
|
||||
} else {
|
||||
l > r
|
||||
})),
|
||||
BinaryOp::Ge => Some(ConstantValue::Boolean(if is_signed {
|
||||
(*l as i64) >= (*r as i64)
|
||||
} else {
|
||||
l >= r
|
||||
})),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
(ConstantValue::Boolean(l), ConstantValue::Boolean(r)) => match op {
|
||||
BinaryOp::Eq => Some(ConstantValue::Boolean(l == r)),
|
||||
BinaryOp::Neq => Some(ConstantValue::Boolean(l != r)),
|
||||
_ => None,
|
||||
},
|
||||
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use crate::frontend::token::Span;
|
||||
|
||||
#[test]
|
||||
fn test_fold_arithmetic() {
|
||||
let func = MirFunction {
|
||||
name: "test".to_string(),
|
||||
params: vec![],
|
||||
return_type: Ty::I32,
|
||||
locals: vec![],
|
||||
blocks: vec![BasicBlock {
|
||||
id: BlockId(0),
|
||||
statements: vec![
|
||||
Statement {
|
||||
kind: StatementKind::Assign(
|
||||
LocalId(0),
|
||||
Rvalue::Use(Operand::Constant(ConstantValue::Integer(5, Ty::I32))),
|
||||
),
|
||||
span: Span::new(0, 0),
|
||||
},
|
||||
Statement {
|
||||
kind: StatementKind::Assign(
|
||||
LocalId(1),
|
||||
Rvalue::Use(Operand::Constant(ConstantValue::Integer(10, Ty::I32))),
|
||||
),
|
||||
span: Span::new(0, 0),
|
||||
},
|
||||
Statement {
|
||||
kind: StatementKind::Assign(
|
||||
LocalId(2),
|
||||
Rvalue::BinaryOp(
|
||||
BinaryOp::Add,
|
||||
Operand::Copy(LocalId(0)),
|
||||
Operand::Copy(LocalId(1)),
|
||||
),
|
||||
),
|
||||
span: Span::new(0, 0),
|
||||
},
|
||||
],
|
||||
terminator: Terminator {
|
||||
kind: TerminatorKind::Return {
|
||||
value: Some(Operand::Copy(LocalId(2))),
|
||||
},
|
||||
span: Span::new(0, 0),
|
||||
},
|
||||
}],
|
||||
};
|
||||
|
||||
let mut module = MirModule {
|
||||
functions: vec![func],
|
||||
};
|
||||
|
||||
fold_constants(&mut module);
|
||||
|
||||
let block = &module.functions[0].blocks[0];
|
||||
|
||||
// The third statement (LocalId(2) = ...) should be folded to 15
|
||||
let StatementKind::Assign(_, rvalue) = &block.statements[2].kind;
|
||||
assert!(matches!(
|
||||
rvalue,
|
||||
Rvalue::Use(Operand::Constant(ConstantValue::Integer(15, Ty::I32)))
|
||||
));
|
||||
|
||||
// The return terminator should be updated to return the constant 15 directly
|
||||
assert!(matches!(
|
||||
&block.terminator.kind,
|
||||
TerminatorKind::Return {
|
||||
value: Some(Operand::Constant(ConstantValue::Integer(15, Ty::I32)))
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fold_cond_branch() {
|
||||
let func = MirFunction {
|
||||
name: "test".to_string(),
|
||||
params: vec![],
|
||||
return_type: Ty::Unit,
|
||||
locals: vec![],
|
||||
blocks: vec![BasicBlock {
|
||||
id: BlockId(0),
|
||||
statements: vec![],
|
||||
terminator: Terminator {
|
||||
kind: TerminatorKind::CondBranch {
|
||||
cond: Operand::Constant(ConstantValue::Boolean(true)),
|
||||
target_true: BlockId(1),
|
||||
target_false: BlockId(2),
|
||||
},
|
||||
span: Span::new(0, 0),
|
||||
},
|
||||
}],
|
||||
};
|
||||
|
||||
let mut module = MirModule {
|
||||
functions: vec![func],
|
||||
};
|
||||
|
||||
fold_constants(&mut module);
|
||||
|
||||
let block = &module.functions[0].blocks[0];
|
||||
|
||||
// The condition branch should be folded into an unconditional Goto
|
||||
assert!(matches!(
|
||||
block.terminator.kind,
|
||||
TerminatorKind::Goto { target: BlockId(1) }
|
||||
));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user