Add JIT compilation support for integer multiplication, division, and exponents (#5561)

* Initial commit for power function

* Float power jit developed

* Addded support for Floats and Ints

* Integration Testing for JITPower implementation.

* Update instructions.rs

cranelift more like painlift

* Update instructions.rs

* Update instructions.rs

* initial commit for making stable PR ready features

* fixed final edge case for compile_ipow

* fixed final edge case for compile_ipow

* commenting out compile_ipow

* fixed spelling errors

* removed unused tests

* forgot to run clippy

---------

Co-authored-by: Nicholas Paulick <paulicknicholas@gmail.com>
Co-authored-by: Nick <nick@Samanthas-MBP.wi.rr.com>
Co-authored-by: JoeLoparco <loparcojoseph@gmail.com>
Co-authored-by: Nathan Rusch <nathan.rusch@icloud.com>
Co-authored-by: dohear <daniel.ohear@marquette.edu>
This commit is contained in:
Daniel O'Hear
2025-03-05 20:41:45 +00:00
committed by GitHub
parent 7fea1e1b4a
commit 58ebf04bac
2 changed files with 232 additions and 1 deletions

View File

@@ -425,12 +425,25 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
(BinaryOperator::Subtract, JitValue::Int(a), JitValue::Int(b)) => {
JitValue::Int(self.compile_sub(a, b))
}
(BinaryOperator::Multiply, JitValue::Int(a), JitValue::Int(b)) => {
JitValue::Int(self.builder.ins().imul(a, b))
}
(BinaryOperator::FloorDivide, JitValue::Int(a), JitValue::Int(b)) => {
JitValue::Int(self.builder.ins().sdiv(a, b))
}
(BinaryOperator::Divide, JitValue::Int(a), JitValue::Int(b)) => {
// Convert to float for regular division
let a_float = self.builder.ins().fcvt_from_sint(types::F64, a);
let b_float = self.builder.ins().fcvt_from_sint(types::F64, b);
JitValue::Float(self.builder.ins().fdiv(a_float, b_float))
}
(BinaryOperator::Modulo, JitValue::Int(a), JitValue::Int(b)) => {
JitValue::Int(self.builder.ins().srem(a, b))
}
// Todo: This should return int when possible
(BinaryOperator::Power, JitValue::Int(a), JitValue::Int(b)) => {
JitValue::Float(self.compile_ipow(a, b))
}
(
BinaryOperator::Lshift | BinaryOperator::Rshift,
JitValue::Int(a),
@@ -562,4 +575,154 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
.trapif(IntCC::Overflow, carry, TrapCode::IntegerOverflow);
out
}
fn compile_ipow(&mut self, a: Value, b: Value) -> Value {
// Convert base to float since result might not always be a Int
let float_base = self.builder.ins().fcvt_from_sint(types::F64, a);
// Create code blocks
let check_block1 = self.builder.create_block();
let check_block2 = self.builder.create_block();
let check_block3 = self.builder.create_block();
let handle_neg_exp = self.builder.create_block();
let loop_block = self.builder.create_block();
let continue_block = self.builder.create_block();
let exit_block = self.builder.create_block();
// Set code block params
// Set code block params
self.builder.append_block_param(check_block1, types::F64);
self.builder.append_block_param(check_block1, types::I64);
self.builder.append_block_param(check_block2, types::F64);
self.builder.append_block_param(check_block2, types::I64);
self.builder.append_block_param(check_block3, types::F64);
self.builder.append_block_param(check_block3, types::I64);
self.builder.append_block_param(handle_neg_exp, types::F64);
self.builder.append_block_param(handle_neg_exp, types::I64);
self.builder.append_block_param(loop_block, types::F64); //base
self.builder.append_block_param(loop_block, types::F64); //result
self.builder.append_block_param(loop_block, types::I64); //exponent
self.builder.append_block_param(continue_block, types::F64); //base
self.builder.append_block_param(continue_block, types::F64); //result
self.builder.append_block_param(continue_block, types::I64); //exponent
self.builder.append_block_param(exit_block, types::F64);
// Begin evaluating by jumping to first check block
self.builder.ins().jump(check_block1, &[float_base, b]);
// Check block one:
// Checks if input is O ** n where n > 0
// Jumps to exit_block as 0 if true
self.builder.switch_to_block(check_block1);
let paramsc1 = self.builder.block_params(check_block1);
let basec1 = paramsc1[0];
let expc1 = paramsc1[1];
let zero_f64 = self.builder.ins().f64const(0.0);
let zero_i64 = self.builder.ins().iconst(types::I64, 0);
let is_base_zero = self.builder.ins().fcmp(FloatCC::Equal, zero_f64, basec1);
let is_exp_positive = self
.builder
.ins()
.icmp(IntCC::SignedGreaterThan, expc1, zero_i64);
let is_zero_to_positive = self.builder.ins().band(is_base_zero, is_exp_positive);
self.builder
.ins()
.brnz(is_zero_to_positive, exit_block, &[zero_f64]);
self.builder.ins().jump(check_block2, &[basec1, expc1]);
// Check block two:
// Checks if exponent is negative
// Jumps to a special handle_neg_exponent block if true
self.builder.switch_to_block(check_block2);
let paramsc2 = self.builder.block_params(check_block2);
let basec2 = paramsc2[0];
let expc2 = paramsc2[1];
let zero_i64 = self.builder.ins().iconst(types::I64, 0);
let is_neg = self
.builder
.ins()
.icmp(IntCC::SignedLessThan, expc2, zero_i64);
self.builder
.ins()
.brnz(is_neg, handle_neg_exp, &[basec2, expc2]);
self.builder.ins().jump(check_block3, &[basec2, expc2]);
// Check block three:
// Checks if exponent is one
// jumps to exit block with the base of the exponents value
self.builder.switch_to_block(check_block3);
let paramsc3 = self.builder.block_params(check_block3);
let basec3 = paramsc3[0];
let expc3 = paramsc3[1];
let resc3 = self.builder.ins().f64const(1.0);
let one_i64 = self.builder.ins().iconst(types::I64, 1);
let is_one = self.builder.ins().icmp(IntCC::Equal, expc3, one_i64);
self.builder.ins().brnz(is_one, exit_block, &[basec3]);
self.builder.ins().jump(loop_block, &[basec3, resc3, expc3]);
// Handles negative Exponents
// calculates x^(-n) = (1/x)^n
// then proceeds to the loop to evaluate
self.builder.switch_to_block(handle_neg_exp);
let paramshn = self.builder.block_params(handle_neg_exp);
let basehn = paramshn[0];
let exphn = paramshn[1];
let one_f64 = self.builder.ins().f64const(1.0);
let base_inverse = self.builder.ins().fdiv(one_f64, basehn);
let pos_exp = self.builder.ins().ineg(exphn);
self.builder
.ins()
.jump(loop_block, &[base_inverse, one_f64, pos_exp]);
// Main loop block
// checks loop condition (exp > 0)
// Jumps to continue block if true, exit block if false
self.builder.switch_to_block(loop_block);
let paramslb = self.builder.block_params(loop_block);
let baselb = paramslb[0];
let reslb = paramslb[1];
let explb = paramslb[2];
let zero = self.builder.ins().iconst(types::I64, 0);
let is_zero = self.builder.ins().icmp(IntCC::Equal, explb, zero);
self.builder.ins().brnz(is_zero, exit_block, &[reslb]);
self.builder
.ins()
.jump(continue_block, &[baselb, reslb, explb]);
// Continue block
// Main math logic
// Always jumps back to loob_block
self.builder.switch_to_block(continue_block);
let paramscb = self.builder.block_params(continue_block);
let basecb = paramscb[0];
let rescb = paramscb[1];
let expcb = paramscb[2];
let is_odd = self.builder.ins().band_imm(expcb, 1);
let is_odd = self.builder.ins().icmp_imm(IntCC::Equal, is_odd, 1);
let mul_result = self.builder.ins().fmul(rescb, basecb);
let new_result = self.builder.ins().select(is_odd, mul_result, rescb);
let squared_base = self.builder.ins().fmul(basecb, basecb);
let new_exp = self.builder.ins().sshr_imm(expcb, 1);
self.builder
.ins()
.jump(loop_block, &[squared_base, new_result, new_exp]);
self.builder.switch_to_block(exit_block);
let result = self.builder.block_params(exit_block)[0];
self.builder.seal_block(check_block1);
self.builder.seal_block(check_block2);
self.builder.seal_block(check_block3);
self.builder.seal_block(handle_neg_exp);
self.builder.seal_block(loop_block);
self.builder.seal_block(continue_block);
self.builder.seal_block(exit_block);
result
}
}

View File

@@ -1,3 +1,5 @@
use core::f64;
#[test]
fn test_add() {
let add = jit_function! { add(a:i64, b:i64) -> i64 => r##"
@@ -23,6 +25,51 @@ fn test_sub() {
assert_eq!(sub(-3, -10), Ok(7));
}
#[test]
fn test_mul() {
let mul = jit_function! { mul(a:i64, b:i64) -> i64 => r##"
def mul(a: int, b: int):
return a * b
"## };
assert_eq!(mul(5, 10), Ok(50));
assert_eq!(mul(0, 5), Ok(0));
assert_eq!(mul(5, 0), Ok(0));
assert_eq!(mul(0, 0), Ok(0));
assert_eq!(mul(-5, 10), Ok(-50));
assert_eq!(mul(5, -10), Ok(-50));
assert_eq!(mul(-5, -10), Ok(50));
assert_eq!(mul(999999, 999999), Ok(999998000001));
assert_eq!(mul(i64::MAX, 1), Ok(i64::MAX));
assert_eq!(mul(1, i64::MAX), Ok(i64::MAX));
}
#[test]
fn test_div() {
let div = jit_function! { div(a:i64, b:i64) -> f64 => r##"
def div(a: int, b: int):
return a / b
"## };
assert_eq!(div(0, 1), Ok(0.0));
assert_eq!(div(5, 1), Ok(5.0));
assert_eq!(div(5, 10), Ok(0.5));
assert_eq!(div(5, 2), Ok(2.5));
assert_eq!(div(12, 10), Ok(1.2));
assert_eq!(div(7, 10), Ok(0.7));
assert_eq!(div(-3, -1), Ok(3.0));
assert_eq!(div(-3, 1), Ok(-3.0));
assert_eq!(div(1, 1000), Ok(0.001));
assert_eq!(div(1, 100000), Ok(0.00001));
assert_eq!(div(2, 3), Ok(0.6666666666666666));
assert_eq!(div(1, 3), Ok(0.3333333333333333));
assert_eq!(div(i64::MAX, 2), Ok(4611686018427387904.0));
assert_eq!(div(i64::MIN, 2), Ok(-4611686018427387904.0));
assert_eq!(div(i64::MIN, -1), Ok(9223372036854775808.0)); // Overflow case
assert_eq!(div(i64::MIN, i64::MAX), Ok(-1.0));
}
#[test]
fn test_floor_div() {
let floor_div = jit_function! { floor_div(a:i64, b:i64) -> i64 => r##"
@@ -35,7 +82,28 @@ fn test_floor_div() {
assert_eq!(floor_div(12, 10), Ok(1));
assert_eq!(floor_div(7, 10), Ok(0));
assert_eq!(floor_div(-3, -1), Ok(3));
assert_eq!(floor_div(-3, 1), Ok(-3));
}
#[test]
fn test_exp() {
let exp = jit_function! { exp(a: i64, b: i64) -> f64 => r##"
def exp(a: int, b: int):
return a ** b
"## };
assert_eq!(exp(2, 3), Ok(8.0));
assert_eq!(exp(3, 2), Ok(9.0));
assert_eq!(exp(5, 0), Ok(1.0));
assert_eq!(exp(0, 0), Ok(1.0));
assert_eq!(exp(-5, 0), Ok(1.0));
assert_eq!(exp(0, 1), Ok(0.0));
assert_eq!(exp(0, 5), Ok(0.0));
assert_eq!(exp(-2, 2), Ok(4.0));
assert_eq!(exp(-3, 4), Ok(81.0));
assert_eq!(exp(-2, 3), Ok(-8.0));
assert_eq!(exp(-3, 3), Ok(-27.0));
assert_eq!(exp(1000, 2), Ok(1000000.0));
}
#[test]