diff --git a/src/backend/cranelift.rs b/src/backend/cranelift.rs index dd300ad..1caab3f 100644 --- a/src/backend/cranelift.rs +++ b/src/backend/cranelift.rs @@ -414,6 +414,59 @@ impl<'a> FunctionTranslator<'a> { } } } + Rvalue::Cast(to_ty, inner) => { + let inner_val = self.translate_operand(inner); + let from_ty = self.get_operand_type(inner); + let cl_to_ty = CraneliftBackend::lower_type(to_ty); + + if from_ty == *to_ty { + inner_val + } else { + match (from_ty.is_float(), to_ty.is_float()) { + (false, false) => { + // Integer <-> Integer + let from_width = from_ty.bit_width(); + let to_width = to_ty.bit_width(); + + if to_width > from_width { + if from_ty.is_signed() { + self.builder.ins().sextend(cl_to_ty, inner_val) + } else { + self.builder.ins().uextend(cl_to_ty, inner_val) + } + } else if to_width < from_width { + self.builder.ins().ireduce(cl_to_ty, inner_val) + } else { + inner_val // e.g. bitcasting between same-sized int and uint + } + } + (true, true) => { + // Float <-> Float + if to_ty.bit_width() > from_ty.bit_width() { + self.builder.ins().fpromote(cl_to_ty, inner_val) + } else { + self.builder.ins().fdemote(cl_to_ty, inner_val) + } + } + (false, true) => { + // Integer -> Float + if from_ty.is_signed() { + self.builder.ins().fcvt_from_sint(cl_to_ty, inner_val) + } else { + self.builder.ins().fcvt_from_uint(cl_to_ty, inner_val) + } + } + (true, false) => { + // Float -> Integer + if to_ty.is_signed() { + self.builder.ins().fcvt_to_sint_sat(cl_to_ty, inner_val) + } else { + self.builder.ins().fcvt_to_uint_sat(cl_to_ty, inner_val) + } + } + } + } + } } } } diff --git a/src/frontend/ast.rs b/src/frontend/ast.rs index 2363d48..28789a6 100644 --- a/src/frontend/ast.rs +++ b/src/frontend/ast.rs @@ -6,6 +6,7 @@ pub trait Phase: Debug + PartialEq + Eq { type ReturnType: Debug + PartialEq + Eq; type ParamType: Debug + PartialEq + Eq; type ExprType: Debug + PartialEq + Eq; + type CastType: Debug + PartialEq + Eq; } #[derive(Debug, PartialEq, Eq)] @@ -15,6 +16,7 @@ impl Phase for Untyped { type ReturnType = Option; type ParamType = FunctionParam; type ExprType = (); + type CastType = Type; } #[derive(Debug, PartialEq, Eq)] @@ -24,6 +26,7 @@ impl Phase for Typed { type ReturnType = Ty; type ParamType = (String, Ty); type ExprType = Ty; + type CastType = Ty; } pub type TypedModule = Module; @@ -154,6 +157,10 @@ pub enum ExprKind { lval: Box>, rval: Box>, }, + Cast { + expr: Box>, + ty: P::CastType, + }, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] diff --git a/src/frontend/lexer.rs b/src/frontend/lexer.rs index d595df0..c4f45f5 100644 --- a/src/frontend/lexer.rs +++ b/src/frontend/lexer.rs @@ -65,6 +65,7 @@ impl<'src> Lexer<'src> { match &self.source[start..self.cursor] { "fn" => TokenKind::Fn, "if" => TokenKind::If, + "as" => TokenKind::As, "else" => TokenKind::Else, "return" => TokenKind::Return, "let" => TokenKind::Let, @@ -248,7 +249,7 @@ mod test { #[test] fn identifiers() { assert_eq!( - tokenize("HELLO _hello _0@ fn if else return let while break continue"), + tokenize("HELLO _hello _0@ fn if else return let while break continue as"), vec![ Token::new(TokenKind::Identifier, "HELLO", Span::new(0, 5)), Token::new(TokenKind::Identifier, "_hello", Span::new(6, 12)), @@ -262,6 +263,7 @@ mod test { Token::new(TokenKind::While, "while", Span::new(39, 44)), Token::new(TokenKind::Break, "break", Span::new(45, 50)), Token::new(TokenKind::Continue, "continue", Span::new(51, 59)), + Token::new(TokenKind::As, "as", Span::new(60, 62)), ] ) } diff --git a/src/frontend/parser.rs b/src/frontend/parser.rs index 4b58f09..800df80 100644 --- a/src/frontend/parser.rs +++ b/src/frontend/parser.rs @@ -559,6 +559,28 @@ impl<'src> Parser<'src> { continue; } + if peek_token.kind == TokenKind::As { + let left_bp = 25; // high precedence, tighter than multiply + if left_bp < min_bp { + break; + } + self.advance(); // consume `as` + + let ty = self.parse_type()?; + let span = lhs.span.join(ty.span); + + lhs = Expr { + kind: ExprKind::Cast { + expr: Box::new(lhs), + ty, + }, + ty: (), + span, + }; + + continue; + } + let Some((op, left_bp, right_bp)) = self.infix_operator(peek_token.kind) else { break; // Not an infix operator }; @@ -1229,4 +1251,26 @@ mod test { }) ); } + + #[test] + fn cast_expr() { + assert_eq!( + parse("5 as f32;", Parser::parse_expr), + Success(Expr { + kind: ExprKind::Cast { + expr: Box::new(Expr { + kind: ExprKind::Integer { value: 5 }, + ty: (), + span: Span::new(0, 1) + }), + ty: Type { + kind: TypeKind::F32, + span: Span::new(5, 8) + } + }, + ty: (), + span: Span::new(0, 8) + }) + ); + } } diff --git a/src/frontend/sema.rs b/src/frontend/sema.rs index 2acc75e..0ad9e5e 100644 --- a/src/frontend/sema.rs +++ b/src/frontend/sema.rs @@ -61,6 +61,30 @@ impl Ty { pub fn is_numeric(&self) -> bool { self.is_integer() || self.is_float() } + + /// Returns `true` if the type is signed (including floats). + pub fn is_signed(&self) -> bool { + matches!( + self, + Ty::I8 | Ty::I16 | Ty::I32 | Ty::I64 | Ty::F32 | Ty::F64 + ) + } + + /// Returns `true` if the type is unsigned or a boolean. + pub fn is_unsigned(&self) -> bool { + matches!(self, Ty::U8 | Ty::U16 | Ty::U32 | Ty::U64 | Ty::Bool) + } + + /// Returns the exact bit width of the type. + pub fn bit_width(&self) -> usize { + match self { + Ty::I8 | Ty::U8 | Ty::Bool => 8, + Ty::I16 | Ty::U16 => 16, + Ty::I32 | Ty::U32 | Ty::F32 => 32, + Ty::I64 | Ty::U64 | Ty::F64 => 64, + _ => 0, + } + } } impl From<&TypeKind> for Ty { @@ -93,6 +117,7 @@ pub struct Sema { deferred_binary: Vec<(Span, Ty)>, deferred_int_literals: Vec<(Span, Ty)>, deferred_float_literals: Vec<(Span, Ty)>, + deferred_casts: Vec<(Span, Ty, Ty)>, loop_depth: usize, } @@ -108,6 +133,7 @@ impl Sema { deferred_binary: Vec::new(), deferred_int_literals: Vec::new(), deferred_float_literals: Vec::new(), + deferred_casts: Vec::new(), loop_depth: 0, } } @@ -621,6 +647,22 @@ impl Sema { span: expr.span, } } + ExprKind::Cast { expr: inner, ty } => { + let typed_inner = self.analyze_expr(inner); + let target_ty = Ty::from(&ty.kind); + + self.deferred_casts + .push((expr.span, typed_inner.ty.clone(), target_ty.clone())); + + TypedExpr { + kind: TypedExprKind::Cast { + expr: Box::new(typed_inner), + ty: target_ty.clone(), + }, + ty: target_ty, + span: expr.span, + } + } } } @@ -729,6 +771,10 @@ impl Sema { lval: Box::new(self.apply_subst_expr(*lval)), rval: Box::new(self.apply_subst_expr(*rval)), }, + TypedExprKind::Cast { expr, ty } => TypedExprKind::Cast { + expr: Box::new(self.apply_subst_expr(*expr)), + ty: self.apply_subst(&ty), + }, }; TypedExpr { kind, ty, span } @@ -859,6 +905,29 @@ impl Sema { .push(SemanticError::new("expected float type for literal", span)); } } + + for (span, from_ty, to_ty) in std::mem::take(&mut self.deferred_casts) { + let from = self.apply_subst(&from_ty); + let to = self.apply_subst(&to_ty); + + let final_from = if let Ty::Var(_) = from { + let default = Ty::I32; // Unconstrained fallback + let _ = self.unify(&from, &default); + default + } else { + from + }; + + let is_valid = (final_from.is_numeric() || final_from == Ty::Bool) + && (to.is_numeric() || to == Ty::Bool); + + if !is_valid { + self.errors.push(SemanticError::new( + format!("cannot cast from `{:?}` to `{:?}`", final_from, to), + span, + )); + } + } } } @@ -1078,4 +1147,10 @@ mod test { .any(|e| e.message.contains("`break` outside of a loop")) ); } + + #[test] + fn valid_cast() { + let src = "fn test() { let a: f32 = 10.0; let b = a as i32; }"; + assert!(analyze(src).is_ok()); + } } diff --git a/src/frontend/token.rs b/src/frontend/token.rs index 9a88c68..fe7e660 100644 --- a/src/frontend/token.rs +++ b/src/frontend/token.rs @@ -58,6 +58,7 @@ pub enum TokenKind { // Keywords Fn, If, + As, Else, Return, Let, @@ -120,6 +121,7 @@ impl Display for TokenKind { TokenKind::BooleanLit => "a boolean", TokenKind::FloatLit => "a float", TokenKind::Fn => "`fn`", + TokenKind::As => "`as`", TokenKind::If => "`if`", TokenKind::Else => "`else`", TokenKind::Return => "`return`", diff --git a/src/middle/builder.rs b/src/middle/builder.rs index a2f6447..541dd56 100644 --- a/src/middle/builder.rs +++ b/src/middle/builder.rs @@ -385,6 +385,20 @@ impl FuncBuilder { rval_op } + TypedExprKind::Cast { + expr: inner, + ty: target_ty, + } => { + let inner_op = self.lower_expr(inner); + let temp = self.new_temp(target_ty.clone()); + + self.emit_stmt(Statement { + kind: StatementKind::Assign(temp, Rvalue::Cast(target_ty.clone(), inner_op)), + span: expr.span, + }); + + Operand::Copy(temp) + } } } } diff --git a/src/middle/fold.rs b/src/middle/fold.rs index 12cd094..801fa97 100644 --- a/src/middle/fold.rs +++ b/src/middle/fold.rs @@ -75,6 +75,7 @@ fn propagate_rvalue(rvalue: &mut Rvalue, known_constants: &HashMap propagate_operand(op, known_constants), } } @@ -85,6 +86,7 @@ fn evaluate_rvalue(rvalue: &Rvalue) -> Option { Rvalue::BinaryOp(op, Operand::Constant(l), Operand::Constant(r)) => { evaluate_binary(*op, l, r) } + Rvalue::Cast(to_ty, Operand::Constant(c)) => evaluate_cast(to_ty, c), _ => None, } } @@ -184,6 +186,63 @@ fn evaluate_binary( } } +fn evaluate_cast(to_ty: &Ty, val: &ConstantValue) -> Option { + if to_ty.is_float() { + let f = match val { + ConstantValue::Integer(v, ty) => { + if ty.is_signed() { + let shift = 64 - ty.bit_width(); + (((*v as i64) << shift) >> shift) as f64 + } else { + *v as f64 + } + } + ConstantValue::Float(v, _) => *v, + ConstantValue::Boolean(b) => { + if *b { + 1.0 + } else { + 0.0 + } + } + }; + Some(ConstantValue::Float(f, to_ty.clone())) + } else if to_ty.is_integer() { + let i = match val { + ConstantValue::Integer(v, _) => *v, + ConstantValue::Float(v, _) => { + if to_ty.is_signed() { + (*v as i64) as u64 + } else { + *v as u64 + } + } + ConstantValue::Boolean(b) => { + if *b { + 1 + } else { + 0 + } + } + }; + let mask = if to_ty.bit_width() == 64 { + u64::MAX + } else { + (1u64 << to_ty.bit_width()) - 1 + }; + Some(ConstantValue::Integer(i & mask, to_ty.clone())) + } else if to_ty == &Ty::Bool { + let b = match val { + ConstantValue::Integer(v, _) => *v != 0, + ConstantValue::Float(v, _) => *v != 0.0, + ConstantValue::Boolean(b) => *b, + }; + Some(ConstantValue::Boolean(b)) + } else { + None + } +} + #[cfg(test)] mod test { use super::*; diff --git a/src/middle/mir.rs b/src/middle/mir.rs index fedfa62..719d995 100644 --- a/src/middle/mir.rs +++ b/src/middle/mir.rs @@ -60,6 +60,7 @@ pub enum Rvalue { Use(Operand), UnaryOp(UnaryOp, Operand), BinaryOp(BinaryOp, Operand, Operand), + Cast(Ty, Operand), } /// An atomic value used as inputs to instructions. diff --git a/tests/cast_float_to_float.test b/tests/cast_float_to_float.test new file mode 100644 index 0000000..5cfd26f --- /dev/null +++ b/tests/cast_float_to_float.test @@ -0,0 +1,21 @@ +[code] +fn promote_f32_to_f64(x: f32) -> f64 { + return x as f64; +} + +fn demote_f64_to_f32(x: f64) -> f32 { + return x as f32; +} + +[harness] +extern double promote_f32_to_f64(float x); +extern float demote_f64_to_f32(double x); + +int main() { + if (promote_f32_to_f64(3.5f) != 3.5) return 1; + if (demote_f64_to_f32(2.25) != 2.25f) return 2; + return 0; +} + +[expected_return_code] +0 \ No newline at end of file diff --git a/tests/cast_float_to_int.test b/tests/cast_float_to_int.test new file mode 100644 index 0000000..f0e2e92 --- /dev/null +++ b/tests/cast_float_to_int.test @@ -0,0 +1,24 @@ +[code] +fn cast_f32_to_i32(x: f32) -> i32 { + return x as i32; +} + +fn cast_f64_to_u64(x: f64) -> u64 { + return x as u64; +} + +[harness] +#include + +extern int32_t cast_f32_to_i32(float x); +extern uint64_t cast_f64_to_u64(double x); + +int main() { + if (cast_f32_to_i32(3.14f) != 3) return 1; + if (cast_f32_to_i32(-2.9f) != -2) return 2; + if (cast_f64_to_u64(42.999) != 42) return 3; + return 0; +} + +[expected_return_code] +0 \ No newline at end of file diff --git a/tests/cast_int_to_float.test b/tests/cast_int_to_float.test new file mode 100644 index 0000000..3d7f861 --- /dev/null +++ b/tests/cast_int_to_float.test @@ -0,0 +1,24 @@ +[code] +fn cast_i32_to_f32(x: i32) -> f32 { + return x as f32; +} + +fn cast_u64_to_f64(x: u64) -> f64 { + return x as f64; +} + +[harness] +#include + +extern float cast_i32_to_f32(int32_t x); +extern double cast_u64_to_f64(uint64_t x); + +int main() { + if (cast_i32_to_f32(42) != 42.0f) return 1; + if (cast_i32_to_f32(-10) != -10.0f) return 2; + if (cast_u64_to_f64(1337) != 1337.0) return 3; + return 0; +} + +[expected_return_code] +0 \ No newline at end of file diff --git a/tests/cast_int_to_int.test b/tests/cast_int_to_int.test new file mode 100644 index 0000000..8d02d9b --- /dev/null +++ b/tests/cast_int_to_int.test @@ -0,0 +1,23 @@ +[code] +fn truncate_i32_to_i8(x: i32) -> i8 { + return x as i8; +} + +fn extend_u8_to_i32(x: u8) -> i32 { + return x as i32; +} + +[harness] +#include + +extern int8_t truncate_i32_to_i8(int32_t x); +extern int32_t extend_u8_to_i32(uint8_t x); + +int main() { + if (truncate_i32_to_i8(257) != 1) return 1; // 257 = 0x0101, truncated to 8 bits is 0x01 + if (extend_u8_to_i32(255) != 255) return 2; + return 0; +} + +[expected_return_code] +0 \ No newline at end of file