use std::collections::HashMap; use crate::frontend::ast::{BinaryOp, UnaryOp}; use crate::frontend::sema::Ty; use crate::middle::mir::*; /// Folds constant expressions and simplifies control flow branches where the condition is known. pub fn fold_constants(module: &mut MirModule) { for func in &mut module.functions { optimize_function(func); } } fn optimize_function(func: &mut MirFunction) { for block in &mut func.blocks { // Block-local constant tracking. This is safe even without strict SSA // because compiler-generated temporaries are evaluated exactly once // prior to being used within their basic block. let mut known_constants = HashMap::new(); for stmt in &mut block.statements { let StatementKind::Assign(local, rvalue) = &mut stmt.kind; // Propagate any known constants downwards into the rvalue propagate_rvalue(rvalue, &known_constants); // Attempt to compute the rvalue if let Some(constant) = evaluate_rvalue(rvalue) { // Replace the complex instruction with a simple constant use *rvalue = Rvalue::Use(Operand::Constant(constant.clone())); known_constants.insert(*local, constant); } else { // Reassigned to a non-computable value; remove older cached inferences known_constants.remove(local); } } // Propagate constants into the terminator match &mut block.terminator.kind { TerminatorKind::CondBranch { cond, target_true, target_false, } => { propagate_operand(cond, &known_constants); // If the condition is statically known, fold the branch into a simple Goto! if let Operand::Constant(ConstantValue::Boolean(val)) = cond { let target = if *val { *target_true } else { *target_false }; block.terminator.kind = TerminatorKind::Goto { target }; } } TerminatorKind::Return { value: Some(val) } => { propagate_operand(val, &known_constants); } _ => {} } } } fn propagate_operand(operand: &mut Operand, known_constants: &HashMap) { if let Operand::Copy(local) = operand && let Some(constant) = known_constants.get(local) { *operand = Operand::Constant(constant.clone()); } } fn propagate_rvalue(rvalue: &mut Rvalue, known_constants: &HashMap) { match rvalue { Rvalue::Use(op) => propagate_operand(op, known_constants), Rvalue::UnaryOp(_, op) => propagate_operand(op, known_constants), Rvalue::BinaryOp(_, lhs, rhs) => { propagate_operand(lhs, known_constants); propagate_operand(rhs, known_constants); } Rvalue::Cast(_, op) => propagate_operand(op, known_constants), } } fn evaluate_rvalue(rvalue: &Rvalue) -> Option { match rvalue { Rvalue::Use(Operand::Constant(c)) => Some(c.clone()), Rvalue::UnaryOp(op, Operand::Constant(c)) => evaluate_unary(*op, c), Rvalue::BinaryOp(op, Operand::Constant(l), Operand::Constant(r)) => { evaluate_binary(*op, l, r) } Rvalue::Cast(to_ty, Operand::Constant(c)) => evaluate_cast(to_ty, c), _ => None, } } fn evaluate_unary(op: UnaryOp, val: &ConstantValue) -> Option { match (op, val) { (UnaryOp::Neg, ConstantValue::Integer(v, ty)) => { Some(ConstantValue::Integer(v.wrapping_neg(), ty.clone())) } (UnaryOp::Neg, ConstantValue::Float(v, ty)) => Some(ConstantValue::Float(-v, ty.clone())), (UnaryOp::Not, ConstantValue::Boolean(b)) => Some(ConstantValue::Boolean(!b)), _ => None, } } fn evaluate_binary( op: BinaryOp, lhs: &ConstantValue, rhs: &ConstantValue, ) -> Option { match (lhs, rhs) { (ConstantValue::Integer(l, ty), ConstantValue::Integer(r, _)) => { let is_signed = matches!(ty, Ty::I8 | Ty::I16 | Ty::I32 | Ty::I64); match op { BinaryOp::Add => Some(ConstantValue::Integer(l.wrapping_add(*r), ty.clone())), BinaryOp::Sub => Some(ConstantValue::Integer(l.wrapping_sub(*r), ty.clone())), BinaryOp::Mul => Some(ConstantValue::Integer(l.wrapping_mul(*r), ty.clone())), // Avoid dividing by 0 during compile-time constant evaluation BinaryOp::Div if *r != 0 => { if is_signed { // `overflowing_div` safely sidesteps the `i64::MIN / -1` panic let (result, _) = (*l as i64).overflowing_div(*r as i64); Some(ConstantValue::Integer(result as u64, ty.clone())) } else { Some(ConstantValue::Integer(l.wrapping_div(*r), ty.clone())) } } BinaryOp::Rem if *r != 0 => { if is_signed { let (result, _) = (*l as i64).overflowing_rem(*r as i64); Some(ConstantValue::Integer(result as u64, ty.clone())) } else { Some(ConstantValue::Integer(l.wrapping_rem(*r), ty.clone())) } } BinaryOp::Eq => Some(ConstantValue::Boolean(l == r)), BinaryOp::Neq => Some(ConstantValue::Boolean(l != r)), BinaryOp::Lt => Some(ConstantValue::Boolean(if is_signed { (*l as i64) < (*r as i64) } else { l < r })), BinaryOp::Le => Some(ConstantValue::Boolean(if is_signed { (*l as i64) <= (*r as i64) } else { l <= r })), BinaryOp::Gt => Some(ConstantValue::Boolean(if is_signed { (*l as i64) > (*r as i64) } else { l > r })), BinaryOp::Ge => Some(ConstantValue::Boolean(if is_signed { (*l as i64) >= (*r as i64) } else { l >= r })), _ => None, } } (ConstantValue::Boolean(l), ConstantValue::Boolean(r)) => match op { BinaryOp::Eq => Some(ConstantValue::Boolean(l == r)), BinaryOp::Neq => Some(ConstantValue::Boolean(l != r)), _ => None, }, (ConstantValue::Float(l, ty), ConstantValue::Float(r, _)) => match op { BinaryOp::Add => Some(ConstantValue::Float(l + r, ty.clone())), BinaryOp::Sub => Some(ConstantValue::Float(l - r, ty.clone())), BinaryOp::Mul => Some(ConstantValue::Float(l * r, ty.clone())), BinaryOp::Div => Some(ConstantValue::Float(l / r, ty.clone())), BinaryOp::Rem => Some(ConstantValue::Float(l % r, ty.clone())), BinaryOp::Eq => Some(ConstantValue::Boolean(l == r)), BinaryOp::Neq => Some(ConstantValue::Boolean(l != r)), BinaryOp::Lt => Some(ConstantValue::Boolean(l < r)), BinaryOp::Le => Some(ConstantValue::Boolean(l <= r)), BinaryOp::Gt => Some(ConstantValue::Boolean(l > r)), BinaryOp::Ge => Some(ConstantValue::Boolean(l >= r)), }, _ => None, } } fn evaluate_cast(to_ty: &Ty, val: &ConstantValue) -> Option { if to_ty.is_float() { let f = match val { ConstantValue::Integer(v, ty) => { if ty.is_signed() { let shift = 64 - ty.bit_width(); (((*v as i64) << shift) >> shift) as f64 } else { *v as f64 } } ConstantValue::Float(v, _) => *v, ConstantValue::Boolean(b) => { if *b { 1.0 } else { 0.0 } } }; Some(ConstantValue::Float(f, to_ty.clone())) } else if to_ty.is_integer() { let i = match val { ConstantValue::Integer(v, _) => *v, ConstantValue::Float(v, _) => { if to_ty.is_signed() { (*v as i64) as u64 } else { *v as u64 } } ConstantValue::Boolean(b) => { if *b { 1 } else { 0 } } }; let mask = if to_ty.bit_width() == 64 { u64::MAX } else { (1u64 << to_ty.bit_width()) - 1 }; Some(ConstantValue::Integer(i & mask, to_ty.clone())) } else if to_ty == &Ty::Bool { let b = match val { ConstantValue::Integer(v, _) => *v != 0, ConstantValue::Float(v, _) => *v != 0.0, ConstantValue::Boolean(b) => *b, }; Some(ConstantValue::Boolean(b)) } else { None } } #[cfg(test)] mod test { use super::*; use crate::frontend::token::Span; #[test] fn test_fold_arithmetic() { let func = MirFunction { name: "test".to_string(), params: vec![], return_type: Ty::I32, locals: vec![], blocks: vec![BasicBlock { id: BlockId(0), statements: vec![ Statement { kind: StatementKind::Assign( LocalId(0), Rvalue::Use(Operand::Constant(ConstantValue::Integer(5, Ty::I32))), ), span: Span::new(0, 0), }, Statement { kind: StatementKind::Assign( LocalId(1), Rvalue::Use(Operand::Constant(ConstantValue::Integer(10, Ty::I32))), ), span: Span::new(0, 0), }, Statement { kind: StatementKind::Assign( LocalId(2), Rvalue::BinaryOp( BinaryOp::Add, Operand::Copy(LocalId(0)), Operand::Copy(LocalId(1)), ), ), span: Span::new(0, 0), }, ], terminator: Terminator { kind: TerminatorKind::Return { value: Some(Operand::Copy(LocalId(2))), }, span: Span::new(0, 0), }, }], }; let mut module = MirModule { functions: vec![func], }; fold_constants(&mut module); let block = &module.functions[0].blocks[0]; // The third statement (LocalId(2) = ...) should be folded to 15 let StatementKind::Assign(_, rvalue) = &block.statements[2].kind; assert!(matches!( rvalue, Rvalue::Use(Operand::Constant(ConstantValue::Integer(15, Ty::I32))) )); // The return terminator should be updated to return the constant 15 directly assert!(matches!( &block.terminator.kind, TerminatorKind::Return { value: Some(Operand::Constant(ConstantValue::Integer(15, Ty::I32))) } )); } #[test] fn test_fold_cond_branch() { let func = MirFunction { name: "test".to_string(), params: vec![], return_type: Ty::Unit, locals: vec![], blocks: vec![BasicBlock { id: BlockId(0), statements: vec![], terminator: Terminator { kind: TerminatorKind::CondBranch { cond: Operand::Constant(ConstantValue::Boolean(true)), target_true: BlockId(1), target_false: BlockId(2), }, span: Span::new(0, 0), }, }], }; let mut module = MirModule { functions: vec![func], }; fold_constants(&mut module); let block = &module.functions[0].blocks[0]; // The condition branch should be folded into an unconditional Goto assert!(matches!( block.terminator.kind, TerminatorKind::Goto { target: BlockId(1) } )); } }