diff --git a/Cargo.lock b/Cargo.lock index e1f720698..b3b8da765 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -36,6 +36,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "allocator-api2" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" + [[package]] name = "android-tzdata" version = "0.1.1" @@ -67,10 +73,10 @@ dependencies = [ ] [[package]] -name = "arrayvec" -version = "0.7.6" +name = "arbitrary" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" +checksum = "dde20b3d026af13f561bdd0f15edf01fc734f0dafcedbaf42bba506a9517f223" [[package]] name = "ascii" @@ -156,6 +162,9 @@ name = "bumpalo" version = "3.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf" +dependencies = [ + "allocator-api2", +] [[package]] name = "bytemuck" @@ -309,118 +318,145 @@ dependencies = [ [[package]] name = "cranelift" -version = "0.88.2" +version = "0.116.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea1b0c164043c16a8ece6813eef609ac2262a32a0bb0f5ed6eecf5d7bfb79ba8" +checksum = "a71de5e59f616d79d14d2c71aa2799ce898241d7f10f7e64a4997014b4000a28" dependencies = [ "cranelift-codegen", "cranelift-frontend", + "cranelift-module", ] [[package]] name = "cranelift-bforest" -version = "0.88.2" +version = "0.116.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52056f6d0584484b57fa6c1a65c1fcb15f3780d8b6a758426d9e3084169b2ddd" +checksum = "e15d04a0ce86cb36ead88ad68cf693ffd6cda47052b9e0ac114bc47fd9cd23c4" dependencies = [ "cranelift-entity", ] [[package]] -name = "cranelift-codegen" -version = "0.88.2" +name = "cranelift-bitset" +version = "0.116.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18fed94c8770dc25d01154c3ffa64ed0b3ba9d583736f305fed7beebe5d9cf74" +checksum = "7c6e3969a7ce267259ce244b7867c5d3bc9e65b0a87e81039588dfdeaede9f34" + +[[package]] +name = "cranelift-codegen" +version = "0.116.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c22032c4cb42558371cf516bb47f26cdad1819d3475c133e93c49f50ebf304e" dependencies = [ - "arrayvec", "bumpalo", "cranelift-bforest", + "cranelift-bitset", "cranelift-codegen-meta", "cranelift-codegen-shared", + "cranelift-control", "cranelift-entity", "cranelift-isle", + "gimli", + "hashbrown 0.14.5", "log", "regalloc2", + "rustc-hash 2.1.1", + "serde", "smallvec", - "target-lexicon", + "target-lexicon 0.13.2", ] [[package]] name = "cranelift-codegen-meta" -version = "0.88.2" +version = "0.116.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c451b81faf237d11c7e4f3165eeb6bac61112762c5cfe7b4c0fb7241474358f" +checksum = "c904bc71c61b27fc57827f4a1379f29de64fe95653b620a3db77d59655eee0b8" dependencies = [ "cranelift-codegen-shared", ] [[package]] name = "cranelift-codegen-shared" -version = "0.88.2" +version = "0.116.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7c940133198426d26128f08be2b40b0bd117b84771fd36798969c4d712d81fc" +checksum = "40180f5497572f644ce88c255480981ae2ec1d7bb4d8e0c0136a13b87a2f2ceb" + +[[package]] +name = "cranelift-control" +version = "0.116.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26d132c6d0bd8a489563472afc171759da0707804a65ece7ceb15a8c6d7dd5ef" +dependencies = [ + "arbitrary", +] [[package]] name = "cranelift-entity" -version = "0.88.2" +version = "0.116.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87a0f1b2fdc18776956370cf8d9b009ded3f855350c480c1c52142510961f352" +checksum = "4b2d0d9618275474fbf679dd018ac6e009acbd6ae6850f6a67be33fb3b00b323" +dependencies = [ + "cranelift-bitset", +] [[package]] name = "cranelift-frontend" -version = "0.88.2" +version = "0.116.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34897538b36b216cc8dd324e73263596d51b8cf610da6498322838b2546baf8a" +checksum = "4fac41e16729107393174b0c9e3730fb072866100e1e64e80a1a963b2e484d57" dependencies = [ "cranelift-codegen", "log", "smallvec", - "target-lexicon", + "target-lexicon 0.13.2", ] [[package]] name = "cranelift-isle" -version = "0.88.2" +version = "0.116.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b2629a569fae540f16a76b70afcc87ad7decb38dc28fa6c648ac73b51e78470" +checksum = "1ca20d576e5070044d0a72a9effc2deacf4d6aa650403189d8ea50126483944d" [[package]] name = "cranelift-jit" -version = "0.88.2" +version = "0.116.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "625be33ce54cf906c408f5ad9d08caa6e2a09e52d05fd0bd1bd95b132bfbba73" +checksum = "5e65c42755a719b09662b00c700daaf76cc35d5ace1f5c002ad404b591ff1978" dependencies = [ "anyhow", "cranelift-codegen", + "cranelift-control", "cranelift-entity", "cranelift-module", "cranelift-native", "libc", "log", "region", - "target-lexicon", - "windows-sys 0.36.1", + "target-lexicon 0.13.2", + "wasmtime-jit-icache-coherence", + "windows-sys 0.59.0", ] [[package]] name = "cranelift-module" -version = "0.88.2" +version = "0.116.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "883f8d42e07fd6b283941688f6c41a9e3b97fbf2b4ddcfb2756e675b86dc5edb" +checksum = "4d55612bebcf16ff7306c8a6f5bdb6d45662b8aa1ee058ecce8807ad87db719b" dependencies = [ "anyhow", "cranelift-codegen", + "cranelift-control", ] [[package]] name = "cranelift-native" -version = "0.88.2" +version = "0.116.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "20937dab4e14d3e225c5adfc9c7106bafd4ac669bdb43027b911ff794c6fb318" +checksum = "b8dee82f3f1f2c4cba9177f1cc5e350fe98764379bcd29340caa7b01f85076c7" dependencies = [ "cranelift-codegen", "libc", - "target-lexicon", + "target-lexicon 0.13.2", ] [[package]] @@ -658,6 +694,12 @@ version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "de853764b47027c2e862a995c34978ffa63c1501f2e15f987ba11bd4f9bba193" +[[package]] +name = "fallible-iterator" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649" + [[package]] name = "fd-lock" version = "4.0.2" @@ -737,15 +779,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" -[[package]] -name = "fxhash" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c31b6d751ae2c7f11320402d34e41349dd1016f8d5d45e48c4312bc8625af50c" -dependencies = [ - "byteorder", -] - [[package]] name = "generic-array" version = "0.14.7" @@ -802,6 +835,17 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "gimli" +version = "0.31.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" +dependencies = [ + "fallible-iterator", + "indexmap", + "stable_deref_trait", +] + [[package]] name = "glob" version = "0.3.2" @@ -1186,10 +1230,10 @@ dependencies = [ ] [[package]] -name = "mach" -version = "0.3.2" +name = "mach2" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b823e83b2affd8f40a9ee8c29dbc56404c1e34cd2710921f2801e2cf29527afa" +checksum = "19b955cdeb2a02b9117f121ce63aa52d08ade45de53e48fe6a38b39c10f6f709" dependencies = [ "libc", ] @@ -1668,7 +1712,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b14b5775b5ff446dd1056212d778012cbe8a0fbffd368029fd9e25b514479c38" dependencies = [ "once_cell", - "target-lexicon", + "target-lexicon 0.12.16", ] [[package]] @@ -1840,13 +1884,15 @@ dependencies = [ [[package]] name = "regalloc2" -version = "0.3.2" +version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d43a209257d978ef079f3d446331d0f1794f5e0fc19b306a199983857833a779" +checksum = "145c1c267e14f20fb0f88aa76a1c5ffec42d592c1d28b3cd9148ae35916158d3" dependencies = [ - "fxhash", + "allocator-api2", + "bumpalo", + "hashbrown 0.15.2", "log", - "slice-group-by", + "rustc-hash 2.1.1", "smallvec", ] @@ -1881,14 +1927,14 @@ checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "region" -version = "2.2.0" +version = "3.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "877e54ea2adcd70d80e9179344c97f93ef0dffd6b03e1f4529e6e83ab2fa9ae0" +checksum = "e6b6ebd13bc009aef9cd476c1310d49ac354d36e240cf1bd753290f3dc7199a7" dependencies = [ "bitflags 1.3.2", "libc", - "mach", - "winapi", + "mach2", + "windows-sys 0.52.0", ] [[package]] @@ -1918,6 +1964,12 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" +[[package]] +name = "rustc-hash" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" + [[package]] name = "rustc_version" version = "0.4.1" @@ -2128,7 +2180,7 @@ dependencies = [ "num-traits", "phf", "phf_codegen", - "rustc-hash", + "rustc-hash 1.1.0", "rustpython-ast", "rustpython-parser-core", "tiny-keccak", @@ -2522,12 +2574,6 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" -[[package]] -name = "slice-group-by" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "826167069c09b99d56f31e9ae5c99049e932a98c9dc2dac47645b08dbbf76ba7" - [[package]] name = "smallvec" version = "1.14.0" @@ -2544,6 +2590,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "stable_deref_trait" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" + [[package]] name = "static_assertions" version = "1.1.0" @@ -2635,6 +2687,12 @@ version = "0.12.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" +[[package]] +name = "target-lexicon" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e502f78cdbb8ba4718f566c418c52bc729126ffd16baee5baa718cf25dd5a69a" + [[package]] name = "termcolor" version = "1.4.1" @@ -3118,6 +3176,18 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "wasmtime-jit-icache-coherence" +version = "29.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec5e8552e01692e6c2e5293171704fed8abdec79d1a6995a0870ab190e5747d1" +dependencies = [ + "anyhow", + "cfg-if", + "libc", + "windows-sys 0.59.0", +] + [[package]] name = "web-sys" version = "0.3.77" @@ -3196,19 +3266,6 @@ dependencies = [ "windows-targets 0.52.6", ] -[[package]] -name = "windows-sys" -version = "0.36.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea04155a16a59f9eab786fe12a4a450e75cdb175f9e0d80da1e17db09f55b8d2" -dependencies = [ - "windows_aarch64_msvc 0.36.1", - "windows_i686_gnu 0.36.1", - "windows_i686_msvc 0.36.1", - "windows_x86_64_gnu 0.36.1", - "windows_x86_64_msvc 0.36.1", -] - [[package]] name = "windows-sys" version = "0.48.0" @@ -3279,12 +3336,6 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" -[[package]] -name = "windows_aarch64_msvc" -version = "0.36.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9bb8c3fd39ade2d67e9874ac4f3db21f0d710bee00fe7cab16949ec184eeaa47" - [[package]] name = "windows_aarch64_msvc" version = "0.48.5" @@ -3297,12 +3348,6 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" -[[package]] -name = "windows_i686_gnu" -version = "0.36.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "180e6ccf01daf4c426b846dfc66db1fc518f074baa793aa7d9b9aaeffad6a3b6" - [[package]] name = "windows_i686_gnu" version = "0.48.5" @@ -3321,12 +3366,6 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" -[[package]] -name = "windows_i686_msvc" -version = "0.36.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2e7917148b2812d1eeafaeb22a97e4813dfa60a3f8f78ebe204bcc88f12f024" - [[package]] name = "windows_i686_msvc" version = "0.48.5" @@ -3339,12 +3378,6 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" -[[package]] -name = "windows_x86_64_gnu" -version = "0.36.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4dcd171b8776c41b97521e5da127a2d86ad280114807d0b2ab1e462bc764d9e1" - [[package]] name = "windows_x86_64_gnu" version = "0.48.5" @@ -3369,12 +3402,6 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" -[[package]] -name = "windows_x86_64_msvc" -version = "0.36.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c811ca4a8c853ef420abd8592ba53ddbbac90410fab6903b3e79972a631f7680" - [[package]] name = "windows_x86_64_msvc" version = "0.48.5" diff --git a/jit/Cargo.toml b/jit/Cargo.toml index cc26eb59a..f59293a7e 100644 --- a/jit/Cargo.toml +++ b/jit/Cargo.toml @@ -16,9 +16,9 @@ rustpython-compiler-core = { workspace = true } num-traits = { workspace = true } thiserror = { workspace = true } -cranelift = "0.88.0" -cranelift-jit = "0.88.0" -cranelift-module = "0.88.0" +cranelift = "0.116.1" +cranelift-jit = "0.116.1" +cranelift-module = "0.116.1" [dependencies.libffi] version = "3.1.0" diff --git a/jit/src/instructions.rs b/jit/src/instructions.rs index 1b74760dc..bf30e51d7 100644 --- a/jit/src/instructions.rs +++ b/jit/src/instructions.rs @@ -11,7 +11,7 @@ use std::collections::HashMap; #[repr(u16)] enum CustomTrapCode { /// Raised when shifting by a negative number - NegativeShiftCount = 0, + NegativeShiftCount = 1, } #[derive(Clone)] @@ -56,6 +56,12 @@ impl JitValue { } } +#[derive(Clone)] +struct DDValue { + hi: Value, + lo: Value, +} + pub struct FunctionCompiler<'a, 'b> { builder: &'a mut FunctionBuilder<'b>, stack: Vec, @@ -123,14 +129,14 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> { fn boolean_val(&mut self, val: JitValue) -> Result { match val { JitValue::Float(val) => { - let zero = self.builder.ins().f64const(0); + let zero = self.builder.ins().f64const(0.0); let val = self.builder.ins().fcmp(FloatCC::NotEqual, val, zero); - Ok(self.builder.ins().bint(types::I8, val)) + Ok(val) } JitValue::Int(val) => { let zero = self.builder.ins().iconst(types::I64, 0); let val = self.builder.ins().icmp(IntCC::NotEqual, val, zero); - Ok(self.builder.ins().bint(types::I8, val)) + Ok(val) } JitValue::Bool(val) => Ok(val), JitValue::None => Ok(self.builder.ins().iconst(types::I8, 0)), @@ -151,38 +157,60 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> { func_ref: FuncRef, bytecode: &CodeObject, ) -> Result<(), JitCompileError> { - // TODO: figure out if this is sufficient -- previously individual labels were associated - // pretty much per-bytecode that uses them, or at least per "type" of block -- in theory an - // if block and a with block might jump to the same place. Now it's all "flattened", so - // there might be less distinction between different types of blocks going off - // label_targets alone let label_targets = bytecode.label_targets(); - let mut arg_state = OpArgState::default(); - for (offset, instruction) in bytecode.instructions.iter().enumerate() { - let (instruction, arg) = arg_state.get(*instruction); - let label = Label(offset as u32); - if label_targets.contains(&label) { - let block = self.get_or_create_block(label); - // If the current block is not terminated/filled just jump - // into the new block. - if !self.builder.is_filled() { - self.builder.ins().jump(block, &[]); + // Track whether we have "returned" in the current block + let mut in_unreachable_code = false; + + for (offset, &raw_instr) in bytecode.instructions.iter().enumerate() { + let label = Label(offset as u32); + let (instruction, arg) = arg_state.get(raw_instr); + + // If this is a label that some earlier jump can target, + // treat it as the start of a new reachable block: + if label_targets.contains(&label) { + // Create or get the block for this label: + let target_block = self.get_or_create_block(label); + + // If the current block isn't terminated, jump: + if let Some(cur) = self.builder.current_block() { + if cur != target_block && self.builder.func.layout.last_inst(cur).is_none() { + self.builder.ins().jump(target_block, &[]); + } + } + // Switch to the target block + if self.builder.current_block() != Some(target_block) { + self.builder.switch_to_block(target_block); } - self.builder.switch_to_block(block); + // We are definitely reachable again at this label + in_unreachable_code = false; } - // Sometimes the bytecode contains instructions after a return - // just ignore those until we are at the next label - if self.builder.is_filled() { + // If we're in unreachable code, skip this instruction unless the label re-entered above. + if in_unreachable_code { continue; } + // Actually compile this instruction: self.add_instruction(func_ref, bytecode, instruction, arg)?; + + // If that was a return instruction, mark future instructions unreachable + match instruction { + Instruction::ReturnValue | Instruction::ReturnConst { .. } => { + in_unreachable_code = true; + } + _ => {} + } } + // After processing, if the current block is unterminated, insert a trap or fallthrough + if let Some(cur) = self.builder.current_block() { + if self.builder.func.layout.last_inst(cur).is_none() { + self.builder.ins().trap(TrapCode::user(0).unwrap()); + } + } Ok(()) } @@ -214,10 +242,12 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> { fn return_value(&mut self, val: JitValue) -> Result<(), JitCompileError> { if let Some(ref ty) = self.sig.ret { + // If the signature has a return type, enforce it if val.to_jit_type().as_ref() != Some(ty) { return Err(JitCompileError::NotSupported); } } else { + // First time we see a return, define it in the signature let ty = val.to_jit_type().ok_or(JitCompileError::NotSupported)?; self.sig.ret = Some(ty.clone()); self.builder @@ -226,7 +256,12 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> { .returns .push(AbiParam::new(ty.to_cranelift())); } - self.builder.ins().return_(&[val.into_value().unwrap()]); + + // If this is e.g. an Int, Float, or Bool we have a Cranelift `Value`. + // If we have JitValue::None or .Tuple(...) but can't handle that, error out (or handle differently). + let cr_val = val.into_value().ok_or(JitCompileError::NotSupported)?; + + self.builder.ins().return_(&[cr_val]); Ok(()) } @@ -241,34 +276,34 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> { Instruction::ExtendedArg => Ok(()), Instruction::JumpIfFalse { target } => { let cond = self.stack.pop().ok_or(JitCompileError::BadBytecode)?; - let val = self.boolean_val(cond)?; let then_block = self.get_or_create_block(target.get(arg)); - self.builder.ins().brz(val, then_block, &[]); + let else_block = self.builder.create_block(); - let block = self.builder.create_block(); - self.builder.ins().jump(block, &[]); - self.builder.switch_to_block(block); + self.builder + .ins() + .brif(val, else_block, &[], then_block, &[]); + self.builder.switch_to_block(else_block); Ok(()) } Instruction::JumpIfTrue { target } => { let cond = self.stack.pop().ok_or(JitCompileError::BadBytecode)?; - let val = self.boolean_val(cond)?; let then_block = self.get_or_create_block(target.get(arg)); - self.builder.ins().brnz(val, then_block, &[]); + let else_block = self.builder.create_block(); - let block = self.builder.create_block(); - self.builder.ins().jump(block, &[]); - self.builder.switch_to_block(block); + self.builder + .ins() + .brif(val, then_block, &[], else_block, &[]); + self.builder.switch_to_block(else_block); Ok(()) } + Instruction::Jump { target } => { let target_block = self.get_or_create_block(target.get(arg)); self.builder.ins().jump(target_block, &[]); - Ok(()) } Instruction::LoadFast(idx) => { @@ -354,9 +389,7 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> { }; let val = self.builder.ins().icmp(cond, operand_one, operand_two); - // TODO: Remove this `bint` in cranelift 0.90 as icmp now returns i8 - self.stack - .push(JitValue::Bool(self.builder.ins().bint(types::I8, val))); + self.stack.push(JitValue::Bool(val)); Ok(()) } (JitValue::Float(a), JitValue::Float(b)) => { @@ -370,9 +403,7 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> { }; let val = self.builder.ins().fcmp(cond, a, b); - // TODO: Remove this `bint` in cranelift 0.90 as fcmp now returns i8 - self.stack - .push(JitValue::Bool(self.builder.ins().bint(types::I8, val))); + self.stack.push(JitValue::Bool(val)); Ok(()) } _ => Err(JitCompileError::NotSupported), @@ -414,35 +445,34 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> { let val = match (op, a, b) { (BinaryOperator::Add, JitValue::Int(a), JitValue::Int(b)) => { - let (out, carry) = self.builder.ins().iadd_ifcout(a, b); - self.builder.ins().trapif( - IntCC::Overflow, - carry, - TrapCode::IntegerOverflow, - ); + let (out, carry) = self.builder.ins().sadd_overflow(a, b); + self.builder.ins().trapnz(carry, TrapCode::INTEGER_OVERFLOW); JitValue::Int(out) } (BinaryOperator::Subtract, JitValue::Int(a), JitValue::Int(b)) => { JitValue::Int(self.compile_sub(a, b)) } - (BinaryOperator::Multiply, JitValue::Int(a), JitValue::Int(b)) => { - JitValue::Int(self.builder.ins().imul(a, b)) - } (BinaryOperator::FloorDivide, JitValue::Int(a), JitValue::Int(b)) => { JitValue::Int(self.builder.ins().sdiv(a, b)) } (BinaryOperator::Divide, JitValue::Int(a), JitValue::Int(b)) => { - // Convert to float for regular division + // Check if b == 0, If so trap with a division by zero error + self.builder + .ins() + .trapz(b, TrapCode::INTEGER_DIVISION_BY_ZERO); + // Else convert to float and divide let a_float = self.builder.ins().fcvt_from_sint(types::F64, a); let b_float = self.builder.ins().fcvt_from_sint(types::F64, b); JitValue::Float(self.builder.ins().fdiv(a_float, b_float)) } + (BinaryOperator::Multiply, JitValue::Int(a), JitValue::Int(b)) => { + JitValue::Int(self.builder.ins().imul(a, b)) + } (BinaryOperator::Modulo, JitValue::Int(a), JitValue::Int(b)) => { JitValue::Int(self.builder.ins().srem(a, b)) } - // Todo: This should return int when possible (BinaryOperator::Power, JitValue::Int(a), JitValue::Int(b)) => { - JitValue::Float(self.compile_ipow(a, b)) + JitValue::Int(self.compile_ipow(a, b)) } ( BinaryOperator::Lshift | BinaryOperator::Rshift, @@ -454,7 +484,7 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> { let sign = self.builder.ins().ushr_imm(b, 63); self.builder.ins().trapnz( sign, - TrapCode::User(CustomTrapCode::NegativeShiftCount as u16), + TrapCode::user(CustomTrapCode::NegativeShiftCount as u8).unwrap(), ); let out = if op == BinaryOperator::Lshift { @@ -487,6 +517,9 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> { (BinaryOperator::Divide, JitValue::Float(a), JitValue::Float(b)) => { JitValue::Float(self.builder.ins().fdiv(a, b)) } + (BinaryOperator::Power, JitValue::Float(a), JitValue::Float(b)) => { + JitValue::Float(self.compile_fpow(a, b)) + } // Floats and Integers (_, JitValue::Int(a), JitValue::Float(b)) @@ -514,6 +547,9 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> { BinaryOperator::Divide => { JitValue::Float(self.builder.ins().fdiv(operand_one, operand_two)) } + BinaryOperator::Power => { + JitValue::Float(self.compile_fpow(operand_one, operand_two)) + } _ => return Err(JitCompileError::NotSupported), } } @@ -523,7 +559,13 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> { Ok(()) } - Instruction::SetupLoop { .. } | Instruction::PopBlock => { + Instruction::SetupLoop { .. } => { + let loop_head = self.builder.create_block(); + self.builder.ins().jump(loop_head, &[]); + self.builder.switch_to_block(loop_head); + Ok(()) + } + Instruction::PopBlock => { // TODO: block support Ok(()) } @@ -562,167 +604,660 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> { } fn compile_sub(&mut self, a: Value, b: Value) -> Value { - // TODO: this should be fine, but cranelift doesn't special-case isub_ifbout - // let (out, carry) = self.builder.ins().isub_ifbout(a, b); - // self.builder - // .ins() - // .trapif(IntCC::Overflow, carry, TrapCode::IntegerOverflow); - // TODO: this shouldn't wrap - let neg_b = self.builder.ins().ineg(b); - let (out, carry) = self.builder.ins().iadd_ifcout(a, neg_b); - self.builder - .ins() - .trapif(IntCC::Overflow, carry, TrapCode::IntegerOverflow); + let (out, carry) = self.builder.ins().ssub_overflow(a, b); + self.builder.ins().trapnz(carry, TrapCode::INTEGER_OVERFLOW); out } - fn compile_ipow(&mut self, a: Value, b: Value) -> Value { - // Convert base to float since result might not always be a Int - let float_base = self.builder.ins().fcvt_from_sint(types::F64, a); - // Create code blocks - let check_block1 = self.builder.create_block(); - let check_block2 = self.builder.create_block(); - let check_block3 = self.builder.create_block(); - let handle_neg_exp = self.builder.create_block(); + /// Creates a double–double (DDValue) from a regular f64 constant. + /// The high part is set to x and the low part is set to 0.0. + fn dd_from_f64(&mut self, x: f64) -> DDValue { + DDValue { + hi: self.builder.ins().f64const(x), + lo: self.builder.ins().f64const(0.0), + } + } + + /// Creates a DDValue from a Value (assumed to represent an f64). + /// This function initializes the high part with x and the low part to 0.0. + fn dd_from_value(&mut self, x: Value) -> DDValue { + DDValue { + hi: x, + lo: self.builder.ins().f64const(0.0), + } + } + + /// Creates a DDValue from two f64 parts. + /// The 'hi' parameter sets the high part and 'lo' sets the low part. + fn dd_from_parts(&mut self, hi: f64, lo: f64) -> DDValue { + DDValue { + hi: self.builder.ins().f64const(hi), + lo: self.builder.ins().f64const(lo), + } + } + + /// Converts a DDValue back to a single f64 value by adding the high and low parts. + fn dd_to_f64(&mut self, dd: DDValue) -> Value { + self.builder.ins().fadd(dd.hi, dd.lo) + } + + /// Computes the negation of a DDValue. + /// It subtracts both the high and low parts from zero. + fn dd_neg(&mut self, dd: DDValue) -> DDValue { + let zero = self.builder.ins().f64const(0.0); + DDValue { + hi: self.builder.ins().fsub(zero, dd.hi), + lo: self.builder.ins().fsub(zero, dd.lo), + } + } + + /// Adds two DDValue numbers using error-free transformations to maintain extra precision. + /// It carefully adds the high parts, computes the rounding error, adds the low parts along with the error, + /// and then normalizes the result. + fn dd_add(&mut self, a: DDValue, b: DDValue) -> DDValue { + // Compute the sum of the high parts. + let s = self.builder.ins().fadd(a.hi, b.hi); + // Compute t = s - a.hi to capture part of the rounding error. + let t = self.builder.ins().fsub(s, a.hi); + // Compute the error e from the high part additions. + let s_minus_t = self.builder.ins().fsub(s, t); + let part1 = self.builder.ins().fsub(a.hi, s_minus_t); + let part2 = self.builder.ins().fsub(b.hi, t); + let e = self.builder.ins().fadd(part1, part2); + // Sum the low parts along with the error. + let lo = self.builder.ins().fadd(a.lo, b.lo); + let lo_sum = self.builder.ins().fadd(lo, e); + // Renormalize: add the low sum to s and compute a new low component. + let hi_new = self.builder.ins().fadd(s, lo_sum); + let hi_new_minus_s = self.builder.ins().fsub(hi_new, s); + let lo_new = self.builder.ins().fsub(lo_sum, hi_new_minus_s); + DDValue { + hi: hi_new, + lo: lo_new, + } + } + + /// Subtracts DDValue b from DDValue a by negating b and then using the addition function. + fn dd_sub(&mut self, a: DDValue, b: DDValue) -> DDValue { + let neg_b = self.dd_neg(b); + self.dd_add(a, neg_b) + } + + /// Multiplies two DDValue numbers using double–double arithmetic. + /// It calculates the high product, uses a fused multiply–add (FMA) to capture rounding error, + /// computes the cross products, and then normalizes the result. + fn dd_mul(&mut self, a: DDValue, b: DDValue) -> DDValue { + // p = a.hi * b.hi (primary product) + let p = self.builder.ins().fmul(a.hi, b.hi); + // err = fma(a.hi, b.hi, -p) recovers the rounding error. + let zero = self.builder.ins().f64const(0.0); + let neg_p = self.builder.ins().fsub(zero, p); + let err = self.builder.ins().fma(a.hi, b.hi, neg_p); + // Compute cross terms: a.hi*b.lo + a.lo*b.hi. + let a_hi_b_lo = self.builder.ins().fmul(a.hi, b.lo); + let a_lo_b_hi = self.builder.ins().fmul(a.lo, b.hi); + let cross = self.builder.ins().fadd(a_hi_b_lo, a_lo_b_hi); + // Sum p and the cross terms. + let s = self.builder.ins().fadd(p, cross); + // Isolate rounding error from the addition. + let t = self.builder.ins().fsub(s, p); + let s_minus_t = self.builder.ins().fsub(s, t); + let part1 = self.builder.ins().fsub(p, s_minus_t); + let part2 = self.builder.ins().fsub(cross, t); + let e = self.builder.ins().fadd(part1, part2); + // Include the error from the low parts multiplication. + let a_lo_b_lo = self.builder.ins().fmul(a.lo, b.lo); + let err_plus_e = self.builder.ins().fadd(err, e); + let lo_sum = self.builder.ins().fadd(err_plus_e, a_lo_b_lo); + // Renormalize the sum. + let hi_new = self.builder.ins().fadd(s, lo_sum); + let hi_new_minus_s = self.builder.ins().fsub(hi_new, s); + let lo_new = self.builder.ins().fsub(lo_sum, hi_new_minus_s); + DDValue { + hi: hi_new, + lo: lo_new, + } + } + + /// Multiplies a DDValue by a regular f64 (Value) using similar techniques as dd_mul. + /// It multiplies both the high and low parts by b, computes the rounding error, + /// and then renormalizes the result. + fn dd_mul_f64(&mut self, a: DDValue, b: Value) -> DDValue { + // p = a.hi * b (primary product) + let p = self.builder.ins().fmul(a.hi, b); + // Compute the rounding error using fma. + let zero = self.builder.ins().f64const(0.0); + let neg_p = self.builder.ins().fsub(zero, p); + let err = self.builder.ins().fma(a.hi, b, neg_p); + // Multiply the low part. + let cross = self.builder.ins().fmul(a.lo, b); + // Sum the primary product and the low multiplication. + let s = self.builder.ins().fadd(p, cross); + // Capture rounding error from addition. + let t = self.builder.ins().fsub(s, p); + let s_minus_t = self.builder.ins().fsub(s, t); + let part1 = self.builder.ins().fsub(p, s_minus_t); + let part2 = self.builder.ins().fsub(cross, t); + let e = self.builder.ins().fadd(part1, part2); + // Combine the error components. + let lo_sum = self.builder.ins().fadd(err, e); + // Renormalize to form the final double–double number. + let hi_new = self.builder.ins().fadd(s, lo_sum); + let hi_new_minus_s = self.builder.ins().fsub(hi_new, s); + let lo_new = self.builder.ins().fsub(lo_sum, hi_new_minus_s); + DDValue { + hi: hi_new, + lo: lo_new, + } + } + + /// Scales a DDValue by multiplying both its high and low parts by the given factor. + fn dd_scale(&mut self, dd: DDValue, factor: Value) -> DDValue { + DDValue { + hi: self.builder.ins().fmul(dd.hi, factor), + lo: self.builder.ins().fmul(dd.lo, factor), + } + } + + /// Approximates ln(1+f) using its Taylor series expansion in double–double arithmetic. + /// It computes the series ∑ (-1)^(i-1) * f^i / i from i = 1 to 1000 for high precision. + fn dd_ln_1p_series(&mut self, f: Value) -> DDValue { + // Convert f to a DDValue and initialize the sum and term. + let f_dd = self.dd_from_value(f); + let mut sum = f_dd.clone(); + let mut term = f_dd; + // Alternating sign starts at -1 for the second term. + let mut sign = -1.0_f64; + let range = 1000; + + // Loop over terms from i = 2 to 1000. + for i in 2..=range { + // Compute f^i by multiplying the previous term by f. + term = self.dd_mul_f64(term, f); + // Divide the term by i. + let inv_i = 1.0 / (i as f64); + let c_inv_i = self.builder.ins().f64const(inv_i); + let term_div = self.dd_mul_f64(term.clone(), c_inv_i); + // Multiply by the alternating sign. + let dd_sign = self.dd_from_f64(sign); + let to_add = self.dd_mul(dd_sign, term_div); + // Add the term to the cumulative sum. + sum = self.dd_add(sum, to_add); + // Flip the sign for the next term. + sign = -sign; + } + sum + } + + /// Computes the natural logarithm ln(x) in double–double arithmetic. + /// It first checks for domain errors (x ≤ 0 or NaN), then extracts the exponent + /// and mantissa from the bit-level representation of x. It computes ln(mantissa) using + /// the ln(1+f) series and adds k*ln2 to obtain ln(x). + fn dd_ln(&mut self, x: Value) -> DDValue { + // (A) Prepare a DDValue representing NaN. + let dd_nan = self.dd_from_f64(f64::NAN); + + // Build a zero constant for comparisons. + let zero_f64 = self.builder.ins().f64const(0.0); + + // Check if x is less than or equal to 0 or is NaN. + let cmp_le = self + .builder + .ins() + .fcmp(FloatCC::LessThanOrEqual, x, zero_f64); + let cmp_nan = self.builder.ins().fcmp(FloatCC::Unordered, x, x); + let need_nan = self.builder.ins().bor(cmp_le, cmp_nan); + + // (B) Reinterpret the bits of x as an integer. + let bits = self.builder.ins().bitcast(types::I64, MemFlags::new(), x); + + // (C) Extract the exponent (top 11 bits) from the bit representation. + let shift_52 = self.builder.ins().ushr_imm(bits, 52); + let exponent_mask = self.builder.ins().iconst(types::I64, 0x7FF); + let exponent = self.builder.ins().band(shift_52, exponent_mask); + + // k = exponent - 1023 (unbias the exponent). + let bias = self.builder.ins().iconst(types::I64, 1023); + let k_i64 = self.builder.ins().isub(exponent, bias); + + // (D) Extract the fraction (mantissa) from the lower 52 bits. + let fraction_mask = self.builder.ins().iconst(types::I64, 0x000F_FFFF_FFFF_FFFF); + let fraction_part = self.builder.ins().band(bits, fraction_mask); + + // (E) For normal numbers (exponent ≠ 0), add the implicit leading 1. + let implicit_one = self.builder.ins().iconst(types::I64, 1 << 52); + let zero_exp = self.builder.ins().icmp_imm(IntCC::Equal, exponent, 0); + let frac_one_bor = self.builder.ins().bor(fraction_part, implicit_one); + let fraction_with_leading_one = self.builder.ins().select( + zero_exp, + fraction_part, // For subnormals, do not add the implicit 1. + frac_one_bor, + ); + + // (F) Force the exponent bits to 1023, yielding a mantissa m in [1, 2). + let new_exp = self.builder.ins().iconst(types::I64, 0x3FF0_0000_0000_0000); + let fraction_bits = self.builder.ins().bor(fraction_with_leading_one, new_exp); + let m = self + .builder + .ins() + .bitcast(types::F64, MemFlags::new(), fraction_bits); + + // (G) Compute ln(m) using the series ln(1+f) with f = m - 1. + let one_f64 = self.builder.ins().f64const(1.0); + let f_val = self.builder.ins().fsub(m, one_f64); + let dd_ln_m = self.dd_ln_1p_series(f_val); + + // (H) Compute k*ln2 in double–double arithmetic. + let ln2_dd = self.dd_from_parts( + f64::from_bits(0x3fe62e42fefa39ef), + f64::from_bits(0x3c7abc9e3b39803f), + ); + let k_f64 = self.builder.ins().fcvt_from_sint(types::F64, k_i64); + let dd_ln2_k = self.dd_mul_f64(ln2_dd, k_f64); + + // Add ln(m) and k*ln2 to get the final ln(x). + let normal_result = self.dd_add(dd_ln_m, dd_ln2_k); + + // (I) If x was nonpositive or NaN, return NaN; otherwise, return the computed result. + let final_hi = self + .builder + .ins() + .select(need_nan, dd_nan.hi, normal_result.hi); + let final_lo = self + .builder + .ins() + .select(need_nan, dd_nan.lo, normal_result.lo); + + DDValue { + hi: final_hi, + lo: final_lo, + } + } + + /// Computes the exponential function exp(x) in double–double arithmetic. + /// It uses range reduction to write x = k*ln2 + r, computes exp(r) via a Taylor series, + /// scales the result by 2^k, and handles overflow by checking if k exceeds the maximum. + fn dd_exp(&mut self, dd: DDValue) -> DDValue { + // (A) Range reduction: Convert dd to a single f64 value. + let x = self.dd_to_f64(dd.clone()); + let ln2_f64 = self + .builder + .ins() + .f64const(f64::from_bits(0x3fe62e42fefa39ef)); + let div = self.builder.ins().fdiv(x, ln2_f64); + let half = self.builder.ins().f64const(0.5); + let div_plus_half = self.builder.ins().fadd(div, half); + // Rounding: floor(div + 0.5) gives the nearest integer k. + let k = self.builder.ins().fcvt_to_sint(types::I64, div_plus_half); + + // --- OVERFLOW CHECK --- + // Check if k is greater than the maximum exponent for finite doubles (1023). + let max_k = self.builder.ins().iconst(types::I64, 1023); + let is_overflow = self.builder.ins().icmp(IntCC::SignedGreaterThan, k, max_k); + + // Define infinity and zero for the overflow case. + let inf = self.builder.ins().f64const(f64::INFINITY); + let zero = self.builder.ins().f64const(0.0); + + // (B) Compute exp(x) normally when not overflowing. + // Compute k*ln2 in double–double arithmetic and subtract it from x. + let ln2_dd = self.dd_from_parts( + f64::from_bits(0x3fe62e42fefa39ef), + f64::from_bits(0x3c7abc9e3b39803f), + ); + let k_f64 = self.builder.ins().fcvt_from_sint(types::F64, k); + let k_ln2 = self.dd_mul_f64(ln2_dd, k_f64); + let r = self.dd_sub(dd, k_ln2); + + // Compute exp(r) using a Taylor series. + let mut sum = self.dd_from_f64(1.0); // Initialize sum to 1. + let mut term = self.dd_from_f64(1.0); // Initialize the first term to 1. + let n_terms = 1000; + for i in 1..=n_terms { + term = self.dd_mul(term, r.clone()); + let inv = 1.0 / (i as f64); + let inv_const = self.builder.ins().f64const(inv); + term = self.dd_mul_f64(term, inv_const); + sum = self.dd_add(sum, term.clone()); + } + + // Reconstruct the final result by scaling with 2^k. + let bias = self.builder.ins().iconst(types::I64, 1023); + let k_plus_bias = self.builder.ins().iadd(k, bias); + let shift_count = self.builder.ins().iconst(types::I64, 52); + let shifted = self.builder.ins().ishl(k_plus_bias, shift_count); + let two_to_k = self + .builder + .ins() + .bitcast(types::F64, MemFlags::new(), shifted); + let result = self.dd_scale(sum, two_to_k); + + // (C) If overflow was detected, return infinity; otherwise, return the computed value. + let final_hi = self.builder.ins().select(is_overflow, inf, result.hi); + let final_lo = self.builder.ins().select(is_overflow, zero, result.lo); + DDValue { + hi: final_hi, + lo: final_lo, + } + } + + /// Computes the power function a^b (f_pow) for f64 values using double–double arithmetic for high precision. + /// It handles different cases for the base 'a': + /// - For a > 0: Computes exp(b * ln(a)). + /// - For a == 0: Handles special cases for 0^b, including returning 0, 1, or a domain error. + /// - For a < 0: Allows only an integer exponent b and adjusts the sign if b is odd. + fn compile_fpow(&mut self, a: Value, b: Value) -> Value { + let f64_ty = types::F64; + let i64_ty = types::I64; + let zero_f = self.builder.ins().f64const(0.0); + let one_f = self.builder.ins().f64const(1.0); + let nan_f = self.builder.ins().f64const(f64::NAN); + let inf_f = self.builder.ins().f64const(f64::INFINITY); + let neg_inf_f = self.builder.ins().f64const(f64::NEG_INFINITY); + + // Merge block for final result. + let merge_block = self.builder.create_block(); + self.builder.append_block_param(merge_block, f64_ty); + + // --- Edge Case 1: b == 0.0 → return 1.0 + let cmp_b_zero = self.builder.ins().fcmp(FloatCC::Equal, b, zero_f); + let b_zero_block = self.builder.create_block(); + let continue_block = self.builder.create_block(); + self.builder + .ins() + .brif(cmp_b_zero, b_zero_block, &[], continue_block, &[]); + self.builder.switch_to_block(b_zero_block); + self.builder.ins().jump(merge_block, &[one_f]); + self.builder.switch_to_block(continue_block); + + // --- Edge Case 2: b is NaN → return NaN + let cmp_b_nan = self.builder.ins().fcmp(FloatCC::Unordered, b, b); + let b_nan_block = self.builder.create_block(); + let continue_block2 = self.builder.create_block(); + self.builder + .ins() + .brif(cmp_b_nan, b_nan_block, &[], continue_block2, &[]); + self.builder.switch_to_block(b_nan_block); + self.builder.ins().jump(merge_block, &[nan_f]); + self.builder.switch_to_block(continue_block2); + + // --- Edge Case 3: a == 0.0 → return 0.0 + let cmp_a_zero = self.builder.ins().fcmp(FloatCC::Equal, a, zero_f); + let a_zero_block = self.builder.create_block(); + let continue_block3 = self.builder.create_block(); + self.builder + .ins() + .brif(cmp_a_zero, a_zero_block, &[], continue_block3, &[]); + self.builder.switch_to_block(a_zero_block); + self.builder.ins().jump(merge_block, &[zero_f]); + self.builder.switch_to_block(continue_block3); + + // --- Edge Case 4: a is NaN → return NaN + let cmp_a_nan = self.builder.ins().fcmp(FloatCC::Unordered, a, a); + let a_nan_block = self.builder.create_block(); + let continue_block4 = self.builder.create_block(); + self.builder + .ins() + .brif(cmp_a_nan, a_nan_block, &[], continue_block4, &[]); + self.builder.switch_to_block(a_nan_block); + self.builder.ins().jump(merge_block, &[nan_f]); + self.builder.switch_to_block(continue_block4); + + // --- Edge Case 5: b == +infinity → return +infinity + let cmp_b_inf = self.builder.ins().fcmp(FloatCC::Equal, b, inf_f); + let b_inf_block = self.builder.create_block(); + let continue_block5 = self.builder.create_block(); + self.builder + .ins() + .brif(cmp_b_inf, b_inf_block, &[], continue_block5, &[]); + self.builder.switch_to_block(b_inf_block); + self.builder.ins().jump(merge_block, &[inf_f]); + self.builder.switch_to_block(continue_block5); + + // --- Edge Case 6: b == -infinity → return 0.0 + let cmp_b_neg_inf = self.builder.ins().fcmp(FloatCC::Equal, b, neg_inf_f); + let b_neg_inf_block = self.builder.create_block(); + let continue_block6 = self.builder.create_block(); + self.builder + .ins() + .brif(cmp_b_neg_inf, b_neg_inf_block, &[], continue_block6, &[]); + self.builder.switch_to_block(b_neg_inf_block); + self.builder.ins().jump(merge_block, &[zero_f]); + self.builder.switch_to_block(continue_block6); + + // --- Edge Case 7: a == +infinity → return +infinity + let cmp_a_inf = self.builder.ins().fcmp(FloatCC::Equal, a, inf_f); + let a_inf_block = self.builder.create_block(); + let continue_block7 = self.builder.create_block(); + self.builder + .ins() + .brif(cmp_a_inf, a_inf_block, &[], continue_block7, &[]); + self.builder.switch_to_block(a_inf_block); + self.builder.ins().jump(merge_block, &[inf_f]); + self.builder.switch_to_block(continue_block7); + + // --- Edge Case 8: a == -infinity → check exponent parity + let cmp_a_neg_inf = self.builder.ins().fcmp(FloatCC::Equal, a, neg_inf_f); + let a_neg_inf_block = self.builder.create_block(); + let continue_block8 = self.builder.create_block(); + self.builder + .ins() + .brif(cmp_a_neg_inf, a_neg_inf_block, &[], continue_block8, &[]); + + self.builder.switch_to_block(a_neg_inf_block); + // a is -infinity here. First, ensure that b is an integer. + let b_floor = self.builder.ins().floor(b); + let cmp_int = self.builder.ins().fcmp(FloatCC::Equal, b_floor, b); + let domain_error_blk = self.builder.create_block(); + let continue_neg_inf = self.builder.create_block(); + self.builder + .ins() + .brif(cmp_int, continue_neg_inf, &[], domain_error_blk, &[]); + + self.builder.switch_to_block(domain_error_blk); + self.builder.ins().jump(merge_block, &[nan_f]); + + self.builder.switch_to_block(continue_neg_inf); + // b is an integer here; convert b_floor to an i64. + let b_i64 = self.builder.ins().fcvt_to_sint(i64_ty, b_floor); + let one_i = self.builder.ins().iconst(i64_ty, 1); + let remainder = self.builder.ins().band(b_i64, one_i); + let zero_i = self.builder.ins().iconst(i64_ty, 0); + let is_odd = self.builder.ins().icmp(IntCC::NotEqual, remainder, zero_i); + + // Create separate blocks for odd and even cases. + let odd_block = self.builder.create_block(); + let even_block = self.builder.create_block(); + self.builder.append_block_param(odd_block, f64_ty); + self.builder.append_block_param(even_block, f64_ty); + self.builder + .ins() + .brif(is_odd, odd_block, &[neg_inf_f], even_block, &[inf_f]); + + self.builder.switch_to_block(odd_block); + let phi_neg_inf = self.builder.block_params(odd_block)[0]; + self.builder.ins().jump(merge_block, &[phi_neg_inf]); + + self.builder.switch_to_block(even_block); + let phi_inf = self.builder.block_params(even_block)[0]; + self.builder.ins().jump(merge_block, &[phi_inf]); + + self.builder.switch_to_block(continue_block8); + + // --- Normal branch: neither a nor b hit the special cases. + // Here we branch based on the sign of a. + let cmp_lt = self.builder.ins().fcmp(FloatCC::LessThan, a, zero_f); + let a_neg_block = self.builder.create_block(); + let a_pos_block = self.builder.create_block(); + self.builder + .ins() + .brif(cmp_lt, a_neg_block, &[], a_pos_block, &[]); + + // ----- Case: a > 0: Compute a^b = exp(b * ln(a)) using double–double arithmetic. + self.builder.switch_to_block(a_pos_block); + let ln_a_dd = self.dd_ln(a); + let b_dd = self.dd_from_value(b); + let product_dd = self.dd_mul(ln_a_dd, b_dd); + let exp_dd = self.dd_exp(product_dd); + let pos_res = self.dd_to_f64(exp_dd); + self.builder.ins().jump(merge_block, &[pos_res]); + + // ----- Case: a < 0: Only allow an integral exponent. + self.builder.switch_to_block(a_neg_block); + let b_floor = self.builder.ins().floor(b); + let cmp_int = self.builder.ins().fcmp(FloatCC::Equal, b_floor, b); + let neg_int_block = self.builder.create_block(); + let domain_error_blk = self.builder.create_block(); + self.builder + .ins() + .brif(cmp_int, neg_int_block, &[], domain_error_blk, &[]); + + // Domain error: non-integer exponent for negative base + self.builder.switch_to_block(domain_error_blk); + self.builder.ins().jump(merge_block, &[nan_f]); + + // For negative base with an integer exponent: + self.builder.switch_to_block(neg_int_block); + let abs_a = self.builder.ins().fabs(a); + let ln_abs_dd = self.dd_ln(abs_a); + let b_dd = self.dd_from_value(b); + let product_dd = self.dd_mul(ln_abs_dd, b_dd); + let exp_dd = self.dd_exp(product_dd); + let mag_val = self.dd_to_f64(exp_dd); + + let b_i64 = self.builder.ins().fcvt_to_sint(i64_ty, b_floor); + let one_i = self.builder.ins().iconst(i64_ty, 1); + let remainder = self.builder.ins().band(b_i64, one_i); + let zero_i = self.builder.ins().iconst(i64_ty, 0); + let is_odd = self.builder.ins().icmp(IntCC::NotEqual, remainder, zero_i); + + let odd_block = self.builder.create_block(); + let even_block = self.builder.create_block(); + // Append block parameters for both branches: + self.builder.append_block_param(odd_block, f64_ty); + self.builder.append_block_param(even_block, f64_ty); + // Pass mag_val to both branches: + self.builder + .ins() + .brif(is_odd, odd_block, &[mag_val], even_block, &[mag_val]); + + self.builder.switch_to_block(odd_block); + let phi_mag_val = self.builder.block_params(odd_block)[0]; + let neg_val = self.builder.ins().fneg(phi_mag_val); + self.builder.ins().jump(merge_block, &[neg_val]); + + self.builder.switch_to_block(even_block); + let phi_mag_val_even = self.builder.block_params(even_block)[0]; + self.builder.ins().jump(merge_block, &[phi_mag_val_even]); + + // ----- Merge: Return the final result. + self.builder.switch_to_block(merge_block); + let final_val = self.builder.block_params(merge_block)[0]; + final_val + } + + fn compile_ipow(&mut self, a: Value, b: Value) -> Value { + let zero = self.builder.ins().iconst(types::I64, 0); + let one_i64 = self.builder.ins().iconst(types::I64, 1); + + // Create required blocks + let check_negative = self.builder.create_block(); + let handle_negative = self.builder.create_block(); let loop_block = self.builder.create_block(); let continue_block = self.builder.create_block(); let exit_block = self.builder.create_block(); - // Set code block params - // Set code block params - self.builder.append_block_param(check_block1, types::F64); - self.builder.append_block_param(check_block1, types::I64); + // Set up block parameters + self.builder.append_block_param(check_negative, types::I64); // exponent + self.builder.append_block_param(check_negative, types::I64); // base - self.builder.append_block_param(check_block2, types::F64); - self.builder.append_block_param(check_block2, types::I64); + self.builder.append_block_param(handle_negative, types::I64); // abs(exponent) + self.builder.append_block_param(handle_negative, types::I64); // base - self.builder.append_block_param(check_block3, types::F64); - self.builder.append_block_param(check_block3, types::I64); + self.builder.append_block_param(loop_block, types::I64); // exponent + self.builder.append_block_param(loop_block, types::I64); // result + self.builder.append_block_param(loop_block, types::I64); // base - self.builder.append_block_param(handle_neg_exp, types::F64); - self.builder.append_block_param(handle_neg_exp, types::I64); + self.builder.append_block_param(exit_block, types::I64); // final result - self.builder.append_block_param(loop_block, types::F64); //base - self.builder.append_block_param(loop_block, types::F64); //result - self.builder.append_block_param(loop_block, types::I64); //exponent + // Set up parameters for continue_block + self.builder.append_block_param(continue_block, types::I64); // exponent + self.builder.append_block_param(continue_block, types::I64); // result + self.builder.append_block_param(continue_block, types::I64); // base - self.builder.append_block_param(continue_block, types::F64); //base - self.builder.append_block_param(continue_block, types::F64); //result - self.builder.append_block_param(continue_block, types::I64); //exponent + // Initial jump to check if exponent is negative + self.builder.ins().jump(check_negative, &[b, a]); - self.builder.append_block_param(exit_block, types::F64); + // Check if exponent is negative + self.builder.switch_to_block(check_negative); + let params = self.builder.block_params(check_negative); + let exp_check = params[0]; + let base_check = params[1]; - // Begin evaluating by jumping to first check block - self.builder.ins().jump(check_block1, &[float_base, b]); - - // Check block one: - // Checks if input is O ** n where n > 0 - // Jumps to exit_block as 0 if true - self.builder.switch_to_block(check_block1); - let paramsc1 = self.builder.block_params(check_block1); - let basec1 = paramsc1[0]; - let expc1 = paramsc1[1]; - let zero_f64 = self.builder.ins().f64const(0.0); - let zero_i64 = self.builder.ins().iconst(types::I64, 0); - let is_base_zero = self.builder.ins().fcmp(FloatCC::Equal, zero_f64, basec1); - let is_exp_positive = self + let is_negative = self .builder .ins() - .icmp(IntCC::SignedGreaterThan, expc1, zero_i64); - let is_zero_to_positive = self.builder.ins().band(is_base_zero, is_exp_positive); - self.builder - .ins() - .brnz(is_zero_to_positive, exit_block, &[zero_f64]); - self.builder.ins().jump(check_block2, &[basec1, expc1]); + .icmp(IntCC::SignedLessThan, exp_check, zero); + self.builder.ins().brif( + is_negative, + handle_negative, + &[exp_check, base_check], + loop_block, + &[exp_check, one_i64, base_check], + ); - // Check block two: - // Checks if exponent is negative - // Jumps to a special handle_neg_exponent block if true - self.builder.switch_to_block(check_block2); - let paramsc2 = self.builder.block_params(check_block2); - let basec2 = paramsc2[0]; - let expc2 = paramsc2[1]; - let zero_i64 = self.builder.ins().iconst(types::I64, 0); - let is_neg = self - .builder - .ins() - .icmp(IntCC::SignedLessThan, expc2, zero_i64); - self.builder - .ins() - .brnz(is_neg, handle_neg_exp, &[basec2, expc2]); - self.builder.ins().jump(check_block3, &[basec2, expc2]); + // Handle negative exponent (return 0 for integer exponentiation) + self.builder.switch_to_block(handle_negative); + self.builder.ins().jump(exit_block, &[zero]); // Return 0 for negative exponents - // Check block three: - // Checks if exponent is one - // jumps to exit block with the base of the exponents value - self.builder.switch_to_block(check_block3); - let paramsc3 = self.builder.block_params(check_block3); - let basec3 = paramsc3[0]; - let expc3 = paramsc3[1]; - let resc3 = self.builder.ins().f64const(1.0); - let one_i64 = self.builder.ins().iconst(types::I64, 1); - let is_one = self.builder.ins().icmp(IntCC::Equal, expc3, one_i64); - self.builder.ins().brnz(is_one, exit_block, &[basec3]); - self.builder.ins().jump(loop_block, &[basec3, resc3, expc3]); - - // Handles negative Exponents - // calculates x^(-n) = (1/x)^n - // then proceeds to the loop to evaluate - self.builder.switch_to_block(handle_neg_exp); - let paramshn = self.builder.block_params(handle_neg_exp); - let basehn = paramshn[0]; - let exphn = paramshn[1]; - let one_f64 = self.builder.ins().f64const(1.0); - let base_inverse = self.builder.ins().fdiv(one_f64, basehn); - let pos_exp = self.builder.ins().ineg(exphn); - self.builder - .ins() - .jump(loop_block, &[base_inverse, one_f64, pos_exp]); - - // Main loop block - // checks loop condition (exp > 0) - // Jumps to continue block if true, exit block if false + // Loop block logic (square-and-multiply algorithm) self.builder.switch_to_block(loop_block); - let paramslb = self.builder.block_params(loop_block); - let baselb = paramslb[0]; - let reslb = paramslb[1]; - let explb = paramslb[2]; - let zero = self.builder.ins().iconst(types::I64, 0); - let is_zero = self.builder.ins().icmp(IntCC::Equal, explb, zero); - self.builder.ins().brnz(is_zero, exit_block, &[reslb]); - self.builder - .ins() - .jump(continue_block, &[baselb, reslb, explb]); + let params = self.builder.block_params(loop_block); + let exp_phi = params[0]; + let result_phi = params[1]; + let base_phi = params[2]; - // Continue block - // Main math logic - // Always jumps back to loob_block + // Check if exponent is zero + let is_zero = self.builder.ins().icmp(IntCC::Equal, exp_phi, zero); + self.builder.ins().brif( + is_zero, + exit_block, + &[result_phi], + continue_block, + &[exp_phi, result_phi, base_phi], + ); + + // Continue block for non-zero case self.builder.switch_to_block(continue_block); - let paramscb = self.builder.block_params(continue_block); - let basecb = paramscb[0]; - let rescb = paramscb[1]; - let expcb = paramscb[2]; - let is_odd = self.builder.ins().band_imm(expcb, 1); + let params = self.builder.block_params(continue_block); + let exp_phi = params[0]; + let result_phi = params[1]; + let base_phi = params[2]; + + // If exponent is odd, multiply result by base + let is_odd = self.builder.ins().band_imm(exp_phi, 1); let is_odd = self.builder.ins().icmp_imm(IntCC::Equal, is_odd, 1); - let mul_result = self.builder.ins().fmul(rescb, basecb); - let new_result = self.builder.ins().select(is_odd, mul_result, rescb); - let squared_base = self.builder.ins().fmul(basecb, basecb); - let new_exp = self.builder.ins().sshr_imm(expcb, 1); + let mul_result = self.builder.ins().imul(result_phi, base_phi); + let new_result = self.builder.ins().select(is_odd, mul_result, result_phi); + + // Square the base and divide exponent by 2 + let squared_base = self.builder.ins().imul(base_phi, base_phi); + let new_exp = self.builder.ins().sshr_imm(exp_phi, 1); self.builder .ins() - .jump(loop_block, &[squared_base, new_result, new_exp]); + .jump(loop_block, &[new_exp, new_result, squared_base]); + // Exit block self.builder.switch_to_block(exit_block); - let result = self.builder.block_params(exit_block)[0]; + let res = self.builder.block_params(exit_block)[0]; - self.builder.seal_block(check_block1); - self.builder.seal_block(check_block2); - self.builder.seal_block(check_block3); - self.builder.seal_block(handle_neg_exp); + // Seal all blocks + self.builder.seal_block(check_negative); + self.builder.seal_block(handle_negative); self.builder.seal_block(loop_block); self.builder.seal_block(continue_block); self.builder.seal_block(exit_block); - result + res } } diff --git a/jit/src/lib.rs b/jit/src/lib.rs index 37f1f2a3d..33054b1c9 100644 --- a/jit/src/lib.rs +++ b/jit/src/lib.rs @@ -114,7 +114,7 @@ pub fn compile( let (id, sig) = jit.build_function(bytecode, args, ret)?; - jit.module.finalize_definitions(); + jit.module.finalize_definitions()?; let code = jit.module.get_finalized_function(id); Ok(CompiledCode { diff --git a/jit/tests/float_tests.rs b/jit/tests/float_tests.rs index 2ba7dec82..384d7b946 100644 --- a/jit/tests/float_tests.rs +++ b/jit/tests/float_tests.rs @@ -110,6 +110,117 @@ fn test_mul_with_integer() { assert_bits_eq!(mul(-0.0, -1), Ok(0.0f64)); } +#[test] +fn test_power() { + let pow = jit_function! { pow(a:f64, b:f64) -> f64 => r##" + def pow(a:float, b: float): + return a**b + "##}; + // Test base cases + assert_approx_eq!(pow(0.0, 0.0), Ok(1.0)); + assert_approx_eq!(pow(0.0, 1.0), Ok(0.0)); + assert_approx_eq!(pow(1.0, 0.0), Ok(1.0)); + assert_approx_eq!(pow(1.0, 1.0), Ok(1.0)); + assert_approx_eq!(pow(1.0, -1.0), Ok(1.0)); + assert_approx_eq!(pow(-1.0, 0.0), Ok(1.0)); + assert_approx_eq!(pow(-1.0, 1.0), Ok(-1.0)); + assert_approx_eq!(pow(-1.0, -1.0), Ok(-1.0)); + + // NaN and Infinity cases + assert_approx_eq!(pow(f64::NAN, 0.0), Ok(1.0)); + //assert_approx_eq!(pow(f64::NAN, 1.0), Ok(f64::NAN)); // Return the correct answer but fails compare + //assert_approx_eq!(pow(0.0, f64::NAN), Ok(f64::NAN)); // Return the correct answer but fails compare + assert_approx_eq!(pow(f64::INFINITY, 0.0), Ok(1.0)); + assert_approx_eq!(pow(f64::INFINITY, 1.0), Ok(f64::INFINITY)); + assert_approx_eq!(pow(f64::INFINITY, f64::INFINITY), Ok(f64::INFINITY)); + // Negative infinity cases: + // For any exponent of 0.0, the result is 1.0. + assert_approx_eq!(pow(f64::NEG_INFINITY, 0.0), Ok(1.0)); + // For negative infinity base, when b is an odd integer, result is -infinity; + // when b is even, result is +infinity. + assert_approx_eq!(pow(f64::NEG_INFINITY, 1.0), Ok(f64::NEG_INFINITY)); + assert_approx_eq!(pow(f64::NEG_INFINITY, 2.0), Ok(f64::INFINITY)); + assert_approx_eq!(pow(f64::NEG_INFINITY, 3.0), Ok(f64::NEG_INFINITY)); + // Exponent -infinity gives 0.0. + assert_approx_eq!(pow(f64::NEG_INFINITY, f64::NEG_INFINITY), Ok(0.0)); + + // Test positive float base, positive float exponent + assert_approx_eq!(pow(2.0, 2.0), Ok(4.0)); + assert_approx_eq!(pow(3.0, 3.0), Ok(27.0)); + assert_approx_eq!(pow(4.0, 4.0), Ok(256.0)); + assert_approx_eq!(pow(2.0, 3.0), Ok(8.0)); + assert_approx_eq!(pow(2.0, 4.0), Ok(16.0)); + // Test negative float base, positive float exponent (integral exponents only) + assert_approx_eq!(pow(-2.0, 2.0), Ok(4.0)); + assert_approx_eq!(pow(-3.0, 3.0), Ok(-27.0)); + assert_approx_eq!(pow(-4.0, 4.0), Ok(256.0)); + assert_approx_eq!(pow(-2.0, 3.0), Ok(-8.0)); + assert_approx_eq!(pow(-2.0, 4.0), Ok(16.0)); + // Test positive float base, positive float exponent + assert_approx_eq!(pow(2.5, 2.0), Ok(6.25)); + assert_approx_eq!(pow(3.5, 3.0), Ok(42.875)); + assert_approx_eq!(pow(4.5, 4.0), Ok(410.0625)); + assert_approx_eq!(pow(2.5, 3.0), Ok(15.625)); + assert_approx_eq!(pow(2.5, 4.0), Ok(39.0625)); + // Test negative float base, positive float exponent (integral exponents only) + assert_approx_eq!(pow(-2.5, 2.0), Ok(6.25)); + assert_approx_eq!(pow(-3.5, 3.0), Ok(-42.875)); + assert_approx_eq!(pow(-4.5, 4.0), Ok(410.0625)); + assert_approx_eq!(pow(-2.5, 3.0), Ok(-15.625)); + assert_approx_eq!(pow(-2.5, 4.0), Ok(39.0625)); + // Test positive float base, positive float exponent with nonintegral exponents + assert_approx_eq!(pow(2.0, 2.5), Ok(5.656854249492381)); + assert_approx_eq!(pow(3.0, 3.5), Ok(46.76537180435969)); + assert_approx_eq!(pow(4.0, 4.5), Ok(512.0)); + assert_approx_eq!(pow(2.0, 3.5), Ok(11.313708498984761)); + assert_approx_eq!(pow(2.0, 4.5), Ok(22.627416997969522)); + // Test positive float base, negative float exponent + assert_approx_eq!(pow(2.0, -2.5), Ok(0.1767766952966369)); + assert_approx_eq!(pow(3.0, -3.5), Ok(0.021383343303319473)); + assert_approx_eq!(pow(4.0, -4.5), Ok(0.001953125)); + assert_approx_eq!(pow(2.0, -3.5), Ok(0.08838834764831845)); + assert_approx_eq!(pow(2.0, -4.5), Ok(0.04419417382415922)); + // Test negative float base, negative float exponent (integral exponents only) + assert_approx_eq!(pow(-2.0, -2.0), Ok(0.25)); + assert_approx_eq!(pow(-3.0, -3.0), Ok(-0.037037037037037035)); + assert_approx_eq!(pow(-4.0, -4.0), Ok(0.00390625)); + assert_approx_eq!(pow(-2.0, -3.0), Ok(-0.125)); + assert_approx_eq!(pow(-2.0, -4.0), Ok(0.0625)); + + // Currently negative float base with nonintegral exponent is not supported: + // assert_approx_eq!(pow(-2.0, 2.5), Ok(5.656854249492381)); + // assert_approx_eq!(pow(-3.0, 3.5), Ok(-46.76537180435969)); + // assert_approx_eq!(pow(-4.0, 4.5), Ok(512.0)); + // assert_approx_eq!(pow(-2.0, -2.5), Ok(0.1767766952966369)); + // assert_approx_eq!(pow(-3.0, -3.5), Ok(0.021383343303319473)); + // assert_approx_eq!(pow(-4.0, -4.5), Ok(0.001953125)); + + // Extra cases **NOTE** these are not all working: + // * If they are commented in then they work + // * If they are commented out with a number that is the current return value it throws vs the expected value + // * If they are commented out with a "fail to run" that means I couldn't get them to work, could add a case for really big or small values + // 1e308^2.0 + assert_approx_eq!(pow(1e308, 2.0), Ok(f64::INFINITY)); + // 1e308^(1e-2) + assert_approx_eq!(pow(1e308, 1e-2), Ok(1202.2644346174131)); + // 1e-308^2.0 + //assert_approx_eq!(pow(1e-308, 2.0), Ok(0.0)); // --8.403311421507407 + // 1e-308^-2.0 + assert_approx_eq!(pow(1e-308, -2.0), Ok(f64::INFINITY)); + // 1e100^(1e50) + //assert_approx_eq!(pow(1e100, 1e50), Ok(1.0000000000000002e+150)); // fail to run (Crashes as "illegal hardware instruction") + // 1e50^(1e-100) + assert_approx_eq!(pow(1e50, 1e-100), Ok(1.0)); + // 1e308^(-1e2) + //assert_approx_eq!(pow(1e308, -1e2), Ok(0.0)); // 2.961801792837933e25 + // 1e-308^(1e2) + //assert_approx_eq!(pow(1e-308, 1e2), Ok(f64::INFINITY)); // 1.6692559244043896e46 + // 1e308^(-1e308) + // assert_approx_eq!(pow(1e308, -1e308), Ok(0.0)); // fail to run (Crashes as "illegal hardware instruction") + // 1e-308^(1e308) + // assert_approx_eq!(pow(1e-308, 1e308), Ok(0.0)); // fail to run (Crashes as "illegal hardware instruction") +} + #[test] fn test_div() { let div = jit_function! { div(a:f64, b:f64) -> f64 => r##" diff --git a/jit/tests/int_tests.rs b/jit/tests/int_tests.rs index 353052df0..5ab2697e0 100644 --- a/jit/tests/int_tests.rs +++ b/jit/tests/int_tests.rs @@ -45,7 +45,6 @@ fn test_mul() { } #[test] - fn test_div() { let div = jit_function! { div(a:i64, b:i64) -> f64 => r##" def div(a: int, b: int): @@ -87,23 +86,23 @@ fn test_floor_div() { #[test] fn test_exp() { - let exp = jit_function! { exp(a: i64, b: i64) -> f64 => r##" + let exp = jit_function! { exp(a: i64, b: i64) -> i64 => r##" def exp(a: int, b: int): return a ** b "## }; - assert_eq!(exp(2, 3), Ok(8.0)); - assert_eq!(exp(3, 2), Ok(9.0)); - assert_eq!(exp(5, 0), Ok(1.0)); - assert_eq!(exp(0, 0), Ok(1.0)); - assert_eq!(exp(-5, 0), Ok(1.0)); - assert_eq!(exp(0, 1), Ok(0.0)); - assert_eq!(exp(0, 5), Ok(0.0)); - assert_eq!(exp(-2, 2), Ok(4.0)); - assert_eq!(exp(-3, 4), Ok(81.0)); - assert_eq!(exp(-2, 3), Ok(-8.0)); - assert_eq!(exp(-3, 3), Ok(-27.0)); - assert_eq!(exp(1000, 2), Ok(1000000.0)); + assert_eq!(exp(2, 3), Ok(8)); + assert_eq!(exp(3, 2), Ok(9)); + assert_eq!(exp(5, 0), Ok(1)); + assert_eq!(exp(0, 0), Ok(1)); + assert_eq!(exp(-5, 0), Ok(1)); + assert_eq!(exp(0, 1), Ok(0)); + assert_eq!(exp(0, 5), Ok(0)); + assert_eq!(exp(-2, 2), Ok(4)); + assert_eq!(exp(-3, 4), Ok(81)); + assert_eq!(exp(-2, 3), Ok(-8)); + assert_eq!(exp(-3, 3), Ok(-27)); + assert_eq!(exp(1000, 2), Ok(1000000)); } #[test] @@ -121,6 +120,18 @@ fn test_mod() { assert_eq!(modulo(-5, 10), Ok(-5)); } +#[test] +fn test_power() { + let power = jit_function! { power(a:i64, b:i64) -> i64 => r##" + def power(a: int, b: int): + return a ** b + "## + }; + assert_eq!(power(10, 2), Ok(100)); + assert_eq!(power(5, 1), Ok(5)); + assert_eq!(power(1, 0), Ok(1)); +} + #[test] fn test_lshift() { let lshift = jit_function! { lshift(a:i64, b:i64) -> i64 => r##" diff --git a/jit/tests/misc_tests.rs b/jit/tests/misc_tests.rs index 7e1174da4..25d66c46c 100644 --- a/jit/tests/misc_tests.rs +++ b/jit/tests/misc_tests.rs @@ -95,7 +95,6 @@ fn test_while_loop() { a -= 1 return b "## }; - assert_eq!(while_loop(0), Ok(0)); assert_eq!(while_loop(-1), Ok(0)); assert_eq!(while_loop(1), Ok(1));