use std::collections::HashMap; use crate::frontend::ast::*; use crate::frontend::sema::Ty; use crate::middle::mir::*; /// Lowers a fully-typed AST into a linear Control Flow Graph (MIR). pub struct MirBuilder; impl MirBuilder { /// Builds a `MirModule` from a `TypedModule`. pub fn build(module: &TypedModule) -> MirModule { let mut functions = Vec::new(); for decl in &module.decls { match &decl.kind { TypedDeclKind::Function { name, name_span: _, params, return_type, body, } => { let mut builder = FuncBuilder::new(name.clone(), return_type.clone()); // Register parameters as local variables for (param_name, ty) in params { let local_id = builder.new_local(param_name.clone(), ty.clone()); builder.params.push(local_id); } let entry = builder.new_block(); builder.switch_to_block(entry); builder.lower_stmt(body); // If the final block was never terminated (e.g. missing an explicit return), // we insert an implicit return for Unit, or an Unreachable trap otherwise. if builder.current_block.is_some() { let span = body.span; if *return_type == Ty::Unit { builder.terminate(Terminator { kind: TerminatorKind::Return { value: None }, span, }); } else { builder.terminate(Terminator { kind: TerminatorKind::Unreachable, span, }); } } functions.push(builder.finish()); } } } MirModule { functions } } } /// A helper struct that manages the state required to construct a single `MirFunction`. /// /// It keeps track of variable declarations, basic blocks, and handles the flattening /// of nested AST nodes into a sequential Control Flow Graph. struct FuncBuilder { /// The name of the function being built. name: String, /// The expected return type of the function. return_type: Ty, /// `LocalId`s corresponding to the function's parameters. params: Vec, /// All local variables and compiler-generated temporaries used in the function. locals: Vec, /// The basic blocks that make up the function's control flow graph. blocks: Vec, /// The block currently being populated with statements. If `None`, the builder is in "dead code" territory. current_block: Option, /// Statements buffered for the current block before a terminator is reached. current_statements: Vec, /// Counter for generating unique `BlockId`s. next_block_id: usize, /// Scoped mapping from user-defined variable names to their corresponding `LocalId`. scopes: Vec>, /// Stack of `(continue_target, break_target)` for nested loops loop_stack: Vec<(BlockId, BlockId)>, } impl FuncBuilder { /// Creates a new `FuncBuilder` for a function with the given name and return type. fn new(name: String, return_type: Ty) -> Self { Self { name, return_type, params: Vec::new(), locals: Vec::new(), blocks: Vec::new(), current_block: None, current_statements: Vec::new(), next_block_id: 0, scopes: vec![HashMap::new()], loop_stack: Vec::new(), } } fn enter_scope(&mut self) { self.scopes.push(HashMap::new()); } fn leave_scope(&mut self) { self.scopes.pop(); } /// Registers a new user-defined local variable and returns its `LocalId`. fn new_local(&mut self, name: String, ty: Ty) -> LocalId { let id = LocalId(self.locals.len()); self.locals.push(LocalDecl { id, ty, mutable: false, name: Some(name.clone()), }); self.scopes.last_mut().unwrap().insert(name, id); id } /// Creates a new compiler-generated temporary variable and returns its `LocalId`. fn new_temp(&mut self, ty: Ty) -> LocalId { let id = LocalId(self.locals.len()); self.locals.push(LocalDecl { id, ty, mutable: false, name: None, }); id } fn lookup(&self, name: &str) -> LocalId { for scope in self.scopes.iter().rev() { if let Some(id) = scope.get(name) { return *id; } } panic!("undeclared variable `{}` in MIR lowering", name); } /// Allocates a new, empty basic block and returns its `BlockId`. fn new_block(&mut self) -> BlockId { let id = BlockId(self.next_block_id); self.next_block_id += 1; id } /// Sets the given block as the active insertion point for new statements. /// Note: The previous block should be terminated before switching. fn switch_to_block(&mut self, id: BlockId) { assert!(self.current_statements.is_empty()); self.current_block = Some(id); } /// Appends a statement to the current basic block. /// If the current block has already been terminated, the statement is ignored (dead code elimination). fn emit_stmt(&mut self, stmt: Statement) { if self.current_block.is_some() { self.current_statements.push(stmt); } } /// Terminates the current basic block, sealing it and finalizing its instructions. /// Subsequent statements will be ignored until `switch_to_block` is called. fn terminate(&mut self, terminator: Terminator) { if let Some(id) = self.current_block.take() { self.blocks.push(BasicBlock { id, statements: std::mem::take(&mut self.current_statements), terminator, }); } } /// Finalizes the function construction, returning the complete `MirFunction`. fn finish(mut self) -> MirFunction { // Ensure basic blocks are strictly ordered by their numerical IDs self.blocks.sort_by_key(|b| b.id.0); MirFunction { name: self.name, params: self.params, return_type: self.return_type, locals: self.locals, blocks: self.blocks, } } /// Recursively lowers a typed statement into MIR instructions and basic blocks. fn lower_stmt(&mut self, stmt: &TypedStmt) { match &stmt.kind { TypedStmtKind::Compound { inner } => { self.enter_scope(); for s in inner { self.lower_stmt(s); } self.leave_scope(); } TypedStmtKind::If { condition, then, elze, } => { let cond_op = self.lower_expr(condition); let then_block = self.new_block(); let merge_block = self.new_block(); let else_block = if elze.is_some() { self.new_block() } else { merge_block }; self.terminate(Terminator { kind: TerminatorKind::CondBranch { cond: cond_op, target_true: then_block, target_false: else_block, }, span: stmt.span, }); self.switch_to_block(then_block); self.lower_stmt(then); self.terminate(Terminator { kind: TerminatorKind::Goto { target: merge_block, }, span: then.span, }); if let Some(e) = elze { self.switch_to_block(else_block); self.lower_stmt(e); self.terminate(Terminator { kind: TerminatorKind::Goto { target: merge_block, }, span: e.span, }); } self.switch_to_block(merge_block); } TypedStmtKind::While { condition, body } => { let cond_block = self.new_block(); let body_block = self.new_block(); let merge_block = self.new_block(); self.terminate(Terminator { kind: TerminatorKind::Goto { target: cond_block }, span: stmt.span, }); self.switch_to_block(cond_block); let cond_op = self.lower_expr(condition); self.terminate(Terminator { kind: TerminatorKind::CondBranch { cond: cond_op, target_true: body_block, target_false: merge_block, }, span: condition.span, }); self.switch_to_block(body_block); self.loop_stack.push((cond_block, merge_block)); self.lower_stmt(body); self.loop_stack.pop(); self.terminate(Terminator { kind: TerminatorKind::Goto { target: cond_block }, span: body.span, }); self.switch_to_block(merge_block); } TypedStmtKind::Break => { if let Some(&(_, merge_block)) = self.loop_stack.last() { self.terminate(Terminator { kind: TerminatorKind::Goto { target: merge_block, }, span: stmt.span, }); } } TypedStmtKind::Continue => { if let Some(&(cond_block, _)) = self.loop_stack.last() { self.terminate(Terminator { kind: TerminatorKind::Goto { target: cond_block }, span: stmt.span, }); } } TypedStmtKind::Return { value } => { let val_op = value.as_ref().map(|v| self.lower_expr(v)); self.terminate(Terminator { kind: TerminatorKind::Return { value: val_op }, span: stmt.span, }); } TypedStmtKind::Let { name, name_span: _, ty, initializer, } => { let local_id = self.new_local(name.clone(), ty.clone()); if let Some(init_expr) = initializer { let val_op = self.lower_expr(init_expr); self.emit_stmt(Statement { kind: StatementKind::Assign(local_id, Rvalue::Use(val_op)), span: stmt.span, }); } } TypedStmtKind::Expression { expr } => { self.lower_expr(expr); } } } /// Recursively lowers a typed expression into MIR instructions, returning an `Operand` representing its result. fn lower_expr(&mut self, expr: &TypedExpr) -> Operand { match &expr.kind { TypedExprKind::Identifier { name } => { let local = self.lookup(name); Operand::Copy(local) } TypedExprKind::Integer { value } => { Operand::Constant(ConstantValue::Integer(*value, expr.ty.clone())) } TypedExprKind::Float { value } => { Operand::Constant(ConstantValue::Float(*value, expr.ty.clone())) } TypedExprKind::Boolean { value } => Operand::Constant(ConstantValue::Boolean(*value)), TypedExprKind::Unary { op, expr: inner } => { let inner_op = self.lower_expr(inner); let temp = self.new_temp(expr.ty.clone()); self.emit_stmt(Statement { kind: StatementKind::Assign(temp, Rvalue::UnaryOp(*op, inner_op)), span: expr.span, }); Operand::Copy(temp) } TypedExprKind::Binary { op, lhs, rhs } => { let lhs_op = self.lower_expr(lhs); let rhs_op = self.lower_expr(rhs); let temp = self.new_temp(expr.ty.clone()); self.emit_stmt(Statement { kind: StatementKind::Assign(temp, Rvalue::BinaryOp(*op, lhs_op, rhs_op)), span: expr.span, }); Operand::Copy(temp) } TypedExprKind::Assign { lval, rval } => { let rval_op = self.lower_expr(rval); let local_id = match &lval.kind { TypedExprKind::Identifier { name } => self.lookup(name), _ => panic!("invalid lval in MIR lowering"), }; self.emit_stmt(Statement { kind: StatementKind::Assign(local_id, Rvalue::Use(rval_op.clone())), span: expr.span, }); rval_op } TypedExprKind::Cast { expr: inner, ty: target_ty, } => { let inner_op = self.lower_expr(inner); let temp = self.new_temp(target_ty.clone()); self.emit_stmt(Statement { kind: StatementKind::Assign(temp, Rvalue::Cast(target_ty.clone(), inner_op)), span: expr.span, }); Operand::Copy(temp) } } } } #[cfg(test)] mod test { use super::*; use crate::frontend::parser::Parser; use crate::frontend::sema::Sema; /// Helper function to parse, analyze, and build a MIR module from source code. fn build_mir(source: &str) -> MirModule { let mut parser = Parser::new(source); let module = parser.parse_module(); if let Some(errors) = parser.errors() { panic!("Parse errors: {:?}", errors); } let mut sema = Sema::new(); let typed_module = sema.analyze_module(&module); if let Some(errors) = sema.errors() { panic!("Semantic errors: {:?}", errors); } MirBuilder::build(&typed_module) } #[test] fn test_lower_while_loop() { let mir = build_mir("fn main() { while true { } }"); let func = &mir.functions[0]; // Ensure exactly 4 basic blocks are generated assert_eq!(func.blocks.len(), 4); // Block 0: Entry -> Goto(1) assert!(matches!( func.blocks[0].terminator.kind, TerminatorKind::Goto { target: BlockId(1) } )); // Block 1: Cond -> CondBranch(true -> 2, false -> 3) assert!(matches!( func.blocks[1].terminator.kind, TerminatorKind::CondBranch { target_true: BlockId(2), target_false: BlockId(3), .. } )); // Block 2: Body -> Goto(1) assert!(matches!( func.blocks[2].terminator.kind, TerminatorKind::Goto { target: BlockId(1) } )); // Block 3: Merge -> Return assert!(matches!( func.blocks[3].terminator.kind, TerminatorKind::Return { value: None } )); } #[test] fn test_lower_break_and_continue() { let mir = build_mir("fn main() { while true { continue; break; } }"); let func = &mir.functions[0]; // The body block (BlockId(2)) hits `continue` first, meaning it terminates immediately with a Goto back to the condition block. // The trailing `break` is correctly skipped by the builder as automatic dead code! assert!(matches!( func.blocks[2].terminator.kind, TerminatorKind::Goto { target: BlockId(1) } )); } #[test] fn test_lower_let_and_assign() { let mir = build_mir("fn main() { let a = 5; a = 10; }"); let func = &mir.functions[0]; assert_eq!(func.locals.len(), 1); // Only 1 variable 'a' (No temporaries generated for literals) assert_eq!(func.blocks.len(), 1); // No branches, so 1 block let block = &func.blocks[0]; assert_eq!(block.statements.len(), 2); // Includes both the let initialization and the later assignment assert!(matches!( block.statements[0].kind, StatementKind::Assign(LocalId(0), _) )); assert!(matches!( block.statements[1].kind, StatementKind::Assign(LocalId(0), _) )); } }