Compare commits

...

13 Commits

Author SHA1 Message Date
Shahar Naveh
8e35cc9e72 Update asyncio to 3.14.5 (#8012) 2026-06-02 09:48:45 +00:00
Bas Schoenmaeckers
5b2f6bc270 Add sequence protocol to the c-api (#8016) 2026-06-02 09:33:26 +00:00
Shahar Naveh
b5b82587dd Update test_super.py to 3.14.5 (#8015) 2026-06-02 09:25:40 +00:00
dependabot[bot]
ef74577f4b Bump codecov/codecov-action from 6.0.0 to 6.0.1 (#7979)
Bumps [codecov/codecov-action](https://github.com/codecov/codecov-action) from 6.0.0 to 6.0.1.
- [Release notes](https://github.com/codecov/codecov-action/releases)
- [Changelog](https://github.com/codecov/codecov-action/blob/main/CHANGELOG.md)
- [Commits](57e3a136b7...e79a6962e0)

---
updated-dependencies:
- dependency-name: codecov/codecov-action
  dependency-version: 6.0.1
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-02 18:12:00 +09:00
Shahar Naveh
14cb5dd874 Remove node24 installation. It's the default now (#8014)
* Remove node24 installation. It's the default now

* Fix insta tests for llvm-cov
2026-06-02 18:07:32 +09:00
Shahar Naveh
a781980d9b Update importlib to 3.14.5 (#8013) 2026-06-02 18:06:36 +09:00
Shahar Naveh
6a860b635a Update some tests to 3.14.5 (#7991) 2026-06-02 18:00:54 +09:00
Shahar Naveh
0c9ed36b79 CI: Don't skip windows tests (#8004) 2026-06-02 15:42:33 +09:00
Bas Schoenmaeckers
2c46b69e4b Add weakref support to c-api (#8006) 2026-06-02 15:31:21 +09:00
Bas Schoenmaeckers
74b4707b65 Add mapping c-api functions (#8003) 2026-06-02 15:20:45 +09:00
Bas Schoenmaeckers
885cf5c29c Add set object functions to c-api (#8002) 2026-06-02 00:35:43 +09:00
Shahar Naveh
3e1c3bc86d impl _queue module + Update multiprocessing to 3.14.5 (#7994) 2026-06-02 00:35:17 +09:00
Bas Schoenmaeckers
36a1722d1b Add agent skill for adding c-api functions (#8001)
* Add agent skill

* abi3/abi3t only
2026-06-02 00:34:06 +09:00
62 changed files with 3194 additions and 401 deletions

View File

@@ -0,0 +1,60 @@
---
name: rustpython-capi-expansion
description: Implement missing RustPython C-API functions in crates/capi using the pyo3-ffi header split mapping (`pyo3-ffi/src/*.rs`, mirroring CPython C API headers). Use this whenever the user asks to add or port C-API functions (for example from setobject.h, dictobject.h, unicodeobject.h) or add capi tests.
---
# RustPython C-API Expansion
Use this workflow for adding missing C-API functions to RustPython.
## Source of truth for target files
- Use this mapping source: `pyo3-ffi/src/*.rs`, which mirrors the CPython header split used by the C API.
- Map requested header APIs to `crates/capi/src/<header_basename>.rs` using that split. Examples:
- `setobject.h` -> `crates/capi/src/setobject.rs`
- `dictobject.h` -> `crates/capi/src/dictobject.rs`
- `unicodeobject.h` -> `crates/capi/src/unicodeobject.rs`
- Do not invent alternate target modules when the header split implies a direct target.
- If the target file is not present yet, create it and wire it in `crates/capi/src/lib.rs`.
## Implementation workflow
1. Identify requested missing APIs from the user request and their originating C API header.
2. Open nearby capi modules in `crates/capi/src/` and follow existing style and patterns.
3. Implement only the requested functions in the mapped target file.
4. Keep behavior aligned with CPython C-API contracts.
5. Prefer using existing `rustpython-vm` functionality as much as possible instead of re-implementing behavior in capi.
6. If a needed `rustpython-vm` helper exists but is private, make it public with a minimal, focused visibility change.
7. Prefer direct contract assumptions over defensive null checks unless required by the established local style.
8. Add basic tests only; do not overfit with very specific edge-case clutter.
9. In tests, use `pyo3` only as a safe wrapper over the API. Avoid raw pointer-heavy direct FFI-style tests.
10. Run tests from `crates/capi`.
## Testing rules
- Run test commands with working directory set to `crates/capi`.
- Prefer targeted tests first (module/function filter), then broader capi tests if needed.
- Keep test names concise (no required `test_` prefix).
## Style rules
- Follow existing RustPython capi coding style in neighboring files.
- Reuse `rustpython-vm` methods and types first; avoid duplicating VM logic in capi wrappers.
- When exposing previously private VM helpers, keep the API surface minimal and avoid unrelated refactors.
- Only expose and implement ABI-stable C-API surface needed for `abi3` / `abi3t`.
- Add comments only when they explain non-obvious behavior.
- Keep edits minimal and focused on requested API expansion.
## Completion checklist
- [ ] All requested functions implemented in mapped target file.
- [ ] New module exported in `crates/capi/src/lib.rs` when applicable.
- [ ] Basic safe-wrapper `pyo3` tests added/updated.
- [ ] Tests executed from `crates/capi` and passing for changed area.
- [ ] Final response includes changed file paths and test command summary.
## Example prompts this skill should handle
- "Implement these missing functions from `dictobject.h`."
- "Add `setobject.h` C-API functions in RustPython and include basic tests."
- "Port the listed `unicodeobject.h` APIs in capi, follow existing style, and run tests from `crates/capi`."

View File

@@ -33,7 +33,6 @@ env:
CARGO_PROFILE_DEV_DEBUG: 0 CARGO_PROFILE_DEV_DEBUG: 0
CARGO_PROFILE_RELEASE_DEBUG: 0 CARGO_PROFILE_RELEASE_DEBUG: 0
CARGO_TERM_COLOR: always CARGO_TERM_COLOR: always
FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true # TODO: Remove on 2026/06/02
CI: true CI: true
jobs: jobs:
@@ -310,11 +309,7 @@ jobs:
extra_test_args: [] # TODO: Enable '-u all' extra_test_args: [] # TODO: Enable '-u all'
env_polluting_tests: env_polluting_tests:
- test_set - test_set
skips: skips: []
- test_rlcompleter
- test_pathlib # panic by surrogate chars
- test_posixpath # OSError: (22, 'The filename, directory name, or volume label syntax is incorrect. (os error 123)')
- test_venv # couple of failing tests
timeout: 50 timeout: 50
fail-fast: false fail-fast: false
steps: steps:
@@ -545,12 +540,6 @@ jobs:
key: prek-${{ hashFiles('.pre-commit-config.yaml') }} key: prek-${{ hashFiles('.pre-commit-config.yaml') }}
path: ~/.cache/prek path: ~/.cache/prek
# TODO: Remove on 2026/06/02 when node24 is the default
- uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0
with:
package-manager-cache: false
node-version: "24"
- name: install prek - name: install prek
id: prek id: prek
uses: j178/prek-action@bdca6f102f98e2b4c7029491a53dfd366469e33d # v2.0.4 uses: j178/prek-action@bdca6f102f98e2b4c7029491a53dfd366469e33d # v2.0.4

View File

@@ -15,7 +15,6 @@ on:
env: env:
CARGO_ARGS: --no-default-features --features stdlib,importlib,stdio,encodings,ssl-rustls-aws-lc,jit,host_env CARGO_ARGS: --no-default-features --features stdlib,importlib,stdio,encodings,ssl-rustls-aws-lc,jit,host_env
FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: 'true' # TODO: Remove on 2026/06/02
jobs: jobs:
# codecov collects code coverage data from the rust tests, python snippets and python test suite. # codecov collects code coverage data from the rust tests, python snippets and python test suite.
@@ -25,6 +24,8 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
# Disable this scheduled job when running on a fork. # Disable this scheduled job when running on a fork.
if: ${{ github.repository == 'RustPython/RustPython' || github.event_name != 'schedule' }} if: ${{ github.repository == 'RustPython/RustPython' || github.event_name != 'schedule' }}
env:
INSTA_WORKSPACE_ROOT: ${{ github.workspace }}
steps: steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with: with:
@@ -56,7 +57,7 @@ jobs:
- name: Upload to Codecov - name: Upload to Codecov
if: ${{ github.event_name != 'pull_request' }} if: ${{ github.event_name != 'pull_request' }}
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0 uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 # v6.0.1
with: with:
files: ./codecov.lcov files: ./codecov.lcov

1
Cargo.lock generated
View File

@@ -3479,6 +3479,7 @@ name = "rustpython-capi"
version = "0.5.0" version = "0.5.0"
dependencies = [ dependencies = [
"bitflags 2.11.1", "bitflags 2.11.1",
"itertools 0.14.0",
"num-complex", "num-complex",
"pyo3", "pyo3",
"rustpython-stdlib", "rustpython-stdlib",

View File

@@ -86,22 +86,27 @@ class REPLThread(threading.Thread):
global return_code global return_code
try: try:
banner = ( if not sys.flags.quiet:
f'asyncio REPL {sys.version} on {sys.platform}\n' banner = (
f'Use "await" directly instead of "asyncio.run()".\n' f'asyncio REPL {sys.version} on {sys.platform}\n'
f'Type "help", "copyright", "credits" or "license" ' f'Use "await" directly instead of "asyncio.run()".\n'
f'for more information.\n' f'Type "help", "copyright", "credits" or "license" '
) f'for more information.\n'
)
console.write(banner) console.write(banner)
if startup_path := os.getenv("PYTHONSTARTUP"): if not sys.flags.isolated and (startup_path := os.getenv("PYTHONSTARTUP")):
sys.audit("cpython.run_startup", startup_path) sys.audit("cpython.run_startup", startup_path)
try:
import tokenize import tokenize
with tokenize.open(startup_path) as f: with tokenize.open(startup_path) as f:
startup_code = compile(f.read(), startup_path, "exec") startup_code = compile(f.read(), startup_path, "exec")
exec(startup_code, console.locals) exec(startup_code, console.locals)
except SystemExit:
raise
except BaseException:
console.showtraceback()
ps1 = getattr(sys, "ps1", ">>> ") ps1 = getattr(sys, "ps1", ">>> ")
if CAN_USE_PYREPL: if CAN_USE_PYREPL:
@@ -236,4 +241,5 @@ if __name__ == '__main__':
break break
console.write('exiting asyncio REPL...\n') console.write('exiting asyncio REPL...\n')
loop.close()
sys.exit(return_code) sys.exit(return_code)

View File

@@ -1345,6 +1345,17 @@ class BaseEventLoop(events.AbstractEventLoop):
# have a chance to get called before "ssl_protocol.connection_made()". # have a chance to get called before "ssl_protocol.connection_made()".
transport.pause_reading() transport.pause_reading()
# gh-142352: move buffered StreamReader data to SSLProtocol
if server_side:
from .streams import StreamReaderProtocol
if isinstance(protocol, StreamReaderProtocol):
stream_reader = getattr(protocol, '_stream_reader', None)
if stream_reader is not None:
buffer = stream_reader._buffer
if buffer:
ssl_protocol._incoming.write(buffer)
buffer.clear()
transport.set_protocol(ssl_protocol) transport.set_protocol(ssl_protocol)
conmade_cb = self.call_soon(ssl_protocol.connection_made, transport) conmade_cb = self.call_soon(ssl_protocol.connection_made, transport)
resume_cb = self.call_soon(transport.resume_reading) resume_cb = self.call_soon(transport.resume_reading)

View File

@@ -265,7 +265,7 @@ class BaseSubprocessTransport(transports.SubprocessTransport):
# to avoid hanging forever in self._wait as otherwise _exit_waiters # to avoid hanging forever in self._wait as otherwise _exit_waiters
# would never be woken up, we wake them up here. # would never be woken up, we wake them up here.
for waiter in self._exit_waiters: for waiter in self._exit_waiters:
if not waiter.cancelled(): if not waiter.done():
waiter.set_result(self._returncode) waiter.set_result(self._returncode)
if all(p is not None and p.disconnected if all(p is not None and p.disconnected
for p in self._pipes.values()): for p in self._pipes.values()):
@@ -278,7 +278,7 @@ class BaseSubprocessTransport(transports.SubprocessTransport):
finally: finally:
# wake up futures waiting for wait() # wake up futures waiting for wait()
for waiter in self._exit_waiters: for waiter in self._exit_waiters:
if not waiter.cancelled(): if not waiter.done():
waiter.set_result(self._returncode) waiter.set_result(self._returncode)
self._exit_waiters = None self._exit_waiters = None
self._loop = None self._loop = None

View File

@@ -392,7 +392,7 @@ def _chain_future(source, destination):
def _call_check_cancel(destination): def _call_check_cancel(destination):
if destination.cancelled(): if destination.cancelled():
if source_loop is None or source_loop is dest_loop: if source_loop is None or source_loop is events._get_running_loop():
source.cancel() source.cancel()
else: else:
source_loop.call_soon_threadsafe(source.cancel) source_loop.call_soon_threadsafe(source.cancel)
@@ -401,7 +401,7 @@ def _chain_future(source, destination):
if (destination.cancelled() and if (destination.cancelled() and
dest_loop is not None and dest_loop.is_closed()): dest_loop is not None and dest_loop.is_closed()):
return return
if dest_loop is None or dest_loop is source_loop: if dest_loop is None or dest_loop is events._get_running_loop():
_set_state(destination, source) _set_state(destination, source)
else: else:
if dest_loop.is_closed(): if dest_loop.is_closed():

View File

@@ -37,7 +37,7 @@ class Queue(mixins._LoopBoundMixin):
is an integer greater than 0, then "await put()" will block when the is an integer greater than 0, then "await put()" will block when the
queue reaches maxsize, until an item is removed by get(). queue reaches maxsize, until an item is removed by get().
Unlike the standard library Queue, you can reliably know this Queue's size Unlike queue.Queue, you can reliably know this Queue's size
with qsize(), since your single-threaded asyncio application won't be with qsize(), since your single-threaded asyncio application won't be
interrupted between calling qsize() and doing an operation on the Queue. interrupted between calling qsize() and doing an operation on the Queue.
""" """

View File

@@ -10,7 +10,6 @@ import itertools
import msvcrt import msvcrt
import os import os
import subprocess import subprocess
import tempfile
import warnings import warnings
@@ -24,6 +23,7 @@ BUFSIZE = 8192
PIPE = subprocess.PIPE PIPE = subprocess.PIPE
STDOUT = subprocess.STDOUT STDOUT = subprocess.STDOUT
_mmap_counter = itertools.count() _mmap_counter = itertools.count()
_MAX_PIPE_ATTEMPTS = 20
# Replacement for os.pipe() using handles instead of fds # Replacement for os.pipe() using handles instead of fds
@@ -31,10 +31,6 @@ _mmap_counter = itertools.count()
def pipe(*, duplex=False, overlapped=(True, True), bufsize=BUFSIZE): def pipe(*, duplex=False, overlapped=(True, True), bufsize=BUFSIZE):
"""Like os.pipe() but with overlapped support and using handles not fds.""" """Like os.pipe() but with overlapped support and using handles not fds."""
address = tempfile.mktemp(
prefix=r'\\.\pipe\python-pipe-{:d}-{:d}-'.format(
os.getpid(), next(_mmap_counter)))
if duplex: if duplex:
openmode = _winapi.PIPE_ACCESS_DUPLEX openmode = _winapi.PIPE_ACCESS_DUPLEX
access = _winapi.GENERIC_READ | _winapi.GENERIC_WRITE access = _winapi.GENERIC_READ | _winapi.GENERIC_WRITE
@@ -56,9 +52,20 @@ def pipe(*, duplex=False, overlapped=(True, True), bufsize=BUFSIZE):
h1 = h2 = None h1 = h2 = None
try: try:
h1 = _winapi.CreateNamedPipe( for attempts in itertools.count():
address, openmode, _winapi.PIPE_WAIT, address = r'\\.\pipe\python-pipe-{:d}-{:d}-{}'.format(
1, obsize, ibsize, _winapi.NMPWAIT_WAIT_FOREVER, _winapi.NULL) os.getpid(), next(_mmap_counter), os.urandom(8).hex())
try:
h1 = _winapi.CreateNamedPipe(
address, openmode, _winapi.PIPE_WAIT,
1, obsize, ibsize, _winapi.NMPWAIT_WAIT_FOREVER, _winapi.NULL)
break
except OSError as e:
if attempts >= _MAX_PIPE_ATTEMPTS:
raise
if e.winerror not in (_winapi.ERROR_PIPE_BUSY,
_winapi.ERROR_ACCESS_DENIED):
raise
h2 = _winapi.CreateFile( h2 = _winapi.CreateFile(
address, access, 0, _winapi.NULL, _winapi.OPEN_EXISTING, address, access, 0, _winapi.NULL, _winapi.OPEN_EXISTING,
@@ -104,8 +111,9 @@ class PipeHandle:
def close(self, *, CloseHandle=_winapi.CloseHandle): def close(self, *, CloseHandle=_winapi.CloseHandle):
if self._handle is not None: if self._handle is not None:
CloseHandle(self._handle) handle = self._handle
self._handle = None self._handle = None
CloseHandle(handle)
def __del__(self, _warn=warnings.warn): def __del__(self, _warn=warnings.warn):
if self._handle is not None: if self._handle is not None:

View File

@@ -1375,6 +1375,14 @@ def _find_and_load(name, import_):
# NOTE: because of this, initializing must be set *before* # NOTE: because of this, initializing must be set *before*
# putting the new module in sys.modules. # putting the new module in sys.modules.
_lock_unlock_module(name) _lock_unlock_module(name)
else:
# Verify the module is still in sys.modules. Another thread may have
# removed it (due to import failure) between our sys.modules.get()
# above and the _initializing check. If removed, we retry the import
# to preserve normal semantics: the caller gets the exception from
# the actual import failure rather than a synthetic error.
if sys.modules.get(name) is not module:
return _find_and_load(name, import_)
if module is None: if module is None:
message = f'import of {name} halted; None in sys.modules' message = f'import of {name} halted; None in sys.modules'

View File

@@ -946,7 +946,7 @@ class FileLoader:
def get_data(self, path): def get_data(self, path):
"""Return the data from path as raw bytes.""" """Return the data from path as raw bytes."""
if isinstance(self, (SourceLoader, ExtensionFileLoader)): if isinstance(self, (SourceLoader, SourcelessFileLoader, ExtensionFileLoader)):
with _io.open_code(str(path)) as file: with _io.open_code(str(path)) as file:
return file.read() return file.read()
else: else:

View File

@@ -11,13 +11,12 @@ __all__ = [ 'Client', 'Listener', 'Pipe', 'wait' ]
import errno import errno
import io import io
import itertools
import os import os
import sys import sys
import socket import socket
import struct import struct
import time import time
import tempfile
import itertools
from . import util from . import util
@@ -39,11 +38,14 @@ except ImportError:
# #
# #
BUFSIZE = 8192 # 64 KiB is the default PIPE buffer size of most POSIX platforms.
BUFSIZE = 64 * 1024
# A very generous timeout when it comes to local connections... # A very generous timeout when it comes to local connections...
CONNECTION_TIMEOUT = 20. CONNECTION_TIMEOUT = 20.
_mmap_counter = itertools.count() _mmap_counter = itertools.count()
_MAX_PIPE_ATTEMPTS = 100
default_family = 'AF_INET' default_family = 'AF_INET'
families = ['AF_INET'] families = ['AF_INET']
@@ -74,10 +76,14 @@ def arbitrary_address(family):
if family == 'AF_INET': if family == 'AF_INET':
return ('localhost', 0) return ('localhost', 0)
elif family == 'AF_UNIX': elif family == 'AF_UNIX':
return tempfile.mktemp(prefix='sock-', dir=util.get_temp_dir()) # NOTE: util.get_temp_dir() is a 0o700 per-process directory. A
# mktemp-style ToC vs ToU concern is not important; bind() surfaces
# the extremely unlikely collision as EADDRINUSE.
return os.path.join(util.get_temp_dir(),
f'sock-{os.urandom(6).hex()}')
elif family == 'AF_PIPE': elif family == 'AF_PIPE':
return tempfile.mktemp(prefix=r'\\.\pipe\pyc-%d-%d-' % return (r'\\.\pipe\pyc-%d-%d-%s' %
(os.getpid(), next(_mmap_counter)), dir="") (os.getpid(), next(_mmap_counter), os.urandom(8).hex()))
else: else:
raise ValueError('unrecognized family') raise ValueError('unrecognized family')
@@ -179,6 +185,10 @@ class _ConnectionBase:
finally: finally:
self._handle = None self._handle = None
def _detach(self):
"""Stop managing the underlying file descriptor or handle."""
self._handle = None
def send_bytes(self, buf, offset=0, size=None): def send_bytes(self, buf, offset=0, size=None):
"""Send the bytes data from a bytes-like object""" """Send the bytes data from a bytes-like object"""
self._check_closed() self._check_closed()
@@ -316,22 +326,32 @@ if _winapi:
try: try:
ov, err = _winapi.ReadFile(self._handle, bsize, ov, err = _winapi.ReadFile(self._handle, bsize,
overlapped=True) overlapped=True)
sentinel = object()
return_value = sentinel
try: try:
if err == _winapi.ERROR_IO_PENDING: try:
waitres = _winapi.WaitForMultipleObjects( if err == _winapi.ERROR_IO_PENDING:
[ov.event], False, INFINITE) waitres = _winapi.WaitForMultipleObjects(
assert waitres == WAIT_OBJECT_0 [ov.event], False, INFINITE)
assert waitres == WAIT_OBJECT_0
except:
ov.cancel()
raise
finally:
nread, err = ov.GetOverlappedResult(True)
if err == 0:
f = io.BytesIO()
f.write(ov.getbuffer())
return_value = f
elif err == _winapi.ERROR_MORE_DATA:
return_value = self._get_more_data(ov, maxsize)
except: except:
ov.cancel() if return_value is sentinel:
raise raise
finally:
nread, err = ov.GetOverlappedResult(True) if return_value is not sentinel:
if err == 0: return return_value
f = io.BytesIO()
f.write(ov.getbuffer())
return f
elif err == _winapi.ERROR_MORE_DATA:
return self._get_more_data(ov, maxsize)
except OSError as e: except OSError as e:
if e.winerror == _winapi.ERROR_BROKEN_PIPE: if e.winerror == _winapi.ERROR_BROKEN_PIPE:
raise EOFError raise EOFError
@@ -392,7 +412,8 @@ class Connection(_ConnectionBase):
handle = self._handle handle = self._handle
remaining = size remaining = size
while remaining > 0: while remaining > 0:
chunk = read(handle, remaining) to_read = min(BUFSIZE, remaining)
chunk = read(handle, to_read)
n = len(chunk) n = len(chunk)
if n == 0: if n == 0:
if remaining == size: if remaining == size:
@@ -455,17 +476,29 @@ class Listener(object):
def __init__(self, address=None, family=None, backlog=1, authkey=None): def __init__(self, address=None, family=None, backlog=1, authkey=None):
family = family or (address and address_type(address)) \ family = family or (address and address_type(address)) \
or default_family or default_family
address = address or arbitrary_address(family)
_validate_family(family) _validate_family(family)
if family == 'AF_PIPE':
self._listener = PipeListener(address, backlog)
else:
self._listener = SocketListener(address, family, backlog)
if authkey is not None and not isinstance(authkey, bytes): if authkey is not None and not isinstance(authkey, bytes):
raise TypeError('authkey should be a byte string') raise TypeError('authkey should be a byte string')
if family == 'AF_PIPE':
if address:
self._listener = PipeListener(address, backlog)
else:
for attempts in itertools.count():
address = arbitrary_address(family)
try:
self._listener = PipeListener(address, backlog)
break
except OSError as e:
if attempts >= _MAX_PIPE_ATTEMPTS:
raise
if e.winerror not in (_winapi.ERROR_PIPE_BUSY,
_winapi.ERROR_ACCESS_DENIED):
raise
else:
address = address or arbitrary_address(family)
self._listener = SocketListener(address, family, backlog)
self._authkey = authkey self._authkey = authkey
def accept(self): def accept(self):
@@ -553,7 +586,6 @@ else:
''' '''
Returns pair of connection objects at either end of a pipe Returns pair of connection objects at either end of a pipe
''' '''
address = arbitrary_address('AF_PIPE')
if duplex: if duplex:
openmode = _winapi.PIPE_ACCESS_DUPLEX openmode = _winapi.PIPE_ACCESS_DUPLEX
access = _winapi.GENERIC_READ | _winapi.GENERIC_WRITE access = _winapi.GENERIC_READ | _winapi.GENERIC_WRITE
@@ -563,15 +595,25 @@ else:
access = _winapi.GENERIC_WRITE access = _winapi.GENERIC_WRITE
obsize, ibsize = 0, BUFSIZE obsize, ibsize = 0, BUFSIZE
h1 = _winapi.CreateNamedPipe( for attempts in itertools.count():
address, openmode | _winapi.FILE_FLAG_OVERLAPPED | address = arbitrary_address('AF_PIPE')
_winapi.FILE_FLAG_FIRST_PIPE_INSTANCE, try:
_winapi.PIPE_TYPE_MESSAGE | _winapi.PIPE_READMODE_MESSAGE | h1 = _winapi.CreateNamedPipe(
_winapi.PIPE_WAIT, address, openmode | _winapi.FILE_FLAG_OVERLAPPED |
1, obsize, ibsize, _winapi.NMPWAIT_WAIT_FOREVER, _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE,
# default security descriptor: the handle cannot be inherited _winapi.PIPE_TYPE_MESSAGE | _winapi.PIPE_READMODE_MESSAGE |
_winapi.NULL _winapi.PIPE_WAIT,
) 1, obsize, ibsize, _winapi.NMPWAIT_WAIT_FOREVER,
# default security descriptor: the handle cannot be inherited
_winapi.NULL
)
break
except OSError as e:
if attempts >= _MAX_PIPE_ATTEMPTS:
raise
if e.winerror not in (_winapi.ERROR_PIPE_BUSY,
_winapi.ERROR_ACCESS_DENIED):
raise
h2 = _winapi.CreateFile( h2 = _winapi.CreateFile(
address, access, 0, _winapi.NULL, _winapi.OPEN_EXISTING, address, access, 0, _winapi.NULL, _winapi.OPEN_EXISTING,
_winapi.FILE_FLAG_OVERLAPPED, _winapi.NULL _winapi.FILE_FLAG_OVERLAPPED, _winapi.NULL

View File

@@ -145,7 +145,13 @@ class BaseContext(object):
'''Check whether this is a fake forked process in a frozen executable. '''Check whether this is a fake forked process in a frozen executable.
If so then run code specified by commandline and exit. If so then run code specified by commandline and exit.
''' '''
if self.get_start_method() == 'spawn' and getattr(sys, 'frozen', False): # gh-140814: allow_none=True avoids locking in the default start
# method, which would cause a later set_start_method() to fail.
# None is safe to pass through: spawn.freeze_support()
# independently detects whether this process is a spawned
# child, so the start method check here is only an optimization.
if (getattr(sys, 'frozen', False)
and self.get_start_method(allow_none=True) in ('spawn', None)):
from .spawn import freeze_support from .spawn import freeze_support
freeze_support() freeze_support()
@@ -167,7 +173,7 @@ class BaseContext(object):
''' '''
# This is undocumented. In previous versions of multiprocessing # This is undocumented. In previous versions of multiprocessing
# its only effect was to make socket objects inheritable on Windows. # its only effect was to make socket objects inheritable on Windows.
from . import connection from . import connection # noqa: F401
def set_executable(self, executable): def set_executable(self, executable):
'''Sets the path to a python.exe or pythonw.exe binary used to run '''Sets the path to a python.exe or pythonw.exe binary used to run
@@ -259,13 +265,12 @@ class DefaultContext(BaseContext):
def get_all_start_methods(self): def get_all_start_methods(self):
"""Returns a list of the supported start methods, default first.""" """Returns a list of the supported start methods, default first."""
if sys.platform == 'win32': default = self._default_context.get_start_method()
return ['spawn'] start_method_names = [default]
else: start_method_names.extend(
methods = ['spawn', 'fork'] if sys.platform == 'darwin' else ['fork', 'spawn'] name for name in _concrete_contexts if name != default
if reduction.HAVE_SEND_HANDLE: )
methods.append('forkserver') return start_method_names
return methods
# #
@@ -320,14 +325,15 @@ if sys.platform != 'win32':
'spawn': SpawnContext(), 'spawn': SpawnContext(),
'forkserver': ForkServerContext(), 'forkserver': ForkServerContext(),
} }
if sys.platform == 'darwin': # bpo-33725: running arbitrary code after fork() is no longer reliable
# bpo-33725: running arbitrary code after fork() is no longer reliable # on macOS since macOS 10.14 (Mojave). Use spawn by default instead.
# on macOS since macOS 10.14 (Mojave). Use spawn by default instead. # gh-84559: We changed everyones default to a thread safeish one in 3.14.
_default_context = DefaultContext(_concrete_contexts['spawn']) if reduction.HAVE_SEND_HANDLE and sys.platform != 'darwin':
_default_context = DefaultContext(_concrete_contexts['forkserver'])
else: else:
_default_context = DefaultContext(_concrete_contexts['fork']) _default_context = DefaultContext(_concrete_contexts['spawn'])
else: else: # Windows
class SpawnProcess(process.BaseProcess): class SpawnProcess(process.BaseProcess):
_start_method = 'spawn' _start_method = 'spawn'

View File

@@ -33,7 +33,7 @@ from queue import Queue
class DummyProcess(threading.Thread): class DummyProcess(threading.Thread):
def __init__(self, group=None, target=None, name=None, args=(), kwargs={}): def __init__(self, group=None, target=None, name=None, args=(), kwargs=None):
threading.Thread.__init__(self, group, target, name, args, kwargs) threading.Thread.__init__(self, group, target, name, args, kwargs)
self._pid = None self._pid = None
self._children = weakref.WeakKeyDictionary() self._children = weakref.WeakKeyDictionary()

View File

@@ -9,6 +9,7 @@ import sys
import threading import threading
import warnings import warnings
from . import AuthenticationError
from . import connection from . import connection
from . import process from . import process
from .context import reduction from .context import reduction
@@ -25,6 +26,7 @@ __all__ = ['ensure_running', 'get_inherited_fds', 'connect_to_new_process',
MAXFDS_TO_SEND = 256 MAXFDS_TO_SEND = 256
SIGNED_STRUCT = struct.Struct('q') # large enough for pid_t SIGNED_STRUCT = struct.Struct('q') # large enough for pid_t
_AUTHKEY_LEN = 32 # <= PIPEBUF so it fits a single write to an empty pipe.
# #
# Forkserver class # Forkserver class
@@ -33,6 +35,7 @@ SIGNED_STRUCT = struct.Struct('q') # large enough for pid_t
class ForkServer(object): class ForkServer(object):
def __init__(self): def __init__(self):
self._forkserver_authkey = None
self._forkserver_address = None self._forkserver_address = None
self._forkserver_alive_fd = None self._forkserver_alive_fd = None
self._forkserver_pid = None self._forkserver_pid = None
@@ -59,6 +62,7 @@ class ForkServer(object):
if not util.is_abstract_socket_namespace(self._forkserver_address): if not util.is_abstract_socket_namespace(self._forkserver_address):
os.unlink(self._forkserver_address) os.unlink(self._forkserver_address)
self._forkserver_address = None self._forkserver_address = None
self._forkserver_authkey = None
def set_forkserver_preload(self, modules_names): def set_forkserver_preload(self, modules_names):
'''Set list of module names to try to load in forkserver process.''' '''Set list of module names to try to load in forkserver process.'''
@@ -83,6 +87,7 @@ class ForkServer(object):
process data. process data.
''' '''
self.ensure_running() self.ensure_running()
assert self._forkserver_authkey
if len(fds) + 4 >= MAXFDS_TO_SEND: if len(fds) + 4 >= MAXFDS_TO_SEND:
raise ValueError('too many fds') raise ValueError('too many fds')
with socket.socket(socket.AF_UNIX) as client: with socket.socket(socket.AF_UNIX) as client:
@@ -93,6 +98,18 @@ class ForkServer(object):
resource_tracker.getfd()] resource_tracker.getfd()]
allfds += fds allfds += fds
try: try:
client.setblocking(True)
wrapped_client = connection.Connection(client.fileno())
# The other side of this exchange happens in the child as
# implemented in main().
try:
connection.answer_challenge(
wrapped_client, self._forkserver_authkey)
connection.deliver_challenge(
wrapped_client, self._forkserver_authkey)
finally:
wrapped_client._detach()
del wrapped_client
reduction.sendfds(client, allfds) reduction.sendfds(client, allfds)
return parent_r, parent_w return parent_r, parent_w
except: except:
@@ -120,20 +137,30 @@ class ForkServer(object):
return return
# dead, launch it again # dead, launch it again
os.close(self._forkserver_alive_fd) os.close(self._forkserver_alive_fd)
self._forkserver_authkey = None
self._forkserver_address = None self._forkserver_address = None
self._forkserver_alive_fd = None self._forkserver_alive_fd = None
self._forkserver_pid = None self._forkserver_pid = None
cmd = ('from multiprocessing.forkserver import main; ' + # gh-144503: sys_argv is passed as real argv elements after the
'main(%d, %d, %r, **%r)') # ``-c cmd`` rather than repr'd into main_kws so that a large
# parent sys.argv cannot push the single ``-c`` command string
# over the OS per-argument length limit (MAX_ARG_STRLEN on Linux).
# The child sees them as sys.argv[1:].
cmd = ('import sys; '
'from multiprocessing.forkserver import main; '
'main(%d, %d, %r, sys_argv=sys.argv[1:], **%r)')
main_kws = {} main_kws = {}
sys_argv = None
if self._preload_modules: if self._preload_modules:
data = spawn.get_preparation_data('ignore') data = spawn.get_preparation_data('ignore')
if 'sys_path' in data: if 'sys_path' in data:
main_kws['sys_path'] = data['sys_path'] main_kws['sys_path'] = data['sys_path']
if 'init_main_from_path' in data: if 'init_main_from_path' in data:
main_kws['main_path'] = data['init_main_from_path'] main_kws['main_path'] = data['init_main_from_path']
if 'sys_argv' in data:
sys_argv = data['sys_argv']
with socket.socket(socket.AF_UNIX) as listener: with socket.socket(socket.AF_UNIX) as listener:
address = connection.arbitrary_address('AF_UNIX') address = connection.arbitrary_address('AF_UNIX')
@@ -145,19 +172,33 @@ class ForkServer(object):
# all client processes own the write end of the "alive" pipe; # all client processes own the write end of the "alive" pipe;
# when they all terminate the read end becomes ready. # when they all terminate the read end becomes ready.
alive_r, alive_w = os.pipe() alive_r, alive_w = os.pipe()
# A short lived pipe to initialize the forkserver authkey.
authkey_r, authkey_w = os.pipe()
try: try:
fds_to_pass = [listener.fileno(), alive_r] fds_to_pass = [listener.fileno(), alive_r, authkey_r]
main_kws['authkey_r'] = authkey_r
cmd %= (listener.fileno(), alive_r, self._preload_modules, cmd %= (listener.fileno(), alive_r, self._preload_modules,
main_kws) main_kws)
exe = spawn.get_executable() exe = spawn.get_executable()
args = [exe] + util._args_from_interpreter_flags() args = [exe] + util._args_from_interpreter_flags()
args += ['-c', cmd] args += ['-c', cmd]
if sys_argv is not None:
args += sys_argv
pid = util.spawnv_passfds(exe, args, fds_to_pass) pid = util.spawnv_passfds(exe, args, fds_to_pass)
except: except:
os.close(alive_w) os.close(alive_w)
os.close(authkey_w)
raise raise
finally: finally:
os.close(alive_r) os.close(alive_r)
os.close(authkey_r)
# Authenticate our control socket to prevent access from
# processes we have not shared this key with.
try:
self._forkserver_authkey = os.urandom(_AUTHKEY_LEN)
os.write(authkey_w, self._forkserver_authkey)
finally:
os.close(authkey_w)
self._forkserver_address = address self._forkserver_address = address
self._forkserver_alive_fd = alive_w self._forkserver_alive_fd = alive_w
self._forkserver_pid = pid self._forkserver_pid = pid
@@ -166,9 +207,21 @@ class ForkServer(object):
# #
# #
def main(listener_fd, alive_r, preload, main_path=None, sys_path=None): def main(listener_fd, alive_r, preload, main_path=None, sys_path=None,
'''Run forkserver.''' *, sys_argv=None, authkey_r=None):
"""Run forkserver."""
if authkey_r is not None:
try:
authkey = os.read(authkey_r, _AUTHKEY_LEN)
assert len(authkey) == _AUTHKEY_LEN, f'{len(authkey)} < {_AUTHKEY_LEN}'
finally:
os.close(authkey_r)
else:
authkey = b''
if preload: if preload:
if sys_argv is not None:
sys.argv[:] = sys_argv
if sys_path is not None: if sys_path is not None:
sys.path[:] = sys_path sys.path[:] = sys_path
if '__main__' in preload and main_path is not None: if '__main__' in preload and main_path is not None:
@@ -262,8 +315,24 @@ def main(listener_fd, alive_r, preload, main_path=None, sys_path=None):
if listener in rfds: if listener in rfds:
# Incoming fork request # Incoming fork request
with listener.accept()[0] as s: with listener.accept()[0] as s:
# Receive fds from client try:
fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1) if authkey:
wrapped_s = connection.Connection(s.fileno())
# The other side of this exchange happens in
# in connect_to_new_process().
try:
connection.deliver_challenge(
wrapped_s, authkey)
connection.answer_challenge(
wrapped_s, authkey)
finally:
wrapped_s._detach()
del wrapped_s
# Receive fds from client
fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1)
except (EOFError, BrokenPipeError, AuthenticationError):
s.close()
continue
if len(fds) > MAXFDS_TO_SEND: if len(fds) > MAXFDS_TO_SEND:
raise RuntimeError( raise RuntimeError(
"Too many ({0:n}) fds to send".format( "Too many ({0:n}) fds to send".format(
@@ -331,13 +400,14 @@ def _serve_one(child_r, fds, unused_fds, handlers):
# #
def read_signed(fd): def read_signed(fd):
data = b'' data = bytearray(SIGNED_STRUCT.size)
length = SIGNED_STRUCT.size unread = memoryview(data)
while len(data) < length: while unread:
s = os.read(fd, length - len(data)) count = os.readinto(fd, unread)
if not s: if count == 0:
raise EOFError('unexpected EOF') raise EOFError('unexpected EOF')
data += s unread = unread[count:]
return SIGNED_STRUCT.unpack(data)[0] return SIGNED_STRUCT.unpack(data)[0]
def write_signed(fd, n): def write_signed(fd, n):

View File

@@ -18,6 +18,7 @@ import sys
import threading import threading
import signal import signal
import array import array
import collections.abc
import queue import queue
import time import time
import types import types
@@ -1058,12 +1059,14 @@ class IteratorProxy(BaseProxy):
class AcquirerProxy(BaseProxy): class AcquirerProxy(BaseProxy):
_exposed_ = ('acquire', 'release') _exposed_ = ('acquire', 'release', 'locked')
def acquire(self, blocking=True, timeout=None): def acquire(self, blocking=True, timeout=None):
args = (blocking,) if timeout is None else (blocking, timeout) args = (blocking,) if timeout is None else (blocking, timeout)
return self._callmethod('acquire', args) return self._callmethod('acquire', args)
def release(self): def release(self):
return self._callmethod('release') return self._callmethod('release')
def locked(self):
return self._callmethod('locked')
def __enter__(self): def __enter__(self):
return self._callmethod('acquire') return self._callmethod('acquire')
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
@@ -1071,7 +1074,7 @@ class AcquirerProxy(BaseProxy):
class ConditionProxy(AcquirerProxy): class ConditionProxy(AcquirerProxy):
_exposed_ = ('acquire', 'release', 'wait', 'notify', 'notify_all') _exposed_ = ('acquire', 'release', 'locked', 'wait', 'notify', 'notify_all')
def wait(self, timeout=None): def wait(self, timeout=None):
return self._callmethod('wait', (timeout,)) return self._callmethod('wait', (timeout,))
def notify(self, n=1): def notify(self, n=1):
@@ -1159,10 +1162,10 @@ class ValueProxy(BaseProxy):
BaseListProxy = MakeProxyType('BaseListProxy', ( BaseListProxy = MakeProxyType('BaseListProxy', (
'__add__', '__contains__', '__delitem__', '__getitem__', '__len__', '__add__', '__contains__', '__delitem__', '__getitem__', '__imul__',
'__mul__', '__reversed__', '__rmul__', '__setitem__', '__len__', '__mul__', '__reversed__', '__rmul__', '__setitem__',
'append', 'count', 'extend', 'index', 'insert', 'pop', 'remove', 'append', 'clear', 'copy', 'count', 'extend', 'index', 'insert', 'pop',
'reverse', 'sort', '__imul__' 'remove', 'reverse', 'sort',
)) ))
class ListProxy(BaseListProxy): class ListProxy(BaseListProxy):
def __iadd__(self, value): def __iadd__(self, value):
@@ -1174,18 +1177,55 @@ class ListProxy(BaseListProxy):
__class_getitem__ = classmethod(types.GenericAlias) __class_getitem__ = classmethod(types.GenericAlias)
collections.abc.MutableSequence.register(BaseListProxy)
_BaseDictProxy = MakeProxyType('DictProxy', ( _BaseDictProxy = MakeProxyType('_BaseDictProxy', (
'__contains__', '__delitem__', '__getitem__', '__iter__', '__len__', '__contains__', '__delitem__', '__getitem__', '__ior__', '__iter__',
'__setitem__', 'clear', 'copy', 'get', 'items', '__len__', '__or__', '__reversed__', '__ror__',
'__setitem__', 'clear', 'copy', 'fromkeys', 'get', 'items',
'keys', 'pop', 'popitem', 'setdefault', 'update', 'values' 'keys', 'pop', 'popitem', 'setdefault', 'update', 'values'
)) ))
_BaseDictProxy._method_to_typeid_ = { _BaseDictProxy._method_to_typeid_ = {
'__iter__': 'Iterator', '__iter__': 'Iterator',
} }
class DictProxy(_BaseDictProxy): class DictProxy(_BaseDictProxy):
def __ior__(self, value):
self._callmethod('__ior__', (value,))
return self
__class_getitem__ = classmethod(types.GenericAlias) __class_getitem__ = classmethod(types.GenericAlias)
collections.abc.MutableMapping.register(_BaseDictProxy)
_BaseSetProxy = MakeProxyType("_BaseSetProxy", (
'__and__', '__class_getitem__', '__contains__', '__iand__', '__ior__',
'__isub__', '__iter__', '__ixor__', '__len__', '__or__', '__rand__',
'__ror__', '__rsub__', '__rxor__', '__sub__', '__xor__',
'__ge__', '__gt__', '__le__', '__lt__',
'add', 'clear', 'copy', 'difference', 'difference_update', 'discard',
'intersection', 'intersection_update', 'isdisjoint', 'issubset',
'issuperset', 'pop', 'remove', 'symmetric_difference',
'symmetric_difference_update', 'union', 'update',
))
class SetProxy(_BaseSetProxy):
def __ior__(self, value):
self._callmethod('__ior__', (value,))
return self
def __iand__(self, value):
self._callmethod('__iand__', (value,))
return self
def __ixor__(self, value):
self._callmethod('__ixor__', (value,))
return self
def __isub__(self, value):
self._callmethod('__isub__', (value,))
return self
__class_getitem__ = classmethod(types.GenericAlias)
collections.abc.MutableMapping.register(_BaseSetProxy)
ArrayProxy = MakeProxyType('ArrayProxy', ( ArrayProxy = MakeProxyType('ArrayProxy', (
'__len__', '__getitem__', '__setitem__' '__len__', '__getitem__', '__setitem__'
@@ -1237,6 +1277,7 @@ SyncManager.register('Barrier', threading.Barrier, BarrierProxy)
SyncManager.register('Pool', pool.Pool, PoolProxy) SyncManager.register('Pool', pool.Pool, PoolProxy)
SyncManager.register('list', list, ListProxy) SyncManager.register('list', list, ListProxy)
SyncManager.register('dict', dict, DictProxy) SyncManager.register('dict', dict, DictProxy)
SyncManager.register('set', set, SetProxy)
SyncManager.register('Value', Value, ValueProxy) SyncManager.register('Value', Value, ValueProxy)
SyncManager.register('Array', Array, ArrayProxy) SyncManager.register('Array', Array, ArrayProxy)
SyncManager.register('Namespace', Namespace, NamespaceProxy) SyncManager.register('Namespace', Namespace, NamespaceProxy)

View File

@@ -54,6 +54,9 @@ class Popen(object):
if self.wait(timeout=0.1) is None: if self.wait(timeout=0.1) is None:
raise raise
def interrupt(self):
self._send_signal(signal.SIGINT)
def terminate(self): def terminate(self):
self._send_signal(signal.SIGTERM) self._send_signal(signal.SIGTERM)
@@ -64,7 +67,17 @@ class Popen(object):
code = 1 code = 1
parent_r, child_w = os.pipe() parent_r, child_w = os.pipe()
child_r, parent_w = os.pipe() child_r, parent_w = os.pipe()
self.pid = os.fork() # gh-146313: Tell the resource tracker's at-fork handler to keep
# the inherited pipe fd so this child reuses the parent's tracker
# (gh-80849) rather than closing it and launching its own.
from .resource_tracker import _fork_intent
_fork_intent.preserve_fd = True
try:
self.pid = os.fork()
finally:
# Reset in both parent and child so the flag does not leak
# into a subsequent raw os.fork() or nested Process launch.
_fork_intent.preserve_fd = False
if self.pid == 0: if self.pid == 0:
try: try:
atexit._clear() atexit._clear()

View File

@@ -77,7 +77,7 @@ class BaseProcess(object):
def _Popen(self): def _Popen(self):
raise NotImplementedError raise NotImplementedError
def __init__(self, group=None, target=None, name=None, args=(), kwargs={}, def __init__(self, group=None, target=None, name=None, args=(), kwargs=None,
*, daemon=None): *, daemon=None):
assert group is None, 'group argument must be None for now' assert group is None, 'group argument must be None for now'
count = next(_process_counter) count = next(_process_counter)
@@ -89,7 +89,7 @@ class BaseProcess(object):
self._closed = False self._closed = False
self._target = target self._target = target
self._args = tuple(args) self._args = tuple(args)
self._kwargs = dict(kwargs) self._kwargs = dict(kwargs) if kwargs else {}
self._name = name or type(self).__name__ + '-' + \ self._name = name or type(self).__name__ + '-' + \
':'.join(str(i) for i in self._identity) ':'.join(str(i) for i in self._identity)
if daemon is not None: if daemon is not None:
@@ -125,6 +125,13 @@ class BaseProcess(object):
del self._target, self._args, self._kwargs del self._target, self._args, self._kwargs
_children.add(self) _children.add(self)
def interrupt(self):
'''
Terminate process; sends SIGINT signal
'''
self._check_closed()
self._popen.interrupt()
def terminate(self): def terminate(self):
''' '''
Terminate process; sends SIGTERM signal or uses TerminateProcess() Terminate process; sends SIGTERM signal or uses TerminateProcess()

View File

@@ -121,7 +121,7 @@ class Queue(object):
def qsize(self): def qsize(self):
# Raises NotImplementedError on Mac OSX because of broken sem_getvalue() # Raises NotImplementedError on Mac OSX because of broken sem_getvalue()
return self._maxsize - self._sem._semlock._get_value() return self._maxsize - self._sem.get_value()
def empty(self): def empty(self):
return not self._poll() return not self._poll()

View File

@@ -139,15 +139,12 @@ else:
__all__ += ['DupFd', 'sendfds', 'recvfds'] __all__ += ['DupFd', 'sendfds', 'recvfds']
import array import array
# On MacOSX we should acknowledge receipt of fds -- see Issue14669
ACKNOWLEDGE = sys.platform == 'darwin'
def sendfds(sock, fds): def sendfds(sock, fds):
'''Send an array of fds over an AF_UNIX socket.''' '''Send an array of fds over an AF_UNIX socket.'''
fds = array.array('i', fds) fds = array.array('i', fds)
msg = bytes([len(fds) % 256]) msg = bytes([len(fds) % 256])
sock.sendmsg([msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fds)]) sock.sendmsg([msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fds)])
if ACKNOWLEDGE and sock.recv(1) != b'A': if sock.recv(1) != b'A':
raise RuntimeError('did not receive acknowledgement of fd') raise RuntimeError('did not receive acknowledgement of fd')
def recvfds(sock, size): def recvfds(sock, size):
@@ -158,8 +155,11 @@ else:
if not msg and not ancdata: if not msg and not ancdata:
raise EOFError raise EOFError
try: try:
if ACKNOWLEDGE: # We send/recv an Ack byte after the fds to work around an old
sock.send(b'A') # macOS bug; it isn't clear if this is still required but it
# makes unit testing fd sending easier.
# See: https://github.com/python/cpython/issues/58874
sock.send(b'A') # Acknowledge
if len(ancdata) != 1: if len(ancdata) != 1:
raise RuntimeError('received %d items of ancdata' % raise RuntimeError('received %d items of ancdata' %
len(ancdata)) len(ancdata))

View File

@@ -20,6 +20,7 @@ import os
import signal import signal
import sys import sys
import threading import threading
import time
import warnings import warnings
from collections import deque from collections import deque
@@ -51,12 +52,8 @@ if os.name == 'posix':
# absence of POSIX named semaphores. In that case, no named semaphores were # absence of POSIX named semaphores. In that case, no named semaphores were
# ever opened, so no cleanup would be necessary. # ever opened, so no cleanup would be necessary.
if hasattr(_multiprocessing, 'sem_unlink'): if hasattr(_multiprocessing, 'sem_unlink'):
_CLEANUP_FUNCS.update({ _CLEANUP_FUNCS['semaphore'] = _multiprocessing.sem_unlink
'semaphore': _multiprocessing.sem_unlink, _CLEANUP_FUNCS['shared_memory'] = _posixshmem.shm_unlink
})
_CLEANUP_FUNCS.update({
'shared_memory': _posixshmem.shm_unlink,
})
class ReentrantCallError(RuntimeError): class ReentrantCallError(RuntimeError):
@@ -79,6 +76,10 @@ class ResourceTracker(object):
# The reader should understand all formats. # The reader should understand all formats.
self._use_simple_format = True self._use_simple_format = True
# Set to True by _stop_locked() if the waitpid polling loop ran to
# its timeout without reaping the tracker. Exposed for tests.
self._waitpid_timed_out = False
def _reentrant_call_error(self): def _reentrant_call_error(self):
# gh-109629: this happens if an explicit call to the ResourceTracker # gh-109629: this happens if an explicit call to the ResourceTracker
# gets interrupted by a garbage collection, invoking a finalizer (*) # gets interrupted by a garbage collection, invoking a finalizer (*)
@@ -91,16 +92,51 @@ class ResourceTracker(object):
# making sure child processess are cleaned before ResourceTracker # making sure child processess are cleaned before ResourceTracker
# gets destructed. # gets destructed.
# see https://github.com/python/cpython/issues/88887 # see https://github.com/python/cpython/issues/88887
self._stop(use_blocking_lock=False) # gh-146313: use a timeout to avoid deadlocking if a forked child
# still holds the pipe's write end open.
self._stop(use_blocking_lock=False, wait_timeout=1.0)
def _stop(self, use_blocking_lock=True): def _after_fork_in_child(self):
# gh-146313: Called in the child right after os.fork().
#
# The tracker process is a child of the *parent*, not of us, so we
# could never waitpid() it anyway. Clearing _pid means our __del__
# becomes a no-op (the early return for _pid is None).
#
# Whether we keep the inherited _fd depends on who forked us:
#
# - multiprocessing.Process with the 'fork' start method sets
# _fork_intent.preserve_fd before forking. The child keeps the
# fd and reuses the parent's tracker (gh-80849). This is safe
# because multiprocessing's atexit handler joins all children
# before the parent's __del__ runs, so by then the fd copies
# are gone and the parent can reap the tracker promptly.
#
# - A raw os.fork() leaves the flag unset. We close the fd in the child after forking so
# the parent's __del__ can reap the tracker without waiting
# for the child to exit. If we later need a tracker, ensure_running()
# will launch a fresh one.
self._lock._at_fork_reinit()
self._reentrant_messages.clear()
self._pid = None
self._exitcode = None
if (self._fd is not None and
not getattr(_fork_intent, 'preserve_fd', False)):
fd = self._fd
self._fd = None
try:
os.close(fd)
except OSError:
pass
def _stop(self, use_blocking_lock=True, wait_timeout=None):
if use_blocking_lock: if use_blocking_lock:
with self._lock: with self._lock:
self._stop_locked() self._stop_locked(wait_timeout=wait_timeout)
else: else:
acquired = self._lock.acquire(blocking=False) acquired = self._lock.acquire(blocking=False)
try: try:
self._stop_locked() self._stop_locked(wait_timeout=wait_timeout)
finally: finally:
if acquired: if acquired:
self._lock.release() self._lock.release()
@@ -110,6 +146,10 @@ class ResourceTracker(object):
close=os.close, close=os.close,
waitpid=os.waitpid, waitpid=os.waitpid,
waitstatus_to_exitcode=os.waitstatus_to_exitcode, waitstatus_to_exitcode=os.waitstatus_to_exitcode,
monotonic=time.monotonic,
sleep=time.sleep,
WNOHANG=getattr(os, 'WNOHANG', None),
wait_timeout=None,
): ):
# This shouldn't happen (it might when called by a finalizer) # This shouldn't happen (it might when called by a finalizer)
# so we check for it anyway. # so we check for it anyway.
@@ -126,7 +166,30 @@ class ResourceTracker(object):
self._fd = None self._fd = None
try: try:
_, status = waitpid(self._pid, 0) if wait_timeout is None:
_, status = waitpid(self._pid, 0)
else:
# gh-146313: A forked child may still hold the pipe's write
# end open, preventing the tracker from seeing EOF and
# exiting. Poll with WNOHANG to avoid blocking forever.
deadline = monotonic() + wait_timeout
delay = 0.001
while True:
result_pid, status = waitpid(self._pid, WNOHANG)
if result_pid != 0:
break
remaining = deadline - monotonic()
if remaining <= 0:
# The tracker is still running; it will be
# reparented to PID 1 (or the nearest subreaper)
# when we exit, and reaped there once all pipe
# holders release their fd.
self._pid = None
self._exitcode = None
self._waitpid_timed_out = True
return
delay = min(delay * 2, remaining, 0.1)
sleep(delay)
except ChildProcessError: except ChildProcessError:
self._pid = None self._pid = None
self._exitcode = None self._exitcode = None
@@ -312,12 +375,24 @@ class ResourceTracker(object):
self._ensure_running_and_write(msg) self._ensure_running_and_write(msg)
# gh-146313: Per-thread flag set by .popen_fork.Popen._launch() just before
# os.fork(), telling _after_fork_in_child() to keep the inherited pipe fd so
# the child can reuse this tracker (gh-80849). Unset for raw os.fork() calls,
# where the child instead closes the fd so the parent's __del__ can reap the
# tracker. Using threading.local() keeps multiple threads calling
# popen_fork.Popen._launch() at once from clobbering eachothers intent.
_fork_intent = threading.local()
_resource_tracker = ResourceTracker() _resource_tracker = ResourceTracker()
ensure_running = _resource_tracker.ensure_running ensure_running = _resource_tracker.ensure_running
register = _resource_tracker.register register = _resource_tracker.register
unregister = _resource_tracker.unregister unregister = _resource_tracker.unregister
getfd = _resource_tracker.getfd getfd = _resource_tracker.getfd
# gh-146313: See _after_fork_in_child docstring.
if hasattr(os, 'register_at_fork'):
os.register_at_fork(after_in_child=_resource_tracker._after_fork_in_child)
def _decode_message(line): def _decode_message(line):
if line.startswith(b'{'): if line.startswith(b'{'):

View File

@@ -539,6 +539,6 @@ class ShareableList:
if value == entry: if value == entry:
return position return position
else: else:
raise ValueError(f"{value!r} not in this container") raise ValueError("ShareableList.index(x): x not in list")
__class_getitem__ = classmethod(types.GenericAlias) __class_getitem__ = classmethod(types.GenericAlias)

View File

@@ -184,7 +184,7 @@ def get_preparation_data(name):
sys_argv=sys.argv, sys_argv=sys.argv,
orig_dir=process.ORIGINAL_DIR, orig_dir=process.ORIGINAL_DIR,
dir=os.getcwd(), dir=os.getcwd(),
start_method=get_start_method(), start_method=get_start_method(allow_none=True),
) )
# Figure out whether to initialise main in the subprocess as a module # Figure out whether to initialise main in the subprocess as a module

View File

@@ -21,22 +21,21 @@ from . import context
from . import process from . import process
from . import util from . import util
# Try to import the mp.synchronize module cleanly, if it fails # TODO: Do any platforms still lack a functioning sem_open?
# raise ImportError for platforms lacking a working sem_open implementation.
# See issue 3770
try: try:
from _multiprocessing import SemLock, sem_unlink from _multiprocessing import SemLock, sem_unlink
except (ImportError): except ImportError:
raise ImportError("This platform lacks a functioning sem_open" + raise ImportError("This platform lacks a functioning sem_open" +
" implementation, therefore, the required" + " implementation. https://github.com/python/cpython/issues/48020.")
" synchronization primitives needed will not" +
" function, see issue 3770.")
# #
# Constants # Constants
# #
RECURSIVE_MUTEX, SEMAPHORE = list(range(2)) # These match the enum in Modules/_multiprocessing/semaphore.c
RECURSIVE_MUTEX = 0
SEMAPHORE = 1
SEM_VALUE_MAX = _multiprocessing.SemLock.SEM_VALUE_MAX SEM_VALUE_MAX = _multiprocessing.SemLock.SEM_VALUE_MAX
# #
@@ -91,6 +90,9 @@ class SemLock(object):
self.acquire = self._semlock.acquire self.acquire = self._semlock.acquire
self.release = self._semlock.release self.release = self._semlock.release
def locked(self):
return self._semlock._is_zero()
def __enter__(self): def __enter__(self):
return self._semlock.__enter__() return self._semlock.__enter__()
@@ -133,11 +135,16 @@ class Semaphore(SemLock):
SemLock.__init__(self, SEMAPHORE, value, SEM_VALUE_MAX, ctx=ctx) SemLock.__init__(self, SEMAPHORE, value, SEM_VALUE_MAX, ctx=ctx)
def get_value(self): def get_value(self):
'''Returns current value of Semaphore.
Raises NotImplementedError on Mac OSX
because of broken sem_getvalue().
'''
return self._semlock._get_value() return self._semlock._get_value()
def __repr__(self): def __repr__(self):
try: try:
value = self._semlock._get_value() value = self.get_value()
except Exception: except Exception:
value = 'unknown' value = 'unknown'
return '<%s(value=%s)>' % (self.__class__.__name__, value) return '<%s(value=%s)>' % (self.__class__.__name__, value)
@@ -153,7 +160,7 @@ class BoundedSemaphore(Semaphore):
def __repr__(self): def __repr__(self):
try: try:
value = self._semlock._get_value() value = self.get_value()
except Exception: except Exception:
value = 'unknown' value = 'unknown'
return '<%s(value=%s, maxvalue=%s)>' % \ return '<%s(value=%s, maxvalue=%s)>' % \
@@ -245,8 +252,8 @@ class Condition(object):
def __repr__(self): def __repr__(self):
try: try:
num_waiters = (self._sleeping_count._semlock._get_value() - num_waiters = (self._sleeping_count.get_value() -
self._woken_count._semlock._get_value()) self._woken_count.get_value())
except Exception: except Exception:
num_waiters = 'unknown' num_waiters = 'unknown'
return '<%s(%s, %s)>' % (self.__class__.__name__, self._lock, num_waiters) return '<%s(%s, %s)>' % (self.__class__.__name__, self._lock, num_waiters)

View File

@@ -14,12 +14,12 @@ import weakref
import atexit import atexit
import threading # we want threading to install it's import threading # we want threading to install it's
# cleanup function before multiprocessing does # cleanup function before multiprocessing does
from subprocess import _args_from_interpreter_flags from subprocess import _args_from_interpreter_flags # noqa: F401
from . import process from . import process
__all__ = [ __all__ = [
'sub_debug', 'debug', 'info', 'sub_warning', 'get_logger', 'sub_debug', 'debug', 'info', 'sub_warning', 'warn', 'get_logger',
'log_to_stderr', 'get_temp_dir', 'register_after_fork', 'log_to_stderr', 'get_temp_dir', 'register_after_fork',
'is_exiting', 'Finalize', 'ForkAwareThreadLock', 'ForkAwareLocal', 'is_exiting', 'Finalize', 'ForkAwareThreadLock', 'ForkAwareLocal',
'close_all_fds_except', 'SUBDEBUG', 'SUBWARNING', 'close_all_fds_except', 'SUBDEBUG', 'SUBWARNING',
@@ -54,7 +54,7 @@ def info(msg, *args):
if _logger: if _logger:
_logger.log(INFO, msg, *args, stacklevel=2) _logger.log(INFO, msg, *args, stacklevel=2)
def _warn(msg, *args): def warn(msg, *args):
if _logger: if _logger:
_logger.log(WARNING, msg, *args, stacklevel=2) _logger.log(WARNING, msg, *args, stacklevel=2)
@@ -196,14 +196,14 @@ def _get_base_temp_dir(tempfile):
try: try:
base_system_tempdir = tempfile._get_default_tempdir(dirlist) base_system_tempdir = tempfile._get_default_tempdir(dirlist)
except FileNotFoundError: except FileNotFoundError:
_warn("Process-wide temporary directory %s will not be usable for " warn("Process-wide temporary directory %s will not be usable for "
"creating socket files and no usable system-wide temporary " "creating socket files and no usable system-wide temporary "
"directory was found in %s", base_tempdir, dirlist) "directory was found in %s", base_tempdir, dirlist)
# At this point, the system-wide temporary directory is not usable # At this point, the system-wide temporary directory is not usable
# but we may assume that the user-defined one is, even if we will # but we may assume that the user-defined one is, even if we will
# not be able to write socket files out there. # not be able to write socket files out there.
return base_tempdir return base_tempdir
_warn("Ignoring user-defined temporary directory: %s", base_tempdir) warn("Ignoring user-defined temporary directory: %s", base_tempdir)
# at most max(map(len, dirlist)) + 14 + 14 = 36 characters # at most max(map(len, dirlist)) + 14 + 14 = 36 characters
assert len(base_system_tempdir) + 14 + 14 < _SUN_PATH_MAX assert len(base_system_tempdir) + 14 + 14 < _SUN_PATH_MAX
return base_system_tempdir return base_system_tempdir

File diff suppressed because it is too large Load Diff

30
Lib/test/mp_preload_large_sysargv.py vendored Normal file
View File

@@ -0,0 +1,30 @@
# gh-144503: Test that the forkserver can start when the parent process has
# a very large sys.argv. Prior to the fix, sys.argv was repr'd into the
# forkserver ``-c`` command string which could exceed the OS limit on the
# length of a single argv element (MAX_ARG_STRLEN on Linux, ~128 KiB),
# causing posix_spawn to fail and the parent to see a BrokenPipeError.
import multiprocessing
import sys
EXPECTED_LEN = 5002 # argv[0] + 5000 padding entries + sentinel
def fun():
print(f"worker:{len(sys.argv)}:{sys.argv[-1]}")
if __name__ == "__main__":
# Inflate sys.argv well past 128 KiB before the forkserver is started.
sys.argv[1:] = ["x" * 50] * 5000 + ["sentinel"]
assert len(sys.argv) == EXPECTED_LEN
ctx = multiprocessing.get_context("forkserver")
p = ctx.Process(target=fun)
p.start()
p.join()
sys.exit(p.exitcode)
else:
# This branch runs when the forkserver preloads this module as
# __mp_main__; confirm the large argv was propagated intact.
print(f"preload:{len(sys.argv)}:{sys.argv[-1]}")

22
Lib/test/mp_preload_sysargv.py vendored Normal file
View File

@@ -0,0 +1,22 @@
# gh-143706: Test that sys.argv is correctly set during main module import
# when using forkserver with __main__ preloading.
import multiprocessing
import sys
# This will be printed during module import - sys.argv should be correct here
print(f"module:{sys.argv[1:]}")
def fun():
# This will be printed when the function is called
print(f"fun:{sys.argv[1:]}")
if __name__ == "__main__":
ctx = multiprocessing.get_context("forkserver")
ctx.set_forkserver_preload(['__main__'])
fun()
p = ctx.Process(target=fun)
p.start()
p.join()

View File

@@ -1,5 +1,4 @@
import io import io
import platform
import queue import queue
import re import re
import subprocess import subprocess
@@ -17,17 +16,11 @@ from unittest.mock import patch
if sys.platform != "android": if sys.platform != "android":
raise unittest.SkipTest("Android-specific") raise unittest.SkipTest("Android-specific")
api_level = platform.android_ver().api_level
# (name, level, fileno) # (name, level, fileno)
STREAM_INFO = [("stdout", "I", 1), ("stderr", "W", 2)] STREAM_INFO = [("stdout", "I", 1), ("stderr", "W", 2)]
# Test redirection of stdout and stderr to the Android log. # Test redirection of stdout and stderr to the Android log.
@unittest.skipIf(
api_level < 23 and platform.machine() == "aarch64",
"SELinux blocks reading logs on older ARM64 emulators"
)
class TestAndroidOutput(unittest.TestCase): class TestAndroidOutput(unittest.TestCase):
maxDiff = None maxDiff = None
@@ -42,31 +35,41 @@ class TestAndroidOutput(unittest.TestCase):
for line in self.logcat_process.stdout: for line in self.logcat_process.stdout:
self.logcat_queue.put(line.rstrip("\n")) self.logcat_queue.put(line.rstrip("\n"))
self.logcat_process.stdout.close() self.logcat_process.stdout.close()
self.logcat_thread = Thread(target=logcat_thread) self.logcat_thread = Thread(target=logcat_thread)
self.logcat_thread.start() self.logcat_thread.start()
from ctypes import CDLL, c_char_p, c_int try:
android_log_write = getattr(CDLL("liblog.so"), "__android_log_write") from ctypes import CDLL, c_char_p, c_int
android_log_write.argtypes = (c_int, c_char_p, c_char_p) android_log_write = getattr(CDLL("liblog.so"), "__android_log_write")
ANDROID_LOG_INFO = 4 android_log_write.argtypes = (c_int, c_char_p, c_char_p)
ANDROID_LOG_INFO = 4
# Separate tests using a marker line with a different tag. # Separate tests using a marker line with a different tag.
tag, message = "python.test", f"{self.id()} {time()}" tag, message = "python.test", f"{self.id()} {time()}"
android_log_write( android_log_write(
ANDROID_LOG_INFO, tag.encode("UTF-8"), message.encode("UTF-8")) ANDROID_LOG_INFO, tag.encode("UTF-8"), message.encode("UTF-8"))
self.assert_log("I", tag, message, skip=True, timeout=5) self.assert_log("I", tag, message, skip=True)
except:
# If setUp throws an exception, tearDown is not automatically
# called. Avoid leaving a dangling thread which would keep the
# Python process alive indefinitely.
self.tearDown()
raise
def assert_logs(self, level, tag, expected, **kwargs): def assert_logs(self, level, tag, expected, **kwargs):
for line in expected: for line in expected:
self.assert_log(level, tag, line, **kwargs) self.assert_log(level, tag, line, **kwargs)
def assert_log(self, level, tag, expected, *, skip=False, timeout=0.5): def assert_log(self, level, tag, expected, *, skip=False):
deadline = time() + timeout deadline = time() + LOOPBACK_TIMEOUT
while True: while True:
try: try:
line = self.logcat_queue.get(timeout=(deadline - time())) line = self.logcat_queue.get(timeout=(deadline - time()))
except queue.Empty: except queue.Empty:
self.fail(f"line not found: {expected!r}") raise self.failureException(
f"line not found: {expected!r}"
) from None
if match := re.fullmatch(fr"(.)/{tag}: (.*)", line): if match := re.fullmatch(fr"(.)/{tag}: (.*)", line):
try: try:
self.assertEqual(level, match[1]) self.assertEqual(level, match[1])
@@ -81,35 +84,42 @@ class TestAndroidOutput(unittest.TestCase):
self.logcat_process.wait(LOOPBACK_TIMEOUT) self.logcat_process.wait(LOOPBACK_TIMEOUT)
self.logcat_thread.join(LOOPBACK_TIMEOUT) self.logcat_thread.join(LOOPBACK_TIMEOUT)
# Avoid an irrelevant warning about threading._dangling.
self.logcat_thread = None
@contextmanager @contextmanager
def unbuffered(self, stream): def reconfigure(self, stream, **settings):
stream.reconfigure(write_through=True) original_settings = {key: getattr(stream, key, None) for key in settings.keys()}
stream.reconfigure(**settings)
try: try:
yield yield
finally: finally:
stream.reconfigure(write_through=False) stream.reconfigure(**original_settings)
# In --verbose3 mode, sys.stdout and sys.stderr are captured, so we can't
# test them directly. Detect this mode and use some temporary streams with
# the same properties.
def stream_context(self, stream_name, level): def stream_context(self, stream_name, level):
# https://developer.android.com/ndk/reference/group/logging
prio = {"I": 4, "W": 5}[level]
stack = ExitStack() stack = ExitStack()
stack.enter_context(self.subTest(stream_name)) stack.enter_context(self.subTest(stream_name))
# In --verbose3 mode, sys.stdout and sys.stderr are captured, so we can't
# test them directly. Detect this mode and use some temporary streams with
# the same properties.
stream = getattr(sys, stream_name) stream = getattr(sys, stream_name)
native_stream = getattr(sys, f"__{stream_name}__") native_stream = getattr(sys, f"__{stream_name}__")
if isinstance(stream, io.StringIO): if isinstance(stream, io.StringIO):
# https://developer.android.com/ndk/reference/group/logging
prio = {"I": 4, "W": 5}[level]
stack.enter_context( stack.enter_context(
patch( patch(
f"sys.{stream_name}", f"sys.{stream_name}",
TextLogStream( stream := TextLogStream(
prio, f"python.{stream_name}", native_stream.fileno(), prio, f"python.{stream_name}", native_stream,
errors="backslashreplace"
), ),
) )
) )
# The tests assume the stream is initially buffered.
stack.enter_context(self.reconfigure(stream, write_through=False))
return stack return stack
def test_str(self): def test_str(self):
@@ -136,7 +146,7 @@ class TestAndroidOutput(unittest.TestCase):
self.assert_logs(level, tag, lines) self.assert_logs(level, tag, lines)
# Single-line messages, # Single-line messages,
with self.unbuffered(stream): with self.reconfigure(stream, write_through=True):
write("", []) write("", [])
write("a") write("a")
@@ -166,14 +176,18 @@ class TestAndroidOutput(unittest.TestCase):
# Multi-line messages. Avoid identical consecutive lines, as # Multi-line messages. Avoid identical consecutive lines, as
# they may activate "chatty" filtering and break the tests. # they may activate "chatty" filtering and break the tests.
write("\nx", [""]) #
# Additional spaces will appear in the output where necessary to
# protect leading newlines.
write("\nx", [" "])
write("\na\n", ["x", "a"]) write("\na\n", ["x", "a"])
write("\n", [""]) write("\n", [" "])
write("\n\n", [" ", " "])
write("b\n", ["b"]) write("b\n", ["b"])
write("c\n\n", ["c", ""]) write("c\n\n", ["c", " "])
write("d\ne", ["d"]) write("d\ne", ["d"])
write("xx", []) write("xx", [])
write("f\n\ng", ["exxf", ""]) write("f\n\ng", ["exxf", " "])
write("\n", ["g"]) write("\n", ["g"])
# Since this is a line-based logging system, line buffering # Since this is a line-based logging system, line buffering
@@ -183,16 +197,17 @@ class TestAndroidOutput(unittest.TestCase):
# However, buffering can be turned off completely if you want a # However, buffering can be turned off completely if you want a
# flush after every write. # flush after every write.
with self.unbuffered(stream): with self.reconfigure(stream, write_through=True):
write("\nx", ["", "x"]) write("\nx", [" ", "x"])
write("\na\n", ["", "a"]) write("\na\n", [" ", "a"])
write("\n", [""]) write("\n", [" "])
write("\n\n", [" ", " "])
write("b\n", ["b"]) write("b\n", ["b"])
write("c\n\n", ["c", ""]) write("c\n\n", ["c", " "])
write("d\ne", ["d", "e"]) write("d\ne", ["d", "e"])
write("xx", ["xx"]) write("xx", ["xx"])
write("f\n\ng", ["f", "", "g"]) write("f\n\ng", ["f", " ", "g"])
write("\n", [""]) write("\n", [" "])
# "\r\n" should be translated into "\n". # "\r\n" should be translated into "\n".
write("hello\r\n", ["hello"]) write("hello\r\n", ["hello"])
@@ -312,19 +327,16 @@ class TestAndroidOutput(unittest.TestCase):
# currently use `logcat -v tag`, which shows each line as if it # currently use `logcat -v tag`, which shows each line as if it
# was a separate log entry, but strips a single trailing # was a separate log entry, but strips a single trailing
# newline. # newline.
# write(b"\nx", [" ", "x"])
# On newer versions of Android, all three of the above tools (or write(b"\na\n", [" ", "a"])
# maybe Logcat itself) will also strip any number of leading write(b"\n", [" "])
# newlines. write(b"\n\n", [" ", ""])
write(b"\nx", ["", "x"] if api_level < 30 else ["x"])
write(b"\na\n", ["", "a"] if api_level < 30 else ["a"])
write(b"\n", [""])
write(b"b\n", ["b"]) write(b"b\n", ["b"])
write(b"c\n\n", ["c", ""]) write(b"c\n\n", ["c", ""])
write(b"d\ne", ["d", "e"]) write(b"d\ne", ["d", "e"])
write(b"xx", ["xx"]) write(b"xx", ["xx"])
write(b"f\n\ng", ["f", "", "g"]) write(b"f\n\ng", ["f", "", "g"])
write(b"\n", [""]) write(b"\n", [" "])
# "\r\n" should be translated into "\n". # "\r\n" should be translated into "\n".
write(b"hello\r\n", ["hello"]) write(b"hello\r\n", ["hello"])

View File

@@ -427,6 +427,27 @@ class BaseSockTestsMixin:
self.loop.run_until_complete( self.loop.run_until_complete(
self._basetest_datagram_recvfrom_into(server_address)) self._basetest_datagram_recvfrom_into(server_address))
async def _basetest_datagram_recvfrom_into_wrong_size(self, server_address):
# Call sock_sendto() with a size larger than the buffer
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
sock.setblocking(False)
buf = bytearray(5000)
data = b'\x01' * 4096
wrong_size = len(buf) + 1
await self.loop.sock_sendto(sock, data, server_address)
with self.assertRaises(ValueError):
await self.loop.sock_recvfrom_into(
sock, buf, wrong_size)
size, addr = await self.loop.sock_recvfrom_into(sock, buf)
self.assertEqual(buf[:size], data)
def test_recvfrom_into_wrong_size(self):
with test_utils.run_udp_echo_server() as server_address:
self.loop.run_until_complete(
self._basetest_datagram_recvfrom_into_wrong_size(server_address))
async def _basetest_datagram_sendto_blocking(self, server_address): async def _basetest_datagram_sendto_blocking(self, server_address):
# Sad path, sock.sendto() raises BlockingIOError # Sad path, sock.sendto() raises BlockingIOError
# This involves patching sock.sendto() to raise BlockingIOError but # This involves patching sock.sendto() to raise BlockingIOError but
@@ -642,6 +663,10 @@ if sys.platform == 'win32':
self._basetest_datagram_send_to_non_listening_address( self._basetest_datagram_send_to_non_listening_address(
recvfrom_into)) recvfrom_into))
@unittest.expectedFailureIfWindows("TODO: RUSTPYTHON; AssertionError: ValueError not raised")
def test_recvfrom_into_wrong_size(self):
return super().test_recvfrom_into_wrong_size()
else: else:
import selectors import selectors

View File

@@ -27,6 +27,7 @@ from test.test_asyncio import utils as test_utils
MACOS = (sys.platform == 'darwin') MACOS = (sys.platform == 'darwin')
BUF_MULTIPLIER = 1024 if not MACOS else 64 BUF_MULTIPLIER = 1024 if not MACOS else 64
HANDSHAKE_TIMEOUT = support.LONG_TIMEOUT
def tearDownModule(): def tearDownModule():
@@ -257,15 +258,12 @@ class TestSSL(test_utils.TestCase):
await fut await fut
async def start_server(): async def start_server():
extras = {}
extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT)
srv = await asyncio.start_server( srv = await asyncio.start_server(
handle_client, handle_client,
'127.0.0.1', 0, '127.0.0.1', 0,
family=socket.AF_INET, family=socket.AF_INET,
ssl=sslctx, ssl=sslctx,
**extras) ssl_handshake_timeout=HANDSHAKE_TIMEOUT)
try: try:
srv_socks = srv.sockets srv_socks = srv.sockets
@@ -322,14 +320,11 @@ class TestSSL(test_utils.TestCase):
sock.close() sock.close()
async def client(addr): async def client(addr):
extras = {}
extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT)
reader, writer = await asyncio.open_connection( reader, writer = await asyncio.open_connection(
*addr, *addr,
ssl=client_sslctx, ssl=client_sslctx,
server_hostname='', server_hostname='',
**extras) ssl_handshake_timeout=HANDSHAKE_TIMEOUT)
writer.write(A_DATA) writer.write(A_DATA)
self.assertEqual(await reader.readexactly(2), b'OK') self.assertEqual(await reader.readexactly(2), b'OK')
@@ -349,7 +344,8 @@ class TestSSL(test_utils.TestCase):
reader, writer = await asyncio.open_connection( reader, writer = await asyncio.open_connection(
sock=sock, sock=sock,
ssl=client_sslctx, ssl=client_sslctx,
server_hostname='') server_hostname='',
ssl_handshake_timeout=HANDSHAKE_TIMEOUT)
writer.write(A_DATA) writer.write(A_DATA)
self.assertEqual(await reader.readexactly(2), b'OK') self.assertEqual(await reader.readexactly(2), b'OK')
@@ -448,7 +444,7 @@ class TestSSL(test_utils.TestCase):
*addr, *addr,
ssl=client_sslctx, ssl=client_sslctx,
server_hostname='', server_hostname='',
ssl_handshake_timeout=support.SHORT_TIMEOUT) ssl_handshake_timeout=HANDSHAKE_TIMEOUT)
writer.close() writer.close()
await self.wait_closed(writer) await self.wait_closed(writer)
@@ -610,7 +606,7 @@ class TestSSL(test_utils.TestCase):
extras = {} extras = {}
if server_ssl: if server_ssl:
extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT) extras = dict(ssl_handshake_timeout=HANDSHAKE_TIMEOUT)
f = loop.create_task( f = loop.create_task(
loop.connect_accepted_socket( loop.connect_accepted_socket(
@@ -659,7 +655,8 @@ class TestSSL(test_utils.TestCase):
reader, writer = await asyncio.open_connection( reader, writer = await asyncio.open_connection(
*addr, *addr,
ssl=client_sslctx, ssl=client_sslctx,
server_hostname='') server_hostname='',
ssl_handshake_timeout=HANDSHAKE_TIMEOUT)
self.assertEqual(await reader.readline(), b'A\n') self.assertEqual(await reader.readline(), b'A\n')
writer.write(b'B') writer.write(b'B')
@@ -1153,14 +1150,11 @@ class TestSSL(test_utils.TestCase):
await fut await fut
async def start_server(): async def start_server():
extras = {}
srv = await self.loop.create_server( srv = await self.loop.create_server(
server_protocol_factory, server_protocol_factory,
'127.0.0.1', 0, '127.0.0.1', 0,
family=socket.AF_INET, family=socket.AF_INET,
ssl=sslctx_1, ssl=sslctx_1)
**extras)
try: try:
srv_socks = srv.sockets srv_socks = srv.sockets
@@ -1210,14 +1204,11 @@ class TestSSL(test_utils.TestCase):
sock.close() sock.close()
async def client(addr): async def client(addr):
extras = {}
extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT)
reader, writer = await asyncio.open_connection( reader, writer = await asyncio.open_connection(
*addr, *addr,
ssl=client_sslctx, ssl=client_sslctx,
server_hostname='', server_hostname='',
**extras) ssl_handshake_timeout=HANDSHAKE_TIMEOUT)
writer.write(A_DATA) writer.write(A_DATA)
self.assertEqual(await reader.readexactly(2), b'OK') self.assertEqual(await reader.readexactly(2), b'OK')
@@ -1287,7 +1278,8 @@ class TestSSL(test_utils.TestCase):
reader, writer = await asyncio.open_connection( reader, writer = await asyncio.open_connection(
*addr, *addr,
ssl=client_sslctx, ssl=client_sslctx,
server_hostname='') server_hostname='',
ssl_handshake_timeout=HANDSHAKE_TIMEOUT)
sslprotocol = writer.transport._ssl_protocol sslprotocol = writer.transport._ssl_protocol
writer.write(b'ping') writer.write(b'ping')
data = await reader.readexactly(4) data = await reader.readexactly(4)
@@ -1399,7 +1391,8 @@ class TestSSL(test_utils.TestCase):
reader, writer = await asyncio.open_connection( reader, writer = await asyncio.open_connection(
*addr, *addr,
ssl=client_sslctx, ssl=client_sslctx,
server_hostname='') server_hostname='',
ssl_handshake_timeout=HANDSHAKE_TIMEOUT)
writer.write(b'ping') writer.write(b'ping')
data = await reader.readexactly(4) data = await reader.readexactly(4)
self.assertEqual(data, b'pong') self.assertEqual(data, b'pong')
@@ -1530,7 +1523,8 @@ class TestSSL(test_utils.TestCase):
reader, writer = await asyncio.open_connection( reader, writer = await asyncio.open_connection(
*addr, *addr,
ssl=client_sslctx, ssl=client_sslctx,
server_hostname='') server_hostname='',
ssl_handshake_timeout=HANDSHAKE_TIMEOUT)
writer.write(b'ping') writer.write(b'ping')
data = await reader.readexactly(4) data = await reader.readexactly(4)
self.assertEqual(data, b'pong') self.assertEqual(data, b'pong')

View File

@@ -819,6 +819,48 @@ class StreamTests(test_utils.TestCase):
self.assertEqual(msg1, b"hello world 1!\n") self.assertEqual(msg1, b"hello world 1!\n")
self.assertEqual(msg2, b"hello world 2!\n") self.assertEqual(msg2, b"hello world 2!\n")
@unittest.skipIf(ssl is None, 'No ssl module')
def test_start_tls_buffered_data(self):
# gh-142352: test start_tls() with buffered data
async def server_handler(client_reader, client_writer):
# Wait for TLS ClientHello to be buffered before start_tls().
await client_reader._wait_for_data('test_start_tls_buffered_data'),
self.assertTrue(client_reader._buffer)
await client_writer.start_tls(test_utils.simple_server_sslcontext())
line = await client_reader.readline()
self.assertEqual(line, b"ping\n")
client_writer.write(b"pong\n")
await client_writer.drain()
client_writer.close()
await client_writer.wait_closed()
async def client(addr):
reader, writer = await asyncio.open_connection(*addr)
await writer.start_tls(test_utils.simple_client_sslcontext())
writer.write(b"ping\n")
await writer.drain()
line = await reader.readline()
self.assertEqual(line, b"pong\n")
writer.close()
await writer.wait_closed()
async def run_test():
server = await asyncio.start_server(
server_handler, socket_helper.HOSTv4, 0)
server_addr = server.sockets[0].getsockname()
await client(server_addr)
server.close()
await server.wait_closed()
messages = []
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
self.loop.run_until_complete(run_test())
self.assertEqual(messages, [])
def test_streamreader_constructor_without_loop(self): def test_streamreader_constructor_without_loop(self):
with self.assertRaisesRegex(RuntimeError, 'no current event loop'): with self.assertRaisesRegex(RuntimeError, 'no current event loop'):
asyncio.StreamReader() asyncio.StreamReader()

View File

@@ -111,6 +111,37 @@ class SubprocessTransportTests(test_utils.TestCase):
) )
transport.close() transport.close()
def test_proc_exited_no_invalid_state_error_on_exit_waiters(self):
# gh-145541: when _connect_pipes hasn't completed (so
# _pipes_connected is False) and the process exits, _try_finish()
# sets the result on exit waiters. Then _call_connection_lost() must
# not call set_result() again on the same waiters.
self.loop.set_exception_handler(
lambda loop, context: self.fail(
f"unexpected exception: {context}")
)
waiter = self.loop.create_future()
transport, protocol = self.create_transport(waiter)
# Simulate a waiter registered via _wait() before the process exits.
exit_waiter = self.loop.create_future()
transport._exit_waiters.append(exit_waiter)
# _connect_pipes hasn't completed, so _pipes_connected is False.
self.assertFalse(transport._pipes_connected)
# Simulate process exit. _try_finish() will set the result on
# exit_waiter because _pipes_connected is False, and then schedule
# _call_connection_lost() because _pipes is empty (vacuously all
# disconnected). _call_connection_lost() must skip exit_waiter
# because it's already done.
transport._process_exited(6)
self.loop.run_until_complete(waiter)
self.assertEqual(exit_waiter.result(), 6)
transport.close()
class SubprocessMixin: class SubprocessMixin:

View File

@@ -2947,7 +2947,6 @@ class CTask_CFuture_Tests(BaseTaskTests, SetMethodsTest,
return super().test_log_destroyed_pending_task() return super().test_log_destroyed_pending_task()
@unittest.skipUnless(hasattr(futures, '_CFuture') and @unittest.skipUnless(hasattr(futures, '_CFuture') and
hasattr(tasks, '_CTask'), hasattr(tasks, '_CTask'),
'requires the C _asyncio module') 'requires the C _asyncio module')
@@ -3030,7 +3029,6 @@ class PyTask_PyFuture_SubclassTests(BaseTaskTests, test_utils.TestCase):
all_tasks = staticmethod(tasks._py_all_tasks) all_tasks = staticmethod(tasks._py_all_tasks)
current_task = staticmethod(tasks._py_current_task) current_task = staticmethod(tasks._py_current_task)
@unittest.skipUnless(hasattr(tasks, '_CTask'), @unittest.skipUnless(hasattr(tasks, '_CTask'),
'requires the C _asyncio module') 'requires the C _asyncio module')
class CTask_Future_Tests(test_utils.TestCase): class CTask_Future_Tests(test_utils.TestCase):
@@ -3063,6 +3061,26 @@ class BaseTaskIntrospectionTests:
_enter_task = None _enter_task = None
_leave_task = None _leave_task = None
all_tasks = None all_tasks = None
Task = None
def test_register_task_resurrection(self):
register_task = self._register_task
class EvilLoop:
def get_debug(self):
return False
def call_exception_handler(self, context):
register_task(context["task"])
async def coro_fn ():
pass
coro = coro_fn()
self.addCleanup(coro.close)
loop = EvilLoop()
with self.assertRaises(AttributeError):
self.Task(coro, loop=loop)
def test__register_task_1(self): def test__register_task_1(self):
class TaskLike: class TaskLike:
@@ -3193,6 +3211,7 @@ class PyIntrospectionTests(test_utils.TestCase, BaseTaskIntrospectionTests):
_leave_task = staticmethod(tasks._py_leave_task) _leave_task = staticmethod(tasks._py_leave_task)
all_tasks = staticmethod(tasks._py_all_tasks) all_tasks = staticmethod(tasks._py_all_tasks)
current_task = staticmethod(tasks._py_current_task) current_task = staticmethod(tasks._py_current_task)
Task = tasks._PyTask
@unittest.skipUnless(hasattr(tasks, '_c_register_task'), @unittest.skipUnless(hasattr(tasks, '_c_register_task'),
@@ -3205,10 +3224,12 @@ class CIntrospectionTests(test_utils.TestCase, BaseTaskIntrospectionTests):
_leave_task = staticmethod(tasks._c_leave_task) _leave_task = staticmethod(tasks._c_leave_task)
all_tasks = staticmethod(tasks._c_all_tasks) all_tasks = staticmethod(tasks._c_all_tasks)
current_task = staticmethod(tasks._c_current_task) current_task = staticmethod(tasks._c_current_task)
Task = tasks._CTask
else: else:
_register_task = _unregister_task = _enter_task = _leave_task = None _register_task = _unregister_task = _enter_task = _leave_task = None
class BaseCurrentLoopTests: class BaseCurrentLoopTests:
current_task = None current_task = None
@@ -3698,6 +3719,30 @@ class RunCoroutineThreadsafeTests(test_utils.TestCase):
(loop, context), kwargs = callback.call_args (loop, context), kwargs = callback.call_args
self.assertEqual(context['exception'], exc_context.exception) self.assertEqual(context['exception'], exc_context.exception)
def test_run_coroutine_threadsafe_and_cancel(self):
task = None
thread_future = None
# Use a custom task factory to capture the created Task
def task_factory(loop, coro):
nonlocal task
task = asyncio.Task(coro, loop=loop)
return task
self.addCleanup(self.loop.set_task_factory,
self.loop.get_task_factory())
async def target():
nonlocal thread_future
self.loop.set_task_factory(task_factory)
thread_future = asyncio.run_coroutine_threadsafe(asyncio.sleep(10), self.loop)
await asyncio.sleep(0)
thread_future.cancel()
self.loop.run_until_complete(target())
self.assertTrue(task.cancelled())
self.assertTrue(thread_future.cancelled())
class SleepTests(test_utils.TestCase): class SleepTests(test_utils.TestCase):
def setUp(self): def setUp(self):

View File

@@ -77,6 +77,30 @@ class PipeTests(unittest.TestCase):
else: else:
raise RuntimeError('expected ERROR_INVALID_HANDLE') raise RuntimeError('expected ERROR_INVALID_HANDLE')
def test_pipe_handle_close_after_external_close(self):
# gh-149388: PipeHandle.close() must clear ``_handle`` before calling
# CloseHandle so that if CloseHandle raises on a stale handle the
# PipeHandle is still marked closed and __del__ / subsequent close()
# calls are silent no-ops.
h1, h2 = windows_utils.pipe(overlapped=(False, False))
try:
p = windows_utils.PipeHandle(h1)
# Simulate an external close of the underlying handle (e.g.
# a finalizer race or a concurrent close on the same object).
_winapi.CloseHandle(p.handle)
# First close() still propagates the OSError from CloseHandle,
# but must clear ``_handle`` first.
with self.assertRaises(OSError):
p.close()
self.assertIsNone(p.handle)
# Second close() is a no-op.
p.close()
# __del__ through GC is also a silent no-op — no unraisable.
del p
support.gc_collect()
finally:
_winapi.CloseHandle(h2)
class PopenTests(unittest.TestCase): class PopenTests(unittest.TestCase):
@@ -129,5 +153,25 @@ class PopenTests(unittest.TestCase):
pass pass
class OverlappedRefleakTests(unittest.TestCase):
def test_wsasendto_failure(self):
ov = _overlapped.Overlapped()
buf = bytearray(4096)
with self.assertRaises(OSError):
ov.WSASendTo(0x1234, buf, 0, ("127.0.0.1", 1))
def test_wsarecvfrom_failure(self):
ov = _overlapped.Overlapped()
with self.assertRaises(OSError):
ov.WSARecvFrom(0x1234, 1024, 0)
def test_wsarecvfrominto_failure(self):
ov = _overlapped.Overlapped()
buf = bytearray(4096)
with self.assertRaises(OSError):
ov.WSARecvFromInto(0x1234, buf, len(buf), 0)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@@ -222,6 +222,7 @@ class BinASCIITest(unittest.TestCase):
assertInvalidLength(b'a' * (4 * 87 + 1)) assertInvalidLength(b'a' * (4 * 87 + 1))
assertInvalidLength(b'A\tB\nC ??DE') # only 5 valid characters assertInvalidLength(b'A\tB\nC ??DE') # only 5 valid characters
@unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: Error not raised by a2b_uu
def test_uu(self): def test_uu(self):
MAX_UU = 45 MAX_UU = 45
for backtick in (True, False): for backtick in (True, False):
@@ -242,6 +243,10 @@ class BinASCIITest(unittest.TestCase):
self.assertEqual(binascii.a2b_uu(b"\xff"), b"\x00"*31) self.assertEqual(binascii.a2b_uu(b"\xff"), b"\x00"*31)
self.assertRaises(binascii.Error, binascii.a2b_uu, b"\xff\x00") self.assertRaises(binascii.Error, binascii.a2b_uu, b"\xff\x00")
self.assertRaises(binascii.Error, binascii.a2b_uu, b"!!!!") self.assertRaises(binascii.Error, binascii.a2b_uu, b"!!!!")
self.assertRaises(binascii.Error, binascii.a2b_uu,
self.type2test(b""))
self.assertRaises(binascii.Error, binascii.a2b_uu,
self.type2test(b"#86)C")[:0])
self.assertRaises(binascii.Error, binascii.b2a_uu, 46*b"!") self.assertRaises(binascii.Error, binascii.b2a_uu, 46*b"!")
# Issue #7701 (crash on a pydebug build) # Issue #7701 (crash on a pydebug build)
@@ -440,6 +445,7 @@ class BinASCIITest(unittest.TestCase):
self.assertConversion(binary, converted, restored, self.assertConversion(binary, converted, restored,
quotetabs=quotetabs, istext=istext, header=header) quotetabs=quotetabs, istext=istext, header=header)
@unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: Error not raised by a2b_uu
def test_empty_string(self): def test_empty_string(self):
# A test for SF bug #1022953. Make sure SystemError is not raised. # A test for SF bug #1022953. Make sure SystemError is not raised.
empty = self.type2test(b'') empty = self.type2test(b'')
@@ -449,6 +455,9 @@ class BinASCIITest(unittest.TestCase):
binascii.crc_hqx(empty, 0) binascii.crc_hqx(empty, 0)
continue continue
f = getattr(binascii, func) f = getattr(binascii, func)
if func == 'a2b_uu':
self.assertRaises(binascii.Error, f, empty)
continue
try: try:
f(empty) f(empty)
except Exception as err: except Exception as err:

View File

@@ -293,6 +293,27 @@ class BuiltinTest(ComplexesAreIdenticalMixin, unittest.TestCase):
self.assertEqual(overridden_outputs, ['all', 'any', 'tuple']) self.assertEqual(overridden_outputs, ['all', 'any', 'tuple'])
def test_builtin_call_async_genexpr_no_crash(self):
async def f_all():
return all(await 2 for _ in [])
async def f_any():
return any(await 2 for _ in [])
async def f_tuple():
return tuple(await 2 for _ in [])
async def f_list():
return list(await 2 for _ in [])
async def f_set():
return set(await 2 for _ in [])
for f in (f_all, f_any, f_tuple, f_list, f_set):
with self.subTest(func=f.__name__):
with self.assertRaises(TypeError):
run_yielding_async_fn(f)
def test_ascii(self): def test_ascii(self):
self.assertEqual(ascii(''), '\'\'') self.assertEqual(ascii(''), '\'\'')
self.assertEqual(ascii(0), '0') self.assertEqual(ascii(0), '0')
@@ -606,7 +627,7 @@ class BuiltinTest(ComplexesAreIdenticalMixin, unittest.TestCase):
exec(co, glob) exec(co, glob)
self.assertEqual(type(glob['ticker']()), AsyncGeneratorType) self.assertEqual(type(glob['ticker']()), AsyncGeneratorType)
@unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: <_ast.Name object at 0xb40000731e3d1360> is not an instance of <class '_ast.Constant'> @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: <_ast.Name object at 0xb40000731e3d1360> is not an instance of <class '_ast.Constant'>
def test_compile_ast(self): def test_compile_ast(self):
args = ("a*__debug__", "f.py", "exec") args = ("a*__debug__", "f.py", "exec")
raw = compile(*args, flags = ast.PyCF_ONLY_AST).body[0] raw = compile(*args, flags = ast.PyCF_ONLY_AST).body[0]
@@ -970,7 +991,7 @@ class BuiltinTest(ComplexesAreIdenticalMixin, unittest.TestCase):
self.assertRaisesRegex(NameError, "name 'superglobal' is not defined", self.assertRaisesRegex(NameError, "name 'superglobal' is not defined",
eval, code, ns) eval, code, ns)
@unittest.expectedFailure # TODO: RUSTPYTHON; wrong error message @unittest.expectedFailure # TODO: RUSTPYTHON; wrong error message
def test_exec_builtins_mapping_import(self): def test_exec_builtins_mapping_import(self):
code = compile("import foo.bar", "test", "exec") code = compile("import foo.bar", "test", "exec")
ns = {'__builtins__': types.MappingProxyType({})} ns = {'__builtins__': types.MappingProxyType({})}
@@ -979,7 +1000,7 @@ class BuiltinTest(ComplexesAreIdenticalMixin, unittest.TestCase):
exec(code, ns) exec(code, ns)
self.assertEqual(ns['foo'], ('foo.bar', ns, ns, None, 0)) self.assertEqual(ns['foo'], ('foo.bar', ns, ns, None, 0))
@unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: AttributeError not raised by eval @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: AttributeError not raised by eval
def test_eval_builtins_mapping_reduce(self): def test_eval_builtins_mapping_reduce(self):
# list_iterator.__reduce__() calls _PyEval_GetBuiltin("iter") # list_iterator.__reduce__() calls _PyEval_GetBuiltin("iter")
code = compile("x.__reduce__()", "test", "eval") code = compile("x.__reduce__()", "test", "eval")

View File

@@ -204,5 +204,20 @@ class TestDefaultDict(unittest.TestCase):
self.assertEqual(test_dict[key], 2) self.assertEqual(test_dict[key], 2)
self.assertEqual(count, 2) self.assertEqual(count, 2)
def test_repr_recursive_factory(self):
# gh-145492: defaultdict.__repr__ should not cause infinite recursion
# when the factory's __repr__ calls repr() on the defaultdict.
class ProblematicFactory:
def __call__(self):
return {}
def __repr__(self):
repr(dd)
return f"ProblematicFactory for {dd}"
dd = defaultdict(ProblematicFactory())
# Should not raise RecursionError
r = repr(dd)
self.assertIn("ProblematicFactory for", r)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@@ -5,7 +5,6 @@ machinery = util.import_importlib('importlib.machinery')
from test.support import captured_stdout, import_helper, STDLIB_DIR from test.support import captured_stdout, import_helper, STDLIB_DIR
import contextlib import contextlib
import os.path import os.path
import sys
import types import types
import unittest import unittest
import warnings import warnings
@@ -76,7 +75,7 @@ class ExecModuleTests(abc.LoaderTests):
self.assertHasAttr(module, '__spec__') self.assertHasAttr(module, '__spec__')
self.assertEqual(module.__spec__.loader_state.origname, name) self.assertEqual(module.__spec__.loader_state.origname, name)
@unittest.skipIf(sys.platform == "win32", "TODO: RUSTPYTHON") @unittest.skipIf(__import__("sys").platform == "win32", "TODO: RUSTPYTHON")
def test_package(self): def test_package(self):
name = '__phello__' name = '__phello__'
module, output = self.exec_module(name) module, output = self.exec_module(name)
@@ -147,7 +146,7 @@ class InspectLoaderTests:
result = self.machinery.FrozenImporter.get_source('__hello__') result = self.machinery.FrozenImporter.get_source('__hello__')
self.assertIsNone(result) self.assertIsNone(result)
@unittest.skipIf(sys.platform == "win32", "TODO: RUSTPYTHON") @unittest.skipIf(__import__("sys").platform == "win32", "TODO: RUSTPYTHON")
def test_is_package(self): def test_is_package(self):
# Should be able to tell what is a package. # Should be able to tell what is a package.
test_for = (('__hello__', False), ('__phello__', True), test_for = (('__hello__', False), ('__phello__', True),

View File

@@ -1,4 +1,3 @@
import os
import unittest import unittest
from . import util from . import util

View File

@@ -155,8 +155,8 @@ class LifetimeTests:
# TODO: RUSTPYTHON; dead weakref module locks not cleaned up in frozen bootstrap # TODO: RUSTPYTHON; dead weakref module locks not cleaned up in frozen bootstrap
Frozen_LifetimeTests.test_all_locks = unittest.skip("TODO: RUSTPYTHON")( Frozen_LifetimeTests.test_all_locks = unittest.skip("TODO: RUSTPYTHON")(
Frozen_LifetimeTests.test_all_locks) Frozen_LifetimeTests.test_all_locks
)
def setUpModule(): def setUpModule():
thread_info = threading_helper.threading_setup() thread_info = threading_helper.threading_setup()

View File

@@ -263,6 +263,72 @@ class ThreadedImportTests(unittest.TestCase):
'partial', 'pool_in_threads.py') 'partial', 'pool_in_threads.py')
script_helper.assert_python_ok(fn) script_helper.assert_python_ok(fn)
@unittest.expectedFailure # TODO: RUSTPYTHON
def test_import_failure_race_condition(self):
# Regression test for race condition where a thread could receive
# a partially-initialized module when another thread's import fails.
# The race occurs when:
# 1. Thread 1 starts importing, adds module to sys.modules
# 2. Thread 2 sees the module in sys.modules
# 3. Thread 1's import fails, removes module from sys.modules
# 4. Thread 2 should NOT return the stale module reference
os.mkdir(TESTFN)
self.addCleanup(shutil.rmtree, TESTFN)
sys.path.insert(0, TESTFN)
self.addCleanup(sys.path.remove, TESTFN)
# Create a module that partially initializes then fails
modname = 'failing_import_module'
with open(os.path.join(TESTFN, modname + '.py'), 'w') as f:
f.write('''
import time
PARTIAL_ATTR = 'initialized'
time.sleep(0.05) # Widen race window
raise RuntimeError("Intentional import failure")
''')
self.addCleanup(forget, modname)
importlib.invalidate_caches()
errors = []
results = []
def do_import(delay=0):
time.sleep(delay)
try:
mod = __import__(modname)
# If we got a module, verify it's in sys.modules
if modname not in sys.modules:
errors.append(
f"Got module {mod!r} but {modname!r} not in sys.modules"
)
elif sys.modules[modname] is not mod:
errors.append(
f"Got different module than sys.modules[{modname!r}]"
)
else:
results.append(('success', mod))
except RuntimeError:
results.append(('RuntimeError',))
except Exception as e:
errors.append(f"Unexpected exception: {e}")
# Run multiple iterations to increase chance of hitting the race
for _ in range(10):
errors.clear()
results.clear()
if modname in sys.modules:
del sys.modules[modname]
t1 = threading.Thread(target=do_import, args=(0,))
t2 = threading.Thread(target=do_import, args=(0.01,))
t1.start()
t2.start()
t1.join()
t2.join()
# Neither thread should have errors about stale modules
self.assertEqual(errors, [], f"Race condition detected: {errors}")
def setUpModule(): def setUpModule():
thread_info = threading_helper.threading_setup() thread_info = threading_helper.threading_setup()

View File

@@ -15,10 +15,10 @@ import sys
import tempfile import tempfile
import types import types
try: try: # TODO: RUSTPYTHON
_testsinglephase = import_helper.import_module("_testsinglephase") import_helper.import_module("_testmultiphase")
except unittest.SkipTest: except unittest.SkipTest:
_testsinglephase = None # TODO: RUSTPYTHON _testmultiphase = None
BUILTINS = types.SimpleNamespace() BUILTINS = types.SimpleNamespace()

385
Lib/test/test_mmap.py vendored
View File

@@ -1,9 +1,12 @@
from test.support import ( from test.support import (
requires, _2G, _4G, gc_collect, cpython_only, is_emscripten requires, _2G, _4G, gc_collect, cpython_only, is_emscripten, is_apple,
in_systemd_nspawn_sync_suppressed,
) )
from test.support.import_helper import import_module from test.support.import_helper import import_module
from test.support.os_helper import TESTFN, unlink from test.support.os_helper import TESTFN, unlink
from test.support.script_helper import assert_python_ok
import unittest import unittest
import errno
import os import os
import re import re
import itertools import itertools
@@ -11,6 +14,7 @@ import random
import socket import socket
import string import string
import sys import sys
import textwrap
import weakref import weakref
# Skip test if we can't import mmap. # Skip test if we can't import mmap.
@@ -32,6 +36,10 @@ if is_emscripten:
class MmapTests(unittest.TestCase): class MmapTests(unittest.TestCase):
def setUp(self): def setUp(self):
# TODO: RUSTPYTHON; Remove this once windows doesn't get errored on setup:/
if os.name == "nt":
raise unittest.SkipTest("TODO: RUSTPYTHON; Error during class setUp")
if os.path.exists(TESTFN): if os.path.exists(TESTFN):
os.unlink(TESTFN) os.unlink(TESTFN)
@@ -41,6 +49,7 @@ class MmapTests(unittest.TestCase):
except OSError: except OSError:
pass pass
@unittest.expectedFailure # TODO: RUSTPYTHON; AttributeError: 'mmap' object has no attribute 'seekable'
def test_basic(self): def test_basic(self):
# Test mmap module on Unix systems and Windows # Test mmap module on Unix systems and Windows
@@ -93,11 +102,12 @@ class MmapTests(unittest.TestCase):
self.assertEqual(end, PAGESIZE + 6) self.assertEqual(end, PAGESIZE + 6)
# test seeking around (try to overflow the seek implementation) # test seeking around (try to overflow the seek implementation)
m.seek(0,0) self.assertTrue(m.seekable())
self.assertEqual(m.seek(0, 0), 0)
self.assertEqual(m.tell(), 0) self.assertEqual(m.tell(), 0)
m.seek(42,1) self.assertEqual(m.seek(42, 1), 42)
self.assertEqual(m.tell(), 42) self.assertEqual(m.tell(), 42)
m.seek(0,2) self.assertEqual(m.seek(0, 2), len(m))
self.assertEqual(m.tell(), len(m)) self.assertEqual(m.tell(), len(m))
# Try to seek to negative position... # Try to seek to negative position...
@@ -162,7 +172,7 @@ class MmapTests(unittest.TestCase):
# Ensuring that readonly mmap can't be write() to # Ensuring that readonly mmap can't be write() to
try: try:
m.seek(0,0) m.seek(0, 0)
m.write(b'abc') m.write(b'abc')
except TypeError: except TypeError:
pass pass
@@ -171,7 +181,7 @@ class MmapTests(unittest.TestCase):
# Ensuring that readonly mmap can't be write_byte() to # Ensuring that readonly mmap can't be write_byte() to
try: try:
m.seek(0,0) m.seek(0, 0)
m.write_byte(b'd') m.write_byte(b'd')
except TypeError: except TypeError:
pass pass
@@ -255,15 +265,79 @@ class MmapTests(unittest.TestCase):
# Try writing with PROT_EXEC and without PROT_WRITE # Try writing with PROT_EXEC and without PROT_WRITE
prot = mmap.PROT_READ | getattr(mmap, 'PROT_EXEC', 0) prot = mmap.PROT_READ | getattr(mmap, 'PROT_EXEC', 0)
with open(TESTFN, "r+b") as f: with open(TESTFN, "r+b") as f:
m = mmap.mmap(f.fileno(), mapsize, prot=prot) try:
self.assertRaises(TypeError, m.write, b"abcdef") m = mmap.mmap(f.fileno(), mapsize, prot=prot)
self.assertRaises(TypeError, m.write_byte, 0) except PermissionError:
m.close() # on macOS 14, PROT_READ | PROT_EXEC is not allowed
pass
else:
self.assertRaises(TypeError, m.write, b"abcdef")
self.assertRaises(TypeError, m.write_byte, 0)
m.close()
@unittest.expectedFailure # TODO: RUSTPYTHON
@unittest.skipIf(os.name == 'nt', 'trackfd not present on Windows')
def test_trackfd_parameter(self):
size = 64
with open(TESTFN, "wb") as f:
f.write(b"a"*size)
for close_original_fd in True, False:
with self.subTest(close_original_fd=close_original_fd):
with open(TESTFN, "r+b") as f:
with mmap.mmap(f.fileno(), size, trackfd=False) as m:
if close_original_fd:
f.close()
self.assertEqual(len(m), size)
with self.assertRaises(OSError) as err_cm:
m.size()
self.assertEqual(err_cm.exception.errno, errno.EBADF)
with self.assertRaises(ValueError):
m.resize(size * 2)
with self.assertRaises(ValueError):
m.resize(size // 2)
self.assertEqual(m.closed, False)
# Smoke-test other API
m.write_byte(ord('X'))
m[2] = ord('Y')
m.flush()
with open(TESTFN, "rb") as f:
self.assertEqual(f.read(4), b'XaYa')
self.assertEqual(m.tell(), 1)
m.seek(0)
self.assertEqual(m.tell(), 0)
self.assertEqual(m.read_byte(), ord('X'))
self.assertEqual(m.closed, True)
self.assertEqual(os.stat(TESTFN).st_size, size)
@unittest.expectedFailure # TODO: RUSTPYTHON
@unittest.skipIf(os.name == 'nt', 'trackfd not present on Windows')
def test_trackfd_neg1(self):
size = 64
with mmap.mmap(-1, size, trackfd=False) as m:
with self.assertRaises(OSError):
m.size()
with self.assertRaises(ValueError):
m.resize(size // 2)
self.assertEqual(len(m), size)
m[0] = ord('a')
assert m[0] == ord('a')
@unittest.skipIf(os.name != 'nt', 'trackfd only fails on Windows')
def test_no_trackfd_parameter_on_windows(self):
# 'trackffd' is an invalid keyword argument for this function
size = 64
with self.assertRaises(TypeError):
mmap.mmap(-1, size, trackfd=True)
with self.assertRaises(TypeError):
mmap.mmap(-1, size, trackfd=False)
def test_bad_file_desc(self): def test_bad_file_desc(self):
# Try opening a bad file descriptor... # Try opening a bad file descriptor...
self.assertRaises(OSError, mmap.mmap, -2, 4096) self.assertRaises(OSError, mmap.mmap, -2, 4096)
@unittest.expectedFailureIfWindows("TODO: RUSTPYTHON")
def test_tougher_find(self): def test_tougher_find(self):
# Do a tougher .find() test. SF bug 515943 pointed out that, in 2.2, # Do a tougher .find() test. SF bug 515943 pointed out that, in 2.2,
# searching for data with embedded \0 bytes didn't work. # searching for data with embedded \0 bytes didn't work.
@@ -282,6 +356,7 @@ class MmapTests(unittest.TestCase):
self.assertEqual(m.find(slice + b'x'), -1) self.assertEqual(m.find(slice + b'x'), -1)
m.close() m.close()
@unittest.skip("TODO: RUSTPYTHON; panic")
def test_find_end(self): def test_find_end(self):
# test the new 'end' parameter works as expected # test the new 'end' parameter works as expected
with open(TESTFN, 'wb+') as f: with open(TESTFN, 'wb+') as f:
@@ -299,7 +374,29 @@ class MmapTests(unittest.TestCase):
self.assertEqual(m.find(b'one', 1, -2), -1) self.assertEqual(m.find(b'one', 1, -2), -1)
self.assertEqual(m.find(bytearray(b'one')), 0) self.assertEqual(m.find(bytearray(b'one')), 0)
for i in range(-n-1, n+1):
for j in range(-n-1, n+1):
for p in [b"o", b"on", b"two", b"ones", b"s"]:
expected = data.find(p, i, j)
self.assertEqual(m.find(p, i, j), expected, (p, i, j))
def test_find_does_not_access_beyond_buffer(self):
try:
flags = mmap.MAP_PRIVATE | mmap.MAP_ANONYMOUS
PAGESIZE = mmap.PAGESIZE
PROT_NONE = 0
PROT_READ = mmap.PROT_READ
except AttributeError as e:
raise unittest.SkipTest("mmap flags unavailable") from e
for i in range(0, 2049):
with mmap.mmap(-1, PAGESIZE * (i + 1),
flags=flags, prot=PROT_NONE) as guard:
with mmap.mmap(-1, PAGESIZE * (i + 2048),
flags=flags, prot=PROT_READ) as fm:
fm.find(b"fo", -2)
@unittest.expectedFailureIfWindows("TODO: RUSTPYTHON")
def test_rfind(self): def test_rfind(self):
# test the new 'end' parameter works as expected # test the new 'end' parameter works as expected
with open(TESTFN, 'wb+') as f: with open(TESTFN, 'wb+') as f:
@@ -360,6 +457,7 @@ class MmapTests(unittest.TestCase):
self.assertRaises(ValueError, mmap.mmap, f.fileno(), 0, self.assertRaises(ValueError, mmap.mmap, f.fileno(), 0,
offset=2147418112) offset=2147418112)
@unittest.expectedFailureIfWindows("TODO: RUSTPYTHON")
def test_move(self): def test_move(self):
# make move works everywhere (64-bit format problem earlier) # make move works everywhere (64-bit format problem earlier)
with open(TESTFN, 'wb+') as f: with open(TESTFN, 'wb+') as f:
@@ -407,7 +505,6 @@ class MmapTests(unittest.TestCase):
m.move(0, 0, 1) m.move(0, 0, 1)
m.move(0, 0, 0) m.move(0, 0, 0)
def test_anonymous(self): def test_anonymous(self):
# anonymous mmap.mmap(-1, PAGE) # anonymous mmap.mmap(-1, PAGE)
m = mmap.mmap(-1, PAGESIZE) m = mmap.mmap(-1, PAGESIZE)
@@ -420,6 +517,7 @@ class MmapTests(unittest.TestCase):
m[x] = b m[x] = b
self.assertEqual(m[x], b) self.assertEqual(m[x], b)
@unittest.expectedFailureIfWindows("TODO: RUSTPYTHON")
def test_read_all(self): def test_read_all(self):
m = mmap.mmap(-1, 16) m = mmap.mmap(-1, 16)
self.addCleanup(m.close) self.addCleanup(m.close)
@@ -441,6 +539,7 @@ class MmapTests(unittest.TestCase):
m.seek(9) m.seek(9)
self.assertEqual(m.read(-42), bytes(range(9, 16))) self.assertEqual(m.read(-42), bytes(range(9, 16)))
@unittest.expectedFailureIfWindows("TODO: RUSTPYTHON")
def test_read_invalid_arg(self): def test_read_invalid_arg(self):
m = mmap.mmap(-1, 16) m = mmap.mmap(-1, 16)
self.addCleanup(m.close) self.addCleanup(m.close)
@@ -557,6 +656,7 @@ class MmapTests(unittest.TestCase):
except OSError: except OSError:
pass pass
@unittest.expectedFailureIfWindows("TODO: RUSTPYTHON")
def test_subclass(self): def test_subclass(self):
class anon_mmap(mmap.mmap): class anon_mmap(mmap.mmap):
def __new__(klass, *args, **kwargs): def __new__(klass, *args, **kwargs):
@@ -609,6 +709,7 @@ class MmapTests(unittest.TestCase):
self.assertEqual(m[:], b"012barbaz9") self.assertEqual(m[:], b"012barbaz9")
self.assertRaises(ValueError, m.write, b"ba") self.assertRaises(ValueError, m.write, b"ba")
@unittest.expectedFailureIfWindows("TODO: RUSTPYTHON")
def test_non_ascii_byte(self): def test_non_ascii_byte(self):
for b in (129, 200, 255): # > 128 for b in (129, 200, 255): # > 128
m = mmap.mmap(-1, 1) m = mmap.mmap(-1, 1)
@@ -618,6 +719,7 @@ class MmapTests(unittest.TestCase):
self.assertEqual(m.read_byte(), b) self.assertEqual(m.read_byte(), b)
m.close() m.close()
@unittest.expectedFailureIfWindows("TODO: RUSTPYTHON")
@unittest.skipUnless(os.name == 'nt', 'requires Windows') @unittest.skipUnless(os.name == 'nt', 'requires Windows')
def test_tagname(self): def test_tagname(self):
data1 = b"0123456789" data1 = b"0123456789"
@@ -646,14 +748,16 @@ class MmapTests(unittest.TestCase):
m2.close() m2.close()
m1.close() m1.close()
with self.assertRaisesRegex(TypeError, 'tagname'):
mmap.mmap(-1, 8, tagname=1)
@cpython_only @cpython_only
@unittest.skipUnless(os.name == 'nt', 'requires Windows') @unittest.skipUnless(os.name == 'nt', 'requires Windows')
def test_sizeof(self): def test_sizeof(self):
m1 = mmap.mmap(-1, 100) m1 = mmap.mmap(-1, 100)
tagname = random_tagname() tagname = random_tagname()
m2 = mmap.mmap(-1, 100, tagname=tagname) m2 = mmap.mmap(-1, 100, tagname=tagname)
self.assertEqual(sys.getsizeof(m2), self.assertGreater(sys.getsizeof(m2), sys.getsizeof(m1))
sys.getsizeof(m1) + len(tagname) + 1)
@unittest.skipUnless(os.name == 'nt', 'requires Windows') @unittest.skipUnless(os.name == 'nt', 'requires Windows')
def test_crasher_on_windows(self): def test_crasher_on_windows(self):
@@ -708,6 +812,7 @@ class MmapTests(unittest.TestCase):
"wrong exception raised in context manager") "wrong exception raised in context manager")
self.assertTrue(m.closed, "context manager failed") self.assertTrue(m.closed, "context manager failed")
@unittest.expectedFailureIfWindows("TODO: RUSTPYTHON")
def test_weakref(self): def test_weakref(self):
# Check mmap objects are weakrefable # Check mmap objects are weakrefable
mm = mmap.mmap(-1, 16) mm = mmap.mmap(-1, 16)
@@ -717,6 +822,7 @@ class MmapTests(unittest.TestCase):
gc_collect() gc_collect()
self.assertIs(wr(), None) self.assertIs(wr(), None)
@unittest.expectedFailureIfWindows("TODO: RUSTPYTHON")
def test_write_returning_the_number_of_bytes_written(self): def test_write_returning_the_number_of_bytes_written(self):
mm = mmap.mmap(-1, 16) mm = mmap.mmap(-1, 16)
self.assertEqual(mm.write(b""), 0) self.assertEqual(mm.write(b""), 0)
@@ -724,6 +830,7 @@ class MmapTests(unittest.TestCase):
self.assertEqual(mm.write(b"yz"), 2) self.assertEqual(mm.write(b"yz"), 2)
self.assertEqual(mm.write(b"python"), 6) self.assertEqual(mm.write(b"python"), 6)
@unittest.expectedFailureIfWindows("TODO: RUSTPYTHON")
def test_resize_past_pos(self): def test_resize_past_pos(self):
m = mmap.mmap(-1, 8192) m = mmap.mmap(-1, 8192)
self.addCleanup(m.close) self.addCleanup(m.close)
@@ -744,7 +851,7 @@ class MmapTests(unittest.TestCase):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
m * 2 m * 2
@unittest.skipIf(sys.platform.startswith("linux"), "TODO: RUSTPYTHON, memmap2 doesn't throw OSError when offset is not a multiple of mmap.PAGESIZE on Linux") @unittest.skipIf(sys.platform.startswith("linux"), "TODO: RUSTPYTHON; memmap2 doesn't throw OSError when offset is not a multiple of mmap.PAGESIZE on Linux")
def test_flush_return_value(self): def test_flush_return_value(self):
# mm.flush() should return None on success, raise an # mm.flush() should return None on success, raise an
# exception on error under all platforms. # exception on error under all platforms.
@@ -753,11 +860,13 @@ class MmapTests(unittest.TestCase):
mm.write(b'python') mm.write(b'python')
result = mm.flush() result = mm.flush()
self.assertIsNone(result) self.assertIsNone(result)
if sys.platform.startswith('linux'): if (sys.platform.startswith(('linux', 'android'))
and not in_systemd_nspawn_sync_suppressed()):
# 'offset' must be a multiple of mmap.PAGESIZE on Linux. # 'offset' must be a multiple of mmap.PAGESIZE on Linux.
# See bpo-34754 for details. # See bpo-34754 for details.
self.assertRaises(OSError, mm.flush, 1, len(b'python')) self.assertRaises(OSError, mm.flush, 1, len(b'python'))
@unittest.expectedFailureIfWindows("TODO: RUSTPYTHON")
def test_repr(self): def test_repr(self):
open_mmap_repr_pat = re.compile( open_mmap_repr_pat = re.compile(
r"<mmap.mmap closed=False, " r"<mmap.mmap closed=False, "
@@ -794,11 +903,16 @@ class MmapTests(unittest.TestCase):
match = closed_mmap_repr_pat.match(repr(mm)) match = closed_mmap_repr_pat.match(repr(mm))
self.assertIsNotNone(match) self.assertIsNotNone(match)
@unittest.expectedFailure # TODO: RUSTPYTHON
@unittest.skipUnless(hasattr(mmap.mmap, 'madvise'), 'needs madvise') @unittest.skipUnless(hasattr(mmap.mmap, 'madvise'), 'needs madvise')
def test_madvise(self): def test_madvise(self):
size = 2 * PAGESIZE size = 2 * PAGESIZE
m = mmap.mmap(-1, size) m = mmap.mmap(-1, size)
class Number:
def __index__(self):
return 2
with self.assertRaisesRegex(ValueError, "madvise start out of bounds"): with self.assertRaisesRegex(ValueError, "madvise start out of bounds"):
m.madvise(mmap.MADV_NORMAL, size) m.madvise(mmap.MADV_NORMAL, size)
with self.assertRaisesRegex(ValueError, "madvise start out of bounds"): with self.assertRaisesRegex(ValueError, "madvise start out of bounds"):
@@ -807,42 +921,84 @@ class MmapTests(unittest.TestCase):
m.madvise(mmap.MADV_NORMAL, 0, -1) m.madvise(mmap.MADV_NORMAL, 0, -1)
with self.assertRaisesRegex(OverflowError, "madvise length too large"): with self.assertRaisesRegex(OverflowError, "madvise length too large"):
m.madvise(mmap.MADV_NORMAL, PAGESIZE, sys.maxsize) m.madvise(mmap.MADV_NORMAL, PAGESIZE, sys.maxsize)
with self.assertRaisesRegex(
TypeError, "'str' object cannot be interpreted as an integer"):
m.madvise(mmap.MADV_NORMAL, PAGESIZE, "Not a Number")
self.assertEqual(m.madvise(mmap.MADV_NORMAL), None) self.assertEqual(m.madvise(mmap.MADV_NORMAL), None)
self.assertEqual(m.madvise(mmap.MADV_NORMAL, PAGESIZE), None) self.assertEqual(m.madvise(mmap.MADV_NORMAL, PAGESIZE), None)
self.assertEqual(m.madvise(mmap.MADV_NORMAL, PAGESIZE, size), None) self.assertEqual(m.madvise(mmap.MADV_NORMAL, PAGESIZE, size), None)
self.assertEqual(m.madvise(mmap.MADV_NORMAL, 0, 2), None) self.assertEqual(m.madvise(mmap.MADV_NORMAL, 0, 2), None)
self.assertEqual(m.madvise(mmap.MADV_NORMAL, 0, Number()), None)
self.assertEqual(m.madvise(mmap.MADV_NORMAL, 0, size), None) self.assertEqual(m.madvise(mmap.MADV_NORMAL, 0, size), None)
@unittest.skipUnless(os.name == 'nt', 'requires Windows') @unittest.expectedFailureIf(sys.platform in ("linux", "win32"), "TODO: RUSTPYTHON")
def test_resize_up_when_mapped_to_pagefile(self): def test_resize_up_anonymous_mapping(self):
"""If the mmap is backed by the pagefile ensure a resize up can happen """If the mmap is backed by the pagefile ensure a resize up can happen
and that the original data is still in place and that the original data is still in place
""" """
start_size = PAGESIZE start_size = PAGESIZE
new_size = 2 * start_size new_size = 2 * start_size
data = bytes(random.getrandbits(8) for _ in range(start_size)) data = random.randbytes(start_size)
m = mmap.mmap(-1, start_size) with mmap.mmap(-1, start_size) as m:
m[:] = data m[:] = data
m.resize(new_size) if sys.platform.startswith(('linux', 'android')):
self.assertEqual(len(m), new_size) # Can't expand a shared anonymous mapping on Linux.
self.assertEqual(m[:start_size], data[:start_size]) # See https://bugzilla.kernel.org/show_bug.cgi?id=8691
with self.assertRaises(ValueError):
m.resize(new_size)
else:
try:
m.resize(new_size)
except SystemError:
pass
else:
self.assertEqual(len(m), new_size)
self.assertEqual(m[:start_size], data)
self.assertEqual(m[start_size:], b'\0' * (new_size - start_size))
@unittest.skipUnless(os.name == 'nt', 'requires Windows') @unittest.expectedFailureIfWindows("TODO: RUSTPYTHON")
def test_resize_down_when_mapped_to_pagefile(self): @unittest.skipUnless(os.name == 'posix', 'requires Posix')
def test_resize_up_private_anonymous_mapping(self):
start_size = PAGESIZE
new_size = 2 * start_size
data = random.randbytes(start_size)
with mmap.mmap(-1, start_size, flags=mmap.MAP_PRIVATE) as m:
m[:] = data
try:
m.resize(new_size)
except SystemError:
pass
else:
self.assertEqual(len(m), new_size)
self.assertEqual(m[:start_size], data)
self.assertEqual(m[start_size:], b'\0' * (new_size - start_size))
@unittest.expectedFailureIfWindows("TODO: RUSTPYTHON")
def test_resize_down_anonymous_mapping(self):
"""If the mmap is backed by the pagefile ensure a resize down up can happen """If the mmap is backed by the pagefile ensure a resize down up can happen
and that a truncated form of the original data is still in place and that a truncated form of the original data is still in place
""" """
start_size = PAGESIZE start_size = 2 * PAGESIZE
new_size = start_size // 2 new_size = start_size // 2
data = bytes(random.getrandbits(8) for _ in range(start_size)) data = random.randbytes(start_size)
m = mmap.mmap(-1, start_size) with mmap.mmap(-1, start_size) as m:
m[:] = data m[:] = data
m.resize(new_size) try:
self.assertEqual(len(m), new_size) m.resize(new_size)
self.assertEqual(m[:new_size], data[:new_size]) except SystemError:
pass
else:
self.assertEqual(len(m), new_size)
self.assertEqual(m[:], data[:new_size])
if sys.platform.startswith(('linux', 'android')):
# Can't expand to its original size.
with self.assertRaises(ValueError):
m.resize(start_size)
@unittest.expectedFailureIfWindows("TODO: RUSTPYTHON")
@unittest.skipUnless(os.name == 'nt', 'requires Windows') @unittest.skipUnless(os.name == 'nt', 'requires Windows')
def test_resize_fails_if_mapping_held_elsewhere(self): def test_resize_fails_if_mapping_held_elsewhere(self):
"""If more than one mapping is held against a named file on Windows, neither """If more than one mapping is held against a named file on Windows, neither
@@ -867,6 +1023,7 @@ class MmapTests(unittest.TestCase):
finally: finally:
f.close() f.close()
@unittest.expectedFailureIfWindows("TODO: RUSTPYTHON")
@unittest.skipUnless(os.name == 'nt', 'requires Windows') @unittest.skipUnless(os.name == 'nt', 'requires Windows')
def test_resize_succeeds_with_error_for_second_named_mapping(self): def test_resize_succeeds_with_error_for_second_named_mapping(self):
"""If a more than one mapping exists of the same name, none of them can """If a more than one mapping exists of the same name, none of them can
@@ -888,6 +1045,168 @@ class MmapTests(unittest.TestCase):
self.assertEqual(m1[:data_length], data) self.assertEqual(m1[:data_length], data)
self.assertEqual(m2[:data_length], data) self.assertEqual(m2[:data_length], data)
@unittest.expectedFailure # TODO: RUSTPYTHON
def test_mmap_closed_by_int_scenarios(self):
"""
gh-103987: Test that mmap objects raise ValueError
for closed mmap files
"""
class MmapClosedByIntContext:
def __init__(self, access) -> None:
self.access = access
def __enter__(self):
self.f = open(TESTFN, "w+b")
self.f.write(random.randbytes(100))
self.f.flush()
m = mmap.mmap(self.f.fileno(), 100, access=self.access)
class X:
def __index__(self):
m.close()
return 10
return (m, X)
def __exit__(self, exc_type, exc_value, traceback):
self.f.close()
read_access_modes = [
mmap.ACCESS_READ,
mmap.ACCESS_WRITE,
mmap.ACCESS_COPY,
mmap.ACCESS_DEFAULT,
]
write_access_modes = [
mmap.ACCESS_WRITE,
mmap.ACCESS_COPY,
mmap.ACCESS_DEFAULT,
]
for access in read_access_modes:
with MmapClosedByIntContext(access) as (m, X):
with self.assertRaisesRegex(ValueError, "mmap closed or invalid"):
m[X()]
with MmapClosedByIntContext(access) as (m, X):
with self.assertRaisesRegex(ValueError, "mmap closed or invalid"):
m[X() : 20]
with MmapClosedByIntContext(access) as (m, X):
with self.assertRaisesRegex(ValueError, "mmap closed or invalid"):
m[X() : 20 : 2]
with MmapClosedByIntContext(access) as (m, X):
with self.assertRaisesRegex(ValueError, "mmap closed or invalid"):
m[20 : X() : -2]
with MmapClosedByIntContext(access) as (m, X):
with self.assertRaisesRegex(ValueError, "mmap closed or invalid"):
m.read(X())
with MmapClosedByIntContext(access) as (m, X):
with self.assertRaisesRegex(ValueError, "mmap closed or invalid"):
m.find(b"1", 1, X())
for access in write_access_modes:
with MmapClosedByIntContext(access) as (m, X):
with self.assertRaisesRegex(ValueError, "mmap closed or invalid"):
m[X() : 20] = b"1" * 10
with MmapClosedByIntContext(access) as (m, X):
with self.assertRaisesRegex(ValueError, "mmap closed or invalid"):
m[X() : 20 : 2] = b"1" * 5
with MmapClosedByIntContext(access) as (m, X):
with self.assertRaisesRegex(ValueError, "mmap closed or invalid"):
m[20 : X() : -2] = b"1" * 5
with MmapClosedByIntContext(access) as (m, X):
with self.assertRaisesRegex(ValueError, "mmap closed or invalid"):
m.move(1, 2, X())
with MmapClosedByIntContext(access) as (m, X):
with self.assertRaisesRegex(ValueError, "mmap closed or invalid"):
m.write_byte(X())
@unittest.skipUnless(os.name == 'nt', 'requires Windows')
@unittest.skipUnless(hasattr(mmap.mmap, '_protect'), 'test needs debug build')
def test_access_violations(self):
from test.support.os_helper import TESTFN
code = textwrap.dedent("""
import faulthandler
import mmap
import os
import sys
from contextlib import suppress
# Prevent logging access violations to stderr.
faulthandler.disable()
PAGESIZE = mmap.PAGESIZE
PAGE_NOACCESS = 0x01
with open(sys.argv[1], 'bw+') as f:
f.write(b'A'* PAGESIZE)
f.flush()
m = mmap.mmap(f.fileno(), PAGESIZE)
m._protect(PAGE_NOACCESS, 0, PAGESIZE)
with suppress(OSError):
m.read(PAGESIZE)
assert False, 'mmap.read() did not raise'
with suppress(OSError):
m.read_byte()
assert False, 'mmap.read_byte() did not raise'
with suppress(OSError):
m.readline()
assert False, 'mmap.readline() did not raise'
with suppress(OSError):
m.write(b'A'* PAGESIZE)
assert False, 'mmap.write() did not raise'
with suppress(OSError):
m.write_byte(0)
assert False, 'mmap.write_byte() did not raise'
with suppress(OSError):
m[0] # test mmap_subscript
assert False, 'mmap.__getitem__() did not raise'
with suppress(OSError):
m[0:10] # test mmap_subscript
assert False, 'mmap.__getitem__() did not raise'
with suppress(OSError):
m[0:10:2] # test mmap_subscript
assert False, 'mmap.__getitem__() did not raise'
with suppress(OSError):
m[0] = 1
assert False, 'mmap.__setitem__() did not raise'
with suppress(OSError):
m[0:10] = b'A'* 10
assert False, 'mmap.__setitem__() did not raise'
with suppress(OSError):
m[0:10:2] = b'A'* 5
assert False, 'mmap.__setitem__() did not raise'
with suppress(OSError):
m.move(0, 10, 1)
assert False, 'mmap.move() did not raise'
with suppress(OSError):
list(m) # test mmap_item
assert False, 'mmap.__getitem__() did not raise'
with suppress(OSError):
m.find(b'A')
assert False, 'mmap.find() did not raise'
with suppress(OSError):
m.rfind(b'A')
assert False, 'mmap.rfind() did not raise'
""")
rt, stdout, stderr = assert_python_ok("-c", code, TESTFN)
self.assertEqual(stdout.strip(), b'')
self.assertEqual(stderr.strip(), b'')
class LargeMmapTests(unittest.TestCase): class LargeMmapTests(unittest.TestCase):
def setUp(self): def setUp(self):
@@ -897,7 +1216,7 @@ class LargeMmapTests(unittest.TestCase):
unlink(TESTFN) unlink(TESTFN)
def _make_test_file(self, num_zeroes, tail): def _make_test_file(self, num_zeroes, tail):
if sys.platform[:3] == 'win' or sys.platform == 'darwin': if sys.platform[:3] == 'win' or is_apple:
requires('largefile', requires('largefile',
'test requires %s bytes and a long time to run' % str(0x180000000)) 'test requires %s bytes and a long time to run' % str(0x180000000))
f = open(TESTFN, 'w+b') f = open(TESTFN, 'w+b')

View File

@@ -3,6 +3,5 @@ from test._test_multiprocessing import install_tests_in_module_dict
install_tests_in_module_dict(globals(), 'fork', only_type="processes") install_tests_in_module_dict(globals(), 'fork', only_type="processes")
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@@ -1934,6 +1934,11 @@ class _PosixSpawnMixin:
) )
support.wait_process(pid, exitcode=0) support.wait_process(pid, exitcode=0)
def test_setpgroup_allow_none(self):
path, args = self.NOOP_PROGRAM[0], self.NOOP_PROGRAM
pid = self.spawn_func(path, args, os.environ, setpgroup=None)
support.wait_process(pid, exitcode=0)
def test_setpgroup_wrong_type(self): def test_setpgroup_wrong_type(self):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
self.spawn_func(sys.executable, self.spawn_func(sys.executable,
@@ -2034,6 +2039,21 @@ class _PosixSpawnMixin:
[sys.executable, "-c", "pass"], [sys.executable, "-c", "pass"],
os.environ, setsigdef=[signal.NSIG, signal.NSIG+1]) os.environ, setsigdef=[signal.NSIG, signal.NSIG+1])
def test_scheduler_allow_none(self):
path, args = self.NOOP_PROGRAM[0], self.NOOP_PROGRAM
pid = self.spawn_func(path, args, os.environ, scheduler=None)
support.wait_process(pid, exitcode=0)
@unittest.expectedFailure # TODO: RUSTPYTHON; Wrong error message
@support.subTests("scheduler", [object(), 1, [1, 2]])
def test_scheduler_wrong_type(self, scheduler):
path, args = self.NOOP_PROGRAM[0], self.NOOP_PROGRAM
with self.assertRaisesRegex(
TypeError,
"scheduler must be a tuple or None",
):
self.spawn_func(path, args, os.environ, scheduler=scheduler)
@unittest.expectedFailureIf(sys.platform in ("darwin", "linux"), "TODO: RUSTPYTHON; NotImplementedError: scheduler parameter is not yet implemented") @unittest.expectedFailureIf(sys.platform in ("darwin", "linux"), "TODO: RUSTPYTHON; NotImplementedError: scheduler parameter is not yet implemented")
@requires_sched @requires_sched
@unittest.skipIf(sys.platform.startswith(('freebsd', 'netbsd')), @unittest.skipIf(sys.platform.startswith(('freebsd', 'netbsd')),

View File

@@ -1,5 +1,7 @@
"""Unit tests for zero-argument super() & related machinery.""" """Unit tests for zero-argument super() & related machinery."""
import copy
import pickle
import textwrap import textwrap
import threading import threading
import unittest import unittest
@@ -7,9 +9,6 @@ from unittest.mock import patch
from test.support import import_helper, threading_helper from test.support import import_helper, threading_helper
ADAPTIVE_WARMUP_DELAY = 2
class A: class A:
def f(self): def f(self):
return 'A' return 'A'
@@ -91,8 +90,7 @@ class TestSuper(unittest.TestCase):
self.assertEqual(E().f(), 'AE') self.assertEqual(E().f(), 'AE')
# TODO: RUSTPYTHON; SyntaxError: name '__class__' is assigned to before global declaration @unittest.expectedFailure # TODO: RUSTPYTHON
'''
def test_various___class___pathologies(self): def test_various___class___pathologies(self):
# See issue #12370 # See issue #12370
class X(A): class X(A):
@@ -114,7 +112,7 @@ class TestSuper(unittest.TestCase):
__class__""", globals(), {}) __class__""", globals(), {})
self.assertIs(type(e.exception), NameError) # Not UnboundLocalError self.assertIs(type(e.exception), NameError) # Not UnboundLocalError
class X: class X:
global __class__ # global __class__ # TODO: RUSTPYTHON; SyntaxError: name '__class__' is assigned to before global declaration
__class__ = 42 __class__ = 42
def f(): def f():
__class__ __class__
@@ -122,12 +120,11 @@ class TestSuper(unittest.TestCase):
del globals()["__class__"] del globals()["__class__"]
self.assertNotIn("__class__", X.__dict__) self.assertNotIn("__class__", X.__dict__)
class X: class X:
nonlocal __class__ # nonlocal __class__ # TODO: RUSTPYTHON; SyntaxError: name '__class__' is assigned to before nonlocal declaration
__class__ = 42 __class__ = 42
def f(): def f():
__class__ __class__
self.assertEqual(__class__, 42) self.assertEqual(__class__, 42)
'''
def test___class___instancemethod(self): def test___class___instancemethod(self):
# See issue #14857 # See issue #14857
@@ -191,7 +188,7 @@ class TestSuper(unittest.TestCase):
B = type("B", (), test_namespace) B = type("B", (), test_namespace)
self.assertIs(B.f(), B) self.assertIs(B.f(), B)
@unittest.expectedFailure # TODO: RUSTPYTHON @unittest.expectedFailure # TODO: RUSTPYTHON
def test___class___mro(self): def test___class___mro(self):
# See issue #23722 # See issue #23722
test_class = None test_class = None
@@ -449,7 +446,7 @@ class TestSuper(unittest.TestCase):
self.assertEqual(C().method(), super) self.assertEqual(C().method(), super)
@unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: type 'super' is not an acceptable base type @unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: type 'super' is not an acceptable base type
def test_super_subclass___class__(self): def test_super_subclass___class__(self):
class mysuper(super): class mysuper(super):
pass pass
@@ -469,7 +466,8 @@ class TestSuper(unittest.TestCase):
super(MyType, type(mytype)).__setattr__(mytype, "bar", 1) super(MyType, type(mytype)).__setattr__(mytype, "bar", 1)
self.assertEqual(mytype.bar, 1) self.assertEqual(mytype.bar, 1)
for _ in range(ADAPTIVE_WARMUP_DELAY): _testinternalcapi = import_helper.import_module("_testinternalcapi")
for _ in range(_testinternalcapi.SPECIALIZATION_THRESHOLD):
test("foo1") test("foo1")
def test_reassigned_new(self): def test_reassigned_new(self):
@@ -488,7 +486,8 @@ class TestSuper(unittest.TestCase):
def __new__(cls): def __new__(cls):
return super().__new__(cls) return super().__new__(cls)
for _ in range(ADAPTIVE_WARMUP_DELAY): _testinternalcapi = import_helper.import_module("_testinternalcapi")
for _ in range(_testinternalcapi.SPECIALIZATION_THRESHOLD):
C() C()
def test_mixed_staticmethod_hierarchy(self): def test_mixed_staticmethod_hierarchy(self):
@@ -508,7 +507,8 @@ class TestSuper(unittest.TestCase):
def some(cls): def some(cls):
return super().some(cls) return super().some(cls)
for _ in range(ADAPTIVE_WARMUP_DELAY): _testinternalcapi = import_helper.import_module("_testinternalcapi")
for _ in range(_testinternalcapi.SPECIALIZATION_THRESHOLD):
C.some(C) C.some(C)
@threading_helper.requires_working_threading() @threading_helper.requires_working_threading()
@@ -544,6 +544,74 @@ class TestSuper(unittest.TestCase):
for thread in threads: for thread in threads:
thread.join() thread.join()
def test_special_methods(self):
for e in E(), E:
s = super(C, e)
self.assertEqual(s.__reduce__, e.__reduce__)
self.assertEqual(s.__reduce_ex__, e.__reduce_ex__)
self.assertEqual(s.__getstate__, e.__getstate__)
self.assertNotHasAttr(s, '__getnewargs__')
self.assertNotHasAttr(s, '__getnewargs_ex__')
self.assertNotHasAttr(s, '__setstate__')
self.assertNotHasAttr(s, '__copy__')
self.assertNotHasAttr(s, '__deepcopy__')
def test_pickling(self):
e = E()
e.x = 1
s = super(C, e)
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
with self.subTest(proto=proto):
u = pickle.loads(pickle.dumps(s, proto))
self.assertEqual(u.f(), s.f())
self.assertIs(type(u), type(s))
self.assertIs(type(u.__self__), E)
self.assertEqual(u.__self__.x, 1)
self.assertIs(u.__thisclass__, C)
self.assertIs(u.__self_class__, E)
s = super(C, E)
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
with self.subTest(proto=proto):
u = pickle.loads(pickle.dumps(s, proto))
self.assertEqual(u.cm(), s.cm())
self.assertEqual(u.f, s.f)
self.assertIs(type(u), type(s))
self.assertIs(u.__self__, E)
self.assertIs(u.__thisclass__, C)
self.assertIs(u.__self_class__, E)
def test_shallow_copying(self):
s = super(C, E())
self.assertIs(copy.copy(s), s)
s = super(C, E)
self.assertIs(copy.copy(s), s)
def test_deep_copying(self):
e = E()
e.x = [1]
s = super(C, e)
u = copy.deepcopy(s)
self.assertEqual(u.f(), s.f())
self.assertIs(type(u), type(s))
self.assertIsNot(u, s)
self.assertIs(type(u.__self__), E)
self.assertIsNot(u.__self__, e)
self.assertIsNot(u.__self__.x, e.x)
self.assertEqual(u.__self__.x, [1])
self.assertIs(u.__thisclass__, C)
self.assertIs(u.__self_class__, E)
s = super(C, E)
u = copy.deepcopy(s)
self.assertEqual(u.cm(), s.cm())
self.assertEqual(u.f, s.f)
self.assertIsNot(u, s)
self.assertIs(type(u), type(s))
self.assertIs(u.__self__, E)
self.assertIs(u.__thisclass__, C)
self.assertIs(u.__self_class__, E)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

84
Lib/test/test_uuid.py vendored
View File

@@ -1176,6 +1176,47 @@ class CommandLineTestCases:
self.assertEqual(cm.exception.code, 2) self.assertEqual(cm.exception.code, 2)
self.assertIn("error: Incorrect number of arguments", mock_err.getvalue()) self.assertIn("error: Incorrect number of arguments", mock_err.getvalue())
@mock.patch.object(sys, "argv",
["", "-u", "uuid3", "-n", "@dns", "-N", "python.org"])
def test_cli_uuid3_outputted_with_valid_namespace_and_name(self):
stdout = io.StringIO()
with contextlib.redirect_stdout(stdout):
self.uuid.main()
output = stdout.getvalue().strip()
uuid_output = self.uuid.UUID(output)
# Output should be in the form of uuid3
self.assertEqual(output, str(uuid_output))
self.assertEqual(uuid_output.version, 3)
@mock.patch.object(sys, "argv",
["", "-u", "uuid3", "-n",
"0d6a16cc-34a7-47d8-b660-214d0ae184d2",
"-N", "some.user"])
def test_cli_uuid3_outputted_with_custom_namespace_and_name(self):
stdout = io.StringIO()
with contextlib.redirect_stdout(stdout):
self.uuid.main()
output = stdout.getvalue().strip()
uuid_output = self.uuid.UUID(output)
# Output should be in the form of uuid3
self.assertEqual(output, str(uuid_output))
self.assertEqual(uuid_output.version, 3)
@mock.patch.object(sys, "argv",
["", "-u", "uuid3", "-n", "any UUID", "-N", "python.org"])
@mock.patch('sys.stderr', new_callable=io.StringIO)
def test_cli_uuid3_with_invalid_namespace(self, mock_err):
with self.assertRaises(SystemExit) as cm:
self.uuid.main()
# Check that exception code is the same as argparse.ArgumentParser.error
self.assertEqual(cm.exception.code, 2)
self.assertIn("error: badly formed hexadecimal UUID string",
mock_err.getvalue())
@mock.patch.object(sys, "argv", [""]) @mock.patch.object(sys, "argv", [""])
def test_cli_uuid4_outputted_with_no_args(self): def test_cli_uuid4_outputted_with_no_args(self):
stdout = io.StringIO() stdout = io.StringIO()
@@ -1203,23 +1244,9 @@ class CommandLineTestCases:
uuid_output = self.uuid.UUID(o) uuid_output = self.uuid.UUID(o)
self.assertEqual(uuid_output.version, 4) self.assertEqual(uuid_output.version, 4)
@mock.patch.object(sys, "argv",
["", "-u", "uuid3", "-n", "@dns", "-N", "python.org"])
def test_cli_uuid3_ouputted_with_valid_namespace_and_name(self):
stdout = io.StringIO()
with contextlib.redirect_stdout(stdout):
self.uuid.main()
output = stdout.getvalue().strip()
uuid_output = self.uuid.UUID(output)
# Output should be in the form of uuid5
self.assertEqual(output, str(uuid_output))
self.assertEqual(uuid_output.version, 3)
@mock.patch.object(sys, "argv", @mock.patch.object(sys, "argv",
["", "-u", "uuid5", "-n", "@dns", "-N", "python.org"]) ["", "-u", "uuid5", "-n", "@dns", "-N", "python.org"])
def test_cli_uuid5_ouputted_with_valid_namespace_and_name(self): def test_cli_uuid5_outputted_with_valid_namespace_and_name(self):
stdout = io.StringIO() stdout = io.StringIO()
with contextlib.redirect_stdout(stdout): with contextlib.redirect_stdout(stdout):
self.uuid.main() self.uuid.main()
@@ -1231,6 +1258,33 @@ class CommandLineTestCases:
self.assertEqual(output, str(uuid_output)) self.assertEqual(output, str(uuid_output))
self.assertEqual(uuid_output.version, 5) self.assertEqual(uuid_output.version, 5)
@mock.patch.object(sys, "argv",
["", "-u", "uuid5", "-n",
"0d6a16cc-34a7-47d8-b660-214d0ae184d2",
"-N", "some.user"])
def test_cli_uuid5_ouputted_with_custom_namespace_and_name(self):
stdout = io.StringIO()
with contextlib.redirect_stdout(stdout):
self.uuid.main()
output = stdout.getvalue().strip()
uuid_output = self.uuid.UUID(output)
# Output should be in the form of uuid5
self.assertEqual(output, str(uuid_output))
self.assertEqual(uuid_output.version, 5)
@mock.patch.object(sys, "argv",
["", "-u", "uuid5", "-n", "any UUID", "-N", "python.org"])
@mock.patch('sys.stderr', new_callable=io.StringIO)
def test_cli_uuid5_with_invalid_namespace(self, mock_err):
with self.assertRaises(SystemExit) as cm:
self.uuid.main()
# Check that exception code is the same as argparse.ArgumentParser.error
self.assertEqual(cm.exception.code, 2)
self.assertIn("error: badly formed hexadecimal UUID string",
mock_err.getvalue())
@mock.patch.object(sys, "argv", ["", "-u", "uuid6"]) @mock.patch.object(sys, "argv", ["", "-u", "uuid6"])
def test_cli_uuid6(self): def test_cli_uuid6(self):
self.do_test_standalone_uuid(6) self.do_test_standalone_uuid(6)

13
Lib/test/test_venv.py vendored
View File

@@ -266,6 +266,7 @@ class BasicTest(BaseTest):
with patch('venv.subprocess.check_output', pip_cmd_checker): with patch('venv.subprocess.check_output', pip_cmd_checker):
builder.upgrade_dependencies(fake_context) builder.upgrade_dependencies(fake_context)
@unittest.expectedFailureIfWindows("TODO: RUSTPYTHON")
@requireVenvCreate @requireVenvCreate
def test_prefixes(self): def test_prefixes(self):
""" """
@@ -285,6 +286,7 @@ class BasicTest(BaseTest):
self.assertEqual(pathlib.Path(out.strip().decode()), self.assertEqual(pathlib.Path(out.strip().decode()),
pathlib.Path(expected), prefix) pathlib.Path(expected), prefix)
@unittest.expectedFailureIfWindows("TODO: RUSTPYTHON")
@requireVenvCreate @requireVenvCreate
def test_sysconfig(self): def test_sysconfig(self):
""" """
@@ -318,6 +320,7 @@ class BasicTest(BaseTest):
out, err = check_output(cmd, encoding='utf-8') out, err = check_output(cmd, encoding='utf-8')
self.assertEqual(out.strip(), expected, err) self.assertEqual(out.strip(), expected, err)
@unittest.expectedFailureIfWindows("TODO: RUSTPYTHON")
@requireVenvCreate @requireVenvCreate
@unittest.skipUnless(can_symlink(), 'Needs symlinks') @unittest.skipUnless(can_symlink(), 'Needs symlinks')
def test_sysconfig_symlinks(self): def test_sysconfig_symlinks(self):
@@ -458,6 +461,7 @@ class BasicTest(BaseTest):
data = self.get_text_file_contents('pyvenv.cfg') data = self.get_text_file_contents('pyvenv.cfg')
self.assertIn('include-system-site-packages = %s\n' % s, data) self.assertIn('include-system-site-packages = %s\n' % s, data)
@unittest.expectedFailureIfWindows("TODO: RUSTPYTHON")
@unittest.skipUnless(can_symlink(), 'Needs symlinks') @unittest.skipUnless(can_symlink(), 'Needs symlinks')
def test_symlinking(self): def test_symlinking(self):
""" """
@@ -482,6 +486,7 @@ class BasicTest(BaseTest):
# run the test, the pyvenv.cfg in the venv created in the test will # run the test, the pyvenv.cfg in the venv created in the test will
# point to the venv being used to run the test, and we lose the link # point to the venv being used to run the test, and we lose the link
# to the source build - so Python can't initialise properly. # to the source build - so Python can't initialise properly.
@unittest.expectedFailureIfWindows("TODO: RUSTPYTHON")
@requireVenvCreate @requireVenvCreate
def test_executable(self): def test_executable(self):
""" """
@@ -494,6 +499,7 @@ class BasicTest(BaseTest):
'import sys; print(sys.executable)']) 'import sys; print(sys.executable)'])
self.assertEqual(out.strip(), envpy.encode()) self.assertEqual(out.strip(), envpy.encode())
@unittest.expectedFailureIfWindows("TODO: RUSTPYTHON")
@unittest.skipUnless(can_symlink(), 'Needs symlinks') @unittest.skipUnless(can_symlink(), 'Needs symlinks')
def test_executable_symlinks(self): def test_executable_symlinks(self):
""" """
@@ -562,6 +568,7 @@ class BasicTest(BaseTest):
self.assertEndsWith(lines[1], env_name.encode()) self.assertEndsWith(lines[1], env_name.encode())
# gh-124651: test quoted strings on Windows # gh-124651: test quoted strings on Windows
@unittest.expectedFailureIfWindows("TODO: RUSTPYTHON")
@unittest.skipUnless(os.name == 'nt', 'only relevant on Windows') @unittest.skipUnless(os.name == 'nt', 'only relevant on Windows')
def test_special_chars_windows(self): def test_special_chars_windows(self):
""" """
@@ -585,6 +592,7 @@ class BasicTest(BaseTest):
self.assertTrue(env_name.encode() in lines[0]) self.assertTrue(env_name.encode() in lines[0])
self.assertEndsWith(lines[1], env_name.encode()) self.assertEndsWith(lines[1], env_name.encode())
@unittest.expectedFailureIfWindows("TODO: RUSTPYTHON")
@unittest.skipUnless(os.name == 'nt', 'only relevant on Windows') @unittest.skipUnless(os.name == 'nt', 'only relevant on Windows')
def test_unicode_in_batch_file(self): def test_unicode_in_batch_file(self):
""" """
@@ -616,6 +624,7 @@ class BasicTest(BaseTest):
filepath_regex = r"'[A-Z]:\\\\(?:[^\\\\]+\\\\)*[^\\\\]+'" filepath_regex = r"'[A-Z]:\\\\(?:[^\\\\]+\\\\)*[^\\\\]+'"
self.assertRegex(err, rf"Unable to symlink {filepath_regex} to {filepath_regex}") self.assertRegex(err, rf"Unable to symlink {filepath_regex} to {filepath_regex}")
@unittest.expectedFailureIfWindows("TODO: RUSTPYTHON")
@requireVenvCreate @requireVenvCreate
def test_multiprocessing(self): def test_multiprocessing(self):
""" """
@@ -635,6 +644,7 @@ class BasicTest(BaseTest):
'pool.terminate()']) 'pool.terminate()'])
self.assertEqual(out.strip(), "python".encode()) self.assertEqual(out.strip(), "python".encode())
@unittest.expectedFailureIfWindows("TODO: RUSTPYTHON")
@requireVenvCreate @requireVenvCreate
def test_multiprocessing_recursion(self): def test_multiprocessing_recursion(self):
""" """
@@ -892,6 +902,7 @@ class BasicTest(BaseTest):
self.assertFalse(same_path(path1, path2)) self.assertFalse(same_path(path1, path2))
# gh-126084: venvwlauncher should run pythonw, not python # gh-126084: venvwlauncher should run pythonw, not python
@unittest.expectedFailureIfWindows("TODO: RUSTPYTHON")
@requireVenvCreate @requireVenvCreate
@unittest.skipUnless(os.name == 'nt', 'only relevant on Windows') @unittest.skipUnless(os.name == 'nt', 'only relevant on Windows')
def test_venvwlauncher(self): def test_venvwlauncher(self):
@@ -926,11 +937,13 @@ class EnsurePipTest(BaseTest):
self.assertEqual(out.strip(), "OK") self.assertEqual(out.strip(), "OK")
@unittest.expectedFailureIfWindows("TODO: RUSTPYTHON")
def test_no_pip_by_default(self): def test_no_pip_by_default(self):
rmtree(self.env_dir) rmtree(self.env_dir)
self.run_with_capture(venv.create, self.env_dir) self.run_with_capture(venv.create, self.env_dir)
self.assert_pip_not_installed() self.assert_pip_not_installed()
@unittest.expectedFailureIfWindows("TODO: RUSTPYTHON")
def test_explicit_no_pip(self): def test_explicit_no_pip(self):
rmtree(self.env_dir) rmtree(self.env_dir)
self.run_with_capture(venv.create, self.env_dir, with_pip=False) self.run_with_capture(venv.create, self.env_dir, with_pip=False)

10
Lib/uuid.py vendored
View File

@@ -961,7 +961,7 @@ def main():
default="uuid4", default="uuid4",
help="function to generate the UUID") help="function to generate the UUID")
parser.add_argument("-n", "--namespace", parser.add_argument("-n", "--namespace",
choices=["any UUID", *namespaces.keys()], metavar=f"{{any UUID,{','.join(namespaces)}}}",
help="uuid3/uuid5 only: " help="uuid3/uuid5 only: "
"a UUID, or a well-known predefined UUID addressed " "a UUID, or a well-known predefined UUID addressed "
"by namespace name") "by namespace name")
@@ -983,7 +983,13 @@ def main():
f"{args.uuid} requires a namespace and a name. " f"{args.uuid} requires a namespace and a name. "
"Run 'python -m uuid -h' for more information." "Run 'python -m uuid -h' for more information."
) )
namespace = namespaces[namespace] if namespace in namespaces else UUID(namespace) if namespace in namespaces:
namespace = namespaces[namespace]
else:
try:
namespace = UUID(namespace)
except ValueError as exc:
parser.error(f"{exc}: {args.namespace!r}")
for _ in range(args.count): for _ in range(args.count):
print(uuid_func(namespace, name)) print(uuid_func(namespace, name))
else: else:

View File

@@ -13,6 +13,7 @@ crate-type = ["cdylib", "rlib"]
[dependencies] [dependencies]
bitflags = { workspace = true } bitflags = { workspace = true }
itertools = { workspace = true }
num-complex = { workspace = true } num-complex = { workspace = true }
rustpython-vm = { workspace = true, features = ["threading", "compiler"] } rustpython-vm = { workspace = true, features = ["threading", "compiler"] }
rustpython-stdlib = {workspace = true, features = ["threading"] } rustpython-stdlib = {workspace = true, features = ["threading"] }

View File

@@ -1,9 +1,14 @@
use crate::{PyObject, pystate::with_vm}; use crate::{PyObject, pystate::with_vm};
use alloc::slice; use alloc::slice;
use core::ffi::c_int; use core::ffi::c_int;
pub use mapping::*;
use rustpython_vm::builtins::{PyDict, PyStr, PyTuple}; use rustpython_vm::builtins::{PyDict, PyStr, PyTuple};
use rustpython_vm::function::{FuncArgs, KwArgs, PosArgs}; use rustpython_vm::function::{FuncArgs, KwArgs, PosArgs};
use rustpython_vm::{AsObject, Py, PyObjectRef, PyResult, VirtualMachine}; use rustpython_vm::{AsObject, Py, PyObjectRef, PyResult, VirtualMachine};
pub use sequence::*;
mod mapping;
mod sequence;
const PY_VECTORCALL_ARGUMENTS_OFFSET: usize = 1usize << (usize::BITS as usize - 1); const PY_VECTORCALL_ARGUMENTS_OFFSET: usize = 1usize << (usize::BITS as usize - 1);

View File

@@ -0,0 +1,66 @@
use crate::{PyObject, pystate::with_vm};
#[unsafe(no_mangle)]
pub unsafe extern "C" fn PyMapping_Size(obj: *mut PyObject) -> isize {
with_vm(|vm| {
let obj = unsafe { &*obj };
obj.try_mapping(vm)?.length(vm)
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn PyMapping_Keys(obj: *mut PyObject) -> *mut PyObject {
with_vm(|vm| {
let obj = unsafe { &*obj };
let keys = obj.try_mapping(vm)?.keys(vm)?;
let iter = keys.get_iter(vm)?;
Ok(vm.ctx.new_list(iter.try_to_value(vm)?))
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn PyMapping_Values(obj: *mut PyObject) -> *mut PyObject {
with_vm(|vm| {
let obj = unsafe { &*obj };
let values = obj.try_mapping(vm)?.values(vm)?;
let iter = values.get_iter(vm)?;
Ok(vm.ctx.new_list(iter.try_to_value(vm)?))
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn PyMapping_Items(obj: *mut PyObject) -> *mut PyObject {
with_vm(|vm| {
let obj = unsafe { &*obj };
let items = obj.try_mapping(vm)?.items(vm)?;
let iter = items.get_iter(vm)?;
Ok(vm.ctx.new_list(iter.try_to_value(vm)?))
})
}
#[cfg(false)]
mod tests {
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyMapping, PyMappingMethods, PyTuple};
#[test]
fn size_keys_values_items() {
Python::attach(|py| {
let dict = PyDict::new(py);
dict.set_item("a", 1).unwrap();
dict.set_item("b", 2).unwrap();
let mapping = dict.cast_into::<PyMapping>().unwrap();
assert_eq!(mapping.len().unwrap(), 2);
let keys = mapping.keys().unwrap();
assert_eq!(keys.len(), 2);
let values = mapping.values().unwrap();
assert_eq!(values.len(), 2);
let items = mapping.items().unwrap();
assert_eq!(items.len(), 2);
assert!(items.iter().all(|item| item.cast_into::<PyTuple>().is_ok()));
})
}
}

View File

@@ -0,0 +1,286 @@
use crate::{PyObject, pystate::with_vm};
use core::ffi::c_int;
#[unsafe(no_mangle)]
pub unsafe extern "C" fn PySequence_Check(obj: *mut PyObject) -> c_int {
with_vm(|_vm| {
let obj = unsafe { &*obj };
Ok(obj.sequence_unchecked().check())
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn PySequence_Concat(
obj1: *mut PyObject,
obj2: *mut PyObject,
) -> *mut PyObject {
with_vm(|vm| {
let obj1 = unsafe { &*obj1 };
let obj2 = unsafe { &*obj2 };
obj1.try_sequence(vm)?.concat(obj2, vm)
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn PySequence_Count(obj: *mut PyObject, value: *mut PyObject) -> isize {
with_vm(|vm| {
let obj = unsafe { &*obj };
let value = unsafe { &*value };
obj.try_sequence(vm)?.count(value, vm)
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn PySequence_DelItem(obj: *mut PyObject, index: isize) -> c_int {
with_vm(|vm| {
let obj = unsafe { &*obj };
obj.try_sequence(vm)?.del_item(index, vm)
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn PySequence_DelSlice(obj: *mut PyObject, low: isize, high: isize) -> c_int {
with_vm(|vm| {
let obj = unsafe { &*obj };
obj.try_sequence(vm)?.del_slice(low, high, vm)
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn PySequence_GetItem(obj: *mut PyObject, index: isize) -> *mut PyObject {
with_vm(|vm| {
let obj = unsafe { &*obj };
obj.try_sequence(vm)?.get_item(index, vm)
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn PySequence_GetSlice(
obj: *mut PyObject,
low: isize,
high: isize,
) -> *mut PyObject {
with_vm(|vm| {
let obj = unsafe { &*obj };
obj.try_sequence(vm)?.get_slice(low, high, vm)
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn PySequence_InPlaceConcat(
obj1: *mut PyObject,
obj2: *mut PyObject,
) -> *mut PyObject {
with_vm(|vm| {
let obj1 = unsafe { &*obj1 };
let obj2 = unsafe { &*obj2 };
obj1.try_sequence(vm)?.inplace_concat(obj2, vm)
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn PySequence_InPlaceRepeat(
obj: *mut PyObject,
count: isize,
) -> *mut PyObject {
with_vm(|vm| {
let obj = unsafe { &*obj };
obj.try_sequence(vm)?.inplace_repeat(count, vm)
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn PySequence_Index(obj: *mut PyObject, value: *mut PyObject) -> isize {
with_vm(|vm| {
let obj = unsafe { &*obj };
let value = unsafe { &*value };
obj.try_sequence(vm)?.index(value, vm)
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn PySequence_List(obj: *mut PyObject) -> *mut PyObject {
with_vm(|vm| {
let obj = unsafe { &*obj };
Ok(obj.try_sequence(vm)?.list(vm))
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn PySequence_Repeat(obj: *mut PyObject, count: isize) -> *mut PyObject {
with_vm(|vm| {
let obj = unsafe { &*obj };
obj.try_sequence(vm)?.repeat(count, vm)
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn PySequence_SetItem(
obj: *mut PyObject,
index: isize,
value: *mut PyObject,
) -> c_int {
with_vm(|vm| {
let obj = unsafe { &*obj };
let value = unsafe { &*value };
obj.try_sequence(vm)?.set_item(index, value.to_owned(), vm)
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn PySequence_SetSlice(
obj: *mut PyObject,
low: isize,
high: isize,
value: *mut PyObject,
) -> c_int {
with_vm(|vm| {
let obj = unsafe { &*obj };
let value = unsafe { &*value };
obj.try_sequence(vm)?
.set_slice(low, high, value.to_owned(), vm)
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn PySequence_Size(obj: *mut PyObject) -> isize {
with_vm(|vm| {
let obj = unsafe { &*obj };
obj.try_sequence(vm)?.length(vm)
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn PySequence_Length(obj: *mut PyObject) -> isize {
unsafe { PySequence_Size(obj) }
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn PySequence_Tuple(obj: *mut PyObject) -> *mut PyObject {
with_vm(|vm| {
let obj = unsafe { &*obj };
Ok(obj.try_sequence(vm)?.tuple(vm))
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn PySequence_Contains(obj: *mut PyObject, value: *mut PyObject) -> c_int {
with_vm(|vm| {
let obj = unsafe { &*obj };
let value = unsafe { &*value };
obj.sequence_unchecked().contains(value, vm)
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn PySequence_In(obj: *mut PyObject, value: *mut PyObject) -> c_int {
unsafe { PySequence_Contains(obj, value) }
}
#[cfg(false)]
mod tests {
use pyo3::prelude::*;
use pyo3::types::{PyAnyMethods, PyDict, PyList, PySequence, PySequenceMethods, PyTuple};
#[test]
fn item_and_size_ops() {
Python::attach(|py| {
let list = PyList::new(py, [1, 2, 3]).unwrap();
let seq = list.cast_into::<PySequence>().unwrap();
assert_eq!(seq.len().unwrap(), 3);
assert_eq!(seq.get_item(1).unwrap().extract::<i32>().unwrap(), 2);
seq.set_item(1, 4).unwrap();
assert_eq!(seq.get_item(1).unwrap().extract::<i32>().unwrap(), 4);
seq.del_item(1).unwrap();
assert_eq!(seq.get_item(1).unwrap().extract::<i32>().unwrap(), 3);
});
}
#[test]
fn slice_ops() {
Python::attach(|py| {
let list = PyList::new(py, [1, 2, 3, 4]).unwrap();
let seq = list.cast_into::<PySequence>().unwrap();
let sub = seq.get_slice(1, 3).unwrap();
assert_eq!(sub.get_item(0).unwrap().extract::<i32>().unwrap(), 2);
assert_eq!(sub.get_item(1).unwrap().extract::<i32>().unwrap(), 3);
let repl = PyList::new(py, [8, 9]).unwrap();
seq.set_slice(1, 3, &repl).unwrap();
assert_eq!(seq.get_item(1).unwrap().extract::<i32>().unwrap(), 8);
assert_eq!(seq.get_item(2).unwrap().extract::<i32>().unwrap(), 9);
seq.del_slice(1, 3).unwrap();
assert_eq!(seq.get_item(0).unwrap().extract::<i32>().unwrap(), 1);
assert_eq!(seq.get_item(1).unwrap().extract::<i32>().unwrap(), 4);
});
}
#[test]
fn concat_repeat_and_inplace_ops() {
Python::attach(|py| {
let list = PyList::new(py, [1, 2]).unwrap();
let seq = list.cast_into::<PySequence>().unwrap();
let rhs = PyList::new(py, [3])
.unwrap()
.cast_into::<PySequence>()
.unwrap();
let concat = seq.concat(&rhs).unwrap();
assert_eq!(concat.get_item(2).unwrap().extract::<i32>().unwrap(), 3);
let repeat = seq.repeat(2).unwrap();
assert_eq!(repeat.get_item(2).unwrap().extract::<i32>().unwrap(), 1);
let iadd_rhs = PyList::new(py, [4])
.unwrap()
.cast_into::<PySequence>()
.unwrap();
seq.in_place_concat(&iadd_rhs).unwrap();
assert_eq!(seq.get_item(2).unwrap().extract::<i32>().unwrap(), 4);
seq.in_place_repeat(2).unwrap();
assert_eq!(seq.get_item(5).unwrap().extract::<i32>().unwrap(), 4);
});
}
#[test]
fn count_index_contains_and_convert_ops() {
Python::attach(|py| {
let tuple = PyTuple::new(py, [1, 2, 1])
.unwrap()
.cast_into::<PySequence>()
.unwrap();
assert_eq!(tuple.count(1).unwrap(), 2);
assert_eq!(tuple.index(2).unwrap(), 1);
assert!(tuple.contains(1).unwrap());
let as_list = tuple.to_list().unwrap().cast_into::<PySequence>().unwrap();
assert_eq!(as_list.get_item(0).unwrap().extract::<i32>().unwrap(), 1);
let as_tuple = as_list
.to_tuple()
.unwrap()
.cast_into::<PySequence>()
.unwrap();
assert_eq!(as_tuple.get_item(2).unwrap().extract::<i32>().unwrap(), 1);
});
}
#[test]
fn contains_works_for_dict() {
Python::attach(|py| {
let dict = PyDict::new(py);
dict.set_item("k", 1).unwrap();
let any = dict.into_any();
assert!(any.contains("k").unwrap());
assert!(!any.contains("missing").unwrap());
});
}
}

View File

@@ -25,10 +25,12 @@ pub mod pyerrors;
pub mod pylifecycle; pub mod pylifecycle;
pub mod pystate; pub mod pystate;
pub mod refcount; pub mod refcount;
pub mod setobject;
pub mod traceback; pub mod traceback;
pub mod tupleobject; pub mod tupleobject;
pub mod unicodeobject; pub mod unicodeobject;
mod util; mod util;
pub mod weakrefobject;
/// Get main interpreter of this process. Will be None if it has not been initialized yet. /// Get main interpreter of this process. Will be None if it has not been initialized yet.
pub fn get_main_interpreter() -> MutexGuard<'static, Option<Interpreter>> { pub fn get_main_interpreter() -> MutexGuard<'static, Option<Interpreter>> {

View File

@@ -0,0 +1,162 @@
use crate::PyObject;
use crate::object::define_py_check;
use crate::pystate::with_vm;
use core::ffi::c_int;
use itertools::process_results;
use rustpython_vm::AsObject;
use rustpython_vm::PyPayload;
use rustpython_vm::TryFromObject;
use rustpython_vm::builtins::{PyFrozenSet, PySet};
use rustpython_vm::function::ArgIterable;
define_py_check!(fn PySet_Check, types.set_type);
define_py_check!(fn PyFrozenSet_Check, types.frozenset_type);
#[unsafe(no_mangle)]
pub unsafe extern "C" fn PySet_New(iterable: *mut PyObject) -> *mut PyObject {
with_vm(|vm| {
if iterable.is_null() {
return Ok(PySet::default().into_ref(&vm.ctx));
}
let iterable = ArgIterable::try_from_object(vm, unsafe { &*iterable }.to_owned())?;
let set = PySet::default().into_ref(&vm.ctx);
for item in iterable.iter(vm)? {
set.add(item?, vm)?;
}
Ok(set)
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn PyFrozenSet_New(iterable: *mut PyObject) -> *mut PyObject {
with_vm(|vm| {
if iterable.is_null() {
return Ok(vm.ctx.empty_frozenset.to_owned());
}
let iterable = ArgIterable::try_from_object(vm, unsafe { &*iterable }.to_owned())?;
let set = process_results(iterable.iter(vm)?, |it| PyFrozenSet::from_iter(vm, it))??;
Ok(set.into_ref(&vm.ctx))
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn PySet_Add(set: *mut PyObject, key: *mut PyObject) -> c_int {
with_vm(|vm| {
let set = unsafe { &*set }.try_downcast_ref::<PySet>(vm)?;
let key = unsafe { &*key }.to_owned();
set.add(key, vm)
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn PySet_Clear(set: *mut PyObject) -> c_int {
with_vm(|vm| {
let set = unsafe { &*set }.try_downcast_ref::<PySet>(vm)?;
set.clear();
Ok(())
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn PySet_Contains(anyset: *mut PyObject, key: *mut PyObject) -> c_int {
with_vm(|vm| {
let anyset = unsafe { &*anyset };
let key = unsafe { &*key };
if let Some(set) = anyset.downcast_ref::<PySet>() {
set.__contains__(key, vm)
} else if let Some(frozenset) = anyset.downcast_ref::<PyFrozenSet>() {
frozenset.__contains__(key, vm)
} else {
Err(vm.new_type_error(format!(
"expected set or frozenset, got '{}'",
anyset.class().name()
)))
}
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn PySet_Discard(set: *mut PyObject, key: *mut PyObject) -> c_int {
with_vm(|vm| {
let set = unsafe { &*set }.try_downcast_ref::<PySet>(vm)?;
let key = unsafe { &*key };
let had_item = set.__contains__(key, vm)?;
if had_item {
set.discard(key.to_owned(), vm)?;
}
Ok(had_item)
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn PySet_Pop(set: *mut PyObject) -> *mut PyObject {
with_vm(|vm| {
let set = unsafe { &*set }.try_downcast_ref::<PySet>(vm)?;
set.pop(vm)
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn PySet_Size(anyset: *mut PyObject) -> isize {
with_vm(|vm| {
let anyset = unsafe { &*anyset };
if let Some(set) = anyset.downcast_ref::<PySet>() {
set.as_object().length(vm)
} else if let Some(frozenset) = anyset.downcast_ref::<PyFrozenSet>() {
frozenset.as_object().length(vm)
} else {
Err(vm.new_type_error(format!(
"expected set or frozenset, got '{}'",
anyset.class().name()
)))
}
})
}
#[cfg(false)]
mod tests {
use pyo3::prelude::*;
use pyo3::types::{PyFrozenSet, PyInt, PySet};
#[test]
fn new_and_size() {
Python::attach(|py| {
let set = PySet::empty(py).unwrap();
assert!(set.is_instance_of::<PySet>());
assert_eq!(set.len(), 0);
let frozen = PyFrozenSet::empty(py).unwrap();
assert!(frozen.is_instance_of::<PyFrozenSet>());
assert_eq!(frozen.len(), 0);
})
}
#[test]
fn add_contains_discard() {
Python::attach(|py| {
let set = PySet::empty(py).unwrap();
let item = PyInt::new(py, 42);
set.add(&item).unwrap();
assert!(set.contains(&item).unwrap());
set.discard(&item).unwrap();
assert!(!set.contains(&item).unwrap());
})
}
#[test]
fn pop_reduces_size() {
Python::attach(|py| {
let set = PySet::empty(py).unwrap();
set.add(7).unwrap();
assert_eq!(set.len(), 1);
let popped = set.pop().unwrap();
assert_eq!(popped.extract::<i32>().unwrap(), 7);
assert_eq!(set.len(), 0);
})
}
}

View File

@@ -0,0 +1,106 @@
use crate::PyObject;
use crate::object::define_py_check;
use crate::pystate::with_vm;
use core::ffi::c_int;
use rustpython_vm::builtins::{PyWeak, PyWeakProxy};
define_py_check!(fn PyWeakref_CheckProxy, types.weakproxy_type);
define_py_check!(fn PyWeakref_CheckRef, types.weakref_type);
#[unsafe(no_mangle)]
pub unsafe extern "C" fn PyWeakref_GetRef(
reference: *mut PyObject,
result: *mut *mut PyObject,
) -> c_int {
with_vm(|vm| {
unsafe {
*result = core::ptr::null_mut();
}
let reference = unsafe { &*reference };
let upgraded = if let Some(weak) = reference.downcast_ref::<PyWeak>() {
weak.upgrade()
} else if let Some(proxy) = reference.downcast_ref::<PyWeakProxy>() {
proxy.get_weak().upgrade()
} else {
return Err(vm.new_type_error("expected a weakref"));
};
if let Some(obj) = upgraded {
unsafe {
*result = obj.into_raw().as_ptr();
}
Ok(true)
} else {
Ok(false)
}
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn PyWeakref_NewProxy(
ob: *mut PyObject,
callback: *mut PyObject,
) -> *mut PyObject {
with_vm(|vm| {
let ob = unsafe { &*ob };
let callback = unsafe { callback.as_ref() }
.filter(|callback| !vm.is_none(callback))
.map(ToOwned::to_owned);
PyWeakProxy::new_weakproxy(ob, callback, vm)
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn PyWeakref_NewRef(
ob: *mut PyObject,
callback: *mut PyObject,
) -> *mut PyObject {
with_vm(|vm| {
let ob = unsafe { &*ob };
let callback = unsafe { callback.as_ref() }
.filter(|callback| !vm.is_none(callback))
.map(ToOwned::to_owned);
ob.downgrade(callback, vm)
})
}
#[cfg(false)]
mod tests {
use pyo3::prelude::*;
use pyo3::types::PyAnyMethods;
use pyo3::types::{PyInt, PyWeakrefMethods, PyWeakrefProxy, PyWeakrefReference};
#[test]
fn check_ref_and_proxy() {
Python::attach(|py| {
let object_ty = py.get_type::<PyInt>();
let weak_ref = PyWeakrefReference::new(&object_ty).unwrap();
let weak_proxy = PyWeakrefProxy::new(&object_ty).unwrap();
assert!(weak_ref.is_instance_of::<PyWeakrefReference>());
assert!(weak_proxy.is_instance_of::<PyWeakrefProxy>());
});
}
#[test]
fn new_ref_and_get_ref() {
Python::attach(|py| {
let object_ty = py.get_type::<PyInt>();
let weak_ref = PyWeakrefReference::new(&object_ty).unwrap();
assert!(weak_ref.upgrade().is_some_and(|obj| obj.is(&object_ty)));
});
}
#[test]
fn new_proxy() {
Python::attach(|py| {
let object_ty = py.get_type::<PyInt>();
let weak_proxy = PyWeakrefProxy::new(&object_ty).unwrap();
assert!(weak_proxy.upgrade().is_some_and(|obj| obj.is(&object_ty)));
});
}
}

309
crates/stdlib/src/_queue.rs Normal file
View File

@@ -0,0 +1,309 @@
pub(crate) use _queue::module_def;
#[pymodule]
mod _queue {
use alloc::collections::VecDeque;
use core::time::Duration;
use std::time::Instant;
use crate::vm::{
AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
builtins::{PyBaseExceptionRef, PyException, PyGenericAlias, PyStr, PyType, PyTypeRef},
function::{PyComparisonValue, TimeoutSeconds},
protocol::PyNumberMethods,
types::{AsNumber, Comparable, Constructor, PyComparisonOp, Representable},
};
type BufInner = VecDeque<PyObjectRef>;
cfg_select! {
feature = "threading" => {
use parking_lot::{Condvar, Mutex, MutexGuard};
type Buf = Mutex<BufInner>;
},
_ => {
use crate::common::lock::PyMutex;
type Buf = PyMutex<BufInner>;
}
}
const INITIAL_RING_BUF_CAPACITY: usize = 8;
#[pyattr]
#[pyclass(module = "_queue", name = "Empty", base = PyException)]
#[repr(transparent)]
pub(crate) struct PyEmptyError(PyException);
#[pyclass(flags(HAS_WEAKREF))]
impl PyEmptyError {}
/// ## See Also
///
/// [`empty_error`](https://github.com/python/cpython/blob/v3.14.5/Modules/_queuemodule.c#L347-L355).
fn empty_error(vm: &VirtualMachine) -> PyBaseExceptionRef {
vm.new_exception_empty(PyEmptyError::class(&vm.ctx).to_owned())
}
#[cfg(feature = "threading")]
#[derive(Debug)]
struct Semaphore {
mutex: Mutex<usize>,
cond: Condvar,
}
#[cfg(feature = "threading")]
impl Semaphore {
#[must_use]
fn new() -> Self {
Self {
mutex: Mutex::new(0),
cond: Condvar::new(),
}
}
fn release(&self) {
{
let mut count = self.mutex.lock();
*count += 1;
} // lock dropped. now we can notify a waiting thread
self.cond.notify_one();
}
/// Returns `true` if the semaphore was acquired, `false` on timeout.
#[must_use]
fn acquire(&self, block: bool, deadline: Option<Instant>, vm: &VirtualMachine) -> bool {
let mut count = self.mutex.lock();
loop {
if *count > 0 {
*count -= 1;
return true;
}
if !block {
return false;
}
match deadline {
Some(dl) => {
let result = vm.allow_threads(|| self.cond.wait_until(&mut count, dl));
if result.timed_out() && *count == 0 {
return false;
}
}
None => {
vm.allow_threads(|| self.cond.wait(&mut count));
}
}
}
}
}
#[pyattr]
#[pyclass(module = "_queue", name = "SimpleQueue", unhashable = true)]
#[derive(Debug, PyPayload)]
struct PySimpleQueue {
buf: Buf,
#[cfg(feature = "threading")]
sem: Semaphore,
}
impl Default for PySimpleQueue {
fn default() -> Self {
Self {
buf: Buf::new(VecDeque::with_capacity(INITIAL_RING_BUF_CAPACITY)),
#[cfg(feature = "threading")]
sem: Semaphore::new(),
}
}
}
impl PySimpleQueue {
fn push(&self, item: PyObjectRef) {
self.buf.lock().push_back(item);
#[cfg(feature = "threading")]
self.sem.release();
}
/// Returns a strong reference from the head of the buffer.
///
/// ## See Also
///
/// [`RingBuf_Get`](https://github.com/python/cpython/blob/v3.14.5/Modules/_queuemodule.c#L133-L154).
fn get_inner(
#[cfg(feature = "threading")] buf: &mut MutexGuard<'_, BufInner>,
#[cfg(not(feature = "threading"))] buf: &mut BufInner,
) -> Option<PyObjectRef> {
let cap = buf.capacity();
if buf.len() < (cap / 4) {
// Items is less than 25% occupied, shrink it by 50%. This allows for
// growth without immediately needing to resize the underlying items array
buf.shrink_to(cap / 2)
}
buf.pop_front()
}
}
#[derive(FromArgs)]
struct PutArgs {
#[pyarg(positional)]
item: PyObjectRef,
#[expect(
dead_code,
reason = "Intentional. Provide compatibility with the Queue class"
)]
#[pyarg(any, optional, default = true)]
block: bool,
#[expect(
dead_code,
reason = "Intentional. Provide compatibility with the Queue class"
)]
#[pyarg(any, optional)]
timeout: Option<PyObjectRef>,
}
#[derive(FromArgs)]
struct GetArgs {
#[pyarg(any, optional, default = true)]
block: bool,
#[pyarg(any, optional)]
timeout: Option<TimeoutSeconds>,
}
#[pyclass(
with(Constructor, Comparable, Representable),
flags(BASETYPE, HAS_WEAKREF, IMMUTABLETYPE)
)]
impl PySimpleQueue {
#[pymethod]
fn empty(&self) -> bool {
self.buf.lock().is_empty()
}
#[pymethod]
fn qsize(&self) -> usize {
self.buf.lock().len()
}
#[pymethod]
fn put(&self, args: PutArgs) {
let PutArgs { item, .. } = args;
self.push(item);
}
#[pymethod]
fn put_nowait(&self, item: PyObjectRef) {
self.push(item);
}
#[pymethod]
fn get(&self, args: GetArgs, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
let GetArgs { block, timeout } = args;
// Non-blocking: just try once
if !block {
return Self::get_inner(&mut self.buf.lock()).ok_or_else(|| empty_error(vm));
}
#[cfg_attr(
not(feature = "threading"),
expect(
unused_variables,
reason = "We are still validating the 'timeout' arg even if we don't have threading"
)
)]
let deadline = match timeout.map(|v| v.to_secs_f64()) {
Some(v) if v < 0.0 => {
return Err(vm.new_value_error("'timeout' must be a non-negative number"));
}
Some(v) => Some(Instant::now() + Duration::from_secs_f64(v)),
None => None,
};
#[cfg(feature = "threading")]
{
if !self.sem.acquire(block, deadline, vm) {
return Err(empty_error(vm));
}
}
Self::get_inner(&mut self.buf.lock()).ok_or_else(|| empty_error(vm))
}
#[pymethod]
fn get_nowait(&self, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
#[cfg(feature = "threading")]
{
if !self.sem.acquire(false, None, vm) {
return Err(empty_error(vm));
}
}
Self::get_inner(&mut self.buf.lock()).ok_or_else(|| empty_error(vm))
}
#[pyclassmethod]
fn __class_getitem__(
cls: PyTypeRef,
args: PyObjectRef,
vm: &VirtualMachine,
) -> PyGenericAlias {
PyGenericAlias::from_args(cls, args, vm)
}
}
impl Constructor for PySimpleQueue {
type Args = ();
fn py_new(_cls: &Py<PyType>, _args: Self::Args, _vm: &VirtualMachine) -> PyResult<Self> {
Ok(Self::default())
}
}
impl AsNumber for PySimpleQueue {
fn as_number() -> &'static PyNumberMethods {
static AS_NUMBER: PyNumberMethods = PyNumberMethods {
boolean: Some(|number, _vm| {
let zelf = number.obj.downcast_ref::<PySimpleQueue>().unwrap();
Ok(!zelf.buf.lock().is_empty())
}),
..PyNumberMethods::NOT_IMPLEMENTED
};
&AS_NUMBER
}
}
impl Comparable for PySimpleQueue {
fn cmp(
zelf: &Py<Self>,
other: &PyObject,
op: PyComparisonOp,
_vm: &VirtualMachine,
) -> PyResult<PyComparisonValue> {
Ok(if let Some(res) = op.identical_optimization(zelf, other) {
res.into()
} else {
PyComparisonValue::NotImplemented
})
}
}
impl Representable for PySimpleQueue {
fn repr(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyRef<PyStr>> {
Ok(vm.ctx.new_str(format!(
"<{} at {:#x}>",
Self::class(&vm.ctx).slot_name(),
zelf.get_id()
)))
}
fn repr_str(_zelf: &Py<Self>, _vm: &VirtualMachine) -> PyResult<String> {
unreachable!("repr() is overridden directly")
}
}
}

View File

@@ -53,6 +53,7 @@ mod math;
#[cfg(all(feature = "host_env", any(unix, windows)))] #[cfg(all(feature = "host_env", any(unix, windows)))]
mod mmap; mod mmap;
mod _queue;
mod pyexpat; mod pyexpat;
mod pystruct; mod pystruct;
mod random; mod random;
@@ -225,6 +226,7 @@ pub fn stdlib_module_defs(ctx: &Context) -> Vec<&'static builtins::PyModuleDef>
posixshmem::module_def(ctx), posixshmem::module_def(ctx),
pyexpat::module_def(ctx), pyexpat::module_def(ctx),
pystruct::module_def(ctx), pystruct::module_def(ctx),
_queue::module_def(ctx),
random::module_def(ctx), random::module_def(ctx),
#[cfg(all(feature = "host_env", unix, not(target_os = "redox")))] #[cfg(all(feature = "host_env", unix, not(target_os = "redox")))]
resource::module_def(ctx), resource::module_def(ctx),

View File

@@ -537,7 +537,7 @@ impl PySet {
self.inner.len() self.inner.len()
} }
fn __contains__(&self, needle: &PyObject, vm: &VirtualMachine) -> PyResult<bool> { pub fn __contains__(&self, needle: &PyObject, vm: &VirtualMachine) -> PyResult<bool> {
self.inner.contains(needle, vm) self.inner.contains(needle, vm)
} }
@@ -679,17 +679,17 @@ impl PySet {
} }
#[pymethod] #[pymethod]
fn discard(&self, item: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { pub fn discard(&self, item: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
self.inner.discard(&item, vm).map(|_| ()) self.inner.discard(&item, vm).map(|_| ())
} }
#[pymethod] #[pymethod]
fn clear(&self) { pub fn clear(&self) {
self.inner.clear() self.inner.clear()
} }
#[pymethod] #[pymethod]
fn pop(&self, vm: &VirtualMachine) -> PyResult { pub fn pop(&self, vm: &VirtualMachine) -> PyResult {
self.inner.pop(vm) self.inner.pop(vm)
} }
@@ -995,7 +995,7 @@ impl PyFrozenSet {
self.inner.len() self.inner.len()
} }
fn __contains__(&self, needle: &PyObject, vm: &VirtualMachine) -> PyResult<bool> { pub fn __contains__(&self, needle: &PyObject, vm: &VirtualMachine) -> PyResult<bool> {
self.inner.contains(needle, vm) self.inner.contains(needle, vm)
} }

View File

@@ -42,6 +42,22 @@ impl Constructor for PyWeakProxy {
Self::Args { referent, callback }: Self::Args, Self::Args { referent, callback }: Self::Args,
vm: &VirtualMachine, vm: &VirtualMachine,
) -> PyResult<Self> { ) -> PyResult<Self> {
let weak = Self::new_weak(referent.as_ref(), callback.into_option(), vm)?;
// TODO: PyWeakProxy should use the same payload as PyWeak
Ok(Self { weak })
}
}
crate::common::static_cell! {
static WEAK_SUBCLASS: PyTypeRef;
}
impl PyWeakProxy {
fn new_weak(
referent: &PyObject,
callback: Option<PyObjectRef>,
vm: &VirtualMachine,
) -> PyResult<PyRef<PyWeak>> {
// using an internal subclass as the class prevents us from getting the generic weakref, // using an internal subclass as the class prevents us from getting the generic weakref,
// which would mess up the weakref count // which would mess up the weakref count
let weak_cls = WEAK_SUBCLASS.get_or_init(|| { let weak_cls = WEAK_SUBCLASS.get_or_init(|| {
@@ -52,15 +68,22 @@ impl Constructor for PyWeakProxy {
super::PyWeak::make_slots(), super::PyWeak::make_slots(),
) )
}); });
// TODO: PyWeakProxy should use the same payload as PyWeak referent.downgrade_with_typ(callback, weak_cls.clone(), vm)
Ok(Self {
weak: referent.downgrade_with_typ(callback.into_option(), weak_cls.clone(), vm)?,
})
} }
}
crate::common::static_cell! { pub fn new_weakproxy(
static WEAK_SUBCLASS: PyTypeRef; referent: &PyObject,
callback: Option<PyObjectRef>,
vm: &VirtualMachine,
) -> PyResult<PyRef<Self>> {
let weak = Self::new_weak(referent, callback, vm)?;
Ok(Self { weak }.into_ref(&vm.ctx))
}
#[must_use]
pub fn get_weak(&self) -> &PyRef<PyWeak> {
&self.weak
}
} }
#[pyclass(with( #[pyclass(with(