commit 9d94e3b81b3b5ea340fb460820e94185ec425b5d Author: Jooris Hadeler Date: Sun Apr 26 19:17:57 2026 +0200 init: initial commit of the Scarlett framework - Initialize Rust project configuration (Cargo) and .gitignore - Implement core Intermediate Representation (IR), printer, and builder utilities - Add IR validation module for type checking and constraint verification - Introduce optimization passes: Mem2Reg, Constant Folding, Copy Propagation, Dead Code Elimination, and SSA Destruction - Implement x86_64 backend for assembly code generation - Add a C test harness and main entry point to generate, compile, and test a GCD assembly function diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ea8c4bf --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +/target diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000..1fc8638 --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,7 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "scarlett" +version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..8fc83a9 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,6 @@ +[package] +name = "scarlett" +version = "0.1.0" +edition = "2024" + +[dependencies] diff --git a/harness.c b/harness.c new file mode 100644 index 0000000..9fdf89a --- /dev/null +++ b/harness.c @@ -0,0 +1,46 @@ +#include +#include +#include + +// Declare the external assembly function generated by the Rust backend +extern uint64_t gcd(uint64_t a, uint64_t b); + +int main() { + // Define an array of test cases: {a, b, expected_result} + uint64_t test_cases[][3] = { + {48, 18, 6}, + {54, 24, 6}, + {7, 13, 1}, // Prime numbers + {100, 10, 10}, // One is a multiple of the other + {2740, 1760, 20}, // Larger random numbers + {1234567890, 90, 90}, // Testing 64-bit bounds + {5, 0, 5} // Edge case: b is 0 (Your IR handles this perfectly!) + }; + + int num_cases = sizeof(test_cases) / sizeof(test_cases[0]); + int passed = 0; + + printf("Running GCD Assembly Tests...\n"); + printf("-----------------------------\n"); + + for (int i = 0; i < num_cases; i++) { + uint64_t a = test_cases[i][0]; + uint64_t b = test_cases[i][1]; + uint64_t expected = test_cases[i][2]; + + // Call out into your compiled assembly! + uint64_t result = gcd(a, b); + + if (result == expected) { + printf("[PASS] gcd(%" PRIu64 ", %" PRIu64 ") = %" PRIu64 "\n", a, b, result); + passed++; + } else { + printf("[FAIL] gcd(%" PRIu64 ", %" PRIu64 ") = %" PRIu64 " (Expected: %" PRIu64 ")\n", a, b, result, expected); + } + } + + printf("-----------------------------\n"); + printf("Results: %d/%d passed.\n", passed, num_cases); + + return (passed == num_cases) ? 0 : 1; +} diff --git a/src/backend/mod.rs b/src/backend/mod.rs new file mode 100644 index 0000000..2a99bf5 --- /dev/null +++ b/src/backend/mod.rs @@ -0,0 +1 @@ +pub mod x86_64; diff --git a/src/backend/x86_64.rs b/src/backend/x86_64.rs new file mode 100644 index 0000000..7186b3c --- /dev/null +++ b/src/backend/x86_64.rs @@ -0,0 +1,708 @@ +use std::collections::{HashMap, HashSet}; +use std::fmt::Write; +use std::mem::take; + +use crate::ir::*; + +#[derive(Clone, Copy, Debug)] +enum Storage { + Hardware(&'static str), + Stack(usize), + Alloc(usize), // Represents the address `-offset(%rbp)` +} + +pub struct X86Backend<'a> { + assembly: String, + module: &'a Module, + live_intervals: HashMap, + allocations: HashMap, +} + +impl<'a> X86Backend<'a> { + pub fn new(module: &'a Module) -> Self { + Self { + assembly: String::new(), + module, + live_intervals: HashMap::new(), + allocations: HashMap::new(), + } + } + + pub fn compile_module(mut self) -> String { + for func in &self.module.functions { + self.compile_function(func); + } + + self.assembly + } + + fn resolve_op( + &mut self, + op: &Operand, + scratch_reg: &str, + allocs: &HashMap, + ) -> String { + match op { + Operand::Integer(v) => format!("${}", v), + Operand::Boolean(b) => format!("${}", if *b { 1 } else { 0 }), + Operand::Register(r) => match allocs.get(r).unwrap() { + Storage::Hardware(hw) => format!("%{}", hw), + Storage::Stack(off) => format!("-{}(%rbp)", off), + Storage::Alloc(off) => { + // Materialize the address dynamically into the provided scratch register + writeln!( + &mut self.assembly, + " leaq -{}(%rbp), {}", + off, scratch_reg + ) + .unwrap(); + scratch_reg.to_string() + } + }, + } + } + + fn mark_use(&mut self, reg: Register, idx: usize) { + if !self.allocations.contains_key(®) { + let entry = self.live_intervals.entry(reg).or_insert((idx, idx)); + entry.1 = idx; + } + } + + fn mark_op(&mut self, op: &Operand, idx: usize) { + if let Operand::Register(r) = op { + self.mark_use(*r, idx); + } + } + + fn compile_function(&mut self, func: &Function) { + // 0. Pre-Pass: Handle Allocations + let mut allocations: HashMap = HashMap::new(); + let mut next_stack_offset = 0; + + for block in &func.blocks { + for inst in &block.instructions { + if let Instruction::Alloc { dest, .. } = inst { + next_stack_offset += 8; + allocations.insert(*dest, Storage::Alloc(next_stack_offset)); + } + } + } + + // 1. Liveness Analysis & Call Tracking + let mut call_indices = Vec::new(); + let mut hints: HashMap = HashMap::new(); + let mut inst_idx = 0; + + for (_, reg) in &func.params { + self.mark_use(*reg, inst_idx); + } + inst_idx += 1; + + for block in &func.blocks { + for inst in &block.instructions { + match inst { + Instruction::Alloc { .. } => {} + Instruction::Call { dest, args, .. } => { + self.mark_use(*dest, inst_idx); + let arg_regs = ["rdi", "rsi", "rdx", "rcx", "r8", "r9"]; + for (i, (_, arg_op)) in args.iter().enumerate() { + self.mark_op(arg_op, inst_idx); + if i < 6 + && let Operand::Register(r) = arg_op + { + hints.insert(*r, arg_regs[i]); + } + } + call_indices.push(inst_idx); + } + Instruction::Load { dest, src, .. } => { + self.mark_use(*dest, inst_idx); + self.mark_op(src, inst_idx); + } + Instruction::Store { dest, src, .. } => { + self.mark_op(dest, inst_idx); + self.mark_op(src, inst_idx); + } + Instruction::Assign { register, operand } => { + self.mark_use(*register, inst_idx); + self.mark_op(operand, inst_idx); + } + Instruction::Binary { + dest, src1, src2, .. + } => { + self.mark_use(*dest, inst_idx); + self.mark_op(src1, inst_idx); + self.mark_op(src2, inst_idx); + } + Instruction::Unary { dest, src, .. } => { + self.mark_use(*dest, inst_idx); + self.mark_op(src, inst_idx); + } + Instruction::Phi { dest, sources, .. } => { + self.mark_use(*dest, inst_idx); + for (op, _) in sources { + self.mark_op(op, inst_idx); + } + } + } + inst_idx += 1; + } + match &block.terminator { + Terminator::Branch { cond, .. } => self.mark_op(cond, inst_idx), + Terminator::Return { + value: Some(val), .. + } => self.mark_op(val, inst_idx), + _ => {} + } + inst_idx += 1; + } + + // 2. ABI-Aware Linear Scan Allocation + let mut free_callee_saved = vec!["rbx", "r12", "r13", "r14", "r15"]; + let mut free_caller_saved = vec!["rdi", "rsi", "rdx", "rcx", "r8", "r9"]; + + let mut active: Vec<(Register, usize, &'static str, bool)> = Vec::new(); + let mut used_callee_saved = HashSet::new(); + + let live_intervals = take(&mut self.live_intervals); + let mut intervals_sorted: Vec<_> = live_intervals.into_iter().collect(); + intervals_sorted.sort_by_key(|(_, (start, _))| *start); + + for (reg, (start, end)) in intervals_sorted { + active.retain(|(_, active_end, hw_reg, is_caller)| { + if *active_end < start { + if *is_caller { + free_caller_saved.push(hw_reg); + } else { + free_callee_saved.push(hw_reg); + } + false + } else { + true + } + }); + + let crosses_call = call_indices + .iter() + .any(|&c_idx| c_idx > start && c_idx < end); + + let mut selected_hw = None; + let mut is_caller = false; + + if !crosses_call { + if let Some(&hint_reg) = hints.get(®) + && let Some(pos) = free_caller_saved.iter().position(|&r| r == hint_reg) + { + selected_hw = Some(free_caller_saved.remove(pos)); + is_caller = true; + } + if selected_hw.is_none() + && let Some(r) = free_caller_saved.pop() + { + selected_hw = Some(r); + is_caller = true; + } + } + + if selected_hw.is_none() + && let Some(r) = free_callee_saved.pop() + { + selected_hw = Some(r); + + used_callee_saved.insert(r); + is_caller = false; + } + + if let Some(hw_reg) = selected_hw { + allocations.insert(reg, Storage::Hardware(hw_reg)); + active.push((reg, end, hw_reg, is_caller)); + } else { + next_stack_offset += 8; + allocations.insert(reg, Storage::Stack(next_stack_offset)); + } + } + + let used_callee_saved: Vec<&'static str> = used_callee_saved.into_iter().collect(); + + // Determine if a Stack Frame is required + let needs_frame = !call_indices.is_empty() + || next_stack_offset > 0 + || !used_callee_saved.is_empty() + || func.params.len() > 6; + + // 3. Prologue + writeln!(&mut self.assembly, " .text").unwrap(); + writeln!(&mut self.assembly, " .globl {}", func.name).unwrap(); + writeln!(&mut self.assembly, " .p2align 4").unwrap(); + writeln!(&mut self.assembly, " .type {},@function", func.name).unwrap(); + writeln!(&mut self.assembly, "{}:", func.name).unwrap(); + + if needs_frame { + writeln!(&mut self.assembly, " pushq %rbp").unwrap(); + writeln!(&mut self.assembly, " movq %rsp, %rbp").unwrap(); + + for reg in &used_callee_saved { + writeln!(&mut self.assembly, " pushq %{}", reg).unwrap(); + } + + let s = next_stack_offset; + let rem = (used_callee_saved.len() * 8 + s) % 16; + let stack_adj = if rem != 0 { s + (16 - rem) } else { s }; + + if stack_adj > 0 { + writeln!(&mut self.assembly, " subq ${}, %rsp", stack_adj).unwrap(); + } + } + + // 4. Map ABI Arguments + let arg_regs = ["rdi", "rsi", "rdx", "rcx", "r8", "r9"]; + for (i, (_, reg)) in func.params.iter().enumerate() { + let dest_str = self.format_dest(*reg, &allocations); + if i < 6 { + writeln!( + &mut self.assembly, + " movq %{}, {}", + arg_regs[i], dest_str + ) + .unwrap(); + } else { + let caller_offset = 16 + ((i - 6) * 8); + writeln!(&mut self.assembly, " movq {}(%rbp), %rax", caller_offset).unwrap(); + writeln!(&mut self.assembly, " movq %rax, {}", dest_str).unwrap(); + } + } + + // 5. Compile Blocks + let num_blocks = func.blocks.len(); + for i in 0..num_blocks { + let block = &func.blocks[i]; + let next_block_id = if i + 1 < num_blocks { + Some(func.blocks[i + 1].id) + } else { + None + }; + + writeln!(&mut self.assembly, ".L{}_block_{}:", func.name, block.id.0).unwrap(); + + // Peephole Optimization: Fuse an immediately preceding ICmp into the Branch + let mut fused_cmp = None; + if let Terminator::Branch { + cond: Operand::Register(cond_reg), + .. + } = block.terminator + && let Some(Instruction::Binary { + dest, + op: BinaryOp::ICmp(cmp_op), + src1, + src2, + .. + }) = block.instructions.last() + && cond_reg == *dest + { + fused_cmp = Some((*cmp_op, *src1, *src2)); + } + + // If we fused the comparison, we omit the last instruction from standard compilation + let inst_limit = if fused_cmp.is_some() { + block.instructions.len() - 1 + } else { + block.instructions.len() + }; + + for inst in &block.instructions[..inst_limit] { + self.compile_instruction(inst, &allocations); + } + + self.compile_terminator( + &block.terminator, + &func.name, + &allocations, + next_block_id, + fused_cmp, + ); + } + + // 6. Unified Epilogue + writeln!(&mut self.assembly, ".L{}_epilogue:", func.name).unwrap(); + + if needs_frame { + let pushes_size = used_callee_saved.len() * 8; + if pushes_size > 0 { + writeln!(&mut self.assembly, " leaq -{}(%rbp), %rsp", pushes_size).unwrap(); + for reg in used_callee_saved.iter().rev() { + writeln!(&mut self.assembly, " popq %{}", reg).unwrap(); + } + } else { + writeln!(&mut self.assembly, " movq %rbp, %rsp").unwrap(); + } + + writeln!(&mut self.assembly, " popq %rbp").unwrap(); + } + + writeln!(&mut self.assembly, " ret\n").unwrap(); + } + + fn compile_instruction(&mut self, inst: &Instruction, allocs: &HashMap) { + match inst { + Instruction::Alloc { .. } => {} // Stack space is already reserved in prologue + Instruction::Assign { register, operand } => { + let dest = self.format_dest(*register, allocs); + let src = self.resolve_op(operand, "%rax", allocs); + self.emit_mov(&src, &dest); + } + Instruction::Load { dest, src, .. } => { + let dest_str = self.format_dest(*dest, allocs); + + // OPTIMIZATION: If reading from an Alloc pointer, read directly from the stack offset + if let Operand::Register(r) = src + && let Some(Storage::Alloc(off)) = allocs.get(r) + { + self.emit_mov(&format!("-{}(%rbp)", off), &dest_str); + return; + } + + let src_str = self.resolve_op(src, "%rax", allocs); + let addr_reg = if src_str.starts_with('-') { + writeln!(&mut self.assembly, " movq {}, %rax", src_str).unwrap(); + "%rax".to_string() + } else { + src_str + }; + + writeln!(&mut self.assembly, " movq ({}), %r10", addr_reg).unwrap(); + self.emit_mov("%r10", &dest_str); + } + Instruction::Store { dest, src, .. } => { + let src_str = self.resolve_op(src, "%rax", allocs); + + // OPTIMIZATION: If writing to an Alloc pointer, write directly to the stack offset + if let Operand::Register(r) = dest + && let Some(Storage::Alloc(off)) = allocs.get(r) + { + self.emit_mov(&src_str, &format!("-{}(%rbp)", off)); + return; + } + + let dest_str = self.resolve_op(dest, "%r10", allocs); + let addr_reg = if dest_str.starts_with('-') { + writeln!(&mut self.assembly, " movq {}, %r10", dest_str).unwrap(); + "%r10".to_string() + } else { + dest_str + }; + + let val_reg = if src_str.starts_with('-') || src_str.starts_with('$') { + writeln!(&mut self.assembly, " movq {}, %rax", src_str).unwrap(); + "%rax".to_string() + } else { + src_str + }; + + writeln!(&mut self.assembly, " movq {}, ({})", val_reg, addr_reg).unwrap(); + } + Instruction::Unary { dest, op, src, .. } => { + let dest_str = self.format_dest(*dest, allocs); + let src_str = self.resolve_op(src, "%rax", allocs); + + self.emit_mov(&src_str, &dest_str); + match op { + UnaryOp::INeg => writeln!(&mut self.assembly, " negq {}", dest_str).unwrap(), + } + } + Instruction::Binary { + dest, + op, + src1, + src2, + .. + } => { + let dest_str = self.format_dest(*dest, allocs); + let src1_str = self.resolve_op(src1, "%r10", allocs); + let src2_str = self.resolve_op(src2, "%r11", allocs); + + match op { + BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul => { + let is_add = matches!(op, BinaryOp::Add); + let dest_hw = !dest_str.starts_with('-'); + let src1_hw = !src1_str.starts_with('-'); + let src2_hw = !src2_str.starts_with('-'); + let src2_imm = src2_str.starts_with('$'); + + if dest_str != src1_str + && dest_hw + && src1_hw + && src2_imm + && !matches!(op, BinaryOp::Mul) + { + let val: i64 = src2_str[1..].parse().unwrap(); + let offset = if is_add { val } else { -val }; + writeln!( + &mut self.assembly, + " leaq {}({}), {}", + offset, src1_str, dest_str + ) + .unwrap(); + } else if dest_str != src1_str && dest_hw && src1_hw && src2_hw && is_add { + writeln!( + &mut self.assembly, + " leaq ({},{}), {}", + src1_str, src2_str, dest_str + ) + .unwrap(); + } else { + self.emit_mov(&src1_str, &dest_str); + let mnemonic = match op { + BinaryOp::Add => "addq", + BinaryOp::Sub => "subq", + BinaryOp::Mul => "imulq", + _ => unreachable!(), + }; + if dest_str.starts_with('-') && src2_str.starts_with('-') { + writeln!(&mut self.assembly, " movq {}, %rax", src2_str) + .unwrap(); + writeln!(&mut self.assembly, " {} %rax, {}", mnemonic, dest_str) + .unwrap(); + } else { + writeln!( + &mut self.assembly, + " {} {}, {}", + mnemonic, src2_str, dest_str + ) + .unwrap(); + } + } + } + BinaryOp::SDiv | BinaryOp::SRem => { + writeln!(&mut self.assembly, " movq {}, %rax", src1_str).unwrap(); + writeln!(&mut self.assembly, " cqto").unwrap(); + + if src2_str.starts_with('$') { + writeln!(&mut self.assembly, " movq {}, %r10", src2_str).unwrap(); + writeln!(&mut self.assembly, " idivq %r10").unwrap(); + } else { + writeln!(&mut self.assembly, " idivq {}", src2_str).unwrap(); + } + + let result_reg = if let BinaryOp::URem = op { + "%rdx" + } else { + "%rax" + }; + self.emit_mov(result_reg, &dest_str); + } + BinaryOp::UDiv | BinaryOp::URem => { + writeln!(&mut self.assembly, " movq {}, %rax", src1_str).unwrap(); + writeln!(&mut self.assembly, " cqto").unwrap(); + + if src2_str.starts_with('$') { + writeln!(&mut self.assembly, " movq {}, %r10", src2_str).unwrap(); + writeln!(&mut self.assembly, " divq %r10").unwrap(); + } else { + writeln!(&mut self.assembly, " divq {}", src2_str).unwrap(); + } + + let result_reg = if let BinaryOp::URem = op { + "%rdx" + } else { + "%rax" + }; + self.emit_mov(result_reg, &dest_str); + } + BinaryOp::ICmp(cmp) => { + if (src1_str.starts_with('-') && src2_str.starts_with('-')) + || src1_str.starts_with('$') + { + writeln!(&mut self.assembly, " movq {}, %rax", src1_str).unwrap(); + writeln!(&mut self.assembly, " cmpq {}, %rax", src2_str).unwrap(); + } else { + writeln!(&mut self.assembly, " cmpq {}, {}", src2_str, src1_str) + .unwrap(); + } + + let set_cc = match cmp { + ICmpOp::Eq => "sete", + ICmpOp::Ne => "setne", + ICmpOp::Slt => "setl", + ICmpOp::Sle => "setle", + ICmpOp::Sgt => "setg", + ICmpOp::Sge => "setge", + ICmpOp::Ult => "setb", + ICmpOp::Ule => "setbe", + ICmpOp::Ugt => "seta", + ICmpOp::Uge => "setae", + }; + + writeln!(&mut self.assembly, " {} %al", set_cc).unwrap(); + writeln!(&mut self.assembly, " movzbq %al, %rax").unwrap(); + self.emit_mov("%rax", &dest_str); + } + } + } + Instruction::Call { + dest, + func: target_id, + args, + .. + } => { + let arg_regs = ["rdi", "rsi", "rdx", "rcx", "r8", "r9"]; + for (i, (_, arg_op)) in args.iter().enumerate() { + let src = self.resolve_op(arg_op, "%r10", allocs); + if i < 6 { + self.emit_mov(&src, &format!("%{}", arg_regs[i])); + } else { + if src.starts_with('-') { + writeln!(&mut self.assembly, " movq {}, %rax", src).unwrap(); + writeln!(&mut self.assembly, " pushq %rax").unwrap(); + } else { + writeln!(&mut self.assembly, " pushq {}", src).unwrap(); + } + } + } + writeln!(&mut self.assembly, " call function_{}", target_id.0).unwrap(); + if args.len() > 6 { + let cleanup_size = (args.len() - 6) * 8; + writeln!(&mut self.assembly, " addq ${}, %rsp", cleanup_size).unwrap(); + } + let dest_str = self.format_dest(*dest, allocs); + self.emit_mov("%rax", &dest_str); + } + Instruction::Phi { .. } => { + unreachable!( + "Phi nodes cannot be compiled to x86-64 natively. You must run an SSA-Destruction (Out-of-SSA) pass before code generation!" + ); + } + } + } + + fn compile_terminator( + &mut self, + term: &Terminator, + func_name: &str, + allocs: &HashMap, + next_block_id: Option, + fused_cmp: Option<(ICmpOp, Operand, Operand)>, + ) { + match term { + Terminator::Return { value, .. } => { + if let Some(val) = value { + let src = self.resolve_op(val, "%r10", allocs); + self.emit_mov(&src, "%rax"); + } + writeln!(&mut self.assembly, " jmp .L{}_epilogue", func_name).unwrap(); + } + Terminator::Jump(target) => { + if Some(*target) != next_block_id { + writeln!( + &mut self.assembly, + " jmp .L{}_block_{}", + func_name, target.0 + ) + .unwrap(); + } + } + Terminator::Branch { + cond, + then_block, + else_block, + } => { + // Determine the condition codes based on whether we fused an ICmp + let (jump_cond_true, jump_cond_false) = + if let Some((cmp_op, src1, src2)) = fused_cmp { + let src1_str = self.resolve_op(&src1, "%r10", allocs); + let src2_str = self.resolve_op(&src2, "%r10", allocs); + + // Emit the comparison directly inside the terminator + if (src1_str.starts_with('-') && src2_str.starts_with('-')) + || src1_str.starts_with('$') + { + writeln!(&mut self.assembly, " movq {}, %rax", src1_str).unwrap(); + writeln!(&mut self.assembly, " cmpq {}, %rax", src2_str).unwrap(); + } else { + writeln!(&mut self.assembly, " cmpq {}, {}", src2_str, src1_str) + .unwrap(); + } + + // Map IR ICmpOp to native AT&T condition suffixes (true_jump, false_jump) + match cmp_op { + ICmpOp::Eq => ("e", "ne"), + ICmpOp::Ne => ("ne", "e"), + ICmpOp::Slt => ("l", "ge"), + ICmpOp::Sle => ("le", "g"), + ICmpOp::Sgt => ("g", "le"), + ICmpOp::Sge => ("ge", "l"), + ICmpOp::Ult => ("b", "ae"), + ICmpOp::Ule => ("be", "a"), + ICmpOp::Ugt => ("a", "be"), + ICmpOp::Uge => ("ae", "b"), + } + } else { + // Standard fallback: evaluating an isolated boolean + let cond_str = self.resolve_op(cond, "%r10", allocs); + if cond_str.starts_with('$') { + writeln!(&mut self.assembly, " movq {}, %rax", cond_str).unwrap(); + writeln!(&mut self.assembly, " testq %rax, %rax").unwrap(); + } else { + writeln!(&mut self.assembly, " testq {}, {}", cond_str, cond_str) + .unwrap(); + } + ("nz", "z") // true = not zero, false = zero + }; + + // Fallthrough logic cleanly applied to dynamically `d jump conditions + if Some(*else_block) == next_block_id { + writeln!( + &mut self.assembly, + " j{} .L{}_block_{}", + jump_cond_true, func_name, then_block.0 + ) + .unwrap(); + } else if Some(*then_block) == next_block_id { + writeln!( + &mut self.assembly, + " j{} .L{}_block_{}", + jump_cond_false, func_name, else_block.0 + ) + .unwrap(); + } else { + writeln!( + &mut self.assembly, + " j{} .L{}_block_{}", + jump_cond_false, func_name, else_block.0 + ) + .unwrap(); + writeln!( + &mut self.assembly, + " jmp .L{}_block_{}", + func_name, then_block.0 + ) + .unwrap(); + } + } + Terminator::Unknown => panic!("Cannot compile Unknown terminator"), + } + } + + // Helpers + + fn format_dest(&self, reg: Register, allocs: &HashMap) -> String { + match allocs.get(®).unwrap() { + Storage::Hardware(hw) => format!("%{}", hw), + Storage::Stack(off) | Storage::Alloc(off) => format!("-{}(%rbp)", off), + } + } + + /// Emits a move instruction, gracefully handling memory-to-memory constraints + fn emit_mov(&mut self, src: &str, dest: &str) { + if src == dest { + return; + } + + if src.starts_with('-') && dest.starts_with('-') { + writeln!(&mut self.assembly, " movq {}, %rax", src).unwrap(); + writeln!(&mut self.assembly, " movq %rax, {}", dest).unwrap(); + } else { + writeln!(&mut self.assembly, " movq {}, {}", src, dest).unwrap(); + } + } +} diff --git a/src/builder.rs b/src/builder.rs new file mode 100644 index 0000000..e6249de --- /dev/null +++ b/src/builder.rs @@ -0,0 +1,327 @@ +use crate::ir::*; + +pub struct IrModuleBuilder { + next_function_id: usize, + current_function_builder: Option, + + module: Module, +} + +impl Default for IrModuleBuilder { + fn default() -> Self { + Self::new() + } +} + +impl IrModuleBuilder { + /// Create a new [IrModuleBuilder]. + pub fn new() -> Self { + Self { + next_function_id: 0, + current_function_builder: None, + + module: Module { + functions: Vec::new(), + }, + } + } + + /// Helper function for allocating a new [FunctionId]. + pub fn new_function_id(&mut self) -> FunctionId { + let function_id = FunctionId(self.next_function_id); + self.next_function_id += 1; + function_id + } + + /// Creates a new [IrFunctionBuilder] with the given [FunctionId], name, parameter and return [Type]s. + pub fn new_function<'a>( + &mut self, + id: FunctionId, + name: impl ToString, + params_tys: impl IntoIterator, + return_ty: Type, + ) -> &mut IrFunctionBuilder { + self.current_function_builder = + Some(IrFunctionBuilder::new(id, name, params_tys, return_ty)); + + self.current_function_builder.as_mut().unwrap() + } + + /// Completes the function building process. + pub fn complete_function(&mut self) { + let function_builder = self + .current_function_builder + .take() + .expect("please call `new_function` before calling `complete_function`"); + + let function = function_builder.finish(); + + self.module.functions.push(function); + } + + /// Finishes the building process returning the built [Module]. + pub fn finish(self) -> Module { + self.module + } +} + +pub struct IrFunctionBuilder { + next_register_id: usize, + next_block_id: usize, + + function: Function, + active_block_id: Option, +} + +impl IrFunctionBuilder { + /// Creates a new [IrFunctionBuilder] for a [Function] with the given [FunctionId], name, parameter and return [Type]s. + pub fn new<'a>( + id: FunctionId, + name: impl ToString, + params_tys: impl IntoIterator, + return_ty: Type, + ) -> Self { + let mut next_register_id = 0; + let mut params = Vec::new(); + + for param_ty in params_tys.into_iter().copied() { + let register = Register(next_register_id); + next_register_id += 1; + params.push((param_ty, register)); + } + + let entry_block_id = BlockId(0); + let entry_block = BasicBlock { + id: entry_block_id, + instructions: Vec::new(), + terminator: Terminator::Unknown, + }; + + let function = Function { + id, + name: name.to_string(), + params, + return_ty, + blocks: vec![entry_block], + entry_block_id, + }; + + Self { + next_register_id, + next_block_id: 1, + + function, + active_block_id: Some(entry_block_id), + } + } + + /// Finishes the building process returning the built [Function]. + pub fn finish(self) -> Function { + self.function + } + + /// Helper function for allocating new [Register]s. + fn allocate_register(&mut self) -> Register { + let register = Register(self.next_register_id); + self.next_register_id += 1; + register + } + + /// Retruns a mutable reference to the currently active [BasicBlock]. + fn get_active_block_mut(&mut self) -> &mut BasicBlock { + let active_block_id = self.active_block_id.expect("no active block selected"); + + self.function + .blocks + .iter_mut() + .find(|block| block.id == active_block_id) + .expect("failed to find BasicBlock by its id") + } + + /// Returns a reference to the currently active [BasicBlock]. + fn get_active_block(&self) -> &BasicBlock { + let active_block_id = self.active_block_id.expect("no active block selected"); + + self.function + .blocks + .iter() + .find(|block| block.id == active_block_id) + .expect("failed to find BasicBlock by its id") + } + + /// A helper function for inserting [Instruction]s into the active block. + fn insert_instruction(&mut self, instruction: Instruction) { + assert!(!self.is_block_sealed(), "cannot insert into sealed block"); + self.get_active_block_mut().instructions.push(instruction); + } + + /// A helper function for setting the [Terminator] of the active block. + fn set_terminator(&mut self, terminator: Terminator) { + assert!( + !self.is_block_sealed(), + "cannot set terminator of sealed block" + ); + + self.get_active_block_mut().terminator = terminator; + } + + /// A helper function for building [Instruction::Binary]. + fn build_binary( + &mut self, + result_ty: Type, + op: BinaryOp, + src1: Operand, + src2: Operand, + ) -> Operand { + let dest = self.allocate_register(); + self.insert_instruction(Instruction::Binary { + dest, + result_ty, + op, + src1, + src2, + }); + Operand::Register(dest) + } + + /// Fetches the parameter as an [Operand]. + pub fn get_param(&self, param_index: usize) -> Option { + self.function + .params + .get(param_index) + .map(|param| Operand::Register(param.1)) + } + + /// Creates a new [BasicBlock] and returns its [BlockId]. + pub fn create_block(&mut self) -> BlockId { + let id = BlockId(self.next_block_id); + self.next_block_id += 1; + self.function.blocks.push(BasicBlock { + id, + instructions: Vec::new(), + terminator: Terminator::Unknown, + }); + id + } + + /// Returns whether a [BasicBlock] has been sealed. + pub fn is_block_sealed(&self) -> bool { + !matches!(self.get_active_block().terminator, Terminator::Unknown) + } + + /// Sets the currently active block. + pub fn switch_to_block(&mut self, block_id: BlockId) { + self.active_block_id = Some(block_id); + } + + /// Builds a `alloc` instruction. + pub fn build_alloc(&mut self, ty: Type) -> Operand { + let dest = self.allocate_register(); + self.insert_instruction(Instruction::Alloc { dest, ty }); + Operand::Register(dest) + } + + /// Builds a `load` instruction. + pub fn build_load(&mut self, ty: Type, ptr: Operand) -> Operand { + let dest = self.allocate_register(); + self.insert_instruction(Instruction::Load { dest, ty, src: ptr }); + Operand::Register(dest) + } + + /// Builds a `store` instruction. + pub fn build_store(&mut self, ty: Type, ptr: Operand, value: Operand) { + self.insert_instruction(Instruction::Store { + dest: ptr, + ty, + src: value, + }); + } + + /// Builds an `add` instruction. + pub fn build_add(&mut self, result_ty: Type, lhs: Operand, rhs: Operand) -> Operand { + self.build_binary(result_ty, BinaryOp::Add, lhs, rhs) + } + + /// Builds an `sub` instruction. + pub fn build_sub(&mut self, result_ty: Type, lhs: Operand, rhs: Operand) -> Operand { + self.build_binary(result_ty, BinaryOp::Sub, lhs, rhs) + } + + /// Builds an `mul` instruction. + pub fn build_mul(&mut self, result_ty: Type, lhs: Operand, rhs: Operand) -> Operand { + self.build_binary(result_ty, BinaryOp::Mul, lhs, rhs) + } + + /// Builds an `udiv` instruction. + pub fn build_udiv(&mut self, result_ty: Type, lhs: Operand, rhs: Operand) -> Operand { + self.build_binary(result_ty, BinaryOp::UDiv, lhs, rhs) + } + + /// Builds an `sdiv` instruction. + pub fn build_sdiv(&mut self, result_ty: Type, lhs: Operand, rhs: Operand) -> Operand { + self.build_binary(result_ty, BinaryOp::SDiv, lhs, rhs) + } + + /// Builds an `urem` instruction. + pub fn build_urem(&mut self, result_ty: Type, lhs: Operand, rhs: Operand) -> Operand { + self.build_binary(result_ty, BinaryOp::URem, lhs, rhs) + } + + /// Builds an `srem` instruction. + pub fn build_srem(&mut self, result_ty: Type, lhs: Operand, rhs: Operand) -> Operand { + self.build_binary(result_ty, BinaryOp::SRem, lhs, rhs) + } + + /// Builds an `icmp` instruction. + pub fn build_icmp(&mut self, cmp_op: ICmpOp, lhs: Operand, rhs: Operand) -> Operand { + self.build_binary(Type::Bool, BinaryOp::ICmp(cmp_op), lhs, rhs) + } + + /// Builds a `call` instruction. + pub fn build_call<'a>( + &mut self, + result_ty: Type, + func: FunctionId, + args: impl IntoIterator, + ) -> Operand { + let dest = self.allocate_register(); + let args = args.into_iter().copied().collect(); + self.insert_instruction(Instruction::Call { + dest, + result_ty, + func, + args, + }); + Operand::Register(dest) + } + + /// Builds a `branch` instruction. + pub fn build_branch(&mut self, cond: Operand, then_block: BlockId, else_block: BlockId) { + self.set_terminator(Terminator::Branch { + cond, + then_block, + else_block, + }); + } + + /// Builds a `jump` instruction. + pub fn build_jump(&mut self, target_block: BlockId) { + self.set_terminator(Terminator::Jump(target_block)); + } + + /// Builds a `return` instruction with a value. + pub fn build_return(&mut self, return_ty: Type, value: Operand) { + self.set_terminator(Terminator::Return { + return_ty, + value: Some(value), + }); + } + + /// Builds a `return` instruction without a value. + pub fn build_return_void(&mut self) { + self.set_terminator(Terminator::Return { + return_ty: Type::Void, + value: None, + }); + } +} diff --git a/src/ir.rs b/src/ir.rs new file mode 100644 index 0000000..3d331dc --- /dev/null +++ b/src/ir.rs @@ -0,0 +1,150 @@ +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum Type { + Bool, + I8, + I16, + I32, + I64, + Ptr, + Void, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct Register(pub usize); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum Operand { + Integer(u64), + Boolean(bool), + Register(Register), +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum ICmpOp { + Slt, + Sle, + Sgt, + Sge, + Ult, + Ule, + Ugt, + Uge, + Eq, + Ne, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum BinaryOp { + Add, + Sub, + UDiv, + SDiv, + Mul, + SRem, + URem, + ICmp(ICmpOp), +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum UnaryOp { + INeg, +} + +#[derive(Debug, PartialEq, Eq)] +pub enum Instruction { + Alloc { + dest: Register, + ty: Type, + }, + + Load { + ty: Type, + dest: Register, + src: Operand, + }, + + Store { + ty: Type, + dest: Operand, + src: Operand, + }, + + Binary { + dest: Register, + result_ty: Type, + op: BinaryOp, + src1: Operand, + src2: Operand, + }, + + Unary { + dest: Register, + result_ty: Type, + op: UnaryOp, + src: Operand, + }, + + Call { + dest: Register, + result_ty: Type, + func: FunctionId, + args: Vec<(Type, Operand)>, + }, + + Assign { + register: Register, + operand: Operand, + }, + + Phi { + dest: Register, + result_ty: Type, + sources: Vec<(Operand, BlockId)>, + }, +} + +#[derive(Debug, PartialEq, Eq)] +pub enum Terminator { + Branch { + cond: Operand, + then_block: BlockId, + else_block: BlockId, + }, + + Return { + return_ty: Type, + value: Option, + }, + + Jump(BlockId), + + Unknown, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct BlockId(pub usize); + +#[derive(Debug, PartialEq, Eq)] +pub struct BasicBlock { + pub id: BlockId, + pub instructions: Vec, + pub terminator: Terminator, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct FunctionId(pub usize); + +#[derive(Debug, PartialEq, Eq)] +pub struct Function { + pub id: FunctionId, + pub name: String, + pub params: Vec<(Type, Register)>, + pub return_ty: Type, + pub blocks: Vec, + pub entry_block_id: BlockId, +} + +#[derive(Debug, PartialEq, Eq)] +pub struct Module { + pub functions: Vec, +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..c44c2a5 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,6 @@ +pub mod backend; +pub mod builder; +pub mod ir; +pub mod passes; +pub mod printer; +pub mod validate; diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..24e83bb --- /dev/null +++ b/src/main.rs @@ -0,0 +1,64 @@ +use scarlett::{ + backend::x86_64::X86Backend, + builder::IrModuleBuilder, + ir::{ICmpOp, Operand, Type}, + passes, + validate::validate_module, +}; + +fn main() { + let mut module_builder = IrModuleBuilder::new(); + + // 1. Build `gcd(a: i64, b: i64) -> i64` + { + let gcd_id = module_builder.new_function_id(); + let builder = + module_builder.new_function(gcd_id, "gcd", &[Type::I64, Type::I64], Type::I64); + + let ptr_x = builder.build_alloc(Type::I64); + let ptr_y = builder.build_alloc(Type::I64); + + let param_0 = builder.get_param(0).unwrap(); + builder.build_store(Type::I64, ptr_x, param_0); + + let param_1 = builder.get_param(1).unwrap(); + builder.build_store(Type::I64, ptr_y, param_1); + + let loop_cond = builder.create_block(); + let loop_body = builder.create_block(); + let loop_merge = builder.create_block(); + + builder.build_jump(loop_cond); + builder.switch_to_block(loop_cond); + + let val_y = builder.build_load(Type::I64, ptr_y); + let cond = builder.build_icmp(ICmpOp::Ne, val_y, Operand::Integer(0)); + + builder.build_branch(cond, loop_body, loop_merge); + builder.switch_to_block(loop_body); + + let val_x = builder.build_load(Type::I64, ptr_x); + let val_y = builder.build_load(Type::I64, ptr_y); + let rem = builder.build_urem(Type::I64, val_x, val_y); + + builder.build_store(Type::I64, ptr_x, val_y); + builder.build_store(Type::I64, ptr_y, rem); + + builder.build_jump(loop_cond); + builder.switch_to_block(loop_merge); + + let val_x = builder.build_load(Type::I64, ptr_x); + builder.build_return(Type::I64, val_x); + + module_builder.complete_function(); + } + + // 2. Finish, Validate, Optimize, and Compile + let mut module = module_builder.finish(); + + validate_module(&module).expect("failed to validate module"); + passes::optimize(&mut module); + + let assembly = X86Backend::new(&module).compile_module(); + println!("{}", assembly); +} diff --git a/src/passes/cfp.rs b/src/passes/cfp.rs new file mode 100644 index 0000000..733fff7 --- /dev/null +++ b/src/passes/cfp.rs @@ -0,0 +1,209 @@ +use std::collections::HashMap; + +use crate::ir::*; + +/// Runs the constant folding pass over the entire module, modifying it in place. +pub fn fold_constants(module: &mut Module) { + for func in &mut module.functions { + fold_function_constants(func); + } +} + +fn fold_function_constants(func: &mut Function) { + for block in &mut func.blocks { + let mut known_constants: HashMap = HashMap::new(); + + for inst in &mut block.instructions { + // 1. Substitute any known constants into the current instruction + substitute_constants_in_inst(inst, &known_constants); + + // 2. Evaluate and rewrite instructions where possible + match inst { + Instruction::Alloc { .. } => {} + Instruction::Assign { register, operand } => { + if is_constant(operand) { + known_constants.insert(*register, *operand); + } + } + Instruction::Binary { + dest, + op, + src1, + src2, + .. + } => { + if let Some(folded) = evaluate_binary(*op, src1, src2) { + known_constants.insert(*dest, folded); + + // Rewrite the evaluated Binary instruction into a clean Assign + *inst = Instruction::Assign { + register: *dest, + operand: folded, + }; + } + } + Instruction::Unary { dest, op, src, .. } => { + if let Some(folded) = evaluate_unary(*op, src) { + known_constants.insert(*dest, folded); + + // Rewrite the evaluated Unary instruction into a clean Assign + *inst = Instruction::Assign { + register: *dest, + operand: folded, + }; + } + } + // Memory and control flow boundaries cannot be statically folded here + Instruction::Load { .. } + | Instruction::Store { .. } + | Instruction::Call { .. } + | Instruction::Phi { .. } => {} + } + } + + // 3. Evaluate terminators + substitute_constants_in_terminator(&mut block.terminator, &known_constants); + + if let Terminator::Branch { + cond: Operand::Boolean(b), + then_block, + else_block, + } = block.terminator + { + block.terminator = Terminator::Jump(if b { then_block } else { else_block }); + } + } +} + +// --- Helper Functions --- + +fn is_constant(op: &Operand) -> bool { + matches!(op, Operand::Integer(_) | Operand::Boolean(_)) +} + +fn substitute_constants_in_inst(inst: &mut Instruction, constants: &HashMap) { + let replace = |op: &mut Operand| { + if let Operand::Register(r) = op + && let Some(c) = constants.get(r) + { + *op = *c; + } + }; + + match inst { + Instruction::Alloc { .. } => {} + Instruction::Assign { operand, .. } => replace(operand), + Instruction::Load { src, .. } => replace(src), + Instruction::Store { dest, src, .. } => { + replace(dest); + replace(src); + } + Instruction::Binary { src1, src2, .. } => { + replace(src1); + replace(src2); + } + Instruction::Unary { src, .. } => replace(src), + Instruction::Call { args, .. } => { + for (_, arg_op) in args { + replace(arg_op); + } + } + Instruction::Phi { sources, .. } => { + for (op, _) in sources.iter_mut() { + replace(op); + } + } + } +} + +fn substitute_constants_in_terminator( + term: &mut Terminator, + constants: &HashMap, +) { + let replace = |op: &mut Operand| { + if let Operand::Register(r) = op + && let Some(c) = constants.get(r) + { + *op = *c; + } + }; + + match term { + Terminator::Branch { cond, .. } => replace(cond), + Terminator::Return { + value: Some(val), .. + } => replace(val), + _ => {} + } +} + +fn evaluate_binary(op: BinaryOp, src1: &Operand, src2: &Operand) -> Option { + match (src1, src2) { + (Operand::Integer(a), Operand::Integer(b)) => { + let a = *a; + let b = *b; + + match op { + BinaryOp::Add => Some(Operand::Integer(a.wrapping_add(b))), + BinaryOp::Sub => Some(Operand::Integer(a.wrapping_sub(b))), + BinaryOp::Mul => Some(Operand::Integer(a.wrapping_mul(b))), + BinaryOp::UDiv => { + if b != 0 { + Some(Operand::Integer(a.wrapping_div(b))) + } else { + None + } + } + BinaryOp::SDiv => { + if b != 0 { + Some(Operand::Integer((a as i64).wrapping_div(b as i64) as u64)) + } else { + None + } + } + BinaryOp::URem => { + if b != 0 { + Some(Operand::Integer(a.wrapping_rem(b))) + } else { + None + } + } + BinaryOp::SRem => { + if b != 0 { + Some(Operand::Integer((a as i64).wrapping_rem(b as i64) as u64)) + } else { + None + } + } + BinaryOp::ICmp(cmp) => { + let res = match cmp { + ICmpOp::Eq => a == b, + ICmpOp::Ne => a != b, + ICmpOp::Ult => a < b, + ICmpOp::Ule => a <= b, + ICmpOp::Ugt => a > b, + ICmpOp::Uge => a >= b, + ICmpOp::Slt => (a as i64) < (b as i64), + ICmpOp::Sle => (a as i64) <= (b as i64), + ICmpOp::Sgt => (a as i64) > (b as i64), + ICmpOp::Sge => (a as i64) >= (b as i64), + }; + Some(Operand::Boolean(res)) + } + } + } + (Operand::Boolean(a), Operand::Boolean(b)) => match op { + BinaryOp::ICmp(ICmpOp::Eq) => Some(Operand::Boolean(a == b)), + BinaryOp::ICmp(ICmpOp::Ne) => Some(Operand::Boolean(a != b)), + _ => None, + }, + _ => None, + } +} + +fn evaluate_unary(op: UnaryOp, src: &Operand) -> Option { + match (op, src) { + (UnaryOp::INeg, Operand::Integer(a)) => Some(Operand::Integer(a.wrapping_neg())), + _ => None, + } +} diff --git a/src/passes/cpp.rs b/src/passes/cpp.rs new file mode 100644 index 0000000..5dfaeb2 --- /dev/null +++ b/src/passes/cpp.rs @@ -0,0 +1,105 @@ +use std::collections::HashMap; + +use crate::ir::*; + +/// Runs the Copy Propagation pass to eliminate redundant register-to-register moves. +pub fn propagate_copies(module: &mut Module) { + for func in &mut module.functions { + propagate_copies_in_func(func); + } +} + +fn propagate_copies_in_func(func: &mut Function) { + let mut aliases: HashMap = HashMap::new(); + + // Helper to fully resolve an operand through any chain of aliases + let resolve = |mut op: Operand, aliases: &HashMap| -> Operand { + while let Operand::Register(r) = op { + if let Some(aliased_to) = aliases.get(&r) { + op = *aliased_to; + } else { + break; + } + } + op + }; + + // 1. Scan for pure copy instructions (Assigning a Register to a Register) + for block in &func.blocks { + for inst in &block.instructions { + if let Instruction::Assign { register, operand } = inst { + // We only unconditionally propagate register-to-register copies here. + // The constant folding pass handles propagating literals. + if let Operand::Register(_) = operand { + let root_operand = resolve(*operand, &aliases); + aliases.insert(*register, root_operand); + } + } + } + } + + if aliases.is_empty() { + return; // Early exit if there are no copies to propagate + } + + // 2. Replace all uses of aliased registers with their root source + let replace = |op: &mut Operand| { + if let Operand::Register(r) = op + && let Some(alias) = aliases.get(r) + { + *op = *alias; + } + }; + + for block in &mut func.blocks { + for inst in &mut block.instructions { + match inst { + Instruction::Load { src, .. } => replace(src), + Instruction::Store { dest, src, .. } => { + replace(dest); + replace(src); + } + Instruction::Binary { src1, src2, .. } => { + replace(src1); + replace(src2); + } + Instruction::Unary { src, .. } => replace(src), + Instruction::Call { args, .. } => { + for (_, arg) in args { + replace(arg); + } + } + Instruction::Phi { sources, .. } => { + for (op, _) in sources { + replace(op); + } + } + Instruction::Assign { operand, .. } => replace(operand), + Instruction::Alloc { .. } => {} + } + } + + match &mut block.terminator { + Terminator::Branch { cond, .. } => replace(cond), + Terminator::Return { + value: Some(val), .. + } => replace(val), + _ => {} + } + } + + // 3. Clean up the now-useless copy instructions + for block in &mut func.blocks { + block.instructions.retain(|inst| { + // Drop any Assign instruction where the operand was a Register, + // since we just successfully propagated it everywhere it was used. + !matches!( + inst, + Instruction::Assign { + operand: Operand::Register(_), + .. + } + ) + }); + } +} diff --git a/src/passes/dce.rs b/src/passes/dce.rs new file mode 100644 index 0000000..a52c29d --- /dev/null +++ b/src/passes/dce.rs @@ -0,0 +1,134 @@ +use std::collections::HashSet; + +use crate::ir::*; + +/// Runs the dead code elimination pass over the entire module, modifying it in place. +pub fn eliminate_dead_code(module: &mut Module) { + for func in &mut module.functions { + eliminate_dead_code_in_func(func); + } +} + +fn eliminate_dead_code_in_func(func: &mut Function) { + let mut changed = true; + + // Loop until we complete a full pass without removing any instructions + while changed { + changed = false; + + // 1. UNREACHABLE BLOCK ELIMINATION + // Find all blocks reachable from the entry point + let mut reachable_blocks = HashSet::new(); + let mut worklist = vec![func.entry_block_id]; + reachable_blocks.insert(func.entry_block_id); + + while let Some(current_id) = worklist.pop() { + // Find the current block to inspect its terminator + if let Some(block) = func.blocks.iter().find(|b| b.id == current_id) { + match block.terminator { + Terminator::Branch { + then_block, + else_block, + .. + } => { + if reachable_blocks.insert(then_block) { + worklist.push(then_block); + } + if reachable_blocks.insert(else_block) { + worklist.push(else_block); + } + } + Terminator::Jump(target) => { + if reachable_blocks.insert(target) { + worklist.push(target); + } + } + Terminator::Return { .. } | Terminator::Unknown => {} + } + } + } + + // Remove any block that is not in the reachable set + let original_block_count = func.blocks.len(); + func.blocks + .retain(|block| reachable_blocks.contains(&block.id)); + + if func.blocks.len() != original_block_count { + changed = true; + } + + // 2. INSTRUCTION-LEVEL DEAD CODE ELIMINATION + // 2.1. Collect all registers that are currently being used + let mut used_registers = HashSet::new(); + + let mut mark_used = |op: &Operand| { + if let Operand::Register(r) = op { + used_registers.insert(*r); + } + }; + + for block in &func.blocks { + // Scan instructions for uses + for inst in &block.instructions { + match inst { + Instruction::Alloc { .. } => {} + Instruction::Load { src, .. } => mark_used(src), + Instruction::Store { dest, src, .. } => { + mark_used(dest); // The pointer being written to + mark_used(src); // The value being written + } + Instruction::Binary { src1, src2, .. } => { + mark_used(src1); + mark_used(src2); + } + Instruction::Unary { src, .. } => mark_used(src), + Instruction::Call { args, .. } => { + for (_, arg) in args { + mark_used(arg); + } + } + Instruction::Assign { operand, .. } => mark_used(operand), + Instruction::Phi { sources, .. } => { + sources.iter().for_each(|(op, _)| mark_used(op)) + } + } + } + + // Scan terminators for uses + match &block.terminator { + Terminator::Branch { cond, .. } => mark_used(cond), + Terminator::Return { + value: Some(val), .. + } => mark_used(val), + _ => {} + } + } + + // 2.2. Remove instructions whose destination registers are never used + for block in &mut func.blocks { + let original_len = block.instructions.len(); + + block.instructions.retain(|inst| { + match inst { + // PURE INSTRUCTIONS: Can be safely removed if their result is ignored + Instruction::Alloc { dest, .. } + | Instruction::Assign { register: dest, .. } + | Instruction::Binary { dest, .. } + | Instruction::Unary { dest, .. } + | Instruction::Load { dest, .. } + | Instruction::Phi { dest, .. } => used_registers.contains(dest), + + // SIDE-EFFECT INSTRUCTIONS: Must never be removed, + // even if their returned value is ignored. + // Store writes to memory. Call executes arbitrary function logic. + Instruction::Store { .. } | Instruction::Call { .. } => true, + } + }); + + // If the block length shrank, we removed dead code and must scan again + if block.instructions.len() != original_len { + changed = true; + } + } + } +} diff --git a/src/passes/des.rs b/src/passes/des.rs new file mode 100644 index 0000000..0ca9347 --- /dev/null +++ b/src/passes/des.rs @@ -0,0 +1,47 @@ +use std::collections::HashMap; + +use crate::ir::*; + +/// Runs the SSA-Destruction pass, converting Phi nodes into explicit Assign instructions. +pub fn destroy_ssa(module: &mut Module) { + for func in &mut module.functions { + destroy_ssa_in_func(func); + } +} + +fn destroy_ssa_in_func(func: &mut Function) { + // A map tracking Predecessor BlockId -> List of Assign instructions to inject + let mut pending_moves: HashMap> = HashMap::new(); + + // 1. Scan for Phi nodes and record the required move instructions + for block in &func.blocks { + for inst in &block.instructions { + if let Instruction::Phi { dest, sources, .. } = inst { + for (op, pred_id) in sources { + let assign = Instruction::Assign { + register: *dest, + operand: *op, + }; + // Queue the assignment to be inserted into the predecessor block + pending_moves.entry(*pred_id).or_default().push(assign); + } + } + } + } + + // 2. Mutate the blocks: Strip Phis and inject the queued Assigns + for block in &mut func.blocks { + // Strip out the conceptual Phi nodes + block + .instructions + .retain(|inst| !matches!(inst, Instruction::Phi { .. })); + + // If this block is a predecessor that needs to supply a value to a Phi, + // append the generated Assign instructions to the end of the block. + // Because they are appended to `instructions`, they will conceptually execute + // immediately before the block's `terminator`. + if let Some(mut moves) = pending_moves.remove(&block.id) { + block.instructions.append(&mut moves); + } + } +} diff --git a/src/passes/m2r.rs b/src/passes/m2r.rs new file mode 100644 index 0000000..2e449b6 --- /dev/null +++ b/src/passes/m2r.rs @@ -0,0 +1,260 @@ +use std::collections::{HashMap, HashSet}; + +use crate::ir::*; + +pub fn mem2reg(module: &mut Module) { + for func in &mut module.functions { + promote_allocs_in_func(func); + } +} + +fn promote_allocs_in_func(func: &mut Function) { + // 1. Identify Promotable Allocs + let mut promotable_allocs: HashMap = HashMap::new(); + let mut escaped_allocs: HashSet = HashSet::new(); + + for block in &func.blocks { + for inst in &block.instructions { + match inst { + Instruction::Alloc { dest, ty } => { + promotable_allocs.insert(*dest, *ty); + } + Instruction::Store { + src: Operand::Register(r), + .. + } => { + escaped_allocs.insert(*r); + // The dest is expected to be the alloc pointer, which is safe. + } + Instruction::Load { + src: Operand::Register(_), + .. + } => { + // Safe use of the alloc pointer + } + Instruction::Call { args, .. } => { + for (_, arg) in args { + if let Operand::Register(r) = arg { + escaped_allocs.insert(*r); + } + } + } + Instruction::Binary { src1, src2, .. } => { + if let Operand::Register(r) = src1 { + escaped_allocs.insert(*r); + } + if let Operand::Register(r) = src2 { + escaped_allocs.insert(*r); + } + } + Instruction::Unary { src, .. } | Instruction::Assign { operand: src, .. } => { + if let Operand::Register(r) = src { + escaped_allocs.insert(*r); + } + } + Instruction::Phi { sources, .. } => { + for (op, _) in sources { + if let Operand::Register(r) = op { + escaped_allocs.insert(*r); + } + } + } + _ => {} + } + } + } + + // Filter out allocs that escaped or were assigned to other pointers + promotable_allocs.retain(|reg, _| !escaped_allocs.contains(reg)); + + if promotable_allocs.is_empty() { + return; // Nothing to promote + } + + // 2. Build CFG (Predecessors & Successors) + let mut preds: HashMap> = HashMap::new(); + let mut succs: HashMap> = HashMap::new(); + + for block in &func.blocks { + preds.entry(block.id).or_default(); + let targets = match block.terminator { + Terminator::Branch { + then_block, + else_block, + .. + } => vec![then_block, else_block], + Terminator::Jump(target) => vec![target], + _ => vec![], + }; + for target in targets { + preds.entry(target).or_default().push(block.id); + succs.entry(block.id).or_default().push(target); + } + } + + // 3. Compute Reverse Post-Order (RPO) for predictable traversal + let mut rpo = Vec::new(); + let mut visited = HashSet::new(); + + fn dfs( + b: BlockId, + succs: &HashMap>, + visited: &mut HashSet, + rpo: &mut Vec, + ) { + visited.insert(b); + if let Some(targets) = succs.get(&b) { + for &t in targets { + if !visited.contains(&t) { + dfs(t, succs, visited, rpo); + } + } + } + rpo.push(b); + } + + dfs(func.entry_block_id, &succs, &mut visited, &mut rpo); + rpo.reverse(); + + // --- 4. Setup Register Generator & Definitions State --- + let mut max_reg = 0; + for block in &func.blocks { + for inst in &block.instructions { + let mut check_reg = |r: Register| { + if r.0 > max_reg { + max_reg = r.0; + } + }; + match inst { + Instruction::Alloc { dest, .. } + | Instruction::Load { dest, .. } + | Instruction::Assign { register: dest, .. } + | Instruction::Binary { dest, .. } + | Instruction::Unary { dest, .. } + | Instruction::Call { dest, .. } + | Instruction::Phi { dest, .. } => check_reg(*dest), + + Instruction::Store { .. } => {} + } + } + } + + let mut next_reg = || { + max_reg += 1; + Register(max_reg) + }; + + // Tracks the current active SSA value for an Alloc at the exit of a block + let mut block_out_defs: HashMap> = HashMap::new(); + + // Tracks the Phis we generate so we can wire them up later + // Map: JoinBlockId -> Map + let mut generated_phis: HashMap> = HashMap::new(); + + // Initialize the entry block with dummy zeroes for safety (uninitialized variables) + let mut initial_defs = HashMap::new(); + for &alloc_reg in promotable_allocs.keys() { + initial_defs.insert(alloc_reg, Operand::Integer(0)); + } + block_out_defs.insert(func.entry_block_id, initial_defs); + + // 5. Forward Propagate Definitions & Inject Phis + for &block_id in &rpo { + let block_preds = preds.get(&block_id).unwrap(); + let mut local_defs = HashMap::new(); + let mut phis_to_inject = Vec::new(); + + if block_id == func.entry_block_id { + local_defs = block_out_defs.get(&block_id).unwrap().clone(); + } else if block_preds.len() == 1 { + // Straight-line code: inherit exact definitions from predecessor + if let Some(pred_defs) = block_out_defs.get(&block_preds[0]) { + local_defs = pred_defs.clone(); + } + } else { + // Join Block: Inject empty Phi nodes for every promotable alloc + let mut block_phis = HashMap::new(); + for (&alloc_reg, &ty) in &promotable_allocs { + let phi_dest = next_reg(); + block_phis.insert(alloc_reg, phi_dest); + local_defs.insert(alloc_reg, Operand::Register(phi_dest)); + + phis_to_inject.push(Instruction::Phi { + dest: phi_dest, + result_ty: ty, + sources: Vec::new(), // Filled in Phase 6 + }); + } + generated_phis.insert(block_id, block_phis); + } + + // Rewrite the block's instructions + let block = func.blocks.iter_mut().find(|b| b.id == block_id).unwrap(); + let mut new_instructions = phis_to_inject; + + for inst in block.instructions.drain(..) { + match inst { + Instruction::Alloc { dest, .. } => { + if !promotable_allocs.contains_key(&dest) { + new_instructions.push(inst); + } + } + Instruction::Store { + dest: Operand::Register(dest_reg), + src, + .. + } if promotable_allocs.contains_key(&dest_reg) => { + // Update our tracked SSA state + local_defs.insert(dest_reg, src); + } + Instruction::Load { + dest, + src: Operand::Register(src_reg), + .. + } if promotable_allocs.contains_key(&src_reg) => { + // Replace Load with direct Assign from the active definition + let active_val = *local_defs.get(&src_reg).unwrap_or(&Operand::Integer(0)); + new_instructions.push(Instruction::Assign { + register: dest, + operand: active_val, + }); + } + _ => { + new_instructions.push(inst); + } + } + } + + block.instructions = new_instructions; + block_out_defs.insert(block_id, local_defs); + } + + // 6. Wire Up the Phi Sources + for &block_id in &rpo { + if let Some(block_phis) = generated_phis.get(&block_id) { + let block_preds = preds.get(&block_id).unwrap(); + let block = func.blocks.iter_mut().find(|b| b.id == block_id).unwrap(); + + for inst in &mut block.instructions { + if let Instruction::Phi { dest, sources, .. } = inst { + // Find which alloc this Phi belongs to + let alloc_reg = block_phis + .iter() + .find(|(_, d)| **d == *dest) + .map(|(a, _)| *a) + .unwrap(); + + // Wire up the exact values coming from each predecessor + for &pred_id in block_preds { + let pred_val = *block_out_defs + .get(&pred_id) + .and_then(|defs| defs.get(&alloc_reg)) + .unwrap_or(&Operand::Integer(0)); // Handle backedges for uninitialized vars + + sources.push((pred_val, pred_id)); + } + } + } + } + } +} diff --git a/src/passes/mod.rs b/src/passes/mod.rs new file mode 100644 index 0000000..0351180 --- /dev/null +++ b/src/passes/mod.rs @@ -0,0 +1,17 @@ +use crate::{ir::Module, validate::validate_module}; + +pub mod cfp; +pub mod cpp; +pub mod dce; +pub mod des; +pub mod m2r; + +/// Runs all the optimization passes. +pub fn optimize(module: &mut Module) { + m2r::mem2reg(module); + cfp::fold_constants(module); + dce::eliminate_dead_code(module); + cpp::propagate_copies(module); + des::destroy_ssa(module); + validate_module(module).expect("failed to validate module after optimization passes"); +} diff --git a/src/printer.rs b/src/printer.rs new file mode 100644 index 0000000..6966549 --- /dev/null +++ b/src/printer.rs @@ -0,0 +1,255 @@ +use std::fmt::{self, Display, Write}; + +use crate::ir::{ + BasicBlock, BinaryOp, BlockId, Function, ICmpOp, Instruction, Module, Operand, Register, + Terminator, Type, UnaryOp, +}; + +impl Display for Register { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "%{}", self.0) + } +} + +impl Display for Operand { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Operand::Integer(value) => write!(f, "${}", value), + Operand::Boolean(value) => write!(f, "${}", value), + Operand::Register(register) => register.fmt(f), + } + } +} + +impl Display for Type { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Type::Bool => write!(f, "bool"), + Type::I8 => write!(f, "i8"), + Type::I16 => write!(f, "i16"), + Type::I32 => write!(f, "i32"), + Type::I64 => write!(f, "i64"), + Type::Ptr => write!(f, "ptr"), + Type::Void => write!(f, "void"), + } + } +} + +impl Display for BlockId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "block_{}", self.0) + } +} + +impl Display for UnaryOp { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + UnaryOp::INeg => write!(f, "ineg"), + } + } +} + +impl Display for BinaryOp { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + BinaryOp::Add => write!(f, "add"), + BinaryOp::Sub => write!(f, "sub"), + BinaryOp::UDiv => write!(f, "udiv"), + BinaryOp::SDiv => write!(f, "udiv"), + BinaryOp::Mul => write!(f, "mul"), + BinaryOp::SRem => write!(f, "srem"), + BinaryOp::URem => write!(f, "urem"), + BinaryOp::ICmp(icmp_op) => write!(f, "icmp {}", icmp_op), + } + } +} + +impl Display for ICmpOp { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ICmpOp::Slt => write!(f, "slt"), + ICmpOp::Sle => write!(f, "sle"), + ICmpOp::Sgt => write!(f, "sgt"), + ICmpOp::Sge => write!(f, "sge"), + ICmpOp::Ult => write!(f, "ult"), + ICmpOp::Ule => write!(f, "ule"), + ICmpOp::Ugt => write!(f, "ugt"), + ICmpOp::Uge => write!(f, "uge"), + ICmpOp::Eq => write!(f, "eq"), + ICmpOp::Ne => write!(f, "ne"), + } + } +} + +impl Display for Module { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(IrPrinter::print(self)?.as_str()) + } +} + +pub struct IrPrinter<'a> { + module: &'a Module, + buffer: String, +} + +impl<'a> IrPrinter<'a> { + pub fn print(module: &'a Module) -> Result { + let mut printer = IrPrinter { + module, + buffer: String::new(), + }; + + printer.print_module()?; + + Ok(printer.buffer) + } + + fn print_module(&mut self) -> fmt::Result { + for (idx, function) in self.module.functions.iter().enumerate() { + if idx != 0 { + writeln!(&mut self.buffer)?; + } + + self.print_function(function)?; + } + + Ok(()) + } + + fn print_function(&mut self, function: &'a Function) -> fmt::Result { + let args = function + .params + .iter() + .map(|(ty, op)| format!("{} {}", ty, op)) + .collect::>() + .join(", "); + + writeln!( + &mut self.buffer, + "fn {}({}) -> {} {{", + function.name, args, function.return_ty + )?; + + for (idx, block) in function.blocks.iter().enumerate() { + if idx != 0 { + writeln!(&mut self.buffer)?; + } + + self.print_block(block)?; + } + + writeln!(&mut self.buffer, "}}") + } + + fn print_block(&mut self, block: &'a BasicBlock) -> fmt::Result { + writeln!(&mut self.buffer, " {}:", block.id)?; + + for instr in block.instructions.iter() { + write!(&mut self.buffer, " ")?; + self.print_instruction(instr)?; + writeln!(&mut self.buffer)?; + } + + write!(&mut self.buffer, " ")?; + self.print_terminator(&block.terminator)?; + writeln!(&mut self.buffer)?; + + Ok(()) + } + + fn print_instruction(&mut self, instruction: &'a Instruction) -> fmt::Result { + match instruction { + Instruction::Alloc { dest, ty } => { + write!(&mut self.buffer, "{} = alloc {}", dest, ty) + } + Instruction::Load { ty, dest, src } => { + write!(&mut self.buffer, "{} = load {} {}", dest, ty, src) + } + Instruction::Store { ty, dest, src } => { + write!(&mut self.buffer, "store {} {}, {}", ty, dest, src) + } + Instruction::Binary { + dest, + result_ty, + op, + src1, + src2, + } => write!( + &mut self.buffer, + "{} = {} {} {}, {}", + dest, op, result_ty, src1, src2 + ), + Instruction::Unary { + dest, + result_ty, + op, + src, + } => write!(&mut self.buffer, "{} = {} {} {}", dest, op, result_ty, src), + Instruction::Call { + dest, + result_ty, + func, + args, + } => { + let args = args + .iter() + .map(|(ty, op)| format!("{} {}", ty, op)) + .collect::>() + .join(", "); + + let func_name = self + .module + .functions + .iter() + .find_map(|f| (f.id == *func).then(|| f.name.clone())) + .unwrap(); + + write!( + &mut self.buffer, + "{} = call {} {}({})", + dest, result_ty, func_name, args + ) + } + Instruction::Assign { register, operand } => { + write!(&mut self.buffer, "{} = {}", register, operand) + } + Instruction::Phi { + dest, + result_ty, + sources, + } => { + let sources = sources + .iter() + .map(|(op, id)| format!("[{}, {}]", op, id)) + .collect::>() + .join(", "); + + write!(&mut self.buffer, "{} = phi {} {}", dest, result_ty, sources) + } + } + } + + fn print_terminator(&mut self, terminator: &'a Terminator) -> fmt::Result { + match terminator { + Terminator::Branch { + cond, + then_block, + else_block, + } => write!( + &mut self.buffer, + "branch {}, then {}, else {}", + cond, then_block, else_block + ), + Terminator::Return { + return_ty, + value: Some(value), + } => write!(&mut self.buffer, "return {} {}", return_ty, value), + Terminator::Return { + return_ty, + value: None, + } => write!(&mut self.buffer, "return {}", return_ty), + Terminator::Jump(block_id) => write!(&mut self.buffer, "jump {}", block_id), + Terminator::Unknown => write!(&mut self.buffer, "unknown"), + } + } +} diff --git a/src/validate.rs b/src/validate.rs new file mode 100644 index 0000000..d257624 --- /dev/null +++ b/src/validate.rs @@ -0,0 +1,203 @@ +use std::collections::HashMap; + +use crate::ir::*; + +/// Runs the type-checking pass over the entire module. +pub fn validate_module(module: &Module) -> Result<(), String> { + // 1. Collect all function return types to verify Call instructions globally + let mut function_signatures = HashMap::new(); + for func in &module.functions { + function_signatures.insert(func.id, func.return_ty); + } + + // 2. Verify each function in the module + for func in &module.functions { + validate_function(func, &function_signatures)?; + } + + Ok(()) +} + +fn validate_function( + func: &Function, + function_signatures: &HashMap, +) -> Result<(), String> { + let mut reg_types = HashMap::new(); + + // 1. Map function parameters to their register types + for (ty, reg) in &func.params { + reg_types.insert(*reg, *ty); + } + + // 2. Iteratively resolve register types + // We use a loop to handle `Assign` instructions that copy from a register + // defined later in the linear block scan. + let mut changed = true; + while changed { + changed = false; + for block in &func.blocks { + for inst in &block.instructions { + let (dest, ty) = match inst { + Instruction::Alloc { dest, .. } => (Some(*dest), Some(Type::Ptr)), + Instruction::Load { ty, dest, .. } => (Some(*dest), Some(*ty)), + Instruction::Binary { + result_ty, dest, .. + } => (Some(*dest), Some(*result_ty)), + Instruction::Unary { + result_ty, dest, .. + } => (Some(*dest), Some(*result_ty)), + Instruction::Call { + result_ty, dest, .. + } => (Some(*dest), Some(*result_ty)), + Instruction::Assign { register, operand } => { + let inferred_ty = match operand { + Operand::Register(r) => reg_types.get(r).copied(), + Operand::Boolean(_) => Some(Type::Bool), + Operand::Integer(_) => Some(Type::I64), // Default fallback for untyped literals + }; + (Some(*register), inferred_ty) + } + Instruction::Store { .. } => (None, None), + Instruction::Phi { + dest, result_ty, .. + } => (Some(*dest), Some(*result_ty)), + }; + + if let (Some(d), Some(t)) = (dest, ty) { + // Only flag as changed if we successfully resolved a new register + if let std::collections::hash_map::Entry::Vacant(e) = reg_types.entry(d) { + e.insert(t); + changed = true; + } + } + } + } + } + + // Helper closure to verify if an Operand matches an expected Type + let check_operand = |op: &Operand, expected: Type| -> Result<(), String> { + match op { + Operand::Register(r) => { + let actual = reg_types + .get(r) + .ok_or_else(|| format!("Unknown register {:?}", r))?; + + if *actual != expected { + return Err(format!( + "Type mismatch: expected {:?}, got {:?}", + expected, actual + )); + } + } + Operand::Boolean(_) => { + if expected != Type::Bool { + return Err(format!("Type mismatch: expected {:?}, got Bool", expected)); + } + } + Operand::Integer(_) => match expected { + Type::I8 | Type::I16 | Type::I32 | Type::I64 | Type::Ptr => {} + _ => { + return Err(format!("Cannot use integer literal as {:?}", expected)); + } + }, + } + + Ok(()) + }; + + // 3. Verify all instruction constraints + for block in &func.blocks { + for inst in &block.instructions { + match inst { + Instruction::Alloc { .. } => {} + Instruction::Assign { register, operand } => { + let expected_ty = *reg_types + .get(register) + .ok_or_else(|| format!("Unknown register {:?}", register))?; + check_operand(operand, expected_ty)?; + } + Instruction::Load { src, .. } => { + check_operand(src, Type::Ptr)?; + } + Instruction::Store { ty, dest, src } => { + check_operand(dest, Type::Ptr)?; + check_operand(src, *ty)?; + } + Instruction::Binary { + result_ty, + op, + src1, + src2, + .. + } => { + if let BinaryOp::ICmp(_) = op { + if *result_ty != Type::Bool { + return Err("ICmp result type must be Bool".to_string()); + } + + let op_type = match src1 { + Operand::Register(r) => *reg_types + .get(r) + .ok_or_else(|| format!("Unknown reg {:?}", r))?, + Operand::Boolean(_) => Type::Bool, + Operand::Integer(_) => match src2 { + Operand::Register(r) => *reg_types.get(r).unwrap_or(&Type::I64), + _ => Type::I64, + }, + }; + check_operand(src1, op_type)?; + check_operand(src2, op_type)?; + } else { + check_operand(src1, *result_ty)?; + check_operand(src2, *result_ty)?; + } + } + Instruction::Unary { result_ty, src, .. } => { + check_operand(src, *result_ty)?; + } + Instruction::Call { + result_ty, + func: target_id, + .. + } => { + let expected_ret = function_signatures + .get(target_id) + .ok_or_else(|| format!("Call to unknown function ID {:?}", target_id))?; + if result_ty != expected_ret { + return Err( + "Call result type does not match target function return type" + .to_string(), + ); + } + } + Instruction::Phi { + result_ty, sources, .. + } => { + for (op, _) in sources { + check_operand(op, *result_ty)?; + } + } + } + } + + // 4. Verify block terminators + match &block.terminator { + Terminator::Branch { cond, .. } => check_operand(cond, Type::Bool)?, + Terminator::Return { return_ty, value } => { + if *return_ty != func.return_ty { + return Err( + "Return terminator type does not match function definition".to_string() + ); + } + if let Some(val) = value { + check_operand(val, *return_ty)?; + } else if *return_ty != Type::Void { + return Err("Missing return value for non-void function".to_string()); + } + } + _ => {} + } + } + + Ok(()) +}