From 5b8a0cb3983e9219af81343cc15a2b7ae36eda40 Mon Sep 17 00:00:00 2001 From: Jooris Hadeler Date: Mon, 27 Apr 2026 10:56:47 +0200 Subject: [PATCH] feat: add basic block reordering pass and improve codegen naming - Implement `reorder_blocks` (BBR) pass using DFS to maximize fallthroughs. - Update x86_64 backend to use actual function names in call instructions instead of generic IDs. - Replace the GCD test case in main with an iterative factorial test module. - Remove redundant validation check at the end of the optimization pipeline. --- harness.c | 46 ---------------- src/backend/x86_64.rs | 15 ++++-- src/main.rs | 122 +++++++++++++++++++++++++++++------------- src/passes/bbr.rs | 76 ++++++++++++++++++++++++++ src/passes/mod.rs | 5 +- 5 files changed, 174 insertions(+), 90 deletions(-) delete mode 100644 harness.c create mode 100644 src/passes/bbr.rs diff --git a/harness.c b/harness.c deleted file mode 100644 index 9fdf89a..0000000 --- a/harness.c +++ /dev/null @@ -1,46 +0,0 @@ -#include -#include -#include - -// Declare the external assembly function generated by the Rust backend -extern uint64_t gcd(uint64_t a, uint64_t b); - -int main() { - // Define an array of test cases: {a, b, expected_result} - uint64_t test_cases[][3] = { - {48, 18, 6}, - {54, 24, 6}, - {7, 13, 1}, // Prime numbers - {100, 10, 10}, // One is a multiple of the other - {2740, 1760, 20}, // Larger random numbers - {1234567890, 90, 90}, // Testing 64-bit bounds - {5, 0, 5} // Edge case: b is 0 (Your IR handles this perfectly!) - }; - - int num_cases = sizeof(test_cases) / sizeof(test_cases[0]); - int passed = 0; - - printf("Running GCD Assembly Tests...\n"); - printf("-----------------------------\n"); - - for (int i = 0; i < num_cases; i++) { - uint64_t a = test_cases[i][0]; - uint64_t b = test_cases[i][1]; - uint64_t expected = test_cases[i][2]; - - // Call out into your compiled assembly! - uint64_t result = gcd(a, b); - - if (result == expected) { - printf("[PASS] gcd(%" PRIu64 ", %" PRIu64 ") = %" PRIu64 "\n", a, b, result); - passed++; - } else { - printf("[FAIL] gcd(%" PRIu64 ", %" PRIu64 ") = %" PRIu64 " (Expected: %" PRIu64 ")\n", a, b, result, expected); - } - } - - printf("-----------------------------\n"); - printf("Results: %d/%d passed.\n", passed, num_cases); - - return (passed == num_cases) ? 0 : 1; -} diff --git a/src/backend/x86_64.rs b/src/backend/x86_64.rs index 7186b3c..5719ef4 100644 --- a/src/backend/x86_64.rs +++ b/src/backend/x86_64.rs @@ -540,10 +540,7 @@ impl<'a> X86Backend<'a> { } } Instruction::Call { - dest, - func: target_id, - args, - .. + dest, func, args, .. } => { let arg_regs = ["rdi", "rsi", "rdx", "rcx", "r8", "r9"]; for (i, (_, arg_op)) in args.iter().enumerate() { @@ -559,7 +556,15 @@ impl<'a> X86Backend<'a> { } } } - writeln!(&mut self.assembly, " call function_{}", target_id.0).unwrap(); + + let function_name = self + .module + .functions + .iter() + .find_map(|f| (f.id == *func).then(|| f.name.clone())) + .unwrap(); + + writeln!(&mut self.assembly, " call {}", function_name).unwrap(); if args.len() > 6 { let cleanup_size = (args.len() - 6) * 8; writeln!(&mut self.assembly, " addq ${}, %rsp", cleanup_size).unwrap(); diff --git a/src/main.rs b/src/main.rs index 24e83bb..779cbab 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,60 +1,108 @@ use scarlett::{ - backend::x86_64::X86Backend, - builder::IrModuleBuilder, - ir::{ICmpOp, Operand, Type}, - passes, - validate::validate_module, + backend::x86_64::X86Backend, builder::IrModuleBuilder, ir::*, passes, validate::validate_module, }; -fn main() { +fn build_test_module() -> Module { let mut module_builder = IrModuleBuilder::new(); - // 1. Build `gcd(a: i64, b: i64) -> i64` + let i32_ty = Type::I32; + + // 1. Define the Factorial Function + // factorial(n: i32) -> i32 + let fact_id = module_builder.new_function_id(); { - let gcd_id = module_builder.new_function_id(); - let builder = - module_builder.new_function(gcd_id, "gcd", &[Type::I64, Type::I64], Type::I64); + let f_builder = module_builder.new_function( + fact_id, + "factorial_iterative", + vec![&i32_ty], + i32_ty.clone(), + ); - let ptr_x = builder.build_alloc(Type::I64); - let ptr_y = builder.build_alloc(Type::I64); + // 1. Block Definitions + let loop_cond_block = f_builder.create_block(); + let loop_body_block = f_builder.create_block(); + let exit_block = f_builder.create_block(); - let param_0 = builder.get_param(0).unwrap(); - builder.build_store(Type::I64, ptr_x, param_0); + // Allocate space for 'res' (accumulator) and 'i' (counter) + let res_ptr = f_builder.build_alloc(i32_ty.clone()); + let i_ptr = f_builder.build_alloc(i32_ty.clone()); - let param_1 = builder.get_param(1).unwrap(); - builder.build_store(Type::I64, ptr_y, param_1); + let n = f_builder.get_param(0).expect("Param 0 missing"); + let const_1 = Operand::Integer(1); - let loop_cond = builder.create_block(); - let loop_body = builder.create_block(); - let loop_merge = builder.create_block(); + // res = 1; i = n; + f_builder.build_store(i32_ty.clone(), res_ptr.clone(), const_1.clone()); + f_builder.build_store(i32_ty.clone(), i_ptr.clone(), n); - builder.build_jump(loop_cond); - builder.switch_to_block(loop_cond); + f_builder.build_jump(loop_cond_block); - let val_y = builder.build_load(Type::I64, ptr_y); - let cond = builder.build_icmp(ICmpOp::Ne, val_y, Operand::Integer(0)); + // 2. Loop Condition Block: i > 1 + f_builder.switch_to_block(loop_cond_block); + let current_i = f_builder.build_load(i32_ty.clone(), i_ptr.clone()); + let is_gt_1 = f_builder.build_icmp(ICmpOp::Ugt, current_i, const_1.clone()); - builder.build_branch(cond, loop_body, loop_merge); - builder.switch_to_block(loop_body); + // If i > 1 goto body, else goto exit + f_builder.build_branch(is_gt_1, loop_body_block, exit_block); - let val_x = builder.build_load(Type::I64, ptr_x); - let val_y = builder.build_load(Type::I64, ptr_y); - let rem = builder.build_urem(Type::I64, val_x, val_y); + // 3. Loop Body Block: res = res * i; i = i - 1; + f_builder.switch_to_block(loop_body_block); - builder.build_store(Type::I64, ptr_x, val_y); - builder.build_store(Type::I64, ptr_y, rem); + let val_res = f_builder.build_load(i32_ty.clone(), res_ptr.clone()); + let val_i = f_builder.build_load(i32_ty.clone(), i_ptr.clone()); - builder.build_jump(loop_cond); - builder.switch_to_block(loop_merge); + // res = res * i + let updated_res = f_builder.build_mul(i32_ty.clone(), val_res, val_i.clone()); + f_builder.build_store(i32_ty.clone(), res_ptr.clone(), updated_res); - let val_x = builder.build_load(Type::I64, ptr_x); - builder.build_return(Type::I64, val_x); + // i = i - 1 + let updated_i = f_builder.build_sub(i32_ty.clone(), val_i, const_1); + f_builder.build_store(i32_ty.clone(), i_ptr.clone(), updated_i); - module_builder.complete_function(); + // Jump back to condition + f_builder.build_jump(loop_cond_block); + + // 4. Exit Block: return res + f_builder.switch_to_block(exit_block); + let final_res = f_builder.build_load(i32_ty.clone(), res_ptr); + f_builder.build_return(i32_ty.clone(), final_res); } + module_builder.complete_function(); - // 2. Finish, Validate, Optimize, and Compile - let mut module = module_builder.finish(); + // 2. Define the Main Function + let main_id = module_builder.new_function_id(); + { + let f_builder = module_builder.new_function( + main_id, + "main", + vec![], // no params + i32_ty.clone(), + ); + + // 1. Allocate space for an integer on the stack + let ptr = f_builder.build_alloc(i32_ty.clone()); + + // 2. Store the value '5' into that pointer + let input_val = Operand::Integer(5); + f_builder.build_store(i32_ty.clone(), ptr.clone(), input_val); + + // 3. Load the value back from the pointer + let loaded_val = f_builder.build_load(i32_ty.clone(), ptr); + + // 4. Call factorial(loaded_val) + let args = [(i32_ty.clone(), loaded_val)]; + let final_result = f_builder.build_call(i32_ty.clone(), fact_id, args.iter()); + + // 5. Return the result of the factorial call + f_builder.build_return(i32_ty.clone(), final_result); + } + module_builder.complete_function(); + + // Finalize the module + module_builder.finish() +} + +fn main() { + let mut module = build_test_module(); validate_module(&module).expect("failed to validate module"); passes::optimize(&mut module); diff --git a/src/passes/bbr.rs b/src/passes/bbr.rs new file mode 100644 index 0000000..31e1ac6 --- /dev/null +++ b/src/passes/bbr.rs @@ -0,0 +1,76 @@ +use std::collections::{HashMap, HashSet}; + +use crate::ir::*; + +/// Runs the Block Reordering pass to maximize fallthroughs and minimize jumps. +pub fn reorder_blocks(module: &mut Module) { + for func in &mut module.functions { + reorder_blocks_in_func(func); + } +} + +fn reorder_blocks_in_func(func: &mut Function) { + if func.blocks.is_empty() { + return; + } + + let mut order = Vec::new(); + let mut visited = HashSet::new(); + + // 1. Recursive Depth-First Search to build the optimal chain of blocks + fn dfs( + id: BlockId, + blocks: &[BasicBlock], + visited: &mut HashSet, + order: &mut Vec, + ) { + if !visited.insert(id) { + return; // Stop if we hit a loop backedge or already visited block + } + order.push(id); + + let block = blocks.iter().find(|b| b.id == id).unwrap(); + + match block.terminator { + Terminator::Jump(target) => { + // Instantly follow unconditional jumps to chain them together + dfs(target, blocks, visited, order); + } + Terminator::Branch { + then_block, + else_block, + .. + } => { + // HEURISTIC: Visit the `else_block` immediately. + // This places it directly next to the current block in memory, + // allowing our x86-64 backend to omit the branch and fall through. + dfs(else_block, blocks, visited, order); + + // Then, recursively process the `then_block` path. + dfs(then_block, blocks, visited, order); + } + Terminator::Return { .. } | Terminator::Unknown => {} + } + } + + // Start traversing from the function's official entry point + dfs(func.entry_block_id, &func.blocks, &mut visited, &mut order); + + // 2. Rebuild the function's block vector in the newly computed order + let mut block_map: HashMap = + func.blocks.drain(..).map(|b| (b.id, b)).collect(); + + // Push the blocks back into the function in our optimized DFS order + for id in order { + if let Some(block) = block_map.remove(&id) { + func.blocks.push(block); + } + } + + // 3. Append any unreachable blocks that the DFS missed. + // Note: If you run your Dead Code Elimination pass before this, + // `block_map` will already be completely empty by this point. + for (_, block) in block_map { + func.blocks.push(block); + } +} diff --git a/src/passes/mod.rs b/src/passes/mod.rs index 0351180..d702e6a 100644 --- a/src/passes/mod.rs +++ b/src/passes/mod.rs @@ -1,5 +1,6 @@ -use crate::{ir::Module, validate::validate_module}; +use crate::ir::Module; +pub mod bbr; pub mod cfp; pub mod cpp; pub mod dce; @@ -12,6 +13,6 @@ pub fn optimize(module: &mut Module) { cfp::fold_constants(module); dce::eliminate_dead_code(module); cpp::propagate_copies(module); + bbr::reorder_blocks(module); des::destroy_ssa(module); - validate_module(module).expect("failed to validate module after optimization passes"); }