diff --git a/Lib/test/test_re.py b/Lib/test/test_re.py index 52f24eb86..0f3521080 100644 --- a/Lib/test/test_re.py +++ b/Lib/test/test_re.py @@ -2033,8 +2033,6 @@ ELSE self.assertIn('ASCII', str(re.A)) self.assertIn('DOTALL', str(re.S)) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_pattern_compare(self): pattern1 = re.compile('abc', re.IGNORECASE) @@ -2064,8 +2062,6 @@ ELSE with self.assertRaises(TypeError): pattern1 < pattern2 - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_pattern_compare_bytes(self): pattern1 = re.compile(b'abc') diff --git a/vm/src/stdlib/sre.rs b/vm/src/stdlib/sre.rs index 94e2d8b44..cbdbb9e19 100644 --- a/vm/src/stdlib/sre.rs +++ b/vm/src/stdlib/sre.rs @@ -6,6 +6,7 @@ mod _sre { use itertools::Itertools; use num_traits::ToPrimitive; use rustpython_common::borrow::BorrowValue; + use rustpython_common::hash::PyHash; use crate::builtins::list::PyListRef; use crate::builtins::memory::try_buffer_from_object; @@ -15,9 +16,10 @@ mod _sre { }; use crate::function::{Args, OptionalArg}; use crate::pyobject::{ - IntoPyObject, ItemProtocol, PyCallable, PyObjectRef, PyRef, PyResult, PyValue, StaticType, - TryFromObject, + IntoPyObject, ItemProtocol, PyCallable, PyComparisonValue, PyObjectRef, PyRef, PyResult, + PyValue, StaticType, TryFromObject, }; + use crate::slots::{Comparable, Hashable}; use crate::VirtualMachine; use core::str; use sre_engine::constants::SreFlag; @@ -139,7 +141,7 @@ mod _sre { } } - #[pyimpl] + #[pyimpl(with(Hashable, Comparable))] impl Pattern { fn with_str_drive PyResult>( &self, @@ -529,6 +531,42 @@ mod _sre { } } + impl Hashable for Pattern { + fn hash(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { + let hash = vm._hash(&zelf.pattern)?; + let (_, code, _) = unsafe { zelf.code.align_to::() }; + let hash = hash ^ vm.state.hash_secret.hash_bytes(code); + let hash = hash ^ (zelf.flags.bits() as PyHash); + let hash = hash ^ (zelf.isbytes as i64); + Ok(hash) + } + } + + impl Comparable for Pattern { + fn cmp( + zelf: &PyRef, + other: &PyObjectRef, + op: crate::slots::PyComparisonOp, + vm: &VirtualMachine, + ) -> PyResult { + if let Some(res) = op.identical_optimization(zelf, other) { + return Ok(res.into()); + } + op.eq_only(|| { + if let Some(other) = other.downcast_ref::() { + Ok(PyComparisonValue::Implemented( + zelf.flags == other.flags + && zelf.isbytes == other.isbytes + && zelf.code == other.code + && vm.bool_eq(&zelf.pattern, &other.pattern)?, + )) + } else { + Ok(PyComparisonValue::NotImplemented) + } + }) + } + } + #[pyattr] #[pyclass(name = "Match")] #[derive(Debug)]