From 6727dd84e47174db4230291c0d535d03105fd85a Mon Sep 17 00:00:00 2001 From: Jooris Hadeler Date: Mon, 27 Apr 2026 20:53:44 +0200 Subject: [PATCH] feat(ir/backend): add support for floating-point types and arithmetic - Introduces F32 and F64 types and Float operand variant to the IR. - Implements floating-point binary operations (FAdd, FSub, FMul, FDiv, FRem, FCmp) and FNeg unary op. - Updates IR printer, validator, and builder to handle the new floating-point functionality. - Extends the constant folding pass to evaluate floating-point expressions at compile time. - Enhances x86_64 backend with XMM register support and floating-point codegen. - Implements a fixed-point iteration pass for register type resolution to correctly allocate GPR vs XMM registers. - Updates the linear scan allocator to manage multiple register classes (GPR, XMM). - Adds System V ABI compliant handling for floating-point function arguments and return values. - Includes comprehensive tests for IR validation, constant folding, and assembly generation. --- src/backend/x86_64/codegen.rs | 51 +++ src/backend/x86_64/mod.rs | 627 ++++++++++++++++++++++++++-------- src/backend/x86_64/types.rs | 182 +++++++++- src/builder.rs | 42 +++ src/ir.rs | 55 ++- src/main.rs | 44 ++- src/passes/cfp.rs | 41 ++- src/printer.rs | 35 +- src/validate.rs | 28 ++ tests/float_test.rs | 113 ++++++ 10 files changed, 1048 insertions(+), 170 deletions(-) create mode 100644 tests/float_test.rs diff --git a/src/backend/x86_64/codegen.rs b/src/backend/x86_64/codegen.rs index baee7e4..5a72d79 100644 --- a/src/backend/x86_64/codegen.rs +++ b/src/backend/x86_64/codegen.rs @@ -112,6 +112,57 @@ impl Codegen { self.instructions.push(Instruction::Setcc(cc, dest)); } + /// Emits a floating-point `mov` instruction (`movss` or `movsd`). + pub fn emit_fmov(&mut self, width: OperandWidth, src: Operand, dest: Operand) { + if matches!(dest, Operand::Imm(_)) { + panic!("x86-64 error: Cannot use an immediate as a floating-point destination."); + } + + if matches!(src, Operand::Mem { .. }) && matches!(dest, Operand::Mem { .. }) { + panic!("x86-64 error: Floating-point memory-to-memory moves are not allowed."); + } + + match width { + OperandWidth::DWord => self.instructions.push(Instruction::Movss(src, dest)), + OperandWidth::QWord => self.instructions.push(Instruction::Movsd(src, dest)), + _ => panic!( + "x86-64 error: Invalid width for floating-point move: {:?}", + width + ), + } + } + + /// Emits a floating-point arithmetic instruction. + pub fn emit_fbinary( + &mut self, + op_f32: fn(Operand, Operand) -> Instruction, + op_f64: fn(Operand, Operand) -> Instruction, + width: OperandWidth, + src: Operand, + dest: Operand, + ) { + if matches!(dest, Operand::Imm(_)) { + panic!( + "x86-64 error: Destination of a floating-point operation cannot be an immediate." + ); + } + + if matches!(src, Operand::Mem { .. }) && matches!(dest, Operand::Mem { .. }) { + panic!( + "x86-64 error: Floating-point binary operations cannot act on two memory addresses." + ); + } + + match width { + OperandWidth::DWord => self.instructions.push(op_f32(src, dest)), + OperandWidth::QWord => self.instructions.push(op_f64(src, dest)), + _ => panic!( + "x86-64 error: Invalid width for floating-point binary: {:?}", + width + ), + } + } + /// Helper to grab the finalized list of instructions pub fn finish(self) -> Vec { self.instructions diff --git a/src/backend/x86_64/mod.rs b/src/backend/x86_64/mod.rs index a217d94..a44eb5f 100644 --- a/src/backend/x86_64/mod.rs +++ b/src/backend/x86_64/mod.rs @@ -12,6 +12,7 @@ use crate::ir; #[derive(Clone, Copy, Debug)] enum Storage { Hardware(Gpr), + Xmm(Xmm), Stack(i32), Alloc(i32), } @@ -21,6 +22,7 @@ pub struct X86Backend<'a> { assembly: String, live_intervals: HashMap, allocations: HashMap, + reg_types: HashMap, } impl<'a> X86Backend<'a> { @@ -30,6 +32,7 @@ impl<'a> X86Backend<'a> { module, live_intervals: HashMap::new(), allocations: HashMap::new(), + reg_types: HashMap::new(), } } @@ -37,8 +40,25 @@ impl<'a> X86Backend<'a> { match op { ir::Operand::Integer(v) => Operand::Imm(*v as i64), ir::Operand::Boolean(b) => Operand::Imm(if *b { 1 } else { 0 }), + ir::Operand::Float(f) => { + // For simplicity, load into a scratch XMM register if it's a constant + // Real compilers would use a constant pool. + let bits = f.to_bits(); + cg.emit_mov( + OperandWidth::QWord, + Operand::Imm(bits as i64), + Operand::Reg(scratch_base), + ); + cg.instructions.push(Instruction::Mov( + OperandWidth::QWord, + Operand::Reg(scratch_base), + Operand::Xmm(Xmm::Xmm15), + )); + Operand::Xmm(Xmm::Xmm15) + } ir::Operand::Register(r) => match self.allocations.get(r).unwrap() { Storage::Hardware(hw_gpr) => Operand::Reg(*hw_gpr), + Storage::Xmm(hw_xmm) => Operand::Xmm(*hw_xmm), &Storage::Stack(offset) => Operand::Mem { base: Gpr::Rbp, offset, @@ -59,6 +79,7 @@ impl<'a> X86Backend<'a> { fn resolve_dest(&self, reg: ir::Register) -> Operand { match self.allocations.get(®).unwrap() { &Storage::Hardware(hw_gpr) => Operand::Reg(hw_gpr), + &Storage::Xmm(hw_xmm) => Operand::Xmm(hw_xmm), &Storage::Stack(offset) | &Storage::Alloc(offset) => Operand::Mem { base: Gpr::Rbp, offset, @@ -88,15 +109,62 @@ impl<'a> X86Backend<'a> { } fn compile_function(&mut self, func: &ir::Function) { - // 0. Pre-Pass: Handle Allocations + self.reg_types.clear(); + self.allocations.clear(); + self.live_intervals.clear(); + + // 0. Pre-Pass: Handle Allocations and collect register types let mut next_stack_offset = 0; - for block in &func.blocks { - for inst in &block.instructions { - if let ir::Instruction::Alloc { dest, .. } = inst { - next_stack_offset -= 8; - self.allocations - .insert(*dest, Storage::Alloc(next_stack_offset)); + for (ty, reg) in &func.params { + self.reg_types.insert(*reg, *ty); + } + + let mut changed = true; + while changed { + changed = false; + for block in &func.blocks { + for inst in &block.instructions { + let (dest, ty) = match inst { + ir::Instruction::Alloc { dest, .. } => { + if !self.allocations.contains_key(dest) { + next_stack_offset -= 8; + self.allocations + .insert(*dest, Storage::Alloc(next_stack_offset)); + } + (Some(*dest), Some(ir::Type::Ptr)) + } + ir::Instruction::Binary { + dest, result_ty, .. + } => (Some(*dest), Some(*result_ty)), + ir::Instruction::Unary { + dest, result_ty, .. + } => (Some(*dest), Some(*result_ty)), + ir::Instruction::Load { dest, ty, .. } => (Some(*dest), Some(*ty)), + ir::Instruction::Call { + dest, result_ty, .. + } => (Some(*dest), Some(*result_ty)), + ir::Instruction::Phi { + dest, result_ty, .. + } => (Some(*dest), Some(*result_ty)), + ir::Instruction::Copy { dest, src } => { + let ty = match src { + ir::Operand::Integer(_) => Some(ir::Type::I64), + ir::Operand::Boolean(_) => Some(ir::Type::Bool), + ir::Operand::Float(_) => Some(ir::Type::F64), + ir::Operand::Register(r) => self.reg_types.get(r).copied(), + }; + (Some(*dest), ty) + } + ir::Instruction::Store { .. } => (None, None), + }; + + if let (Some(d), Some(t)) = (dest, ty) { + if !self.reg_types.contains_key(&d) { + self.reg_types.insert(d, t); + changed = true; + } + } } } } @@ -118,9 +186,10 @@ impl<'a> X86Backend<'a> { ir::Instruction::Call { dest, args, .. } => { self.mark_use(*dest, inst_idx); let arg_regs = [Gpr::Rdi, Gpr::Rsi, Gpr::Rdx, Gpr::Rcx, Gpr::R8, Gpr::R9]; - for (i, (_, arg_op)) in args.iter().enumerate() { + for (i, (ty, arg_op)) in args.iter().enumerate() { self.mark_op(arg_op, inst_idx); if i < 6 + && !matches!(ty, ir::Type::F32 | ir::Type::F64) && let ir::Operand::Register(r) = arg_op { hints.insert(*r, arg_regs[i]); @@ -176,8 +245,26 @@ impl<'a> X86Backend<'a> { // 2. ABI-Aware Linear Scan Allocation let mut free_callee_saved = vec![Gpr::Rbx, Gpr::R12, Gpr::R13, Gpr::R14, Gpr::R15]; let mut free_caller_saved = vec![Gpr::Rdi, Gpr::Rsi, Gpr::Rdx, Gpr::Rcx, Gpr::R8, Gpr::R9]; + let mut free_xmm = vec![ + Xmm::Xmm0, + Xmm::Xmm1, + Xmm::Xmm2, + Xmm::Xmm3, + Xmm::Xmm4, + Xmm::Xmm5, + Xmm::Xmm6, + Xmm::Xmm7, + Xmm::Xmm8, + Xmm::Xmm9, + Xmm::Xmm10, + Xmm::Xmm11, + Xmm::Xmm12, + Xmm::Xmm13, + Xmm::Xmm14, + ]; - let mut active: Vec<(ir::Register, usize, Gpr, bool)> = Vec::new(); + let mut active_gpr: Vec<(ir::Register, usize, Gpr, bool)> = Vec::new(); + let mut active_xmm: Vec<(ir::Register, usize, Xmm)> = Vec::new(); let mut used_callee_saved = HashSet::new(); let live_intervals = take(&mut self.live_intervals); @@ -185,7 +272,7 @@ impl<'a> X86Backend<'a> { intervals_sorted.sort_by_key(|(_, (start, _))| *start); for (reg, (start, end)) in intervals_sorted { - active.retain(|(_, active_end, hw_reg, is_caller)| { + active_gpr.retain(|(_, active_end, hw_reg, is_caller)| { if *active_end < start { if *is_caller { free_caller_saved.push(*hw_reg); @@ -197,45 +284,66 @@ impl<'a> X86Backend<'a> { true } }); + active_xmm.retain(|(_, active_end, hw_reg)| { + if *active_end < start { + free_xmm.push(*hw_reg); + false + } else { + true + } + }); let crosses_call = call_indices .iter() .any(|&c_idx| c_idx > start && c_idx < end); - let mut selected_hw = None; - let mut is_caller = false; + let reg_ty = self.reg_types.get(®).unwrap(); + let is_float = matches!(reg_ty, ir::Type::F32 | ir::Type::F64); - if !crosses_call { - if let Some(&hint_reg) = hints.get(®) - && let Some(pos) = free_caller_saved.iter().position(|&r| r == hint_reg) - { - selected_hw = Some(free_caller_saved.remove(pos)); - is_caller = true; + if is_float { + if let Some(hw_xmm) = free_xmm.pop() { + self.allocations.insert(reg, Storage::Xmm(hw_xmm)); + active_xmm.push((reg, end, hw_xmm)); + } else { + next_stack_offset += 8; + self.allocations + .insert(reg, Storage::Stack(next_stack_offset)); } + } else { + let mut selected_hw = None; + let mut is_caller = false; + + if !crosses_call { + if let Some(&hint_reg) = hints.get(®) + && let Some(pos) = free_caller_saved.iter().position(|&r| r == hint_reg) + { + selected_hw = Some(free_caller_saved.remove(pos)); + is_caller = true; + } + if selected_hw.is_none() + && let Some(r) = free_caller_saved.pop() + { + selected_hw = Some(r); + is_caller = true; + } + } + if selected_hw.is_none() - && let Some(r) = free_caller_saved.pop() + && let Some(r) = free_callee_saved.pop() { selected_hw = Some(r); - is_caller = true; + used_callee_saved.insert(r); + is_caller = false; } - } - if selected_hw.is_none() - && let Some(r) = free_callee_saved.pop() - { - selected_hw = Some(r); - - used_callee_saved.insert(r); - is_caller = false; - } - - if let Some(hw_reg) = selected_hw { - self.allocations.insert(reg, Storage::Hardware(hw_reg)); - active.push((reg, end, hw_reg, is_caller)); - } else { - next_stack_offset += 8; - self.allocations - .insert(reg, Storage::Stack(next_stack_offset)); + if let Some(hw_reg) = selected_hw { + self.allocations.insert(reg, Storage::Hardware(hw_reg)); + active_gpr.push((reg, end, hw_reg, is_caller)); + } else { + next_stack_offset += 8; + self.allocations + .insert(reg, Storage::Stack(next_stack_offset)); + } } } @@ -291,25 +399,55 @@ impl<'a> X86Backend<'a> { // 4. Map ABI Arguments let arg_regs = [Gpr::Rdi, Gpr::Rsi, Gpr::Rdx, Gpr::Rcx, Gpr::R8, Gpr::R9]; - for (i, (ty, reg)) in func.params.iter().enumerate() { + let xmm_regs = [ + Xmm::Xmm0, + Xmm::Xmm1, + Xmm::Xmm2, + Xmm::Xmm3, + Xmm::Xmm4, + Xmm::Xmm5, + Xmm::Xmm6, + Xmm::Xmm7, + ]; + let mut gpr_idx = 0; + let mut xmm_idx = 0; + let mut stack_idx = 0; + + for (ty, reg) in &func.params { let width = OperandWidth::from_type(ty); let dest_op = self.resolve_dest(*reg); + let is_float = matches!(ty, ir::Type::F32 | ir::Type::F64); - if i < 6 { - pro_cg.emit_mov(width, Operand::Reg(arg_regs[i]), dest_op); - } else { - let caller_offset = 16 + ((i - 6) * 8); - let caller_mem = Operand::Mem { - base: Gpr::Rbp, - offset: caller_offset as i32, - }; - - // x86-64 constraint: Memory-to-memory moves must route through a register - if matches!(dest_op, Operand::Mem { .. }) { - pro_cg.emit_mov(width, caller_mem, Operand::Reg(Gpr::Rax)); - pro_cg.emit_mov(width, Operand::Reg(Gpr::Rax), dest_op); + if is_float { + if xmm_idx < 8 { + pro_cg.emit_fmov(width, Operand::Xmm(xmm_regs[xmm_idx]), dest_op); + xmm_idx += 1; } else { - pro_cg.emit_mov(width, caller_mem, dest_op); + let caller_offset = 16 + (stack_idx * 8); + let caller_mem = Operand::Mem { + base: Gpr::Rbp, + offset: caller_offset as i32, + }; + pro_cg.emit_fmov(width, caller_mem, dest_op); + stack_idx += 1; + } + } else { + if gpr_idx < 6 { + pro_cg.emit_mov(width, Operand::Reg(arg_regs[gpr_idx]), dest_op); + gpr_idx += 1; + } else { + let caller_offset = 16 + (stack_idx * 8); + let caller_mem = Operand::Mem { + base: Gpr::Rbp, + offset: caller_offset as i32, + }; + if matches!(dest_op, Operand::Mem { .. }) { + pro_cg.emit_mov(width, caller_mem, Operand::Reg(Gpr::Rax)); + pro_cg.emit_mov(width, Operand::Reg(Gpr::Rax), dest_op); + } else { + pro_cg.emit_mov(width, caller_mem, dest_op); + } + stack_idx += 1; } } } @@ -331,26 +469,31 @@ impl<'a> X86Backend<'a> { writeln!(&mut self.assembly, ".L{}_block_{}:", func.name, block.id.0).unwrap(); - // Peephole Optimization: Fuse an immediately preceding ICmp into the Branch - let mut fused_cmp = None; + // Peephole Optimization: Fuse an immediately preceding Comparison into the Branch + let mut fused_icmp = None; + let mut fused_fcmp = None; if let ir::Terminator::Branch { cond: ir::Operand::Register(cond_reg), .. } = block.terminator && let Some(ir::Instruction::Binary { dest, - op: ir::BinaryOp::ICmp(cmp_op), + op, src1, src2, .. }) = block.instructions.last() && cond_reg == *dest { - fused_cmp = Some((*cmp_op, *src1, *src2)); + match op { + ir::BinaryOp::ICmp(cmp_op) => fused_icmp = Some((*cmp_op, *src1, *src2)), + ir::BinaryOp::FCmp(cmp_op) => fused_fcmp = Some((*cmp_op, *src1, *src2)), + _ => {} + } } // If we fused the comparison, we omit the last instruction from standard compilation - let inst_limit = if fused_cmp.is_some() { + let inst_limit = if fused_icmp.is_some() || fused_fcmp.is_some() { block.instructions.len() - 1 } else { block.instructions.len() @@ -366,7 +509,8 @@ impl<'a> X86Backend<'a> { &block.terminator, &func.name, next_block_id, - fused_cmp, + fused_icmp, + fused_fcmp, &mut block_instructions, ); @@ -424,22 +568,42 @@ impl<'a> X86Backend<'a> { ir::Instruction::Alloc { .. } => {} // Handled in prologue ir::Instruction::Copy { dest, src } => { - // Assuming everything is 64-bit for a raw copy unless we track sizes perfectly - let width = OperandWidth::QWord; let dest_op = self.resolve_dest(*dest); let src_op = self.resolve_op(src, Gpr::Rax, &mut cg); - - if matches!(dest_op, Operand::Mem { .. }) && matches!(src_op, Operand::Mem { .. }) { - cg.emit_mov(width, src_op, Operand::Reg(Gpr::Rax)); - cg.emit_mov(width, Operand::Reg(Gpr::Rax), dest_op); + let is_float = + matches!(src_op, Operand::Xmm(_)) || matches!(dest_op, Operand::Xmm(_)); + let width = if is_float { + let ty = self.reg_types.get(dest).unwrap(); + OperandWidth::from_type(ty) } else { - cg.emit_mov(width, src_op, dest_op); + OperandWidth::QWord + }; + + if is_float { + if matches!(dest_op, Operand::Mem { .. }) + && matches!(src_op, Operand::Mem { .. }) + { + cg.emit_fmov(width, src_op, Operand::Xmm(Xmm::Xmm15)); + cg.emit_fmov(width, Operand::Xmm(Xmm::Xmm15), dest_op); + } else { + cg.emit_fmov(width, src_op, dest_op); + } + } else { + if matches!(dest_op, Operand::Mem { .. }) + && matches!(src_op, Operand::Mem { .. }) + { + cg.emit_mov(width, src_op, Operand::Reg(Gpr::Rax)); + cg.emit_mov(width, Operand::Reg(Gpr::Rax), dest_op); + } else { + cg.emit_mov(width, src_op, dest_op); + } } } ir::Instruction::Load { ty, dest, src } => { let width = OperandWidth::from_type(ty); let dest_op = self.resolve_dest(*dest); + let is_float = matches!(ty, ir::Type::F32 | ir::Type::F64); // If loading directly from an Alloc, we can bypass pointer materialization if let ir::Operand::Register(r) = src @@ -449,7 +613,11 @@ impl<'a> X86Backend<'a> { base: Gpr::Rbp, offset, }; - cg.emit_mov(width, mem, dest_op); + if is_float { + cg.emit_fmov(width, mem, dest_op); + } else { + cg.emit_mov(width, mem, dest_op); + } block_instructions.extend(cg.finish()); return; } @@ -472,17 +640,27 @@ impl<'a> X86Backend<'a> { offset: 0, }; - if matches!(dest_op, Operand::Mem { .. }) { - cg.emit_mov(width, deref, Operand::Reg(Gpr::R11)); - cg.emit_mov(width, Operand::Reg(Gpr::R11), dest_op); + if is_float { + if matches!(dest_op, Operand::Mem { .. }) { + cg.emit_fmov(width, deref, Operand::Xmm(Xmm::Xmm15)); + cg.emit_fmov(width, Operand::Xmm(Xmm::Xmm15), dest_op); + } else { + cg.emit_fmov(width, deref, dest_op); + } } else { - cg.emit_mov(width, deref, dest_op); + if matches!(dest_op, Operand::Mem { .. }) { + cg.emit_mov(width, deref, Operand::Reg(Gpr::R11)); + cg.emit_mov(width, Operand::Reg(Gpr::R11), dest_op); + } else { + cg.emit_mov(width, deref, dest_op); + } } } ir::Instruction::Store { ty, dest, src } => { let width = OperandWidth::from_type(ty); let src_op = self.resolve_op(src, Gpr::Rax, &mut cg); + let is_float = matches!(ty, ir::Type::F32 | ir::Type::F64); // If storing directly to an Alloc, bypass pointer materialization if let ir::Operand::Register(r) = dest @@ -492,11 +670,22 @@ impl<'a> X86Backend<'a> { base: Gpr::Rbp, offset, }; - if matches!(src_op, Operand::Mem { .. }) || matches!(src_op, Operand::Imm(_)) { - cg.emit_mov(width, src_op, Operand::Reg(Gpr::Rax)); - cg.emit_mov(width, Operand::Reg(Gpr::Rax), mem); + if is_float { + if matches!(src_op, Operand::Mem { .. }) { + cg.emit_fmov(width, src_op, Operand::Xmm(Xmm::Xmm15)); + cg.emit_fmov(width, Operand::Xmm(Xmm::Xmm15), mem); + } else { + cg.emit_fmov(width, src_op, mem); + } } else { - cg.emit_mov(width, src_op, mem); + if matches!(src_op, Operand::Mem { .. }) + || matches!(src_op, Operand::Imm(_)) + { + cg.emit_mov(width, src_op, Operand::Reg(Gpr::Rax)); + cg.emit_mov(width, Operand::Reg(Gpr::Rax), mem); + } else { + cg.emit_mov(width, src_op, mem); + } } block_instructions.extend(cg.finish()); return; @@ -517,11 +706,20 @@ impl<'a> X86Backend<'a> { offset: 0, }; - if matches!(src_op, Operand::Mem { .. }) || matches!(src_op, Operand::Imm(_)) { - cg.emit_mov(width, src_op, Operand::Reg(Gpr::Rax)); - cg.emit_mov(width, Operand::Reg(Gpr::Rax), deref); + if is_float { + if matches!(src_op, Operand::Mem { .. }) { + cg.emit_fmov(width, src_op, Operand::Xmm(Xmm::Xmm15)); + cg.emit_fmov(width, Operand::Xmm(Xmm::Xmm15), deref); + } else { + cg.emit_fmov(width, src_op, deref); + } } else { - cg.emit_mov(width, src_op, deref); + if matches!(src_op, Operand::Mem { .. }) || matches!(src_op, Operand::Imm(_)) { + cg.emit_mov(width, src_op, Operand::Reg(Gpr::Rax)); + cg.emit_mov(width, Operand::Reg(Gpr::Rax), deref); + } else { + cg.emit_mov(width, src_op, deref); + } } } @@ -534,51 +732,92 @@ impl<'a> X86Backend<'a> { } => { let width = OperandWidth::from_type(result_ty); let dest_op = self.resolve_dest(*dest); - let src1_op = self.resolve_op(src1, Gpr::R10, &mut cg); - let src2_op = self.resolve_op(src2, Gpr::R11, &mut cg); + let is_float = matches!(result_ty, ir::Type::F32 | ir::Type::F64); - match op { - ir::BinaryOp::Add | ir::BinaryOp::Sub | ir::BinaryOp::SMul => { - let x86_op = match op { - ir::BinaryOp::Add => Instruction::Add, - ir::BinaryOp::Sub => Instruction::Sub, - ir::BinaryOp::SMul => Instruction::IMul, - _ => unreachable!(), - }; + if is_float { + let src1_op = self.resolve_op(src1, Gpr::R10, &mut cg); + let src2_op = self.resolve_op(src2, Gpr::R11, &mut cg); - // Ensure dest = src1 before applying destructive x86 math - if src1_op != dest_op { - if matches!(src1_op, Operand::Mem { .. }) + if src1_op != dest_op { + cg.emit_fmov(width, src1_op, dest_op); + } + + match op { + ir::BinaryOp::FAdd => cg.emit_fbinary( + Instruction::Addss, + Instruction::Addsd, + width, + src2_op, + dest_op, + ), + ir::BinaryOp::FSub => cg.emit_fbinary( + Instruction::Subss, + Instruction::Subsd, + width, + src2_op, + dest_op, + ), + ir::BinaryOp::FMul => cg.emit_fbinary( + Instruction::Mulss, + Instruction::Mulsd, + width, + src2_op, + dest_op, + ), + ir::BinaryOp::FDiv => cg.emit_fbinary( + Instruction::Divss, + Instruction::Divsd, + width, + src2_op, + dest_op, + ), + _ => todo!("Implement other float binary ops: {:?}", op), + } + } else { + let src1_op = self.resolve_op(src1, Gpr::R10, &mut cg); + let src2_op = self.resolve_op(src2, Gpr::R11, &mut cg); + + match op { + ir::BinaryOp::Add | ir::BinaryOp::Sub | ir::BinaryOp::SMul => { + let x86_op = match op { + ir::BinaryOp::Add => Instruction::Add, + ir::BinaryOp::Sub => Instruction::Sub, + ir::BinaryOp::SMul => Instruction::IMul, + _ => unreachable!(), + }; + + if src1_op != dest_op { + if matches!(src1_op, Operand::Mem { .. }) + && matches!(dest_op, Operand::Mem { .. }) + { + cg.emit_mov(width, src1_op, Operand::Reg(Gpr::Rax)); + cg.emit_mov(width, Operand::Reg(Gpr::Rax), dest_op); + } else { + cg.emit_mov(width, src1_op, dest_op); + } + } + + if matches!(src2_op, Operand::Mem { .. }) && matches!(dest_op, Operand::Mem { .. }) { - cg.emit_mov(width, src1_op, Operand::Reg(Gpr::Rax)); - cg.emit_mov(width, Operand::Reg(Gpr::Rax), dest_op); + cg.emit_mov(width, src2_op, Operand::Reg(Gpr::Rax)); + cg.emit_binary(x86_op, width, Operand::Reg(Gpr::Rax), dest_op); } else { - cg.emit_mov(width, src1_op, dest_op); + cg.emit_binary(x86_op, width, src2_op, dest_op); } } - - if matches!(src2_op, Operand::Mem { .. }) - && matches!(dest_op, Operand::Mem { .. }) - { - cg.emit_mov(width, src2_op, Operand::Reg(Gpr::Rax)); - cg.emit_binary(x86_op, width, Operand::Reg(Gpr::Rax), dest_op); - } else { - cg.emit_binary(x86_op, width, src2_op, dest_op); + ir::BinaryOp::UMul => { + cg.emit_mov(width, src1_op, Operand::Reg(Gpr::Rax)); + if matches!(src2_op, Operand::Imm(_)) { + cg.emit_mov(width, src2_op, Operand::Reg(Gpr::R10)); + cg.emit_umul(width, Operand::Reg(Gpr::R10)); + } else { + cg.emit_umul(width, src2_op); + } + cg.emit_mov(width, Operand::Reg(Gpr::Rax), dest_op); } + _ => todo!("Implement other binary ops: {:?}", op), } - ir::BinaryOp::UMul => { - cg.emit_mov(width, src1_op, Operand::Reg(Gpr::Rax)); - if matches!(src2_op, Operand::Imm(_)) { - cg.emit_mov(width, src2_op, Operand::Reg(Gpr::R10)); - cg.emit_umul(width, Operand::Reg(Gpr::R10)); - } else { - cg.emit_umul(width, src2_op); - } - cg.emit_mov(width, Operand::Reg(Gpr::Rax), dest_op); - } - // Implement Div/Rem/Cmp similarly using cg.emit_* methods... - _ => unimplemented!("Implement other binary ops"), } } @@ -590,11 +829,46 @@ impl<'a> X86Backend<'a> { } => { let width = OperandWidth::from_type(result_ty); let dest_op = self.resolve_dest(*dest); - let src_op = self.resolve_op(src, Gpr::Rax, &mut cg); + let is_float = matches!(result_ty, ir::Type::F32 | ir::Type::F64); - cg.emit_mov(width, src_op, dest_op); - match op { - ir::UnaryOp::INeg => cg.instructions.push(Instruction::Neg(width, dest_op)), + if is_float { + let src_op = self.resolve_op(src, Gpr::Rax, &mut cg); + match op { + ir::UnaryOp::FNeg => { + // Negation: xor with -0.0 + let mask = if *result_ty == ir::Type::F32 { + 0x80000000u64 + } else { + 0x8000000000000000u64 + }; + cg.emit_mov( + OperandWidth::QWord, + Operand::Imm(mask as i64), + Operand::Reg(Gpr::Rax), + ); + cg.instructions.push(Instruction::Mov( + OperandWidth::QWord, + Operand::Reg(Gpr::Rax), + Operand::Xmm(Xmm::Xmm15), + )); + cg.emit_fmov(width, src_op, dest_op); + if *result_ty == ir::Type::F32 { + cg.instructions + .push(Instruction::Xorps(Operand::Xmm(Xmm::Xmm15), dest_op)); + } else { + cg.instructions + .push(Instruction::Xorpd(Operand::Xmm(Xmm::Xmm15), dest_op)); + } + } + _ => todo!("Implement other float unary ops"), + } + } else { + let src_op = self.resolve_op(src, Gpr::Rax, &mut cg); + cg.emit_mov(width, src_op, dest_op); + match op { + ir::UnaryOp::INeg => cg.instructions.push(Instruction::Neg(width, dest_op)), + _ => unreachable!(), + } } } @@ -606,18 +880,57 @@ impl<'a> X86Backend<'a> { } => { let width = OperandWidth::from_type(result_ty); let arg_regs = [Gpr::Rdi, Gpr::Rsi, Gpr::Rdx, Gpr::Rcx, Gpr::R8, Gpr::R9]; + let xmm_regs = [ + Xmm::Xmm0, + Xmm::Xmm1, + Xmm::Xmm2, + Xmm::Xmm3, + Xmm::Xmm4, + Xmm::Xmm5, + Xmm::Xmm6, + Xmm::Xmm7, + ]; + let mut gpr_idx = 0; + let mut xmm_idx = 0; + let mut stack_args = Vec::new(); - for (i, (ty, arg_op)) in args.iter().enumerate() { + for (ty, arg_op) in args { let arg_width = OperandWidth::from_type(ty); let src = self.resolve_op(arg_op, Gpr::R10, &mut cg); - if i < 6 { - cg.emit_mov(arg_width, src, Operand::Reg(arg_regs[i])); + let is_float = matches!(ty, ir::Type::F32 | ir::Type::F64); + + if is_float { + if xmm_idx < 8 { + cg.emit_fmov(arg_width, src, Operand::Xmm(xmm_regs[xmm_idx])); + xmm_idx += 1; + } else { + stack_args.push((arg_width, src, true)); + } + } else { + if gpr_idx < 6 { + cg.emit_mov(arg_width, src, Operand::Reg(arg_regs[gpr_idx])); + gpr_idx += 1; + } else { + stack_args.push((arg_width, src, false)); + } + } + } + + for (_, src, is_float) in stack_args.iter().rev() { + if *is_float { + // Push float by moving to GPR first then pushing + cg.instructions.push(Instruction::Mov( + OperandWidth::QWord, + *src, + Operand::Reg(Gpr::Rax), + )); + cg.emit_push(Operand::Reg(Gpr::Rax)); } else { if matches!(src, Operand::Mem { .. }) { - cg.emit_mov(OperandWidth::QWord, src, Operand::Reg(Gpr::Rax)); + cg.emit_mov(OperandWidth::QWord, *src, Operand::Reg(Gpr::Rax)); cg.emit_push(Operand::Reg(Gpr::Rax)); } else { - cg.emit_push(src); + cg.emit_push(*src); } } } @@ -628,29 +941,31 @@ impl<'a> X86Backend<'a> { .iter() .find_map(|f| (f.id == *func).then(|| f.name.clone())) .unwrap(); - cg.instructions.push(Instruction::Call(func_name)); - if args.len() > 6 { - let cleanup = ((args.len() - 6) * 8) as i64; + let stack_cleanup = stack_args.len() * 8; + if stack_cleanup > 0 { cg.emit_binary( Instruction::Add, OperandWidth::QWord, - Operand::Imm(cleanup), + Operand::Imm(stack_cleanup as i64), Operand::Reg(Gpr::Rsp), ); } if *result_ty != ir::Type::Void { let dest_op = self.resolve_dest(*dest); - cg.emit_mov(width, Operand::Reg(Gpr::Rax), dest_op); + if matches!(result_ty, ir::Type::F32 | ir::Type::F64) { + cg.emit_fmov(width, Operand::Xmm(Xmm::Xmm0), dest_op); + } else { + cg.emit_mov(width, Operand::Reg(Gpr::Rax), dest_op); + } } } ir::Instruction::Phi { .. } => unreachable!("Run SSA Destruction pass before codegen!"), } - // Commit generated instructions to the block block_instructions.extend(cg.finish()); } @@ -659,7 +974,8 @@ impl<'a> X86Backend<'a> { term: &ir::Terminator, func_name: &str, next_block_id: Option, - fused_cmp: Option<(ir::ICmpOp, ir::Operand, ir::Operand)>, + fused_icmp: Option<(ir::ICmpOp, ir::Operand, ir::Operand)>, + fused_fcmp: Option<(ir::FCmpOp, ir::Operand, ir::Operand)>, block_instructions: &mut Vec, ) { let mut cg = Codegen::new(); @@ -669,10 +985,12 @@ impl<'a> X86Backend<'a> { if let Some(val) = value { let width = OperandWidth::from_type(return_ty); let src = self.resolve_op(val, Gpr::R10, &mut cg); - cg.emit_mov(width, src, Operand::Reg(Gpr::Rax)); + if matches!(return_ty, ir::Type::F32 | ir::Type::F64) { + cg.emit_fmov(width, src, Operand::Xmm(Xmm::Xmm0)); + } else { + cg.emit_mov(width, src, Operand::Reg(Gpr::Rax)); + } } - - // If there is no next block, we can fallthrough. Otherwise, emit a jump. if next_block_id.is_some() { cg.instructions .push(Instruction::Jmp(format!(".L{}_epilogue", func_name))); @@ -693,17 +1011,11 @@ impl<'a> X86Backend<'a> { then_block, else_block, } => { - let (cc_true, cc_false) = if let Some((cmp_op, src1, src2)) = fused_cmp { + let (cc_true, cc_false) = if let Some((cmp_op, src1, src2)) = fused_icmp { let src1_op = self.resolve_op(&src1, Gpr::R10, &mut cg); let src2_op = self.resolve_op(&src2, Gpr::R11, &mut cg); - - // Note: IR Binary ops only track the result type (Bool for ICmp). - // To be perfectly accurate in the future, we may want to track operand types. - // For now, we assume pointers/64-bit integers (QWord) for the comparison. let width = OperandWidth::QWord; - // x86-64 constraints: destination (src1 in AT&T) cannot be an immediate. - // Memory-to-memory comparisons are also invalid. if matches!(src1_op, Operand::Mem { .. }) && matches!(src2_op, Operand::Mem { .. }) || matches!(src1_op, Operand::Imm(_)) @@ -731,11 +1043,35 @@ impl<'a> X86Backend<'a> { ir::ICmpOp::Ugt => (ConditionCode::A, ConditionCode::Be), ir::ICmpOp::Uge => (ConditionCode::Ae, ConditionCode::B), } - } else { - // Fallback: evaluate an isolated boolean - let cond_op = self.resolve_op(cond, Gpr::R10, &mut cg); - let width = OperandWidth::Byte; // Booleans are 8-bit + } else if let Some((cmp_op, src1, src2)) = fused_fcmp { + let src1_op = self.resolve_op(&src1, Gpr::R10, &mut cg); + let src2_op = self.resolve_op(&src2, Gpr::R11, &mut cg); + // Get type of src1 to determine width + let ty = if let ir::Operand::Register(r) = src1 { + *self.reg_types.get(&r).unwrap() + } else { + ir::Type::F64 + }; + let width = OperandWidth::from_type(&ty); + if width == OperandWidth::DWord { + cg.instructions.push(Instruction::Ucomiss(src2_op, src1_op)); + } else { + cg.instructions.push(Instruction::Ucomisd(src2_op, src1_op)); + } + + match cmp_op { + ir::FCmpOp::Oeq => (ConditionCode::E, ConditionCode::Ne), + ir::FCmpOp::One => (ConditionCode::Ne, ConditionCode::E), + ir::FCmpOp::Olt => (ConditionCode::B, ConditionCode::Ae), + ir::FCmpOp::Ole => (ConditionCode::Be, ConditionCode::A), + ir::FCmpOp::Ogt => (ConditionCode::A, ConditionCode::Be), + ir::FCmpOp::Oge => (ConditionCode::Ae, ConditionCode::B), + _ => todo!("Implement other fcmp ops in branch: {:?}", cmp_op), + } + } else { + let cond_op = self.resolve_op(cond, Gpr::R10, &mut cg); + let width = OperandWidth::Byte; if matches!(cond_op, Operand::Imm(_)) || matches!(cond_op, Operand::Mem { .. }) { cg.emit_mov(width, cond_op, Operand::Reg(Gpr::Rax)); @@ -748,14 +1084,12 @@ impl<'a> X86Backend<'a> { cg.instructions .push(Instruction::Test(width, cond_op, cond_op)); } - (ConditionCode::Nz, ConditionCode::Z) }; let then_label = format!(".L{}_block_{}", func_name, then_block.0); let else_label = format!(".L{}_block_{}", func_name, else_block.0); - // Fallthrough logic cleanly applied to dynamically resolved jump conditions if Some(*else_block) == next_block_id { cg.instructions.push(Instruction::Jcc(cc_true, then_label)); } else if Some(*then_block) == next_block_id { @@ -768,7 +1102,6 @@ impl<'a> X86Backend<'a> { ir::Terminator::Unknown => panic!("Cannot compile Unknown terminator"), } - // Commit generated instructions to the block block_instructions.extend(cg.finish()); } } diff --git a/src/backend/x86_64/types.rs b/src/backend/x86_64/types.rs index 916fea0..354f95d 100644 --- a/src/backend/x86_64/types.rs +++ b/src/backend/x86_64/types.rs @@ -26,8 +26,8 @@ impl OperandWidth { match ty { Type::Bool | Type::I8 => OperandWidth::Byte, Type::I16 => OperandWidth::Word, - Type::I32 => OperandWidth::DWord, - Type::I64 | Type::Ptr => OperandWidth::QWord, + Type::I32 | Type::F32 => OperandWidth::DWord, + Type::I64 | Type::F64 | Type::Ptr => OperandWidth::QWord, Type::Void => panic!("x86-64 error: Cannot compute width of Void type"), } } @@ -132,6 +132,49 @@ impl Gpr { } } +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum Xmm { + Xmm0, + Xmm1, + Xmm2, + Xmm3, + Xmm4, + Xmm5, + Xmm6, + Xmm7, + Xmm8, + Xmm9, + Xmm10, + Xmm11, + Xmm12, + Xmm13, + Xmm14, + Xmm15, +} + +impl Xmm { + pub fn format(self) -> &'static str { + match self { + Xmm::Xmm0 => "%xmm0", + Xmm::Xmm1 => "%xmm1", + Xmm::Xmm2 => "%xmm2", + Xmm::Xmm3 => "%xmm3", + Xmm::Xmm4 => "%xmm4", + Xmm::Xmm5 => "%xmm5", + Xmm::Xmm6 => "%xmm6", + Xmm::Xmm7 => "%xmm7", + Xmm::Xmm8 => "%xmm8", + Xmm::Xmm9 => "%xmm9", + Xmm::Xmm10 => "%xmm10", + Xmm::Xmm11 => "%xmm11", + Xmm::Xmm12 => "%xmm12", + Xmm::Xmm13 => "%xmm13", + Xmm::Xmm14 => "%xmm14", + Xmm::Xmm15 => "%xmm15", + } + } +} + #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum ConditionCode { E, @@ -172,8 +215,10 @@ impl fmt::Display for ConditionCode { pub enum Operand { /// An immediate constant, e.g. `$5` Imm(i64), - /// A hardware register + /// A hardware GPR register Reg(Gpr), + /// A hardware XMM register + Xmm(Xmm), /// A memory operand: `offset(base_reg)` Mem { base: Gpr, offset: i32 }, } @@ -189,6 +234,7 @@ impl Operand { match self { Operand::Imm(val) => format!("${}", val), Operand::Reg(gpr) => gpr.format_with_width(width).to_string(), + Operand::Xmm(xmm) => xmm.format().to_string(), // Memory addresses inherently use 64-bit pointers (QWord) for the base register, // even if the value being read/written is smaller. Operand::Mem { base, offset } => { @@ -210,6 +256,10 @@ pub enum Instruction { Movzx(OperandWidth, OperandWidth, Operand, Operand), // src_width, dest_width, src, dest Lea(OperandWidth, Operand, Operand), // width, src (mem), dest (reg) + // Floating point movement + Movss(Operand, Operand), // 32-bit + Movsd(Operand, Operand), // 64-bit + // Stack manipulation (typically 64-bit in x86-64) Push(Operand), Pop(Operand), @@ -224,9 +274,23 @@ pub enum Instruction { Div(OperandWidth, Operand), Cqto, // Sign-extend RAX into RDX:RAX + // Floating point arithmetic + Addss(Operand, Operand), + Addsd(Operand, Operand), + Subss(Operand, Operand), + Subsd(Operand, Operand), + Mulss(Operand, Operand), + Mulsd(Operand, Operand), + Divss(Operand, Operand), + Divsd(Operand, Operand), + Xorps(Operand, Operand), // Used for negation/clearing + Xorpd(Operand, Operand), + // Logic & Comparison Cmp(OperandWidth, Operand, Operand), // width, src, dest Test(OperandWidth, Operand, Operand), + Ucomiss(Operand, Operand), + Ucomisd(Operand, Operand), Setcc(ConditionCode, Operand), // e.g., sete %al // Control Flow @@ -268,6 +332,22 @@ impl fmt::Display for Instruction { dest.format_with_width(*w) ) } + Instruction::Movss(src, dest) => { + write!( + f, + "movss {}, {}", + src.format_with_width(OperandWidth::DWord), + dest.format_with_width(OperandWidth::DWord) + ) + } + Instruction::Movsd(src, dest) => { + write!( + f, + "movsd {}, {}", + src.format_with_width(OperandWidth::QWord), + dest.format_with_width(OperandWidth::QWord) + ) + } Instruction::Push(op) => { write!(f, "pushq {}", op.format_with_width(OperandWidth::QWord)) } @@ -316,6 +396,86 @@ impl fmt::Display for Instruction { Instruction::Cqto => { write!(f, "cqto") } + Instruction::Addss(src, dest) => { + write!( + f, + "addss {}, {}", + src.format_with_width(OperandWidth::DWord), + dest.format_with_width(OperandWidth::DWord) + ) + } + Instruction::Addsd(src, dest) => { + write!( + f, + "addsd {}, {}", + src.format_with_width(OperandWidth::QWord), + dest.format_with_width(OperandWidth::QWord) + ) + } + Instruction::Subss(src, dest) => { + write!( + f, + "subss {}, {}", + src.format_with_width(OperandWidth::DWord), + dest.format_with_width(OperandWidth::DWord) + ) + } + Instruction::Subsd(src, dest) => { + write!( + f, + "subsd {}, {}", + src.format_with_width(OperandWidth::QWord), + dest.format_with_width(OperandWidth::QWord) + ) + } + Instruction::Mulss(src, dest) => { + write!( + f, + "mulss {}, {}", + src.format_with_width(OperandWidth::DWord), + dest.format_with_width(OperandWidth::DWord) + ) + } + Instruction::Mulsd(src, dest) => { + write!( + f, + "mulsd {}, {}", + src.format_with_width(OperandWidth::QWord), + dest.format_with_width(OperandWidth::QWord) + ) + } + Instruction::Divss(src, dest) => { + write!( + f, + "divss {}, {}", + src.format_with_width(OperandWidth::DWord), + dest.format_with_width(OperandWidth::DWord) + ) + } + Instruction::Divsd(src, dest) => { + write!( + f, + "divsd {}, {}", + src.format_with_width(OperandWidth::QWord), + dest.format_with_width(OperandWidth::QWord) + ) + } + Instruction::Xorps(src, dest) => { + write!( + f, + "xorps {}, {}", + src.format_with_width(OperandWidth::DWord), + dest.format_with_width(OperandWidth::DWord) + ) + } + Instruction::Xorpd(src, dest) => { + write!( + f, + "xorpd {}, {}", + src.format_with_width(OperandWidth::QWord), + dest.format_with_width(OperandWidth::QWord) + ) + } Instruction::Cmp(w, src, dest) => { write!( f, @@ -334,6 +494,22 @@ impl fmt::Display for Instruction { dest.format_with_width(*w) ) } + Instruction::Ucomiss(src, dest) => { + write!( + f, + "ucomiss {}, {}", + src.format_with_width(OperandWidth::DWord), + dest.format_with_width(OperandWidth::DWord) + ) + } + Instruction::Ucomisd(src, dest) => { + write!( + f, + "ucomisd {}, {}", + src.format_with_width(OperandWidth::QWord), + dest.format_with_width(OperandWidth::QWord) + ) + } Instruction::Setcc(cc, dest) => { // setcc strictly operates on 8-bit (byte) registers write!( diff --git a/src/builder.rs b/src/builder.rs index 49810e6..55134c9 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -282,6 +282,48 @@ impl IrFunctionBuilder { self.build_binary(Type::Bool, BinaryOp::ICmp(cmp_op), lhs, rhs) } + /// Builds an `fadd` instruction. + pub fn build_fadd(&mut self, result_ty: Type, lhs: Operand, rhs: Operand) -> Operand { + self.build_binary(result_ty, BinaryOp::FAdd, lhs, rhs) + } + + /// Builds an `fsub` instruction. + pub fn build_fsub(&mut self, result_ty: Type, lhs: Operand, rhs: Operand) -> Operand { + self.build_binary(result_ty, BinaryOp::FSub, lhs, rhs) + } + + /// Builds an `fmul` instruction. + pub fn build_fmul(&mut self, result_ty: Type, lhs: Operand, rhs: Operand) -> Operand { + self.build_binary(result_ty, BinaryOp::FMul, lhs, rhs) + } + + /// Builds an `fdiv` instruction. + pub fn build_fdiv(&mut self, result_ty: Type, lhs: Operand, rhs: Operand) -> Operand { + self.build_binary(result_ty, BinaryOp::FDiv, lhs, rhs) + } + + /// Builds an `frem` instruction. + pub fn build_frem(&mut self, result_ty: Type, lhs: Operand, rhs: Operand) -> Operand { + self.build_binary(result_ty, BinaryOp::FRem, lhs, rhs) + } + + /// Builds an `fcmp` instruction. + pub fn build_fcmp(&mut self, cmp_op: FCmpOp, lhs: Operand, rhs: Operand) -> Operand { + self.build_binary(Type::Bool, BinaryOp::FCmp(cmp_op), lhs, rhs) + } + + /// Builds an `fneg` instruction. + pub fn build_fneg(&mut self, result_ty: Type, src: Operand) -> Operand { + let dest = self.allocate_register(); + self.insert_instruction(Instruction::Unary { + dest, + result_ty, + op: UnaryOp::FNeg, + src, + }); + Operand::Register(dest) + } + /// Builds a `call` instruction. pub fn build_call<'a>( &mut self, diff --git a/src/ir.rs b/src/ir.rs index a04e425..41b96a2 100644 --- a/src/ir.rs +++ b/src/ir.rs @@ -5,6 +5,8 @@ pub enum Type { I16, I32, I64, + F32, + F64, Ptr, Void, } @@ -12,13 +14,39 @@ pub enum Type { #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct Register(pub usize); -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, PartialEq)] pub enum Operand { Integer(u64), Boolean(bool), + Float(f64), Register(Register), } +impl Eq for Operand {} + +impl std::hash::Hash for Operand { + fn hash(&self, state: &mut H) { + match self { + Operand::Integer(i) => { + state.write_u8(0); + i.hash(state); + } + Operand::Boolean(b) => { + state.write_u8(1); + b.hash(state); + } + Operand::Float(f) => { + state.write_u8(2); + state.write_u64(f.to_bits()); + } + Operand::Register(r) => { + state.write_u8(3); + r.hash(state); + } + } + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum ICmpOp { Slt, @@ -33,6 +61,24 @@ pub enum ICmpOp { Ne, } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum FCmpOp { + Oeq, + Ogt, + Oge, + Olt, + Ole, + One, + Ord, + Uno, + Ueq, + Ugt, + Uge, + Ult, + Ule, + Une, +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum BinaryOp { Add, @@ -44,11 +90,18 @@ pub enum BinaryOp { SRem, URem, ICmp(ICmpOp), + FAdd, + FSub, + FMul, + FDiv, + FRem, + FCmp(FCmpOp), } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum UnaryOp { INeg, + FNeg, } #[derive(Debug, PartialEq, Eq)] diff --git a/src/main.rs b/src/main.rs index 93309b1..9a3eadb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,6 +6,7 @@ fn build_test_module() -> Module { let mut module_builder = IrModuleBuilder::new(); let i32_ty = Type::I32; + let f64_ty = Type::F64; // 1. Define the Factorial Function // factorial(n: i32) -> i32 @@ -64,31 +65,44 @@ fn build_test_module() -> Module { } module_builder.complete_function(); - // 2. Define the Main Function + // 2. Define a Floating Point Function: hypotenuse_sq(a: f64, b: f64) -> f64 + let hypot_id = module_builder.new_function_id(); + { + let f_builder = + module_builder.new_function(hypot_id, "hypotenuse_sq", vec![&f64_ty, &f64_ty], f64_ty); + + let a = f_builder.get_param(0).unwrap(); + let b = f_builder.get_param(1).unwrap(); + + let a_sq = f_builder.build_fmul(f64_ty, a, a); + let b_sq = f_builder.build_fmul(f64_ty, b, b); + let res = f_builder.build_fadd(f64_ty, a_sq, b_sq); + + f_builder.build_return(f64_ty, res); + } + module_builder.complete_function(); + + // 3. Define the Main Function let main_id = module_builder.new_function_id(); { - let f_builder = module_builder.new_function( - main_id, - "main", - vec![], // no params - i32_ty, - ); + let f_builder = module_builder.new_function(main_id, "main", vec![], i32_ty); - // 1. Allocate space for an integer on the stack + // 1. Integer math: call factorial(5) let ptr = f_builder.build_alloc(i32_ty); - - // 2. Store the value '5' into that pointer let input_val = Operand::Integer(5); f_builder.build_store(i32_ty, ptr, input_val); - - // 3. Load the value back from the pointer let loaded_val = f_builder.build_load(i32_ty, ptr); - // 4. Call factorial(loaded_val) let args = [(i32_ty, loaded_val)]; let final_result = f_builder.build_call(i32_ty, fact_id, args.iter()); - // 5. Return the result of the factorial call + // 2. Floating point math: call hypotenuse_sq(3.0, 4.0) + let arg1 = Operand::Float(3.0); + let arg2 = Operand::Float(4.0); + let hypot_args = [(f64_ty, arg1), (f64_ty, arg2)]; + let _hypot_res = f_builder.build_call(f64_ty, hypot_id, hypot_args.iter()); + + // 3. Return the result of the factorial call f_builder.build_return(i32_ty, final_result); } module_builder.complete_function(); @@ -103,6 +117,8 @@ fn main() { validate_module(&module).expect("failed to validate module"); passes::optimize(&mut module); + println!("{}", module); + let assembly = X86Backend::new(&module).compile_module(); println!("{}", assembly); } diff --git a/src/passes/cfp.rs b/src/passes/cfp.rs index c24d1ab..c1cddcd 100644 --- a/src/passes/cfp.rs +++ b/src/passes/cfp.rs @@ -75,10 +75,11 @@ fn fold_function_constants(func: &mut Function) { } } -// --- Helper Functions --- - fn is_constant(op: &Operand) -> bool { - matches!(op, Operand::Integer(_) | Operand::Boolean(_)) + matches!( + op, + Operand::Integer(_) | Operand::Boolean(_) | Operand::Float(_) + ) } fn substitute_constants_in_inst(inst: &mut Instruction, constants: &HashMap) { @@ -191,6 +192,7 @@ fn evaluate_binary(op: BinaryOp, src1: &Operand, src2: &Operand) -> Option None, } } (Operand::Boolean(a), Operand::Boolean(b)) => match op { @@ -198,6 +200,38 @@ fn evaluate_binary(op: BinaryOp, src1: &Operand, src2: &Operand) -> Option Some(Operand::Boolean(a != b)), _ => None, }, + (Operand::Float(a), Operand::Float(b)) => { + let a = *a; + let b = *b; + + match op { + BinaryOp::FAdd => Some(Operand::Float(a + b)), + BinaryOp::FSub => Some(Operand::Float(a - b)), + BinaryOp::FMul => Some(Operand::Float(a * b)), + BinaryOp::FDiv => Some(Operand::Float(a / b)), + BinaryOp::FRem => Some(Operand::Float(a % b)), + BinaryOp::FCmp(cmp) => { + let res = match cmp { + FCmpOp::Oeq => a == b, + FCmpOp::Ogt => a > b, + FCmpOp::Oge => a >= b, + FCmpOp::Olt => a < b, + FCmpOp::Ole => a <= b, + FCmpOp::One => a != b && !a.is_nan() && !b.is_nan(), + FCmpOp::Ord => !a.is_nan() && !b.is_nan(), + FCmpOp::Uno => a.is_nan() || b.is_nan(), + FCmpOp::Ueq => a == b || a.is_nan() || b.is_nan(), + FCmpOp::Ugt => a > b || a.is_nan() || b.is_nan(), + FCmpOp::Uge => a >= b || a.is_nan() || b.is_nan(), + FCmpOp::Ult => a < b || a.is_nan() || b.is_nan(), + FCmpOp::Ule => a <= b || a.is_nan() || b.is_nan(), + FCmpOp::Une => a != b || a.is_nan() || b.is_nan(), + }; + Some(Operand::Boolean(res)) + } + _ => None, + } + } _ => None, } } @@ -205,6 +239,7 @@ fn evaluate_binary(op: BinaryOp, src1: &Operand, src2: &Operand) -> Option Option { match (op, src) { (UnaryOp::INeg, Operand::Integer(a)) => Some(Operand::Integer(a.wrapping_neg())), + (UnaryOp::FNeg, Operand::Float(a)) => Some(Operand::Float(-a)), _ => None, } } diff --git a/src/printer.rs b/src/printer.rs index d414031..679e6ff 100644 --- a/src/printer.rs +++ b/src/printer.rs @@ -1,8 +1,8 @@ use std::fmt::{self, Display, Write}; use crate::ir::{ - BasicBlock, BinaryOp, BlockId, Function, ICmpOp, Instruction, Module, Operand, Register, - Terminator, Type, UnaryOp, + BasicBlock, BinaryOp, BlockId, FCmpOp, Function, ICmpOp, Instruction, Module, Operand, + Register, Terminator, Type, UnaryOp, }; impl Display for Register { @@ -16,6 +16,7 @@ impl Display for Operand { match self { Operand::Integer(value) => write!(f, "${}", value), Operand::Boolean(value) => write!(f, "${}", value), + Operand::Float(value) => write!(f, "${}", value), Operand::Register(register) => register.fmt(f), } } @@ -29,6 +30,8 @@ impl Display for Type { Type::I16 => write!(f, "i16"), Type::I32 => write!(f, "i32"), Type::I64 => write!(f, "i64"), + Type::F32 => write!(f, "f32"), + Type::F64 => write!(f, "f64"), Type::Ptr => write!(f, "ptr"), Type::Void => write!(f, "void"), } @@ -45,6 +48,7 @@ impl Display for UnaryOp { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { UnaryOp::INeg => write!(f, "ineg"), + UnaryOp::FNeg => write!(f, "fneg"), } } } @@ -61,6 +65,12 @@ impl Display for BinaryOp { BinaryOp::SRem => write!(f, "srem"), BinaryOp::URem => write!(f, "urem"), BinaryOp::ICmp(icmp_op) => write!(f, "icmp {}", icmp_op), + BinaryOp::FAdd => write!(f, "fadd"), + BinaryOp::FSub => write!(f, "fsub"), + BinaryOp::FDiv => write!(f, "fdiv"), + BinaryOp::FMul => write!(f, "fmul"), + BinaryOp::FRem => write!(f, "frem"), + BinaryOp::FCmp(fcmp_op) => write!(f, "fcmp {}", fcmp_op), } } } @@ -82,6 +92,27 @@ impl Display for ICmpOp { } } +impl Display for FCmpOp { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + FCmpOp::Oeq => write!(f, "oeq"), + FCmpOp::Ogt => write!(f, "ogt"), + FCmpOp::Oge => write!(f, "oge"), + FCmpOp::Olt => write!(f, "olt"), + FCmpOp::Ole => write!(f, "ole"), + FCmpOp::One => write!(f, "one"), + FCmpOp::Ord => write!(f, "ord"), + FCmpOp::Uno => write!(f, "uno"), + FCmpOp::Ueq => write!(f, "ueq"), + FCmpOp::Ugt => write!(f, "ugt"), + FCmpOp::Uge => write!(f, "uge"), + FCmpOp::Ult => write!(f, "ult"), + FCmpOp::Ule => write!(f, "ule"), + FCmpOp::Une => write!(f, "une"), + } + } +} + impl Display for Module { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str(IrPrinter::print(self)?.as_str()) diff --git a/src/validate.rs b/src/validate.rs index 2e7ec1b..06f6918 100644 --- a/src/validate.rs +++ b/src/validate.rs @@ -57,6 +57,7 @@ fn validate_function( Operand::Register(r) => reg_types.get(r).copied(), Operand::Boolean(_) => Some(Type::Bool), Operand::Integer(_) => Some(Type::I64), // Default fallback for untyped literals + Operand::Float(_) => Some(Type::F64), // Default fallback }; (Some(*register), inferred_ty) } @@ -103,6 +104,12 @@ fn validate_function( return Err(format!("Cannot use integer literal as {:?}", expected)); } }, + Operand::Float(_) => match expected { + Type::F32 | Type::F64 => {} + _ => { + return Err(format!("Cannot use float literal as {:?}", expected)); + } + }, } Ok(()) @@ -150,6 +157,27 @@ fn validate_function( Operand::Register(r) => *reg_types.get(r).unwrap_or(&Type::I64), _ => Type::I64, }, + Operand::Float(_) => match src2 { + Operand::Register(r) => *reg_types.get(r).unwrap_or(&Type::F64), + _ => Type::F64, + }, + }; + check_operand(src1, op_type)?; + check_operand(src2, op_type)?; + } else if let BinaryOp::FCmp(_) = op { + if *result_ty != Type::Bool { + return Err("FCmp result type must be Bool".to_string()); + } + + let op_type = match src1 { + Operand::Register(r) => *reg_types + .get(r) + .ok_or_else(|| format!("Unknown reg {:?}", r))?, + Operand::Float(_) => match src2 { + Operand::Register(r) => *reg_types.get(r).unwrap_or(&Type::F64), + _ => Type::F64, + }, + _ => return Err("FCmp operands must be floating point".to_string()), }; check_operand(src1, op_type)?; check_operand(src2, op_type)?; diff --git a/tests/float_test.rs b/tests/float_test.rs new file mode 100644 index 0000000..360cb55 --- /dev/null +++ b/tests/float_test.rs @@ -0,0 +1,113 @@ +use scarlett::{builder::IrModuleBuilder, ir::*, validate::validate_module}; + +#[test] +fn test_float_ir() { + let mut module_builder = IrModuleBuilder::new(); + let f64_ty = Type::F64; + + let func_id = module_builder.new_function_id(); + { + let f_builder = + module_builder.new_function(func_id, "float_math", vec![&f64_ty, &f64_ty], f64_ty); + + let a = f_builder.get_param(0).unwrap(); + let b = f_builder.get_param(1).unwrap(); + + let sum = f_builder.build_fadd(f64_ty, a, b); + let diff = f_builder.build_fsub(f64_ty, sum, a); + let prod = f_builder.build_fmul(f64_ty, diff, b); + let quot = f_builder.build_fdiv(f64_ty, prod, a); + let neg = f_builder.build_fneg(f64_ty, quot); + + let is_gt = f_builder.build_fcmp(FCmpOp::Ogt, neg, Operand::Float(0.0)); + + let then_block = f_builder.create_block(); + let else_block = f_builder.create_block(); + + f_builder.build_branch(is_gt, then_block, else_block); + + f_builder.switch_to_block(then_block); + f_builder.build_return(f64_ty, neg); + + f_builder.switch_to_block(else_block); + f_builder.build_return(f64_ty, Operand::Float(42.0)); + } + module_builder.complete_function(); + + let module = module_builder.finish(); + + validate_module(&module).expect("Module validation failed"); + + let printed = format!("{}", module); + println!("{}", printed); + + assert!(printed.contains("f64")); + assert!(printed.contains("fadd")); + assert!(printed.contains("fsub")); + assert!(printed.contains("fmul")); + assert!(printed.contains("fdiv")); + assert!(printed.contains("fneg")); + assert!(printed.contains("fcmp ogt")); +} + +#[test] +fn test_float_constant_folding() { + let mut module_builder = IrModuleBuilder::new(); + let f64_ty = Type::F64; + + let func_id = module_builder.new_function_id(); + { + let f_builder = module_builder.new_function(func_id, "fold_me", vec![], f64_ty); + + let a = Operand::Float(10.0); + let b = Operand::Float(2.0); + + let sum = f_builder.build_fadd(f64_ty, a, b); // 12.0 + let prod = f_builder.build_fmul(f64_ty, sum, b); // 24.0 + + f_builder.build_return(f64_ty, prod); + } + module_builder.complete_function(); + + let mut module = module_builder.finish(); + + validate_module(&module).expect("Module validation failed"); + + scarlett::passes::cfp::fold_constants(&mut module); + + let printed = format!("{}", module); + println!("{}", printed); + + assert!(printed.contains("$24")); + assert!(!printed.contains("fadd")); + assert!(!printed.contains("fmul")); +} + +#[test] +fn test_float_codegen() { + let mut module_builder = IrModuleBuilder::new(); + let f64_ty = Type::F64; + + let func_id = module_builder.new_function_id(); + { + let f_builder = + module_builder.new_function(func_id, "add_floats", vec![&f64_ty, &f64_ty], f64_ty); + + let a = f_builder.get_param(0).unwrap(); + let b = f_builder.get_param(1).unwrap(); + + let res = f_builder.build_fadd(f64_ty, a, b); + f_builder.build_return(f64_ty, res); + } + module_builder.complete_function(); + + let module = module_builder.finish(); + + let assembly = scarlett::backend::x86_64::X86Backend::new(&module).compile_module(); + println!("{}", assembly); + + assert!(assembly.contains("addsd")); + assert!(assembly.contains("movsd")); + assert!(assembly.contains("%xmm0")); + assert!(assembly.contains("%xmm1")); +}