diff --git a/src/backend/cranelift.rs b/src/backend/cranelift.rs index 6a8b0a0..c281e1a 100644 --- a/src/backend/cranelift.rs +++ b/src/backend/cranelift.rs @@ -137,14 +137,27 @@ impl CraneliftBackend { } let mut var_map = HashMap::new(); + let mut stack_slot_map = HashMap::new(); for local in &func.locals { - let var = builder.declare_var(Self::lower_type(&local.ty)); - var_map.insert(local.id, var); + if local.address_taken { + let cl_ty = Self::lower_type(&local.ty); + let bytes = cl_ty.bytes(); + let slot = builder.create_sized_stack_slot(ir::StackSlotData::new( + ir::StackSlotKind::ExplicitSlot, + bytes, + 0, + )); + stack_slot_map.insert(local.id, slot); + } else { + let var = builder.declare_var(Self::lower_type(&local.ty)); + var_map.insert(local.id, var); + } } let mut trans = FunctionTranslator { builder, var_map, + stack_slot_map, block_map, locals: &func.locals, module: &mut self.module, @@ -166,7 +179,11 @@ impl CraneliftBackend { if i == 0 { for (j, param_id) in func.params.iter().enumerate() { let val = trans.builder.block_params(cl_block)[j]; - trans.builder.def_var(trans.var_map[param_id], val); + if let Some(&slot) = trans.stack_slot_map.get(param_id) { + trans.builder.ins().stack_store(val, slot, 0); + } else { + trans.builder.def_var(trans.var_map[param_id], val); + } } } @@ -195,6 +212,7 @@ impl CraneliftBackend { Ty::F32 => types::F32, Ty::F64 => types::F64, Ty::Bool => types::I8, // Booleans are represented as 8-bit integers + Ty::Pointer(_) => types::I64, // Assume 64-bit environment pointers _ => unimplemented!("Unsupported type for Cranelift lowering: {:?}", ty), } } @@ -205,6 +223,7 @@ impl CraneliftBackend { struct FunctionTranslator<'a> { builder: FunctionBuilder<'a>, var_map: HashMap, + stack_slot_map: HashMap, block_map: HashMap, locals: &'a [LocalDecl], module: &'a mut ObjectModule, @@ -217,13 +236,25 @@ impl<'a> FunctionTranslator<'a> { StatementKind::Assign(local_id, rvalue) => { let val = self.translate_rvalue(rvalue); if let Some(v) = val { - let var = self.var_map[local_id]; - self.builder.def_var(var, v); + if let Some(&slot) = self.stack_slot_map.get(local_id) { + self.builder.ins().stack_store(v, slot, 0); + } else { + let var = self.var_map[local_id]; + self.builder.def_var(var, v); + } } } StatementKind::SideEffect(rvalue) => { self.translate_rvalue(rvalue); } + StatementKind::Store { ptr, val } => { + let ptr_val = self.translate_operand(ptr); + if let Some(v) = self.translate_rvalue(val) { + self.builder + .ins() + .store(ir::MemFlags::trusted(), v, ptr_val, 0); + } + } } } @@ -272,8 +303,13 @@ impl<'a> FunctionTranslator<'a> { fn translate_operand(&mut self, op: &Operand) -> ir::Value { match op { Operand::Copy(local_id) => { - let var = self.var_map[local_id]; - self.builder.use_var(var) + if let Some(&slot) = self.stack_slot_map.get(local_id) { + let cl_ty = CraneliftBackend::lower_type(&self.locals[local_id.0].ty); + self.builder.ins().stack_load(cl_ty, slot, 0) + } else { + let var = self.var_map[local_id]; + self.builder.use_var(var) + } } Operand::Constant(ConstantValue::Integer(val, ty)) => { let cl_ty = CraneliftBackend::lower_type(ty); @@ -315,6 +351,9 @@ impl<'a> FunctionTranslator<'a> { .icmp(ir::condcodes::IntCC::Equal, inner_val, zero), ) } + UnaryOp::Deref | UnaryOp::AddressOf => { + unreachable!("handled as distinct Rvalues in MIR") + } } } Rvalue::BinaryOp(op, lhs, rhs) => { @@ -527,6 +566,24 @@ impl<'a> FunctionTranslator<'a> { Some(self.builder.inst_results(call_inst)[0]) } } + Rvalue::AddressOf(local_id) => { + let slot = self.stack_slot_map[local_id]; + Some(self.builder.ins().stack_addr(types::I64, slot, 0)) + } + Rvalue::ReadPointer(ptr_op) => { + let ptr_val = self.translate_operand(ptr_op); + let ptr_ty = self.get_operand_type(ptr_op); + let inner_ty = match ptr_ty { + Ty::Pointer(inner) => *inner, + _ => unreachable!(), + }; + let cl_ty = CraneliftBackend::lower_type(&inner_ty); + Some( + self.builder + .ins() + .load(cl_ty, ir::MemFlags::trusted(), ptr_val, 0), + ) + } } } } diff --git a/src/frontend/ast.rs b/src/frontend/ast.rs index 8cebab6..f057c2c 100644 --- a/src/frontend/ast.rs +++ b/src/frontend/ast.rs @@ -91,6 +91,7 @@ pub enum TypeKind { F32, F64, Bool, + Pointer(Box), } #[derive(Debug, PartialEq)] @@ -177,6 +178,8 @@ pub enum ExprKind { pub enum UnaryOp { Neg, Not, + Deref, + AddressOf, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] diff --git a/src/frontend/lexer.rs b/src/frontend/lexer.rs index 9461936..473776f 100644 --- a/src/frontend/lexer.rs +++ b/src/frontend/lexer.rs @@ -188,6 +188,7 @@ impl<'src> Lexer<'src> { '*' => token!(TokenKind::Star), '/' => token!(TokenKind::Slash), '%' => token!(TokenKind::Percent), + '&' => token!(TokenKind::Ampersand), '!' => token!(TokenKind::Bang, '=' => TokenKind::Unequal), '=' => token!(TokenKind::Assign, '=' => TokenKind::Equal), diff --git a/src/frontend/parser.rs b/src/frontend/parser.rs index fd2ca4e..eedc55e 100644 --- a/src/frontend/parser.rs +++ b/src/frontend/parser.rs @@ -314,6 +314,15 @@ impl<'src> Parser<'src> { self.advance(); TypeKind::Bool } + TokenKind::Star => { + let star_token = self.advance().unwrap(); + let inner = self.parse_type()?; + let span = star_token.span.join(inner.span); + return Ok(Type { + kind: TypeKind::Pointer(Box::new(inner)), + span, + }); + } _ => { return Err(ParseError::new( @@ -782,6 +791,8 @@ impl<'src> Parser<'src> { match op { TokenKind::Minus => Some((UnaryOp::Neg, 30)), TokenKind::Bang => Some((UnaryOp::Not, 30)), + TokenKind::Star => Some((UnaryOp::Deref, 30)), + TokenKind::Ampersand => Some((UnaryOp::AddressOf, 30)), _ => None, } diff --git a/src/frontend/sema.rs b/src/frontend/sema.rs index dc5bb74..d7dce8a 100644 --- a/src/frontend/sema.rs +++ b/src/frontend/sema.rs @@ -41,6 +41,7 @@ pub enum Ty { Unit, Var(usize), Function(Vec, Box), + Pointer(Box), } impl Ty { @@ -85,6 +86,11 @@ impl Ty { _ => 0, } } + + /// Returns `true` if the type is a pointer. + pub fn is_pointer(&self) -> bool { + matches!(self, Ty::Pointer(_)) + } } impl From<&TypeKind> for Ty { @@ -101,6 +107,20 @@ impl From<&TypeKind> for Ty { TypeKind::F32 => Ty::F32, TypeKind::F64 => Ty::F64, TypeKind::Bool => Ty::Bool, + TypeKind::Pointer(inner) => Ty::Pointer(Box::new(Ty::from(&inner.kind))), + } + } +} + +impl TypedExpr { + /// Returns `true` if the expression produces a valid memory location that can be mutated. + pub fn is_lvalue(&self) -> bool { + match &self.kind { + TypedExprKind::Identifier { .. } => true, + TypedExprKind::Unary { + op: UnaryOp::Deref, .. + } => true, + _ => false, } } } @@ -187,6 +207,8 @@ impl Sema { Ty::Function(params, ret) } + Ty::Pointer(inner) => Ty::Pointer(Box::new(self.apply_subst(inner))), + t => t.clone(), } } @@ -200,6 +222,7 @@ impl Sema { Ty::Function(params, ret) => { params.iter().any(|p| self.occurs_check(v, p)) || self.occurs_check(v, &ret) } + Ty::Pointer(inner) => self.occurs_check(v, &inner), _ => false, } } @@ -235,6 +258,7 @@ impl Sema { self.unify(&r1, &r2) } + (Ty::Pointer(p1), Ty::Pointer(p2)) => self.unify(&p1, &p2), (a, b) => Err(format!("type mismatch: expected {:?}, found {:?}", a, b)), } } @@ -617,6 +641,48 @@ impl Sema { } } + ExprKind::Unary { + op: UnaryOp::AddressOf, + expr: inner_expr, + } => { + let typed_inner = self.analyze_expr(inner_expr); + if !typed_inner.is_lvalue() { + self.errors.push(SemanticError::new( + "invalid operand for address-of operator", + inner_expr.span, + )); + } + TypedExpr { + ty: Ty::Pointer(Box::new(typed_inner.ty.clone())), + kind: TypedExprKind::Unary { + op: UnaryOp::AddressOf, + expr: Box::new(typed_inner), + }, + span: expr.span, + } + } + + ExprKind::Unary { + op: UnaryOp::Deref, + expr: inner_expr, + } => { + let typed_inner = self.analyze_expr(inner_expr); + let result_ty = self.new_var(); + if let Err(e) = + self.unify(&typed_inner.ty, &Ty::Pointer(Box::new(result_ty.clone()))) + { + self.errors.push(SemanticError::new(e, inner_expr.span)); + } + TypedExpr { + kind: TypedExprKind::Unary { + op: UnaryOp::Deref, + expr: Box::new(typed_inner), + }, + ty: result_ty, + span: expr.span, + } + } + ExprKind::Binary { op, lhs, rhs } => { let typed_lhs = self.analyze_expr(lhs); let typed_rhs = self.analyze_expr(rhs); @@ -657,12 +723,11 @@ impl Sema { let typed_rval = self.analyze_expr(rval); let typed_lval = self.analyze_expr(lval); - match &typed_lval.kind { - TypedExprKind::Identifier { .. } => {} - _ => self.errors.push(SemanticError::new( + if !typed_lval.is_lvalue() { + self.errors.push(SemanticError::new( "invalid left-hand side of assignment", lval.span, - )), + )); } if let Err(e) = self.unify(&typed_lval.ty, &typed_rval.ty) { @@ -994,8 +1059,9 @@ impl Sema { from }; - let is_valid = (final_from.is_numeric() || final_from == Ty::Bool) - && (to.is_numeric() || to == Ty::Bool); + let is_valid = + (final_from.is_numeric() || final_from == Ty::Bool || final_from.is_pointer()) + && (to.is_numeric() || to == Ty::Bool || to.is_pointer()); if !is_valid { self.errors.push(SemanticError::new( @@ -1229,4 +1295,17 @@ mod test { let src = "fn test() { let a: f32 = 10.0; let b = a as i32; }"; assert!(analyze(src).is_ok()); } + + #[test] + fn invalid_pointer_assignment() { + let src = "fn test() { let a: i32 = 5; let b: *f32 = &a; }"; + let errors = analyze(src).unwrap_err(); + assert!(errors.iter().any(|e| e.message.contains("type mismatch"))); + } + + #[test] + fn valid_pointer_cast() { + let src = "fn test() { let a: i32 = 5; let b: *f32 = &a as *f32; }"; + assert!(analyze(src).is_ok()); + } } diff --git a/src/frontend/token.rs b/src/frontend/token.rs index 549d85f..a7ce4da 100644 --- a/src/frontend/token.rs +++ b/src/frontend/token.rs @@ -86,6 +86,7 @@ pub enum TokenKind { Star, Slash, Percent, + Ampersand, Equal, Unequal, @@ -147,6 +148,7 @@ impl Display for TokenKind { TokenKind::Star => "`*`", TokenKind::Slash => "`/`", TokenKind::Percent => "`%`", + TokenKind::Ampersand => "`&`", TokenKind::Equal => "`==`", TokenKind::Unequal => "`!=`", TokenKind::LessThan => "`<`", diff --git a/src/middle/builder.rs b/src/middle/builder.rs index 2581c9b..ae912d5 100644 --- a/src/middle/builder.rs +++ b/src/middle/builder.rs @@ -139,6 +139,7 @@ impl FuncBuilder { ty, mutable: false, name: Some(name.clone()), + address_taken: false, }); self.scopes.last_mut().unwrap().insert(name, id); id @@ -152,6 +153,7 @@ impl FuncBuilder { ty, mutable: false, name: None, + address_taken: false, }); id } @@ -364,15 +366,43 @@ impl FuncBuilder { } TypedExprKind::Boolean { value } => Operand::Constant(ConstantValue::Boolean(*value)), TypedExprKind::Unary { op, expr: inner } => { - let inner_op = self.lower_expr(inner); - let temp = self.new_temp(expr.ty.clone()); - - self.emit_stmt(Statement { - kind: StatementKind::Assign(temp, Rvalue::UnaryOp(*op, inner_op)), - span: expr.span, - }); - - Operand::Copy(temp) + match op { + UnaryOp::AddressOf => match &inner.kind { + TypedExprKind::Identifier { name } => { + let id = self.lookup(name); + self.locals[id.0].address_taken = true; + let temp = self.new_temp(expr.ty.clone()); + self.emit_stmt(Statement { + kind: StatementKind::Assign(temp, Rvalue::AddressOf(id)), + span: expr.span, + }); + Operand::Copy(temp) + } + TypedExprKind::Unary { + op: UnaryOp::Deref, + expr: ptr_expr, + } => self.lower_expr(ptr_expr), // `&*ptr` is optimized right back to `ptr`! + _ => unreachable!("invalid lvalue for addressof"), + }, + UnaryOp::Deref => { + let inner_op = self.lower_expr(inner); + let temp = self.new_temp(expr.ty.clone()); + self.emit_stmt(Statement { + kind: StatementKind::Assign(temp, Rvalue::ReadPointer(inner_op)), + span: expr.span, + }); + Operand::Copy(temp) + } + _ => { + let inner_op = self.lower_expr(inner); + let temp = self.new_temp(expr.ty.clone()); + self.emit_stmt(Statement { + kind: StatementKind::Assign(temp, Rvalue::UnaryOp(*op, inner_op)), + span: expr.span, + }); + Operand::Copy(temp) + } + } } TypedExprKind::Binary { op, lhs, rhs } => { let lhs_op = self.lower_expr(lhs); @@ -389,15 +419,29 @@ impl FuncBuilder { TypedExprKind::Assign { lval, rval } => { let rval_op = self.lower_expr(rval); - let local_id = match &lval.kind { - TypedExprKind::Identifier { name } => self.lookup(name), - _ => panic!("invalid lval in MIR lowering"), - }; - - self.emit_stmt(Statement { - kind: StatementKind::Assign(local_id, Rvalue::Use(rval_op.clone())), - span: expr.span, - }); + match &lval.kind { + TypedExprKind::Identifier { name } => { + let local_id = self.lookup(name); + self.emit_stmt(Statement { + kind: StatementKind::Assign(local_id, Rvalue::Use(rval_op.clone())), + span: expr.span, + }); + } + TypedExprKind::Unary { + op: UnaryOp::Deref, + expr: ptr_expr, + } => { + let ptr_op = self.lower_expr(ptr_expr); + self.emit_stmt(Statement { + kind: StatementKind::Store { + ptr: ptr_op, + val: Rvalue::Use(rval_op.clone()), + }, + span: expr.span, + }); + } + _ => unreachable!("invalid lval in MIR lowering"), + } rval_op } diff --git a/src/middle/fold.rs b/src/middle/fold.rs index 8631a7b..5bb93e5 100644 --- a/src/middle/fold.rs +++ b/src/middle/fold.rs @@ -28,7 +28,9 @@ fn optimize_function(func: &mut MirFunction) { if let Some(constant) = evaluate_rvalue(rvalue) { // Replace the complex instruction with a simple constant use *rvalue = Rvalue::Use(Operand::Constant(constant.clone())); - known_constants.insert(*local, constant); + if !func.locals[local.0].address_taken { + known_constants.insert(*local, constant); + } } else { // Reassigned to a non-computable value; remove older cached inferences known_constants.remove(local); @@ -37,6 +39,10 @@ fn optimize_function(func: &mut MirFunction) { StatementKind::SideEffect(rvalue) => { propagate_rvalue(rvalue, &known_constants); } + StatementKind::Store { ptr, val } => { + propagate_operand(ptr, &known_constants); + propagate_rvalue(val, &known_constants); + } } } @@ -86,6 +92,8 @@ fn propagate_rvalue(rvalue: &mut Rvalue, known_constants: &HashMap {} + Rvalue::ReadPointer(op) => propagate_operand(op, known_constants), } } @@ -264,7 +272,29 @@ mod test { name: "test".to_string(), params: vec![], return_type: Ty::I32, - locals: vec![], + locals: vec![ + LocalDecl { + id: LocalId(0), + ty: Ty::I32, + mutable: false, + name: None, + address_taken: false, + }, + LocalDecl { + id: LocalId(1), + ty: Ty::I32, + mutable: false, + name: None, + address_taken: false, + }, + LocalDecl { + id: LocalId(2), + ty: Ty::I32, + mutable: false, + name: None, + address_taken: false, + }, + ], blocks: vec![BasicBlock { id: BlockId(0), statements: vec![ diff --git a/src/middle/mir.rs b/src/middle/mir.rs index e07eefb..d52e7c2 100644 --- a/src/middle/mir.rs +++ b/src/middle/mir.rs @@ -40,6 +40,8 @@ pub struct LocalDecl { pub mutable: bool, /// Contains the name if it is a user-declared variable, or `None` if it's a compiler temporary. pub name: Option, + /// Indicates if the address of this local was ever taken, forcing allocation via explicit memory stack slots instead of SSA pseudo-registers. + pub address_taken: bool, } /// A sequential list of non-branching statements followed by a single terminator. @@ -62,6 +64,8 @@ pub enum StatementKind { Assign(LocalId, Rvalue), /// Executes an Rvalue strictly for its side effects (e.g. FFI calling `Unit` functions) SideEffect(Rvalue), + /// Stores a value into the memory address specified by the pointer operand. + Store { ptr: Operand, val: Rvalue }, } /// Operations that produce a value. @@ -72,6 +76,8 @@ pub enum Rvalue { BinaryOp(BinaryOp, Operand, Operand), Cast(Ty, Operand), Call(String, Vec, Ty), + AddressOf(LocalId), + ReadPointer(Operand), } /// An atomic value used as inputs to instructions. diff --git a/tests/pointers.test b/tests/pointers.test new file mode 100644 index 0000000..4eb2481 --- /dev/null +++ b/tests/pointers.test @@ -0,0 +1,16 @@ +[code] +fn swap(a: *i32, b: *i32) { + let temp = *a; + *a = *b; + *b = temp; +} + +fn main() -> i32 { + let x = 10; + let y = 20; + swap(&x, &y); + return x - y; // 20 - 10 = 10 +} + +[expected_return_code] +10 \ No newline at end of file