356 lines
13 KiB
Rust
356 lines
13 KiB
Rust
use std::collections::HashMap;
|
|
|
|
use crate::frontend::ast::{BinaryOp, UnaryOp};
|
|
use crate::frontend::sema::Ty;
|
|
use crate::middle::mir::*;
|
|
|
|
/// Folds constant expressions and simplifies control flow branches where the condition is known.
|
|
pub fn fold_constants(module: &mut MirModule) {
|
|
for func in &mut module.functions {
|
|
optimize_function(func);
|
|
}
|
|
}
|
|
|
|
fn optimize_function(func: &mut MirFunction) {
|
|
for block in &mut func.blocks {
|
|
// Block-local constant tracking. This is safe even without strict SSA
|
|
// because compiler-generated temporaries are evaluated exactly once
|
|
// prior to being used within their basic block.
|
|
let mut known_constants = HashMap::new();
|
|
|
|
for stmt in &mut block.statements {
|
|
let StatementKind::Assign(local, rvalue) = &mut stmt.kind;
|
|
|
|
// Propagate any known constants downwards into the rvalue
|
|
propagate_rvalue(rvalue, &known_constants);
|
|
|
|
// Attempt to compute the rvalue
|
|
if let Some(constant) = evaluate_rvalue(rvalue) {
|
|
// Replace the complex instruction with a simple constant use
|
|
*rvalue = Rvalue::Use(Operand::Constant(constant.clone()));
|
|
known_constants.insert(*local, constant);
|
|
} else {
|
|
// Reassigned to a non-computable value; remove older cached inferences
|
|
known_constants.remove(local);
|
|
}
|
|
}
|
|
|
|
// Propagate constants into the terminator
|
|
match &mut block.terminator.kind {
|
|
TerminatorKind::CondBranch {
|
|
cond,
|
|
target_true,
|
|
target_false,
|
|
} => {
|
|
propagate_operand(cond, &known_constants);
|
|
|
|
// If the condition is statically known, fold the branch into a simple Goto!
|
|
if let Operand::Constant(ConstantValue::Boolean(val)) = cond {
|
|
let target = if *val { *target_true } else { *target_false };
|
|
|
|
block.terminator.kind = TerminatorKind::Goto { target };
|
|
}
|
|
}
|
|
TerminatorKind::Return { value: Some(val) } => {
|
|
propagate_operand(val, &known_constants);
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
}
|
|
|
|
fn propagate_operand(operand: &mut Operand, known_constants: &HashMap<LocalId, ConstantValue>) {
|
|
if let Operand::Copy(local) = operand
|
|
&& let Some(constant) = known_constants.get(local)
|
|
{
|
|
*operand = Operand::Constant(constant.clone());
|
|
}
|
|
}
|
|
|
|
fn propagate_rvalue(rvalue: &mut Rvalue, known_constants: &HashMap<LocalId, ConstantValue>) {
|
|
match rvalue {
|
|
Rvalue::Use(op) => propagate_operand(op, known_constants),
|
|
Rvalue::UnaryOp(_, op) => propagate_operand(op, known_constants),
|
|
Rvalue::BinaryOp(_, lhs, rhs) => {
|
|
propagate_operand(lhs, known_constants);
|
|
propagate_operand(rhs, known_constants);
|
|
}
|
|
Rvalue::Cast(_, op) => propagate_operand(op, known_constants),
|
|
}
|
|
}
|
|
|
|
fn evaluate_rvalue(rvalue: &Rvalue) -> Option<ConstantValue> {
|
|
match rvalue {
|
|
Rvalue::Use(Operand::Constant(c)) => Some(c.clone()),
|
|
Rvalue::UnaryOp(op, Operand::Constant(c)) => evaluate_unary(*op, c),
|
|
Rvalue::BinaryOp(op, Operand::Constant(l), Operand::Constant(r)) => {
|
|
evaluate_binary(*op, l, r)
|
|
}
|
|
Rvalue::Cast(to_ty, Operand::Constant(c)) => evaluate_cast(to_ty, c),
|
|
_ => None,
|
|
}
|
|
}
|
|
|
|
fn evaluate_unary(op: UnaryOp, val: &ConstantValue) -> Option<ConstantValue> {
|
|
match (op, val) {
|
|
(UnaryOp::Neg, ConstantValue::Integer(v, ty)) => {
|
|
Some(ConstantValue::Integer(v.wrapping_neg(), ty.clone()))
|
|
}
|
|
(UnaryOp::Neg, ConstantValue::Float(v, ty)) => Some(ConstantValue::Float(-v, ty.clone())),
|
|
(UnaryOp::Not, ConstantValue::Boolean(b)) => Some(ConstantValue::Boolean(!b)),
|
|
_ => None,
|
|
}
|
|
}
|
|
|
|
fn evaluate_binary(
|
|
op: BinaryOp,
|
|
lhs: &ConstantValue,
|
|
rhs: &ConstantValue,
|
|
) -> Option<ConstantValue> {
|
|
match (lhs, rhs) {
|
|
(ConstantValue::Integer(l, ty), ConstantValue::Integer(r, _)) => {
|
|
let is_signed = matches!(ty, Ty::I8 | Ty::I16 | Ty::I32 | Ty::I64);
|
|
|
|
match op {
|
|
BinaryOp::Add => Some(ConstantValue::Integer(l.wrapping_add(*r), ty.clone())),
|
|
BinaryOp::Sub => Some(ConstantValue::Integer(l.wrapping_sub(*r), ty.clone())),
|
|
BinaryOp::Mul => Some(ConstantValue::Integer(l.wrapping_mul(*r), ty.clone())),
|
|
|
|
// Avoid dividing by 0 during compile-time constant evaluation
|
|
BinaryOp::Div if *r != 0 => {
|
|
if is_signed {
|
|
// `overflowing_div` safely sidesteps the `i64::MIN / -1` panic
|
|
let (result, _) = (*l as i64).overflowing_div(*r as i64);
|
|
Some(ConstantValue::Integer(result as u64, ty.clone()))
|
|
} else {
|
|
Some(ConstantValue::Integer(l.wrapping_div(*r), ty.clone()))
|
|
}
|
|
}
|
|
BinaryOp::Rem if *r != 0 => {
|
|
if is_signed {
|
|
let (result, _) = (*l as i64).overflowing_rem(*r as i64);
|
|
Some(ConstantValue::Integer(result as u64, ty.clone()))
|
|
} else {
|
|
Some(ConstantValue::Integer(l.wrapping_rem(*r), ty.clone()))
|
|
}
|
|
}
|
|
|
|
BinaryOp::Eq => Some(ConstantValue::Boolean(l == r)),
|
|
BinaryOp::Neq => Some(ConstantValue::Boolean(l != r)),
|
|
|
|
BinaryOp::Lt => Some(ConstantValue::Boolean(if is_signed {
|
|
(*l as i64) < (*r as i64)
|
|
} else {
|
|
l < r
|
|
})),
|
|
BinaryOp::Le => Some(ConstantValue::Boolean(if is_signed {
|
|
(*l as i64) <= (*r as i64)
|
|
} else {
|
|
l <= r
|
|
})),
|
|
BinaryOp::Gt => Some(ConstantValue::Boolean(if is_signed {
|
|
(*l as i64) > (*r as i64)
|
|
} else {
|
|
l > r
|
|
})),
|
|
BinaryOp::Ge => Some(ConstantValue::Boolean(if is_signed {
|
|
(*l as i64) >= (*r as i64)
|
|
} else {
|
|
l >= r
|
|
})),
|
|
_ => None,
|
|
}
|
|
}
|
|
|
|
(ConstantValue::Boolean(l), ConstantValue::Boolean(r)) => match op {
|
|
BinaryOp::Eq => Some(ConstantValue::Boolean(l == r)),
|
|
BinaryOp::Neq => Some(ConstantValue::Boolean(l != r)),
|
|
_ => None,
|
|
},
|
|
|
|
(ConstantValue::Float(l, ty), ConstantValue::Float(r, _)) => match op {
|
|
BinaryOp::Add => Some(ConstantValue::Float(l + r, ty.clone())),
|
|
BinaryOp::Sub => Some(ConstantValue::Float(l - r, ty.clone())),
|
|
BinaryOp::Mul => Some(ConstantValue::Float(l * r, ty.clone())),
|
|
BinaryOp::Div => Some(ConstantValue::Float(l / r, ty.clone())),
|
|
BinaryOp::Rem => Some(ConstantValue::Float(l % r, ty.clone())),
|
|
BinaryOp::Eq => Some(ConstantValue::Boolean(l == r)),
|
|
BinaryOp::Neq => Some(ConstantValue::Boolean(l != r)),
|
|
BinaryOp::Lt => Some(ConstantValue::Boolean(l < r)),
|
|
BinaryOp::Le => Some(ConstantValue::Boolean(l <= r)),
|
|
BinaryOp::Gt => Some(ConstantValue::Boolean(l > r)),
|
|
BinaryOp::Ge => Some(ConstantValue::Boolean(l >= r)),
|
|
},
|
|
|
|
_ => None,
|
|
}
|
|
}
|
|
|
|
fn evaluate_cast(to_ty: &Ty, val: &ConstantValue) -> Option<ConstantValue> {
|
|
if to_ty.is_float() {
|
|
let f = match val {
|
|
ConstantValue::Integer(v, ty) => {
|
|
if ty.is_signed() {
|
|
let shift = 64 - ty.bit_width();
|
|
(((*v as i64) << shift) >> shift) as f64
|
|
} else {
|
|
*v as f64
|
|
}
|
|
}
|
|
ConstantValue::Float(v, _) => *v,
|
|
ConstantValue::Boolean(b) => {
|
|
if *b {
|
|
1.0
|
|
} else {
|
|
0.0
|
|
}
|
|
}
|
|
};
|
|
Some(ConstantValue::Float(f, to_ty.clone()))
|
|
} else if to_ty.is_integer() {
|
|
let i = match val {
|
|
ConstantValue::Integer(v, _) => *v,
|
|
ConstantValue::Float(v, _) => {
|
|
if to_ty.is_signed() {
|
|
(*v as i64) as u64
|
|
} else {
|
|
*v as u64
|
|
}
|
|
}
|
|
ConstantValue::Boolean(b) => {
|
|
if *b {
|
|
1
|
|
} else {
|
|
0
|
|
}
|
|
}
|
|
};
|
|
let mask = if to_ty.bit_width() == 64 {
|
|
u64::MAX
|
|
} else {
|
|
(1u64 << to_ty.bit_width()) - 1
|
|
};
|
|
Some(ConstantValue::Integer(i & mask, to_ty.clone()))
|
|
} else if to_ty == &Ty::Bool {
|
|
let b = match val {
|
|
ConstantValue::Integer(v, _) => *v != 0,
|
|
ConstantValue::Float(v, _) => *v != 0.0,
|
|
ConstantValue::Boolean(b) => *b,
|
|
};
|
|
Some(ConstantValue::Boolean(b))
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod test {
|
|
use super::*;
|
|
use crate::frontend::token::Span;
|
|
|
|
#[test]
|
|
fn test_fold_arithmetic() {
|
|
let func = MirFunction {
|
|
name: "test".to_string(),
|
|
params: vec![],
|
|
return_type: Ty::I32,
|
|
locals: vec![],
|
|
blocks: vec![BasicBlock {
|
|
id: BlockId(0),
|
|
statements: vec![
|
|
Statement {
|
|
kind: StatementKind::Assign(
|
|
LocalId(0),
|
|
Rvalue::Use(Operand::Constant(ConstantValue::Integer(5, Ty::I32))),
|
|
),
|
|
span: Span::new(0, 0),
|
|
},
|
|
Statement {
|
|
kind: StatementKind::Assign(
|
|
LocalId(1),
|
|
Rvalue::Use(Operand::Constant(ConstantValue::Integer(10, Ty::I32))),
|
|
),
|
|
span: Span::new(0, 0),
|
|
},
|
|
Statement {
|
|
kind: StatementKind::Assign(
|
|
LocalId(2),
|
|
Rvalue::BinaryOp(
|
|
BinaryOp::Add,
|
|
Operand::Copy(LocalId(0)),
|
|
Operand::Copy(LocalId(1)),
|
|
),
|
|
),
|
|
span: Span::new(0, 0),
|
|
},
|
|
],
|
|
terminator: Terminator {
|
|
kind: TerminatorKind::Return {
|
|
value: Some(Operand::Copy(LocalId(2))),
|
|
},
|
|
span: Span::new(0, 0),
|
|
},
|
|
}],
|
|
};
|
|
|
|
let mut module = MirModule {
|
|
functions: vec![func],
|
|
};
|
|
|
|
fold_constants(&mut module);
|
|
|
|
let block = &module.functions[0].blocks[0];
|
|
|
|
// The third statement (LocalId(2) = ...) should be folded to 15
|
|
let StatementKind::Assign(_, rvalue) = &block.statements[2].kind;
|
|
assert!(matches!(
|
|
rvalue,
|
|
Rvalue::Use(Operand::Constant(ConstantValue::Integer(15, Ty::I32)))
|
|
));
|
|
|
|
// The return terminator should be updated to return the constant 15 directly
|
|
assert!(matches!(
|
|
&block.terminator.kind,
|
|
TerminatorKind::Return {
|
|
value: Some(Operand::Constant(ConstantValue::Integer(15, Ty::I32)))
|
|
}
|
|
));
|
|
}
|
|
|
|
#[test]
|
|
fn test_fold_cond_branch() {
|
|
let func = MirFunction {
|
|
name: "test".to_string(),
|
|
params: vec![],
|
|
return_type: Ty::Unit,
|
|
locals: vec![],
|
|
blocks: vec![BasicBlock {
|
|
id: BlockId(0),
|
|
statements: vec![],
|
|
terminator: Terminator {
|
|
kind: TerminatorKind::CondBranch {
|
|
cond: Operand::Constant(ConstantValue::Boolean(true)),
|
|
target_true: BlockId(1),
|
|
target_false: BlockId(2),
|
|
},
|
|
span: Span::new(0, 0),
|
|
},
|
|
}],
|
|
};
|
|
|
|
let mut module = MirModule {
|
|
functions: vec![func],
|
|
};
|
|
|
|
fold_constants(&mut module);
|
|
|
|
let block = &module.functions[0].blocks[0];
|
|
|
|
// The condition branch should be folded into an unconditional Goto
|
|
assert!(matches!(
|
|
block.terminator.kind,
|
|
TerminatorKind::Goto { target: BlockId(1) }
|
|
));
|
|
}
|
|
}
|