feat: add basic block reordering pass and improve codegen naming
- Implement `reorder_blocks` (BBR) pass using DFS to maximize fallthroughs. - Update x86_64 backend to use actual function names in call instructions instead of generic IDs. - Replace the GCD test case in main with an iterative factorial test module. - Remove redundant validation check at the end of the optimization pipeline.
This commit is contained in:
@@ -1,46 +0,0 @@
|
|||||||
#include <stdio.h>
|
|
||||||
#include <stdint.h>
|
|
||||||
#include <inttypes.h>
|
|
||||||
|
|
||||||
// Declare the external assembly function generated by the Rust backend
|
|
||||||
extern uint64_t gcd(uint64_t a, uint64_t b);
|
|
||||||
|
|
||||||
int main() {
|
|
||||||
// Define an array of test cases: {a, b, expected_result}
|
|
||||||
uint64_t test_cases[][3] = {
|
|
||||||
{48, 18, 6},
|
|
||||||
{54, 24, 6},
|
|
||||||
{7, 13, 1}, // Prime numbers
|
|
||||||
{100, 10, 10}, // One is a multiple of the other
|
|
||||||
{2740, 1760, 20}, // Larger random numbers
|
|
||||||
{1234567890, 90, 90}, // Testing 64-bit bounds
|
|
||||||
{5, 0, 5} // Edge case: b is 0 (Your IR handles this perfectly!)
|
|
||||||
};
|
|
||||||
|
|
||||||
int num_cases = sizeof(test_cases) / sizeof(test_cases[0]);
|
|
||||||
int passed = 0;
|
|
||||||
|
|
||||||
printf("Running GCD Assembly Tests...\n");
|
|
||||||
printf("-----------------------------\n");
|
|
||||||
|
|
||||||
for (int i = 0; i < num_cases; i++) {
|
|
||||||
uint64_t a = test_cases[i][0];
|
|
||||||
uint64_t b = test_cases[i][1];
|
|
||||||
uint64_t expected = test_cases[i][2];
|
|
||||||
|
|
||||||
// Call out into your compiled assembly!
|
|
||||||
uint64_t result = gcd(a, b);
|
|
||||||
|
|
||||||
if (result == expected) {
|
|
||||||
printf("[PASS] gcd(%" PRIu64 ", %" PRIu64 ") = %" PRIu64 "\n", a, b, result);
|
|
||||||
passed++;
|
|
||||||
} else {
|
|
||||||
printf("[FAIL] gcd(%" PRIu64 ", %" PRIu64 ") = %" PRIu64 " (Expected: %" PRIu64 ")\n", a, b, result, expected);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
printf("-----------------------------\n");
|
|
||||||
printf("Results: %d/%d passed.\n", passed, num_cases);
|
|
||||||
|
|
||||||
return (passed == num_cases) ? 0 : 1;
|
|
||||||
}
|
|
||||||
+10
-5
@@ -540,10 +540,7 @@ impl<'a> X86Backend<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
Instruction::Call {
|
Instruction::Call {
|
||||||
dest,
|
dest, func, args, ..
|
||||||
func: target_id,
|
|
||||||
args,
|
|
||||||
..
|
|
||||||
} => {
|
} => {
|
||||||
let arg_regs = ["rdi", "rsi", "rdx", "rcx", "r8", "r9"];
|
let arg_regs = ["rdi", "rsi", "rdx", "rcx", "r8", "r9"];
|
||||||
for (i, (_, arg_op)) in args.iter().enumerate() {
|
for (i, (_, arg_op)) in args.iter().enumerate() {
|
||||||
@@ -559,7 +556,15 @@ impl<'a> X86Backend<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
writeln!(&mut self.assembly, " call function_{}", target_id.0).unwrap();
|
|
||||||
|
let function_name = self
|
||||||
|
.module
|
||||||
|
.functions
|
||||||
|
.iter()
|
||||||
|
.find_map(|f| (f.id == *func).then(|| f.name.clone()))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
writeln!(&mut self.assembly, " call {}", function_name).unwrap();
|
||||||
if args.len() > 6 {
|
if args.len() > 6 {
|
||||||
let cleanup_size = (args.len() - 6) * 8;
|
let cleanup_size = (args.len() - 6) * 8;
|
||||||
writeln!(&mut self.assembly, " addq ${}, %rsp", cleanup_size).unwrap();
|
writeln!(&mut self.assembly, " addq ${}, %rsp", cleanup_size).unwrap();
|
||||||
|
|||||||
+85
-37
@@ -1,60 +1,108 @@
|
|||||||
use scarlett::{
|
use scarlett::{
|
||||||
backend::x86_64::X86Backend,
|
backend::x86_64::X86Backend, builder::IrModuleBuilder, ir::*, passes, validate::validate_module,
|
||||||
builder::IrModuleBuilder,
|
|
||||||
ir::{ICmpOp, Operand, Type},
|
|
||||||
passes,
|
|
||||||
validate::validate_module,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
fn main() {
|
fn build_test_module() -> Module {
|
||||||
let mut module_builder = IrModuleBuilder::new();
|
let mut module_builder = IrModuleBuilder::new();
|
||||||
|
|
||||||
// 1. Build `gcd(a: i64, b: i64) -> i64`
|
let i32_ty = Type::I32;
|
||||||
|
|
||||||
|
// 1. Define the Factorial Function
|
||||||
|
// factorial(n: i32) -> i32
|
||||||
|
let fact_id = module_builder.new_function_id();
|
||||||
{
|
{
|
||||||
let gcd_id = module_builder.new_function_id();
|
let f_builder = module_builder.new_function(
|
||||||
let builder =
|
fact_id,
|
||||||
module_builder.new_function(gcd_id, "gcd", &[Type::I64, Type::I64], Type::I64);
|
"factorial_iterative",
|
||||||
|
vec![&i32_ty],
|
||||||
|
i32_ty.clone(),
|
||||||
|
);
|
||||||
|
|
||||||
let ptr_x = builder.build_alloc(Type::I64);
|
// 1. Block Definitions
|
||||||
let ptr_y = builder.build_alloc(Type::I64);
|
let loop_cond_block = f_builder.create_block();
|
||||||
|
let loop_body_block = f_builder.create_block();
|
||||||
|
let exit_block = f_builder.create_block();
|
||||||
|
|
||||||
let param_0 = builder.get_param(0).unwrap();
|
// Allocate space for 'res' (accumulator) and 'i' (counter)
|
||||||
builder.build_store(Type::I64, ptr_x, param_0);
|
let res_ptr = f_builder.build_alloc(i32_ty.clone());
|
||||||
|
let i_ptr = f_builder.build_alloc(i32_ty.clone());
|
||||||
|
|
||||||
let param_1 = builder.get_param(1).unwrap();
|
let n = f_builder.get_param(0).expect("Param 0 missing");
|
||||||
builder.build_store(Type::I64, ptr_y, param_1);
|
let const_1 = Operand::Integer(1);
|
||||||
|
|
||||||
let loop_cond = builder.create_block();
|
// res = 1; i = n;
|
||||||
let loop_body = builder.create_block();
|
f_builder.build_store(i32_ty.clone(), res_ptr.clone(), const_1.clone());
|
||||||
let loop_merge = builder.create_block();
|
f_builder.build_store(i32_ty.clone(), i_ptr.clone(), n);
|
||||||
|
|
||||||
builder.build_jump(loop_cond);
|
f_builder.build_jump(loop_cond_block);
|
||||||
builder.switch_to_block(loop_cond);
|
|
||||||
|
|
||||||
let val_y = builder.build_load(Type::I64, ptr_y);
|
// 2. Loop Condition Block: i > 1
|
||||||
let cond = builder.build_icmp(ICmpOp::Ne, val_y, Operand::Integer(0));
|
f_builder.switch_to_block(loop_cond_block);
|
||||||
|
let current_i = f_builder.build_load(i32_ty.clone(), i_ptr.clone());
|
||||||
|
let is_gt_1 = f_builder.build_icmp(ICmpOp::Ugt, current_i, const_1.clone());
|
||||||
|
|
||||||
builder.build_branch(cond, loop_body, loop_merge);
|
// If i > 1 goto body, else goto exit
|
||||||
builder.switch_to_block(loop_body);
|
f_builder.build_branch(is_gt_1, loop_body_block, exit_block);
|
||||||
|
|
||||||
let val_x = builder.build_load(Type::I64, ptr_x);
|
// 3. Loop Body Block: res = res * i; i = i - 1;
|
||||||
let val_y = builder.build_load(Type::I64, ptr_y);
|
f_builder.switch_to_block(loop_body_block);
|
||||||
let rem = builder.build_urem(Type::I64, val_x, val_y);
|
|
||||||
|
|
||||||
builder.build_store(Type::I64, ptr_x, val_y);
|
let val_res = f_builder.build_load(i32_ty.clone(), res_ptr.clone());
|
||||||
builder.build_store(Type::I64, ptr_y, rem);
|
let val_i = f_builder.build_load(i32_ty.clone(), i_ptr.clone());
|
||||||
|
|
||||||
builder.build_jump(loop_cond);
|
// res = res * i
|
||||||
builder.switch_to_block(loop_merge);
|
let updated_res = f_builder.build_mul(i32_ty.clone(), val_res, val_i.clone());
|
||||||
|
f_builder.build_store(i32_ty.clone(), res_ptr.clone(), updated_res);
|
||||||
|
|
||||||
let val_x = builder.build_load(Type::I64, ptr_x);
|
// i = i - 1
|
||||||
builder.build_return(Type::I64, val_x);
|
let updated_i = f_builder.build_sub(i32_ty.clone(), val_i, const_1);
|
||||||
|
f_builder.build_store(i32_ty.clone(), i_ptr.clone(), updated_i);
|
||||||
|
|
||||||
module_builder.complete_function();
|
// Jump back to condition
|
||||||
|
f_builder.build_jump(loop_cond_block);
|
||||||
|
|
||||||
|
// 4. Exit Block: return res
|
||||||
|
f_builder.switch_to_block(exit_block);
|
||||||
|
let final_res = f_builder.build_load(i32_ty.clone(), res_ptr);
|
||||||
|
f_builder.build_return(i32_ty.clone(), final_res);
|
||||||
}
|
}
|
||||||
|
module_builder.complete_function();
|
||||||
|
|
||||||
// 2. Finish, Validate, Optimize, and Compile
|
// 2. Define the Main Function
|
||||||
let mut module = module_builder.finish();
|
let main_id = module_builder.new_function_id();
|
||||||
|
{
|
||||||
|
let f_builder = module_builder.new_function(
|
||||||
|
main_id,
|
||||||
|
"main",
|
||||||
|
vec![], // no params
|
||||||
|
i32_ty.clone(),
|
||||||
|
);
|
||||||
|
|
||||||
|
// 1. Allocate space for an integer on the stack
|
||||||
|
let ptr = f_builder.build_alloc(i32_ty.clone());
|
||||||
|
|
||||||
|
// 2. Store the value '5' into that pointer
|
||||||
|
let input_val = Operand::Integer(5);
|
||||||
|
f_builder.build_store(i32_ty.clone(), ptr.clone(), input_val);
|
||||||
|
|
||||||
|
// 3. Load the value back from the pointer
|
||||||
|
let loaded_val = f_builder.build_load(i32_ty.clone(), ptr);
|
||||||
|
|
||||||
|
// 4. Call factorial(loaded_val)
|
||||||
|
let args = [(i32_ty.clone(), loaded_val)];
|
||||||
|
let final_result = f_builder.build_call(i32_ty.clone(), fact_id, args.iter());
|
||||||
|
|
||||||
|
// 5. Return the result of the factorial call
|
||||||
|
f_builder.build_return(i32_ty.clone(), final_result);
|
||||||
|
}
|
||||||
|
module_builder.complete_function();
|
||||||
|
|
||||||
|
// Finalize the module
|
||||||
|
module_builder.finish()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let mut module = build_test_module();
|
||||||
|
|
||||||
validate_module(&module).expect("failed to validate module");
|
validate_module(&module).expect("failed to validate module");
|
||||||
passes::optimize(&mut module);
|
passes::optimize(&mut module);
|
||||||
|
|||||||
@@ -0,0 +1,76 @@
|
|||||||
|
use std::collections::{HashMap, HashSet};
|
||||||
|
|
||||||
|
use crate::ir::*;
|
||||||
|
|
||||||
|
/// Runs the Block Reordering pass to maximize fallthroughs and minimize jumps.
|
||||||
|
pub fn reorder_blocks(module: &mut Module) {
|
||||||
|
for func in &mut module.functions {
|
||||||
|
reorder_blocks_in_func(func);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reorder_blocks_in_func(func: &mut Function) {
|
||||||
|
if func.blocks.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut order = Vec::new();
|
||||||
|
let mut visited = HashSet::new();
|
||||||
|
|
||||||
|
// 1. Recursive Depth-First Search to build the optimal chain of blocks
|
||||||
|
fn dfs(
|
||||||
|
id: BlockId,
|
||||||
|
blocks: &[BasicBlock],
|
||||||
|
visited: &mut HashSet<BlockId>,
|
||||||
|
order: &mut Vec<BlockId>,
|
||||||
|
) {
|
||||||
|
if !visited.insert(id) {
|
||||||
|
return; // Stop if we hit a loop backedge or already visited block
|
||||||
|
}
|
||||||
|
order.push(id);
|
||||||
|
|
||||||
|
let block = blocks.iter().find(|b| b.id == id).unwrap();
|
||||||
|
|
||||||
|
match block.terminator {
|
||||||
|
Terminator::Jump(target) => {
|
||||||
|
// Instantly follow unconditional jumps to chain them together
|
||||||
|
dfs(target, blocks, visited, order);
|
||||||
|
}
|
||||||
|
Terminator::Branch {
|
||||||
|
then_block,
|
||||||
|
else_block,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
// HEURISTIC: Visit the `else_block` immediately.
|
||||||
|
// This places it directly next to the current block in memory,
|
||||||
|
// allowing our x86-64 backend to omit the branch and fall through.
|
||||||
|
dfs(else_block, blocks, visited, order);
|
||||||
|
|
||||||
|
// Then, recursively process the `then_block` path.
|
||||||
|
dfs(then_block, blocks, visited, order);
|
||||||
|
}
|
||||||
|
Terminator::Return { .. } | Terminator::Unknown => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start traversing from the function's official entry point
|
||||||
|
dfs(func.entry_block_id, &func.blocks, &mut visited, &mut order);
|
||||||
|
|
||||||
|
// 2. Rebuild the function's block vector in the newly computed order
|
||||||
|
let mut block_map: HashMap<BlockId, BasicBlock> =
|
||||||
|
func.blocks.drain(..).map(|b| (b.id, b)).collect();
|
||||||
|
|
||||||
|
// Push the blocks back into the function in our optimized DFS order
|
||||||
|
for id in order {
|
||||||
|
if let Some(block) = block_map.remove(&id) {
|
||||||
|
func.blocks.push(block);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Append any unreachable blocks that the DFS missed.
|
||||||
|
// Note: If you run your Dead Code Elimination pass before this,
|
||||||
|
// `block_map` will already be completely empty by this point.
|
||||||
|
for (_, block) in block_map {
|
||||||
|
func.blocks.push(block);
|
||||||
|
}
|
||||||
|
}
|
||||||
+3
-2
@@ -1,5 +1,6 @@
|
|||||||
use crate::{ir::Module, validate::validate_module};
|
use crate::ir::Module;
|
||||||
|
|
||||||
|
pub mod bbr;
|
||||||
pub mod cfp;
|
pub mod cfp;
|
||||||
pub mod cpp;
|
pub mod cpp;
|
||||||
pub mod dce;
|
pub mod dce;
|
||||||
@@ -12,6 +13,6 @@ pub fn optimize(module: &mut Module) {
|
|||||||
cfp::fold_constants(module);
|
cfp::fold_constants(module);
|
||||||
dce::eliminate_dead_code(module);
|
dce::eliminate_dead_code(module);
|
||||||
cpp::propagate_copies(module);
|
cpp::propagate_copies(module);
|
||||||
|
bbr::reorder_blocks(module);
|
||||||
des::destroy_ssa(module);
|
des::destroy_ssa(module);
|
||||||
validate_module(module).expect("failed to validate module after optimization passes");
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user