Files
compiler-old/src/middle/builder.rs
T
2026-04-22 22:40:19 +02:00

496 lines
17 KiB
Rust

use std::collections::HashMap;
use crate::frontend::ast::*;
use crate::frontend::sema::Ty;
use crate::middle::mir::*;
/// Lowers a fully-typed AST into a linear Control Flow Graph (MIR).
pub struct MirBuilder;
impl MirBuilder {
/// Builds a `MirModule` from a `TypedModule`.
pub fn build(module: &TypedModule) -> MirModule {
let mut functions = Vec::new();
for decl in &module.decls {
match &decl.kind {
TypedDeclKind::Function {
name,
name_span: _,
params,
return_type,
body,
} => {
let mut builder = FuncBuilder::new(name.clone(), return_type.clone());
// Register parameters as local variables
for (param_name, ty) in params {
let local_id = builder.new_local(param_name.clone(), ty.clone());
builder.params.push(local_id);
}
let entry = builder.new_block();
builder.switch_to_block(entry);
builder.lower_stmt(body);
// If the final block was never terminated (e.g. missing an explicit return),
// we insert an implicit return for Unit, or an Unreachable trap otherwise.
if builder.current_block.is_some() {
let span = body.span;
if *return_type == Ty::Unit {
builder.terminate(Terminator {
kind: TerminatorKind::Return { value: None },
span,
});
} else {
builder.terminate(Terminator {
kind: TerminatorKind::Unreachable,
span,
});
}
}
functions.push(builder.finish());
}
}
}
MirModule { functions }
}
}
/// A helper struct that manages the state required to construct a single `MirFunction`.
///
/// It keeps track of variable declarations, basic blocks, and handles the flattening
/// of nested AST nodes into a sequential Control Flow Graph.
struct FuncBuilder {
/// The name of the function being built.
name: String,
/// The expected return type of the function.
return_type: Ty,
/// `LocalId`s corresponding to the function's parameters.
params: Vec<LocalId>,
/// All local variables and compiler-generated temporaries used in the function.
locals: Vec<LocalDecl>,
/// The basic blocks that make up the function's control flow graph.
blocks: Vec<BasicBlock>,
/// The block currently being populated with statements. If `None`, the builder is in "dead code" territory.
current_block: Option<BlockId>,
/// Statements buffered for the current block before a terminator is reached.
current_statements: Vec<Statement>,
/// Counter for generating unique `BlockId`s.
next_block_id: usize,
/// Scoped mapping from user-defined variable names to their corresponding `LocalId`.
scopes: Vec<HashMap<String, LocalId>>,
/// Stack of `(continue_target, break_target)` for nested loops
loop_stack: Vec<(BlockId, BlockId)>,
}
impl FuncBuilder {
/// Creates a new `FuncBuilder` for a function with the given name and return type.
fn new(name: String, return_type: Ty) -> Self {
Self {
name,
return_type,
params: Vec::new(),
locals: Vec::new(),
blocks: Vec::new(),
current_block: None,
current_statements: Vec::new(),
next_block_id: 0,
scopes: vec![HashMap::new()],
loop_stack: Vec::new(),
}
}
fn enter_scope(&mut self) {
self.scopes.push(HashMap::new());
}
fn leave_scope(&mut self) {
self.scopes.pop();
}
/// Registers a new user-defined local variable and returns its `LocalId`.
fn new_local(&mut self, name: String, ty: Ty) -> LocalId {
let id = LocalId(self.locals.len());
self.locals.push(LocalDecl {
id,
ty,
mutable: false,
name: Some(name.clone()),
});
self.scopes.last_mut().unwrap().insert(name, id);
id
}
/// Creates a new compiler-generated temporary variable and returns its `LocalId`.
fn new_temp(&mut self, ty: Ty) -> LocalId {
let id = LocalId(self.locals.len());
self.locals.push(LocalDecl {
id,
ty,
mutable: false,
name: None,
});
id
}
fn lookup(&self, name: &str) -> LocalId {
for scope in self.scopes.iter().rev() {
if let Some(id) = scope.get(name) {
return *id;
}
}
panic!("undeclared variable `{}` in MIR lowering", name);
}
/// Allocates a new, empty basic block and returns its `BlockId`.
fn new_block(&mut self) -> BlockId {
let id = BlockId(self.next_block_id);
self.next_block_id += 1;
id
}
/// Sets the given block as the active insertion point for new statements.
/// Note: The previous block should be terminated before switching.
fn switch_to_block(&mut self, id: BlockId) {
assert!(self.current_statements.is_empty());
self.current_block = Some(id);
}
/// Appends a statement to the current basic block.
/// If the current block has already been terminated, the statement is ignored (dead code elimination).
fn emit_stmt(&mut self, stmt: Statement) {
if self.current_block.is_some() {
self.current_statements.push(stmt);
}
}
/// Terminates the current basic block, sealing it and finalizing its instructions.
/// Subsequent statements will be ignored until `switch_to_block` is called.
fn terminate(&mut self, terminator: Terminator) {
if let Some(id) = self.current_block.take() {
self.blocks.push(BasicBlock {
id,
statements: std::mem::take(&mut self.current_statements),
terminator,
});
}
}
/// Finalizes the function construction, returning the complete `MirFunction`.
fn finish(mut self) -> MirFunction {
// Ensure basic blocks are strictly ordered by their numerical IDs
self.blocks.sort_by_key(|b| b.id.0);
MirFunction {
name: self.name,
params: self.params,
return_type: self.return_type,
locals: self.locals,
blocks: self.blocks,
}
}
/// Recursively lowers a typed statement into MIR instructions and basic blocks.
fn lower_stmt(&mut self, stmt: &TypedStmt) {
match &stmt.kind {
TypedStmtKind::Compound { inner } => {
self.enter_scope();
for s in inner {
self.lower_stmt(s);
}
self.leave_scope();
}
TypedStmtKind::If {
condition,
then,
elze,
} => {
let cond_op = self.lower_expr(condition);
let then_block = self.new_block();
let merge_block = self.new_block();
let else_block = if elze.is_some() {
self.new_block()
} else {
merge_block
};
self.terminate(Terminator {
kind: TerminatorKind::CondBranch {
cond: cond_op,
target_true: then_block,
target_false: else_block,
},
span: stmt.span,
});
self.switch_to_block(then_block);
self.lower_stmt(then);
self.terminate(Terminator {
kind: TerminatorKind::Goto {
target: merge_block,
},
span: then.span,
});
if let Some(e) = elze {
self.switch_to_block(else_block);
self.lower_stmt(e);
self.terminate(Terminator {
kind: TerminatorKind::Goto {
target: merge_block,
},
span: e.span,
});
}
self.switch_to_block(merge_block);
}
TypedStmtKind::While { condition, body } => {
let cond_block = self.new_block();
let body_block = self.new_block();
let merge_block = self.new_block();
self.terminate(Terminator {
kind: TerminatorKind::Goto { target: cond_block },
span: stmt.span,
});
self.switch_to_block(cond_block);
let cond_op = self.lower_expr(condition);
self.terminate(Terminator {
kind: TerminatorKind::CondBranch {
cond: cond_op,
target_true: body_block,
target_false: merge_block,
},
span: condition.span,
});
self.switch_to_block(body_block);
self.loop_stack.push((cond_block, merge_block));
self.lower_stmt(body);
self.loop_stack.pop();
self.terminate(Terminator {
kind: TerminatorKind::Goto { target: cond_block },
span: body.span,
});
self.switch_to_block(merge_block);
}
TypedStmtKind::Break => {
if let Some(&(_, merge_block)) = self.loop_stack.last() {
self.terminate(Terminator {
kind: TerminatorKind::Goto {
target: merge_block,
},
span: stmt.span,
});
}
}
TypedStmtKind::Continue => {
if let Some(&(cond_block, _)) = self.loop_stack.last() {
self.terminate(Terminator {
kind: TerminatorKind::Goto { target: cond_block },
span: stmt.span,
});
}
}
TypedStmtKind::Return { value } => {
let val_op = value.as_ref().map(|v| self.lower_expr(v));
self.terminate(Terminator {
kind: TerminatorKind::Return { value: val_op },
span: stmt.span,
});
}
TypedStmtKind::Let {
name,
name_span: _,
ty,
initializer,
} => {
let local_id = self.new_local(name.clone(), ty.clone());
if let Some(init_expr) = initializer {
let val_op = self.lower_expr(init_expr);
self.emit_stmt(Statement {
kind: StatementKind::Assign(local_id, Rvalue::Use(val_op)),
span: stmt.span,
});
}
}
TypedStmtKind::Expression { expr } => {
self.lower_expr(expr);
}
}
}
/// Recursively lowers a typed expression into MIR instructions, returning an `Operand` representing its result.
fn lower_expr(&mut self, expr: &TypedExpr) -> Operand {
match &expr.kind {
TypedExprKind::Identifier { name } => {
let local = self.lookup(name);
Operand::Copy(local)
}
TypedExprKind::Integer { value } => {
Operand::Constant(ConstantValue::Integer(*value, expr.ty.clone()))
}
TypedExprKind::Float { value } => {
Operand::Constant(ConstantValue::Float(*value, expr.ty.clone()))
}
TypedExprKind::Boolean { value } => Operand::Constant(ConstantValue::Boolean(*value)),
TypedExprKind::Unary { op, expr: inner } => {
let inner_op = self.lower_expr(inner);
let temp = self.new_temp(expr.ty.clone());
self.emit_stmt(Statement {
kind: StatementKind::Assign(temp, Rvalue::UnaryOp(*op, inner_op)),
span: expr.span,
});
Operand::Copy(temp)
}
TypedExprKind::Binary { op, lhs, rhs } => {
let lhs_op = self.lower_expr(lhs);
let rhs_op = self.lower_expr(rhs);
let temp = self.new_temp(expr.ty.clone());
self.emit_stmt(Statement {
kind: StatementKind::Assign(temp, Rvalue::BinaryOp(*op, lhs_op, rhs_op)),
span: expr.span,
});
Operand::Copy(temp)
}
TypedExprKind::Assign { lval, rval } => {
let rval_op = self.lower_expr(rval);
let local_id = match &lval.kind {
TypedExprKind::Identifier { name } => self.lookup(name),
_ => panic!("invalid lval in MIR lowering"),
};
self.emit_stmt(Statement {
kind: StatementKind::Assign(local_id, Rvalue::Use(rval_op.clone())),
span: expr.span,
});
rval_op
}
TypedExprKind::Cast {
expr: inner,
ty: target_ty,
} => {
let inner_op = self.lower_expr(inner);
let temp = self.new_temp(target_ty.clone());
self.emit_stmt(Statement {
kind: StatementKind::Assign(temp, Rvalue::Cast(target_ty.clone(), inner_op)),
span: expr.span,
});
Operand::Copy(temp)
}
}
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::frontend::parser::Parser;
use crate::frontend::sema::Sema;
/// Helper function to parse, analyze, and build a MIR module from source code.
fn build_mir(source: &str) -> MirModule {
let mut parser = Parser::new(source);
let module = parser.parse_module();
if let Some(errors) = parser.errors() {
panic!("Parse errors: {:?}", errors);
}
let mut sema = Sema::new();
let typed_module = sema.analyze_module(&module);
if let Some(errors) = sema.errors() {
panic!("Semantic errors: {:?}", errors);
}
MirBuilder::build(&typed_module)
}
#[test]
fn test_lower_while_loop() {
let mir = build_mir("fn main() { while true { } }");
let func = &mir.functions[0];
// Ensure exactly 4 basic blocks are generated
assert_eq!(func.blocks.len(), 4);
// Block 0: Entry -> Goto(1)
assert!(matches!(
func.blocks[0].terminator.kind,
TerminatorKind::Goto { target: BlockId(1) }
));
// Block 1: Cond -> CondBranch(true -> 2, false -> 3)
assert!(matches!(
func.blocks[1].terminator.kind,
TerminatorKind::CondBranch {
target_true: BlockId(2),
target_false: BlockId(3),
..
}
));
// Block 2: Body -> Goto(1)
assert!(matches!(
func.blocks[2].terminator.kind,
TerminatorKind::Goto { target: BlockId(1) }
));
// Block 3: Merge -> Return
assert!(matches!(
func.blocks[3].terminator.kind,
TerminatorKind::Return { value: None }
));
}
#[test]
fn test_lower_break_and_continue() {
let mir = build_mir("fn main() { while true { continue; break; } }");
let func = &mir.functions[0];
// The body block (BlockId(2)) hits `continue` first, meaning it terminates immediately with a Goto back to the condition block.
// The trailing `break` is correctly skipped by the builder as automatic dead code!
assert!(matches!(
func.blocks[2].terminator.kind,
TerminatorKind::Goto { target: BlockId(1) }
));
}
#[test]
fn test_lower_let_and_assign() {
let mir = build_mir("fn main() { let a = 5; a = 10; }");
let func = &mir.functions[0];
assert_eq!(func.locals.len(), 1); // Only 1 variable 'a' (No temporaries generated for literals)
assert_eq!(func.blocks.len(), 1); // No branches, so 1 block
let block = &func.blocks[0];
assert_eq!(block.statements.len(), 2); // Includes both the let initialization and the later assignment
assert!(matches!(
block.statements[0].kind,
StatementKind::Assign(LocalId(0), _)
));
assert!(matches!(
block.statements[1].kind,
StatementKind::Assign(LocalId(0), _)
));
}
}