diff --git a/jit/src/instructions.rs b/jit/src/instructions.rs index b495c6e1f6..1b74760dc0 100644 --- a/jit/src/instructions.rs +++ b/jit/src/instructions.rs @@ -425,12 +425,25 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> { (BinaryOperator::Subtract, JitValue::Int(a), JitValue::Int(b)) => { JitValue::Int(self.compile_sub(a, b)) } + (BinaryOperator::Multiply, JitValue::Int(a), JitValue::Int(b)) => { + JitValue::Int(self.builder.ins().imul(a, b)) + } (BinaryOperator::FloorDivide, JitValue::Int(a), JitValue::Int(b)) => { JitValue::Int(self.builder.ins().sdiv(a, b)) } + (BinaryOperator::Divide, JitValue::Int(a), JitValue::Int(b)) => { + // Convert to float for regular division + let a_float = self.builder.ins().fcvt_from_sint(types::F64, a); + let b_float = self.builder.ins().fcvt_from_sint(types::F64, b); + JitValue::Float(self.builder.ins().fdiv(a_float, b_float)) + } (BinaryOperator::Modulo, JitValue::Int(a), JitValue::Int(b)) => { JitValue::Int(self.builder.ins().srem(a, b)) } + // Todo: This should return int when possible + (BinaryOperator::Power, JitValue::Int(a), JitValue::Int(b)) => { + JitValue::Float(self.compile_ipow(a, b)) + } ( BinaryOperator::Lshift | BinaryOperator::Rshift, JitValue::Int(a), @@ -562,4 +575,154 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> { .trapif(IntCC::Overflow, carry, TrapCode::IntegerOverflow); out } + fn compile_ipow(&mut self, a: Value, b: Value) -> Value { + // Convert base to float since result might not always be a Int + let float_base = self.builder.ins().fcvt_from_sint(types::F64, a); + + // Create code blocks + let check_block1 = self.builder.create_block(); + let check_block2 = self.builder.create_block(); + let check_block3 = self.builder.create_block(); + let handle_neg_exp = self.builder.create_block(); + let loop_block = self.builder.create_block(); + let continue_block = self.builder.create_block(); + let exit_block = self.builder.create_block(); + + // Set code block params + // Set code block params + self.builder.append_block_param(check_block1, types::F64); + self.builder.append_block_param(check_block1, types::I64); + + self.builder.append_block_param(check_block2, types::F64); + self.builder.append_block_param(check_block2, types::I64); + + self.builder.append_block_param(check_block3, types::F64); + self.builder.append_block_param(check_block3, types::I64); + + self.builder.append_block_param(handle_neg_exp, types::F64); + self.builder.append_block_param(handle_neg_exp, types::I64); + + self.builder.append_block_param(loop_block, types::F64); //base + self.builder.append_block_param(loop_block, types::F64); //result + self.builder.append_block_param(loop_block, types::I64); //exponent + + self.builder.append_block_param(continue_block, types::F64); //base + self.builder.append_block_param(continue_block, types::F64); //result + self.builder.append_block_param(continue_block, types::I64); //exponent + + self.builder.append_block_param(exit_block, types::F64); + + // Begin evaluating by jumping to first check block + self.builder.ins().jump(check_block1, &[float_base, b]); + + // Check block one: + // Checks if input is O ** n where n > 0 + // Jumps to exit_block as 0 if true + self.builder.switch_to_block(check_block1); + let paramsc1 = self.builder.block_params(check_block1); + let basec1 = paramsc1[0]; + let expc1 = paramsc1[1]; + let zero_f64 = self.builder.ins().f64const(0.0); + let zero_i64 = self.builder.ins().iconst(types::I64, 0); + let is_base_zero = self.builder.ins().fcmp(FloatCC::Equal, zero_f64, basec1); + let is_exp_positive = self + .builder + .ins() + .icmp(IntCC::SignedGreaterThan, expc1, zero_i64); + let is_zero_to_positive = self.builder.ins().band(is_base_zero, is_exp_positive); + self.builder + .ins() + .brnz(is_zero_to_positive, exit_block, &[zero_f64]); + self.builder.ins().jump(check_block2, &[basec1, expc1]); + + // Check block two: + // Checks if exponent is negative + // Jumps to a special handle_neg_exponent block if true + self.builder.switch_to_block(check_block2); + let paramsc2 = self.builder.block_params(check_block2); + let basec2 = paramsc2[0]; + let expc2 = paramsc2[1]; + let zero_i64 = self.builder.ins().iconst(types::I64, 0); + let is_neg = self + .builder + .ins() + .icmp(IntCC::SignedLessThan, expc2, zero_i64); + self.builder + .ins() + .brnz(is_neg, handle_neg_exp, &[basec2, expc2]); + self.builder.ins().jump(check_block3, &[basec2, expc2]); + + // Check block three: + // Checks if exponent is one + // jumps to exit block with the base of the exponents value + self.builder.switch_to_block(check_block3); + let paramsc3 = self.builder.block_params(check_block3); + let basec3 = paramsc3[0]; + let expc3 = paramsc3[1]; + let resc3 = self.builder.ins().f64const(1.0); + let one_i64 = self.builder.ins().iconst(types::I64, 1); + let is_one = self.builder.ins().icmp(IntCC::Equal, expc3, one_i64); + self.builder.ins().brnz(is_one, exit_block, &[basec3]); + self.builder.ins().jump(loop_block, &[basec3, resc3, expc3]); + + // Handles negative Exponents + // calculates x^(-n) = (1/x)^n + // then proceeds to the loop to evaluate + self.builder.switch_to_block(handle_neg_exp); + let paramshn = self.builder.block_params(handle_neg_exp); + let basehn = paramshn[0]; + let exphn = paramshn[1]; + let one_f64 = self.builder.ins().f64const(1.0); + let base_inverse = self.builder.ins().fdiv(one_f64, basehn); + let pos_exp = self.builder.ins().ineg(exphn); + self.builder + .ins() + .jump(loop_block, &[base_inverse, one_f64, pos_exp]); + + // Main loop block + // checks loop condition (exp > 0) + // Jumps to continue block if true, exit block if false + self.builder.switch_to_block(loop_block); + let paramslb = self.builder.block_params(loop_block); + let baselb = paramslb[0]; + let reslb = paramslb[1]; + let explb = paramslb[2]; + let zero = self.builder.ins().iconst(types::I64, 0); + let is_zero = self.builder.ins().icmp(IntCC::Equal, explb, zero); + self.builder.ins().brnz(is_zero, exit_block, &[reslb]); + self.builder + .ins() + .jump(continue_block, &[baselb, reslb, explb]); + + // Continue block + // Main math logic + // Always jumps back to loob_block + self.builder.switch_to_block(continue_block); + let paramscb = self.builder.block_params(continue_block); + let basecb = paramscb[0]; + let rescb = paramscb[1]; + let expcb = paramscb[2]; + let is_odd = self.builder.ins().band_imm(expcb, 1); + let is_odd = self.builder.ins().icmp_imm(IntCC::Equal, is_odd, 1); + let mul_result = self.builder.ins().fmul(rescb, basecb); + let new_result = self.builder.ins().select(is_odd, mul_result, rescb); + let squared_base = self.builder.ins().fmul(basecb, basecb); + let new_exp = self.builder.ins().sshr_imm(expcb, 1); + self.builder + .ins() + .jump(loop_block, &[squared_base, new_result, new_exp]); + + self.builder.switch_to_block(exit_block); + let result = self.builder.block_params(exit_block)[0]; + + self.builder.seal_block(check_block1); + self.builder.seal_block(check_block2); + self.builder.seal_block(check_block3); + self.builder.seal_block(handle_neg_exp); + self.builder.seal_block(loop_block); + self.builder.seal_block(continue_block); + self.builder.seal_block(exit_block); + + result + } } diff --git a/jit/tests/int_tests.rs b/jit/tests/int_tests.rs index 9ce3f3b4a6..353052df00 100644 --- a/jit/tests/int_tests.rs +++ b/jit/tests/int_tests.rs @@ -1,3 +1,5 @@ +use core::f64; + #[test] fn test_add() { let add = jit_function! { add(a:i64, b:i64) -> i64 => r##" @@ -23,6 +25,51 @@ fn test_sub() { assert_eq!(sub(-3, -10), Ok(7)); } +#[test] +fn test_mul() { + let mul = jit_function! { mul(a:i64, b:i64) -> i64 => r##" + def mul(a: int, b: int): + return a * b + "## }; + + assert_eq!(mul(5, 10), Ok(50)); + assert_eq!(mul(0, 5), Ok(0)); + assert_eq!(mul(5, 0), Ok(0)); + assert_eq!(mul(0, 0), Ok(0)); + assert_eq!(mul(-5, 10), Ok(-50)); + assert_eq!(mul(5, -10), Ok(-50)); + assert_eq!(mul(-5, -10), Ok(50)); + assert_eq!(mul(999999, 999999), Ok(999998000001)); + assert_eq!(mul(i64::MAX, 1), Ok(i64::MAX)); + assert_eq!(mul(1, i64::MAX), Ok(i64::MAX)); +} + +#[test] + +fn test_div() { + let div = jit_function! { div(a:i64, b:i64) -> f64 => r##" + def div(a: int, b: int): + return a / b + "## }; + + assert_eq!(div(0, 1), Ok(0.0)); + assert_eq!(div(5, 1), Ok(5.0)); + assert_eq!(div(5, 10), Ok(0.5)); + assert_eq!(div(5, 2), Ok(2.5)); + assert_eq!(div(12, 10), Ok(1.2)); + assert_eq!(div(7, 10), Ok(0.7)); + assert_eq!(div(-3, -1), Ok(3.0)); + assert_eq!(div(-3, 1), Ok(-3.0)); + assert_eq!(div(1, 1000), Ok(0.001)); + assert_eq!(div(1, 100000), Ok(0.00001)); + assert_eq!(div(2, 3), Ok(0.6666666666666666)); + assert_eq!(div(1, 3), Ok(0.3333333333333333)); + assert_eq!(div(i64::MAX, 2), Ok(4611686018427387904.0)); + assert_eq!(div(i64::MIN, 2), Ok(-4611686018427387904.0)); + assert_eq!(div(i64::MIN, -1), Ok(9223372036854775808.0)); // Overflow case + assert_eq!(div(i64::MIN, i64::MAX), Ok(-1.0)); +} + #[test] fn test_floor_div() { let floor_div = jit_function! { floor_div(a:i64, b:i64) -> i64 => r##" @@ -35,7 +82,28 @@ fn test_floor_div() { assert_eq!(floor_div(12, 10), Ok(1)); assert_eq!(floor_div(7, 10), Ok(0)); assert_eq!(floor_div(-3, -1), Ok(3)); - assert_eq!(floor_div(-3, 1), Ok(-3)); +} + +#[test] + +fn test_exp() { + let exp = jit_function! { exp(a: i64, b: i64) -> f64 => r##" + def exp(a: int, b: int): + return a ** b + "## }; + + assert_eq!(exp(2, 3), Ok(8.0)); + assert_eq!(exp(3, 2), Ok(9.0)); + assert_eq!(exp(5, 0), Ok(1.0)); + assert_eq!(exp(0, 0), Ok(1.0)); + assert_eq!(exp(-5, 0), Ok(1.0)); + assert_eq!(exp(0, 1), Ok(0.0)); + assert_eq!(exp(0, 5), Ok(0.0)); + assert_eq!(exp(-2, 2), Ok(4.0)); + assert_eq!(exp(-3, 4), Ok(81.0)); + assert_eq!(exp(-2, 3), Ok(-8.0)); + assert_eq!(exp(-3, 3), Ok(-27.0)); + assert_eq!(exp(1000, 2), Ok(1000000.0)); } #[test]