feat: add support for structs and member access

This commit is contained in:
2026-04-22 23:49:13 +02:00
parent f3c93fa516
commit ec2aa771fa
12 changed files with 1006 additions and 101 deletions
+203 -23
View File
@@ -92,7 +92,7 @@ impl CraneliftBackend {
}
for func in &module.functions {
self.compile_function(func);
self.compile_function(func, &module.structs);
// Run Cranelift's optimization passes before emitting the text IR
let mut ctrl_plane = ControlPlane::default();
@@ -113,7 +113,11 @@ impl CraneliftBackend {
}
/// Lowers a single MIR function into Cranelift IR.
fn compile_function(&mut self, func: &MirFunction) {
fn compile_function(
&mut self,
func: &MirFunction,
structs: &HashMap<String, Vec<(String, Ty)>>,
) {
let mut sig = self.module.make_signature();
for param_id in &func.params {
@@ -139,9 +143,8 @@ impl CraneliftBackend {
let mut var_map = HashMap::new();
let mut stack_slot_map = HashMap::new();
for local in &func.locals {
if local.address_taken {
let cl_ty = Self::lower_type(&local.ty);
let bytes = cl_ty.bytes();
if local.address_taken || matches!(local.ty, Ty::Struct(_)) {
let bytes = Self::type_size(&local.ty, structs);
let slot = builder.create_sized_stack_slot(ir::StackSlotData::new(
ir::StackSlotKind::ExplicitSlot,
bytes,
@@ -162,6 +165,7 @@ impl CraneliftBackend {
locals: &func.locals,
module: &mut self.module,
func_ids: &self.func_ids,
structs,
};
if let Some(first_block) = func.blocks.first() {
@@ -179,7 +183,14 @@ impl CraneliftBackend {
if i == 0 {
for (j, param_id) in func.params.iter().enumerate() {
let val = trans.builder.block_params(cl_block)[j];
if let Some(&slot) = trans.stack_slot_map.get(param_id) {
if matches!(func.locals[param_id.0].ty, Ty::Struct(_)) {
let slot = trans.stack_slot_map[param_id];
let dest_addr = trans.builder.ins().stack_addr(types::I64, slot, 0);
let size =
CraneliftBackend::type_size(&func.locals[param_id.0].ty, trans.structs);
trans.emit_memcpy(dest_addr, val, size);
} else if let Some(&slot) = trans.stack_slot_map.get(param_id) {
trans.builder.ins().stack_store(val, slot, 0);
} else {
trans.builder.def_var(trans.var_map[param_id], val);
@@ -212,10 +223,56 @@ impl CraneliftBackend {
Ty::F32 => types::F32,
Ty::F64 => types::F64,
Ty::Bool => types::I8, // Booleans are represented as 8-bit integers
Ty::Pointer(_) => types::I64, // Assume 64-bit environment pointers
Ty::Pointer(_) | Ty::Struct(_) => types::I64, // Structs are passed by reference implicitly
_ => unimplemented!("Unsupported type for Cranelift lowering: {:?}", ty),
}
}
fn type_size(ty: &Ty, structs: &HashMap<String, Vec<(String, Ty)>>) -> u32 {
match ty {
Ty::I8 | Ty::U8 | Ty::Bool => 1,
Ty::I16 | Ty::U16 => 2,
Ty::I32 | Ty::U32 | Ty::F32 => 4,
Ty::I64 | Ty::U64 | Ty::F64 | Ty::Pointer(_) => 8,
Ty::Struct(name) => {
let mut size = 0;
if let Some(fields) = structs.get(name) {
for (_, f_ty) in fields {
let f_size = Self::type_size(f_ty, structs);
let align = f_size.min(8);
size = (size + align - 1) & !(align - 1);
size += f_size;
}
size = (size + 7) & !7;
}
size
}
Ty::Unit => 0,
Ty::Var(_) | Ty::Function(_, _) => unimplemented!(),
}
}
fn field_offset(
struct_name: &str,
field_name: &str,
structs: &HashMap<String, Vec<(String, Ty)>>,
) -> u32 {
let mut offset = 0;
if let Some(fields) = structs.get(struct_name) {
for (f_name, f_ty) in fields {
let f_size = Self::type_size(f_ty, structs);
let align = f_size.min(8);
offset = (offset + align - 1) & !(align - 1);
if f_name == field_name {
return offset;
}
offset += f_size;
}
}
panic!("Field not found");
}
}
/// A visitor that traverses MIR basic blocks and instructions, emitting Cranelift IR instructions
@@ -228,18 +285,29 @@ struct FunctionTranslator<'a> {
locals: &'a [LocalDecl],
module: &'a mut ObjectModule,
func_ids: &'a HashMap<String, cranelift_module::FuncId>,
structs: &'a HashMap<String, Vec<(String, Ty)>>,
}
impl<'a> FunctionTranslator<'a> {
fn translate_stmt(&mut self, stmt: &Statement) {
match &stmt.kind {
StatementKind::Assign(local_id, rvalue) => {
let val = self.translate_rvalue(rvalue);
if let Some(v) = val {
if let Some(&slot) = self.stack_slot_map.get(local_id) {
if let Some(&slot) = self.stack_slot_map.get(local_id) {
if let Ty::Struct(name) = &self.locals[local_id.0].ty {
let dest_addr = self.builder.ins().stack_addr(types::I64, slot, 0);
if let Some(src_addr) = self.translate_rvalue(rvalue) {
let size = CraneliftBackend::type_size(
&Ty::Struct(name.clone()),
self.structs,
);
self.emit_memcpy(dest_addr, src_addr, size);
}
} else if let Some(v) = self.translate_rvalue(rvalue) {
self.builder.ins().stack_store(v, slot, 0);
} else {
let var = self.var_map[local_id];
}
} else {
let var = self.var_map[local_id];
if let Some(v) = self.translate_rvalue(rvalue) {
self.builder.def_var(var, v);
}
}
@@ -249,10 +317,18 @@ impl<'a> FunctionTranslator<'a> {
}
StatementKind::Store { ptr, val } => {
let ptr_val = self.translate_operand(ptr);
if let Some(v) = self.translate_rvalue(val) {
self.builder
.ins()
.store(ir::MemFlags::trusted(), v, ptr_val, 0);
let rval_ty = self.get_rvalue_type(val);
if matches!(rval_ty, Ty::Struct(_)) {
if let Some(src_addr) = self.translate_rvalue(val) {
let size = CraneliftBackend::type_size(&rval_ty, self.structs);
self.emit_memcpy(ptr_val, src_addr, size);
}
} else {
if let Some(v) = self.translate_rvalue(val) {
self.builder
.ins()
.store(ir::MemFlags::trusted(), v, ptr_val, 0);
}
}
}
}
@@ -303,7 +379,10 @@ impl<'a> FunctionTranslator<'a> {
fn translate_operand(&mut self, op: &Operand) -> ir::Value {
match op {
Operand::Copy(local_id) => {
if let Some(&slot) = self.stack_slot_map.get(local_id) {
if matches!(self.locals[local_id.0].ty, Ty::Struct(_)) {
let slot = self.stack_slot_map[local_id];
self.builder.ins().stack_addr(types::I64, slot, 0)
} else if let Some(&slot) = self.stack_slot_map.get(local_id) {
let cl_ty = CraneliftBackend::lower_type(&self.locals[local_id.0].ty);
self.builder.ins().stack_load(cl_ty, slot, 0)
} else {
@@ -577,13 +656,114 @@ impl<'a> FunctionTranslator<'a> {
Ty::Pointer(inner) => *inner,
_ => unreachable!(),
};
let cl_ty = CraneliftBackend::lower_type(&inner_ty);
Some(
self.builder
.ins()
.load(cl_ty, ir::MemFlags::trusted(), ptr_val, 0),
)
if matches!(inner_ty, Ty::Struct(_)) {
Some(ptr_val)
} else {
let cl_ty = CraneliftBackend::lower_type(&inner_ty);
Some(
self.builder
.ins()
.load(cl_ty, ir::MemFlags::trusted(), ptr_val, 0),
)
}
}
Rvalue::GetFieldPtr {
base_ptr,
struct_name,
field_name,
} => {
let base = self.translate_operand(base_ptr);
let offset = CraneliftBackend::field_offset(struct_name, field_name, self.structs);
Some(self.builder.ins().iadd_imm(base, offset as i64))
}
}
}
fn get_rvalue_type(&self, rvalue: &Rvalue) -> Ty {
match rvalue {
Rvalue::Use(op) => self.get_operand_type(op),
Rvalue::UnaryOp(op, inner) => match op {
UnaryOp::Deref => {
if let Ty::Pointer(inner) = self.get_operand_type(inner) {
*inner
} else {
unreachable!()
}
}
UnaryOp::AddressOf => Ty::Pointer(Box::new(self.get_operand_type(inner))),
_ => self.get_operand_type(inner),
},
Rvalue::BinaryOp(_, lhs, _) => self.get_operand_type(lhs),
Rvalue::Cast(ty, _) => ty.clone(),
Rvalue::Call(_, _, ty) => ty.clone(),
Rvalue::AddressOf(local) => Ty::Pointer(Box::new(self.locals[local.0].ty.clone())),
Rvalue::ReadPointer(ptr) => {
if let Ty::Pointer(inner) = self.get_operand_type(ptr) {
*inner
} else {
unreachable!()
}
}
Rvalue::GetFieldPtr {
struct_name,
field_name,
..
} => {
let fields = self.structs.get(struct_name).unwrap();
let ty = fields
.iter()
.find(|(n, _)| n == field_name)
.unwrap()
.1
.clone();
Ty::Pointer(Box::new(ty))
}
}
}
fn emit_memcpy(&mut self, dest: ir::Value, src: ir::Value, mut size: u32) {
let mut offset = 0;
while size >= 8 {
let val = self
.builder
.ins()
.load(types::I64, ir::MemFlags::trusted(), src, offset);
self.builder
.ins()
.store(ir::MemFlags::trusted(), val, dest, offset);
size -= 8;
offset += 8;
}
if size >= 4 {
let val = self
.builder
.ins()
.load(types::I32, ir::MemFlags::trusted(), src, offset);
self.builder
.ins()
.store(ir::MemFlags::trusted(), val, dest, offset);
size -= 4;
offset += 4;
}
if size >= 2 {
let val = self
.builder
.ins()
.load(types::I16, ir::MemFlags::trusted(), src, offset);
self.builder
.ins()
.store(ir::MemFlags::trusted(), val, dest, offset);
size -= 2;
offset += 2;
}
if size == 1 {
let val = self
.builder
.ins()
.load(types::I8, ir::MemFlags::trusted(), src, offset);
self.builder
.ins()
.store(ir::MemFlags::trusted(), val, dest, offset);
}
}
}
+31 -1
View File
@@ -63,6 +63,11 @@ pub enum DeclKind<P: Phase = Untyped> {
params: Vec<P::ParamType>,
return_type: P::ReturnType,
},
Struct {
name: String,
name_span: Span,
fields: Vec<StructField>,
},
}
#[derive(Debug, PartialEq, Eq)]
@@ -73,12 +78,19 @@ pub struct FunctionParam {
}
#[derive(Debug, PartialEq, Eq)]
pub struct StructField {
pub name: String,
pub name_span: Span,
pub ty: Type,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Type {
pub kind: TypeKind,
pub span: Span,
}
#[derive(Debug, PartialEq, Eq)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TypeKind {
I8,
I16,
@@ -92,6 +104,7 @@ pub enum TypeKind {
F64,
Bool,
Pointer(Box<Type>),
Struct(String),
}
#[derive(Debug, PartialEq)]
@@ -172,6 +185,23 @@ pub enum ExprKind<P: Phase = Untyped> {
callee: Box<Expr<P>>,
args: Vec<Expr<P>>,
},
Struct {
name: String,
name_span: Span,
fields: Vec<FieldValue<P>>,
},
FieldAccess {
expr: Box<Expr<P>>,
field: String,
field_span: Span,
},
}
#[derive(Debug, PartialEq)]
pub struct FieldValue<P: Phase = Untyped> {
pub name: String,
pub name_span: Span,
pub value: Expr<P>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+9 -5
View File
@@ -71,6 +71,7 @@ impl<'src> Lexer<'src> {
"return" => TokenKind::Return,
"let" => TokenKind::Let,
"while" => TokenKind::While,
"struct" => TokenKind::Struct,
"break" => TokenKind::Break,
"continue" => TokenKind::Continue,
@@ -251,7 +252,9 @@ mod test {
#[test]
fn identifiers() {
assert_eq!(
tokenize("HELLO _hello _0@ fn if else return let while break continue as foreign"),
tokenize(
"HELLO _hello _0@ fn if else return let while struct break continue as foreign"
),
vec![
Token::new(TokenKind::Identifier, "HELLO", Span::new(0, 5)),
Token::new(TokenKind::Identifier, "_hello", Span::new(6, 12)),
@@ -263,10 +266,11 @@ mod test {
Token::new(TokenKind::Return, "return", Span::new(28, 34)),
Token::new(TokenKind::Let, "let", Span::new(35, 38)),
Token::new(TokenKind::While, "while", Span::new(39, 44)),
Token::new(TokenKind::Break, "break", Span::new(45, 50)),
Token::new(TokenKind::Continue, "continue", Span::new(51, 59)),
Token::new(TokenKind::As, "as", Span::new(60, 62)),
Token::new(TokenKind::Foreign, "foreign", Span::new(63, 70)),
Token::new(TokenKind::Struct, "struct", Span::new(45, 51)),
Token::new(TokenKind::Break, "break", Span::new(52, 57)),
Token::new(TokenKind::Continue, "continue", Span::new(58, 66)),
Token::new(TokenKind::As, "as", Span::new(67, 69)),
Token::new(TokenKind::Foreign, "foreign", Span::new(70, 77)),
]
)
}
+239 -18
View File
@@ -117,7 +117,7 @@ impl<'src> Parser<'src> {
Ok(decl) => decls.push(decl),
Err(err) => {
self.errors.push(err);
self.synchronize(&[TokenKind::Fn]);
self.synchronize(&[TokenKind::Fn, TokenKind::Foreign, TokenKind::Struct]);
}
}
}
@@ -136,6 +136,7 @@ impl<'src> Parser<'src> {
match peek_token.kind {
TokenKind::Fn => self.parse_function_decl(),
TokenKind::Foreign => self.parse_foreign_function_decl(),
TokenKind::Struct => self.parse_struct_decl(),
_ => Err(ParseError::new(
format!(
@@ -225,6 +226,54 @@ impl<'src> Parser<'src> {
})
}
/// Parses a struct declaration.
///
/// ```ebnf
/// struct_decl = "struct" IDENTIFIER "{" { struct_field } "}" ;
/// struct_field = IDENTIFIER ":" type [ "," ] ;
/// ```
fn parse_struct_decl(&mut self) -> ParseResult<Decl> {
let struct_token = self.expect(TokenKind::Struct)?;
let (name, name_span) = {
let ident_token = self.expect(TokenKind::Identifier)?;
(ident_token.text.to_string(), ident_token.span)
};
self.expect(TokenKind::LBrace)?;
let mut fields = Vec::new();
while !self.is_at_eof() && !self.is_peek(TokenKind::RBrace) {
let field_ident = self.expect(TokenKind::Identifier)?;
self.expect(TokenKind::Colon)?;
let ty = self.parse_type()?;
fields.push(StructField {
name: field_ident.text.to_string(),
name_span: field_ident.span,
ty,
});
if self.is_peek(TokenKind::Comma) {
self.advance();
} else {
break;
}
}
let rbrace_token = self.expect(TokenKind::RBrace)?;
let span = struct_token.span.join(rbrace_token.span);
Ok(Decl {
kind: DeclKind::Struct {
name,
name_span,
fields,
},
span,
})
}
/// Parses the function parameter list.
///
/// ```ebnf
@@ -265,6 +314,7 @@ impl<'src> Parser<'src> {
/// | "u8" | "u16" | "u32" | "u64"
/// | "f32" | "f64"
/// | "bool" ;
/// | IDENTIFIER
/// ```
pub fn parse_type(&mut self) -> ParseResult<Type> {
let peek_token = self.peek_no_eof()?;
@@ -314,6 +364,10 @@ impl<'src> Parser<'src> {
self.advance();
TypeKind::Bool
}
TokenKind::Identifier => {
let token = self.advance().unwrap();
TypeKind::Struct(token.text.to_string())
}
TokenKind::Star => {
let star_token = self.advance().unwrap();
let inner = self.parse_type()?;
@@ -407,7 +461,7 @@ impl<'src> Parser<'src> {
fn parse_if_stmt(&mut self) -> ParseResult<Stmt> {
let if_token = self.expect(TokenKind::If)?;
let condition = self.parse_expr()?;
let condition = self.parse_expr_no_struct()?;
let consequence = self.parse_compound_stmt()?;
let alternative = if self.is_peek(TokenKind::Else) {
@@ -446,7 +500,7 @@ impl<'src> Parser<'src> {
/// ```
fn parse_while_stmt(&mut self) -> ParseResult<Stmt> {
let while_token = self.expect(TokenKind::While)?;
let condition = self.parse_expr()?;
let condition = self.parse_expr_no_struct()?;
let body = self.parse_compound_stmt()?;
let span = while_token.span.join(body.span);
@@ -575,12 +629,17 @@ impl<'src> Parser<'src> {
/// Parses an expression.
pub fn parse_expr(&mut self) -> ParseResult<Expr> {
self.parse_expr_bp(0)
self.parse_expr_bp(0, true)
}
/// Parses an expression, disallowing struct literals.
pub fn parse_expr_no_struct(&mut self) -> ParseResult<Expr> {
self.parse_expr_bp(0, false)
}
/// Pratt parsing implementation for expressions.
fn parse_expr_bp(&mut self, min_bp: u8) -> ParseResult<Expr> {
let mut lhs = self.parse_leading_expr()?;
fn parse_expr_bp(&mut self, min_bp: u8, allow_struct: bool) -> ParseResult<Expr> {
let mut lhs = self.parse_leading_expr(allow_struct)?;
loop {
let peek_token = self.peek_no_eof()?;
@@ -594,7 +653,7 @@ impl<'src> Parser<'src> {
}
self.advance(); // consume '='
let rhs = self.parse_expr_bp(right_bp)?;
let rhs = self.parse_expr_bp(right_bp, allow_struct)?;
let span = lhs.span.join(rhs.span);
lhs = Expr {
@@ -662,6 +721,28 @@ impl<'src> Parser<'src> {
continue;
}
if peek_token.kind == TokenKind::Dot {
let left_bp = 30; // Field access has very high precedence
if left_bp < min_bp {
break;
}
self.advance(); // consume `.`
let field_token = self.expect(TokenKind::Identifier)?;
let span = lhs.span.join(field_token.span);
lhs = Expr {
kind: ExprKind::FieldAccess {
expr: Box::new(lhs),
field: field_token.text.to_string(),
field_span: field_token.span,
},
ty: (),
span,
};
continue;
}
let Some((op, left_bp, right_bp)) = self.infix_operator(peek_token.kind) else {
break; // Not an infix operator
};
@@ -672,7 +753,7 @@ impl<'src> Parser<'src> {
self.advance(); // consume the operator
let rhs = self.parse_expr_bp(right_bp)?;
let rhs = self.parse_expr_bp(right_bp, allow_struct)?;
let span = lhs.span.join(rhs.span);
lhs = Expr {
@@ -691,20 +772,56 @@ impl<'src> Parser<'src> {
/// Parses a leading expression such as identifiers, integer and boolean literals
/// or prefix expressions.
fn parse_leading_expr(&mut self) -> ParseResult<Expr> {
fn parse_leading_expr(&mut self, allow_struct: bool) -> ParseResult<Expr> {
let peek_token = self.peek_no_eof()?;
match peek_token.kind {
TokenKind::Identifier => {
let token = self.advance().unwrap();
Ok(Expr {
kind: ExprKind::Identifier {
name: token.text.to_string(),
},
ty: (),
span: token.span,
})
if allow_struct && self.is_peek(TokenKind::LBrace) {
self.advance(); // consume `{`
let mut fields = Vec::new();
while !self.is_at_eof() && !self.is_peek(TokenKind::RBrace) {
let field_token = self.expect(TokenKind::Identifier)?;
self.expect(TokenKind::Colon)?;
let value = self.parse_expr()?;
fields.push(FieldValue {
name: field_token.text.to_string(),
name_span: field_token.span,
value,
});
if self.is_peek(TokenKind::Comma) {
self.advance();
} else {
break;
}
}
let rbrace_token = self.expect(TokenKind::RBrace)?;
let span = token.span.join(rbrace_token.span);
Ok(Expr {
kind: ExprKind::Struct {
name: token.text.to_string(),
name_span: token.span,
fields,
},
ty: (),
span,
})
} else {
Ok(Expr {
kind: ExprKind::Identifier {
name: token.text.to_string(),
},
ty: (),
span: token.span,
})
}
}
TokenKind::IntegerLit => {
@@ -754,7 +871,7 @@ impl<'src> Parser<'src> {
TokenKind::LParen => {
let lparen = self.advance().unwrap();
let expr = self.parse_expr_bp(0)?;
let expr = self.parse_expr()?; // Inner expressions allow struct literals
let rparen = self.expect(TokenKind::RParen)?;
Ok(Expr {
@@ -766,7 +883,7 @@ impl<'src> Parser<'src> {
kind if let Some((op, r_bp)) = self.prefix_operator(kind) => {
let op_token = self.advance().unwrap();
let rhs = self.parse_expr_bp(r_bp)?;
let rhs = self.parse_expr_bp(r_bp, allow_struct)?;
Ok(Expr {
ty: (),
@@ -1356,4 +1473,108 @@ mod test {
})
);
}
#[test]
fn struct_decl() {
assert_eq!(
parse("struct Vec3 { x: f32, y: f32 }", Parser::parse_decl),
Success(Decl {
kind: DeclKind::Struct {
name: "Vec3".to_string(),
name_span: Span::new(7, 11),
fields: vec![
StructField {
name: "x".to_string(),
name_span: Span::new(14, 15),
ty: Type {
kind: TypeKind::F32,
span: Span::new(17, 20)
}
},
StructField {
name: "y".to_string(),
name_span: Span::new(22, 23),
ty: Type {
kind: TypeKind::F32,
span: Span::new(25, 28)
}
}
],
},
span: Span::new(0, 30)
})
);
}
#[test]
fn struct_literal() {
assert_eq!(
parse("Vec3 { x: 1.0, y: -3.5 }", Parser::parse_expr),
Success(Expr {
kind: ExprKind::Struct {
name: "Vec3".to_string(),
name_span: Span::new(0, 4),
fields: vec![
FieldValue {
name: "x".to_string(),
name_span: Span::new(7, 8),
value: Expr {
kind: ExprKind::Float { value: 1.0 },
ty: (),
span: Span::new(10, 13)
}
},
FieldValue {
name: "y".to_string(),
name_span: Span::new(15, 16),
value: Expr {
kind: ExprKind::Unary {
op: UnaryOp::Neg,
expr: Box::new(Expr {
kind: ExprKind::Float { value: 3.5 },
ty: (),
span: Span::new(19, 22)
})
},
ty: (),
span: Span::new(18, 22)
}
}
]
},
ty: (),
span: Span::new(0, 24)
})
);
}
#[test]
fn field_access() {
assert_eq!(
parse("a.b.c", Parser::parse_expr),
Success(Expr {
kind: ExprKind::FieldAccess {
expr: Box::new(Expr {
kind: ExprKind::FieldAccess {
expr: Box::new(Expr {
kind: ExprKind::Identifier {
name: "a".to_string()
},
ty: (),
span: Span::new(0, 1)
}),
field: "b".to_string(),
field_span: Span::new(2, 3),
},
ty: (),
span: Span::new(0, 3)
}),
field: "c".to_string(),
field_span: Span::new(4, 5),
},
ty: (),
span: Span::new(0, 5)
})
);
}
}
+227
View File
@@ -42,6 +42,7 @@ pub enum Ty {
Var(usize),
Function(Vec<Ty>, Box<Ty>),
Pointer(Box<Ty>),
Struct(String),
}
impl Ty {
@@ -108,6 +109,7 @@ impl From<&TypeKind> for Ty {
TypeKind::F64 => Ty::F64,
TypeKind::Bool => Ty::Bool,
TypeKind::Pointer(inner) => Ty::Pointer(Box::new(Ty::from(&inner.kind))),
TypeKind::Struct(name) => Ty::Struct(name.clone()),
}
}
}
@@ -120,6 +122,7 @@ impl TypedExpr {
TypedExprKind::Unary {
op: UnaryOp::Deref, ..
} => true,
TypedExprKind::FieldAccess { expr, .. } => expr.is_lvalue(),
_ => false,
}
}
@@ -132,6 +135,7 @@ pub struct Sema {
next_var: usize,
subst: HashMap<usize, Ty>,
scopes: Vec<HashMap<String, Ty>>,
structs: HashMap<String, Vec<(String, Ty)>>,
errors: Vec<SemanticError>,
deferred_unary_neg: Vec<(Span, Ty, Ty, Option<u64>)>,
deferred_binary: Vec<(Span, Ty)>,
@@ -148,6 +152,7 @@ impl Sema {
next_var: 0,
subst: HashMap::new(),
scopes: Vec::new(),
structs: HashMap::new(),
errors: Vec::new(),
deferred_unary_neg: Vec::new(),
deferred_binary: Vec::new(),
@@ -290,6 +295,13 @@ impl Sema {
self.bind(name, Ty::Function(param_tys, Box::new(ret_ty)));
}
DeclKind::Struct { name, fields, .. } => {
let typed_fields = fields
.iter()
.map(|f| (f.name.clone(), Ty::from(&f.ty.kind)))
.collect();
self.structs.insert(name.clone(), typed_fields);
}
}
}
@@ -376,6 +388,25 @@ impl Sema {
span: decl.span,
}
}
DeclKind::Struct {
name,
name_span,
fields,
} => TypedDecl {
kind: TypedDeclKind::Struct {
name: name.clone(),
name_span: *name_span,
fields: fields
.iter()
.map(|f| StructField {
name: f.name.clone(),
name_span: f.name_span,
ty: f.ty.clone(),
})
.collect(),
},
span: decl.span,
},
}
}
@@ -786,6 +817,133 @@ impl Sema {
span: expr.span,
}
}
ExprKind::Struct {
name,
name_span,
fields,
} => {
let mut typed_fields = Vec::new();
let mut provided_fields = HashMap::new();
if let Some(struct_def) = self.structs.get(name).cloned() {
for field in fields {
let typed_value = self.analyze_expr(&field.value);
if let Some((_, expected_ty)) =
struct_def.iter().find(|(n, _)| n == &field.name)
{
if let Err(e) = self.unify(&typed_value.ty, expected_ty) {
self.errors.push(SemanticError::new(e, field.value.span));
}
} else {
self.errors.push(SemanticError::new(
format!("struct `{}` has no field named `{}`", name, field.name),
field.name_span,
));
}
if provided_fields
.insert(field.name.clone(), field.name_span)
.is_some()
{
self.errors.push(SemanticError::new(
format!("field `{}` specified more than once", field.name),
field.name_span,
));
}
typed_fields.push(FieldValue {
name: field.name.clone(),
name_span: field.name_span,
value: typed_value,
});
}
for (expected_field, _) in struct_def {
if !provided_fields.contains_key(&expected_field) {
self.errors.push(SemanticError::new(
format!(
"missing field `{}` in initializer of `{}`",
expected_field, name
),
expr.span,
));
}
}
} else {
self.errors.push(SemanticError::new(
format!("undeclared struct `{}`", name),
*name_span,
));
for field in fields {
let typed_value = self.analyze_expr(&field.value);
typed_fields.push(FieldValue {
name: field.name.clone(),
name_span: field.name_span,
value: typed_value,
});
}
}
TypedExpr {
kind: TypedExprKind::Struct {
name: name.clone(),
name_span: *name_span,
fields: typed_fields,
},
ty: Ty::Struct(name.clone()),
span: expr.span,
}
}
ExprKind::FieldAccess {
expr: inner_expr,
field,
field_span,
} => {
let typed_inner = self.analyze_expr(inner_expr);
let result_ty = self.new_var();
let inner_ty_resolved = self.apply_subst(&typed_inner.ty);
match inner_ty_resolved {
Ty::Struct(ref struct_name) => {
if let Some(struct_def) = self.structs.get(struct_name).cloned() {
if let Some((_, field_ty)) = struct_def.iter().find(|(n, _)| n == field)
{
if let Err(e) = self.unify(&result_ty, field_ty) {
self.errors.push(SemanticError::new(e, *field_span));
}
} else {
self.errors.push(SemanticError::new(
format!("no field `{}` on type `{}`", field, struct_name),
*field_span,
));
}
}
}
Ty::Var(_) => {
self.errors.push(SemanticError::new(
"type of expression must be known to access a field",
inner_expr.span,
));
}
_ => {
self.errors.push(SemanticError::new(
format!("cannot access field `{}` on a non-struct type", field),
*field_span,
));
}
}
TypedExpr {
kind: TypedExprKind::FieldAccess {
expr: Box::new(typed_inner),
field: field.clone(),
field_span: *field_span,
},
ty: result_ty,
span: expr.span,
}
}
}
}
@@ -827,6 +985,15 @@ impl Sema {
.collect(),
return_type: self.apply_subst(&return_type),
},
TypedDeclKind::Struct {
name,
name_span,
fields,
} => TypedDeclKind::Struct {
name,
name_span,
fields,
},
};
TypedDecl { kind, span }
@@ -916,6 +1083,31 @@ impl Sema {
callee: Box::new(self.apply_subst_expr(*callee)),
args: args.into_iter().map(|a| self.apply_subst_expr(a)).collect(),
},
TypedExprKind::Struct {
name,
name_span,
fields,
} => TypedExprKind::Struct {
name,
name_span,
fields: fields
.into_iter()
.map(|f| FieldValue {
name: f.name,
name_span: f.name_span,
value: self.apply_subst_expr(f.value),
})
.collect(),
},
TypedExprKind::FieldAccess {
expr,
field,
field_span,
} => TypedExprKind::FieldAccess {
expr: Box::new(self.apply_subst_expr(*expr)),
field,
field_span,
},
};
TypedExpr { kind, ty, span }
@@ -1308,4 +1500,39 @@ mod test {
let src = "fn test() { let a: i32 = 5; let b: *f32 = &a as *f32; }";
assert!(analyze(src).is_ok());
}
#[test]
fn valid_struct() {
let src = "
struct Vec3 { x: f32, y: f32, z: f32 }
fn make_vec() -> Vec3 {
return Vec3 { x: 1.0, y: 2.0, z: 3.0 };
}
fn get_x(v: Vec3) -> f32 {
return v.x;
}
";
assert!(analyze(src).is_ok());
}
#[test]
fn invalid_struct_field() {
let src = "struct Vec2 { x: f32, y: f32 } fn test(v: Vec2) -> f32 { return v.z; }";
let errors = analyze(src).unwrap_err();
assert!(
errors
.iter()
.any(|e| e.message.contains("no field `z` on type `Vec2`"))
);
}
#[test]
fn missing_struct_initializer_field() {
let src = "struct Vec2 { x: f32, y: f32 } fn test() -> Vec2 { return Vec2 { x: 1.0 }; }";
let errors = analyze(src).unwrap_err();
assert!(errors.iter().any(|e| {
e.message
.contains("missing field `y` in initializer of `Vec2`")
}));
}
}
+2
View File
@@ -64,6 +64,7 @@ pub enum TokenKind {
Return,
Let,
While,
Struct,
Break,
Continue,
@@ -130,6 +131,7 @@ impl Display for TokenKind {
TokenKind::Return => "`return`",
TokenKind::Let => "`let`",
TokenKind::While => "`while`",
TokenKind::Struct => "`struct`",
TokenKind::Break => "`break`",
TokenKind::Continue => "`continue`",
TokenKind::I8 => "`i8`",
+215 -54
View File
@@ -12,6 +12,20 @@ impl MirBuilder {
pub fn build(module: &TypedModule) -> MirModule {
let mut extern_functions = Vec::new();
let mut functions = Vec::new();
let mut structs = HashMap::new();
// Collect struct layouts so the backend knows their sizes
for decl in &module.decls {
if let TypedDeclKind::Struct { name, fields, .. } = &decl.kind {
structs.insert(
name.clone(),
fields
.iter()
.map(|f| (f.name.clone(), Ty::from(&f.ty.kind)))
.collect(),
);
}
}
for decl in &module.decls {
match &decl.kind {
@@ -22,7 +36,23 @@ impl MirBuilder {
return_type,
body,
} => {
let mut builder = FuncBuilder::new(name.clone(), return_type.clone());
let is_sret = matches!(return_type, Ty::Struct(_));
let mir_return_type = if is_sret {
Ty::Unit
} else {
return_type.clone()
};
let mut builder = FuncBuilder::new(name.clone(), mir_return_type);
// Implement implicit pass-by-reference for struct return values (sret)
if is_sret {
let sret_id = builder.new_local(
"$sret".to_string(),
Ty::Pointer(Box::new(return_type.clone())),
);
builder.params.push(sret_id);
builder.sret_local = Some(sret_id);
}
// Register parameters as local variables
for (param_name, ty) in params {
@@ -66,10 +96,12 @@ impl MirBuilder {
return_type: return_type.clone(),
});
}
TypedDeclKind::Struct { .. } => {}
}
}
MirModule {
structs,
extern_functions,
functions,
}
@@ -104,6 +136,8 @@ struct FuncBuilder {
/// Stack of `(continue_target, break_target)` for nested loops
loop_stack: Vec<(BlockId, BlockId)>,
/// Local ID mapped to the hidden `$sret` return pointer (if applicable)
sret_local: Option<LocalId>,
}
impl FuncBuilder {
@@ -120,6 +154,7 @@ impl FuncBuilder {
next_block_id: 0,
scopes: vec![HashMap::new()],
loop_stack: Vec::new(),
sret_local: None,
}
}
@@ -323,11 +358,26 @@ impl FuncBuilder {
}
}
TypedStmtKind::Return { value } => {
let val_op = value.as_ref().map(|v| self.lower_expr(v));
self.terminate(Terminator {
kind: TerminatorKind::Return { value: val_op },
span: stmt.span,
});
if let Some(sret_id) = self.sret_local {
let val_op = self.lower_expr(value.as_ref().unwrap());
self.emit_stmt(Statement {
kind: StatementKind::Store {
ptr: Operand::Copy(sret_id),
val: Rvalue::Use(val_op),
},
span: stmt.span,
});
self.terminate(Terminator {
kind: TerminatorKind::Return { value: None },
span: stmt.span,
});
} else {
let val_op = value.as_ref().map(|v| self.lower_expr(v));
self.terminate(Terminator {
kind: TerminatorKind::Return { value: val_op },
span: stmt.span,
});
}
}
TypedStmtKind::Let {
name,
@@ -365,45 +415,27 @@ impl FuncBuilder {
Operand::Constant(ConstantValue::Float(*value, expr.ty.clone()))
}
TypedExprKind::Boolean { value } => Operand::Constant(ConstantValue::Boolean(*value)),
TypedExprKind::Unary { op, expr: inner } => {
match op {
UnaryOp::AddressOf => match &inner.kind {
TypedExprKind::Identifier { name } => {
let id = self.lookup(name);
self.locals[id.0].address_taken = true;
let temp = self.new_temp(expr.ty.clone());
self.emit_stmt(Statement {
kind: StatementKind::Assign(temp, Rvalue::AddressOf(id)),
span: expr.span,
});
Operand::Copy(temp)
}
TypedExprKind::Unary {
op: UnaryOp::Deref,
expr: ptr_expr,
} => self.lower_expr(ptr_expr), // `&*ptr` is optimized right back to `ptr`!
_ => unreachable!("invalid lvalue for addressof"),
},
UnaryOp::Deref => {
let inner_op = self.lower_expr(inner);
let temp = self.new_temp(expr.ty.clone());
self.emit_stmt(Statement {
kind: StatementKind::Assign(temp, Rvalue::ReadPointer(inner_op)),
span: expr.span,
});
Operand::Copy(temp)
}
_ => {
let inner_op = self.lower_expr(inner);
let temp = self.new_temp(expr.ty.clone());
self.emit_stmt(Statement {
kind: StatementKind::Assign(temp, Rvalue::UnaryOp(*op, inner_op)),
span: expr.span,
});
Operand::Copy(temp)
}
TypedExprKind::Unary { op, expr: inner } => match op {
UnaryOp::AddressOf => self.lower_address_of(inner),
UnaryOp::Deref => {
let inner_op = self.lower_expr(inner);
let temp = self.new_temp(expr.ty.clone());
self.emit_stmt(Statement {
kind: StatementKind::Assign(temp, Rvalue::ReadPointer(inner_op)),
span: expr.span,
});
Operand::Copy(temp)
}
}
_ => {
let inner_op = self.lower_expr(inner);
let temp = self.new_temp(expr.ty.clone());
self.emit_stmt(Statement {
kind: StatementKind::Assign(temp, Rvalue::UnaryOp(*op, inner_op)),
span: expr.span,
});
Operand::Copy(temp)
}
},
TypedExprKind::Binary { op, lhs, rhs } => {
let lhs_op = self.lower_expr(lhs);
let rhs_op = self.lower_expr(rhs);
@@ -427,11 +459,8 @@ impl FuncBuilder {
span: expr.span,
});
}
TypedExprKind::Unary {
op: UnaryOp::Deref,
expr: ptr_expr,
} => {
let ptr_op = self.lower_expr(ptr_expr);
_ => {
let ptr_op = self.lower_address_of(lval);
self.emit_stmt(Statement {
kind: StatementKind::Store {
ptr: ptr_op,
@@ -440,7 +469,6 @@ impl FuncBuilder {
span: expr.span,
});
}
_ => unreachable!("invalid lval in MIR lowering"),
}
rval_op
@@ -466,20 +494,43 @@ impl FuncBuilder {
};
let mut arg_ops = Vec::new();
let is_sret = matches!(expr.ty, Ty::Struct(_));
let mut sret_temp = None;
// Implement implicit sret struct passing to the callee
if is_sret {
let temp = self.new_temp(expr.ty.clone());
self.locals[temp.0].address_taken = true;
let ptr_temp = self.new_temp(Ty::Pointer(Box::new(expr.ty.clone())));
self.emit_stmt(Statement {
kind: StatementKind::Assign(ptr_temp, Rvalue::AddressOf(temp)),
span: expr.span,
});
arg_ops.push(Operand::Copy(ptr_temp));
sret_temp = Some(temp);
}
for arg in args {
arg_ops.push(self.lower_expr(arg));
}
let rval = Rvalue::Call(callee_name, arg_ops, expr.ty.clone());
let mir_ret_ty = if is_sret { Ty::Unit } else { expr.ty.clone() };
let rval = Rvalue::Call(callee_name, arg_ops, mir_ret_ty.clone());
if expr.ty == Ty::Unit {
if mir_ret_ty == Ty::Unit {
self.emit_stmt(Statement {
kind: StatementKind::SideEffect(rval),
span: expr.span,
});
Operand::Constant(ConstantValue::Boolean(false)) // Dummy value for Unit assignments
if let Some(temp) = sret_temp {
Operand::Copy(temp)
} else {
Operand::Constant(ConstantValue::Boolean(false)) // Dummy value for Unit
}
} else {
let temp = self.new_temp(expr.ty.clone());
let temp = self.new_temp(mir_ret_ty);
self.emit_stmt(Statement {
kind: StatementKind::Assign(temp, rval),
span: expr.span,
@@ -487,6 +538,116 @@ impl FuncBuilder {
Operand::Copy(temp)
}
}
TypedExprKind::Struct {
name,
name_span: _,
fields,
} => {
let local_id = self.new_temp(expr.ty.clone());
self.locals[local_id.0].address_taken = true;
let base_ptr_temp = self.new_temp(Ty::Pointer(Box::new(expr.ty.clone())));
self.emit_stmt(Statement {
kind: StatementKind::Assign(base_ptr_temp, Rvalue::AddressOf(local_id)),
span: expr.span,
});
for field in fields {
let val_op = self.lower_expr(&field.value);
let field_ptr_temp =
self.new_temp(Ty::Pointer(Box::new(field.value.ty.clone())));
self.emit_stmt(Statement {
kind: StatementKind::Assign(
field_ptr_temp,
Rvalue::GetFieldPtr {
base_ptr: Operand::Copy(base_ptr_temp),
struct_name: name.clone(),
field_name: field.name.clone(),
},
),
span: field.name_span,
});
self.emit_stmt(Statement {
kind: StatementKind::Store {
ptr: Operand::Copy(field_ptr_temp),
val: Rvalue::Use(val_op),
},
span: field.value.span,
});
}
Operand::Copy(local_id)
}
TypedExprKind::FieldAccess { .. } => {
let ptr_op = self.lower_address_of(expr);
let temp = self.new_temp(expr.ty.clone());
self.emit_stmt(Statement {
kind: StatementKind::Assign(temp, Rvalue::ReadPointer(ptr_op)),
span: expr.span,
});
Operand::Copy(temp)
}
}
}
/// Safely computes the memory address of an expression without extracting or dereferencing
/// its underlying evaluated contents explicitly.
fn lower_address_of(&mut self, expr: &TypedExpr) -> Operand {
match &expr.kind {
TypedExprKind::Identifier { name } => {
let id = self.lookup(name);
self.locals[id.0].address_taken = true;
let temp = self.new_temp(Ty::Pointer(Box::new(expr.ty.clone())));
self.emit_stmt(Statement {
kind: StatementKind::Assign(temp, Rvalue::AddressOf(id)),
span: expr.span,
});
Operand::Copy(temp)
}
TypedExprKind::Unary {
op: UnaryOp::Deref,
expr: ptr_expr,
} => self.lower_expr(ptr_expr), // `&*ptr` is optimized right back to `ptr`!
TypedExprKind::FieldAccess {
expr: base,
field,
field_span: _,
} => {
let base_ptr = self.lower_address_of(base);
let struct_name = match &base.ty {
Ty::Struct(name) => name.clone(),
_ => unreachable!("field access on non-struct"),
};
let temp = self.new_temp(Ty::Pointer(Box::new(expr.ty.clone())));
self.emit_stmt(Statement {
kind: StatementKind::Assign(
temp,
Rvalue::GetFieldPtr {
base_ptr,
struct_name,
field_name: field.clone(),
},
),
span: expr.span,
});
Operand::Copy(temp)
}
_ => {
let val_op = self.lower_expr(expr);
if let Operand::Copy(id) = val_op {
self.locals[id.0].address_taken = true;
let temp = self.new_temp(Ty::Pointer(Box::new(expr.ty.clone())));
self.emit_stmt(Statement {
kind: StatementKind::Assign(temp, Rvalue::AddressOf(id)),
span: expr.span,
});
Operand::Copy(temp)
} else {
unreachable!("cannot safely take address of constant rvalue")
}
}
}
}
}
+2
View File
@@ -94,6 +94,7 @@ mod test {
#[test]
fn test_eliminate_dead_blocks() {
let mut module = MirModule {
structs: HashMap::new(),
functions: vec![MirFunction {
name: "test_func".to_string(),
params: vec![],
@@ -145,6 +146,7 @@ mod test {
#[test]
fn test_eliminate_dead_cond_branch() {
let mut module = MirModule {
structs: HashMap::new(),
functions: vec![MirFunction {
name: "test_cond_func".to_string(),
params: vec![],
+3
View File
@@ -94,6 +94,7 @@ fn propagate_rvalue(rvalue: &mut Rvalue, known_constants: &HashMap<LocalId, Cons
}
Rvalue::AddressOf(_) => {}
Rvalue::ReadPointer(op) => propagate_operand(op, known_constants),
Rvalue::GetFieldPtr { base_ptr, .. } => propagate_operand(base_ptr, known_constants),
}
}
@@ -334,6 +335,7 @@ mod test {
};
let mut module = MirModule {
structs: HashMap::new(),
functions: vec![func],
extern_functions: vec![],
};
@@ -382,6 +384,7 @@ mod test {
};
let mut module = MirModule {
structs: HashMap::new(),
functions: vec![func],
extern_functions: vec![],
};
+8
View File
@@ -1,3 +1,5 @@
use std::collections::HashMap;
use crate::frontend::ast::{BinaryOp, UnaryOp};
use crate::frontend::sema::Ty;
use crate::frontend::token::Span;
@@ -12,6 +14,7 @@ pub struct LocalId(pub usize);
#[derive(Debug)]
pub struct MirModule {
pub structs: HashMap<String, Vec<(String, Ty)>>,
pub extern_functions: Vec<MirExternFunction>,
pub functions: Vec<MirFunction>,
}
@@ -78,6 +81,11 @@ pub enum Rvalue {
Call(String, Vec<Operand>, Ty),
AddressOf(LocalId),
ReadPointer(Operand),
GetFieldPtr {
base_ptr: Operand,
struct_name: String,
field_name: String,
},
}
/// An atomic value used as inputs to instructions.
+29
View File
@@ -0,0 +1,29 @@
[code]
foreign fn putchar(c: i32) -> i32;
struct Point {
x: i32,
y: i32
}
fn print_num(n: i32) {
// Simple hack to print a 2-digit number for testing
putchar(48 + (n / 10));
putchar(48 + (n % 10));
putchar(10); // newline
}
fn main() -> i32 {
let p = Point { x: 40, y: 2 };
// 40 + 2 = 42
print_num(p.x + p.y);
return 0;
}
[expected_return_code]
0
[expected_output]
42
+38
View File
@@ -0,0 +1,38 @@
[code]
foreign fn putchar(c: i32) -> i32;
struct Rect {
width: i32,
height: i32
}
fn modify_rect(r: *Rect) {
let temp: Rect = *r;
temp.width = temp.width + 10;
temp.height = temp.height + 20;
*r = temp;
}
fn print_num(n: i32) {
putchar(48 + (n / 10));
putchar(48 + (n % 10));
putchar(10);
}
fn main() -> i32 {
let r = Rect { width: 15, height: 25 };
modify_rect(&r);
print_num(r.width);
print_num(r.height);
return 0;
}
[expected_return_code]
0
[expected_output]
25
45