From ec2aa771fac27aeafba849bc7eb95c464b33bf3e Mon Sep 17 00:00:00 2001 From: Jooris Hadeler Date: Wed, 22 Apr 2026 23:49:13 +0200 Subject: [PATCH] feat: add support for structs and member access --- src/backend/cranelift.rs | 226 +++++++++++++++++++++++++++---- src/frontend/ast.rs | 32 ++++- src/frontend/lexer.rs | 14 +- src/frontend/parser.rs | 257 ++++++++++++++++++++++++++++++++--- src/frontend/sema.rs | 227 +++++++++++++++++++++++++++++++ src/frontend/token.rs | 2 + src/middle/builder.rs | 269 +++++++++++++++++++++++++++++-------- src/middle/dce.rs | 2 + src/middle/fold.rs | 3 + src/middle/mir.rs | 8 ++ tests/struct_basic.test | 29 ++++ tests/struct_pointers.test | 38 ++++++ 12 files changed, 1006 insertions(+), 101 deletions(-) create mode 100644 tests/struct_basic.test create mode 100644 tests/struct_pointers.test diff --git a/src/backend/cranelift.rs b/src/backend/cranelift.rs index c281e1a..d0bd2cd 100644 --- a/src/backend/cranelift.rs +++ b/src/backend/cranelift.rs @@ -92,7 +92,7 @@ impl CraneliftBackend { } for func in &module.functions { - self.compile_function(func); + self.compile_function(func, &module.structs); // Run Cranelift's optimization passes before emitting the text IR let mut ctrl_plane = ControlPlane::default(); @@ -113,7 +113,11 @@ impl CraneliftBackend { } /// Lowers a single MIR function into Cranelift IR. - fn compile_function(&mut self, func: &MirFunction) { + fn compile_function( + &mut self, + func: &MirFunction, + structs: &HashMap>, + ) { let mut sig = self.module.make_signature(); for param_id in &func.params { @@ -139,9 +143,8 @@ impl CraneliftBackend { let mut var_map = HashMap::new(); let mut stack_slot_map = HashMap::new(); for local in &func.locals { - if local.address_taken { - let cl_ty = Self::lower_type(&local.ty); - let bytes = cl_ty.bytes(); + if local.address_taken || matches!(local.ty, Ty::Struct(_)) { + let bytes = Self::type_size(&local.ty, structs); let slot = builder.create_sized_stack_slot(ir::StackSlotData::new( ir::StackSlotKind::ExplicitSlot, bytes, @@ -162,6 +165,7 @@ impl CraneliftBackend { locals: &func.locals, module: &mut self.module, func_ids: &self.func_ids, + structs, }; if let Some(first_block) = func.blocks.first() { @@ -179,7 +183,14 @@ impl CraneliftBackend { if i == 0 { for (j, param_id) in func.params.iter().enumerate() { let val = trans.builder.block_params(cl_block)[j]; - if let Some(&slot) = trans.stack_slot_map.get(param_id) { + + if matches!(func.locals[param_id.0].ty, Ty::Struct(_)) { + let slot = trans.stack_slot_map[param_id]; + let dest_addr = trans.builder.ins().stack_addr(types::I64, slot, 0); + let size = + CraneliftBackend::type_size(&func.locals[param_id.0].ty, trans.structs); + trans.emit_memcpy(dest_addr, val, size); + } else if let Some(&slot) = trans.stack_slot_map.get(param_id) { trans.builder.ins().stack_store(val, slot, 0); } else { trans.builder.def_var(trans.var_map[param_id], val); @@ -212,10 +223,56 @@ impl CraneliftBackend { Ty::F32 => types::F32, Ty::F64 => types::F64, Ty::Bool => types::I8, // Booleans are represented as 8-bit integers - Ty::Pointer(_) => types::I64, // Assume 64-bit environment pointers + Ty::Pointer(_) | Ty::Struct(_) => types::I64, // Structs are passed by reference implicitly _ => unimplemented!("Unsupported type for Cranelift lowering: {:?}", ty), } } + + fn type_size(ty: &Ty, structs: &HashMap>) -> u32 { + match ty { + Ty::I8 | Ty::U8 | Ty::Bool => 1, + Ty::I16 | Ty::U16 => 2, + Ty::I32 | Ty::U32 | Ty::F32 => 4, + Ty::I64 | Ty::U64 | Ty::F64 | Ty::Pointer(_) => 8, + Ty::Struct(name) => { + let mut size = 0; + if let Some(fields) = structs.get(name) { + for (_, f_ty) in fields { + let f_size = Self::type_size(f_ty, structs); + let align = f_size.min(8); + size = (size + align - 1) & !(align - 1); + size += f_size; + } + size = (size + 7) & !7; + } + size + } + Ty::Unit => 0, + Ty::Var(_) | Ty::Function(_, _) => unimplemented!(), + } + } + + fn field_offset( + struct_name: &str, + field_name: &str, + structs: &HashMap>, + ) -> u32 { + let mut offset = 0; + if let Some(fields) = structs.get(struct_name) { + for (f_name, f_ty) in fields { + let f_size = Self::type_size(f_ty, structs); + let align = f_size.min(8); + offset = (offset + align - 1) & !(align - 1); + + if f_name == field_name { + return offset; + } + + offset += f_size; + } + } + panic!("Field not found"); + } } /// A visitor that traverses MIR basic blocks and instructions, emitting Cranelift IR instructions @@ -228,18 +285,29 @@ struct FunctionTranslator<'a> { locals: &'a [LocalDecl], module: &'a mut ObjectModule, func_ids: &'a HashMap, + structs: &'a HashMap>, } impl<'a> FunctionTranslator<'a> { fn translate_stmt(&mut self, stmt: &Statement) { match &stmt.kind { StatementKind::Assign(local_id, rvalue) => { - let val = self.translate_rvalue(rvalue); - if let Some(v) = val { - if let Some(&slot) = self.stack_slot_map.get(local_id) { + if let Some(&slot) = self.stack_slot_map.get(local_id) { + if let Ty::Struct(name) = &self.locals[local_id.0].ty { + let dest_addr = self.builder.ins().stack_addr(types::I64, slot, 0); + if let Some(src_addr) = self.translate_rvalue(rvalue) { + let size = CraneliftBackend::type_size( + &Ty::Struct(name.clone()), + self.structs, + ); + self.emit_memcpy(dest_addr, src_addr, size); + } + } else if let Some(v) = self.translate_rvalue(rvalue) { self.builder.ins().stack_store(v, slot, 0); - } else { - let var = self.var_map[local_id]; + } + } else { + let var = self.var_map[local_id]; + if let Some(v) = self.translate_rvalue(rvalue) { self.builder.def_var(var, v); } } @@ -249,10 +317,18 @@ impl<'a> FunctionTranslator<'a> { } StatementKind::Store { ptr, val } => { let ptr_val = self.translate_operand(ptr); - if let Some(v) = self.translate_rvalue(val) { - self.builder - .ins() - .store(ir::MemFlags::trusted(), v, ptr_val, 0); + let rval_ty = self.get_rvalue_type(val); + if matches!(rval_ty, Ty::Struct(_)) { + if let Some(src_addr) = self.translate_rvalue(val) { + let size = CraneliftBackend::type_size(&rval_ty, self.structs); + self.emit_memcpy(ptr_val, src_addr, size); + } + } else { + if let Some(v) = self.translate_rvalue(val) { + self.builder + .ins() + .store(ir::MemFlags::trusted(), v, ptr_val, 0); + } } } } @@ -303,7 +379,10 @@ impl<'a> FunctionTranslator<'a> { fn translate_operand(&mut self, op: &Operand) -> ir::Value { match op { Operand::Copy(local_id) => { - if let Some(&slot) = self.stack_slot_map.get(local_id) { + if matches!(self.locals[local_id.0].ty, Ty::Struct(_)) { + let slot = self.stack_slot_map[local_id]; + self.builder.ins().stack_addr(types::I64, slot, 0) + } else if let Some(&slot) = self.stack_slot_map.get(local_id) { let cl_ty = CraneliftBackend::lower_type(&self.locals[local_id.0].ty); self.builder.ins().stack_load(cl_ty, slot, 0) } else { @@ -577,13 +656,114 @@ impl<'a> FunctionTranslator<'a> { Ty::Pointer(inner) => *inner, _ => unreachable!(), }; - let cl_ty = CraneliftBackend::lower_type(&inner_ty); - Some( - self.builder - .ins() - .load(cl_ty, ir::MemFlags::trusted(), ptr_val, 0), - ) + if matches!(inner_ty, Ty::Struct(_)) { + Some(ptr_val) + } else { + let cl_ty = CraneliftBackend::lower_type(&inner_ty); + Some( + self.builder + .ins() + .load(cl_ty, ir::MemFlags::trusted(), ptr_val, 0), + ) + } + } + Rvalue::GetFieldPtr { + base_ptr, + struct_name, + field_name, + } => { + let base = self.translate_operand(base_ptr); + let offset = CraneliftBackend::field_offset(struct_name, field_name, self.structs); + Some(self.builder.ins().iadd_imm(base, offset as i64)) } } } + + fn get_rvalue_type(&self, rvalue: &Rvalue) -> Ty { + match rvalue { + Rvalue::Use(op) => self.get_operand_type(op), + Rvalue::UnaryOp(op, inner) => match op { + UnaryOp::Deref => { + if let Ty::Pointer(inner) = self.get_operand_type(inner) { + *inner + } else { + unreachable!() + } + } + UnaryOp::AddressOf => Ty::Pointer(Box::new(self.get_operand_type(inner))), + _ => self.get_operand_type(inner), + }, + Rvalue::BinaryOp(_, lhs, _) => self.get_operand_type(lhs), + Rvalue::Cast(ty, _) => ty.clone(), + Rvalue::Call(_, _, ty) => ty.clone(), + Rvalue::AddressOf(local) => Ty::Pointer(Box::new(self.locals[local.0].ty.clone())), + Rvalue::ReadPointer(ptr) => { + if let Ty::Pointer(inner) = self.get_operand_type(ptr) { + *inner + } else { + unreachable!() + } + } + Rvalue::GetFieldPtr { + struct_name, + field_name, + .. + } => { + let fields = self.structs.get(struct_name).unwrap(); + let ty = fields + .iter() + .find(|(n, _)| n == field_name) + .unwrap() + .1 + .clone(); + Ty::Pointer(Box::new(ty)) + } + } + } + + fn emit_memcpy(&mut self, dest: ir::Value, src: ir::Value, mut size: u32) { + let mut offset = 0; + while size >= 8 { + let val = self + .builder + .ins() + .load(types::I64, ir::MemFlags::trusted(), src, offset); + self.builder + .ins() + .store(ir::MemFlags::trusted(), val, dest, offset); + size -= 8; + offset += 8; + } + if size >= 4 { + let val = self + .builder + .ins() + .load(types::I32, ir::MemFlags::trusted(), src, offset); + self.builder + .ins() + .store(ir::MemFlags::trusted(), val, dest, offset); + size -= 4; + offset += 4; + } + if size >= 2 { + let val = self + .builder + .ins() + .load(types::I16, ir::MemFlags::trusted(), src, offset); + self.builder + .ins() + .store(ir::MemFlags::trusted(), val, dest, offset); + size -= 2; + offset += 2; + } + if size == 1 { + let val = self + .builder + .ins() + .load(types::I8, ir::MemFlags::trusted(), src, offset); + self.builder + .ins() + .store(ir::MemFlags::trusted(), val, dest, offset); + } + } } diff --git a/src/frontend/ast.rs b/src/frontend/ast.rs index f057c2c..766cab7 100644 --- a/src/frontend/ast.rs +++ b/src/frontend/ast.rs @@ -63,6 +63,11 @@ pub enum DeclKind { params: Vec, return_type: P::ReturnType, }, + Struct { + name: String, + name_span: Span, + fields: Vec, + }, } #[derive(Debug, PartialEq, Eq)] @@ -73,12 +78,19 @@ pub struct FunctionParam { } #[derive(Debug, PartialEq, Eq)] +pub struct StructField { + pub name: String, + pub name_span: Span, + pub ty: Type, +} + +#[derive(Debug, Clone, PartialEq, Eq)] pub struct Type { pub kind: TypeKind, pub span: Span, } -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum TypeKind { I8, I16, @@ -92,6 +104,7 @@ pub enum TypeKind { F64, Bool, Pointer(Box), + Struct(String), } #[derive(Debug, PartialEq)] @@ -172,6 +185,23 @@ pub enum ExprKind { callee: Box>, args: Vec>, }, + Struct { + name: String, + name_span: Span, + fields: Vec>, + }, + FieldAccess { + expr: Box>, + field: String, + field_span: Span, + }, +} + +#[derive(Debug, PartialEq)] +pub struct FieldValue { + pub name: String, + pub name_span: Span, + pub value: Expr

, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] diff --git a/src/frontend/lexer.rs b/src/frontend/lexer.rs index 473776f..c20ef4e 100644 --- a/src/frontend/lexer.rs +++ b/src/frontend/lexer.rs @@ -71,6 +71,7 @@ impl<'src> Lexer<'src> { "return" => TokenKind::Return, "let" => TokenKind::Let, "while" => TokenKind::While, + "struct" => TokenKind::Struct, "break" => TokenKind::Break, "continue" => TokenKind::Continue, @@ -251,7 +252,9 @@ mod test { #[test] fn identifiers() { assert_eq!( - tokenize("HELLO _hello _0@ fn if else return let while break continue as foreign"), + tokenize( + "HELLO _hello _0@ fn if else return let while struct break continue as foreign" + ), vec![ Token::new(TokenKind::Identifier, "HELLO", Span::new(0, 5)), Token::new(TokenKind::Identifier, "_hello", Span::new(6, 12)), @@ -263,10 +266,11 @@ mod test { Token::new(TokenKind::Return, "return", Span::new(28, 34)), Token::new(TokenKind::Let, "let", Span::new(35, 38)), Token::new(TokenKind::While, "while", Span::new(39, 44)), - Token::new(TokenKind::Break, "break", Span::new(45, 50)), - Token::new(TokenKind::Continue, "continue", Span::new(51, 59)), - Token::new(TokenKind::As, "as", Span::new(60, 62)), - Token::new(TokenKind::Foreign, "foreign", Span::new(63, 70)), + Token::new(TokenKind::Struct, "struct", Span::new(45, 51)), + Token::new(TokenKind::Break, "break", Span::new(52, 57)), + Token::new(TokenKind::Continue, "continue", Span::new(58, 66)), + Token::new(TokenKind::As, "as", Span::new(67, 69)), + Token::new(TokenKind::Foreign, "foreign", Span::new(70, 77)), ] ) } diff --git a/src/frontend/parser.rs b/src/frontend/parser.rs index eedc55e..983d401 100644 --- a/src/frontend/parser.rs +++ b/src/frontend/parser.rs @@ -117,7 +117,7 @@ impl<'src> Parser<'src> { Ok(decl) => decls.push(decl), Err(err) => { self.errors.push(err); - self.synchronize(&[TokenKind::Fn]); + self.synchronize(&[TokenKind::Fn, TokenKind::Foreign, TokenKind::Struct]); } } } @@ -136,6 +136,7 @@ impl<'src> Parser<'src> { match peek_token.kind { TokenKind::Fn => self.parse_function_decl(), TokenKind::Foreign => self.parse_foreign_function_decl(), + TokenKind::Struct => self.parse_struct_decl(), _ => Err(ParseError::new( format!( @@ -225,6 +226,54 @@ impl<'src> Parser<'src> { }) } + /// Parses a struct declaration. + /// + /// ```ebnf + /// struct_decl = "struct" IDENTIFIER "{" { struct_field } "}" ; + /// struct_field = IDENTIFIER ":" type [ "," ] ; + /// ``` + fn parse_struct_decl(&mut self) -> ParseResult { + let struct_token = self.expect(TokenKind::Struct)?; + + let (name, name_span) = { + let ident_token = self.expect(TokenKind::Identifier)?; + (ident_token.text.to_string(), ident_token.span) + }; + + self.expect(TokenKind::LBrace)?; + + let mut fields = Vec::new(); + while !self.is_at_eof() && !self.is_peek(TokenKind::RBrace) { + let field_ident = self.expect(TokenKind::Identifier)?; + self.expect(TokenKind::Colon)?; + let ty = self.parse_type()?; + + fields.push(StructField { + name: field_ident.text.to_string(), + name_span: field_ident.span, + ty, + }); + + if self.is_peek(TokenKind::Comma) { + self.advance(); + } else { + break; + } + } + + let rbrace_token = self.expect(TokenKind::RBrace)?; + let span = struct_token.span.join(rbrace_token.span); + + Ok(Decl { + kind: DeclKind::Struct { + name, + name_span, + fields, + }, + span, + }) + } + /// Parses the function parameter list. /// /// ```ebnf @@ -265,6 +314,7 @@ impl<'src> Parser<'src> { /// | "u8" | "u16" | "u32" | "u64" /// | "f32" | "f64" /// | "bool" ; + /// | IDENTIFIER /// ``` pub fn parse_type(&mut self) -> ParseResult { let peek_token = self.peek_no_eof()?; @@ -314,6 +364,10 @@ impl<'src> Parser<'src> { self.advance(); TypeKind::Bool } + TokenKind::Identifier => { + let token = self.advance().unwrap(); + TypeKind::Struct(token.text.to_string()) + } TokenKind::Star => { let star_token = self.advance().unwrap(); let inner = self.parse_type()?; @@ -407,7 +461,7 @@ impl<'src> Parser<'src> { fn parse_if_stmt(&mut self) -> ParseResult { let if_token = self.expect(TokenKind::If)?; - let condition = self.parse_expr()?; + let condition = self.parse_expr_no_struct()?; let consequence = self.parse_compound_stmt()?; let alternative = if self.is_peek(TokenKind::Else) { @@ -446,7 +500,7 @@ impl<'src> Parser<'src> { /// ``` fn parse_while_stmt(&mut self) -> ParseResult { let while_token = self.expect(TokenKind::While)?; - let condition = self.parse_expr()?; + let condition = self.parse_expr_no_struct()?; let body = self.parse_compound_stmt()?; let span = while_token.span.join(body.span); @@ -575,12 +629,17 @@ impl<'src> Parser<'src> { /// Parses an expression. pub fn parse_expr(&mut self) -> ParseResult { - self.parse_expr_bp(0) + self.parse_expr_bp(0, true) + } + + /// Parses an expression, disallowing struct literals. + pub fn parse_expr_no_struct(&mut self) -> ParseResult { + self.parse_expr_bp(0, false) } /// Pratt parsing implementation for expressions. - fn parse_expr_bp(&mut self, min_bp: u8) -> ParseResult { - let mut lhs = self.parse_leading_expr()?; + fn parse_expr_bp(&mut self, min_bp: u8, allow_struct: bool) -> ParseResult { + let mut lhs = self.parse_leading_expr(allow_struct)?; loop { let peek_token = self.peek_no_eof()?; @@ -594,7 +653,7 @@ impl<'src> Parser<'src> { } self.advance(); // consume '=' - let rhs = self.parse_expr_bp(right_bp)?; + let rhs = self.parse_expr_bp(right_bp, allow_struct)?; let span = lhs.span.join(rhs.span); lhs = Expr { @@ -662,6 +721,28 @@ impl<'src> Parser<'src> { continue; } + if peek_token.kind == TokenKind::Dot { + let left_bp = 30; // Field access has very high precedence + if left_bp < min_bp { + break; + } + self.advance(); // consume `.` + + let field_token = self.expect(TokenKind::Identifier)?; + let span = lhs.span.join(field_token.span); + + lhs = Expr { + kind: ExprKind::FieldAccess { + expr: Box::new(lhs), + field: field_token.text.to_string(), + field_span: field_token.span, + }, + ty: (), + span, + }; + continue; + } + let Some((op, left_bp, right_bp)) = self.infix_operator(peek_token.kind) else { break; // Not an infix operator }; @@ -672,7 +753,7 @@ impl<'src> Parser<'src> { self.advance(); // consume the operator - let rhs = self.parse_expr_bp(right_bp)?; + let rhs = self.parse_expr_bp(right_bp, allow_struct)?; let span = lhs.span.join(rhs.span); lhs = Expr { @@ -691,20 +772,56 @@ impl<'src> Parser<'src> { /// Parses a leading expression such as identifiers, integer and boolean literals /// or prefix expressions. - fn parse_leading_expr(&mut self) -> ParseResult { + fn parse_leading_expr(&mut self, allow_struct: bool) -> ParseResult { let peek_token = self.peek_no_eof()?; match peek_token.kind { TokenKind::Identifier => { let token = self.advance().unwrap(); - Ok(Expr { - kind: ExprKind::Identifier { - name: token.text.to_string(), - }, - ty: (), - span: token.span, - }) + if allow_struct && self.is_peek(TokenKind::LBrace) { + self.advance(); // consume `{` + + let mut fields = Vec::new(); + while !self.is_at_eof() && !self.is_peek(TokenKind::RBrace) { + let field_token = self.expect(TokenKind::Identifier)?; + self.expect(TokenKind::Colon)?; + let value = self.parse_expr()?; + + fields.push(FieldValue { + name: field_token.text.to_string(), + name_span: field_token.span, + value, + }); + + if self.is_peek(TokenKind::Comma) { + self.advance(); + } else { + break; + } + } + + let rbrace_token = self.expect(TokenKind::RBrace)?; + let span = token.span.join(rbrace_token.span); + + Ok(Expr { + kind: ExprKind::Struct { + name: token.text.to_string(), + name_span: token.span, + fields, + }, + ty: (), + span, + }) + } else { + Ok(Expr { + kind: ExprKind::Identifier { + name: token.text.to_string(), + }, + ty: (), + span: token.span, + }) + } } TokenKind::IntegerLit => { @@ -754,7 +871,7 @@ impl<'src> Parser<'src> { TokenKind::LParen => { let lparen = self.advance().unwrap(); - let expr = self.parse_expr_bp(0)?; + let expr = self.parse_expr()?; // Inner expressions allow struct literals let rparen = self.expect(TokenKind::RParen)?; Ok(Expr { @@ -766,7 +883,7 @@ impl<'src> Parser<'src> { kind if let Some((op, r_bp)) = self.prefix_operator(kind) => { let op_token = self.advance().unwrap(); - let rhs = self.parse_expr_bp(r_bp)?; + let rhs = self.parse_expr_bp(r_bp, allow_struct)?; Ok(Expr { ty: (), @@ -1356,4 +1473,108 @@ mod test { }) ); } + + #[test] + fn struct_decl() { + assert_eq!( + parse("struct Vec3 { x: f32, y: f32 }", Parser::parse_decl), + Success(Decl { + kind: DeclKind::Struct { + name: "Vec3".to_string(), + name_span: Span::new(7, 11), + fields: vec![ + StructField { + name: "x".to_string(), + name_span: Span::new(14, 15), + ty: Type { + kind: TypeKind::F32, + span: Span::new(17, 20) + } + }, + StructField { + name: "y".to_string(), + name_span: Span::new(22, 23), + ty: Type { + kind: TypeKind::F32, + span: Span::new(25, 28) + } + } + ], + }, + span: Span::new(0, 30) + }) + ); + } + + #[test] + fn struct_literal() { + assert_eq!( + parse("Vec3 { x: 1.0, y: -3.5 }", Parser::parse_expr), + Success(Expr { + kind: ExprKind::Struct { + name: "Vec3".to_string(), + name_span: Span::new(0, 4), + fields: vec![ + FieldValue { + name: "x".to_string(), + name_span: Span::new(7, 8), + value: Expr { + kind: ExprKind::Float { value: 1.0 }, + ty: (), + span: Span::new(10, 13) + } + }, + FieldValue { + name: "y".to_string(), + name_span: Span::new(15, 16), + value: Expr { + kind: ExprKind::Unary { + op: UnaryOp::Neg, + expr: Box::new(Expr { + kind: ExprKind::Float { value: 3.5 }, + ty: (), + span: Span::new(19, 22) + }) + }, + ty: (), + span: Span::new(18, 22) + } + } + ] + }, + ty: (), + span: Span::new(0, 24) + }) + ); + } + + #[test] + fn field_access() { + assert_eq!( + parse("a.b.c", Parser::parse_expr), + Success(Expr { + kind: ExprKind::FieldAccess { + expr: Box::new(Expr { + kind: ExprKind::FieldAccess { + expr: Box::new(Expr { + kind: ExprKind::Identifier { + name: "a".to_string() + }, + ty: (), + span: Span::new(0, 1) + }), + field: "b".to_string(), + field_span: Span::new(2, 3), + }, + ty: (), + span: Span::new(0, 3) + }), + field: "c".to_string(), + field_span: Span::new(4, 5), + }, + ty: (), + span: Span::new(0, 5) + }) + ); + } } diff --git a/src/frontend/sema.rs b/src/frontend/sema.rs index d7dce8a..42e548e 100644 --- a/src/frontend/sema.rs +++ b/src/frontend/sema.rs @@ -42,6 +42,7 @@ pub enum Ty { Var(usize), Function(Vec, Box), Pointer(Box), + Struct(String), } impl Ty { @@ -108,6 +109,7 @@ impl From<&TypeKind> for Ty { TypeKind::F64 => Ty::F64, TypeKind::Bool => Ty::Bool, TypeKind::Pointer(inner) => Ty::Pointer(Box::new(Ty::from(&inner.kind))), + TypeKind::Struct(name) => Ty::Struct(name.clone()), } } } @@ -120,6 +122,7 @@ impl TypedExpr { TypedExprKind::Unary { op: UnaryOp::Deref, .. } => true, + TypedExprKind::FieldAccess { expr, .. } => expr.is_lvalue(), _ => false, } } @@ -132,6 +135,7 @@ pub struct Sema { next_var: usize, subst: HashMap, scopes: Vec>, + structs: HashMap>, errors: Vec, deferred_unary_neg: Vec<(Span, Ty, Ty, Option)>, deferred_binary: Vec<(Span, Ty)>, @@ -148,6 +152,7 @@ impl Sema { next_var: 0, subst: HashMap::new(), scopes: Vec::new(), + structs: HashMap::new(), errors: Vec::new(), deferred_unary_neg: Vec::new(), deferred_binary: Vec::new(), @@ -290,6 +295,13 @@ impl Sema { self.bind(name, Ty::Function(param_tys, Box::new(ret_ty))); } + DeclKind::Struct { name, fields, .. } => { + let typed_fields = fields + .iter() + .map(|f| (f.name.clone(), Ty::from(&f.ty.kind))) + .collect(); + self.structs.insert(name.clone(), typed_fields); + } } } @@ -376,6 +388,25 @@ impl Sema { span: decl.span, } } + DeclKind::Struct { + name, + name_span, + fields, + } => TypedDecl { + kind: TypedDeclKind::Struct { + name: name.clone(), + name_span: *name_span, + fields: fields + .iter() + .map(|f| StructField { + name: f.name.clone(), + name_span: f.name_span, + ty: f.ty.clone(), + }) + .collect(), + }, + span: decl.span, + }, } } @@ -786,6 +817,133 @@ impl Sema { span: expr.span, } } + ExprKind::Struct { + name, + name_span, + fields, + } => { + let mut typed_fields = Vec::new(); + let mut provided_fields = HashMap::new(); + + if let Some(struct_def) = self.structs.get(name).cloned() { + for field in fields { + let typed_value = self.analyze_expr(&field.value); + + if let Some((_, expected_ty)) = + struct_def.iter().find(|(n, _)| n == &field.name) + { + if let Err(e) = self.unify(&typed_value.ty, expected_ty) { + self.errors.push(SemanticError::new(e, field.value.span)); + } + } else { + self.errors.push(SemanticError::new( + format!("struct `{}` has no field named `{}`", name, field.name), + field.name_span, + )); + } + + if provided_fields + .insert(field.name.clone(), field.name_span) + .is_some() + { + self.errors.push(SemanticError::new( + format!("field `{}` specified more than once", field.name), + field.name_span, + )); + } + + typed_fields.push(FieldValue { + name: field.name.clone(), + name_span: field.name_span, + value: typed_value, + }); + } + + for (expected_field, _) in struct_def { + if !provided_fields.contains_key(&expected_field) { + self.errors.push(SemanticError::new( + format!( + "missing field `{}` in initializer of `{}`", + expected_field, name + ), + expr.span, + )); + } + } + } else { + self.errors.push(SemanticError::new( + format!("undeclared struct `{}`", name), + *name_span, + )); + for field in fields { + let typed_value = self.analyze_expr(&field.value); + typed_fields.push(FieldValue { + name: field.name.clone(), + name_span: field.name_span, + value: typed_value, + }); + } + } + + TypedExpr { + kind: TypedExprKind::Struct { + name: name.clone(), + name_span: *name_span, + fields: typed_fields, + }, + ty: Ty::Struct(name.clone()), + span: expr.span, + } + } + ExprKind::FieldAccess { + expr: inner_expr, + field, + field_span, + } => { + let typed_inner = self.analyze_expr(inner_expr); + let result_ty = self.new_var(); + + let inner_ty_resolved = self.apply_subst(&typed_inner.ty); + match inner_ty_resolved { + Ty::Struct(ref struct_name) => { + if let Some(struct_def) = self.structs.get(struct_name).cloned() { + if let Some((_, field_ty)) = struct_def.iter().find(|(n, _)| n == field) + { + if let Err(e) = self.unify(&result_ty, field_ty) { + self.errors.push(SemanticError::new(e, *field_span)); + } + } else { + self.errors.push(SemanticError::new( + format!("no field `{}` on type `{}`", field, struct_name), + *field_span, + )); + } + } + } + Ty::Var(_) => { + self.errors.push(SemanticError::new( + "type of expression must be known to access a field", + inner_expr.span, + )); + } + _ => { + self.errors.push(SemanticError::new( + format!("cannot access field `{}` on a non-struct type", field), + *field_span, + )); + } + } + + TypedExpr { + kind: TypedExprKind::FieldAccess { + expr: Box::new(typed_inner), + field: field.clone(), + field_span: *field_span, + }, + ty: result_ty, + span: expr.span, + } + } } } @@ -827,6 +985,15 @@ impl Sema { .collect(), return_type: self.apply_subst(&return_type), }, + TypedDeclKind::Struct { + name, + name_span, + fields, + } => TypedDeclKind::Struct { + name, + name_span, + fields, + }, }; TypedDecl { kind, span } @@ -916,6 +1083,31 @@ impl Sema { callee: Box::new(self.apply_subst_expr(*callee)), args: args.into_iter().map(|a| self.apply_subst_expr(a)).collect(), }, + TypedExprKind::Struct { + name, + name_span, + fields, + } => TypedExprKind::Struct { + name, + name_span, + fields: fields + .into_iter() + .map(|f| FieldValue { + name: f.name, + name_span: f.name_span, + value: self.apply_subst_expr(f.value), + }) + .collect(), + }, + TypedExprKind::FieldAccess { + expr, + field, + field_span, + } => TypedExprKind::FieldAccess { + expr: Box::new(self.apply_subst_expr(*expr)), + field, + field_span, + }, }; TypedExpr { kind, ty, span } @@ -1308,4 +1500,39 @@ mod test { let src = "fn test() { let a: i32 = 5; let b: *f32 = &a as *f32; }"; assert!(analyze(src).is_ok()); } + + #[test] + fn valid_struct() { + let src = " + struct Vec3 { x: f32, y: f32, z: f32 } + fn make_vec() -> Vec3 { + return Vec3 { x: 1.0, y: 2.0, z: 3.0 }; + } + fn get_x(v: Vec3) -> f32 { + return v.x; + } + "; + assert!(analyze(src).is_ok()); + } + + #[test] + fn invalid_struct_field() { + let src = "struct Vec2 { x: f32, y: f32 } fn test(v: Vec2) -> f32 { return v.z; }"; + let errors = analyze(src).unwrap_err(); + assert!( + errors + .iter() + .any(|e| e.message.contains("no field `z` on type `Vec2`")) + ); + } + + #[test] + fn missing_struct_initializer_field() { + let src = "struct Vec2 { x: f32, y: f32 } fn test() -> Vec2 { return Vec2 { x: 1.0 }; }"; + let errors = analyze(src).unwrap_err(); + assert!(errors.iter().any(|e| { + e.message + .contains("missing field `y` in initializer of `Vec2`") + })); + } } diff --git a/src/frontend/token.rs b/src/frontend/token.rs index a7ce4da..9c58600 100644 --- a/src/frontend/token.rs +++ b/src/frontend/token.rs @@ -64,6 +64,7 @@ pub enum TokenKind { Return, Let, While, + Struct, Break, Continue, @@ -130,6 +131,7 @@ impl Display for TokenKind { TokenKind::Return => "`return`", TokenKind::Let => "`let`", TokenKind::While => "`while`", + TokenKind::Struct => "`struct`", TokenKind::Break => "`break`", TokenKind::Continue => "`continue`", TokenKind::I8 => "`i8`", diff --git a/src/middle/builder.rs b/src/middle/builder.rs index ae912d5..6dcb447 100644 --- a/src/middle/builder.rs +++ b/src/middle/builder.rs @@ -12,6 +12,20 @@ impl MirBuilder { pub fn build(module: &TypedModule) -> MirModule { let mut extern_functions = Vec::new(); let mut functions = Vec::new(); + let mut structs = HashMap::new(); + + // Collect struct layouts so the backend knows their sizes + for decl in &module.decls { + if let TypedDeclKind::Struct { name, fields, .. } = &decl.kind { + structs.insert( + name.clone(), + fields + .iter() + .map(|f| (f.name.clone(), Ty::from(&f.ty.kind))) + .collect(), + ); + } + } for decl in &module.decls { match &decl.kind { @@ -22,7 +36,23 @@ impl MirBuilder { return_type, body, } => { - let mut builder = FuncBuilder::new(name.clone(), return_type.clone()); + let is_sret = matches!(return_type, Ty::Struct(_)); + let mir_return_type = if is_sret { + Ty::Unit + } else { + return_type.clone() + }; + let mut builder = FuncBuilder::new(name.clone(), mir_return_type); + + // Implement implicit pass-by-reference for struct return values (sret) + if is_sret { + let sret_id = builder.new_local( + "$sret".to_string(), + Ty::Pointer(Box::new(return_type.clone())), + ); + builder.params.push(sret_id); + builder.sret_local = Some(sret_id); + } // Register parameters as local variables for (param_name, ty) in params { @@ -66,10 +96,12 @@ impl MirBuilder { return_type: return_type.clone(), }); } + TypedDeclKind::Struct { .. } => {} } } MirModule { + structs, extern_functions, functions, } @@ -104,6 +136,8 @@ struct FuncBuilder { /// Stack of `(continue_target, break_target)` for nested loops loop_stack: Vec<(BlockId, BlockId)>, + /// Local ID mapped to the hidden `$sret` return pointer (if applicable) + sret_local: Option, } impl FuncBuilder { @@ -120,6 +154,7 @@ impl FuncBuilder { next_block_id: 0, scopes: vec![HashMap::new()], loop_stack: Vec::new(), + sret_local: None, } } @@ -323,11 +358,26 @@ impl FuncBuilder { } } TypedStmtKind::Return { value } => { - let val_op = value.as_ref().map(|v| self.lower_expr(v)); - self.terminate(Terminator { - kind: TerminatorKind::Return { value: val_op }, - span: stmt.span, - }); + if let Some(sret_id) = self.sret_local { + let val_op = self.lower_expr(value.as_ref().unwrap()); + self.emit_stmt(Statement { + kind: StatementKind::Store { + ptr: Operand::Copy(sret_id), + val: Rvalue::Use(val_op), + }, + span: stmt.span, + }); + self.terminate(Terminator { + kind: TerminatorKind::Return { value: None }, + span: stmt.span, + }); + } else { + let val_op = value.as_ref().map(|v| self.lower_expr(v)); + self.terminate(Terminator { + kind: TerminatorKind::Return { value: val_op }, + span: stmt.span, + }); + } } TypedStmtKind::Let { name, @@ -365,45 +415,27 @@ impl FuncBuilder { Operand::Constant(ConstantValue::Float(*value, expr.ty.clone())) } TypedExprKind::Boolean { value } => Operand::Constant(ConstantValue::Boolean(*value)), - TypedExprKind::Unary { op, expr: inner } => { - match op { - UnaryOp::AddressOf => match &inner.kind { - TypedExprKind::Identifier { name } => { - let id = self.lookup(name); - self.locals[id.0].address_taken = true; - let temp = self.new_temp(expr.ty.clone()); - self.emit_stmt(Statement { - kind: StatementKind::Assign(temp, Rvalue::AddressOf(id)), - span: expr.span, - }); - Operand::Copy(temp) - } - TypedExprKind::Unary { - op: UnaryOp::Deref, - expr: ptr_expr, - } => self.lower_expr(ptr_expr), // `&*ptr` is optimized right back to `ptr`! - _ => unreachable!("invalid lvalue for addressof"), - }, - UnaryOp::Deref => { - let inner_op = self.lower_expr(inner); - let temp = self.new_temp(expr.ty.clone()); - self.emit_stmt(Statement { - kind: StatementKind::Assign(temp, Rvalue::ReadPointer(inner_op)), - span: expr.span, - }); - Operand::Copy(temp) - } - _ => { - let inner_op = self.lower_expr(inner); - let temp = self.new_temp(expr.ty.clone()); - self.emit_stmt(Statement { - kind: StatementKind::Assign(temp, Rvalue::UnaryOp(*op, inner_op)), - span: expr.span, - }); - Operand::Copy(temp) - } + TypedExprKind::Unary { op, expr: inner } => match op { + UnaryOp::AddressOf => self.lower_address_of(inner), + UnaryOp::Deref => { + let inner_op = self.lower_expr(inner); + let temp = self.new_temp(expr.ty.clone()); + self.emit_stmt(Statement { + kind: StatementKind::Assign(temp, Rvalue::ReadPointer(inner_op)), + span: expr.span, + }); + Operand::Copy(temp) } - } + _ => { + let inner_op = self.lower_expr(inner); + let temp = self.new_temp(expr.ty.clone()); + self.emit_stmt(Statement { + kind: StatementKind::Assign(temp, Rvalue::UnaryOp(*op, inner_op)), + span: expr.span, + }); + Operand::Copy(temp) + } + }, TypedExprKind::Binary { op, lhs, rhs } => { let lhs_op = self.lower_expr(lhs); let rhs_op = self.lower_expr(rhs); @@ -427,11 +459,8 @@ impl FuncBuilder { span: expr.span, }); } - TypedExprKind::Unary { - op: UnaryOp::Deref, - expr: ptr_expr, - } => { - let ptr_op = self.lower_expr(ptr_expr); + _ => { + let ptr_op = self.lower_address_of(lval); self.emit_stmt(Statement { kind: StatementKind::Store { ptr: ptr_op, @@ -440,7 +469,6 @@ impl FuncBuilder { span: expr.span, }); } - _ => unreachable!("invalid lval in MIR lowering"), } rval_op @@ -466,20 +494,43 @@ impl FuncBuilder { }; let mut arg_ops = Vec::new(); + let is_sret = matches!(expr.ty, Ty::Struct(_)); + let mut sret_temp = None; + + // Implement implicit sret struct passing to the callee + if is_sret { + let temp = self.new_temp(expr.ty.clone()); + self.locals[temp.0].address_taken = true; + + let ptr_temp = self.new_temp(Ty::Pointer(Box::new(expr.ty.clone()))); + self.emit_stmt(Statement { + kind: StatementKind::Assign(ptr_temp, Rvalue::AddressOf(temp)), + span: expr.span, + }); + + arg_ops.push(Operand::Copy(ptr_temp)); + sret_temp = Some(temp); + } + for arg in args { arg_ops.push(self.lower_expr(arg)); } - let rval = Rvalue::Call(callee_name, arg_ops, expr.ty.clone()); + let mir_ret_ty = if is_sret { Ty::Unit } else { expr.ty.clone() }; + let rval = Rvalue::Call(callee_name, arg_ops, mir_ret_ty.clone()); - if expr.ty == Ty::Unit { + if mir_ret_ty == Ty::Unit { self.emit_stmt(Statement { kind: StatementKind::SideEffect(rval), span: expr.span, }); - Operand::Constant(ConstantValue::Boolean(false)) // Dummy value for Unit assignments + if let Some(temp) = sret_temp { + Operand::Copy(temp) + } else { + Operand::Constant(ConstantValue::Boolean(false)) // Dummy value for Unit + } } else { - let temp = self.new_temp(expr.ty.clone()); + let temp = self.new_temp(mir_ret_ty); self.emit_stmt(Statement { kind: StatementKind::Assign(temp, rval), span: expr.span, @@ -487,6 +538,116 @@ impl FuncBuilder { Operand::Copy(temp) } } + TypedExprKind::Struct { + name, + name_span: _, + fields, + } => { + let local_id = self.new_temp(expr.ty.clone()); + self.locals[local_id.0].address_taken = true; + + let base_ptr_temp = self.new_temp(Ty::Pointer(Box::new(expr.ty.clone()))); + self.emit_stmt(Statement { + kind: StatementKind::Assign(base_ptr_temp, Rvalue::AddressOf(local_id)), + span: expr.span, + }); + + for field in fields { + let val_op = self.lower_expr(&field.value); + let field_ptr_temp = + self.new_temp(Ty::Pointer(Box::new(field.value.ty.clone()))); + + self.emit_stmt(Statement { + kind: StatementKind::Assign( + field_ptr_temp, + Rvalue::GetFieldPtr { + base_ptr: Operand::Copy(base_ptr_temp), + struct_name: name.clone(), + field_name: field.name.clone(), + }, + ), + span: field.name_span, + }); + + self.emit_stmt(Statement { + kind: StatementKind::Store { + ptr: Operand::Copy(field_ptr_temp), + val: Rvalue::Use(val_op), + }, + span: field.value.span, + }); + } + + Operand::Copy(local_id) + } + TypedExprKind::FieldAccess { .. } => { + let ptr_op = self.lower_address_of(expr); + let temp = self.new_temp(expr.ty.clone()); + self.emit_stmt(Statement { + kind: StatementKind::Assign(temp, Rvalue::ReadPointer(ptr_op)), + span: expr.span, + }); + Operand::Copy(temp) + } + } + } + + /// Safely computes the memory address of an expression without extracting or dereferencing + /// its underlying evaluated contents explicitly. + fn lower_address_of(&mut self, expr: &TypedExpr) -> Operand { + match &expr.kind { + TypedExprKind::Identifier { name } => { + let id = self.lookup(name); + self.locals[id.0].address_taken = true; + let temp = self.new_temp(Ty::Pointer(Box::new(expr.ty.clone()))); + self.emit_stmt(Statement { + kind: StatementKind::Assign(temp, Rvalue::AddressOf(id)), + span: expr.span, + }); + Operand::Copy(temp) + } + TypedExprKind::Unary { + op: UnaryOp::Deref, + expr: ptr_expr, + } => self.lower_expr(ptr_expr), // `&*ptr` is optimized right back to `ptr`! + TypedExprKind::FieldAccess { + expr: base, + field, + field_span: _, + } => { + let base_ptr = self.lower_address_of(base); + let struct_name = match &base.ty { + Ty::Struct(name) => name.clone(), + _ => unreachable!("field access on non-struct"), + }; + let temp = self.new_temp(Ty::Pointer(Box::new(expr.ty.clone()))); + self.emit_stmt(Statement { + kind: StatementKind::Assign( + temp, + Rvalue::GetFieldPtr { + base_ptr, + struct_name, + field_name: field.clone(), + }, + ), + span: expr.span, + }); + Operand::Copy(temp) + } + _ => { + let val_op = self.lower_expr(expr); + if let Operand::Copy(id) = val_op { + self.locals[id.0].address_taken = true; + let temp = self.new_temp(Ty::Pointer(Box::new(expr.ty.clone()))); + self.emit_stmt(Statement { + kind: StatementKind::Assign(temp, Rvalue::AddressOf(id)), + span: expr.span, + }); + Operand::Copy(temp) + } else { + unreachable!("cannot safely take address of constant rvalue") + } + } } } } diff --git a/src/middle/dce.rs b/src/middle/dce.rs index 8034250..df39e3a 100644 --- a/src/middle/dce.rs +++ b/src/middle/dce.rs @@ -94,6 +94,7 @@ mod test { #[test] fn test_eliminate_dead_blocks() { let mut module = MirModule { + structs: HashMap::new(), functions: vec![MirFunction { name: "test_func".to_string(), params: vec![], @@ -145,6 +146,7 @@ mod test { #[test] fn test_eliminate_dead_cond_branch() { let mut module = MirModule { + structs: HashMap::new(), functions: vec![MirFunction { name: "test_cond_func".to_string(), params: vec![], diff --git a/src/middle/fold.rs b/src/middle/fold.rs index 5bb93e5..e2b32ec 100644 --- a/src/middle/fold.rs +++ b/src/middle/fold.rs @@ -94,6 +94,7 @@ fn propagate_rvalue(rvalue: &mut Rvalue, known_constants: &HashMap {} Rvalue::ReadPointer(op) => propagate_operand(op, known_constants), + Rvalue::GetFieldPtr { base_ptr, .. } => propagate_operand(base_ptr, known_constants), } } @@ -334,6 +335,7 @@ mod test { }; let mut module = MirModule { + structs: HashMap::new(), functions: vec![func], extern_functions: vec![], }; @@ -382,6 +384,7 @@ mod test { }; let mut module = MirModule { + structs: HashMap::new(), functions: vec![func], extern_functions: vec![], }; diff --git a/src/middle/mir.rs b/src/middle/mir.rs index d52e7c2..bda93a5 100644 --- a/src/middle/mir.rs +++ b/src/middle/mir.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use crate::frontend::ast::{BinaryOp, UnaryOp}; use crate::frontend::sema::Ty; use crate::frontend::token::Span; @@ -12,6 +14,7 @@ pub struct LocalId(pub usize); #[derive(Debug)] pub struct MirModule { + pub structs: HashMap>, pub extern_functions: Vec, pub functions: Vec, } @@ -78,6 +81,11 @@ pub enum Rvalue { Call(String, Vec, Ty), AddressOf(LocalId), ReadPointer(Operand), + GetFieldPtr { + base_ptr: Operand, + struct_name: String, + field_name: String, + }, } /// An atomic value used as inputs to instructions. diff --git a/tests/struct_basic.test b/tests/struct_basic.test new file mode 100644 index 0000000..7b187f3 --- /dev/null +++ b/tests/struct_basic.test @@ -0,0 +1,29 @@ +[code] +foreign fn putchar(c: i32) -> i32; + +struct Point { + x: i32, + y: i32 +} + +fn print_num(n: i32) { + // Simple hack to print a 2-digit number for testing + putchar(48 + (n / 10)); + putchar(48 + (n % 10)); + putchar(10); // newline +} + +fn main() -> i32 { + let p = Point { x: 40, y: 2 }; + + // 40 + 2 = 42 + print_num(p.x + p.y); + + return 0; +} + +[expected_return_code] +0 + +[expected_output] +42 \ No newline at end of file diff --git a/tests/struct_pointers.test b/tests/struct_pointers.test new file mode 100644 index 0000000..d14865f --- /dev/null +++ b/tests/struct_pointers.test @@ -0,0 +1,38 @@ +[code] +foreign fn putchar(c: i32) -> i32; + +struct Rect { + width: i32, + height: i32 +} + +fn modify_rect(r: *Rect) { + let temp: Rect = *r; + temp.width = temp.width + 10; + temp.height = temp.height + 20; + *r = temp; +} + +fn print_num(n: i32) { + putchar(48 + (n / 10)); + putchar(48 + (n % 10)); + putchar(10); +} + +fn main() -> i32 { + let r = Rect { width: 15, height: 25 }; + + modify_rect(&r); + + print_num(r.width); + print_num(r.height); + + return 0; +} + +[expected_return_code] +0 + +[expected_output] +25 +45 \ No newline at end of file