Files
compiler-old/src/middle/fold.rs
T
2026-04-22 22:40:19 +02:00

356 lines
13 KiB
Rust

use std::collections::HashMap;
use crate::frontend::ast::{BinaryOp, UnaryOp};
use crate::frontend::sema::Ty;
use crate::middle::mir::*;
/// Folds constant expressions and simplifies control flow branches where the condition is known.
pub fn fold_constants(module: &mut MirModule) {
for func in &mut module.functions {
optimize_function(func);
}
}
fn optimize_function(func: &mut MirFunction) {
for block in &mut func.blocks {
// Block-local constant tracking. This is safe even without strict SSA
// because compiler-generated temporaries are evaluated exactly once
// prior to being used within their basic block.
let mut known_constants = HashMap::new();
for stmt in &mut block.statements {
let StatementKind::Assign(local, rvalue) = &mut stmt.kind;
// Propagate any known constants downwards into the rvalue
propagate_rvalue(rvalue, &known_constants);
// Attempt to compute the rvalue
if let Some(constant) = evaluate_rvalue(rvalue) {
// Replace the complex instruction with a simple constant use
*rvalue = Rvalue::Use(Operand::Constant(constant.clone()));
known_constants.insert(*local, constant);
} else {
// Reassigned to a non-computable value; remove older cached inferences
known_constants.remove(local);
}
}
// Propagate constants into the terminator
match &mut block.terminator.kind {
TerminatorKind::CondBranch {
cond,
target_true,
target_false,
} => {
propagate_operand(cond, &known_constants);
// If the condition is statically known, fold the branch into a simple Goto!
if let Operand::Constant(ConstantValue::Boolean(val)) = cond {
let target = if *val { *target_true } else { *target_false };
block.terminator.kind = TerminatorKind::Goto { target };
}
}
TerminatorKind::Return { value: Some(val) } => {
propagate_operand(val, &known_constants);
}
_ => {}
}
}
}
fn propagate_operand(operand: &mut Operand, known_constants: &HashMap<LocalId, ConstantValue>) {
if let Operand::Copy(local) = operand
&& let Some(constant) = known_constants.get(local)
{
*operand = Operand::Constant(constant.clone());
}
}
fn propagate_rvalue(rvalue: &mut Rvalue, known_constants: &HashMap<LocalId, ConstantValue>) {
match rvalue {
Rvalue::Use(op) => propagate_operand(op, known_constants),
Rvalue::UnaryOp(_, op) => propagate_operand(op, known_constants),
Rvalue::BinaryOp(_, lhs, rhs) => {
propagate_operand(lhs, known_constants);
propagate_operand(rhs, known_constants);
}
Rvalue::Cast(_, op) => propagate_operand(op, known_constants),
}
}
fn evaluate_rvalue(rvalue: &Rvalue) -> Option<ConstantValue> {
match rvalue {
Rvalue::Use(Operand::Constant(c)) => Some(c.clone()),
Rvalue::UnaryOp(op, Operand::Constant(c)) => evaluate_unary(*op, c),
Rvalue::BinaryOp(op, Operand::Constant(l), Operand::Constant(r)) => {
evaluate_binary(*op, l, r)
}
Rvalue::Cast(to_ty, Operand::Constant(c)) => evaluate_cast(to_ty, c),
_ => None,
}
}
fn evaluate_unary(op: UnaryOp, val: &ConstantValue) -> Option<ConstantValue> {
match (op, val) {
(UnaryOp::Neg, ConstantValue::Integer(v, ty)) => {
Some(ConstantValue::Integer(v.wrapping_neg(), ty.clone()))
}
(UnaryOp::Neg, ConstantValue::Float(v, ty)) => Some(ConstantValue::Float(-v, ty.clone())),
(UnaryOp::Not, ConstantValue::Boolean(b)) => Some(ConstantValue::Boolean(!b)),
_ => None,
}
}
fn evaluate_binary(
op: BinaryOp,
lhs: &ConstantValue,
rhs: &ConstantValue,
) -> Option<ConstantValue> {
match (lhs, rhs) {
(ConstantValue::Integer(l, ty), ConstantValue::Integer(r, _)) => {
let is_signed = matches!(ty, Ty::I8 | Ty::I16 | Ty::I32 | Ty::I64);
match op {
BinaryOp::Add => Some(ConstantValue::Integer(l.wrapping_add(*r), ty.clone())),
BinaryOp::Sub => Some(ConstantValue::Integer(l.wrapping_sub(*r), ty.clone())),
BinaryOp::Mul => Some(ConstantValue::Integer(l.wrapping_mul(*r), ty.clone())),
// Avoid dividing by 0 during compile-time constant evaluation
BinaryOp::Div if *r != 0 => {
if is_signed {
// `overflowing_div` safely sidesteps the `i64::MIN / -1` panic
let (result, _) = (*l as i64).overflowing_div(*r as i64);
Some(ConstantValue::Integer(result as u64, ty.clone()))
} else {
Some(ConstantValue::Integer(l.wrapping_div(*r), ty.clone()))
}
}
BinaryOp::Rem if *r != 0 => {
if is_signed {
let (result, _) = (*l as i64).overflowing_rem(*r as i64);
Some(ConstantValue::Integer(result as u64, ty.clone()))
} else {
Some(ConstantValue::Integer(l.wrapping_rem(*r), ty.clone()))
}
}
BinaryOp::Eq => Some(ConstantValue::Boolean(l == r)),
BinaryOp::Neq => Some(ConstantValue::Boolean(l != r)),
BinaryOp::Lt => Some(ConstantValue::Boolean(if is_signed {
(*l as i64) < (*r as i64)
} else {
l < r
})),
BinaryOp::Le => Some(ConstantValue::Boolean(if is_signed {
(*l as i64) <= (*r as i64)
} else {
l <= r
})),
BinaryOp::Gt => Some(ConstantValue::Boolean(if is_signed {
(*l as i64) > (*r as i64)
} else {
l > r
})),
BinaryOp::Ge => Some(ConstantValue::Boolean(if is_signed {
(*l as i64) >= (*r as i64)
} else {
l >= r
})),
_ => None,
}
}
(ConstantValue::Boolean(l), ConstantValue::Boolean(r)) => match op {
BinaryOp::Eq => Some(ConstantValue::Boolean(l == r)),
BinaryOp::Neq => Some(ConstantValue::Boolean(l != r)),
_ => None,
},
(ConstantValue::Float(l, ty), ConstantValue::Float(r, _)) => match op {
BinaryOp::Add => Some(ConstantValue::Float(l + r, ty.clone())),
BinaryOp::Sub => Some(ConstantValue::Float(l - r, ty.clone())),
BinaryOp::Mul => Some(ConstantValue::Float(l * r, ty.clone())),
BinaryOp::Div => Some(ConstantValue::Float(l / r, ty.clone())),
BinaryOp::Rem => Some(ConstantValue::Float(l % r, ty.clone())),
BinaryOp::Eq => Some(ConstantValue::Boolean(l == r)),
BinaryOp::Neq => Some(ConstantValue::Boolean(l != r)),
BinaryOp::Lt => Some(ConstantValue::Boolean(l < r)),
BinaryOp::Le => Some(ConstantValue::Boolean(l <= r)),
BinaryOp::Gt => Some(ConstantValue::Boolean(l > r)),
BinaryOp::Ge => Some(ConstantValue::Boolean(l >= r)),
},
_ => None,
}
}
fn evaluate_cast(to_ty: &Ty, val: &ConstantValue) -> Option<ConstantValue> {
if to_ty.is_float() {
let f = match val {
ConstantValue::Integer(v, ty) => {
if ty.is_signed() {
let shift = 64 - ty.bit_width();
(((*v as i64) << shift) >> shift) as f64
} else {
*v as f64
}
}
ConstantValue::Float(v, _) => *v,
ConstantValue::Boolean(b) => {
if *b {
1.0
} else {
0.0
}
}
};
Some(ConstantValue::Float(f, to_ty.clone()))
} else if to_ty.is_integer() {
let i = match val {
ConstantValue::Integer(v, _) => *v,
ConstantValue::Float(v, _) => {
if to_ty.is_signed() {
(*v as i64) as u64
} else {
*v as u64
}
}
ConstantValue::Boolean(b) => {
if *b {
1
} else {
0
}
}
};
let mask = if to_ty.bit_width() == 64 {
u64::MAX
} else {
(1u64 << to_ty.bit_width()) - 1
};
Some(ConstantValue::Integer(i & mask, to_ty.clone()))
} else if to_ty == &Ty::Bool {
let b = match val {
ConstantValue::Integer(v, _) => *v != 0,
ConstantValue::Float(v, _) => *v != 0.0,
ConstantValue::Boolean(b) => *b,
};
Some(ConstantValue::Boolean(b))
} else {
None
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::frontend::token::Span;
#[test]
fn test_fold_arithmetic() {
let func = MirFunction {
name: "test".to_string(),
params: vec![],
return_type: Ty::I32,
locals: vec![],
blocks: vec![BasicBlock {
id: BlockId(0),
statements: vec![
Statement {
kind: StatementKind::Assign(
LocalId(0),
Rvalue::Use(Operand::Constant(ConstantValue::Integer(5, Ty::I32))),
),
span: Span::new(0, 0),
},
Statement {
kind: StatementKind::Assign(
LocalId(1),
Rvalue::Use(Operand::Constant(ConstantValue::Integer(10, Ty::I32))),
),
span: Span::new(0, 0),
},
Statement {
kind: StatementKind::Assign(
LocalId(2),
Rvalue::BinaryOp(
BinaryOp::Add,
Operand::Copy(LocalId(0)),
Operand::Copy(LocalId(1)),
),
),
span: Span::new(0, 0),
},
],
terminator: Terminator {
kind: TerminatorKind::Return {
value: Some(Operand::Copy(LocalId(2))),
},
span: Span::new(0, 0),
},
}],
};
let mut module = MirModule {
functions: vec![func],
};
fold_constants(&mut module);
let block = &module.functions[0].blocks[0];
// The third statement (LocalId(2) = ...) should be folded to 15
let StatementKind::Assign(_, rvalue) = &block.statements[2].kind;
assert!(matches!(
rvalue,
Rvalue::Use(Operand::Constant(ConstantValue::Integer(15, Ty::I32)))
));
// The return terminator should be updated to return the constant 15 directly
assert!(matches!(
&block.terminator.kind,
TerminatorKind::Return {
value: Some(Operand::Constant(ConstantValue::Integer(15, Ty::I32)))
}
));
}
#[test]
fn test_fold_cond_branch() {
let func = MirFunction {
name: "test".to_string(),
params: vec![],
return_type: Ty::Unit,
locals: vec![],
blocks: vec![BasicBlock {
id: BlockId(0),
statements: vec![],
terminator: Terminator {
kind: TerminatorKind::CondBranch {
cond: Operand::Constant(ConstantValue::Boolean(true)),
target_true: BlockId(1),
target_false: BlockId(2),
},
span: Span::new(0, 0),
},
}],
};
let mut module = MirModule {
functions: vec![func],
};
fold_constants(&mut module);
let block = &module.functions[0].blocks[0];
// The condition branch should be folded into an unconditional Goto
assert!(matches!(
block.terminator.kind,
TerminatorKind::Goto { target: BlockId(1) }
));
}
}