feat: add support for structs and member access
This commit is contained in:
+190
-10
@@ -92,7 +92,7 @@ impl CraneliftBackend {
|
||||
}
|
||||
|
||||
for func in &module.functions {
|
||||
self.compile_function(func);
|
||||
self.compile_function(func, &module.structs);
|
||||
|
||||
// Run Cranelift's optimization passes before emitting the text IR
|
||||
let mut ctrl_plane = ControlPlane::default();
|
||||
@@ -113,7 +113,11 @@ impl CraneliftBackend {
|
||||
}
|
||||
|
||||
/// Lowers a single MIR function into Cranelift IR.
|
||||
fn compile_function(&mut self, func: &MirFunction) {
|
||||
fn compile_function(
|
||||
&mut self,
|
||||
func: &MirFunction,
|
||||
structs: &HashMap<String, Vec<(String, Ty)>>,
|
||||
) {
|
||||
let mut sig = self.module.make_signature();
|
||||
|
||||
for param_id in &func.params {
|
||||
@@ -139,9 +143,8 @@ impl CraneliftBackend {
|
||||
let mut var_map = HashMap::new();
|
||||
let mut stack_slot_map = HashMap::new();
|
||||
for local in &func.locals {
|
||||
if local.address_taken {
|
||||
let cl_ty = Self::lower_type(&local.ty);
|
||||
let bytes = cl_ty.bytes();
|
||||
if local.address_taken || matches!(local.ty, Ty::Struct(_)) {
|
||||
let bytes = Self::type_size(&local.ty, structs);
|
||||
let slot = builder.create_sized_stack_slot(ir::StackSlotData::new(
|
||||
ir::StackSlotKind::ExplicitSlot,
|
||||
bytes,
|
||||
@@ -162,6 +165,7 @@ impl CraneliftBackend {
|
||||
locals: &func.locals,
|
||||
module: &mut self.module,
|
||||
func_ids: &self.func_ids,
|
||||
structs,
|
||||
};
|
||||
|
||||
if let Some(first_block) = func.blocks.first() {
|
||||
@@ -179,7 +183,14 @@ impl CraneliftBackend {
|
||||
if i == 0 {
|
||||
for (j, param_id) in func.params.iter().enumerate() {
|
||||
let val = trans.builder.block_params(cl_block)[j];
|
||||
if let Some(&slot) = trans.stack_slot_map.get(param_id) {
|
||||
|
||||
if matches!(func.locals[param_id.0].ty, Ty::Struct(_)) {
|
||||
let slot = trans.stack_slot_map[param_id];
|
||||
let dest_addr = trans.builder.ins().stack_addr(types::I64, slot, 0);
|
||||
let size =
|
||||
CraneliftBackend::type_size(&func.locals[param_id.0].ty, trans.structs);
|
||||
trans.emit_memcpy(dest_addr, val, size);
|
||||
} else if let Some(&slot) = trans.stack_slot_map.get(param_id) {
|
||||
trans.builder.ins().stack_store(val, slot, 0);
|
||||
} else {
|
||||
trans.builder.def_var(trans.var_map[param_id], val);
|
||||
@@ -212,10 +223,56 @@ impl CraneliftBackend {
|
||||
Ty::F32 => types::F32,
|
||||
Ty::F64 => types::F64,
|
||||
Ty::Bool => types::I8, // Booleans are represented as 8-bit integers
|
||||
Ty::Pointer(_) => types::I64, // Assume 64-bit environment pointers
|
||||
Ty::Pointer(_) | Ty::Struct(_) => types::I64, // Structs are passed by reference implicitly
|
||||
_ => unimplemented!("Unsupported type for Cranelift lowering: {:?}", ty),
|
||||
}
|
||||
}
|
||||
|
||||
fn type_size(ty: &Ty, structs: &HashMap<String, Vec<(String, Ty)>>) -> u32 {
|
||||
match ty {
|
||||
Ty::I8 | Ty::U8 | Ty::Bool => 1,
|
||||
Ty::I16 | Ty::U16 => 2,
|
||||
Ty::I32 | Ty::U32 | Ty::F32 => 4,
|
||||
Ty::I64 | Ty::U64 | Ty::F64 | Ty::Pointer(_) => 8,
|
||||
Ty::Struct(name) => {
|
||||
let mut size = 0;
|
||||
if let Some(fields) = structs.get(name) {
|
||||
for (_, f_ty) in fields {
|
||||
let f_size = Self::type_size(f_ty, structs);
|
||||
let align = f_size.min(8);
|
||||
size = (size + align - 1) & !(align - 1);
|
||||
size += f_size;
|
||||
}
|
||||
size = (size + 7) & !7;
|
||||
}
|
||||
size
|
||||
}
|
||||
Ty::Unit => 0,
|
||||
Ty::Var(_) | Ty::Function(_, _) => unimplemented!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn field_offset(
|
||||
struct_name: &str,
|
||||
field_name: &str,
|
||||
structs: &HashMap<String, Vec<(String, Ty)>>,
|
||||
) -> u32 {
|
||||
let mut offset = 0;
|
||||
if let Some(fields) = structs.get(struct_name) {
|
||||
for (f_name, f_ty) in fields {
|
||||
let f_size = Self::type_size(f_ty, structs);
|
||||
let align = f_size.min(8);
|
||||
offset = (offset + align - 1) & !(align - 1);
|
||||
|
||||
if f_name == field_name {
|
||||
return offset;
|
||||
}
|
||||
|
||||
offset += f_size;
|
||||
}
|
||||
}
|
||||
panic!("Field not found");
|
||||
}
|
||||
}
|
||||
|
||||
/// A visitor that traverses MIR basic blocks and instructions, emitting Cranelift IR instructions
|
||||
@@ -228,18 +285,29 @@ struct FunctionTranslator<'a> {
|
||||
locals: &'a [LocalDecl],
|
||||
module: &'a mut ObjectModule,
|
||||
func_ids: &'a HashMap<String, cranelift_module::FuncId>,
|
||||
structs: &'a HashMap<String, Vec<(String, Ty)>>,
|
||||
}
|
||||
|
||||
impl<'a> FunctionTranslator<'a> {
|
||||
fn translate_stmt(&mut self, stmt: &Statement) {
|
||||
match &stmt.kind {
|
||||
StatementKind::Assign(local_id, rvalue) => {
|
||||
let val = self.translate_rvalue(rvalue);
|
||||
if let Some(v) = val {
|
||||
if let Some(&slot) = self.stack_slot_map.get(local_id) {
|
||||
if let Ty::Struct(name) = &self.locals[local_id.0].ty {
|
||||
let dest_addr = self.builder.ins().stack_addr(types::I64, slot, 0);
|
||||
if let Some(src_addr) = self.translate_rvalue(rvalue) {
|
||||
let size = CraneliftBackend::type_size(
|
||||
&Ty::Struct(name.clone()),
|
||||
self.structs,
|
||||
);
|
||||
self.emit_memcpy(dest_addr, src_addr, size);
|
||||
}
|
||||
} else if let Some(v) = self.translate_rvalue(rvalue) {
|
||||
self.builder.ins().stack_store(v, slot, 0);
|
||||
}
|
||||
} else {
|
||||
let var = self.var_map[local_id];
|
||||
if let Some(v) = self.translate_rvalue(rvalue) {
|
||||
self.builder.def_var(var, v);
|
||||
}
|
||||
}
|
||||
@@ -249,6 +317,13 @@ impl<'a> FunctionTranslator<'a> {
|
||||
}
|
||||
StatementKind::Store { ptr, val } => {
|
||||
let ptr_val = self.translate_operand(ptr);
|
||||
let rval_ty = self.get_rvalue_type(val);
|
||||
if matches!(rval_ty, Ty::Struct(_)) {
|
||||
if let Some(src_addr) = self.translate_rvalue(val) {
|
||||
let size = CraneliftBackend::type_size(&rval_ty, self.structs);
|
||||
self.emit_memcpy(ptr_val, src_addr, size);
|
||||
}
|
||||
} else {
|
||||
if let Some(v) = self.translate_rvalue(val) {
|
||||
self.builder
|
||||
.ins()
|
||||
@@ -257,6 +332,7 @@ impl<'a> FunctionTranslator<'a> {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn translate_terminator(&mut self, term: &Terminator) {
|
||||
match &term.kind {
|
||||
@@ -303,7 +379,10 @@ impl<'a> FunctionTranslator<'a> {
|
||||
fn translate_operand(&mut self, op: &Operand) -> ir::Value {
|
||||
match op {
|
||||
Operand::Copy(local_id) => {
|
||||
if let Some(&slot) = self.stack_slot_map.get(local_id) {
|
||||
if matches!(self.locals[local_id.0].ty, Ty::Struct(_)) {
|
||||
let slot = self.stack_slot_map[local_id];
|
||||
self.builder.ins().stack_addr(types::I64, slot, 0)
|
||||
} else if let Some(&slot) = self.stack_slot_map.get(local_id) {
|
||||
let cl_ty = CraneliftBackend::lower_type(&self.locals[local_id.0].ty);
|
||||
self.builder.ins().stack_load(cl_ty, slot, 0)
|
||||
} else {
|
||||
@@ -577,6 +656,9 @@ impl<'a> FunctionTranslator<'a> {
|
||||
Ty::Pointer(inner) => *inner,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
if matches!(inner_ty, Ty::Struct(_)) {
|
||||
Some(ptr_val)
|
||||
} else {
|
||||
let cl_ty = CraneliftBackend::lower_type(&inner_ty);
|
||||
Some(
|
||||
self.builder
|
||||
@@ -585,5 +667,103 @@ impl<'a> FunctionTranslator<'a> {
|
||||
)
|
||||
}
|
||||
}
|
||||
Rvalue::GetFieldPtr {
|
||||
base_ptr,
|
||||
struct_name,
|
||||
field_name,
|
||||
} => {
|
||||
let base = self.translate_operand(base_ptr);
|
||||
let offset = CraneliftBackend::field_offset(struct_name, field_name, self.structs);
|
||||
Some(self.builder.ins().iadd_imm(base, offset as i64))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_rvalue_type(&self, rvalue: &Rvalue) -> Ty {
|
||||
match rvalue {
|
||||
Rvalue::Use(op) => self.get_operand_type(op),
|
||||
Rvalue::UnaryOp(op, inner) => match op {
|
||||
UnaryOp::Deref => {
|
||||
if let Ty::Pointer(inner) = self.get_operand_type(inner) {
|
||||
*inner
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
UnaryOp::AddressOf => Ty::Pointer(Box::new(self.get_operand_type(inner))),
|
||||
_ => self.get_operand_type(inner),
|
||||
},
|
||||
Rvalue::BinaryOp(_, lhs, _) => self.get_operand_type(lhs),
|
||||
Rvalue::Cast(ty, _) => ty.clone(),
|
||||
Rvalue::Call(_, _, ty) => ty.clone(),
|
||||
Rvalue::AddressOf(local) => Ty::Pointer(Box::new(self.locals[local.0].ty.clone())),
|
||||
Rvalue::ReadPointer(ptr) => {
|
||||
if let Ty::Pointer(inner) = self.get_operand_type(ptr) {
|
||||
*inner
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
Rvalue::GetFieldPtr {
|
||||
struct_name,
|
||||
field_name,
|
||||
..
|
||||
} => {
|
||||
let fields = self.structs.get(struct_name).unwrap();
|
||||
let ty = fields
|
||||
.iter()
|
||||
.find(|(n, _)| n == field_name)
|
||||
.unwrap()
|
||||
.1
|
||||
.clone();
|
||||
Ty::Pointer(Box::new(ty))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn emit_memcpy(&mut self, dest: ir::Value, src: ir::Value, mut size: u32) {
|
||||
let mut offset = 0;
|
||||
while size >= 8 {
|
||||
let val = self
|
||||
.builder
|
||||
.ins()
|
||||
.load(types::I64, ir::MemFlags::trusted(), src, offset);
|
||||
self.builder
|
||||
.ins()
|
||||
.store(ir::MemFlags::trusted(), val, dest, offset);
|
||||
size -= 8;
|
||||
offset += 8;
|
||||
}
|
||||
if size >= 4 {
|
||||
let val = self
|
||||
.builder
|
||||
.ins()
|
||||
.load(types::I32, ir::MemFlags::trusted(), src, offset);
|
||||
self.builder
|
||||
.ins()
|
||||
.store(ir::MemFlags::trusted(), val, dest, offset);
|
||||
size -= 4;
|
||||
offset += 4;
|
||||
}
|
||||
if size >= 2 {
|
||||
let val = self
|
||||
.builder
|
||||
.ins()
|
||||
.load(types::I16, ir::MemFlags::trusted(), src, offset);
|
||||
self.builder
|
||||
.ins()
|
||||
.store(ir::MemFlags::trusted(), val, dest, offset);
|
||||
size -= 2;
|
||||
offset += 2;
|
||||
}
|
||||
if size == 1 {
|
||||
let val = self
|
||||
.builder
|
||||
.ins()
|
||||
.load(types::I8, ir::MemFlags::trusted(), src, offset);
|
||||
self.builder
|
||||
.ins()
|
||||
.store(ir::MemFlags::trusted(), val, dest, offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+31
-1
@@ -63,6 +63,11 @@ pub enum DeclKind<P: Phase = Untyped> {
|
||||
params: Vec<P::ParamType>,
|
||||
return_type: P::ReturnType,
|
||||
},
|
||||
Struct {
|
||||
name: String,
|
||||
name_span: Span,
|
||||
fields: Vec<StructField>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
@@ -73,12 +78,19 @@ pub struct FunctionParam {
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub struct StructField {
|
||||
pub name: String,
|
||||
pub name_span: Span,
|
||||
pub ty: Type,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct Type {
|
||||
pub kind: TypeKind,
|
||||
pub span: Span,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum TypeKind {
|
||||
I8,
|
||||
I16,
|
||||
@@ -92,6 +104,7 @@ pub enum TypeKind {
|
||||
F64,
|
||||
Bool,
|
||||
Pointer(Box<Type>),
|
||||
Struct(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
@@ -172,6 +185,23 @@ pub enum ExprKind<P: Phase = Untyped> {
|
||||
callee: Box<Expr<P>>,
|
||||
args: Vec<Expr<P>>,
|
||||
},
|
||||
Struct {
|
||||
name: String,
|
||||
name_span: Span,
|
||||
fields: Vec<FieldValue<P>>,
|
||||
},
|
||||
FieldAccess {
|
||||
expr: Box<Expr<P>>,
|
||||
field: String,
|
||||
field_span: Span,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct FieldValue<P: Phase = Untyped> {
|
||||
pub name: String,
|
||||
pub name_span: Span,
|
||||
pub value: Expr<P>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
|
||||
@@ -71,6 +71,7 @@ impl<'src> Lexer<'src> {
|
||||
"return" => TokenKind::Return,
|
||||
"let" => TokenKind::Let,
|
||||
"while" => TokenKind::While,
|
||||
"struct" => TokenKind::Struct,
|
||||
"break" => TokenKind::Break,
|
||||
"continue" => TokenKind::Continue,
|
||||
|
||||
@@ -251,7 +252,9 @@ mod test {
|
||||
#[test]
|
||||
fn identifiers() {
|
||||
assert_eq!(
|
||||
tokenize("HELLO _hello _0@ fn if else return let while break continue as foreign"),
|
||||
tokenize(
|
||||
"HELLO _hello _0@ fn if else return let while struct break continue as foreign"
|
||||
),
|
||||
vec![
|
||||
Token::new(TokenKind::Identifier, "HELLO", Span::new(0, 5)),
|
||||
Token::new(TokenKind::Identifier, "_hello", Span::new(6, 12)),
|
||||
@@ -263,10 +266,11 @@ mod test {
|
||||
Token::new(TokenKind::Return, "return", Span::new(28, 34)),
|
||||
Token::new(TokenKind::Let, "let", Span::new(35, 38)),
|
||||
Token::new(TokenKind::While, "while", Span::new(39, 44)),
|
||||
Token::new(TokenKind::Break, "break", Span::new(45, 50)),
|
||||
Token::new(TokenKind::Continue, "continue", Span::new(51, 59)),
|
||||
Token::new(TokenKind::As, "as", Span::new(60, 62)),
|
||||
Token::new(TokenKind::Foreign, "foreign", Span::new(63, 70)),
|
||||
Token::new(TokenKind::Struct, "struct", Span::new(45, 51)),
|
||||
Token::new(TokenKind::Break, "break", Span::new(52, 57)),
|
||||
Token::new(TokenKind::Continue, "continue", Span::new(58, 66)),
|
||||
Token::new(TokenKind::As, "as", Span::new(67, 69)),
|
||||
Token::new(TokenKind::Foreign, "foreign", Span::new(70, 77)),
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
+232
-11
@@ -117,7 +117,7 @@ impl<'src> Parser<'src> {
|
||||
Ok(decl) => decls.push(decl),
|
||||
Err(err) => {
|
||||
self.errors.push(err);
|
||||
self.synchronize(&[TokenKind::Fn]);
|
||||
self.synchronize(&[TokenKind::Fn, TokenKind::Foreign, TokenKind::Struct]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -136,6 +136,7 @@ impl<'src> Parser<'src> {
|
||||
match peek_token.kind {
|
||||
TokenKind::Fn => self.parse_function_decl(),
|
||||
TokenKind::Foreign => self.parse_foreign_function_decl(),
|
||||
TokenKind::Struct => self.parse_struct_decl(),
|
||||
|
||||
_ => Err(ParseError::new(
|
||||
format!(
|
||||
@@ -225,6 +226,54 @@ impl<'src> Parser<'src> {
|
||||
})
|
||||
}
|
||||
|
||||
/// Parses a struct declaration.
|
||||
///
|
||||
/// ```ebnf
|
||||
/// struct_decl = "struct" IDENTIFIER "{" { struct_field } "}" ;
|
||||
/// struct_field = IDENTIFIER ":" type [ "," ] ;
|
||||
/// ```
|
||||
fn parse_struct_decl(&mut self) -> ParseResult<Decl> {
|
||||
let struct_token = self.expect(TokenKind::Struct)?;
|
||||
|
||||
let (name, name_span) = {
|
||||
let ident_token = self.expect(TokenKind::Identifier)?;
|
||||
(ident_token.text.to_string(), ident_token.span)
|
||||
};
|
||||
|
||||
self.expect(TokenKind::LBrace)?;
|
||||
|
||||
let mut fields = Vec::new();
|
||||
while !self.is_at_eof() && !self.is_peek(TokenKind::RBrace) {
|
||||
let field_ident = self.expect(TokenKind::Identifier)?;
|
||||
self.expect(TokenKind::Colon)?;
|
||||
let ty = self.parse_type()?;
|
||||
|
||||
fields.push(StructField {
|
||||
name: field_ident.text.to_string(),
|
||||
name_span: field_ident.span,
|
||||
ty,
|
||||
});
|
||||
|
||||
if self.is_peek(TokenKind::Comma) {
|
||||
self.advance();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let rbrace_token = self.expect(TokenKind::RBrace)?;
|
||||
let span = struct_token.span.join(rbrace_token.span);
|
||||
|
||||
Ok(Decl {
|
||||
kind: DeclKind::Struct {
|
||||
name,
|
||||
name_span,
|
||||
fields,
|
||||
},
|
||||
span,
|
||||
})
|
||||
}
|
||||
|
||||
/// Parses the function parameter list.
|
||||
///
|
||||
/// ```ebnf
|
||||
@@ -265,6 +314,7 @@ impl<'src> Parser<'src> {
|
||||
/// | "u8" | "u16" | "u32" | "u64"
|
||||
/// | "f32" | "f64"
|
||||
/// | "bool" ;
|
||||
/// | IDENTIFIER
|
||||
/// ```
|
||||
pub fn parse_type(&mut self) -> ParseResult<Type> {
|
||||
let peek_token = self.peek_no_eof()?;
|
||||
@@ -314,6 +364,10 @@ impl<'src> Parser<'src> {
|
||||
self.advance();
|
||||
TypeKind::Bool
|
||||
}
|
||||
TokenKind::Identifier => {
|
||||
let token = self.advance().unwrap();
|
||||
TypeKind::Struct(token.text.to_string())
|
||||
}
|
||||
TokenKind::Star => {
|
||||
let star_token = self.advance().unwrap();
|
||||
let inner = self.parse_type()?;
|
||||
@@ -407,7 +461,7 @@ impl<'src> Parser<'src> {
|
||||
fn parse_if_stmt(&mut self) -> ParseResult<Stmt> {
|
||||
let if_token = self.expect(TokenKind::If)?;
|
||||
|
||||
let condition = self.parse_expr()?;
|
||||
let condition = self.parse_expr_no_struct()?;
|
||||
let consequence = self.parse_compound_stmt()?;
|
||||
|
||||
let alternative = if self.is_peek(TokenKind::Else) {
|
||||
@@ -446,7 +500,7 @@ impl<'src> Parser<'src> {
|
||||
/// ```
|
||||
fn parse_while_stmt(&mut self) -> ParseResult<Stmt> {
|
||||
let while_token = self.expect(TokenKind::While)?;
|
||||
let condition = self.parse_expr()?;
|
||||
let condition = self.parse_expr_no_struct()?;
|
||||
let body = self.parse_compound_stmt()?;
|
||||
let span = while_token.span.join(body.span);
|
||||
|
||||
@@ -575,12 +629,17 @@ impl<'src> Parser<'src> {
|
||||
|
||||
/// Parses an expression.
|
||||
pub fn parse_expr(&mut self) -> ParseResult<Expr> {
|
||||
self.parse_expr_bp(0)
|
||||
self.parse_expr_bp(0, true)
|
||||
}
|
||||
|
||||
/// Parses an expression, disallowing struct literals.
|
||||
pub fn parse_expr_no_struct(&mut self) -> ParseResult<Expr> {
|
||||
self.parse_expr_bp(0, false)
|
||||
}
|
||||
|
||||
/// Pratt parsing implementation for expressions.
|
||||
fn parse_expr_bp(&mut self, min_bp: u8) -> ParseResult<Expr> {
|
||||
let mut lhs = self.parse_leading_expr()?;
|
||||
fn parse_expr_bp(&mut self, min_bp: u8, allow_struct: bool) -> ParseResult<Expr> {
|
||||
let mut lhs = self.parse_leading_expr(allow_struct)?;
|
||||
|
||||
loop {
|
||||
let peek_token = self.peek_no_eof()?;
|
||||
@@ -594,7 +653,7 @@ impl<'src> Parser<'src> {
|
||||
}
|
||||
|
||||
self.advance(); // consume '='
|
||||
let rhs = self.parse_expr_bp(right_bp)?;
|
||||
let rhs = self.parse_expr_bp(right_bp, allow_struct)?;
|
||||
let span = lhs.span.join(rhs.span);
|
||||
|
||||
lhs = Expr {
|
||||
@@ -662,6 +721,28 @@ impl<'src> Parser<'src> {
|
||||
continue;
|
||||
}
|
||||
|
||||
if peek_token.kind == TokenKind::Dot {
|
||||
let left_bp = 30; // Field access has very high precedence
|
||||
if left_bp < min_bp {
|
||||
break;
|
||||
}
|
||||
self.advance(); // consume `.`
|
||||
|
||||
let field_token = self.expect(TokenKind::Identifier)?;
|
||||
let span = lhs.span.join(field_token.span);
|
||||
|
||||
lhs = Expr {
|
||||
kind: ExprKind::FieldAccess {
|
||||
expr: Box::new(lhs),
|
||||
field: field_token.text.to_string(),
|
||||
field_span: field_token.span,
|
||||
},
|
||||
ty: (),
|
||||
span,
|
||||
};
|
||||
continue;
|
||||
}
|
||||
|
||||
let Some((op, left_bp, right_bp)) = self.infix_operator(peek_token.kind) else {
|
||||
break; // Not an infix operator
|
||||
};
|
||||
@@ -672,7 +753,7 @@ impl<'src> Parser<'src> {
|
||||
|
||||
self.advance(); // consume the operator
|
||||
|
||||
let rhs = self.parse_expr_bp(right_bp)?;
|
||||
let rhs = self.parse_expr_bp(right_bp, allow_struct)?;
|
||||
let span = lhs.span.join(rhs.span);
|
||||
|
||||
lhs = Expr {
|
||||
@@ -691,13 +772,48 @@ impl<'src> Parser<'src> {
|
||||
|
||||
/// Parses a leading expression such as identifiers, integer and boolean literals
|
||||
/// or prefix expressions.
|
||||
fn parse_leading_expr(&mut self) -> ParseResult<Expr> {
|
||||
fn parse_leading_expr(&mut self, allow_struct: bool) -> ParseResult<Expr> {
|
||||
let peek_token = self.peek_no_eof()?;
|
||||
|
||||
match peek_token.kind {
|
||||
TokenKind::Identifier => {
|
||||
let token = self.advance().unwrap();
|
||||
|
||||
if allow_struct && self.is_peek(TokenKind::LBrace) {
|
||||
self.advance(); // consume `{`
|
||||
|
||||
let mut fields = Vec::new();
|
||||
while !self.is_at_eof() && !self.is_peek(TokenKind::RBrace) {
|
||||
let field_token = self.expect(TokenKind::Identifier)?;
|
||||
self.expect(TokenKind::Colon)?;
|
||||
let value = self.parse_expr()?;
|
||||
|
||||
fields.push(FieldValue {
|
||||
name: field_token.text.to_string(),
|
||||
name_span: field_token.span,
|
||||
value,
|
||||
});
|
||||
|
||||
if self.is_peek(TokenKind::Comma) {
|
||||
self.advance();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let rbrace_token = self.expect(TokenKind::RBrace)?;
|
||||
let span = token.span.join(rbrace_token.span);
|
||||
|
||||
Ok(Expr {
|
||||
kind: ExprKind::Struct {
|
||||
name: token.text.to_string(),
|
||||
name_span: token.span,
|
||||
fields,
|
||||
},
|
||||
ty: (),
|
||||
span,
|
||||
})
|
||||
} else {
|
||||
Ok(Expr {
|
||||
kind: ExprKind::Identifier {
|
||||
name: token.text.to_string(),
|
||||
@@ -706,6 +822,7 @@ impl<'src> Parser<'src> {
|
||||
span: token.span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
TokenKind::IntegerLit => {
|
||||
let token = self.advance().unwrap();
|
||||
@@ -754,7 +871,7 @@ impl<'src> Parser<'src> {
|
||||
|
||||
TokenKind::LParen => {
|
||||
let lparen = self.advance().unwrap();
|
||||
let expr = self.parse_expr_bp(0)?;
|
||||
let expr = self.parse_expr()?; // Inner expressions allow struct literals
|
||||
let rparen = self.expect(TokenKind::RParen)?;
|
||||
|
||||
Ok(Expr {
|
||||
@@ -766,7 +883,7 @@ impl<'src> Parser<'src> {
|
||||
|
||||
kind if let Some((op, r_bp)) = self.prefix_operator(kind) => {
|
||||
let op_token = self.advance().unwrap();
|
||||
let rhs = self.parse_expr_bp(r_bp)?;
|
||||
let rhs = self.parse_expr_bp(r_bp, allow_struct)?;
|
||||
|
||||
Ok(Expr {
|
||||
ty: (),
|
||||
@@ -1356,4 +1473,108 @@ mod test {
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn struct_decl() {
|
||||
assert_eq!(
|
||||
parse("struct Vec3 { x: f32, y: f32 }", Parser::parse_decl),
|
||||
Success(Decl {
|
||||
kind: DeclKind::Struct {
|
||||
name: "Vec3".to_string(),
|
||||
name_span: Span::new(7, 11),
|
||||
fields: vec![
|
||||
StructField {
|
||||
name: "x".to_string(),
|
||||
name_span: Span::new(14, 15),
|
||||
ty: Type {
|
||||
kind: TypeKind::F32,
|
||||
span: Span::new(17, 20)
|
||||
}
|
||||
},
|
||||
StructField {
|
||||
name: "y".to_string(),
|
||||
name_span: Span::new(22, 23),
|
||||
ty: Type {
|
||||
kind: TypeKind::F32,
|
||||
span: Span::new(25, 28)
|
||||
}
|
||||
}
|
||||
],
|
||||
},
|
||||
span: Span::new(0, 30)
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn struct_literal() {
|
||||
assert_eq!(
|
||||
parse("Vec3 { x: 1.0, y: -3.5 }", Parser::parse_expr),
|
||||
Success(Expr {
|
||||
kind: ExprKind::Struct {
|
||||
name: "Vec3".to_string(),
|
||||
name_span: Span::new(0, 4),
|
||||
fields: vec![
|
||||
FieldValue {
|
||||
name: "x".to_string(),
|
||||
name_span: Span::new(7, 8),
|
||||
value: Expr {
|
||||
kind: ExprKind::Float { value: 1.0 },
|
||||
ty: (),
|
||||
span: Span::new(10, 13)
|
||||
}
|
||||
},
|
||||
FieldValue {
|
||||
name: "y".to_string(),
|
||||
name_span: Span::new(15, 16),
|
||||
value: Expr {
|
||||
kind: ExprKind::Unary {
|
||||
op: UnaryOp::Neg,
|
||||
expr: Box::new(Expr {
|
||||
kind: ExprKind::Float { value: 3.5 },
|
||||
ty: (),
|
||||
span: Span::new(19, 22)
|
||||
})
|
||||
},
|
||||
ty: (),
|
||||
span: Span::new(18, 22)
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
ty: (),
|
||||
span: Span::new(0, 24)
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn field_access() {
|
||||
assert_eq!(
|
||||
parse("a.b.c", Parser::parse_expr),
|
||||
Success(Expr {
|
||||
kind: ExprKind::FieldAccess {
|
||||
expr: Box::new(Expr {
|
||||
kind: ExprKind::FieldAccess {
|
||||
expr: Box::new(Expr {
|
||||
kind: ExprKind::Identifier {
|
||||
name: "a".to_string()
|
||||
},
|
||||
ty: (),
|
||||
span: Span::new(0, 1)
|
||||
}),
|
||||
field: "b".to_string(),
|
||||
field_span: Span::new(2, 3),
|
||||
},
|
||||
ty: (),
|
||||
span: Span::new(0, 3)
|
||||
}),
|
||||
field: "c".to_string(),
|
||||
field_span: Span::new(4, 5),
|
||||
},
|
||||
ty: (),
|
||||
span: Span::new(0, 5)
|
||||
})
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,6 +42,7 @@ pub enum Ty {
|
||||
Var(usize),
|
||||
Function(Vec<Ty>, Box<Ty>),
|
||||
Pointer(Box<Ty>),
|
||||
Struct(String),
|
||||
}
|
||||
|
||||
impl Ty {
|
||||
@@ -108,6 +109,7 @@ impl From<&TypeKind> for Ty {
|
||||
TypeKind::F64 => Ty::F64,
|
||||
TypeKind::Bool => Ty::Bool,
|
||||
TypeKind::Pointer(inner) => Ty::Pointer(Box::new(Ty::from(&inner.kind))),
|
||||
TypeKind::Struct(name) => Ty::Struct(name.clone()),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -120,6 +122,7 @@ impl TypedExpr {
|
||||
TypedExprKind::Unary {
|
||||
op: UnaryOp::Deref, ..
|
||||
} => true,
|
||||
TypedExprKind::FieldAccess { expr, .. } => expr.is_lvalue(),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
@@ -132,6 +135,7 @@ pub struct Sema {
|
||||
next_var: usize,
|
||||
subst: HashMap<usize, Ty>,
|
||||
scopes: Vec<HashMap<String, Ty>>,
|
||||
structs: HashMap<String, Vec<(String, Ty)>>,
|
||||
errors: Vec<SemanticError>,
|
||||
deferred_unary_neg: Vec<(Span, Ty, Ty, Option<u64>)>,
|
||||
deferred_binary: Vec<(Span, Ty)>,
|
||||
@@ -148,6 +152,7 @@ impl Sema {
|
||||
next_var: 0,
|
||||
subst: HashMap::new(),
|
||||
scopes: Vec::new(),
|
||||
structs: HashMap::new(),
|
||||
errors: Vec::new(),
|
||||
deferred_unary_neg: Vec::new(),
|
||||
deferred_binary: Vec::new(),
|
||||
@@ -290,6 +295,13 @@ impl Sema {
|
||||
|
||||
self.bind(name, Ty::Function(param_tys, Box::new(ret_ty)));
|
||||
}
|
||||
DeclKind::Struct { name, fields, .. } => {
|
||||
let typed_fields = fields
|
||||
.iter()
|
||||
.map(|f| (f.name.clone(), Ty::from(&f.ty.kind)))
|
||||
.collect();
|
||||
self.structs.insert(name.clone(), typed_fields);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -376,6 +388,25 @@ impl Sema {
|
||||
span: decl.span,
|
||||
}
|
||||
}
|
||||
DeclKind::Struct {
|
||||
name,
|
||||
name_span,
|
||||
fields,
|
||||
} => TypedDecl {
|
||||
kind: TypedDeclKind::Struct {
|
||||
name: name.clone(),
|
||||
name_span: *name_span,
|
||||
fields: fields
|
||||
.iter()
|
||||
.map(|f| StructField {
|
||||
name: f.name.clone(),
|
||||
name_span: f.name_span,
|
||||
ty: f.ty.clone(),
|
||||
})
|
||||
.collect(),
|
||||
},
|
||||
span: decl.span,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -786,6 +817,133 @@ impl Sema {
|
||||
span: expr.span,
|
||||
}
|
||||
}
|
||||
ExprKind::Struct {
|
||||
name,
|
||||
name_span,
|
||||
fields,
|
||||
} => {
|
||||
let mut typed_fields = Vec::new();
|
||||
let mut provided_fields = HashMap::new();
|
||||
|
||||
if let Some(struct_def) = self.structs.get(name).cloned() {
|
||||
for field in fields {
|
||||
let typed_value = self.analyze_expr(&field.value);
|
||||
|
||||
if let Some((_, expected_ty)) =
|
||||
struct_def.iter().find(|(n, _)| n == &field.name)
|
||||
{
|
||||
if let Err(e) = self.unify(&typed_value.ty, expected_ty) {
|
||||
self.errors.push(SemanticError::new(e, field.value.span));
|
||||
}
|
||||
} else {
|
||||
self.errors.push(SemanticError::new(
|
||||
format!("struct `{}` has no field named `{}`", name, field.name),
|
||||
field.name_span,
|
||||
));
|
||||
}
|
||||
|
||||
if provided_fields
|
||||
.insert(field.name.clone(), field.name_span)
|
||||
.is_some()
|
||||
{
|
||||
self.errors.push(SemanticError::new(
|
||||
format!("field `{}` specified more than once", field.name),
|
||||
field.name_span,
|
||||
));
|
||||
}
|
||||
|
||||
typed_fields.push(FieldValue {
|
||||
name: field.name.clone(),
|
||||
name_span: field.name_span,
|
||||
value: typed_value,
|
||||
});
|
||||
}
|
||||
|
||||
for (expected_field, _) in struct_def {
|
||||
if !provided_fields.contains_key(&expected_field) {
|
||||
self.errors.push(SemanticError::new(
|
||||
format!(
|
||||
"missing field `{}` in initializer of `{}`",
|
||||
expected_field, name
|
||||
),
|
||||
expr.span,
|
||||
));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
self.errors.push(SemanticError::new(
|
||||
format!("undeclared struct `{}`", name),
|
||||
*name_span,
|
||||
));
|
||||
for field in fields {
|
||||
let typed_value = self.analyze_expr(&field.value);
|
||||
typed_fields.push(FieldValue {
|
||||
name: field.name.clone(),
|
||||
name_span: field.name_span,
|
||||
value: typed_value,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
TypedExpr {
|
||||
kind: TypedExprKind::Struct {
|
||||
name: name.clone(),
|
||||
name_span: *name_span,
|
||||
fields: typed_fields,
|
||||
},
|
||||
ty: Ty::Struct(name.clone()),
|
||||
span: expr.span,
|
||||
}
|
||||
}
|
||||
ExprKind::FieldAccess {
|
||||
expr: inner_expr,
|
||||
field,
|
||||
field_span,
|
||||
} => {
|
||||
let typed_inner = self.analyze_expr(inner_expr);
|
||||
let result_ty = self.new_var();
|
||||
|
||||
let inner_ty_resolved = self.apply_subst(&typed_inner.ty);
|
||||
match inner_ty_resolved {
|
||||
Ty::Struct(ref struct_name) => {
|
||||
if let Some(struct_def) = self.structs.get(struct_name).cloned() {
|
||||
if let Some((_, field_ty)) = struct_def.iter().find(|(n, _)| n == field)
|
||||
{
|
||||
if let Err(e) = self.unify(&result_ty, field_ty) {
|
||||
self.errors.push(SemanticError::new(e, *field_span));
|
||||
}
|
||||
} else {
|
||||
self.errors.push(SemanticError::new(
|
||||
format!("no field `{}` on type `{}`", field, struct_name),
|
||||
*field_span,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ty::Var(_) => {
|
||||
self.errors.push(SemanticError::new(
|
||||
"type of expression must be known to access a field",
|
||||
inner_expr.span,
|
||||
));
|
||||
}
|
||||
_ => {
|
||||
self.errors.push(SemanticError::new(
|
||||
format!("cannot access field `{}` on a non-struct type", field),
|
||||
*field_span,
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
TypedExpr {
|
||||
kind: TypedExprKind::FieldAccess {
|
||||
expr: Box::new(typed_inner),
|
||||
field: field.clone(),
|
||||
field_span: *field_span,
|
||||
},
|
||||
ty: result_ty,
|
||||
span: expr.span,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -827,6 +985,15 @@ impl Sema {
|
||||
.collect(),
|
||||
return_type: self.apply_subst(&return_type),
|
||||
},
|
||||
TypedDeclKind::Struct {
|
||||
name,
|
||||
name_span,
|
||||
fields,
|
||||
} => TypedDeclKind::Struct {
|
||||
name,
|
||||
name_span,
|
||||
fields,
|
||||
},
|
||||
};
|
||||
|
||||
TypedDecl { kind, span }
|
||||
@@ -916,6 +1083,31 @@ impl Sema {
|
||||
callee: Box::new(self.apply_subst_expr(*callee)),
|
||||
args: args.into_iter().map(|a| self.apply_subst_expr(a)).collect(),
|
||||
},
|
||||
TypedExprKind::Struct {
|
||||
name,
|
||||
name_span,
|
||||
fields,
|
||||
} => TypedExprKind::Struct {
|
||||
name,
|
||||
name_span,
|
||||
fields: fields
|
||||
.into_iter()
|
||||
.map(|f| FieldValue {
|
||||
name: f.name,
|
||||
name_span: f.name_span,
|
||||
value: self.apply_subst_expr(f.value),
|
||||
})
|
||||
.collect(),
|
||||
},
|
||||
TypedExprKind::FieldAccess {
|
||||
expr,
|
||||
field,
|
||||
field_span,
|
||||
} => TypedExprKind::FieldAccess {
|
||||
expr: Box::new(self.apply_subst_expr(*expr)),
|
||||
field,
|
||||
field_span,
|
||||
},
|
||||
};
|
||||
|
||||
TypedExpr { kind, ty, span }
|
||||
@@ -1308,4 +1500,39 @@ mod test {
|
||||
let src = "fn test() { let a: i32 = 5; let b: *f32 = &a as *f32; }";
|
||||
assert!(analyze(src).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn valid_struct() {
|
||||
let src = "
|
||||
struct Vec3 { x: f32, y: f32, z: f32 }
|
||||
fn make_vec() -> Vec3 {
|
||||
return Vec3 { x: 1.0, y: 2.0, z: 3.0 };
|
||||
}
|
||||
fn get_x(v: Vec3) -> f32 {
|
||||
return v.x;
|
||||
}
|
||||
";
|
||||
assert!(analyze(src).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_struct_field() {
|
||||
let src = "struct Vec2 { x: f32, y: f32 } fn test(v: Vec2) -> f32 { return v.z; }";
|
||||
let errors = analyze(src).unwrap_err();
|
||||
assert!(
|
||||
errors
|
||||
.iter()
|
||||
.any(|e| e.message.contains("no field `z` on type `Vec2`"))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn missing_struct_initializer_field() {
|
||||
let src = "struct Vec2 { x: f32, y: f32 } fn test() -> Vec2 { return Vec2 { x: 1.0 }; }";
|
||||
let errors = analyze(src).unwrap_err();
|
||||
assert!(errors.iter().any(|e| {
|
||||
e.message
|
||||
.contains("missing field `y` in initializer of `Vec2`")
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,6 +64,7 @@ pub enum TokenKind {
|
||||
Return,
|
||||
Let,
|
||||
While,
|
||||
Struct,
|
||||
Break,
|
||||
Continue,
|
||||
|
||||
@@ -130,6 +131,7 @@ impl Display for TokenKind {
|
||||
TokenKind::Return => "`return`",
|
||||
TokenKind::Let => "`let`",
|
||||
TokenKind::While => "`while`",
|
||||
TokenKind::Struct => "`struct`",
|
||||
TokenKind::Break => "`break`",
|
||||
TokenKind::Continue => "`continue`",
|
||||
TokenKind::I8 => "`i8`",
|
||||
|
||||
+193
-32
@@ -12,6 +12,20 @@ impl MirBuilder {
|
||||
pub fn build(module: &TypedModule) -> MirModule {
|
||||
let mut extern_functions = Vec::new();
|
||||
let mut functions = Vec::new();
|
||||
let mut structs = HashMap::new();
|
||||
|
||||
// Collect struct layouts so the backend knows their sizes
|
||||
for decl in &module.decls {
|
||||
if let TypedDeclKind::Struct { name, fields, .. } = &decl.kind {
|
||||
structs.insert(
|
||||
name.clone(),
|
||||
fields
|
||||
.iter()
|
||||
.map(|f| (f.name.clone(), Ty::from(&f.ty.kind)))
|
||||
.collect(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
for decl in &module.decls {
|
||||
match &decl.kind {
|
||||
@@ -22,7 +36,23 @@ impl MirBuilder {
|
||||
return_type,
|
||||
body,
|
||||
} => {
|
||||
let mut builder = FuncBuilder::new(name.clone(), return_type.clone());
|
||||
let is_sret = matches!(return_type, Ty::Struct(_));
|
||||
let mir_return_type = if is_sret {
|
||||
Ty::Unit
|
||||
} else {
|
||||
return_type.clone()
|
||||
};
|
||||
let mut builder = FuncBuilder::new(name.clone(), mir_return_type);
|
||||
|
||||
// Implement implicit pass-by-reference for struct return values (sret)
|
||||
if is_sret {
|
||||
let sret_id = builder.new_local(
|
||||
"$sret".to_string(),
|
||||
Ty::Pointer(Box::new(return_type.clone())),
|
||||
);
|
||||
builder.params.push(sret_id);
|
||||
builder.sret_local = Some(sret_id);
|
||||
}
|
||||
|
||||
// Register parameters as local variables
|
||||
for (param_name, ty) in params {
|
||||
@@ -66,10 +96,12 @@ impl MirBuilder {
|
||||
return_type: return_type.clone(),
|
||||
});
|
||||
}
|
||||
TypedDeclKind::Struct { .. } => {}
|
||||
}
|
||||
}
|
||||
|
||||
MirModule {
|
||||
structs,
|
||||
extern_functions,
|
||||
functions,
|
||||
}
|
||||
@@ -104,6 +136,8 @@ struct FuncBuilder {
|
||||
|
||||
/// Stack of `(continue_target, break_target)` for nested loops
|
||||
loop_stack: Vec<(BlockId, BlockId)>,
|
||||
/// Local ID mapped to the hidden `$sret` return pointer (if applicable)
|
||||
sret_local: Option<LocalId>,
|
||||
}
|
||||
|
||||
impl FuncBuilder {
|
||||
@@ -120,6 +154,7 @@ impl FuncBuilder {
|
||||
next_block_id: 0,
|
||||
scopes: vec![HashMap::new()],
|
||||
loop_stack: Vec::new(),
|
||||
sret_local: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -323,12 +358,27 @@ impl FuncBuilder {
|
||||
}
|
||||
}
|
||||
TypedStmtKind::Return { value } => {
|
||||
if let Some(sret_id) = self.sret_local {
|
||||
let val_op = self.lower_expr(value.as_ref().unwrap());
|
||||
self.emit_stmt(Statement {
|
||||
kind: StatementKind::Store {
|
||||
ptr: Operand::Copy(sret_id),
|
||||
val: Rvalue::Use(val_op),
|
||||
},
|
||||
span: stmt.span,
|
||||
});
|
||||
self.terminate(Terminator {
|
||||
kind: TerminatorKind::Return { value: None },
|
||||
span: stmt.span,
|
||||
});
|
||||
} else {
|
||||
let val_op = value.as_ref().map(|v| self.lower_expr(v));
|
||||
self.terminate(Terminator {
|
||||
kind: TerminatorKind::Return { value: val_op },
|
||||
span: stmt.span,
|
||||
});
|
||||
}
|
||||
}
|
||||
TypedStmtKind::Let {
|
||||
name,
|
||||
name_span: _,
|
||||
@@ -365,25 +415,8 @@ impl FuncBuilder {
|
||||
Operand::Constant(ConstantValue::Float(*value, expr.ty.clone()))
|
||||
}
|
||||
TypedExprKind::Boolean { value } => Operand::Constant(ConstantValue::Boolean(*value)),
|
||||
TypedExprKind::Unary { op, expr: inner } => {
|
||||
match op {
|
||||
UnaryOp::AddressOf => match &inner.kind {
|
||||
TypedExprKind::Identifier { name } => {
|
||||
let id = self.lookup(name);
|
||||
self.locals[id.0].address_taken = true;
|
||||
let temp = self.new_temp(expr.ty.clone());
|
||||
self.emit_stmt(Statement {
|
||||
kind: StatementKind::Assign(temp, Rvalue::AddressOf(id)),
|
||||
span: expr.span,
|
||||
});
|
||||
Operand::Copy(temp)
|
||||
}
|
||||
TypedExprKind::Unary {
|
||||
op: UnaryOp::Deref,
|
||||
expr: ptr_expr,
|
||||
} => self.lower_expr(ptr_expr), // `&*ptr` is optimized right back to `ptr`!
|
||||
_ => unreachable!("invalid lvalue for addressof"),
|
||||
},
|
||||
TypedExprKind::Unary { op, expr: inner } => match op {
|
||||
UnaryOp::AddressOf => self.lower_address_of(inner),
|
||||
UnaryOp::Deref => {
|
||||
let inner_op = self.lower_expr(inner);
|
||||
let temp = self.new_temp(expr.ty.clone());
|
||||
@@ -402,8 +435,7 @@ impl FuncBuilder {
|
||||
});
|
||||
Operand::Copy(temp)
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
TypedExprKind::Binary { op, lhs, rhs } => {
|
||||
let lhs_op = self.lower_expr(lhs);
|
||||
let rhs_op = self.lower_expr(rhs);
|
||||
@@ -427,11 +459,8 @@ impl FuncBuilder {
|
||||
span: expr.span,
|
||||
});
|
||||
}
|
||||
TypedExprKind::Unary {
|
||||
op: UnaryOp::Deref,
|
||||
expr: ptr_expr,
|
||||
} => {
|
||||
let ptr_op = self.lower_expr(ptr_expr);
|
||||
_ => {
|
||||
let ptr_op = self.lower_address_of(lval);
|
||||
self.emit_stmt(Statement {
|
||||
kind: StatementKind::Store {
|
||||
ptr: ptr_op,
|
||||
@@ -440,7 +469,6 @@ impl FuncBuilder {
|
||||
span: expr.span,
|
||||
});
|
||||
}
|
||||
_ => unreachable!("invalid lval in MIR lowering"),
|
||||
}
|
||||
|
||||
rval_op
|
||||
@@ -466,20 +494,43 @@ impl FuncBuilder {
|
||||
};
|
||||
|
||||
let mut arg_ops = Vec::new();
|
||||
let is_sret = matches!(expr.ty, Ty::Struct(_));
|
||||
let mut sret_temp = None;
|
||||
|
||||
// Implement implicit sret struct passing to the callee
|
||||
if is_sret {
|
||||
let temp = self.new_temp(expr.ty.clone());
|
||||
self.locals[temp.0].address_taken = true;
|
||||
|
||||
let ptr_temp = self.new_temp(Ty::Pointer(Box::new(expr.ty.clone())));
|
||||
self.emit_stmt(Statement {
|
||||
kind: StatementKind::Assign(ptr_temp, Rvalue::AddressOf(temp)),
|
||||
span: expr.span,
|
||||
});
|
||||
|
||||
arg_ops.push(Operand::Copy(ptr_temp));
|
||||
sret_temp = Some(temp);
|
||||
}
|
||||
|
||||
for arg in args {
|
||||
arg_ops.push(self.lower_expr(arg));
|
||||
}
|
||||
|
||||
let rval = Rvalue::Call(callee_name, arg_ops, expr.ty.clone());
|
||||
let mir_ret_ty = if is_sret { Ty::Unit } else { expr.ty.clone() };
|
||||
let rval = Rvalue::Call(callee_name, arg_ops, mir_ret_ty.clone());
|
||||
|
||||
if expr.ty == Ty::Unit {
|
||||
if mir_ret_ty == Ty::Unit {
|
||||
self.emit_stmt(Statement {
|
||||
kind: StatementKind::SideEffect(rval),
|
||||
span: expr.span,
|
||||
});
|
||||
Operand::Constant(ConstantValue::Boolean(false)) // Dummy value for Unit assignments
|
||||
if let Some(temp) = sret_temp {
|
||||
Operand::Copy(temp)
|
||||
} else {
|
||||
let temp = self.new_temp(expr.ty.clone());
|
||||
Operand::Constant(ConstantValue::Boolean(false)) // Dummy value for Unit
|
||||
}
|
||||
} else {
|
||||
let temp = self.new_temp(mir_ret_ty);
|
||||
self.emit_stmt(Statement {
|
||||
kind: StatementKind::Assign(temp, rval),
|
||||
span: expr.span,
|
||||
@@ -487,6 +538,116 @@ impl FuncBuilder {
|
||||
Operand::Copy(temp)
|
||||
}
|
||||
}
|
||||
TypedExprKind::Struct {
|
||||
name,
|
||||
name_span: _,
|
||||
fields,
|
||||
} => {
|
||||
let local_id = self.new_temp(expr.ty.clone());
|
||||
self.locals[local_id.0].address_taken = true;
|
||||
|
||||
let base_ptr_temp = self.new_temp(Ty::Pointer(Box::new(expr.ty.clone())));
|
||||
self.emit_stmt(Statement {
|
||||
kind: StatementKind::Assign(base_ptr_temp, Rvalue::AddressOf(local_id)),
|
||||
span: expr.span,
|
||||
});
|
||||
|
||||
for field in fields {
|
||||
let val_op = self.lower_expr(&field.value);
|
||||
let field_ptr_temp =
|
||||
self.new_temp(Ty::Pointer(Box::new(field.value.ty.clone())));
|
||||
|
||||
self.emit_stmt(Statement {
|
||||
kind: StatementKind::Assign(
|
||||
field_ptr_temp,
|
||||
Rvalue::GetFieldPtr {
|
||||
base_ptr: Operand::Copy(base_ptr_temp),
|
||||
struct_name: name.clone(),
|
||||
field_name: field.name.clone(),
|
||||
},
|
||||
),
|
||||
span: field.name_span,
|
||||
});
|
||||
|
||||
self.emit_stmt(Statement {
|
||||
kind: StatementKind::Store {
|
||||
ptr: Operand::Copy(field_ptr_temp),
|
||||
val: Rvalue::Use(val_op),
|
||||
},
|
||||
span: field.value.span,
|
||||
});
|
||||
}
|
||||
|
||||
Operand::Copy(local_id)
|
||||
}
|
||||
TypedExprKind::FieldAccess { .. } => {
|
||||
let ptr_op = self.lower_address_of(expr);
|
||||
let temp = self.new_temp(expr.ty.clone());
|
||||
self.emit_stmt(Statement {
|
||||
kind: StatementKind::Assign(temp, Rvalue::ReadPointer(ptr_op)),
|
||||
span: expr.span,
|
||||
});
|
||||
Operand::Copy(temp)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Safely computes the memory address of an expression without extracting or dereferencing
|
||||
/// its underlying evaluated contents explicitly.
|
||||
fn lower_address_of(&mut self, expr: &TypedExpr) -> Operand {
|
||||
match &expr.kind {
|
||||
TypedExprKind::Identifier { name } => {
|
||||
let id = self.lookup(name);
|
||||
self.locals[id.0].address_taken = true;
|
||||
let temp = self.new_temp(Ty::Pointer(Box::new(expr.ty.clone())));
|
||||
self.emit_stmt(Statement {
|
||||
kind: StatementKind::Assign(temp, Rvalue::AddressOf(id)),
|
||||
span: expr.span,
|
||||
});
|
||||
Operand::Copy(temp)
|
||||
}
|
||||
TypedExprKind::Unary {
|
||||
op: UnaryOp::Deref,
|
||||
expr: ptr_expr,
|
||||
} => self.lower_expr(ptr_expr), // `&*ptr` is optimized right back to `ptr`!
|
||||
TypedExprKind::FieldAccess {
|
||||
expr: base,
|
||||
field,
|
||||
field_span: _,
|
||||
} => {
|
||||
let base_ptr = self.lower_address_of(base);
|
||||
let struct_name = match &base.ty {
|
||||
Ty::Struct(name) => name.clone(),
|
||||
_ => unreachable!("field access on non-struct"),
|
||||
};
|
||||
let temp = self.new_temp(Ty::Pointer(Box::new(expr.ty.clone())));
|
||||
self.emit_stmt(Statement {
|
||||
kind: StatementKind::Assign(
|
||||
temp,
|
||||
Rvalue::GetFieldPtr {
|
||||
base_ptr,
|
||||
struct_name,
|
||||
field_name: field.clone(),
|
||||
},
|
||||
),
|
||||
span: expr.span,
|
||||
});
|
||||
Operand::Copy(temp)
|
||||
}
|
||||
_ => {
|
||||
let val_op = self.lower_expr(expr);
|
||||
if let Operand::Copy(id) = val_op {
|
||||
self.locals[id.0].address_taken = true;
|
||||
let temp = self.new_temp(Ty::Pointer(Box::new(expr.ty.clone())));
|
||||
self.emit_stmt(Statement {
|
||||
kind: StatementKind::Assign(temp, Rvalue::AddressOf(id)),
|
||||
span: expr.span,
|
||||
});
|
||||
Operand::Copy(temp)
|
||||
} else {
|
||||
unreachable!("cannot safely take address of constant rvalue")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -94,6 +94,7 @@ mod test {
|
||||
#[test]
|
||||
fn test_eliminate_dead_blocks() {
|
||||
let mut module = MirModule {
|
||||
structs: HashMap::new(),
|
||||
functions: vec![MirFunction {
|
||||
name: "test_func".to_string(),
|
||||
params: vec![],
|
||||
@@ -145,6 +146,7 @@ mod test {
|
||||
#[test]
|
||||
fn test_eliminate_dead_cond_branch() {
|
||||
let mut module = MirModule {
|
||||
structs: HashMap::new(),
|
||||
functions: vec![MirFunction {
|
||||
name: "test_cond_func".to_string(),
|
||||
params: vec![],
|
||||
|
||||
@@ -94,6 +94,7 @@ fn propagate_rvalue(rvalue: &mut Rvalue, known_constants: &HashMap<LocalId, Cons
|
||||
}
|
||||
Rvalue::AddressOf(_) => {}
|
||||
Rvalue::ReadPointer(op) => propagate_operand(op, known_constants),
|
||||
Rvalue::GetFieldPtr { base_ptr, .. } => propagate_operand(base_ptr, known_constants),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -334,6 +335,7 @@ mod test {
|
||||
};
|
||||
|
||||
let mut module = MirModule {
|
||||
structs: HashMap::new(),
|
||||
functions: vec![func],
|
||||
extern_functions: vec![],
|
||||
};
|
||||
@@ -382,6 +384,7 @@ mod test {
|
||||
};
|
||||
|
||||
let mut module = MirModule {
|
||||
structs: HashMap::new(),
|
||||
functions: vec![func],
|
||||
extern_functions: vec![],
|
||||
};
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::frontend::ast::{BinaryOp, UnaryOp};
|
||||
use crate::frontend::sema::Ty;
|
||||
use crate::frontend::token::Span;
|
||||
@@ -12,6 +14,7 @@ pub struct LocalId(pub usize);
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MirModule {
|
||||
pub structs: HashMap<String, Vec<(String, Ty)>>,
|
||||
pub extern_functions: Vec<MirExternFunction>,
|
||||
pub functions: Vec<MirFunction>,
|
||||
}
|
||||
@@ -78,6 +81,11 @@ pub enum Rvalue {
|
||||
Call(String, Vec<Operand>, Ty),
|
||||
AddressOf(LocalId),
|
||||
ReadPointer(Operand),
|
||||
GetFieldPtr {
|
||||
base_ptr: Operand,
|
||||
struct_name: String,
|
||||
field_name: String,
|
||||
},
|
||||
}
|
||||
|
||||
/// An atomic value used as inputs to instructions.
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
[code]
|
||||
foreign fn putchar(c: i32) -> i32;
|
||||
|
||||
struct Point {
|
||||
x: i32,
|
||||
y: i32
|
||||
}
|
||||
|
||||
fn print_num(n: i32) {
|
||||
// Simple hack to print a 2-digit number for testing
|
||||
putchar(48 + (n / 10));
|
||||
putchar(48 + (n % 10));
|
||||
putchar(10); // newline
|
||||
}
|
||||
|
||||
fn main() -> i32 {
|
||||
let p = Point { x: 40, y: 2 };
|
||||
|
||||
// 40 + 2 = 42
|
||||
print_num(p.x + p.y);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
[expected_return_code]
|
||||
0
|
||||
|
||||
[expected_output]
|
||||
42
|
||||
@@ -0,0 +1,38 @@
|
||||
[code]
|
||||
foreign fn putchar(c: i32) -> i32;
|
||||
|
||||
struct Rect {
|
||||
width: i32,
|
||||
height: i32
|
||||
}
|
||||
|
||||
fn modify_rect(r: *Rect) {
|
||||
let temp: Rect = *r;
|
||||
temp.width = temp.width + 10;
|
||||
temp.height = temp.height + 20;
|
||||
*r = temp;
|
||||
}
|
||||
|
||||
fn print_num(n: i32) {
|
||||
putchar(48 + (n / 10));
|
||||
putchar(48 + (n % 10));
|
||||
putchar(10);
|
||||
}
|
||||
|
||||
fn main() -> i32 {
|
||||
let r = Rect { width: 15, height: 25 };
|
||||
|
||||
modify_rect(&r);
|
||||
|
||||
print_num(r.width);
|
||||
print_num(r.height);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
[expected_return_code]
|
||||
0
|
||||
|
||||
[expected_output]
|
||||
25
|
||||
45
|
||||
Reference in New Issue
Block a user