diff --git a/tests/snippets/dict.py b/tests/snippets/dict.py index 270d2c07e..f7506a205 100644 --- a/tests/snippets/dict.py +++ b/tests/snippets/dict.py @@ -198,3 +198,5 @@ z = {'c': 3, 'd': 3, 'e': 3} w = {1: 1, **x, 2: 2, **y, 3: 3, **z, 4: 4} assert w == {1: 1, 'a': 1, 'b': 2, 'c': 3, 2: 2, 'd': 3, 3: 3, 'e': 3, 4: 4} + +assert str({True: True, 1.0: 1.0}) == str({True: 1.0}) diff --git a/tests/snippets/floats.py b/tests/snippets/floats.py index 82f2b86e3..c116b8010 100644 --- a/tests/snippets/floats.py +++ b/tests/snippets/floats.py @@ -98,6 +98,21 @@ assert float(b'2.99e-23') == 2.99e-23 assert_raises(ValueError, lambda: float('foo')) assert_raises(OverflowError, lambda: float(2**10000)) +# check eq and hash for small numbers + +assert 1.0 == 1 +assert 1.0 == True +assert 0.0 == 0 +assert 0.0 == False +assert hash(1.0) == hash(1) +assert hash(1.0) == hash(True) +assert hash(0.0) == hash(0) +assert hash(0.0) == hash(False) +assert hash(1.0) != hash(1.0000000001) + +assert 5.0 in {3, 4, 5} +assert {-1: 2}[-1.0] == 2 + # check that magic methods are implemented for ints and floats assert 1.0.__add__(1.0) == 2.0 diff --git a/vm/src/lib.rs b/vm/src/lib.rs index 8b14454fe..2e5fcb3f4 100644 --- a/vm/src/lib.rs +++ b/vm/src/lib.rs @@ -55,6 +55,7 @@ pub mod frame; pub mod function; pub mod import; pub mod obj; +mod pyhash; pub mod pyobject; pub mod stdlib; mod symboltable; diff --git a/vm/src/obj/objfloat.rs b/vm/src/obj/objfloat.rs index add61a684..8d983c3ae 100644 --- a/vm/src/obj/objfloat.rs +++ b/vm/src/obj/objfloat.rs @@ -4,6 +4,7 @@ use super::objstr; use super::objtype; use crate::function::OptionalArg; use crate::obj::objtype::PyClassRef; +use crate::pyhash; use crate::pyobject::{ IdProtocol, IntoPyObject, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol, @@ -429,6 +430,11 @@ impl PyFloat { zelf } + #[pymethod(name = "__hash__")] + fn hash(&self, _vm: &VirtualMachine) -> pyhash::PyHash { + pyhash::hash_float(self.value) + } + #[pyproperty(name = "real")] fn real(zelf: PyRef, _vm: &VirtualMachine) -> PyFloatRef { zelf diff --git a/vm/src/pyhash.rs b/vm/src/pyhash.rs new file mode 100644 index 000000000..81c605584 --- /dev/null +++ b/vm/src/pyhash.rs @@ -0,0 +1,67 @@ +use std::hash::{Hash, Hasher}; + +use crate::pyobject::PyObjectRef; +use crate::pyobject::PyResult; +use crate::vm::VirtualMachine; + +pub type PyHash = i64; +pub type PyUHash = u64; + +pub const BITS: usize = 61; +pub const MODULUS: PyUHash = (1 << BITS) - 1; +// pub const CUTOFF: usize = 7; + +pub const INF: PyHash = 314159; +pub const NAN: PyHash = 0; + +pub fn hash_float(value: f64) -> PyHash { + // cpython _Py_HashDouble + if !value.is_finite() { + return if value.is_infinite() { + if value > 0.0 { + INF + } else { + -INF + } + } else { + NAN + }; + } + + let frexp = if 0.0 == value { + (value, 0i32) + } else { + let bits = value.to_bits(); + let exponent: i32 = ((bits >> 52) & 0x7ff) as i32 - 1022; + let mantissa_bits = bits & (0x000fffffffffffff) | (1022 << 52); + (f64::from_bits(mantissa_bits), exponent) + }; + + // process 28 bits at a time; this should work well both for binary + // and hexadecimal floating point. + let mut m = frexp.0; + let mut e = frexp.1; + let mut x: PyUHash = 0; + while m != 0.0 { + x = ((x << 28) & MODULUS) | x >> (BITS - 28); + m *= 268435456.0; // 2**28 + e -= 28; + let y = m as PyUHash; // pull out integer part + m -= y as f64; + x += y; + if x >= MODULUS { + x -= MODULUS; + } + } + + // adjust for the exponent; first reduce it modulo BITS + const BITS32: i32 = BITS as i32; + e = if e >= 0 { + e % BITS32 + } else { + BITS32 - 1 - ((-1 - e) % BITS32) + }; + x = ((x << e) & MODULUS) | x >> (BITS32 - e); + + x as PyHash * value.signum() as PyHash +}