From 6f6c84eac4e6e5a3b0eddcf86b297724446a05ef Mon Sep 17 00:00:00 2001 From: Jooris Hadeler Date: Mon, 27 Apr 2026 13:36:17 +0200 Subject: [PATCH] feat: differentiate between signed and unsigned multiplication - Split `BinaryOp::Mul` into `BinaryOp::SMul` and `BinaryOp::UMul` in the IR. - Implement x86_64 codegen for `UMul` using the `mulq` instruction. - Update `X86Backend` to use `imulq` specifically for `SMul`. - Update constant folding and IR printer to support the new multiplication operations. - Optimize function epilogue by omitting the final `jmp` on fallthrough. - Update `main` test case to use `build_umul`. --- src/backend/x86_64.rs | 27 +++++++++++++++++++++++---- src/builder.rs | 11 ++++++++--- src/ir.rs | 3 ++- src/main.rs | 2 +- src/passes/cfp.rs | 3 ++- src/printer.rs | 5 +++-- 6 files changed, 39 insertions(+), 12 deletions(-) diff --git a/src/backend/x86_64.rs b/src/backend/x86_64.rs index 5719ef4..3dea3ab 100644 --- a/src/backend/x86_64.rs +++ b/src/backend/x86_64.rs @@ -422,7 +422,7 @@ impl<'a> X86Backend<'a> { let src2_str = self.resolve_op(src2, "%r11", allocs); match op { - BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul => { + BinaryOp::Add | BinaryOp::Sub | BinaryOp::SMul => { let is_add = matches!(op, BinaryOp::Add); let dest_hw = !dest_str.starts_with('-'); let src1_hw = !src1_str.starts_with('-'); @@ -433,7 +433,7 @@ impl<'a> X86Backend<'a> { && dest_hw && src1_hw && src2_imm - && !matches!(op, BinaryOp::Mul) + && !matches!(op, BinaryOp::UMul) { let val: i64 = src2_str[1..].parse().unwrap(); let offset = if is_add { val } else { -val }; @@ -455,7 +455,7 @@ impl<'a> X86Backend<'a> { let mnemonic = match op { BinaryOp::Add => "addq", BinaryOp::Sub => "subq", - BinaryOp::Mul => "imulq", + BinaryOp::SMul => "imulq", _ => unreachable!(), }; if dest_str.starts_with('-') && src2_str.starts_with('-') { @@ -473,6 +473,21 @@ impl<'a> X86Backend<'a> { } } } + BinaryOp::UMul => { + // Unsigned multiply (mulq) strictly takes 1 operand and multiplies it by %rax. + writeln!(&mut self.assembly, " movq {}, %rax", src1_str).unwrap(); + + // mulq cannot take an immediate value, so route it through a scratch register if needed + if src2_str.starts_with('$') { + writeln!(&mut self.assembly, " movq {}, %r10", src2_str).unwrap(); + writeln!(&mut self.assembly, " mulq %r10").unwrap(); + } else { + writeln!(&mut self.assembly, " mulq {}", src2_str).unwrap(); + } + + // The lower 64 bits of the result are in %rax + self.emit_mov("%rax", &dest_str); + } BinaryOp::SDiv | BinaryOp::SRem => { writeln!(&mut self.assembly, " movq {}, %rax", src1_str).unwrap(); writeln!(&mut self.assembly, " cqto").unwrap(); @@ -594,7 +609,11 @@ impl<'a> X86Backend<'a> { let src = self.resolve_op(val, "%r10", allocs); self.emit_mov(&src, "%rax"); } - writeln!(&mut self.assembly, " jmp .L{}_epilogue", func_name).unwrap(); + + // If there is no next block we can fallthrough otherwise emit a jump + if next_block_id.is_some() { + writeln!(&mut self.assembly, " jmp .L{}_epilogue", func_name).unwrap(); + } } Terminator::Jump(target) => { if Some(*target) != next_block_id { diff --git a/src/builder.rs b/src/builder.rs index e6249de..49810e6 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -247,9 +247,14 @@ impl IrFunctionBuilder { self.build_binary(result_ty, BinaryOp::Sub, lhs, rhs) } - /// Builds an `mul` instruction. - pub fn build_mul(&mut self, result_ty: Type, lhs: Operand, rhs: Operand) -> Operand { - self.build_binary(result_ty, BinaryOp::Mul, lhs, rhs) + /// Builds an `smul` instruction. + pub fn build_smul(&mut self, result_ty: Type, lhs: Operand, rhs: Operand) -> Operand { + self.build_binary(result_ty, BinaryOp::SMul, lhs, rhs) + } + + /// Builds an `umul` instruction. + pub fn build_umul(&mut self, result_ty: Type, lhs: Operand, rhs: Operand) -> Operand { + self.build_binary(result_ty, BinaryOp::UMul, lhs, rhs) } /// Builds an `udiv` instruction. diff --git a/src/ir.rs b/src/ir.rs index 3d331dc..ad7df65 100644 --- a/src/ir.rs +++ b/src/ir.rs @@ -39,7 +39,8 @@ pub enum BinaryOp { Sub, UDiv, SDiv, - Mul, + SMul, + UMul, SRem, URem, ICmp(ICmpOp), diff --git a/src/main.rs b/src/main.rs index 779cbab..7958495 100644 --- a/src/main.rs +++ b/src/main.rs @@ -51,7 +51,7 @@ fn build_test_module() -> Module { let val_i = f_builder.build_load(i32_ty.clone(), i_ptr.clone()); // res = res * i - let updated_res = f_builder.build_mul(i32_ty.clone(), val_res, val_i.clone()); + let updated_res = f_builder.build_umul(i32_ty.clone(), val_res, val_i.clone()); f_builder.build_store(i32_ty.clone(), res_ptr.clone(), updated_res); // i = i - 1 diff --git a/src/passes/cfp.rs b/src/passes/cfp.rs index 733fff7..6d9a517 100644 --- a/src/passes/cfp.rs +++ b/src/passes/cfp.rs @@ -146,7 +146,8 @@ fn evaluate_binary(op: BinaryOp, src1: &Operand, src2: &Operand) -> Option Some(Operand::Integer(a.wrapping_add(b))), BinaryOp::Sub => Some(Operand::Integer(a.wrapping_sub(b))), - BinaryOp::Mul => Some(Operand::Integer(a.wrapping_mul(b))), + BinaryOp::SMul => Some(Operand::Integer((a as i64).wrapping_mul(b as i64) as u64)), + BinaryOp::UMul => Some(Operand::Integer(a.wrapping_mul(b))), BinaryOp::UDiv => { if b != 0 { Some(Operand::Integer(a.wrapping_div(b))) diff --git a/src/printer.rs b/src/printer.rs index 6966549..0e4c361 100644 --- a/src/printer.rs +++ b/src/printer.rs @@ -54,9 +54,10 @@ impl Display for BinaryOp { match self { BinaryOp::Add => write!(f, "add"), BinaryOp::Sub => write!(f, "sub"), + BinaryOp::SDiv => write!(f, "sdiv"), BinaryOp::UDiv => write!(f, "udiv"), - BinaryOp::SDiv => write!(f, "udiv"), - BinaryOp::Mul => write!(f, "mul"), + BinaryOp::SMul => write!(f, "smul"), + BinaryOp::UMul => write!(f, "umul"), BinaryOp::SRem => write!(f, "srem"), BinaryOp::URem => write!(f, "urem"), BinaryOp::ICmp(icmp_op) => write!(f, "icmp {}", icmp_op),