496 lines
17 KiB
Rust
496 lines
17 KiB
Rust
use std::collections::HashMap;
|
|
|
|
use crate::frontend::ast::*;
|
|
use crate::frontend::sema::Ty;
|
|
use crate::middle::mir::*;
|
|
|
|
/// Lowers a fully-typed AST into a linear Control Flow Graph (MIR).
|
|
pub struct MirBuilder;
|
|
|
|
impl MirBuilder {
|
|
/// Builds a `MirModule` from a `TypedModule`.
|
|
pub fn build(module: &TypedModule) -> MirModule {
|
|
let mut functions = Vec::new();
|
|
|
|
for decl in &module.decls {
|
|
match &decl.kind {
|
|
TypedDeclKind::Function {
|
|
name,
|
|
name_span: _,
|
|
params,
|
|
return_type,
|
|
body,
|
|
} => {
|
|
let mut builder = FuncBuilder::new(name.clone(), return_type.clone());
|
|
|
|
// Register parameters as local variables
|
|
for (param_name, ty) in params {
|
|
let local_id = builder.new_local(param_name.clone(), ty.clone());
|
|
builder.params.push(local_id);
|
|
}
|
|
|
|
let entry = builder.new_block();
|
|
builder.switch_to_block(entry);
|
|
|
|
builder.lower_stmt(body);
|
|
|
|
// If the final block was never terminated (e.g. missing an explicit return),
|
|
// we insert an implicit return for Unit, or an Unreachable trap otherwise.
|
|
if builder.current_block.is_some() {
|
|
let span = body.span;
|
|
if *return_type == Ty::Unit {
|
|
builder.terminate(Terminator {
|
|
kind: TerminatorKind::Return { value: None },
|
|
span,
|
|
});
|
|
} else {
|
|
builder.terminate(Terminator {
|
|
kind: TerminatorKind::Unreachable,
|
|
span,
|
|
});
|
|
}
|
|
}
|
|
|
|
functions.push(builder.finish());
|
|
}
|
|
}
|
|
}
|
|
|
|
MirModule { functions }
|
|
}
|
|
}
|
|
|
|
/// A helper struct that manages the state required to construct a single `MirFunction`.
|
|
///
|
|
/// It keeps track of variable declarations, basic blocks, and handles the flattening
|
|
/// of nested AST nodes into a sequential Control Flow Graph.
|
|
struct FuncBuilder {
|
|
/// The name of the function being built.
|
|
name: String,
|
|
/// The expected return type of the function.
|
|
return_type: Ty,
|
|
/// `LocalId`s corresponding to the function's parameters.
|
|
params: Vec<LocalId>,
|
|
/// All local variables and compiler-generated temporaries used in the function.
|
|
locals: Vec<LocalDecl>,
|
|
|
|
/// The basic blocks that make up the function's control flow graph.
|
|
blocks: Vec<BasicBlock>,
|
|
/// The block currently being populated with statements. If `None`, the builder is in "dead code" territory.
|
|
current_block: Option<BlockId>,
|
|
/// Statements buffered for the current block before a terminator is reached.
|
|
current_statements: Vec<Statement>,
|
|
/// Counter for generating unique `BlockId`s.
|
|
next_block_id: usize,
|
|
|
|
/// Scoped mapping from user-defined variable names to their corresponding `LocalId`.
|
|
scopes: Vec<HashMap<String, LocalId>>,
|
|
|
|
/// Stack of `(continue_target, break_target)` for nested loops
|
|
loop_stack: Vec<(BlockId, BlockId)>,
|
|
}
|
|
|
|
impl FuncBuilder {
|
|
/// Creates a new `FuncBuilder` for a function with the given name and return type.
|
|
fn new(name: String, return_type: Ty) -> Self {
|
|
Self {
|
|
name,
|
|
return_type,
|
|
params: Vec::new(),
|
|
locals: Vec::new(),
|
|
blocks: Vec::new(),
|
|
current_block: None,
|
|
current_statements: Vec::new(),
|
|
next_block_id: 0,
|
|
scopes: vec![HashMap::new()],
|
|
loop_stack: Vec::new(),
|
|
}
|
|
}
|
|
|
|
fn enter_scope(&mut self) {
|
|
self.scopes.push(HashMap::new());
|
|
}
|
|
|
|
fn leave_scope(&mut self) {
|
|
self.scopes.pop();
|
|
}
|
|
|
|
/// Registers a new user-defined local variable and returns its `LocalId`.
|
|
fn new_local(&mut self, name: String, ty: Ty) -> LocalId {
|
|
let id = LocalId(self.locals.len());
|
|
self.locals.push(LocalDecl {
|
|
id,
|
|
ty,
|
|
mutable: false,
|
|
name: Some(name.clone()),
|
|
});
|
|
self.scopes.last_mut().unwrap().insert(name, id);
|
|
id
|
|
}
|
|
|
|
/// Creates a new compiler-generated temporary variable and returns its `LocalId`.
|
|
fn new_temp(&mut self, ty: Ty) -> LocalId {
|
|
let id = LocalId(self.locals.len());
|
|
self.locals.push(LocalDecl {
|
|
id,
|
|
ty,
|
|
mutable: false,
|
|
name: None,
|
|
});
|
|
id
|
|
}
|
|
|
|
fn lookup(&self, name: &str) -> LocalId {
|
|
for scope in self.scopes.iter().rev() {
|
|
if let Some(id) = scope.get(name) {
|
|
return *id;
|
|
}
|
|
}
|
|
panic!("undeclared variable `{}` in MIR lowering", name);
|
|
}
|
|
|
|
/// Allocates a new, empty basic block and returns its `BlockId`.
|
|
fn new_block(&mut self) -> BlockId {
|
|
let id = BlockId(self.next_block_id);
|
|
self.next_block_id += 1;
|
|
id
|
|
}
|
|
|
|
/// Sets the given block as the active insertion point for new statements.
|
|
/// Note: The previous block should be terminated before switching.
|
|
fn switch_to_block(&mut self, id: BlockId) {
|
|
assert!(self.current_statements.is_empty());
|
|
self.current_block = Some(id);
|
|
}
|
|
|
|
/// Appends a statement to the current basic block.
|
|
/// If the current block has already been terminated, the statement is ignored (dead code elimination).
|
|
fn emit_stmt(&mut self, stmt: Statement) {
|
|
if self.current_block.is_some() {
|
|
self.current_statements.push(stmt);
|
|
}
|
|
}
|
|
|
|
/// Terminates the current basic block, sealing it and finalizing its instructions.
|
|
/// Subsequent statements will be ignored until `switch_to_block` is called.
|
|
fn terminate(&mut self, terminator: Terminator) {
|
|
if let Some(id) = self.current_block.take() {
|
|
self.blocks.push(BasicBlock {
|
|
id,
|
|
statements: std::mem::take(&mut self.current_statements),
|
|
terminator,
|
|
});
|
|
}
|
|
}
|
|
|
|
/// Finalizes the function construction, returning the complete `MirFunction`.
|
|
fn finish(mut self) -> MirFunction {
|
|
// Ensure basic blocks are strictly ordered by their numerical IDs
|
|
self.blocks.sort_by_key(|b| b.id.0);
|
|
|
|
MirFunction {
|
|
name: self.name,
|
|
params: self.params,
|
|
return_type: self.return_type,
|
|
locals: self.locals,
|
|
blocks: self.blocks,
|
|
}
|
|
}
|
|
|
|
/// Recursively lowers a typed statement into MIR instructions and basic blocks.
|
|
fn lower_stmt(&mut self, stmt: &TypedStmt) {
|
|
match &stmt.kind {
|
|
TypedStmtKind::Compound { inner } => {
|
|
self.enter_scope();
|
|
for s in inner {
|
|
self.lower_stmt(s);
|
|
}
|
|
self.leave_scope();
|
|
}
|
|
TypedStmtKind::If {
|
|
condition,
|
|
then,
|
|
elze,
|
|
} => {
|
|
let cond_op = self.lower_expr(condition);
|
|
|
|
let then_block = self.new_block();
|
|
let merge_block = self.new_block();
|
|
let else_block = if elze.is_some() {
|
|
self.new_block()
|
|
} else {
|
|
merge_block
|
|
};
|
|
|
|
self.terminate(Terminator {
|
|
kind: TerminatorKind::CondBranch {
|
|
cond: cond_op,
|
|
target_true: then_block,
|
|
target_false: else_block,
|
|
},
|
|
span: stmt.span,
|
|
});
|
|
|
|
self.switch_to_block(then_block);
|
|
self.lower_stmt(then);
|
|
self.terminate(Terminator {
|
|
kind: TerminatorKind::Goto {
|
|
target: merge_block,
|
|
},
|
|
span: then.span,
|
|
});
|
|
|
|
if let Some(e) = elze {
|
|
self.switch_to_block(else_block);
|
|
self.lower_stmt(e);
|
|
self.terminate(Terminator {
|
|
kind: TerminatorKind::Goto {
|
|
target: merge_block,
|
|
},
|
|
span: e.span,
|
|
});
|
|
}
|
|
|
|
self.switch_to_block(merge_block);
|
|
}
|
|
TypedStmtKind::While { condition, body } => {
|
|
let cond_block = self.new_block();
|
|
let body_block = self.new_block();
|
|
let merge_block = self.new_block();
|
|
|
|
self.terminate(Terminator {
|
|
kind: TerminatorKind::Goto { target: cond_block },
|
|
span: stmt.span,
|
|
});
|
|
|
|
self.switch_to_block(cond_block);
|
|
let cond_op = self.lower_expr(condition);
|
|
self.terminate(Terminator {
|
|
kind: TerminatorKind::CondBranch {
|
|
cond: cond_op,
|
|
target_true: body_block,
|
|
target_false: merge_block,
|
|
},
|
|
span: condition.span,
|
|
});
|
|
|
|
self.switch_to_block(body_block);
|
|
self.loop_stack.push((cond_block, merge_block));
|
|
self.lower_stmt(body);
|
|
self.loop_stack.pop();
|
|
|
|
self.terminate(Terminator {
|
|
kind: TerminatorKind::Goto { target: cond_block },
|
|
span: body.span,
|
|
});
|
|
|
|
self.switch_to_block(merge_block);
|
|
}
|
|
TypedStmtKind::Break => {
|
|
if let Some(&(_, merge_block)) = self.loop_stack.last() {
|
|
self.terminate(Terminator {
|
|
kind: TerminatorKind::Goto {
|
|
target: merge_block,
|
|
},
|
|
span: stmt.span,
|
|
});
|
|
}
|
|
}
|
|
TypedStmtKind::Continue => {
|
|
if let Some(&(cond_block, _)) = self.loop_stack.last() {
|
|
self.terminate(Terminator {
|
|
kind: TerminatorKind::Goto { target: cond_block },
|
|
span: stmt.span,
|
|
});
|
|
}
|
|
}
|
|
TypedStmtKind::Return { value } => {
|
|
let val_op = value.as_ref().map(|v| self.lower_expr(v));
|
|
self.terminate(Terminator {
|
|
kind: TerminatorKind::Return { value: val_op },
|
|
span: stmt.span,
|
|
});
|
|
}
|
|
TypedStmtKind::Let {
|
|
name,
|
|
name_span: _,
|
|
ty,
|
|
initializer,
|
|
} => {
|
|
let local_id = self.new_local(name.clone(), ty.clone());
|
|
|
|
if let Some(init_expr) = initializer {
|
|
let val_op = self.lower_expr(init_expr);
|
|
self.emit_stmt(Statement {
|
|
kind: StatementKind::Assign(local_id, Rvalue::Use(val_op)),
|
|
span: stmt.span,
|
|
});
|
|
}
|
|
}
|
|
TypedStmtKind::Expression { expr } => {
|
|
self.lower_expr(expr);
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Recursively lowers a typed expression into MIR instructions, returning an `Operand` representing its result.
|
|
fn lower_expr(&mut self, expr: &TypedExpr) -> Operand {
|
|
match &expr.kind {
|
|
TypedExprKind::Identifier { name } => {
|
|
let local = self.lookup(name);
|
|
Operand::Copy(local)
|
|
}
|
|
TypedExprKind::Integer { value } => {
|
|
Operand::Constant(ConstantValue::Integer(*value, expr.ty.clone()))
|
|
}
|
|
TypedExprKind::Float { value } => {
|
|
Operand::Constant(ConstantValue::Float(*value, expr.ty.clone()))
|
|
}
|
|
TypedExprKind::Boolean { value } => Operand::Constant(ConstantValue::Boolean(*value)),
|
|
TypedExprKind::Unary { op, expr: inner } => {
|
|
let inner_op = self.lower_expr(inner);
|
|
let temp = self.new_temp(expr.ty.clone());
|
|
|
|
self.emit_stmt(Statement {
|
|
kind: StatementKind::Assign(temp, Rvalue::UnaryOp(*op, inner_op)),
|
|
span: expr.span,
|
|
});
|
|
|
|
Operand::Copy(temp)
|
|
}
|
|
TypedExprKind::Binary { op, lhs, rhs } => {
|
|
let lhs_op = self.lower_expr(lhs);
|
|
let rhs_op = self.lower_expr(rhs);
|
|
let temp = self.new_temp(expr.ty.clone());
|
|
|
|
self.emit_stmt(Statement {
|
|
kind: StatementKind::Assign(temp, Rvalue::BinaryOp(*op, lhs_op, rhs_op)),
|
|
span: expr.span,
|
|
});
|
|
|
|
Operand::Copy(temp)
|
|
}
|
|
TypedExprKind::Assign { lval, rval } => {
|
|
let rval_op = self.lower_expr(rval);
|
|
|
|
let local_id = match &lval.kind {
|
|
TypedExprKind::Identifier { name } => self.lookup(name),
|
|
_ => panic!("invalid lval in MIR lowering"),
|
|
};
|
|
|
|
self.emit_stmt(Statement {
|
|
kind: StatementKind::Assign(local_id, Rvalue::Use(rval_op.clone())),
|
|
span: expr.span,
|
|
});
|
|
|
|
rval_op
|
|
}
|
|
TypedExprKind::Cast {
|
|
expr: inner,
|
|
ty: target_ty,
|
|
} => {
|
|
let inner_op = self.lower_expr(inner);
|
|
let temp = self.new_temp(target_ty.clone());
|
|
|
|
self.emit_stmt(Statement {
|
|
kind: StatementKind::Assign(temp, Rvalue::Cast(target_ty.clone(), inner_op)),
|
|
span: expr.span,
|
|
});
|
|
|
|
Operand::Copy(temp)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod test {
|
|
use super::*;
|
|
use crate::frontend::parser::Parser;
|
|
use crate::frontend::sema::Sema;
|
|
|
|
/// Helper function to parse, analyze, and build a MIR module from source code.
|
|
fn build_mir(source: &str) -> MirModule {
|
|
let mut parser = Parser::new(source);
|
|
let module = parser.parse_module();
|
|
if let Some(errors) = parser.errors() {
|
|
panic!("Parse errors: {:?}", errors);
|
|
}
|
|
|
|
let mut sema = Sema::new();
|
|
let typed_module = sema.analyze_module(&module);
|
|
if let Some(errors) = sema.errors() {
|
|
panic!("Semantic errors: {:?}", errors);
|
|
}
|
|
|
|
MirBuilder::build(&typed_module)
|
|
}
|
|
|
|
#[test]
|
|
fn test_lower_while_loop() {
|
|
let mir = build_mir("fn main() { while true { } }");
|
|
let func = &mir.functions[0];
|
|
|
|
// Ensure exactly 4 basic blocks are generated
|
|
assert_eq!(func.blocks.len(), 4);
|
|
|
|
// Block 0: Entry -> Goto(1)
|
|
assert!(matches!(
|
|
func.blocks[0].terminator.kind,
|
|
TerminatorKind::Goto { target: BlockId(1) }
|
|
));
|
|
// Block 1: Cond -> CondBranch(true -> 2, false -> 3)
|
|
assert!(matches!(
|
|
func.blocks[1].terminator.kind,
|
|
TerminatorKind::CondBranch {
|
|
target_true: BlockId(2),
|
|
target_false: BlockId(3),
|
|
..
|
|
}
|
|
));
|
|
// Block 2: Body -> Goto(1)
|
|
assert!(matches!(
|
|
func.blocks[2].terminator.kind,
|
|
TerminatorKind::Goto { target: BlockId(1) }
|
|
));
|
|
// Block 3: Merge -> Return
|
|
assert!(matches!(
|
|
func.blocks[3].terminator.kind,
|
|
TerminatorKind::Return { value: None }
|
|
));
|
|
}
|
|
|
|
#[test]
|
|
fn test_lower_break_and_continue() {
|
|
let mir = build_mir("fn main() { while true { continue; break; } }");
|
|
let func = &mir.functions[0];
|
|
|
|
// The body block (BlockId(2)) hits `continue` first, meaning it terminates immediately with a Goto back to the condition block.
|
|
// The trailing `break` is correctly skipped by the builder as automatic dead code!
|
|
assert!(matches!(
|
|
func.blocks[2].terminator.kind,
|
|
TerminatorKind::Goto { target: BlockId(1) }
|
|
));
|
|
}
|
|
|
|
#[test]
|
|
fn test_lower_let_and_assign() {
|
|
let mir = build_mir("fn main() { let a = 5; a = 10; }");
|
|
let func = &mir.functions[0];
|
|
|
|
assert_eq!(func.locals.len(), 1); // Only 1 variable 'a' (No temporaries generated for literals)
|
|
assert_eq!(func.blocks.len(), 1); // No branches, so 1 block
|
|
|
|
let block = &func.blocks[0];
|
|
assert_eq!(block.statements.len(), 2); // Includes both the let initialization and the later assignment
|
|
assert!(matches!(
|
|
block.statements[0].kind,
|
|
StatementKind::Assign(LocalId(0), _)
|
|
));
|
|
assert!(matches!(
|
|
block.statements[1].kind,
|
|
StatementKind::Assign(LocalId(0), _)
|
|
));
|
|
}
|
|
}
|