feat: add support for *T pointers
This commit is contained in:
@@ -137,14 +137,27 @@ impl CraneliftBackend {
|
||||
}
|
||||
|
||||
let mut var_map = HashMap::new();
|
||||
let mut stack_slot_map = HashMap::new();
|
||||
for local in &func.locals {
|
||||
if local.address_taken {
|
||||
let cl_ty = Self::lower_type(&local.ty);
|
||||
let bytes = cl_ty.bytes();
|
||||
let slot = builder.create_sized_stack_slot(ir::StackSlotData::new(
|
||||
ir::StackSlotKind::ExplicitSlot,
|
||||
bytes,
|
||||
0,
|
||||
));
|
||||
stack_slot_map.insert(local.id, slot);
|
||||
} else {
|
||||
let var = builder.declare_var(Self::lower_type(&local.ty));
|
||||
var_map.insert(local.id, var);
|
||||
}
|
||||
}
|
||||
|
||||
let mut trans = FunctionTranslator {
|
||||
builder,
|
||||
var_map,
|
||||
stack_slot_map,
|
||||
block_map,
|
||||
locals: &func.locals,
|
||||
module: &mut self.module,
|
||||
@@ -166,9 +179,13 @@ impl CraneliftBackend {
|
||||
if i == 0 {
|
||||
for (j, param_id) in func.params.iter().enumerate() {
|
||||
let val = trans.builder.block_params(cl_block)[j];
|
||||
if let Some(&slot) = trans.stack_slot_map.get(param_id) {
|
||||
trans.builder.ins().stack_store(val, slot, 0);
|
||||
} else {
|
||||
trans.builder.def_var(trans.var_map[param_id], val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for stmt in &block.statements {
|
||||
trans.translate_stmt(stmt);
|
||||
@@ -195,6 +212,7 @@ impl CraneliftBackend {
|
||||
Ty::F32 => types::F32,
|
||||
Ty::F64 => types::F64,
|
||||
Ty::Bool => types::I8, // Booleans are represented as 8-bit integers
|
||||
Ty::Pointer(_) => types::I64, // Assume 64-bit environment pointers
|
||||
_ => unimplemented!("Unsupported type for Cranelift lowering: {:?}", ty),
|
||||
}
|
||||
}
|
||||
@@ -205,6 +223,7 @@ impl CraneliftBackend {
|
||||
struct FunctionTranslator<'a> {
|
||||
builder: FunctionBuilder<'a>,
|
||||
var_map: HashMap<LocalId, Variable>,
|
||||
stack_slot_map: HashMap<LocalId, ir::StackSlot>,
|
||||
block_map: HashMap<BlockId, ir::Block>,
|
||||
locals: &'a [LocalDecl],
|
||||
module: &'a mut ObjectModule,
|
||||
@@ -217,13 +236,25 @@ impl<'a> FunctionTranslator<'a> {
|
||||
StatementKind::Assign(local_id, rvalue) => {
|
||||
let val = self.translate_rvalue(rvalue);
|
||||
if let Some(v) = val {
|
||||
if let Some(&slot) = self.stack_slot_map.get(local_id) {
|
||||
self.builder.ins().stack_store(v, slot, 0);
|
||||
} else {
|
||||
let var = self.var_map[local_id];
|
||||
self.builder.def_var(var, v);
|
||||
}
|
||||
}
|
||||
}
|
||||
StatementKind::SideEffect(rvalue) => {
|
||||
self.translate_rvalue(rvalue);
|
||||
}
|
||||
StatementKind::Store { ptr, val } => {
|
||||
let ptr_val = self.translate_operand(ptr);
|
||||
if let Some(v) = self.translate_rvalue(val) {
|
||||
self.builder
|
||||
.ins()
|
||||
.store(ir::MemFlags::trusted(), v, ptr_val, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -272,9 +303,14 @@ impl<'a> FunctionTranslator<'a> {
|
||||
fn translate_operand(&mut self, op: &Operand) -> ir::Value {
|
||||
match op {
|
||||
Operand::Copy(local_id) => {
|
||||
if let Some(&slot) = self.stack_slot_map.get(local_id) {
|
||||
let cl_ty = CraneliftBackend::lower_type(&self.locals[local_id.0].ty);
|
||||
self.builder.ins().stack_load(cl_ty, slot, 0)
|
||||
} else {
|
||||
let var = self.var_map[local_id];
|
||||
self.builder.use_var(var)
|
||||
}
|
||||
}
|
||||
Operand::Constant(ConstantValue::Integer(val, ty)) => {
|
||||
let cl_ty = CraneliftBackend::lower_type(ty);
|
||||
self.builder.ins().iconst(cl_ty, *val as i64)
|
||||
@@ -315,6 +351,9 @@ impl<'a> FunctionTranslator<'a> {
|
||||
.icmp(ir::condcodes::IntCC::Equal, inner_val, zero),
|
||||
)
|
||||
}
|
||||
UnaryOp::Deref | UnaryOp::AddressOf => {
|
||||
unreachable!("handled as distinct Rvalues in MIR")
|
||||
}
|
||||
}
|
||||
}
|
||||
Rvalue::BinaryOp(op, lhs, rhs) => {
|
||||
@@ -527,6 +566,24 @@ impl<'a> FunctionTranslator<'a> {
|
||||
Some(self.builder.inst_results(call_inst)[0])
|
||||
}
|
||||
}
|
||||
Rvalue::AddressOf(local_id) => {
|
||||
let slot = self.stack_slot_map[local_id];
|
||||
Some(self.builder.ins().stack_addr(types::I64, slot, 0))
|
||||
}
|
||||
Rvalue::ReadPointer(ptr_op) => {
|
||||
let ptr_val = self.translate_operand(ptr_op);
|
||||
let ptr_ty = self.get_operand_type(ptr_op);
|
||||
let inner_ty = match ptr_ty {
|
||||
Ty::Pointer(inner) => *inner,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let cl_ty = CraneliftBackend::lower_type(&inner_ty);
|
||||
Some(
|
||||
self.builder
|
||||
.ins()
|
||||
.load(cl_ty, ir::MemFlags::trusted(), ptr_val, 0),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,6 +91,7 @@ pub enum TypeKind {
|
||||
F32,
|
||||
F64,
|
||||
Bool,
|
||||
Pointer(Box<Type>),
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
@@ -177,6 +178,8 @@ pub enum ExprKind<P: Phase = Untyped> {
|
||||
pub enum UnaryOp {
|
||||
Neg,
|
||||
Not,
|
||||
Deref,
|
||||
AddressOf,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
|
||||
@@ -188,6 +188,7 @@ impl<'src> Lexer<'src> {
|
||||
'*' => token!(TokenKind::Star),
|
||||
'/' => token!(TokenKind::Slash),
|
||||
'%' => token!(TokenKind::Percent),
|
||||
'&' => token!(TokenKind::Ampersand),
|
||||
|
||||
'!' => token!(TokenKind::Bang, '=' => TokenKind::Unequal),
|
||||
'=' => token!(TokenKind::Assign, '=' => TokenKind::Equal),
|
||||
|
||||
@@ -314,6 +314,15 @@ impl<'src> Parser<'src> {
|
||||
self.advance();
|
||||
TypeKind::Bool
|
||||
}
|
||||
TokenKind::Star => {
|
||||
let star_token = self.advance().unwrap();
|
||||
let inner = self.parse_type()?;
|
||||
let span = star_token.span.join(inner.span);
|
||||
return Ok(Type {
|
||||
kind: TypeKind::Pointer(Box::new(inner)),
|
||||
span,
|
||||
});
|
||||
}
|
||||
|
||||
_ => {
|
||||
return Err(ParseError::new(
|
||||
@@ -782,6 +791,8 @@ impl<'src> Parser<'src> {
|
||||
match op {
|
||||
TokenKind::Minus => Some((UnaryOp::Neg, 30)),
|
||||
TokenKind::Bang => Some((UnaryOp::Not, 30)),
|
||||
TokenKind::Star => Some((UnaryOp::Deref, 30)),
|
||||
TokenKind::Ampersand => Some((UnaryOp::AddressOf, 30)),
|
||||
|
||||
_ => None,
|
||||
}
|
||||
|
||||
+85
-6
@@ -41,6 +41,7 @@ pub enum Ty {
|
||||
Unit,
|
||||
Var(usize),
|
||||
Function(Vec<Ty>, Box<Ty>),
|
||||
Pointer(Box<Ty>),
|
||||
}
|
||||
|
||||
impl Ty {
|
||||
@@ -85,6 +86,11 @@ impl Ty {
|
||||
_ => 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns `true` if the type is a pointer.
|
||||
pub fn is_pointer(&self) -> bool {
|
||||
matches!(self, Ty::Pointer(_))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&TypeKind> for Ty {
|
||||
@@ -101,6 +107,20 @@ impl From<&TypeKind> for Ty {
|
||||
TypeKind::F32 => Ty::F32,
|
||||
TypeKind::F64 => Ty::F64,
|
||||
TypeKind::Bool => Ty::Bool,
|
||||
TypeKind::Pointer(inner) => Ty::Pointer(Box::new(Ty::from(&inner.kind))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TypedExpr {
|
||||
/// Returns `true` if the expression produces a valid memory location that can be mutated.
|
||||
pub fn is_lvalue(&self) -> bool {
|
||||
match &self.kind {
|
||||
TypedExprKind::Identifier { .. } => true,
|
||||
TypedExprKind::Unary {
|
||||
op: UnaryOp::Deref, ..
|
||||
} => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -187,6 +207,8 @@ impl Sema {
|
||||
Ty::Function(params, ret)
|
||||
}
|
||||
|
||||
Ty::Pointer(inner) => Ty::Pointer(Box::new(self.apply_subst(inner))),
|
||||
|
||||
t => t.clone(),
|
||||
}
|
||||
}
|
||||
@@ -200,6 +222,7 @@ impl Sema {
|
||||
Ty::Function(params, ret) => {
|
||||
params.iter().any(|p| self.occurs_check(v, p)) || self.occurs_check(v, &ret)
|
||||
}
|
||||
Ty::Pointer(inner) => self.occurs_check(v, &inner),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
@@ -235,6 +258,7 @@ impl Sema {
|
||||
|
||||
self.unify(&r1, &r2)
|
||||
}
|
||||
(Ty::Pointer(p1), Ty::Pointer(p2)) => self.unify(&p1, &p2),
|
||||
(a, b) => Err(format!("type mismatch: expected {:?}, found {:?}", a, b)),
|
||||
}
|
||||
}
|
||||
@@ -617,6 +641,48 @@ impl Sema {
|
||||
}
|
||||
}
|
||||
|
||||
ExprKind::Unary {
|
||||
op: UnaryOp::AddressOf,
|
||||
expr: inner_expr,
|
||||
} => {
|
||||
let typed_inner = self.analyze_expr(inner_expr);
|
||||
if !typed_inner.is_lvalue() {
|
||||
self.errors.push(SemanticError::new(
|
||||
"invalid operand for address-of operator",
|
||||
inner_expr.span,
|
||||
));
|
||||
}
|
||||
TypedExpr {
|
||||
ty: Ty::Pointer(Box::new(typed_inner.ty.clone())),
|
||||
kind: TypedExprKind::Unary {
|
||||
op: UnaryOp::AddressOf,
|
||||
expr: Box::new(typed_inner),
|
||||
},
|
||||
span: expr.span,
|
||||
}
|
||||
}
|
||||
|
||||
ExprKind::Unary {
|
||||
op: UnaryOp::Deref,
|
||||
expr: inner_expr,
|
||||
} => {
|
||||
let typed_inner = self.analyze_expr(inner_expr);
|
||||
let result_ty = self.new_var();
|
||||
if let Err(e) =
|
||||
self.unify(&typed_inner.ty, &Ty::Pointer(Box::new(result_ty.clone())))
|
||||
{
|
||||
self.errors.push(SemanticError::new(e, inner_expr.span));
|
||||
}
|
||||
TypedExpr {
|
||||
kind: TypedExprKind::Unary {
|
||||
op: UnaryOp::Deref,
|
||||
expr: Box::new(typed_inner),
|
||||
},
|
||||
ty: result_ty,
|
||||
span: expr.span,
|
||||
}
|
||||
}
|
||||
|
||||
ExprKind::Binary { op, lhs, rhs } => {
|
||||
let typed_lhs = self.analyze_expr(lhs);
|
||||
let typed_rhs = self.analyze_expr(rhs);
|
||||
@@ -657,12 +723,11 @@ impl Sema {
|
||||
let typed_rval = self.analyze_expr(rval);
|
||||
let typed_lval = self.analyze_expr(lval);
|
||||
|
||||
match &typed_lval.kind {
|
||||
TypedExprKind::Identifier { .. } => {}
|
||||
_ => self.errors.push(SemanticError::new(
|
||||
if !typed_lval.is_lvalue() {
|
||||
self.errors.push(SemanticError::new(
|
||||
"invalid left-hand side of assignment",
|
||||
lval.span,
|
||||
)),
|
||||
));
|
||||
}
|
||||
|
||||
if let Err(e) = self.unify(&typed_lval.ty, &typed_rval.ty) {
|
||||
@@ -994,8 +1059,9 @@ impl Sema {
|
||||
from
|
||||
};
|
||||
|
||||
let is_valid = (final_from.is_numeric() || final_from == Ty::Bool)
|
||||
&& (to.is_numeric() || to == Ty::Bool);
|
||||
let is_valid =
|
||||
(final_from.is_numeric() || final_from == Ty::Bool || final_from.is_pointer())
|
||||
&& (to.is_numeric() || to == Ty::Bool || to.is_pointer());
|
||||
|
||||
if !is_valid {
|
||||
self.errors.push(SemanticError::new(
|
||||
@@ -1229,4 +1295,17 @@ mod test {
|
||||
let src = "fn test() { let a: f32 = 10.0; let b = a as i32; }";
|
||||
assert!(analyze(src).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_pointer_assignment() {
|
||||
let src = "fn test() { let a: i32 = 5; let b: *f32 = &a; }";
|
||||
let errors = analyze(src).unwrap_err();
|
||||
assert!(errors.iter().any(|e| e.message.contains("type mismatch")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn valid_pointer_cast() {
|
||||
let src = "fn test() { let a: i32 = 5; let b: *f32 = &a as *f32; }";
|
||||
assert!(analyze(src).is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -86,6 +86,7 @@ pub enum TokenKind {
|
||||
Star,
|
||||
Slash,
|
||||
Percent,
|
||||
Ampersand,
|
||||
|
||||
Equal,
|
||||
Unequal,
|
||||
@@ -147,6 +148,7 @@ impl Display for TokenKind {
|
||||
TokenKind::Star => "`*`",
|
||||
TokenKind::Slash => "`/`",
|
||||
TokenKind::Percent => "`%`",
|
||||
TokenKind::Ampersand => "`&`",
|
||||
TokenKind::Equal => "`==`",
|
||||
TokenKind::Unequal => "`!=`",
|
||||
TokenKind::LessThan => "`<`",
|
||||
|
||||
+51
-7
@@ -139,6 +139,7 @@ impl FuncBuilder {
|
||||
ty,
|
||||
mutable: false,
|
||||
name: Some(name.clone()),
|
||||
address_taken: false,
|
||||
});
|
||||
self.scopes.last_mut().unwrap().insert(name, id);
|
||||
id
|
||||
@@ -152,6 +153,7 @@ impl FuncBuilder {
|
||||
ty,
|
||||
mutable: false,
|
||||
name: None,
|
||||
address_taken: false,
|
||||
});
|
||||
id
|
||||
}
|
||||
@@ -364,16 +366,44 @@ impl FuncBuilder {
|
||||
}
|
||||
TypedExprKind::Boolean { value } => Operand::Constant(ConstantValue::Boolean(*value)),
|
||||
TypedExprKind::Unary { op, expr: inner } => {
|
||||
match op {
|
||||
UnaryOp::AddressOf => match &inner.kind {
|
||||
TypedExprKind::Identifier { name } => {
|
||||
let id = self.lookup(name);
|
||||
self.locals[id.0].address_taken = true;
|
||||
let temp = self.new_temp(expr.ty.clone());
|
||||
self.emit_stmt(Statement {
|
||||
kind: StatementKind::Assign(temp, Rvalue::AddressOf(id)),
|
||||
span: expr.span,
|
||||
});
|
||||
Operand::Copy(temp)
|
||||
}
|
||||
TypedExprKind::Unary {
|
||||
op: UnaryOp::Deref,
|
||||
expr: ptr_expr,
|
||||
} => self.lower_expr(ptr_expr), // `&*ptr` is optimized right back to `ptr`!
|
||||
_ => unreachable!("invalid lvalue for addressof"),
|
||||
},
|
||||
UnaryOp::Deref => {
|
||||
let inner_op = self.lower_expr(inner);
|
||||
let temp = self.new_temp(expr.ty.clone());
|
||||
self.emit_stmt(Statement {
|
||||
kind: StatementKind::Assign(temp, Rvalue::ReadPointer(inner_op)),
|
||||
span: expr.span,
|
||||
});
|
||||
Operand::Copy(temp)
|
||||
}
|
||||
_ => {
|
||||
let inner_op = self.lower_expr(inner);
|
||||
let temp = self.new_temp(expr.ty.clone());
|
||||
|
||||
self.emit_stmt(Statement {
|
||||
kind: StatementKind::Assign(temp, Rvalue::UnaryOp(*op, inner_op)),
|
||||
span: expr.span,
|
||||
});
|
||||
|
||||
Operand::Copy(temp)
|
||||
}
|
||||
}
|
||||
}
|
||||
TypedExprKind::Binary { op, lhs, rhs } => {
|
||||
let lhs_op = self.lower_expr(lhs);
|
||||
let rhs_op = self.lower_expr(rhs);
|
||||
@@ -389,15 +419,29 @@ impl FuncBuilder {
|
||||
TypedExprKind::Assign { lval, rval } => {
|
||||
let rval_op = self.lower_expr(rval);
|
||||
|
||||
let local_id = match &lval.kind {
|
||||
TypedExprKind::Identifier { name } => self.lookup(name),
|
||||
_ => panic!("invalid lval in MIR lowering"),
|
||||
};
|
||||
|
||||
match &lval.kind {
|
||||
TypedExprKind::Identifier { name } => {
|
||||
let local_id = self.lookup(name);
|
||||
self.emit_stmt(Statement {
|
||||
kind: StatementKind::Assign(local_id, Rvalue::Use(rval_op.clone())),
|
||||
span: expr.span,
|
||||
});
|
||||
}
|
||||
TypedExprKind::Unary {
|
||||
op: UnaryOp::Deref,
|
||||
expr: ptr_expr,
|
||||
} => {
|
||||
let ptr_op = self.lower_expr(ptr_expr);
|
||||
self.emit_stmt(Statement {
|
||||
kind: StatementKind::Store {
|
||||
ptr: ptr_op,
|
||||
val: Rvalue::Use(rval_op.clone()),
|
||||
},
|
||||
span: expr.span,
|
||||
});
|
||||
}
|
||||
_ => unreachable!("invalid lval in MIR lowering"),
|
||||
}
|
||||
|
||||
rval_op
|
||||
}
|
||||
|
||||
+31
-1
@@ -28,7 +28,9 @@ fn optimize_function(func: &mut MirFunction) {
|
||||
if let Some(constant) = evaluate_rvalue(rvalue) {
|
||||
// Replace the complex instruction with a simple constant use
|
||||
*rvalue = Rvalue::Use(Operand::Constant(constant.clone()));
|
||||
if !func.locals[local.0].address_taken {
|
||||
known_constants.insert(*local, constant);
|
||||
}
|
||||
} else {
|
||||
// Reassigned to a non-computable value; remove older cached inferences
|
||||
known_constants.remove(local);
|
||||
@@ -37,6 +39,10 @@ fn optimize_function(func: &mut MirFunction) {
|
||||
StatementKind::SideEffect(rvalue) => {
|
||||
propagate_rvalue(rvalue, &known_constants);
|
||||
}
|
||||
StatementKind::Store { ptr, val } => {
|
||||
propagate_operand(ptr, &known_constants);
|
||||
propagate_rvalue(val, &known_constants);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -86,6 +92,8 @@ fn propagate_rvalue(rvalue: &mut Rvalue, known_constants: &HashMap<LocalId, Cons
|
||||
propagate_operand(arg, known_constants);
|
||||
}
|
||||
}
|
||||
Rvalue::AddressOf(_) => {}
|
||||
Rvalue::ReadPointer(op) => propagate_operand(op, known_constants),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -264,7 +272,29 @@ mod test {
|
||||
name: "test".to_string(),
|
||||
params: vec![],
|
||||
return_type: Ty::I32,
|
||||
locals: vec![],
|
||||
locals: vec![
|
||||
LocalDecl {
|
||||
id: LocalId(0),
|
||||
ty: Ty::I32,
|
||||
mutable: false,
|
||||
name: None,
|
||||
address_taken: false,
|
||||
},
|
||||
LocalDecl {
|
||||
id: LocalId(1),
|
||||
ty: Ty::I32,
|
||||
mutable: false,
|
||||
name: None,
|
||||
address_taken: false,
|
||||
},
|
||||
LocalDecl {
|
||||
id: LocalId(2),
|
||||
ty: Ty::I32,
|
||||
mutable: false,
|
||||
name: None,
|
||||
address_taken: false,
|
||||
},
|
||||
],
|
||||
blocks: vec![BasicBlock {
|
||||
id: BlockId(0),
|
||||
statements: vec![
|
||||
|
||||
@@ -40,6 +40,8 @@ pub struct LocalDecl {
|
||||
pub mutable: bool,
|
||||
/// Contains the name if it is a user-declared variable, or `None` if it's a compiler temporary.
|
||||
pub name: Option<String>,
|
||||
/// Indicates if the address of this local was ever taken, forcing allocation via explicit memory stack slots instead of SSA pseudo-registers.
|
||||
pub address_taken: bool,
|
||||
}
|
||||
|
||||
/// A sequential list of non-branching statements followed by a single terminator.
|
||||
@@ -62,6 +64,8 @@ pub enum StatementKind {
|
||||
Assign(LocalId, Rvalue),
|
||||
/// Executes an Rvalue strictly for its side effects (e.g. FFI calling `Unit` functions)
|
||||
SideEffect(Rvalue),
|
||||
/// Stores a value into the memory address specified by the pointer operand.
|
||||
Store { ptr: Operand, val: Rvalue },
|
||||
}
|
||||
|
||||
/// Operations that produce a value.
|
||||
@@ -72,6 +76,8 @@ pub enum Rvalue {
|
||||
BinaryOp(BinaryOp, Operand, Operand),
|
||||
Cast(Ty, Operand),
|
||||
Call(String, Vec<Operand>, Ty),
|
||||
AddressOf(LocalId),
|
||||
ReadPointer(Operand),
|
||||
}
|
||||
|
||||
/// An atomic value used as inputs to instructions.
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
[code]
|
||||
fn swap(a: *i32, b: *i32) {
|
||||
let temp = *a;
|
||||
*a = *b;
|
||||
*b = temp;
|
||||
}
|
||||
|
||||
fn main() -> i32 {
|
||||
let x = 10;
|
||||
let y = 20;
|
||||
swap(&x, &y);
|
||||
return x - y; // 20 - 10 = 10
|
||||
}
|
||||
|
||||
[expected_return_code]
|
||||
10
|
||||
Reference in New Issue
Block a user