feat: unify AST structures, introduce MIR and update codegen

This commit is contained in:
2026-04-21 22:08:15 +02:00
parent 22023a8734
commit 3e0b5c5b00
10 changed files with 686 additions and 307 deletions
+154 -134
View File
@@ -13,8 +13,8 @@ use cranelift_object::{ObjectBuilder, ObjectModule};
use crate::frontend::{ use crate::frontend::{
ast::{BinaryOp, UnaryOp}, ast::{BinaryOp, UnaryOp},
sema::Ty, sema::Ty,
typed_ast::*,
}; };
use crate::middle::mir::*;
/// The backend responsible for lowering a `TypedModule` into Cranelift IR and /// The backend responsible for lowering a `TypedModule` into Cranelift IR and
/// generating native machine code object files. /// generating native machine code object files.
@@ -49,63 +49,53 @@ impl CraneliftBackend {
} }
} }
/// Compiles a fully typed AST module into native object code. /// Compiles a MIR module into native object code.
/// ///
/// Returns a tuple containing the generated Cranelift IR (as a human-readable string) and the assembled object file bytes. /// Returns a tuple containing the generated Cranelift IR (as a human-readable string) and the assembled object file bytes.
pub fn compile_module(mut self, module: &TypedModule) -> (String, Vec<u8>) { pub fn compile_module(mut self, module: &MirModule) -> (String, Vec<u8>) {
let mut ir_output = String::new(); let mut ir_output = String::new();
for decl in &module.decls { for func in &module.functions {
match decl { self.compile_function(func);
TypedDecl::Function {
name,
params,
return_type,
body,
} => {
self.compile_function(params, return_type, body);
// Run Cranelift's optimization passes before emitting the text IR // Run Cranelift's optimization passes before emitting the text IR
let mut ctrl_plane = ControlPlane::default(); let mut ctrl_plane = ControlPlane::default();
self.ctx self.ctx
.optimize(self.module.isa(), &mut ctrl_plane) .optimize(self.module.isa(), &mut ctrl_plane)
.unwrap(); .unwrap();
ir_output.push_str(&format!( ir_output.push_str(&format!(
"; Function: {}\n{}", "; Function: {}\n{}",
name, func.name,
self.ctx.func.to_string() self.ctx.func.to_string()
)); ));
ir_output.push('\n'); ir_output.push('\n');
let func_id = self let func_id = self
.module .module
.declare_function(name, Linkage::Export, &self.ctx.func.signature) .declare_function(&func.name, Linkage::Export, &self.ctx.func.signature)
.unwrap(); .unwrap();
self.module.define_function(func_id, &mut self.ctx).unwrap(); self.module.define_function(func_id, &mut self.ctx).unwrap();
self.module.clear_context(&mut self.ctx); self.module.clear_context(&mut self.ctx);
}
}
} }
let obj_bytes = self.module.finish().emit().unwrap(); let obj_bytes = self.module.finish().emit().unwrap();
(ir_output, obj_bytes) (ir_output, obj_bytes)
} }
/// Lowers a single function declaration into Cranelift IR. /// Lowers a single MIR function into Cranelift IR.
/// fn compile_function(&mut self, func: &MirFunction) {
/// This sets up the function signature, ABI parameters, entry block, and declares the parameters as local variables.
fn compile_function(&mut self, params: &[(String, Ty)], return_type: &Ty, body: &TypedStmt) {
let mut sig = self.module.make_signature(); let mut sig = self.module.make_signature();
for (_, ty) in params { for param_id in &func.params {
sig.params.push(AbiParam::new(Self::lower_type(ty))); let param_ty = &func.locals[param_id.0].ty;
sig.params.push(AbiParam::new(Self::lower_type(param_ty)));
} }
if return_type != &Ty::Unit { if func.return_type != Ty::Unit {
sig.returns sig.returns
.push(AbiParam::new(Self::lower_type(return_type))); .push(AbiParam::new(Self::lower_type(&func.return_type)));
} }
self.ctx.func.signature = sig; self.ctx.func.signature = sig;
@@ -113,23 +103,55 @@ impl CraneliftBackend {
let mut builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context); let mut builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context);
let entry_block = builder.create_block(); let mut block_map = HashMap::new();
builder.append_block_params_for_function_params(entry_block); for block in &func.blocks {
builder.switch_to_block(entry_block); block_map.insert(block.id, builder.create_block());
builder.seal_block(entry_block);
let mut vars = HashMap::new();
for (i, (param_name, ty)) in params.iter().enumerate() {
let var = builder.declare_var(Self::lower_type(ty));
let val = builder.block_params(entry_block)[i];
builder.def_var(var, val);
vars.insert(param_name.clone(), var);
} }
let mut trans = FunctionTranslator { builder, vars }; let mut var_map = HashMap::new();
for local in &func.locals {
let var = builder.declare_var(Self::lower_type(&local.ty));
var_map.insert(local.id, var);
}
let mut trans = FunctionTranslator {
builder,
var_map,
block_map,
locals: &func.locals,
};
if let Some(first_block) = func.blocks.first() {
let entry_block = trans.block_map[&first_block.id];
trans
.builder
.append_block_params_for_function_params(entry_block);
}
for (i, block) in func.blocks.iter().enumerate() {
let cl_block = trans.block_map[&block.id];
trans.builder.switch_to_block(cl_block);
// Retrieve function arguments explicitly if this is the entry block
if i == 0 {
for (j, param_id) in func.params.iter().enumerate() {
let val = trans.builder.block_params(cl_block)[j];
trans.builder.def_var(trans.var_map[param_id], val);
}
}
for stmt in &block.statements {
trans.translate_stmt(stmt);
}
trans.translate_terminator(&block.terminator);
}
// Seal all blocks now that all branches are translated and predecessors are known
for cl_block in trans.block_map.values() {
trans.builder.seal_block(*cl_block);
}
trans.translate_stmt(body);
trans.builder.finalize(); trans.builder.finalize();
} }
@@ -146,109 +168,107 @@ impl CraneliftBackend {
} }
} }
/// A visitor that traverses typed statements and expressions, emitting Cranelift IR instructions into the current function builder. /// A visitor that traverses MIR basic blocks and instructions, emitting Cranelift IR instructions
/// into the current function builder.
struct FunctionTranslator<'a> { struct FunctionTranslator<'a> {
builder: FunctionBuilder<'a>, builder: FunctionBuilder<'a>,
vars: HashMap<String, Variable>, var_map: HashMap<LocalId, Variable>,
block_map: HashMap<BlockId, ir::Block>,
locals: &'a [LocalDecl],
} }
impl<'a> FunctionTranslator<'a> { impl<'a> FunctionTranslator<'a> {
/// Translates a statement, recursively compiling its inner components. fn translate_stmt(&mut self, stmt: &Statement) {
/// Returns `true` if the statement resulted in a basic block terminator. match &stmt.kind {
fn translate_stmt(&mut self, stmt: &TypedStmt) -> bool { StatementKind::Assign(local_id, rvalue) => {
match stmt { let val = self.translate_rvalue(rvalue);
TypedStmt::Compound { inner } => { let var = self.var_map[local_id];
for s in inner { self.builder.def_var(var, val);
if self.translate_stmt(s) {
return true;
}
}
false
}
TypedStmt::If {
condition,
then,
elze,
} => {
let cond_val = self.translate_expr(condition);
let then_block = self.builder.create_block();
let else_block = self.builder.create_block();
let merge_block = self.builder.create_block();
self.builder
.ins()
.brif(cond_val, then_block, &[], else_block, &[]);
self.builder.switch_to_block(then_block);
self.builder.seal_block(then_block);
let then_terminated = self.translate_stmt(then);
if !then_terminated {
self.builder.ins().jump(merge_block, &[]);
}
self.builder.switch_to_block(else_block);
self.builder.seal_block(else_block);
let else_terminated = elze
.as_ref()
.map(|stmt| self.translate_stmt(stmt))
.unwrap_or(false);
if !else_terminated {
self.builder.ins().jump(merge_block, &[]);
}
self.builder.switch_to_block(merge_block);
self.builder.seal_block(merge_block);
then_terminated && else_terminated
}
TypedStmt::Return { value } => {
if let Some(expr) = value {
let val = self.translate_expr(expr);
self.builder.ins().return_(&[val]);
} else {
self.builder.ins().return_(&[]);
}
true
} }
} }
} }
/// Translates an expression into a Cranelift IR value. fn translate_terminator(&mut self, term: &Terminator) {
/// Emits appropriate computation instructions based on operators and operand types. match &term.kind {
fn translate_expr(&mut self, expr: &TypedExpr) -> ir::Value { TerminatorKind::Goto { target } => {
match &expr.kind { self.builder.ins().jump(self.block_map[target], &[]);
TypedExprKind::Identifier { name } => {
let var = self.vars.get(name).expect("Undeclared variable");
self.builder.use_var(*var)
} }
TypedExprKind::Integer { value } => { TerminatorKind::CondBranch {
let ty = CraneliftBackend::lower_type(&expr.ty); cond,
self.builder.ins().iconst(ty, *value as i64) target_true,
target_false,
} => {
let cond_val = self.translate_operand(cond);
self.builder.ins().brif(
cond_val,
self.block_map[target_true],
&[],
self.block_map[target_false],
&[],
);
} }
TypedExprKind::Boolean { value } => { TerminatorKind::Return { value } => {
let ty = CraneliftBackend::lower_type(&expr.ty); if let Some(op) = value {
self.builder.ins().iconst(ty, if *value { 1 } else { 0 }) let val = self.translate_operand(op);
self.builder.ins().return_(&[val]);
} else {
self.builder.ins().return_(&[]);
}
} }
TypedExprKind::Unary { op, expr: inner } => { TerminatorKind::Unreachable => {
let inner_val = self.translate_expr(inner); self.builder.ins().trap(ir::TrapCode::user(5).unwrap());
}
}
}
fn get_operand_type(&self, op: &Operand) -> Ty {
match op {
Operand::Copy(local_id) => self.locals[local_id.0].ty.clone(),
Operand::Constant(ConstantValue::Integer(_, ty)) => ty.clone(),
Operand::Constant(ConstantValue::Boolean(_)) => Ty::Bool,
}
}
fn translate_operand(&mut self, op: &Operand) -> ir::Value {
match op {
Operand::Copy(local_id) => {
let var = self.var_map[local_id];
self.builder.use_var(var)
}
Operand::Constant(ConstantValue::Integer(val, ty)) => {
let cl_ty = CraneliftBackend::lower_type(ty);
self.builder.ins().iconst(cl_ty, *val as i64)
}
Operand::Constant(ConstantValue::Boolean(val)) => self
.builder
.ins()
.iconst(types::I8, if *val { 1 } else { 0 }),
}
}
fn translate_rvalue(&mut self, rvalue: &Rvalue) -> ir::Value {
match rvalue {
Rvalue::Use(op) => self.translate_operand(op),
Rvalue::UnaryOp(op, inner) => {
let inner_val = self.translate_operand(inner);
match op { match op {
UnaryOp::Neg => self.builder.ins().ineg(inner_val), UnaryOp::Neg => self.builder.ins().ineg(inner_val),
UnaryOp::Not => { UnaryOp::Not => {
// `!x` is equivalent to `x == 0` for booleans (0 or 1). let ty = self.get_operand_type(inner);
let ty = CraneliftBackend::lower_type(&inner.ty); let cl_ty = CraneliftBackend::lower_type(&ty);
let zero = self.builder.ins().iconst(ty, 0); let zero = self.builder.ins().iconst(cl_ty, 0);
self.builder self.builder
.ins() .ins()
.icmp(ir::condcodes::IntCC::Equal, inner_val, zero) .icmp(ir::condcodes::IntCC::Equal, inner_val, zero)
} }
} }
} }
TypedExprKind::Binary { op, lhs, rhs } => { Rvalue::BinaryOp(op, lhs, rhs) => {
let lhs_val = self.translate_expr(lhs); let lhs_val = self.translate_operand(lhs);
let rhs_val = self.translate_expr(rhs); let rhs_val = self.translate_operand(rhs);
let is_signed = matches!(lhs.ty, Ty::I8 | Ty::I16 | Ty::I32 | Ty::I64); let ty = self.get_operand_type(lhs);
let is_signed = matches!(ty, Ty::I8 | Ty::I16 | Ty::I32 | Ty::I64);
match op { match op {
BinaryOp::Add => self.builder.ins().iadd(lhs_val, rhs_val), BinaryOp::Add => self.builder.ins().iadd(lhs_val, rhs_val),
+58 -23
View File
@@ -1,24 +1,58 @@
use crate::frontend::sema::Ty;
use crate::frontend::token::Span; use crate::frontend::token::Span;
use std::fmt::Debug;
#[derive(Debug, PartialEq, Eq)] pub trait Phase: Debug + PartialEq + Eq {
pub struct Module { type ReturnType: Debug + PartialEq + Eq;
pub decls: Vec<Decl>, type ParamType: Debug + PartialEq + Eq;
type ExprType: Debug + PartialEq + Eq;
} }
#[derive(Debug, PartialEq, Eq)] #[derive(Debug, PartialEq, Eq)]
pub struct Decl { pub struct Untyped;
pub kind: DeclKind,
impl Phase for Untyped {
type ReturnType = Option<Type>;
type ParamType = FunctionParam;
type ExprType = ();
}
#[derive(Debug, PartialEq, Eq)]
pub struct Typed;
impl Phase for Typed {
type ReturnType = Ty;
type ParamType = (String, Ty);
type ExprType = Ty;
}
pub type TypedModule = Module<Typed>;
pub type TypedDecl = Decl<Typed>;
pub type TypedDeclKind = DeclKind<Typed>;
pub type TypedStmt = Stmt<Typed>;
pub type TypedStmtKind = StmtKind<Typed>;
pub type TypedExpr = Expr<Typed>;
pub type TypedExprKind = ExprKind<Typed>;
#[derive(Debug, PartialEq, Eq)]
pub struct Module<P: Phase = Untyped> {
pub decls: Vec<Decl<P>>,
}
#[derive(Debug, PartialEq, Eq)]
pub struct Decl<P: Phase = Untyped> {
pub kind: DeclKind<P>,
pub span: Span, pub span: Span,
} }
#[derive(Debug, PartialEq, Eq)] #[derive(Debug, PartialEq, Eq)]
pub enum DeclKind { pub enum DeclKind<P: Phase = Untyped> {
Function { Function {
name: String, name: String,
name_span: Span, name_span: Span,
params: Vec<FunctionParam>, params: Vec<P::ParamType>,
return_type: Option<Type>, return_type: P::ReturnType,
body: Stmt, body: Stmt<P>,
}, },
} }
@@ -49,34 +83,35 @@ pub enum TypeKind {
} }
#[derive(Debug, PartialEq, Eq)] #[derive(Debug, PartialEq, Eq)]
pub struct Stmt { pub struct Stmt<P: Phase = Untyped> {
pub kind: StmtKind, pub kind: StmtKind<P>,
pub span: Span, pub span: Span,
} }
#[derive(Debug, PartialEq, Eq)] #[derive(Debug, PartialEq, Eq)]
pub enum StmtKind { pub enum StmtKind<P: Phase = Untyped> {
Compound { Compound {
inner: Vec<Stmt>, inner: Vec<Stmt<P>>,
}, },
If { If {
condition: Expr, condition: Expr<P>,
then: Box<Stmt>, then: Box<Stmt<P>>,
elze: Option<Box<Stmt>>, elze: Option<Box<Stmt<P>>>,
}, },
Return { Return {
value: Option<Expr>, value: Option<Expr<P>>,
}, },
} }
#[derive(Debug, PartialEq, Eq)] #[derive(Debug, PartialEq, Eq)]
pub struct Expr { pub struct Expr<P: Phase = Untyped> {
pub kind: ExprKind, pub kind: ExprKind<P>,
pub ty: P::ExprType,
pub span: Span, pub span: Span,
} }
#[derive(Debug, PartialEq, Eq)] #[derive(Debug, PartialEq, Eq)]
pub enum ExprKind { pub enum ExprKind<P: Phase = Untyped> {
Identifier { Identifier {
name: String, name: String,
}, },
@@ -88,12 +123,12 @@ pub enum ExprKind {
}, },
Unary { Unary {
op: UnaryOp, op: UnaryOp,
expr: Box<Expr>, expr: Box<Expr<P>>,
}, },
Binary { Binary {
op: BinaryOp, op: BinaryOp,
lhs: Box<Expr>, lhs: Box<Expr<P>>,
rhs: Box<Expr>, rhs: Box<Expr<P>>,
}, },
} }
-1
View File
@@ -3,4 +3,3 @@ pub mod lexer;
pub mod parser; pub mod parser;
pub mod sema; pub mod sema;
pub mod token; pub mod token;
pub mod typed_ast;
+29
View File
@@ -431,6 +431,7 @@ impl<'src> Parser<'src> {
lhs: Box::new(lhs), lhs: Box::new(lhs),
rhs: Box::new(rhs), rhs: Box::new(rhs),
}, },
ty: (),
span, span,
}; };
} }
@@ -451,6 +452,7 @@ impl<'src> Parser<'src> {
kind: ExprKind::Identifier { kind: ExprKind::Identifier {
name: token.text.to_string(), name: token.text.to_string(),
}, },
ty: (),
span: token.span, span: token.span,
}) })
} }
@@ -472,6 +474,7 @@ impl<'src> Parser<'src> {
Ok(Expr { Ok(Expr {
kind: ExprKind::Integer { value }, kind: ExprKind::Integer { value },
ty: (),
span: token.span, span: token.span,
}) })
} }
@@ -483,6 +486,7 @@ impl<'src> Parser<'src> {
kind: ExprKind::Boolean { kind: ExprKind::Boolean {
value: token.text == "true", value: token.text == "true",
}, },
ty: (),
span: token.span, span: token.span,
}) })
} }
@@ -494,6 +498,7 @@ impl<'src> Parser<'src> {
Ok(Expr { Ok(Expr {
kind: expr.kind, kind: expr.kind,
ty: (),
span: lparen.span.join(rparen.span), span: lparen.span.join(rparen.span),
}) })
} }
@@ -503,6 +508,7 @@ impl<'src> Parser<'src> {
let rhs = self.parse_expr_bp(r_bp)?; let rhs = self.parse_expr_bp(r_bp)?;
Ok(Expr { Ok(Expr {
ty: (),
span: op_token.span.join(rhs.span), span: op_token.span.join(rhs.span),
kind: ExprKind::Unary { kind: ExprKind::Unary {
op, op,
@@ -598,6 +604,7 @@ mod test {
parse("0xBEEF;", Parser::parse_expr), parse("0xBEEF;", Parser::parse_expr),
Success(Expr { Success(Expr {
kind: ExprKind::Integer { value: 0xBEEF }, kind: ExprKind::Integer { value: 0xBEEF },
ty: (),
span: Span::new(0, 6) span: Span::new(0, 6)
}) })
); );
@@ -606,6 +613,7 @@ mod test {
parse("0o777;", Parser::parse_expr), parse("0o777;", Parser::parse_expr),
Success(Expr { Success(Expr {
kind: ExprKind::Integer { value: 0o777 }, kind: ExprKind::Integer { value: 0o777 },
ty: (),
span: Span::new(0, 5) span: Span::new(0, 5)
}) })
); );
@@ -614,6 +622,7 @@ mod test {
parse("0b1001;", Parser::parse_expr), parse("0b1001;", Parser::parse_expr),
Success(Expr { Success(Expr {
kind: ExprKind::Integer { value: 0b1001 }, kind: ExprKind::Integer { value: 0b1001 },
ty: (),
span: Span::new(0, 6) span: Span::new(0, 6)
}) })
); );
@@ -622,6 +631,7 @@ mod test {
parse("1337;", Parser::parse_expr), parse("1337;", Parser::parse_expr),
Success(Expr { Success(Expr {
kind: ExprKind::Integer { value: 1337 }, kind: ExprKind::Integer { value: 1337 },
ty: (),
span: Span::new(0, 4) span: Span::new(0, 4)
}) })
); );
@@ -633,6 +643,7 @@ mod test {
parse("true;", Parser::parse_expr), parse("true;", Parser::parse_expr),
Success(Expr { Success(Expr {
kind: ExprKind::Boolean { value: true }, kind: ExprKind::Boolean { value: true },
ty: (),
span: Span::new(0, 4) span: Span::new(0, 4)
}) })
); );
@@ -641,6 +652,7 @@ mod test {
parse("false;", Parser::parse_expr), parse("false;", Parser::parse_expr),
Success(Expr { Success(Expr {
kind: ExprKind::Boolean { value: false }, kind: ExprKind::Boolean { value: false },
ty: (),
span: Span::new(0, 5) span: Span::new(0, 5)
}) })
); );
@@ -655,9 +667,11 @@ mod test {
op: UnaryOp::Neg, op: UnaryOp::Neg,
expr: Box::new(Expr { expr: Box::new(Expr {
kind: ExprKind::Integer { value: 5 }, kind: ExprKind::Integer { value: 5 },
ty: (),
span: Span::new(1, 2) span: Span::new(1, 2)
}) })
}, },
ty: (),
span: Span::new(0, 2) span: Span::new(0, 2)
}) })
); );
@@ -672,6 +686,7 @@ mod test {
op: BinaryOp::Add, op: BinaryOp::Add,
lhs: Box::new(Expr { lhs: Box::new(Expr {
kind: ExprKind::Integer { value: 12 }, kind: ExprKind::Integer { value: 12 },
ty: (),
span: Span::new(0, 2) span: Span::new(0, 2)
}), }),
rhs: Box::new(Expr { rhs: Box::new(Expr {
@@ -679,16 +694,20 @@ mod test {
op: BinaryOp::Mul, op: BinaryOp::Mul,
lhs: Box::new(Expr { lhs: Box::new(Expr {
kind: ExprKind::Integer { value: 3 }, kind: ExprKind::Integer { value: 3 },
ty: (),
span: Span::new(5, 6) span: Span::new(5, 6)
}), }),
rhs: Box::new(Expr { rhs: Box::new(Expr {
kind: ExprKind::Integer { value: 6 }, kind: ExprKind::Integer { value: 6 },
ty: (),
span: Span::new(9, 10) span: Span::new(9, 10)
}) })
}, },
ty: (),
span: Span::new(5, 10) span: Span::new(5, 10)
}) })
}, },
ty: (),
span: Span::new(0, 10) span: Span::new(0, 10)
}) })
); );
@@ -710,6 +729,7 @@ mod test {
kind: StmtKind::Return { kind: StmtKind::Return {
value: Some(Expr { value: Some(Expr {
kind: ExprKind::Integer { value: 0 }, kind: ExprKind::Integer { value: 0 },
ty: (),
span: Span::new(7, 8) span: Span::new(7, 8)
}) })
}, },
@@ -726,6 +746,7 @@ mod test {
kind: StmtKind::If { kind: StmtKind::If {
condition: Expr { condition: Expr {
kind: ExprKind::Boolean { value: true }, kind: ExprKind::Boolean { value: true },
ty: (),
span: Span::new(3, 7) span: Span::new(3, 7)
}, },
then: Box::new(Stmt { then: Box::new(Stmt {
@@ -818,15 +839,18 @@ mod test {
kind: ExprKind::Identifier { kind: ExprKind::Identifier {
name: "a".to_string() name: "a".to_string()
}, },
ty: (),
span: Span::new(39, 40) span: Span::new(39, 40)
}), }),
rhs: Box::new(Expr { rhs: Box::new(Expr {
kind: ExprKind::Identifier { kind: ExprKind::Identifier {
name: "b".to_string() name: "b".to_string()
}, },
ty: (),
span: Span::new(43, 44) span: Span::new(43, 44)
}) })
}, },
ty: (),
span: Span::new(39, 44) span: Span::new(39, 44)
}) })
}, },
@@ -852,13 +876,16 @@ mod test {
kind: ExprKind::Identifier { kind: ExprKind::Identifier {
name: "a".to_string() name: "a".to_string()
}, },
ty: (),
span: Span::new(0, 1) span: Span::new(0, 1)
}), }),
rhs: Box::new(Expr { rhs: Box::new(Expr {
kind: ExprKind::Integer { value: 5 }, kind: ExprKind::Integer { value: 5 },
ty: (),
span: Span::new(5, 6) span: Span::new(5, 6)
}) })
}, },
ty: (),
span: Span::new(0, 6) span: Span::new(0, 6)
}) })
); );
@@ -873,9 +900,11 @@ mod test {
op: UnaryOp::Not, op: UnaryOp::Not,
expr: Box::new(Expr { expr: Box::new(Expr {
kind: ExprKind::Boolean { value: true }, kind: ExprKind::Boolean { value: true },
ty: (),
span: Span::new(1, 5) span: Span::new(1, 5)
}) })
}, },
ty: (),
span: Span::new(0, 5) span: Span::new(0, 5)
}) })
); );
+58 -88
View File
@@ -2,7 +2,6 @@ use std::collections::HashMap;
use crate::frontend::ast::*; use crate::frontend::ast::*;
use crate::frontend::token::Span; use crate::frontend::token::Span;
use crate::frontend::typed_ast::*;
/// A structured error produced during semantic analysis, carrying a human-readable /// A structured error produced during semantic analysis, carrying a human-readable
/// message and the [Span] of the offending AST node for precise diagnostics. /// message and the [Span] of the offending AST node for precise diagnostics.
@@ -79,7 +78,6 @@ pub struct Sema {
deferred_unary_neg: Vec<(Span, Ty, Ty, Option<u64>)>, deferred_unary_neg: Vec<(Span, Ty, Ty, Option<u64>)>,
deferred_binary: Vec<(Span, Ty)>, deferred_binary: Vec<(Span, Ty)>,
deferred_literals: Vec<(Span, Ty)>, deferred_literals: Vec<(Span, Ty)>,
is_reachable: bool,
} }
impl Sema { impl Sema {
@@ -93,7 +91,6 @@ impl Sema {
deferred_unary_neg: Vec::new(), deferred_unary_neg: Vec::new(),
deferred_binary: Vec::new(), deferred_binary: Vec::new(),
deferred_literals: Vec::new(), deferred_literals: Vec::new(),
is_reachable: true,
} }
} }
@@ -245,10 +242,10 @@ impl Sema {
match &decl.kind { match &decl.kind {
DeclKind::Function { DeclKind::Function {
name, name,
name_span,
params, params,
return_type, return_type,
body, body,
..
} => { } => {
let mut typed_params = Vec::new(); let mut typed_params = Vec::new();
@@ -265,24 +262,19 @@ impl Sema {
.map(|t| Ty::from(&t.kind)) .map(|t| Ty::from(&t.kind))
.unwrap_or(Ty::Unit); .unwrap_or(Ty::Unit);
self.is_reachable = true;
let typed_body = self.analyze_stmt(body, &expected_ret_ty); let typed_body = self.analyze_stmt(body, &expected_ret_ty);
if expected_ret_ty != Ty::Unit && self.is_reachable {
self.errors.push(SemanticError::new(
"not all control paths return a value",
decl.span,
));
}
self.leave_scope(); self.leave_scope();
TypedDecl::Function { TypedDecl {
name: name.clone(), kind: TypedDeclKind::Function {
params: typed_params, name: name.clone(),
return_type: expected_ret_ty, name_span: *name_span,
body: typed_body, params: typed_params,
return_type: expected_ret_ty,
body: typed_body,
},
span: decl.span,
} }
} }
} }
@@ -294,22 +286,19 @@ impl Sema {
match &stmt.kind { match &stmt.kind {
StmtKind::Compound { inner } => { StmtKind::Compound { inner } => {
let mut typed_inner = Vec::new(); let mut typed_inner = Vec::new();
let mut reported_unreachable = false;
self.enter_scope(); self.enter_scope();
for s in inner { for s in inner {
if !self.is_reachable && !reported_unreachable {
self.errors
.push(SemanticError::new("unreachable statement", s.span));
reported_unreachable = true;
}
typed_inner.push(self.analyze_stmt(s, expected_ret_ty)); typed_inner.push(self.analyze_stmt(s, expected_ret_ty));
} }
self.leave_scope(); self.leave_scope();
TypedStmt::Compound { inner: typed_inner } TypedStmt {
kind: TypedStmtKind::Compound { inner: typed_inner },
span: stmt.span,
}
} }
StmtKind::If { StmtKind::If {
condition, condition,
@@ -322,29 +311,16 @@ impl Sema {
self.errors.push(SemanticError::new(err, condition.span)); self.errors.push(SemanticError::new(err, condition.span));
} }
let initial_reachable = self.is_reachable;
self.is_reachable = initial_reachable;
let typed_then = self.analyze_stmt(then, expected_ret_ty); let typed_then = self.analyze_stmt(then, expected_ret_ty);
let reachable_after_then = self.is_reachable; let typed_elze = elze.as_ref().map(|e| self.analyze_stmt(e, expected_ret_ty));
let typed_elze = elze.as_ref().map(|e| { TypedStmt {
self.is_reachable = initial_reachable; kind: TypedStmtKind::If {
self.analyze_stmt(e, expected_ret_ty) condition: typed_condition,
}); then: Box::new(typed_then),
elze: typed_elze.map(Box::new),
let reachable_after_else = if elze.is_some() { },
self.is_reachable span: stmt.span,
} else {
initial_reachable
};
self.is_reachable = reachable_after_then || reachable_after_else;
TypedStmt::If {
condition: typed_condition,
then: Box::new(typed_then),
elze: typed_elze.map(Box::new),
} }
} }
StmtKind::Return { value } => { StmtKind::Return { value } => {
@@ -355,19 +331,21 @@ impl Sema {
self.errors.push(SemanticError::new(err, expr.span)); self.errors.push(SemanticError::new(err, expr.span));
} }
self.is_reachable = false; TypedStmt {
kind: TypedStmtKind::Return {
TypedStmt::Return { value: Some(typed_expr),
value: Some(typed_expr), },
span: stmt.span,
} }
} else { } else {
if let Err(err) = self.unify(&Ty::Unit, expected_ret_ty) { if let Err(err) = self.unify(&Ty::Unit, expected_ret_ty) {
self.errors.push(SemanticError::new(err, stmt.span)); self.errors.push(SemanticError::new(err, stmt.span));
} }
self.is_reachable = false; TypedStmt {
kind: TypedStmtKind::Return { value: None },
TypedStmt::Return { value: None } span: stmt.span,
}
} }
} }
} }
@@ -392,6 +370,7 @@ impl Sema {
TypedExpr { TypedExpr {
kind: TypedExprKind::Identifier { name: name.clone() }, kind: TypedExprKind::Identifier { name: name.clone() },
ty, ty,
span: expr.span,
} }
} }
@@ -402,12 +381,14 @@ impl Sema {
TypedExpr { TypedExpr {
kind: TypedExprKind::Integer { value: *value }, kind: TypedExprKind::Integer { value: *value },
ty, ty,
span: expr.span,
} }
} }
ExprKind::Boolean { value } => TypedExpr { ExprKind::Boolean { value } => TypedExpr {
kind: TypedExprKind::Boolean { value: *value }, kind: TypedExprKind::Boolean { value: *value },
ty: Ty::Bool, ty: Ty::Bool,
span: expr.span,
}, },
ExprKind::Unary { ExprKind::Unary {
@@ -435,6 +416,7 @@ impl Sema {
expr: Box::new(typed_inner), expr: Box::new(typed_inner),
}, },
ty: result_ty, ty: result_ty,
span: expr.span,
} }
} }
@@ -454,6 +436,7 @@ impl Sema {
expr: Box::new(typed_inner), expr: Box::new(typed_inner),
}, },
ty: Ty::Bool, ty: Ty::Bool,
span: expr.span,
} }
} }
@@ -490,6 +473,7 @@ impl Sema {
rhs: Box::new(typed_rhs), rhs: Box::new(typed_rhs),
}, },
ty: result_ty, ty: result_ty,
span: expr.span,
} }
} }
} }
@@ -497,9 +481,11 @@ impl Sema {
/// Recursively applies the final resolved type substitutions to a typed declaration. /// Recursively applies the final resolved type substitutions to a typed declaration.
fn apply_subst_decl(&self, decl: TypedDecl) -> TypedDecl { fn apply_subst_decl(&self, decl: TypedDecl) -> TypedDecl {
match decl { let span = decl.span;
TypedDecl::Function { let kind = match decl.kind {
TypedDeclKind::Function {
name, name,
name_span,
params, params,
return_type, return_type,
body, body,
@@ -509,45 +495,52 @@ impl Sema {
.map(|(n, ty)| (n, self.apply_subst(&ty))) .map(|(n, ty)| (n, self.apply_subst(&ty)))
.collect(); .collect();
TypedDecl::Function { TypedDeclKind::Function {
name, name,
name_span,
params, params,
return_type: self.apply_subst(&return_type), return_type: self.apply_subst(&return_type),
body: self.apply_subst_stmt(body), body: self.apply_subst_stmt(body),
} }
} }
} };
TypedDecl { kind, span }
} }
/// Recursively applies the final resolved type substitutions to a typed statement. /// Recursively applies the final resolved type substitutions to a typed statement.
fn apply_subst_stmt(&self, stmt: TypedStmt) -> TypedStmt { fn apply_subst_stmt(&self, stmt: TypedStmt) -> TypedStmt {
match stmt { let span = stmt.span;
TypedStmt::Compound { inner } => TypedStmt::Compound { let kind = match stmt.kind {
TypedStmtKind::Compound { inner } => TypedStmtKind::Compound {
inner: inner inner: inner
.into_iter() .into_iter()
.map(|s| self.apply_subst_stmt(s)) .map(|s| self.apply_subst_stmt(s))
.collect(), .collect(),
}, },
TypedStmt::If { TypedStmtKind::If {
condition, condition,
then, then,
elze, elze,
} => TypedStmt::If { } => TypedStmtKind::If {
condition: self.apply_subst_expr(condition), condition: self.apply_subst_expr(condition),
then: Box::new(self.apply_subst_stmt(*then)), then: Box::new(self.apply_subst_stmt(*then)),
elze: elze.map(|s| Box::new(self.apply_subst_stmt(*s))), elze: elze.map(|s| Box::new(self.apply_subst_stmt(*s))),
}, },
TypedStmt::Return { value } => TypedStmt::Return { TypedStmtKind::Return { value } => TypedStmtKind::Return {
value: value.map(|e| self.apply_subst_expr(e)), value: value.map(|e| self.apply_subst_expr(e)),
}, },
} };
TypedStmt { kind, span }
} }
/// Recursively applies the final resolved type substitutions to a typed expression. /// Recursively applies the final resolved type substitutions to a typed expression.
fn apply_subst_expr(&self, expr: TypedExpr) -> TypedExpr { fn apply_subst_expr(&self, expr: TypedExpr) -> TypedExpr {
let ty = self.apply_subst(&expr.ty); let ty = self.apply_subst(&expr.ty);
let span = expr.span;
let kind = match expr.kind { let kind = match expr.kind {
TypedExprKind::Identifier { name } => TypedExprKind::Identifier { name }, TypedExprKind::Identifier { name } => TypedExprKind::Identifier { name },
TypedExprKind::Integer { value } => TypedExprKind::Integer { value }, TypedExprKind::Integer { value } => TypedExprKind::Integer { value },
@@ -565,7 +558,7 @@ impl Sema {
}, },
}; };
TypedExpr { kind, ty } TypedExpr { kind, ty, span }
} }
/// Resolves all deferred type constraints accumulated during analysis, such as /// Resolves all deferred type constraints accumulated during analysis, such as
@@ -682,9 +675,9 @@ impl Sema {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use crate::frontend::{ use crate::frontend::{
ast::TypedModule,
parser::Parser, parser::Parser,
sema::{Sema, SemanticError}, sema::{Sema, SemanticError},
typed_ast::TypedModule,
}; };
fn analyze(source: &str) -> Result<TypedModule, Vec<SemanticError>> { fn analyze(source: &str) -> Result<TypedModule, Vec<SemanticError>> {
@@ -804,27 +797,4 @@ mod test {
let src = "fn test() { if 12 {} }"; let src = "fn test() { if 12 {} }";
assert!(analyze(src).is_err()); assert!(analyze(src).is_err());
} }
#[test]
fn not_all_paths_return() {
let src = "fn test(a: i32) -> i32 { if a < 5 { return 5; } else { } }";
assert!(analyze(src).is_err());
let src = "fn test() -> i32 { }";
assert!(analyze(src).is_err());
let src = "fn test(a: i32) -> i32 { if a < 5 { return 5; } return 10; }";
assert!(analyze(src).is_ok());
}
#[test]
fn unreachable_code() {
let src = "fn test() -> i32 { return 5; return 10; }";
let errors = analyze(src).unwrap_err();
assert!(
errors
.iter()
.any(|e| e.message.contains("unreachable statement"))
);
}
} }
-60
View File
@@ -1,60 +0,0 @@
use crate::frontend::ast::{BinaryOp, UnaryOp};
use crate::frontend::sema::Ty;
#[derive(Debug, PartialEq, Eq)]
pub struct TypedModule {
pub decls: Vec<TypedDecl>,
}
#[derive(Debug, PartialEq, Eq)]
pub enum TypedDecl {
Function {
name: String,
params: Vec<(String, Ty)>,
return_type: Ty,
body: TypedStmt,
},
}
#[derive(Debug, PartialEq, Eq)]
pub enum TypedStmt {
Compound {
inner: Vec<TypedStmt>,
},
If {
condition: TypedExpr,
then: Box<TypedStmt>,
elze: Option<Box<TypedStmt>>,
},
Return {
value: Option<TypedExpr>,
},
}
#[derive(Debug, PartialEq, Eq)]
pub struct TypedExpr {
pub kind: TypedExprKind,
pub ty: Ty,
}
#[derive(Debug, PartialEq, Eq)]
pub enum TypedExprKind {
Identifier {
name: String,
},
Integer {
value: u64,
},
Boolean {
value: bool,
},
Unary {
op: UnaryOp,
expr: Box<TypedExpr>,
},
Binary {
op: BinaryOp,
lhs: Box<TypedExpr>,
rhs: Box<TypedExpr>,
},
}
+5 -1
View File
@@ -4,9 +4,11 @@ use clap::Parser as ClapParser;
use crate::frontend::parser::Parser; use crate::frontend::parser::Parser;
use crate::frontend::sema::Sema; use crate::frontend::sema::Sema;
use crate::middle::builder::MirBuilder;
pub mod backend; pub mod backend;
pub mod frontend; pub mod frontend;
pub mod middle;
use crate::backend::cranelift::CraneliftBackend; use crate::backend::cranelift::CraneliftBackend;
@@ -55,8 +57,10 @@ fn main() {
exit(1); exit(1);
} }
let mir_module = MirBuilder::build(&typed_module);
let backend = CraneliftBackend::new(); let backend = CraneliftBackend::new();
let (ir, obj_bytes) = backend.compile_module(&typed_module); let (ir, obj_bytes) = backend.compile_module(&mir_module);
if cli.emit_ir { if cli.emit_ir {
println!("{}", ir); println!("{}", ir);
+282
View File
@@ -0,0 +1,282 @@
use std::collections::HashMap;
use crate::frontend::ast::*;
use crate::frontend::sema::Ty;
use crate::middle::mir::*;
/// Lowers a fully-typed AST into a linear Control Flow Graph (MIR).
pub struct MirBuilder;
impl MirBuilder {
/// Builds a `MirModule` from a `TypedModule`.
pub fn build(module: &TypedModule) -> MirModule {
let mut functions = Vec::new();
for decl in &module.decls {
match &decl.kind {
TypedDeclKind::Function {
name,
name_span: _,
params,
return_type,
body,
} => {
let mut builder = FuncBuilder::new(name.clone(), return_type.clone());
// Register parameters as local variables
for (param_name, ty) in params {
let local_id = builder.new_local(param_name.clone(), ty.clone());
builder.params.push(local_id);
}
let entry = builder.new_block();
builder.switch_to_block(entry);
builder.lower_stmt(body);
// If the final block was never terminated (e.g. missing an explicit return),
// we insert an implicit return for Unit, or an Unreachable trap otherwise.
if builder.current_block.is_some() {
let span = body.span;
if *return_type == Ty::Unit {
builder.terminate(Terminator {
kind: TerminatorKind::Return { value: None },
span,
});
} else {
builder.terminate(Terminator {
kind: TerminatorKind::Unreachable,
span,
});
}
}
functions.push(builder.finish());
}
}
}
MirModule { functions }
}
}
/// A helper struct that manages the state required to construct a single `MirFunction`.
///
/// It keeps track of variable declarations, basic blocks, and handles the flattening
/// of nested AST nodes into a sequential Control Flow Graph.
struct FuncBuilder {
/// The name of the function being built.
name: String,
/// The expected return type of the function.
return_type: Ty,
/// `LocalId`s corresponding to the function's parameters.
params: Vec<LocalId>,
/// All local variables and compiler-generated temporaries used in the function.
locals: Vec<LocalDecl>,
/// The basic blocks that make up the function's control flow graph.
blocks: Vec<BasicBlock>,
/// The block currently being populated with statements. If `None`, the builder is in "dead code" territory.
current_block: Option<BlockId>,
/// Statements buffered for the current block before a terminator is reached.
current_statements: Vec<Statement>,
/// Counter for generating unique `BlockId`s.
next_block_id: usize,
/// Mapping from user-defined variable names to their corresponding `LocalId`.
vars: HashMap<String, LocalId>,
}
impl FuncBuilder {
/// Creates a new `FuncBuilder` for a function with the given name and return type.
fn new(name: String, return_type: Ty) -> Self {
Self {
name,
return_type,
params: Vec::new(),
locals: Vec::new(),
blocks: Vec::new(),
current_block: None,
current_statements: Vec::new(),
next_block_id: 0,
vars: HashMap::new(),
}
}
/// Registers a new user-defined local variable and returns its `LocalId`.
fn new_local(&mut self, name: String, ty: Ty) -> LocalId {
let id = LocalId(self.locals.len());
self.locals.push(LocalDecl {
id,
ty,
mutable: false,
name: Some(name.clone()),
});
self.vars.insert(name, id);
id
}
/// Creates a new compiler-generated temporary variable and returns its `LocalId`.
fn new_temp(&mut self, ty: Ty) -> LocalId {
let id = LocalId(self.locals.len());
self.locals.push(LocalDecl {
id,
ty,
mutable: false,
name: None,
});
id
}
/// Allocates a new, empty basic block and returns its `BlockId`.
fn new_block(&mut self) -> BlockId {
let id = BlockId(self.next_block_id);
self.next_block_id += 1;
id
}
/// Sets the given block as the active insertion point for new statements.
/// Note: The previous block should be terminated before switching.
fn switch_to_block(&mut self, id: BlockId) {
assert!(self.current_statements.is_empty());
self.current_block = Some(id);
}
/// Appends a statement to the current basic block.
/// If the current block has already been terminated, the statement is ignored (dead code elimination).
fn emit_stmt(&mut self, stmt: Statement) {
if self.current_block.is_some() {
self.current_statements.push(stmt);
}
}
/// Terminates the current basic block, sealing it and finalizing its instructions.
/// Subsequent statements will be ignored until `switch_to_block` is called.
fn terminate(&mut self, terminator: Terminator) {
if let Some(id) = self.current_block.take() {
self.blocks.push(BasicBlock {
id,
statements: std::mem::take(&mut self.current_statements),
terminator,
});
}
}
/// Finalizes the function construction, returning the complete `MirFunction`.
fn finish(mut self) -> MirFunction {
// Ensure basic blocks are strictly ordered by their numerical IDs
self.blocks.sort_by_key(|b| b.id.0);
MirFunction {
name: self.name,
params: self.params,
return_type: self.return_type,
locals: self.locals,
blocks: self.blocks,
}
}
/// Recursively lowers a typed statement into MIR instructions and basic blocks.
fn lower_stmt(&mut self, stmt: &TypedStmt) {
match &stmt.kind {
TypedStmtKind::Compound { inner } => {
for s in inner {
self.lower_stmt(s);
}
}
TypedStmtKind::If {
condition,
then,
elze,
} => {
let cond_op = self.lower_expr(condition);
let then_block = self.new_block();
let merge_block = self.new_block();
let else_block = if elze.is_some() {
self.new_block()
} else {
merge_block
};
self.terminate(Terminator {
kind: TerminatorKind::CondBranch {
cond: cond_op,
target_true: then_block,
target_false: else_block,
},
span: stmt.span,
});
self.switch_to_block(then_block);
self.lower_stmt(then);
self.terminate(Terminator {
kind: TerminatorKind::Goto {
target: merge_block,
},
span: then.span,
});
if let Some(e) = elze {
self.switch_to_block(else_block);
self.lower_stmt(e);
self.terminate(Terminator {
kind: TerminatorKind::Goto {
target: merge_block,
},
span: e.span,
});
}
self.switch_to_block(merge_block);
}
TypedStmtKind::Return { value } => {
let val_op = value.as_ref().map(|v| self.lower_expr(v));
self.terminate(Terminator {
kind: TerminatorKind::Return { value: val_op },
span: stmt.span,
});
}
}
}
/// Recursively lowers a typed expression into MIR instructions, returning an `Operand` representing its result.
fn lower_expr(&mut self, expr: &TypedExpr) -> Operand {
match &expr.kind {
TypedExprKind::Identifier { name } => {
let local = *self
.vars
.get(name)
.expect("undeclared variable in MIR lowering");
Operand::Copy(local)
}
TypedExprKind::Integer { value } => {
Operand::Constant(ConstantValue::Integer(*value, expr.ty.clone()))
}
TypedExprKind::Boolean { value } => Operand::Constant(ConstantValue::Boolean(*value)),
TypedExprKind::Unary { op, expr: inner } => {
let inner_op = self.lower_expr(inner);
let temp = self.new_temp(expr.ty.clone());
self.emit_stmt(Statement {
kind: StatementKind::Assign(temp, Rvalue::UnaryOp(*op, inner_op)),
span: expr.span,
});
Operand::Copy(temp)
}
TypedExprKind::Binary { op, lhs, rhs } => {
let lhs_op = self.lower_expr(lhs);
let rhs_op = self.lower_expr(rhs);
let temp = self.new_temp(expr.ty.clone());
self.emit_stmt(Statement {
kind: StatementKind::Assign(temp, Rvalue::BinaryOp(*op, lhs_op, rhs_op)),
span: expr.span,
});
Operand::Copy(temp)
}
}
}
}
+98
View File
@@ -0,0 +1,98 @@
use crate::frontend::ast::{BinaryOp, UnaryOp};
use crate::frontend::sema::Ty;
use crate::frontend::token::Span;
/// A strongly-typed reference to a basic block.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct BlockId(pub usize);
/// A strongly-typed reference to a local variable or temporary.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct LocalId(pub usize);
#[derive(Debug)]
pub struct MirModule {
pub functions: Vec<MirFunction>,
}
#[derive(Debug)]
pub struct MirFunction {
pub name: String,
pub params: Vec<LocalId>,
pub return_type: Ty,
pub locals: Vec<LocalDecl>,
pub blocks: Vec<BasicBlock>,
}
/// A declaration of a local variable or compiler-generated temporary.
#[derive(Debug)]
pub struct LocalDecl {
pub id: LocalId,
pub ty: Ty,
pub mutable: bool,
/// Contains the name if it is a user-declared variable, or `None` if it's a compiler temporary.
pub name: Option<String>,
}
/// A sequential list of non-branching statements followed by a single terminator.
#[derive(Debug)]
pub struct BasicBlock {
pub id: BlockId,
pub statements: Vec<Statement>,
pub terminator: Terminator,
}
#[derive(Debug)]
pub struct Statement {
pub kind: StatementKind,
pub span: Span,
}
#[derive(Debug)]
pub enum StatementKind {
/// Assigns the result of an Rvalue to a local variable or temporary.
Assign(LocalId, Rvalue),
}
/// Operations that produce a value.
#[derive(Debug)]
pub enum Rvalue {
Use(Operand),
UnaryOp(UnaryOp, Operand),
BinaryOp(BinaryOp, Operand, Operand),
}
/// An atomic value used as inputs to instructions.
#[derive(Debug, Clone)]
pub enum Operand {
Copy(LocalId),
Constant(ConstantValue),
}
#[derive(Debug, Clone)]
pub enum ConstantValue {
Integer(u64, Ty),
Boolean(bool),
}
#[derive(Debug)]
pub struct Terminator {
pub kind: TerminatorKind,
pub span: Span,
}
#[derive(Debug)]
pub enum TerminatorKind {
Goto {
target: BlockId,
},
CondBranch {
cond: Operand,
target_true: BlockId,
target_false: BlockId,
},
Return {
value: Option<Operand>,
},
Unreachable,
}
+2
View File
@@ -0,0 +1,2 @@
pub mod builder;
pub mod mir;