feat: add as type casting

This commit is contained in:
2026-04-22 22:40:19 +02:00
parent e66a4ee736
commit 041a49e574
13 changed files with 350 additions and 1 deletions
+53
View File
@@ -414,6 +414,59 @@ impl<'a> FunctionTranslator<'a> {
} }
} }
} }
Rvalue::Cast(to_ty, inner) => {
let inner_val = self.translate_operand(inner);
let from_ty = self.get_operand_type(inner);
let cl_to_ty = CraneliftBackend::lower_type(to_ty);
if from_ty == *to_ty {
inner_val
} else {
match (from_ty.is_float(), to_ty.is_float()) {
(false, false) => {
// Integer <-> Integer
let from_width = from_ty.bit_width();
let to_width = to_ty.bit_width();
if to_width > from_width {
if from_ty.is_signed() {
self.builder.ins().sextend(cl_to_ty, inner_val)
} else {
self.builder.ins().uextend(cl_to_ty, inner_val)
}
} else if to_width < from_width {
self.builder.ins().ireduce(cl_to_ty, inner_val)
} else {
inner_val // e.g. bitcasting between same-sized int and uint
}
}
(true, true) => {
// Float <-> Float
if to_ty.bit_width() > from_ty.bit_width() {
self.builder.ins().fpromote(cl_to_ty, inner_val)
} else {
self.builder.ins().fdemote(cl_to_ty, inner_val)
}
}
(false, true) => {
// Integer -> Float
if from_ty.is_signed() {
self.builder.ins().fcvt_from_sint(cl_to_ty, inner_val)
} else {
self.builder.ins().fcvt_from_uint(cl_to_ty, inner_val)
}
}
(true, false) => {
// Float -> Integer
if to_ty.is_signed() {
self.builder.ins().fcvt_to_sint_sat(cl_to_ty, inner_val)
} else {
self.builder.ins().fcvt_to_uint_sat(cl_to_ty, inner_val)
}
}
}
}
}
} }
} }
} }
+7
View File
@@ -6,6 +6,7 @@ pub trait Phase: Debug + PartialEq + Eq {
type ReturnType: Debug + PartialEq + Eq; type ReturnType: Debug + PartialEq + Eq;
type ParamType: Debug + PartialEq + Eq; type ParamType: Debug + PartialEq + Eq;
type ExprType: Debug + PartialEq + Eq; type ExprType: Debug + PartialEq + Eq;
type CastType: Debug + PartialEq + Eq;
} }
#[derive(Debug, PartialEq, Eq)] #[derive(Debug, PartialEq, Eq)]
@@ -15,6 +16,7 @@ impl Phase for Untyped {
type ReturnType = Option<Type>; type ReturnType = Option<Type>;
type ParamType = FunctionParam; type ParamType = FunctionParam;
type ExprType = (); type ExprType = ();
type CastType = Type;
} }
#[derive(Debug, PartialEq, Eq)] #[derive(Debug, PartialEq, Eq)]
@@ -24,6 +26,7 @@ impl Phase for Typed {
type ReturnType = Ty; type ReturnType = Ty;
type ParamType = (String, Ty); type ParamType = (String, Ty);
type ExprType = Ty; type ExprType = Ty;
type CastType = Ty;
} }
pub type TypedModule = Module<Typed>; pub type TypedModule = Module<Typed>;
@@ -154,6 +157,10 @@ pub enum ExprKind<P: Phase = Untyped> {
lval: Box<Expr<P>>, lval: Box<Expr<P>>,
rval: Box<Expr<P>>, rval: Box<Expr<P>>,
}, },
Cast {
expr: Box<Expr<P>>,
ty: P::CastType,
},
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
+3 -1
View File
@@ -65,6 +65,7 @@ impl<'src> Lexer<'src> {
match &self.source[start..self.cursor] { match &self.source[start..self.cursor] {
"fn" => TokenKind::Fn, "fn" => TokenKind::Fn,
"if" => TokenKind::If, "if" => TokenKind::If,
"as" => TokenKind::As,
"else" => TokenKind::Else, "else" => TokenKind::Else,
"return" => TokenKind::Return, "return" => TokenKind::Return,
"let" => TokenKind::Let, "let" => TokenKind::Let,
@@ -248,7 +249,7 @@ mod test {
#[test] #[test]
fn identifiers() { fn identifiers() {
assert_eq!( assert_eq!(
tokenize("HELLO _hello _0@ fn if else return let while break continue"), tokenize("HELLO _hello _0@ fn if else return let while break continue as"),
vec![ vec![
Token::new(TokenKind::Identifier, "HELLO", Span::new(0, 5)), Token::new(TokenKind::Identifier, "HELLO", Span::new(0, 5)),
Token::new(TokenKind::Identifier, "_hello", Span::new(6, 12)), Token::new(TokenKind::Identifier, "_hello", Span::new(6, 12)),
@@ -262,6 +263,7 @@ mod test {
Token::new(TokenKind::While, "while", Span::new(39, 44)), Token::new(TokenKind::While, "while", Span::new(39, 44)),
Token::new(TokenKind::Break, "break", Span::new(45, 50)), Token::new(TokenKind::Break, "break", Span::new(45, 50)),
Token::new(TokenKind::Continue, "continue", Span::new(51, 59)), Token::new(TokenKind::Continue, "continue", Span::new(51, 59)),
Token::new(TokenKind::As, "as", Span::new(60, 62)),
] ]
) )
} }
+44
View File
@@ -559,6 +559,28 @@ impl<'src> Parser<'src> {
continue; continue;
} }
if peek_token.kind == TokenKind::As {
let left_bp = 25; // high precedence, tighter than multiply
if left_bp < min_bp {
break;
}
self.advance(); // consume `as`
let ty = self.parse_type()?;
let span = lhs.span.join(ty.span);
lhs = Expr {
kind: ExprKind::Cast {
expr: Box::new(lhs),
ty,
},
ty: (),
span,
};
continue;
}
let Some((op, left_bp, right_bp)) = self.infix_operator(peek_token.kind) else { let Some((op, left_bp, right_bp)) = self.infix_operator(peek_token.kind) else {
break; // Not an infix operator break; // Not an infix operator
}; };
@@ -1229,4 +1251,26 @@ mod test {
}) })
); );
} }
#[test]
fn cast_expr() {
assert_eq!(
parse("5 as f32;", Parser::parse_expr),
Success(Expr {
kind: ExprKind::Cast {
expr: Box::new(Expr {
kind: ExprKind::Integer { value: 5 },
ty: (),
span: Span::new(0, 1)
}),
ty: Type {
kind: TypeKind::F32,
span: Span::new(5, 8)
}
},
ty: (),
span: Span::new(0, 8)
})
);
}
} }
+75
View File
@@ -61,6 +61,30 @@ impl Ty {
pub fn is_numeric(&self) -> bool { pub fn is_numeric(&self) -> bool {
self.is_integer() || self.is_float() self.is_integer() || self.is_float()
} }
/// Returns `true` if the type is signed (including floats).
pub fn is_signed(&self) -> bool {
matches!(
self,
Ty::I8 | Ty::I16 | Ty::I32 | Ty::I64 | Ty::F32 | Ty::F64
)
}
/// Returns `true` if the type is unsigned or a boolean.
pub fn is_unsigned(&self) -> bool {
matches!(self, Ty::U8 | Ty::U16 | Ty::U32 | Ty::U64 | Ty::Bool)
}
/// Returns the exact bit width of the type.
pub fn bit_width(&self) -> usize {
match self {
Ty::I8 | Ty::U8 | Ty::Bool => 8,
Ty::I16 | Ty::U16 => 16,
Ty::I32 | Ty::U32 | Ty::F32 => 32,
Ty::I64 | Ty::U64 | Ty::F64 => 64,
_ => 0,
}
}
} }
impl From<&TypeKind> for Ty { impl From<&TypeKind> for Ty {
@@ -93,6 +117,7 @@ pub struct Sema {
deferred_binary: Vec<(Span, Ty)>, deferred_binary: Vec<(Span, Ty)>,
deferred_int_literals: Vec<(Span, Ty)>, deferred_int_literals: Vec<(Span, Ty)>,
deferred_float_literals: Vec<(Span, Ty)>, deferred_float_literals: Vec<(Span, Ty)>,
deferred_casts: Vec<(Span, Ty, Ty)>,
loop_depth: usize, loop_depth: usize,
} }
@@ -108,6 +133,7 @@ impl Sema {
deferred_binary: Vec::new(), deferred_binary: Vec::new(),
deferred_int_literals: Vec::new(), deferred_int_literals: Vec::new(),
deferred_float_literals: Vec::new(), deferred_float_literals: Vec::new(),
deferred_casts: Vec::new(),
loop_depth: 0, loop_depth: 0,
} }
} }
@@ -621,6 +647,22 @@ impl Sema {
span: expr.span, span: expr.span,
} }
} }
ExprKind::Cast { expr: inner, ty } => {
let typed_inner = self.analyze_expr(inner);
let target_ty = Ty::from(&ty.kind);
self.deferred_casts
.push((expr.span, typed_inner.ty.clone(), target_ty.clone()));
TypedExpr {
kind: TypedExprKind::Cast {
expr: Box::new(typed_inner),
ty: target_ty.clone(),
},
ty: target_ty,
span: expr.span,
}
}
} }
} }
@@ -729,6 +771,10 @@ impl Sema {
lval: Box::new(self.apply_subst_expr(*lval)), lval: Box::new(self.apply_subst_expr(*lval)),
rval: Box::new(self.apply_subst_expr(*rval)), rval: Box::new(self.apply_subst_expr(*rval)),
}, },
TypedExprKind::Cast { expr, ty } => TypedExprKind::Cast {
expr: Box::new(self.apply_subst_expr(*expr)),
ty: self.apply_subst(&ty),
},
}; };
TypedExpr { kind, ty, span } TypedExpr { kind, ty, span }
@@ -859,6 +905,29 @@ impl Sema {
.push(SemanticError::new("expected float type for literal", span)); .push(SemanticError::new("expected float type for literal", span));
} }
} }
for (span, from_ty, to_ty) in std::mem::take(&mut self.deferred_casts) {
let from = self.apply_subst(&from_ty);
let to = self.apply_subst(&to_ty);
let final_from = if let Ty::Var(_) = from {
let default = Ty::I32; // Unconstrained fallback
let _ = self.unify(&from, &default);
default
} else {
from
};
let is_valid = (final_from.is_numeric() || final_from == Ty::Bool)
&& (to.is_numeric() || to == Ty::Bool);
if !is_valid {
self.errors.push(SemanticError::new(
format!("cannot cast from `{:?}` to `{:?}`", final_from, to),
span,
));
}
}
} }
} }
@@ -1078,4 +1147,10 @@ mod test {
.any(|e| e.message.contains("`break` outside of a loop")) .any(|e| e.message.contains("`break` outside of a loop"))
); );
} }
#[test]
fn valid_cast() {
let src = "fn test() { let a: f32 = 10.0; let b = a as i32; }";
assert!(analyze(src).is_ok());
}
} }
+2
View File
@@ -58,6 +58,7 @@ pub enum TokenKind {
// Keywords // Keywords
Fn, Fn,
If, If,
As,
Else, Else,
Return, Return,
Let, Let,
@@ -120,6 +121,7 @@ impl Display for TokenKind {
TokenKind::BooleanLit => "a boolean", TokenKind::BooleanLit => "a boolean",
TokenKind::FloatLit => "a float", TokenKind::FloatLit => "a float",
TokenKind::Fn => "`fn`", TokenKind::Fn => "`fn`",
TokenKind::As => "`as`",
TokenKind::If => "`if`", TokenKind::If => "`if`",
TokenKind::Else => "`else`", TokenKind::Else => "`else`",
TokenKind::Return => "`return`", TokenKind::Return => "`return`",
+14
View File
@@ -385,6 +385,20 @@ impl FuncBuilder {
rval_op rval_op
} }
TypedExprKind::Cast {
expr: inner,
ty: target_ty,
} => {
let inner_op = self.lower_expr(inner);
let temp = self.new_temp(target_ty.clone());
self.emit_stmt(Statement {
kind: StatementKind::Assign(temp, Rvalue::Cast(target_ty.clone(), inner_op)),
span: expr.span,
});
Operand::Copy(temp)
}
} }
} }
} }
+59
View File
@@ -75,6 +75,7 @@ fn propagate_rvalue(rvalue: &mut Rvalue, known_constants: &HashMap<LocalId, Cons
propagate_operand(lhs, known_constants); propagate_operand(lhs, known_constants);
propagate_operand(rhs, known_constants); propagate_operand(rhs, known_constants);
} }
Rvalue::Cast(_, op) => propagate_operand(op, known_constants),
} }
} }
@@ -85,6 +86,7 @@ fn evaluate_rvalue(rvalue: &Rvalue) -> Option<ConstantValue> {
Rvalue::BinaryOp(op, Operand::Constant(l), Operand::Constant(r)) => { Rvalue::BinaryOp(op, Operand::Constant(l), Operand::Constant(r)) => {
evaluate_binary(*op, l, r) evaluate_binary(*op, l, r)
} }
Rvalue::Cast(to_ty, Operand::Constant(c)) => evaluate_cast(to_ty, c),
_ => None, _ => None,
} }
} }
@@ -184,6 +186,63 @@ fn evaluate_binary(
} }
} }
fn evaluate_cast(to_ty: &Ty, val: &ConstantValue) -> Option<ConstantValue> {
if to_ty.is_float() {
let f = match val {
ConstantValue::Integer(v, ty) => {
if ty.is_signed() {
let shift = 64 - ty.bit_width();
(((*v as i64) << shift) >> shift) as f64
} else {
*v as f64
}
}
ConstantValue::Float(v, _) => *v,
ConstantValue::Boolean(b) => {
if *b {
1.0
} else {
0.0
}
}
};
Some(ConstantValue::Float(f, to_ty.clone()))
} else if to_ty.is_integer() {
let i = match val {
ConstantValue::Integer(v, _) => *v,
ConstantValue::Float(v, _) => {
if to_ty.is_signed() {
(*v as i64) as u64
} else {
*v as u64
}
}
ConstantValue::Boolean(b) => {
if *b {
1
} else {
0
}
}
};
let mask = if to_ty.bit_width() == 64 {
u64::MAX
} else {
(1u64 << to_ty.bit_width()) - 1
};
Some(ConstantValue::Integer(i & mask, to_ty.clone()))
} else if to_ty == &Ty::Bool {
let b = match val {
ConstantValue::Integer(v, _) => *v != 0,
ConstantValue::Float(v, _) => *v != 0.0,
ConstantValue::Boolean(b) => *b,
};
Some(ConstantValue::Boolean(b))
} else {
None
}
}
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use super::*; use super::*;
+1
View File
@@ -60,6 +60,7 @@ pub enum Rvalue {
Use(Operand), Use(Operand),
UnaryOp(UnaryOp, Operand), UnaryOp(UnaryOp, Operand),
BinaryOp(BinaryOp, Operand, Operand), BinaryOp(BinaryOp, Operand, Operand),
Cast(Ty, Operand),
} }
/// An atomic value used as inputs to instructions. /// An atomic value used as inputs to instructions.
+21
View File
@@ -0,0 +1,21 @@
[code]
fn promote_f32_to_f64(x: f32) -> f64 {
return x as f64;
}
fn demote_f64_to_f32(x: f64) -> f32 {
return x as f32;
}
[harness]
extern double promote_f32_to_f64(float x);
extern float demote_f64_to_f32(double x);
int main() {
if (promote_f32_to_f64(3.5f) != 3.5) return 1;
if (demote_f64_to_f32(2.25) != 2.25f) return 2;
return 0;
}
[expected_return_code]
0
+24
View File
@@ -0,0 +1,24 @@
[code]
fn cast_f32_to_i32(x: f32) -> i32 {
return x as i32;
}
fn cast_f64_to_u64(x: f64) -> u64 {
return x as u64;
}
[harness]
#include <stdint.h>
extern int32_t cast_f32_to_i32(float x);
extern uint64_t cast_f64_to_u64(double x);
int main() {
if (cast_f32_to_i32(3.14f) != 3) return 1;
if (cast_f32_to_i32(-2.9f) != -2) return 2;
if (cast_f64_to_u64(42.999) != 42) return 3;
return 0;
}
[expected_return_code]
0
+24
View File
@@ -0,0 +1,24 @@
[code]
fn cast_i32_to_f32(x: i32) -> f32 {
return x as f32;
}
fn cast_u64_to_f64(x: u64) -> f64 {
return x as f64;
}
[harness]
#include <stdint.h>
extern float cast_i32_to_f32(int32_t x);
extern double cast_u64_to_f64(uint64_t x);
int main() {
if (cast_i32_to_f32(42) != 42.0f) return 1;
if (cast_i32_to_f32(-10) != -10.0f) return 2;
if (cast_u64_to_f64(1337) != 1337.0) return 3;
return 0;
}
[expected_return_code]
0
+23
View File
@@ -0,0 +1,23 @@
[code]
fn truncate_i32_to_i8(x: i32) -> i8 {
return x as i8;
}
fn extend_u8_to_i32(x: u8) -> i32 {
return x as i32;
}
[harness]
#include <stdint.h>
extern int8_t truncate_i32_to_i8(int32_t x);
extern int32_t extend_u8_to_i32(uint8_t x);
int main() {
if (truncate_i32_to_i8(257) != 1) return 1; // 257 = 0x0101, truncated to 8 bits is 0x01
if (extend_u8_to_i32(255) != 255) return 2;
return 0;
}
[expected_return_code]
0