feat: add basic block reordering pass and improve codegen naming

- Implement `reorder_blocks` (BBR) pass using DFS to maximize
fallthroughs.
- Update x86_64 backend to use actual function names in call
instructions instead of generic IDs.
- Replace the GCD test case in main with an iterative factorial test
module.
- Remove redundant validation check at the end of the optimization
pipeline.
This commit is contained in:
2026-04-27 10:56:47 +02:00
parent 9d94e3b81b
commit 5b8a0cb398
5 changed files with 174 additions and 90 deletions
-46
View File
@@ -1,46 +0,0 @@
#include <stdio.h>
#include <stdint.h>
#include <inttypes.h>
// Declare the external assembly function generated by the Rust backend
extern uint64_t gcd(uint64_t a, uint64_t b);
int main() {
// Define an array of test cases: {a, b, expected_result}
uint64_t test_cases[][3] = {
{48, 18, 6},
{54, 24, 6},
{7, 13, 1}, // Prime numbers
{100, 10, 10}, // One is a multiple of the other
{2740, 1760, 20}, // Larger random numbers
{1234567890, 90, 90}, // Testing 64-bit bounds
{5, 0, 5} // Edge case: b is 0 (Your IR handles this perfectly!)
};
int num_cases = sizeof(test_cases) / sizeof(test_cases[0]);
int passed = 0;
printf("Running GCD Assembly Tests...\n");
printf("-----------------------------\n");
for (int i = 0; i < num_cases; i++) {
uint64_t a = test_cases[i][0];
uint64_t b = test_cases[i][1];
uint64_t expected = test_cases[i][2];
// Call out into your compiled assembly!
uint64_t result = gcd(a, b);
if (result == expected) {
printf("[PASS] gcd(%" PRIu64 ", %" PRIu64 ") = %" PRIu64 "\n", a, b, result);
passed++;
} else {
printf("[FAIL] gcd(%" PRIu64 ", %" PRIu64 ") = %" PRIu64 " (Expected: %" PRIu64 ")\n", a, b, result, expected);
}
}
printf("-----------------------------\n");
printf("Results: %d/%d passed.\n", passed, num_cases);
return (passed == num_cases) ? 0 : 1;
}
+10 -5
View File
@@ -540,10 +540,7 @@ impl<'a> X86Backend<'a> {
} }
} }
Instruction::Call { Instruction::Call {
dest, dest, func, args, ..
func: target_id,
args,
..
} => { } => {
let arg_regs = ["rdi", "rsi", "rdx", "rcx", "r8", "r9"]; let arg_regs = ["rdi", "rsi", "rdx", "rcx", "r8", "r9"];
for (i, (_, arg_op)) in args.iter().enumerate() { for (i, (_, arg_op)) in args.iter().enumerate() {
@@ -559,7 +556,15 @@ impl<'a> X86Backend<'a> {
} }
} }
} }
writeln!(&mut self.assembly, " call function_{}", target_id.0).unwrap();
let function_name = self
.module
.functions
.iter()
.find_map(|f| (f.id == *func).then(|| f.name.clone()))
.unwrap();
writeln!(&mut self.assembly, " call {}", function_name).unwrap();
if args.len() > 6 { if args.len() > 6 {
let cleanup_size = (args.len() - 6) * 8; let cleanup_size = (args.len() - 6) * 8;
writeln!(&mut self.assembly, " addq ${}, %rsp", cleanup_size).unwrap(); writeln!(&mut self.assembly, " addq ${}, %rsp", cleanup_size).unwrap();
+84 -36
View File
@@ -1,60 +1,108 @@
use scarlett::{ use scarlett::{
backend::x86_64::X86Backend, backend::x86_64::X86Backend, builder::IrModuleBuilder, ir::*, passes, validate::validate_module,
builder::IrModuleBuilder,
ir::{ICmpOp, Operand, Type},
passes,
validate::validate_module,
}; };
fn main() { fn build_test_module() -> Module {
let mut module_builder = IrModuleBuilder::new(); let mut module_builder = IrModuleBuilder::new();
// 1. Build `gcd(a: i64, b: i64) -> i64` let i32_ty = Type::I32;
// 1. Define the Factorial Function
// factorial(n: i32) -> i32
let fact_id = module_builder.new_function_id();
{ {
let gcd_id = module_builder.new_function_id(); let f_builder = module_builder.new_function(
let builder = fact_id,
module_builder.new_function(gcd_id, "gcd", &[Type::I64, Type::I64], Type::I64); "factorial_iterative",
vec![&i32_ty],
i32_ty.clone(),
);
let ptr_x = builder.build_alloc(Type::I64); // 1. Block Definitions
let ptr_y = builder.build_alloc(Type::I64); let loop_cond_block = f_builder.create_block();
let loop_body_block = f_builder.create_block();
let exit_block = f_builder.create_block();
let param_0 = builder.get_param(0).unwrap(); // Allocate space for 'res' (accumulator) and 'i' (counter)
builder.build_store(Type::I64, ptr_x, param_0); let res_ptr = f_builder.build_alloc(i32_ty.clone());
let i_ptr = f_builder.build_alloc(i32_ty.clone());
let param_1 = builder.get_param(1).unwrap(); let n = f_builder.get_param(0).expect("Param 0 missing");
builder.build_store(Type::I64, ptr_y, param_1); let const_1 = Operand::Integer(1);
let loop_cond = builder.create_block(); // res = 1; i = n;
let loop_body = builder.create_block(); f_builder.build_store(i32_ty.clone(), res_ptr.clone(), const_1.clone());
let loop_merge = builder.create_block(); f_builder.build_store(i32_ty.clone(), i_ptr.clone(), n);
builder.build_jump(loop_cond); f_builder.build_jump(loop_cond_block);
builder.switch_to_block(loop_cond);
let val_y = builder.build_load(Type::I64, ptr_y); // 2. Loop Condition Block: i > 1
let cond = builder.build_icmp(ICmpOp::Ne, val_y, Operand::Integer(0)); f_builder.switch_to_block(loop_cond_block);
let current_i = f_builder.build_load(i32_ty.clone(), i_ptr.clone());
let is_gt_1 = f_builder.build_icmp(ICmpOp::Ugt, current_i, const_1.clone());
builder.build_branch(cond, loop_body, loop_merge); // If i > 1 goto body, else goto exit
builder.switch_to_block(loop_body); f_builder.build_branch(is_gt_1, loop_body_block, exit_block);
let val_x = builder.build_load(Type::I64, ptr_x); // 3. Loop Body Block: res = res * i; i = i - 1;
let val_y = builder.build_load(Type::I64, ptr_y); f_builder.switch_to_block(loop_body_block);
let rem = builder.build_urem(Type::I64, val_x, val_y);
builder.build_store(Type::I64, ptr_x, val_y); let val_res = f_builder.build_load(i32_ty.clone(), res_ptr.clone());
builder.build_store(Type::I64, ptr_y, rem); let val_i = f_builder.build_load(i32_ty.clone(), i_ptr.clone());
builder.build_jump(loop_cond); // res = res * i
builder.switch_to_block(loop_merge); let updated_res = f_builder.build_mul(i32_ty.clone(), val_res, val_i.clone());
f_builder.build_store(i32_ty.clone(), res_ptr.clone(), updated_res);
let val_x = builder.build_load(Type::I64, ptr_x); // i = i - 1
builder.build_return(Type::I64, val_x); let updated_i = f_builder.build_sub(i32_ty.clone(), val_i, const_1);
f_builder.build_store(i32_ty.clone(), i_ptr.clone(), updated_i);
// Jump back to condition
f_builder.build_jump(loop_cond_block);
// 4. Exit Block: return res
f_builder.switch_to_block(exit_block);
let final_res = f_builder.build_load(i32_ty.clone(), res_ptr);
f_builder.build_return(i32_ty.clone(), final_res);
}
module_builder.complete_function(); module_builder.complete_function();
// 2. Define the Main Function
let main_id = module_builder.new_function_id();
{
let f_builder = module_builder.new_function(
main_id,
"main",
vec![], // no params
i32_ty.clone(),
);
// 1. Allocate space for an integer on the stack
let ptr = f_builder.build_alloc(i32_ty.clone());
// 2. Store the value '5' into that pointer
let input_val = Operand::Integer(5);
f_builder.build_store(i32_ty.clone(), ptr.clone(), input_val);
// 3. Load the value back from the pointer
let loaded_val = f_builder.build_load(i32_ty.clone(), ptr);
// 4. Call factorial(loaded_val)
let args = [(i32_ty.clone(), loaded_val)];
let final_result = f_builder.build_call(i32_ty.clone(), fact_id, args.iter());
// 5. Return the result of the factorial call
f_builder.build_return(i32_ty.clone(), final_result);
}
module_builder.complete_function();
// Finalize the module
module_builder.finish()
} }
// 2. Finish, Validate, Optimize, and Compile fn main() {
let mut module = module_builder.finish(); let mut module = build_test_module();
validate_module(&module).expect("failed to validate module"); validate_module(&module).expect("failed to validate module");
passes::optimize(&mut module); passes::optimize(&mut module);
+76
View File
@@ -0,0 +1,76 @@
use std::collections::{HashMap, HashSet};
use crate::ir::*;
/// Runs the Block Reordering pass to maximize fallthroughs and minimize jumps.
pub fn reorder_blocks(module: &mut Module) {
for func in &mut module.functions {
reorder_blocks_in_func(func);
}
}
fn reorder_blocks_in_func(func: &mut Function) {
if func.blocks.is_empty() {
return;
}
let mut order = Vec::new();
let mut visited = HashSet::new();
// 1. Recursive Depth-First Search to build the optimal chain of blocks
fn dfs(
id: BlockId,
blocks: &[BasicBlock],
visited: &mut HashSet<BlockId>,
order: &mut Vec<BlockId>,
) {
if !visited.insert(id) {
return; // Stop if we hit a loop backedge or already visited block
}
order.push(id);
let block = blocks.iter().find(|b| b.id == id).unwrap();
match block.terminator {
Terminator::Jump(target) => {
// Instantly follow unconditional jumps to chain them together
dfs(target, blocks, visited, order);
}
Terminator::Branch {
then_block,
else_block,
..
} => {
// HEURISTIC: Visit the `else_block` immediately.
// This places it directly next to the current block in memory,
// allowing our x86-64 backend to omit the branch and fall through.
dfs(else_block, blocks, visited, order);
// Then, recursively process the `then_block` path.
dfs(then_block, blocks, visited, order);
}
Terminator::Return { .. } | Terminator::Unknown => {}
}
}
// Start traversing from the function's official entry point
dfs(func.entry_block_id, &func.blocks, &mut visited, &mut order);
// 2. Rebuild the function's block vector in the newly computed order
let mut block_map: HashMap<BlockId, BasicBlock> =
func.blocks.drain(..).map(|b| (b.id, b)).collect();
// Push the blocks back into the function in our optimized DFS order
for id in order {
if let Some(block) = block_map.remove(&id) {
func.blocks.push(block);
}
}
// 3. Append any unreachable blocks that the DFS missed.
// Note: If you run your Dead Code Elimination pass before this,
// `block_map` will already be completely empty by this point.
for (_, block) in block_map {
func.blocks.push(block);
}
}
+3 -2
View File
@@ -1,5 +1,6 @@
use crate::{ir::Module, validate::validate_module}; use crate::ir::Module;
pub mod bbr;
pub mod cfp; pub mod cfp;
pub mod cpp; pub mod cpp;
pub mod dce; pub mod dce;
@@ -12,6 +13,6 @@ pub fn optimize(module: &mut Module) {
cfp::fold_constants(module); cfp::fold_constants(module);
dce::eliminate_dead_code(module); dce::eliminate_dead_code(module);
cpp::propagate_copies(module); cpp::propagate_copies(module);
bbr::reorder_blocks(module);
des::destroy_ssa(module); des::destroy_ssa(module);
validate_module(module).expect("failed to validate module after optimization passes");
} }