Merge pull request #4064 from youknowone/unittests

Unittests
This commit is contained in:
Jeong YunWon
2022-08-15 01:38:59 +09:00
committed by GitHub
39 changed files with 2159 additions and 1154 deletions

49
Lib/runpy.py vendored
View File

@@ -13,8 +13,9 @@ importers when locating support scripts as well as when importing modules.
import sys
import importlib.machinery # importlib first so we can test #15386 via -m
import importlib.util
import io
import types
from pkgutil import read_code, get_importer
import os
__all__ = [
"run_module", "run_path",
@@ -131,6 +132,9 @@ def _get_module_details(mod_name, error=ImportError):
# importlib, where the latter raises other errors for cases where
# pkgutil previously raised ImportError
msg = "Error while finding module specification for {!r} ({}: {})"
if mod_name.endswith(".py"):
msg += (f". Try using '{mod_name[:-3]}' instead of "
f"'{mod_name}' as the module name.")
raise error(msg.format(mod_name, type(ex).__name__, ex)) from ex
if spec is None:
raise error("No module named %s" % mod_name)
@@ -194,9 +198,24 @@ def _run_module_as_main(mod_name, alter_argv=True):
def run_module(mod_name, init_globals=None,
run_name=None, alter_sys=False):
"""Execute a module's code without importing it
"""Execute a module's code without importing it.
Returns the resulting top level namespace dictionary
mod_name -- an absolute module name or package name.
Optional arguments:
init_globals -- dictionary used to pre-populate the modules
globals dictionary before the code is executed.
run_name -- if not None, this will be used for setting __name__;
otherwise, __name__ will be set to mod_name + '__main__' if the
named module is a package and to just mod_name otherwise.
alter_sys -- if True, sys.argv[0] is updated with the value of
__file__ and sys.modules[__name__] is updated with a temporary
module object for the module being executed. Both are
restored to their original values before the function returns.
Returns the resulting module globals dictionary.
"""
mod_name, mod_spec, code = _get_module_details(mod_name)
if run_name is None:
@@ -228,27 +247,35 @@ def _get_main_module_details(error=ImportError):
def _get_code_from_file(run_name, fname):
# Check for a compiled file first
with open(fname, "rb") as f:
from pkgutil import read_code
decoded_path = os.path.abspath(os.fsdecode(fname))
with io.open_code(decoded_path) as f:
code = read_code(f)
if code is None:
# That didn't work, so try it as normal source code
with open(fname, "rb") as f:
with io.open_code(decoded_path) as f:
code = compile(f.read(), fname, 'exec')
return code, fname
def run_path(path_name, init_globals=None, run_name=None):
"""Execute code located at the specified filesystem location
"""Execute code located at the specified filesystem location.
Returns the resulting top level namespace dictionary
path_name -- filesystem location of a Python script, zipfile,
or directory containing a top level __main__.py script.
The file path may refer directly to a Python script (i.e.
one that could be directly executed with execfile) or else
it may refer to a zipfile or directory containing a top
level __main__.py script.
Optional arguments:
init_globals -- dictionary used to pre-populate the modules
globals dictionary before the code is executed.
run_name -- if not None, this will be used to set __name__;
otherwise, '<run_path>' will be used for __name__.
Returns the resulting module globals dictionary.
"""
if run_name is None:
run_name = "<run_path>"
pkg_name = run_name.rpartition(".")[0]
from pkgutil import get_importer
importer = get_importer(path_name)
# Trying to avoid importing imp so as to not consume the deprecation warning.
is_NullImporter = False

20
Lib/sched.py vendored
View File

@@ -26,23 +26,19 @@ has another way to reference private data (besides global variables).
import time
import heapq
from collections import namedtuple
from itertools import count
import threading
from time import monotonic as _time
__all__ = ["scheduler"]
class Event(namedtuple('Event', 'time, priority, action, argument, kwargs')):
__slots__ = []
def __eq__(s, o): return (s.time, s.priority) == (o.time, o.priority)
def __lt__(s, o): return (s.time, s.priority) < (o.time, o.priority)
def __le__(s, o): return (s.time, s.priority) <= (o.time, o.priority)
def __gt__(s, o): return (s.time, s.priority) > (o.time, o.priority)
def __ge__(s, o): return (s.time, s.priority) >= (o.time, o.priority)
Event = namedtuple('Event', 'time, priority, sequence, action, argument, kwargs')
Event.time.__doc__ = ('''Numeric type compatible with the return value of the
timefunc function passed to the constructor.''')
Event.priority.__doc__ = ('''Events scheduled for the same time will be executed
in the order of their priority.''')
Event.sequence.__doc__ = ('''A continually increasing sequence number that
separates events if time and priority are equal.''')
Event.action.__doc__ = ('''Executing the event means executing
action(*argument, **kwargs)''')
Event.argument.__doc__ = ('''argument is a sequence holding the positional
@@ -61,6 +57,7 @@ class scheduler:
self._lock = threading.RLock()
self.timefunc = timefunc
self.delayfunc = delayfunc
self._sequence_generator = count()
def enterabs(self, time, priority, action, argument=(), kwargs=_sentinel):
"""Enter a new event in the queue at an absolute time.
@@ -71,8 +68,10 @@ class scheduler:
"""
if kwargs is _sentinel:
kwargs = {}
event = Event(time, priority, action, argument, kwargs)
with self._lock:
event = Event(time, priority, next(self._sequence_generator),
action, argument, kwargs)
heapq.heappush(self._queue, event)
return event # The ID
@@ -136,7 +135,8 @@ class scheduler:
with lock:
if not q:
break
time, priority, action, argument, kwargs = q[0]
(time, priority, sequence, action,
argument, kwargs) = q[0]
now = timefunc()
if time > now:
delay = True

3
Lib/secrets.py vendored
View File

@@ -14,7 +14,6 @@ __all__ = ['choice', 'randbelow', 'randbits', 'SystemRandom',
import base64
import binascii
import os
from hmac import compare_digest
from random import SystemRandom
@@ -44,7 +43,7 @@ def token_bytes(nbytes=None):
"""
if nbytes is None:
nbytes = DEFAULT_ENTROPY
return os.urandom(nbytes)
return _sysrand.randbytes(nbytes)
def token_hex(nbytes=None):
"""Return a random text string, in hexadecimal.

49
Lib/socket.py vendored
View File

@@ -12,6 +12,8 @@ Functions:
socket() -- create a new socket object
socketpair() -- create a pair of new socket objects [*]
fromfd() -- create a socket object from an open file descriptor [*]
send_fds() -- Send file descriptor to the socket.
recv_fds() -- Recieve file descriptors from the socket.
fromshare() -- create a socket object from data received from socket.share() [*]
gethostname() -- return the current hostname
gethostbyname() -- map a hostname to its IP number
@@ -104,7 +106,6 @@ def _intenum_converter(value, enum_klass):
except ValueError:
return value
_realsocket = socket
# WSA error codes
if sys.platform.lower().startswith("win"):
@@ -336,6 +337,7 @@ class socket(_socket.socket):
buffer = io.BufferedWriter(raw, buffering)
if binary:
return buffer
encoding = io.text_encoding(encoding)
text = io.TextIOWrapper(buffer, encoding, errors, newline)
text.mode = mode
return text
@@ -376,7 +378,7 @@ class socket(_socket.socket):
try:
while True:
if timeout and not selector_select(timeout):
raise _socket.timeout('timed out')
raise TimeoutError('timed out')
if count:
blocksize = count - total_sent
if blocksize <= 0:
@@ -543,6 +545,40 @@ def fromfd(fd, family, type, proto=0):
nfd = dup(fd)
return socket(family, type, proto, nfd)
if hasattr(_socket.socket, "sendmsg"):
import array
def send_fds(sock, buffers, fds, flags=0, address=None):
""" send_fds(sock, buffers, fds[, flags[, address]]) -> integer
Send the list of file descriptors fds over an AF_UNIX socket.
"""
return sock.sendmsg(buffers, [(_socket.SOL_SOCKET,
_socket.SCM_RIGHTS, array.array("i", fds))])
__all__.append("send_fds")
if hasattr(_socket.socket, "recvmsg"):
import array
def recv_fds(sock, bufsize, maxfds, flags=0):
""" recv_fds(sock, bufsize, maxfds[, flags]) -> (data, list of file
descriptors, msg_flags, address)
Receive up to maxfds file descriptors returning the message
data and a list containing the descriptors.
"""
# Array of ints
fds = array.array("i")
msg, ancdata, flags, addr = sock.recvmsg(bufsize,
_socket.CMSG_LEN(maxfds * fds.itemsize))
for cmsg_level, cmsg_type, cmsg_data in ancdata:
if (cmsg_level == _socket.SOL_SOCKET and cmsg_type == _socket.SCM_RIGHTS):
fds.frombytes(cmsg_data[:
len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
return msg, list(fds), flags, addr
__all__.append("recv_fds")
if hasattr(_socket.socket, "share"):
def fromshare(info):
""" fromshare(info) -> socket object
@@ -671,7 +707,7 @@ class SocketIO(io.RawIOBase):
self._timeout_occurred = True
raise
except error as e:
if e.args[0] in _blocking_errnos:
if e.errno in _blocking_errnos:
return None
raise
@@ -687,7 +723,7 @@ class SocketIO(io.RawIOBase):
return self._sock.send(b)
except error as e:
# XXX what about EINTR?
if e.args[0] in _blocking_errnos:
if e.errno in _blocking_errnos:
return None
raise
@@ -746,8 +782,9 @@ def getfqdn(name=''):
An empty argument is interpreted as meaning the local host.
First the hostname returned by gethostbyaddr() is checked, then
possibly existing aliases. In case no FQDN is available, hostname
from gethostname() is returned.
possibly existing aliases. In case no FQDN is available and `name`
was given, it is returned unchanged. If `name` was empty or '0.0.0.0',
hostname from gethostname() is returned.
"""
name = name.strip()
if not name or name == '0.0.0.0':

72
Lib/socketserver.py vendored
View File

@@ -24,7 +24,7 @@ For request-based servers (including socket-based):
The classes in this module favor the server type that is simplest to
write: a synchronous TCP/IP server. This is bad class design, but
save some typing. (There's also the issue that a deep class hierarchy
saves some typing. (There's also the issue that a deep class hierarchy
slows down method lookups.)
There are five classes in an inheritance diagram, four of which represent
@@ -126,7 +126,6 @@ __version__ = "0.4"
import socket
import selectors
import os
import errno
import sys
try:
import threading
@@ -234,6 +233,9 @@ class BaseServer:
while not self.__shutdown_request:
ready = selector.select(poll_interval)
# bpo-35017: shutdown() called during select(), exit immediately.
if self.__shutdown_request:
break
if ready:
self._handle_request_noblock()
@@ -375,7 +377,7 @@ class BaseServer:
"""
print('-'*40, file=sys.stderr)
print('Exception happened during processing of request from',
print('Exception occurred during processing of request from',
client_address, file=sys.stderr)
import traceback
traceback.print_exc()
@@ -547,8 +549,10 @@ if hasattr(os, "fork"):
timeout = 300
active_children = None
max_children = 40
# If true, server_close() waits until all child processes complete.
block_on_close = True
def collect_children(self):
def collect_children(self, *, blocking=False):
"""Internal routine to wait for children that have exited."""
if self.active_children is None:
return
@@ -572,7 +576,8 @@ if hasattr(os, "fork"):
# Now reap all defunct children.
for pid in self.active_children.copy():
try:
pid, _ = os.waitpid(pid, os.WNOHANG)
flags = 0 if blocking else os.WNOHANG
pid, _ = os.waitpid(pid, flags)
# if the child hasn't exited yet, pid will be 0 and ignored by
# discard() below
self.active_children.discard(pid)
@@ -592,7 +597,7 @@ if hasattr(os, "fork"):
def service_actions(self):
"""Collect the zombie child processes regularly in the ForkingMixIn.
service_actions is called in the BaseServer's serve_forver loop.
service_actions is called in the BaseServer's serve_forever loop.
"""
self.collect_children()
@@ -621,6 +626,43 @@ if hasattr(os, "fork"):
finally:
os._exit(status)
def server_close(self):
super().server_close()
self.collect_children(blocking=self.block_on_close)
class _Threads(list):
"""
Joinable list of all non-daemon threads.
"""
def append(self, thread):
self.reap()
if thread.daemon:
return
super().append(thread)
def pop_all(self):
self[:], result = [], self[:]
return result
def join(self):
for thread in self.pop_all():
thread.join()
def reap(self):
self[:] = (thread for thread in self if thread.is_alive())
class _NoThreads:
"""
Degenerate version of _Threads.
"""
def append(self, thread):
pass
def join(self):
pass
class ThreadingMixIn:
"""Mix-in class to handle each request in a new thread."""
@@ -628,6 +670,11 @@ class ThreadingMixIn:
# Decides how threads will act upon termination of the
# main process
daemon_threads = False
# If true, server_close() waits until all non-daemonic threads terminate.
block_on_close = True
# Threads object
# used by server_close() to wait for all threads completion.
_threads = _NoThreads()
def process_request_thread(self, request, client_address):
"""Same as in BaseServer but as a thread.
@@ -644,11 +691,18 @@ class ThreadingMixIn:
def process_request(self, request, client_address):
"""Start a new thread to process the request."""
if self.block_on_close:
vars(self).setdefault('_threads', _Threads())
t = threading.Thread(target = self.process_request_thread,
args = (request, client_address))
t.daemon = self.daemon_threads
self._threads.append(t)
t.start()
def server_close(self):
super().server_close()
self._threads.join()
if hasattr(os, "fork"):
class ForkingUDPServer(ForkingMixIn, UDPServer): pass
@@ -773,10 +827,8 @@ class _SocketWriter(BufferedIOBase):
def write(self, b):
self._sock.sendall(b)
# XXX RustPython TODO: implement memoryview properly
#with memoryview(b) as view:
# return view.nbytes
return len(b)
with memoryview(b) as view:
return view.nbytes
def fileno(self):
return self._sock.fileno()

46
Lib/string.py vendored
View File

@@ -54,30 +54,7 @@ from collections import ChainMap as _ChainMap
_sentinel_dict = {}
class _TemplateMetaclass(type):
pattern = r"""
%(delim)s(?:
(?P<escaped>%(delim)s) | # Escape sequence of two delimiters
(?P<named>%(id)s) | # delimiter and a Python identifier
{(?P<braced>%(bid)s)} | # delimiter and a braced identifier
(?P<invalid>) # Other ill-formed delimiter exprs
)
"""
def __init__(cls, name, bases, dct):
super(_TemplateMetaclass, cls).__init__(name, bases, dct)
if 'pattern' in dct:
pattern = cls.pattern
else:
pattern = _TemplateMetaclass.pattern % {
'delim' : _re.escape(cls.delimiter),
'id' : cls.idpattern,
'bid' : cls.braceidpattern or cls.idpattern,
}
cls.pattern = _re.compile(pattern, cls.flags | _re.VERBOSE)
class Template(metaclass=_TemplateMetaclass):
class Template:
"""A string class for supporting $-substitutions."""
delimiter = '$'
@@ -89,6 +66,24 @@ class Template(metaclass=_TemplateMetaclass):
braceidpattern = None
flags = _re.IGNORECASE
def __init_subclass__(cls):
super().__init_subclass__()
if 'pattern' in cls.__dict__:
pattern = cls.pattern
else:
delim = _re.escape(cls.delimiter)
id = cls.idpattern
bid = cls.braceidpattern or cls.idpattern
pattern = fr"""
{delim}(?:
(?P<escaped>{delim}) | # Escape sequence of two delimiters
(?P<named>{id}) | # delimiter and a Python identifier
{{(?P<braced>{bid})}} | # delimiter and a braced identifier
(?P<invalid>) # Other ill-formed delimiter exprs
)
"""
cls.pattern = _re.compile(pattern, cls.flags | _re.VERBOSE)
def __init__(self, template):
self.template = template
@@ -146,6 +141,9 @@ class Template(metaclass=_TemplateMetaclass):
self.pattern)
return self.pattern.sub(convert, self.template)
# Initialize Template.pattern. __init_subclass__() is automatically called
# only for subclasses, not for the Template class itself.
Template.__init_subclass__()
########################################################################

10
Lib/tarfile.py vendored
View File

@@ -1163,6 +1163,11 @@ class TarInfo(object):
# header information.
self._apply_pax_info(tarfile.pax_headers, tarfile.encoding, tarfile.errors)
# Remove redundant slashes from directories. This is to be consistent
# with frombuf().
if self.isdir():
self.name = self.name.rstrip("/")
return self
def _proc_gnulong(self, tarfile):
@@ -1185,6 +1190,11 @@ class TarInfo(object):
elif self.type == GNUTYPE_LONGLINK:
next.linkname = nts(buf, tarfile.encoding, tarfile.errors)
# Remove redundant slashes from directories. This is to be consistent
# with frombuf().
if next.isdir():
next.name = next.name.removesuffix("/")
return next
def _proc_sparse(self, tarfile):

View File

@@ -523,8 +523,6 @@ class CmdLineTest(unittest.TestCase):
self.assertNotIn(b'is a package', err)
self.assertNotIn(b'Traceback', err)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_hint_when_triying_to_import_a_py_file(self):
with os_helper.temp_dir() as script_dir, \
os_helper.change_cwd(path=script_dir):

28
Lib/test/test_repl.py vendored
View File

@@ -29,7 +29,9 @@ def spawn_repl(*args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, **kw):
# test.support.script_helper.
env = kw.setdefault('env', dict(os.environ))
env['TERM'] = 'vt100'
return subprocess.Popen(cmd_line, executable=sys.executable,
return subprocess.Popen(cmd_line,
executable=sys.executable,
text=True,
stdin=subprocess.PIPE,
stdout=stdout, stderr=stderr,
**kw)
@@ -49,12 +51,11 @@ class TestInteractiveInterpreter(unittest.TestCase):
sys.exit(0)
"""
user_input = dedent(user_input)
user_input = user_input.encode()
p = spawn_repl()
with SuppressCrashReport():
p.stdin.write(user_input)
output = kill_python(p)
self.assertIn(b'After the exception.', output)
self.assertIn('After the exception.', output)
# Exit code 120: Py_FinalizeEx() failed to flush stdout and stderr.
self.assertIn(p.returncode, (1, 120))
@@ -86,13 +87,28 @@ class TestInteractiveInterpreter(unittest.TestCase):
</test>"""
'''
user_input = dedent(user_input)
user_input = user_input.encode()
p = spawn_repl()
with SuppressCrashReport():
p.stdin.write(user_input)
p.stdin.write(user_input)
output = kill_python(p)
self.assertEqual(p.returncode, 0)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_close_stdin(self):
user_input = dedent('''
import os
print("before close")
os.close(0)
''')
prepare_repl = dedent('''
from test.support import suppress_msvcrt_asserts
suppress_msvcrt_asserts()
''')
process = spawn_repl('-c', prepare_repl)
output = process.communicate(user_input)[0]
self.assertEqual(process.returncode, 0)
self.assertIn('before close', output)
if __name__ == "__main__":
unittest.main()

View File

@@ -12,10 +12,10 @@ import tempfile
import textwrap
import unittest
import warnings
from test.support import verbose, no_tracing
from test.support import no_tracing, verbose
from test.support.import_helper import forget, make_legacy_pyc, unload
from test.support.os_helper import create_empty_file, temp_dir
from test.support.script_helper import make_script, make_zip_script
from test.support.import_helper import unload, forget, make_legacy_pyc
import runpy
@@ -745,8 +745,7 @@ class RunPathTestCase(unittest.TestCase, CodeExecutionMixin):
"runpy.run_path(%r)\n") % dummy_dir
script_name = self._make_test_script(script_dir, mod_name, source)
zip_name, fname = make_zip_script(script_dir, 'test_zip', script_name)
msg = "recursion depth exceeded"
self.assertRaisesRegex(RecursionError, msg, run_path, zip_name)
self.assertRaises(RecursionError, run_path, zip_name)
# TODO: RUSTPYTHON, detect encoding comments in files
@unittest.expectedFailure

View File

@@ -142,6 +142,17 @@ class TestCase(unittest.TestCase):
self.assertTrue(q.empty())
self.assertEqual(timer.time(), 4)
def test_cancel_correct_event(self):
# bpo-19270
events = []
scheduler = sched.scheduler()
scheduler.enterabs(1, 1, events.append, ("a",))
b = scheduler.enterabs(1, 1, events.append, ("b",))
scheduler.enterabs(1, 1, events.append, ("c",))
scheduler.cancel(b)
scheduler.run()
self.assertEqual(events, ["a", "c"])
def test_empty(self):
l = []
fun = lambda x: l.append(x)

View File

@@ -2,11 +2,11 @@ import unittest
import weakref
from test.support import check_syntax_error, cpython_only
from test.support import gc_collect
class ScopeTests(unittest.TestCase):
def testSimpleNesting(self):
def make_adder(x):
@@ -20,7 +20,6 @@ class ScopeTests(unittest.TestCase):
self.assertEqual(inc(1), 2)
self.assertEqual(plus10(-2), 8)
def testExtraNesting(self):
def make_adder2(x):
@@ -36,7 +35,6 @@ class ScopeTests(unittest.TestCase):
self.assertEqual(inc(1), 2)
self.assertEqual(plus10(-2), 8)
def testSimpleAndRebinding(self):
def make_adder3(x):
@@ -50,8 +48,7 @@ class ScopeTests(unittest.TestCase):
self.assertEqual(inc(1), 2)
self.assertEqual(plus10(-2), 8)
def testNestingGlobalNoFree(self):
def make_adder4(): # XXX add exta level of indirection
@@ -70,7 +67,6 @@ class ScopeTests(unittest.TestCase):
global_x = 10
self.assertEqual(adder(-2), 8)
def testNestingThroughClass(self):
def make_adder5(x):
@@ -85,7 +81,6 @@ class ScopeTests(unittest.TestCase):
self.assertEqual(inc(1), 2)
self.assertEqual(plus10(-2), 8)
def testNestingPlusFreeRefToGlobal(self):
def make_adder6(x):
@@ -101,7 +96,6 @@ class ScopeTests(unittest.TestCase):
self.assertEqual(inc(1), 11) # there's only one global
self.assertEqual(plus10(-2), 8)
def testNearestEnclosingScope(self):
def f(x):
@@ -115,7 +109,6 @@ class ScopeTests(unittest.TestCase):
test_func = f(10)
self.assertEqual(test_func(5), 47)
def testMixedFreevarsAndCellvars(self):
def identity(x):
@@ -173,7 +166,6 @@ class ScopeTests(unittest.TestCase):
self.assertEqual(t.method_and_var(), "method")
self.assertEqual(t.actual_global(), "global")
def testCellIsKwonlyArg(self):
# Issue 1409: Initialisation of a cell value,
# when it comes from a keyword-only parameter
@@ -185,7 +177,6 @@ class ScopeTests(unittest.TestCase):
self.assertEqual(foo(a=42), 50)
self.assertEqual(foo(), 25)
def testRecursion(self):
def f(x):
@@ -201,6 +192,7 @@ class ScopeTests(unittest.TestCase):
self.assertEqual(f(6), 720)
def testUnoptimizedNamespaces(self):
check_syntax_error(self, """if 1:
@@ -235,7 +227,6 @@ class ScopeTests(unittest.TestCase):
return getrefcount # global or local?
""")
def testLambdas(self):
f1 = lambda x: lambda y: x + y
@@ -312,8 +303,6 @@ class ScopeTests(unittest.TestCase):
fail('scope of global_x not correctly determined')
""", {'fail': self.fail})
def testComplexDefinitions(self):
def makeReturner(*lst):
@@ -348,6 +337,7 @@ class ScopeTests(unittest.TestCase):
return g()
self.assertEqual(f(), 7)
self.assertEqual(x, 7)
# II
x = 7
def f():
@@ -362,6 +352,7 @@ class ScopeTests(unittest.TestCase):
return g()
self.assertEqual(f(), 2)
self.assertEqual(x, 7)
# III
x = 7
def f():
@@ -377,6 +368,7 @@ class ScopeTests(unittest.TestCase):
return g()
self.assertEqual(f(), 2)
self.assertEqual(x, 2)
# IV
x = 7
def f():
@@ -392,8 +384,10 @@ class ScopeTests(unittest.TestCase):
return g()
self.assertEqual(f(), 2)
self.assertEqual(x, 2)
# XXX what about global statements in class blocks?
# do they affect methods?
x = 12
class Global:
global x
@@ -402,6 +396,7 @@ class ScopeTests(unittest.TestCase):
x = val
def get(self):
return x
g = Global()
self.assertEqual(g.get(), 13)
g.set(15)
@@ -428,6 +423,7 @@ class ScopeTests(unittest.TestCase):
for i in range(100):
f1()
gc_collect() # For PyPy or other GCs.
self.assertEqual(Foo.count, 0)
def testClassAndGlobal(self):
@@ -439,16 +435,19 @@ class ScopeTests(unittest.TestCase):
def __call__(self, y):
return x + y
return Foo()
x = 0
self.assertEqual(test(6)(2), 8)
x = -1
self.assertEqual(test(3)(2), 5)
looked_up_by_load_name = False
class X:
# Implicit globals inside classes are be looked up by LOAD_NAME, not
# LOAD_GLOBAL.
locals()['looked_up_by_load_name'] = True
passed = looked_up_by_load_name
self.assertTrue(X.passed)
""")
@@ -468,7 +467,6 @@ class ScopeTests(unittest.TestCase):
del d['h']
self.assertEqual(d, {'x': 2, 'y': 7, 'w': 6})
def testLocalsClass(self):
# This test verifies that calling locals() does not pollute
# the local namespace of the class with free variables. Old
@@ -502,7 +500,6 @@ class ScopeTests(unittest.TestCase):
self.assertNotIn("x", varnames)
self.assertIn("y", varnames)
@cpython_only
def testLocalsClass_WithTrace(self):
# Issue23728: after the trace function returns, the locals()
@@ -520,7 +517,6 @@ class ScopeTests(unittest.TestCase):
self.assertEqual(x, 12) # Used to raise UnboundLocalError
def testBoundAndFree(self):
# var is bound and free in class
@@ -534,7 +530,6 @@ class ScopeTests(unittest.TestCase):
inst = f(3)()
self.assertEqual(inst.a, inst.m())
@cpython_only
def testInteractionWithTraceFunc(self):
@@ -574,7 +569,6 @@ class ScopeTests(unittest.TestCase):
else:
self.fail("exec should have failed, because code contained free vars")
def testListCompLocalVars(self):
try:
@@ -603,7 +597,6 @@ class ScopeTests(unittest.TestCase):
f(4)()
def testFreeingCell(self):
# Test what happens when a finalizer accesses
# the cell where the object was stored.
@@ -611,7 +604,6 @@ class ScopeTests(unittest.TestCase):
def __del__(self):
nestedcell_get()
def testNonLocalFunction(self):
def f(x):
@@ -631,7 +623,6 @@ class ScopeTests(unittest.TestCase):
self.assertEqual(dec(), 1)
self.assertEqual(dec(), 0)
def testNonLocalMethod(self):
def f(x):
class c:
@@ -678,7 +669,6 @@ class ScopeTests(unittest.TestCase):
self.assertEqual(2, global_ns["result2"])
self.assertEqual(9, global_ns["result9"])
def testNonLocalClass(self):
def f(x):
@@ -693,7 +683,7 @@ class ScopeTests(unittest.TestCase):
self.assertEqual(c.get(), 1)
self.assertNotIn("x", c.__class__.__dict__)
def testNonLocalGenerator(self):
def f(x):
@@ -707,7 +697,6 @@ class ScopeTests(unittest.TestCase):
g = f(0)
self.assertEqual(list(g(5)), [1, 2, 3, 4, 5])
def testNestedNonLocal(self):
def f(x):
@@ -725,7 +714,6 @@ class ScopeTests(unittest.TestCase):
h = g()
self.assertEqual(h(), 3)
def testTopIsNotSignificant(self):
# See #9997.
def top(a):
@@ -746,7 +734,6 @@ class ScopeTests(unittest.TestCase):
self.assertFalse(hasattr(X, "x"))
self.assertEqual(x, 42)
@cpython_only
def testCellLeak(self):
# Issue 17927.
@@ -773,8 +760,9 @@ class ScopeTests(unittest.TestCase):
tester.dig()
ref = weakref.ref(tester)
del tester
gc_collect() # For PyPy or other GCs.
self.assertIsNone(ref())
if __name__ == '__main__':
unittest.main()
unittest.main()

File diff suppressed because it is too large Load Diff

View File

@@ -39,7 +39,7 @@ def signal_alarm(n):
# Remember real select() to avoid interferences with mocking
_real_select = select.select
def receive(sock, n, timeout=20):
def receive(sock, n, timeout=test.support.SHORT_TIMEOUT):
r, w, x = _real_select([sock], [], [], timeout)
if sock in r:
return sock.recv(n)
@@ -68,9 +68,7 @@ def simple_subprocess(testcase):
except:
raise
finally:
pid2, status = os.waitpid(pid, 0)
testcase.assertEqual(pid2, pid)
testcase.assertEqual(72 << 8, status)
test.support.wait_process(pid, exitcode=72)
class SocketServerTest(unittest.TestCase):
@@ -345,8 +343,11 @@ class ErrorHandlerTest(unittest.TestCase):
self.check_result(handled=True)
def test_threading_not_handled(self):
ThreadingErrorTestServer(SystemExit)
self.check_result(handled=False)
with threading_helper.catch_threading_exception() as cm:
ThreadingErrorTestServer(SystemExit)
self.check_result(handled=False)
self.assertIs(cm.exc_type, SystemExit)
@requires_forking
def test_forking_handled(self):
@@ -520,8 +521,6 @@ class MiscTestCase(unittest.TestCase):
self.assertEqual(server.shutdown_called, 1)
server.server_close()
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_threads_reaped(self):
"""
In #37193, users reported a memory leak

17
Lib/test/test_stat.py vendored
View File

@@ -2,9 +2,11 @@ import unittest
import os
import socket
import sys
from test.support.socket_helper import skip_unless_bind_unix_socket
from test.support.os_helper import TESTFN
from test.support import os_helper
from test.support import socket_helper
from test.support.import_helper import import_fresh_module
from test.support.os_helper import TESTFN
c_stat = import_fresh_module('stat', fresh=['_stat'])
py_stat = import_fresh_module('stat', blocked=['_stat'])
@@ -172,11 +174,16 @@ class TestFilemode:
@unittest.skipUnless(hasattr(os, 'mkfifo'), 'os.mkfifo not available')
def test_fifo(self):
if sys.platform == "vxworks":
fifo_path = os.path.join("/fifos/", TESTFN)
else:
fifo_path = TESTFN
self.addCleanup(os_helper.unlink, fifo_path)
try:
os.mkfifo(TESTFN, 0o700)
os.mkfifo(fifo_path, 0o700)
except PermissionError as e:
self.skipTest('os.mkfifo(): %s' % e)
st_mode, modestr = self.get_mode()
st_mode, modestr = self.get_mode(fifo_path)
self.assertEqual(modestr, 'prwx------')
self.assertS_IS("FIFO", st_mode)
@@ -194,7 +201,7 @@ class TestFilemode:
self.assertS_IS("BLK", st_mode)
break
@skip_unless_bind_unix_socket
@socket_helper.skip_unless_bind_unix_socket
def test_socket(self):
with socket.socket(socket.AF_UNIX) as s:
s.bind(TESTFN)

View File

@@ -115,7 +115,7 @@ class StrftimeTest(unittest.TestCase):
)
for e in expectations:
# musn't raise a value error
# mustn't raise a value error
try:
result = time.strftime(e[0], now)
except ValueError as error:

View File

@@ -63,8 +63,6 @@ def byte(i):
class TestLiterals(unittest.TestCase):
from test.support.warnings_helper import check_syntax_warning
def setUp(self):
self.save_path = sys.path[:]
self.tmpdir = tempfile.mkdtemp()
@@ -133,6 +131,7 @@ class TestLiterals(unittest.TestCase):
self.assertEqual(w, [])
self.assertEqual(exc.filename, '<string>')
self.assertEqual(exc.lineno, 1)
self.assertEqual(exc.offset, 1)
# TODO: RUSTPYTHON
@unittest.expectedFailure

View File

@@ -1,12 +1,16 @@
from collections import abc
import array
import gc
import math
import operator
import unittest
import struct
import sys
import weakref
from test import support
from test.support import import_helper
from test.support.script_helper import assert_python_ok
ISBIGENDIAN = sys.byteorder == "big"
@@ -654,6 +658,62 @@ class StructTest(unittest.TestCase):
s2 = struct.Struct(s.format.encode())
self.assertEqual(s2.format, s.format)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_struct_cleans_up_at_runtime_shutdown(self):
code = """if 1:
import struct
class C:
def __init__(self):
self.pack = struct.pack
def __del__(self):
self.pack('I', -42)
struct.x = C()
"""
rc, stdout, stderr = assert_python_ok("-c", code)
self.assertEqual(rc, 0)
self.assertEqual(stdout.rstrip(), b"")
self.assertIn(b"Exception ignored in:", stderr)
self.assertIn(b"C.__del__", stderr)
def test__struct_reference_cycle_cleaned_up(self):
# Regression test for python/cpython#94207.
# When we create a new struct module, trigger use of its cache,
# and then delete it ...
_struct_module = import_helper.import_fresh_module("_struct")
module_ref = weakref.ref(_struct_module)
_struct_module.calcsize("b")
del _struct_module
# Then the module should have been garbage collected.
gc.collect()
self.assertIsNone(
module_ref(), "_struct module was not garbage collected")
@support.cpython_only
def test__struct_types_immutable(self):
# See https://github.com/python/cpython/issues/94254
Struct = struct.Struct
unpack_iterator = type(struct.iter_unpack("b", b'x'))
for cls in (Struct, unpack_iterator):
with self.subTest(cls=cls):
with self.assertRaises(TypeError):
cls.x = 1
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_issue35714(self):
# Embedded null characters should not be allowed in format strings.
for s in '\0', '2\0i', b'\0':
with self.assertRaisesRegex(struct.error,
'embedded null character'):
struct.calcsize(s)
class UnpackIteratorTest(unittest.TestCase):
"""
@@ -681,6 +741,12 @@ class UnpackIteratorTest(unittest.TestCase):
with self.assertRaises(struct.error):
s.iter_unpack(b"12")
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_uninstantiable(self):
iter_unpack_type = type(struct.Struct(">ibcp").iter_unpack(b""))
self.assertRaises(TypeError, iter_unpack_type)
def test_iterate(self):
s = struct.Struct('>IB')
b = bytes(range(1, 16))

View File

@@ -124,5 +124,21 @@ class StructSeqTest(unittest.TestCase):
self.assertEqual(list(t[start:stop:step]),
L[start:stop:step])
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_match_args(self):
expected_args = ('tm_year', 'tm_mon', 'tm_mday', 'tm_hour', 'tm_min',
'tm_sec', 'tm_wday', 'tm_yday', 'tm_isdst')
self.assertEqual(time.struct_time.__match_args__, expected_args)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_match_args_with_unnamed_fields(self):
expected_args = ('st_mode', 'st_ino', 'st_dev', 'st_nlink', 'st_uid',
'st_gid', 'st_size')
self.assertEqual(os.stat_result.n_unnamed_fields, 3)
self.assertEqual(os.stat_result.__match_args__, expected_args)
if __name__ == "__main__":
unittest.main()

View File

@@ -220,6 +220,26 @@ class UstarReadTest(ReadTest, unittest.TestCase):
def test_issue14160(self):
self._test_fileobj_link("symtype2", "ustar/regtype")
def test_add_dir_getmember(self):
# bpo-21987
self.add_dir_and_getmember('bar')
self.add_dir_and_getmember('a'*101)
def add_dir_and_getmember(self, name):
with os_helper.temp_cwd():
with tarfile.open(tmpname, 'w') as tar:
tar.format = tarfile.USTAR_FORMAT
try:
os.mkdir(name)
tar.add(name)
finally:
os.rmdir(name)
with tarfile.open(tmpname) as tar:
self.assertEqual(
tar.getmember(name),
tar.getmember(name + '/')
)
class GzipUstarReadTest(GzipTest, UstarReadTest):
pass
@@ -330,6 +350,38 @@ class LzmaListTest(LzmaTest, ListTest):
class CommonReadTest(ReadTest):
def test_is_tarfile_erroneous(self):
with open(tmpname, "wb"):
pass
# is_tarfile works on filenames
self.assertFalse(tarfile.is_tarfile(tmpname))
# is_tarfile works on path-like objects
self.assertFalse(tarfile.is_tarfile(pathlib.Path(tmpname)))
# is_tarfile works on file objects
with open(tmpname, "rb") as fobj:
self.assertFalse(tarfile.is_tarfile(fobj))
# is_tarfile works on file-like objects
self.assertFalse(tarfile.is_tarfile(io.BytesIO(b"invalid")))
def test_is_tarfile_valid(self):
# is_tarfile works on filenames
self.assertTrue(tarfile.is_tarfile(self.tarname))
# is_tarfile works on path-like objects
self.assertTrue(tarfile.is_tarfile(pathlib.Path(self.tarname)))
# is_tarfile works on file objects
with open(self.tarname, "rb") as fobj:
self.assertTrue(tarfile.is_tarfile(fobj))
# is_tarfile works on file-like objects
with open(self.tarname, "rb") as fobj:
self.assertTrue(tarfile.is_tarfile(io.BytesIO(fobj.read())))
def test_empty_tarfile(self):
# Test for issue6123: Allow opening empty archives.
# This test checks if tarfile.open() is able to open an empty tar
@@ -365,7 +417,7 @@ class CommonReadTest(ReadTest):
def test_ignore_zeros(self):
# Test TarFile's ignore_zeros option.
# generate 512 pseudorandom bytes
data = Random(0).getrandbits(512*8).to_bytes(512, 'big')
data = Random(0).randbytes(512)
for char in (b'\0', b'a'):
# Test if EOFHeaderError ('\0') and InvalidHeaderError ('a')
# are ignored correctly.
@@ -682,6 +734,16 @@ class MiscReadTestBase(CommonReadTest):
self.assertEqual(m1.offset, m2.offset)
self.assertEqual(m1.get_info(), m2.get_info())
@unittest.skipIf(zlib is None, "requires zlib")
def test_zlib_error_does_not_leak(self):
# bpo-39039: tarfile.open allowed zlib exceptions to bubble up when
# parsing certain types of invalid data
with unittest.mock.patch("tarfile.TarInfo.fromtarfile") as mock:
mock.side_effect = zlib.error
with self.assertRaises(tarfile.ReadError):
tarfile.open(self.tarname)
class MiscReadTest(MiscReadTestBase, unittest.TestCase):
test_fail_comp = None
@@ -968,11 +1030,26 @@ class LongnameTest:
"iso8859-1", "strict")
self.assertEqual(tarinfo.type, self.longnametype)
def test_longname_directory(self):
# Test reading a longlink directory. Issue #47231.
longdir = ('a' * 101) + '/'
with os_helper.temp_cwd():
with tarfile.open(tmpname, 'w') as tar:
tar.format = self.format
try:
os.mkdir(longdir)
tar.add(longdir)
finally:
os.rmdir(longdir)
with tarfile.open(tmpname) as tar:
self.assertIsNotNone(tar.getmember(longdir))
self.assertIsNotNone(tar.getmember(longdir.removesuffix('/')))
class GNUReadTest(LongnameTest, ReadTest, unittest.TestCase):
subdir = "gnu"
longnametype = tarfile.GNUTYPE_LONGNAME
format = tarfile.GNU_FORMAT
# Since 3.2 tarfile is supposed to accurately restore sparse members and
# produce files with holes. This is what we actually want to test here.
@@ -1048,6 +1125,7 @@ class PaxReadTest(LongnameTest, ReadTest, unittest.TestCase):
subdir = "pax"
longnametype = tarfile.XHDTYPE
format = tarfile.PAX_FORMAT
def test_pax_global_headers(self):
tar = tarfile.open(tarname, encoding="iso8859-1")
@@ -1600,6 +1678,52 @@ class GNUWriteTest(unittest.TestCase):
("longlnk/" * 127) + "longlink_")
class DeviceHeaderTest(WriteTestBase, unittest.TestCase):
prefix = "w:"
def test_headers_written_only_for_device_files(self):
# Regression test for bpo-18819.
tempdir = os.path.join(TEMPDIR, "device_header_test")
os.mkdir(tempdir)
try:
tar = tarfile.open(tmpname, self.mode)
try:
input_blk = tarfile.TarInfo(name="my_block_device")
input_reg = tarfile.TarInfo(name="my_regular_file")
input_blk.type = tarfile.BLKTYPE
input_reg.type = tarfile.REGTYPE
tar.addfile(input_blk)
tar.addfile(input_reg)
finally:
tar.close()
# devmajor and devminor should be *interpreted* as 0 in both...
tar = tarfile.open(tmpname, "r")
try:
output_blk = tar.getmember("my_block_device")
output_reg = tar.getmember("my_regular_file")
finally:
tar.close()
self.assertEqual(output_blk.devmajor, 0)
self.assertEqual(output_blk.devminor, 0)
self.assertEqual(output_reg.devmajor, 0)
self.assertEqual(output_reg.devminor, 0)
# ...but the fields should not actually be set on regular files:
with open(tmpname, "rb") as infile:
buf = infile.read()
buf_blk = buf[output_blk.offset:output_blk.offset_data]
buf_reg = buf[output_reg.offset:output_reg.offset_data]
# See `struct posixheader` in GNU docs for byte offsets:
# <https://www.gnu.org/software/tar/manual/html_node/Standard.html>
device_headers = slice(329, 329 + 16)
self.assertEqual(buf_blk[device_headers], b"0000000\0" * 2)
self.assertEqual(buf_reg[device_headers], b"\0" * 16)
finally:
os_helper.rmtree(tempdir)
class CreateTest(WriteTestBase, unittest.TestCase):
prefix = "x:"
@@ -1691,15 +1815,30 @@ class CreateTest(WriteTestBase, unittest.TestCase):
class GzipCreateTest(GzipTest, CreateTest):
pass
def test_create_with_compresslevel(self):
with tarfile.open(tmpname, self.mode, compresslevel=1) as tobj:
tobj.add(self.file_path)
with tarfile.open(tmpname, 'r:gz', compresslevel=1) as tobj:
pass
class Bz2CreateTest(Bz2Test, CreateTest):
pass
def test_create_with_compresslevel(self):
with tarfile.open(tmpname, self.mode, compresslevel=1) as tobj:
tobj.add(self.file_path)
with tarfile.open(tmpname, 'r:bz2', compresslevel=1) as tobj:
pass
class LzmaCreateTest(LzmaTest, CreateTest):
pass
# Unlike gz and bz2, xz uses the preset keyword instead of compresslevel.
# It does not allow for preset to be specified when reading.
def test_create_with_preset(self):
with tarfile.open(tmpname, self.mode, preset=1) as tobj:
tobj.add(self.file_path)
class CreateWithXModeTest(CreateTest):
@@ -1840,6 +1979,61 @@ class PaxWriteTest(GNUWriteTest):
finally:
tar.close()
def test_create_pax_header(self):
# The ustar header should contain values that can be
# represented reasonably, even if a better (e.g. higher
# precision) version is set in the pax header.
# Issue #45863
# values that should be kept
t = tarfile.TarInfo()
t.name = "foo"
t.mtime = 1000.1
t.size = 100
t.uid = 123
t.gid = 124
info = t.get_info()
header = t.create_pax_header(info, encoding="iso8859-1")
self.assertEqual(info['name'], "foo")
# mtime should be rounded to nearest second
self.assertIsInstance(info['mtime'], int)
self.assertEqual(info['mtime'], 1000)
self.assertEqual(info['size'], 100)
self.assertEqual(info['uid'], 123)
self.assertEqual(info['gid'], 124)
self.assertEqual(header,
b'././@PaxHeader' + bytes(86) \
+ b'0000000\x000000000\x000000000\x0000000000020\x0000000000000\x00010205\x00 x' \
+ bytes(100) + b'ustar\x0000'+ bytes(247) \
+ b'16 mtime=1000.1\n' + bytes(496) + b'foo' + bytes(97) \
+ b'0000644\x000000173\x000000174\x0000000000144\x0000000001750\x00006516\x00 0' \
+ bytes(100) + b'ustar\x0000' + bytes(247))
# values that should be changed
t = tarfile.TarInfo()
t.name = "foo\u3374" # can't be represented in ascii
t.mtime = 10**10 # too big
t.size = 10**10 # too big
t.uid = 8**8 # too big
t.gid = 8**8+1 # too big
info = t.get_info()
header = t.create_pax_header(info, encoding="iso8859-1")
# name is kept as-is in info but should be added to pax header
self.assertEqual(info['name'], "foo\u3374")
self.assertEqual(info['mtime'], 0)
self.assertEqual(info['size'], 0)
self.assertEqual(info['uid'], 0)
self.assertEqual(info['gid'], 0)
self.assertEqual(header,
b'././@PaxHeader' + bytes(86) \
+ b'0000000\x000000000\x000000000\x0000000000130\x0000000000000\x00010207\x00 x' \
+ bytes(100) + b'ustar\x0000' + bytes(247) \
+ b'15 path=foo\xe3\x8d\xb4\n16 uid=16777216\n' \
+ b'16 gid=16777217\n20 size=10000000000\n' \
+ b'21 mtime=10000000000\n'+ bytes(424) + b'foo?' + bytes(96) \
+ b'0000644\x000000000\x000000000\x0000000000000\x0000000000000\x00006540\x00 0' \
+ bytes(100) + b'ustar\x0000' + bytes(247))
class UnicodeTest:
@@ -2274,23 +2468,34 @@ class MiscTest(unittest.TestCase):
tarfile.itn(0x10000000000, 6, tarfile.GNU_FORMAT)
def test__all__(self):
not_exported = {'version', 'grp', 'pwd', 'symlink_exception',
'NUL', 'BLOCKSIZE', 'RECORDSIZE', 'GNU_MAGIC',
'POSIX_MAGIC', 'LENGTH_NAME', 'LENGTH_LINK',
'LENGTH_PREFIX', 'REGTYPE', 'AREGTYPE', 'LNKTYPE',
'SYMTYPE', 'CHRTYPE', 'BLKTYPE', 'DIRTYPE', 'FIFOTYPE',
'CONTTYPE', 'GNUTYPE_LONGNAME', 'GNUTYPE_LONGLINK',
'GNUTYPE_SPARSE', 'XHDTYPE', 'XGLTYPE', 'SOLARIS_XHDTYPE',
'SUPPORTED_TYPES', 'REGULAR_TYPES', 'GNU_TYPES',
'PAX_FIELDS', 'PAX_NAME_FIELDS', 'PAX_NUMBER_FIELDS',
'stn', 'nts', 'nti', 'itn', 'calc_chksums', 'copyfileobj',
'filemode',
'EmptyHeaderError', 'TruncatedHeaderError',
'EOFHeaderError', 'InvalidHeaderError',
'SubsequentHeaderError', 'ExFileObject',
'main'}
not_exported = {
'version', 'grp', 'pwd', 'symlink_exception', 'NUL', 'BLOCKSIZE',
'RECORDSIZE', 'GNU_MAGIC', 'POSIX_MAGIC', 'LENGTH_NAME',
'LENGTH_LINK', 'LENGTH_PREFIX', 'REGTYPE', 'AREGTYPE', 'LNKTYPE',
'SYMTYPE', 'CHRTYPE', 'BLKTYPE', 'DIRTYPE', 'FIFOTYPE', 'CONTTYPE',
'GNUTYPE_LONGNAME', 'GNUTYPE_LONGLINK', 'GNUTYPE_SPARSE',
'XHDTYPE', 'XGLTYPE', 'SOLARIS_XHDTYPE', 'SUPPORTED_TYPES',
'REGULAR_TYPES', 'GNU_TYPES', 'PAX_FIELDS', 'PAX_NAME_FIELDS',
'PAX_NUMBER_FIELDS', 'stn', 'nts', 'nti', 'itn', 'calc_chksums',
'copyfileobj', 'filemode', 'EmptyHeaderError',
'TruncatedHeaderError', 'EOFHeaderError', 'InvalidHeaderError',
'SubsequentHeaderError', 'ExFileObject', 'main'}
support.check__all__(self, tarfile, not_exported=not_exported)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_useful_error_message_when_modules_missing(self):
fname = os.path.join(os.path.dirname(__file__), 'testtar.tar.xz')
with self.assertRaises(tarfile.ReadError) as excinfo:
error = tarfile.CompressionError('lzma module is not available'),
with unittest.mock.patch.object(tarfile.TarFile, 'xzopen', side_effect=error):
tarfile.open(fname)
self.assertIn(
"\n- method xz: CompressionError('lzma module is not available')\n",
str(excinfo.exception),
)
class CommandLineTest(unittest.TestCase):
@@ -2330,7 +2535,8 @@ class CommandLineTest(unittest.TestCase):
def test_test_command_verbose(self):
for tar_name in testtarnames:
for opt in '-v', '--verbose':
out = self.tarfilecmd(opt, '-t', tar_name)
out = self.tarfilecmd(opt, '-t', tar_name,
PYTHONIOENCODING='utf-8')
self.assertIn(b'is a tar archive.\n', out)
def test_test_command_invalid_file(self):
@@ -2405,7 +2611,8 @@ class CommandLineTest(unittest.TestCase):
'and-utf8-bom-sig-only.txt')]
for opt in '-v', '--verbose':
try:
out = self.tarfilecmd(opt, '-c', tmpname, *files)
out = self.tarfilecmd(opt, '-c', tmpname, *files,
PYTHONIOENCODING='utf-8')
self.assertIn(b' file created.', out)
with tarfile.open(tmpname) as tar:
tar.getmembers()
@@ -2463,7 +2670,8 @@ class CommandLineTest(unittest.TestCase):
for opt in '-v', '--verbose':
try:
with os_helper.temp_cwd(tarextdir):
out = self.tarfilecmd(opt, '-e', tmpname)
out = self.tarfilecmd(opt, '-e', tmpname,
PYTHONIOENCODING='utf-8')
self.assertIn(b' file is extracted.', out)
finally:
os_helper.rmtree(tarextdir)

1172
Lib/test/test_trace.py vendored

File diff suppressed because it is too large Load Diff

View File

@@ -5,7 +5,8 @@ import sys
import subprocess
from unittest import mock
from test import support
from test.support import os_helper, import_helper
from test.support import import_helper
from test.support import os_helper
URL = 'http://www.example.com'
@@ -171,29 +172,21 @@ class OperaCommandTest(CommandTestMixin, unittest.TestCase):
browser_class = webbrowser.Opera
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_open(self):
self._test('open',
options=[],
arguments=[URL])
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_open_with_autoraise_false(self):
self._test('open', kw=dict(autoraise=False),
options=[],
arguments=[URL])
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_open_new(self):
self._test('open_new',
options=['--new-window'],
arguments=[URL])
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_open_new_tab(self):
self._test('open_new_tab',
options=[],
@@ -267,23 +260,17 @@ class BrowserRegistrationTest(unittest.TestCase):
self.assertEqual(webbrowser._tryorder, expected_tryorder)
self.assertEqual(webbrowser._browsers, expected_browsers)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_register(self):
self._check_registration(preferred=False)
def test_register_default(self):
self._check_registration(preferred=None)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_register_preferred(self):
self._check_registration(preferred=True)
class ImportTest(unittest.TestCase):
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_register(self):
webbrowser = import_helper.import_fresh_module('webbrowser')
self.assertIsNone(webbrowser._tryorder)
@@ -298,8 +285,6 @@ class ImportTest(unittest.TestCase):
self.assertIn('example1', webbrowser._browsers)
self.assertEqual(webbrowser._browsers['example1'], [ExampleBrowser, None])
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_get(self):
webbrowser = import_helper.import_fresh_module('webbrowser')
self.assertIsNone(webbrowser._tryorder)

View File

@@ -7,7 +7,7 @@ __email__ = "mbland at acm dot org"
import sys
import unittest
from collections import deque
from contextlib import _GeneratorContextManager, contextmanager
from contextlib import _GeneratorContextManager, contextmanager, nullcontext
class MockContextManager(_GeneratorContextManager):
@@ -641,6 +641,12 @@ class AssignmentTargetTestCase(unittest.TestCase):
self.assertEqual(blah.two, 2)
self.assertEqual(blah.three, 3)
def testWithExtendedTargets(self):
with nullcontext(range(1, 5)) as (a, *b, c):
self.assertEqual(a, 1)
self.assertEqual(b, [2, 3])
self.assertEqual(c, 4)
class ExitSwallowsExceptionTestCase(unittest.TestCase):

View File

@@ -10,7 +10,6 @@ import functools
import html
import io
import itertools
import locale
import operator
import os
import pickle
@@ -130,6 +129,9 @@ def checkwarnings(*filters, quiet=False):
return newtest
return decorator
def convlinesep(data):
return data.replace(b'\n', os.linesep.encode())
class ModuleTest(unittest.TestCase):
def test_sanity(self):
@@ -1023,17 +1025,15 @@ class ElementTreeTest(unittest.TestCase):
@unittest.expectedFailure
def test_tostring_xml_declaration_unicode_encoding(self):
elem = ET.XML('<body><tag/></body>')
preferredencoding = locale.getpreferredencoding()
self.assertEqual(
f"<?xml version='1.0' encoding='{preferredencoding}'?>\n<body><tag /></body>",
ET.tostring(elem, encoding='unicode', xml_declaration=True)
ET.tostring(elem, encoding='unicode', xml_declaration=True),
"<?xml version='1.0' encoding='utf-8'?>\n<body><tag /></body>"
)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_tostring_xml_declaration_cases(self):
elem = ET.XML('<body><tag>ø</tag></body>')
preferredencoding = locale.getpreferredencoding()
TESTCASES = [
# (expected_retval, encoding, xml_declaration)
# ... xml_declaration = None
@@ -1060,7 +1060,7 @@ class ElementTreeTest(unittest.TestCase):
b"<body><tag>&#248;</tag></body>", 'US-ASCII', True),
(b"<?xml version='1.0' encoding='ISO-8859-1'?>\n"
b"<body><tag>\xf8</tag></body>", 'ISO-8859-1', True),
(f"<?xml version='1.0' encoding='{preferredencoding}'?>\n"
("<?xml version='1.0' encoding='utf-8'?>\n"
"<body><tag>ø</tag></body>", 'unicode', True),
]
@@ -1102,11 +1102,10 @@ class ElementTreeTest(unittest.TestCase):
b"<?xml version='1.0' encoding='us-ascii'?>\n<body><tag /></body>"
)
preferredencoding = locale.getpreferredencoding()
stringlist = ET.tostringlist(elem, encoding='unicode', xml_declaration=True)
self.assertEqual(
''.join(stringlist),
f"<?xml version='1.0' encoding='{preferredencoding}'?>\n<body><tag /></body>"
"<?xml version='1.0' encoding='utf-8'?>\n<body><tag /></body>"
)
self.assertRegex(stringlist[0], r"^<\?xml version='1.0' encoding='.+'?>")
self.assertEqual(['<body', '>', '<tag', ' />', '</body>'], stringlist[1:])
@@ -2467,8 +2466,6 @@ class BasicElementTest(ElementTestCase, unittest.TestCase):
self.assertIsNot(element_foo.attrib, attrib)
self.assertNotEqual(element_foo.attrib, attrib)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_copy(self):
# Only run this test if Element.copy() is defined.
if "copy" not in dir(ET.Element):
@@ -3914,38 +3911,97 @@ class IOTest(unittest.TestCase):
@unittest.expectedFailure
def test_write_to_filename(self):
self.addCleanup(os_helper.unlink, TESTFN)
tree = ET.ElementTree(ET.XML('''<site />'''))
tree = ET.ElementTree(ET.XML('''<site>\xf8</site>'''))
tree.write(TESTFN)
with open(TESTFN, 'rb') as f:
self.assertEqual(f.read(), b'''<site />''')
self.assertEqual(f.read(), b'''<site>&#248;</site>''')
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_write_to_filename_with_encoding(self):
self.addCleanup(os_helper.unlink, TESTFN)
tree = ET.ElementTree(ET.XML('''<site>\xf8</site>'''))
tree.write(TESTFN, encoding='utf-8')
with open(TESTFN, 'rb') as f:
self.assertEqual(f.read(), b'''<site>\xc3\xb8</site>''')
tree.write(TESTFN, encoding='ISO-8859-1')
with open(TESTFN, 'rb') as f:
self.assertEqual(f.read(), convlinesep(
b'''<?xml version='1.0' encoding='ISO-8859-1'?>\n'''
b'''<site>\xf8</site>'''))
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_write_to_filename_as_unicode(self):
self.addCleanup(os_helper.unlink, TESTFN)
with open(TESTFN, 'w') as f:
encoding = f.encoding
os_helper.unlink(TESTFN)
tree = ET.ElementTree(ET.XML('''<site>\xf8</site>'''))
tree.write(TESTFN, encoding='unicode')
with open(TESTFN, 'rb') as f:
self.assertEqual(f.read(), b"<site>\xc3\xb8</site>")
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_write_to_text_file(self):
self.addCleanup(os_helper.unlink, TESTFN)
tree = ET.ElementTree(ET.XML('''<site />'''))
tree = ET.ElementTree(ET.XML('''<site>\xf8</site>'''))
with open(TESTFN, 'w', encoding='utf-8') as f:
tree.write(f, encoding='unicode')
self.assertFalse(f.closed)
with open(TESTFN, 'rb') as f:
self.assertEqual(f.read(), b'''<site />''')
self.assertEqual(f.read(), b'''<site>\xc3\xb8</site>''')
with open(TESTFN, 'w', encoding='ascii', errors='xmlcharrefreplace') as f:
tree.write(f, encoding='unicode')
self.assertFalse(f.closed)
with open(TESTFN, 'rb') as f:
self.assertEqual(f.read(), b'''<site>&#248;</site>''')
with open(TESTFN, 'w', encoding='ISO-8859-1') as f:
tree.write(f, encoding='unicode')
self.assertFalse(f.closed)
with open(TESTFN, 'rb') as f:
self.assertEqual(f.read(), b'''<site>\xf8</site>''')
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_write_to_binary_file(self):
self.addCleanup(os_helper.unlink, TESTFN)
tree = ET.ElementTree(ET.XML('''<site />'''))
tree = ET.ElementTree(ET.XML('''<site>\xf8</site>'''))
with open(TESTFN, 'wb') as f:
tree.write(f)
self.assertFalse(f.closed)
with open(TESTFN, 'rb') as f:
self.assertEqual(f.read(), b'''<site />''')
self.assertEqual(f.read(), b'''<site>&#248;</site>''')
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_write_to_binary_file_with_encoding(self):
self.addCleanup(os_helper.unlink, TESTFN)
tree = ET.ElementTree(ET.XML('''<site>\xf8</site>'''))
with open(TESTFN, 'wb') as f:
tree.write(f, encoding='utf-8')
self.assertFalse(f.closed)
with open(TESTFN, 'rb') as f:
self.assertEqual(f.read(), b'''<site>\xc3\xb8</site>''')
with open(TESTFN, 'wb') as f:
tree.write(f, encoding='ISO-8859-1')
self.assertFalse(f.closed)
with open(TESTFN, 'rb') as f:
self.assertEqual(f.read(),
b'''<?xml version='1.0' encoding='ISO-8859-1'?>\n'''
b'''<site>\xf8</site>''')
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_write_to_binary_file_with_bom(self):
self.addCleanup(os_helper.unlink, TESTFN)
tree = ET.ElementTree(ET.XML('''<site />'''))
tree = ET.ElementTree(ET.XML('''<site>\xf8</site>'''))
# test BOM writing to buffered file
with open(TESTFN, 'wb') as f:
tree.write(f, encoding='utf-16')
@@ -3953,7 +4009,7 @@ class IOTest(unittest.TestCase):
with open(TESTFN, 'rb') as f:
self.assertEqual(f.read(),
'''<?xml version='1.0' encoding='utf-16'?>\n'''
'''<site />'''.encode("utf-16"))
'''<site>\xf8</site>'''.encode("utf-16"))
# test BOM writing to non-buffered file
with open(TESTFN, 'wb', buffering=0) as f:
tree.write(f, encoding='utf-16')
@@ -3961,7 +4017,7 @@ class IOTest(unittest.TestCase):
with open(TESTFN, 'rb') as f:
self.assertEqual(f.read(),
'''<?xml version='1.0' encoding='utf-16'?>\n'''
'''<site />'''.encode("utf-16"))
'''<site>\xf8</site>'''.encode("utf-16"))
# TODO: RUSTPYTHON
@unittest.expectedFailure
@@ -3974,10 +4030,10 @@ class IOTest(unittest.TestCase):
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_write_to_stringio(self):
tree = ET.ElementTree(ET.XML('''<site />'''))
tree = ET.ElementTree(ET.XML('''<site>\xf8</site>'''))
stream = io.StringIO()
tree.write(stream, encoding='unicode')
self.assertEqual(stream.getvalue(), '''<site />''')
self.assertEqual(stream.getvalue(), '''<site>\xf8</site>''')
# TODO: RUSTPYTHON
@unittest.expectedFailure
@@ -3990,10 +4046,10 @@ class IOTest(unittest.TestCase):
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_write_to_bytesio(self):
tree = ET.ElementTree(ET.XML('''<site />'''))
tree = ET.ElementTree(ET.XML('''<site>\xf8</site>'''))
raw = io.BytesIO()
tree.write(raw)
self.assertEqual(raw.getvalue(), b'''<site />''')
self.assertEqual(raw.getvalue(), b'''<site>&#248;</site>''')
class dummy:
pass
@@ -4011,12 +4067,12 @@ class IOTest(unittest.TestCase):
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_write_to_user_text_writer(self):
tree = ET.ElementTree(ET.XML('''<site />'''))
tree = ET.ElementTree(ET.XML('''<site>\xf8</site>'''))
stream = io.StringIO()
writer = self.dummy()
writer.write = stream.write
tree.write(writer, encoding='unicode')
self.assertEqual(stream.getvalue(), '''<site />''')
self.assertEqual(stream.getvalue(), '''<site>\xf8</site>''')
# TODO: RUSTPYTHON
@unittest.expectedFailure
@@ -4032,12 +4088,12 @@ class IOTest(unittest.TestCase):
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_write_to_user_binary_writer(self):
tree = ET.ElementTree(ET.XML('''<site />'''))
tree = ET.ElementTree(ET.XML('''<site>\xf8</site>'''))
raw = io.BytesIO()
writer = self.dummy()
writer.write = raw.write
tree.write(writer)
self.assertEqual(raw.getvalue(), b'''<site />''')
self.assertEqual(raw.getvalue(), b'''<site>&#248;</site>''')
# TODO: RUSTPYTHON
@unittest.expectedFailure

View File

@@ -946,6 +946,9 @@ class TestPEP380Operation(unittest.TestCase):
res.append(g1.throw(MyErr))
except StopIteration:
pass
except:
self.assertEqual(res, [0, 1, 2, 3])
raise
# Check with close
class MyIt(object):
def __iter__(self):

30
Lib/test/test_zlib.py vendored
View File

@@ -1,11 +1,13 @@
import unittest
from test import support
from test.support import import_helper
import binascii
import copy
import pickle
import random
import sys
from test.support import bigmemtest, _1G, _4G, import_helper
from test.support import bigmemtest, _1G, _4G
zlib = import_helper.import_module('zlib')
@@ -129,6 +131,12 @@ class ExceptionTestCase(unittest.TestCase):
with self.assertRaisesRegex(OverflowError, 'int too large'):
zlib.decompressobj().flush(sys.maxsize + 1)
@support.cpython_only
def test_disallow_instantiation(self):
# Ensure that the type disallows instantiation (bpo-43916)
support.check_disallow_instantiation(self, type(zlib.compressobj()))
support.check_disallow_instantiation(self, type(zlib.decompressobj()))
class BaseCompressTestCase(object):
def check_big_compress_buffer(self, size, compress_func):
@@ -136,8 +144,7 @@ class BaseCompressTestCase(object):
# Generate 10 MiB worth of random, and expand it by repeating it.
# The assumption is that zlib's memory is not big enough to exploit
# such spread out redundancy.
data = b''.join([random.getrandbits(8 * _1M).to_bytes(_1M, 'little')
for i in range(10)])
data = random.randbytes(_1M * 10)
data = data * (size // len(data) + 1)
try:
compress_func(data)
@@ -498,7 +505,7 @@ class CompressObjectTestCase(BaseCompressTestCase, unittest.TestCase):
# others might simply have a single RNG
gen = random
gen.seed(1)
data = genblock(1, 17 * 1024, generator=gen)
data = gen.randbytes(17 * 1024)
# compress, sync-flush, and decompress
first = co.compress(data)
@@ -847,20 +854,6 @@ class CompressObjectTestCase(BaseCompressTestCase, unittest.TestCase):
self.assertEqual(dco.decompress(gzip), HAMLET_SCENE)
def genblock(seed, length, step=1024, generator=random):
"""length-byte stream of random data from a seed (in step-byte blocks)."""
if seed is not None:
generator.seed(seed)
randint = generator.randint
if length < step or step < 2:
step = length
blocks = bytes()
for i in range(0, length, step):
blocks += bytes(randint(0, 255) for x in range(step))
return blocks
def choose_lines(source, number, seed=None, generator=random):
"""Return a list of number lines randomly chosen from the source"""
if seed is not None:
@@ -869,7 +862,6 @@ def choose_lines(source, number, seed=None, generator=random):
return [generator.choice(sources) for n in range(number)]
HAMLET_SCENE = b"""
LAERTES

213
Lib/webbrowser.py vendored
View File

@@ -1,5 +1,5 @@
#! /usr/bin/env python3
"""Interfaces for launching and remotely controlling Web browsers."""
"""Interfaces for launching and remotely controlling web browsers."""
# Maintained by Georg Brandl.
import os
@@ -7,25 +7,39 @@ import shlex
import shutil
import sys
import subprocess
import threading
__all__ = ["Error", "open", "open_new", "open_new_tab", "get", "register"]
class Error(Exception):
pass
_browsers = {} # Dictionary of available browser controllers
_tryorder = [] # Preference order of available browsers
_lock = threading.RLock()
_browsers = {} # Dictionary of available browser controllers
_tryorder = None # Preference order of available browsers
_os_preferred_browser = None # The preferred browser
def register(name, klass, instance=None, update_tryorder=1):
"""Register a browser connector and, optionally, connection."""
_browsers[name.lower()] = [klass, instance]
if update_tryorder > 0:
_tryorder.append(name)
elif update_tryorder < 0:
_tryorder.insert(0, name)
def register(name, klass, instance=None, *, preferred=False):
"""Register a browser connector."""
with _lock:
if _tryorder is None:
register_standard_browsers()
_browsers[name.lower()] = [klass, instance]
# Preferred browsers go to the front of the list.
# Need to match to the default browser returned by xdg-settings, which
# may be of the form e.g. "firefox.desktop".
if preferred or (_os_preferred_browser and name in _os_preferred_browser):
_tryorder.insert(0, name)
else:
_tryorder.append(name)
def get(using=None):
"""Return a browser launcher instance appropriate for the environment."""
if _tryorder is None:
with _lock:
if _tryorder is None:
register_standard_browsers()
if using is not None:
alternatives = [using]
else:
@@ -55,6 +69,18 @@ def get(using=None):
# instead of "from webbrowser import *".
def open(url, new=0, autoraise=True):
"""Display url using the default browser.
If possible, open url in a location determined by new.
- 0: the same browser window (the default).
- 1: a new browser window.
- 2: a new browser page ("tab").
If possible, autoraise raises the window (the default) or not.
"""
if _tryorder is None:
with _lock:
if _tryorder is None:
register_standard_browsers()
for name in _tryorder:
browser = get(name)
if browser.open(url, new, autoraise):
@@ -62,14 +88,22 @@ def open(url, new=0, autoraise=True):
return False
def open_new(url):
"""Open url in a new window of the default browser.
If not possible, then open url in the only browser window.
"""
return open(url, 1)
def open_new_tab(url):
"""Open url in a new page ("tab") of the default browser.
If not possible, then the behavior becomes equivalent to open_new().
"""
return open(url, 2)
def _synthesize(browser, update_tryorder=1):
"""Attempt to synthesize a controller base on existing controllers.
def _synthesize(browser, *, preferred=False):
"""Attempt to synthesize a controller based on existing controllers.
This is useful to create a controller when a user specifies a path to
an entry in the BROWSER environment variable -- we can copy a general
@@ -95,7 +129,7 @@ def _synthesize(browser, update_tryorder=1):
controller = copy.copy(controller)
controller.name = browser
controller.basename = os.path.basename(browser)
register(browser, None, controller, update_tryorder)
register(browser, None, instance=controller, preferred=preferred)
return [None, controller]
return [None, None]
@@ -136,6 +170,7 @@ class GenericBrowser(BaseBrowser):
self.basename = os.path.basename(self.name)
def open(self, url, new=0, autoraise=True):
sys.audit("webbrowser.open", url)
cmdline = [self.name] + [arg.replace("%s", url)
for arg in self.args]
try:
@@ -155,6 +190,7 @@ class BackgroundBrowser(GenericBrowser):
def open(self, url, new=0, autoraise=True):
cmdline = [self.name] + [arg.replace("%s", url)
for arg in self.args]
sys.audit("webbrowser.open", url)
try:
if sys.platform[:3] == 'win':
p = subprocess.Popen(cmdline)
@@ -183,7 +219,7 @@ class UnixBrowser(BaseBrowser):
remote_action_newwin = None
remote_action_newtab = None
def _invoke(self, args, remote, autoraise):
def _invoke(self, args, remote, autoraise, url=None):
raise_opt = []
if remote and self.raise_opts:
# use autoraise argument only for remote invocation
@@ -219,6 +255,7 @@ class UnixBrowser(BaseBrowser):
return not p.wait()
def open(self, url, new=0, autoraise=True):
sys.audit("webbrowser.open", url)
if new == 0:
action = self.remote_action
elif new == 1:
@@ -235,7 +272,7 @@ class UnixBrowser(BaseBrowser):
args = [arg.replace("%s", url).replace("%action", action)
for arg in self.remote_args]
args = [arg for arg in args if arg]
success = self._invoke(args, True, autoraise)
success = self._invoke(args, True, autoraise, url)
if not success:
# remote invocation failed, try straight way
args = [arg.replace("%s", url) for arg in self.args]
@@ -290,11 +327,10 @@ Chromium = Chrome
class Opera(UnixBrowser):
"Launcher class for Opera browser."
raise_opts = ["-noraise", ""]
remote_args = ['-remote', 'openURL(%s%action)']
remote_args = ['%action', '%s']
remote_action = ""
remote_action_newwin = ",new-window"
remote_action_newtab = ",new-page"
remote_action_newwin = "--new-window"
remote_action_newtab = ""
background = True
@@ -320,6 +356,7 @@ class Konqueror(BaseBrowser):
"""
def open(self, url, new=0, autoraise=True):
sys.audit("webbrowser.open", url)
# XXX Currently I know no way to prevent KFM from opening a new win.
if new == 2:
action = "newTab"
@@ -376,7 +413,7 @@ class Grail(BaseBrowser):
tempdir = os.path.join(tempfile.gettempdir(),
".grail-unix")
user = pwd.getpwuid(os.getuid())[0]
filename = os.path.join(tempdir, user + "-*")
filename = os.path.join(glob.escape(tempdir), glob.escape(user) + "-*")
maybes = glob.glob(filename)
if not maybes:
return None
@@ -403,6 +440,7 @@ class Grail(BaseBrowser):
return 1
def open(self, url, new=0, autoraise=True):
sys.audit("webbrowser.open", url)
if new:
ok = self._remote("LOADNEW " + url)
else:
@@ -482,25 +520,80 @@ def register_X_browsers():
if shutil.which("grail"):
register("grail", Grail, None)
# Prefer X browsers if present
if os.environ.get("DISPLAY"):
register_X_browsers()
def register_standard_browsers():
global _tryorder
_tryorder = []
if sys.platform == 'darwin':
register("MacOSX", None, MacOSXOSAScript('default'))
register("chrome", None, MacOSXOSAScript('chrome'))
register("firefox", None, MacOSXOSAScript('firefox'))
register("safari", None, MacOSXOSAScript('safari'))
# OS X can use below Unix support (but we prefer using the OS X
# specific stuff)
if sys.platform == "serenityos":
# SerenityOS webbrowser, simply called "Browser".
register("Browser", None, BackgroundBrowser("Browser"))
if sys.platform[:3] == "win":
# First try to use the default Windows browser
register("windows-default", WindowsDefault)
# Detect some common Windows browsers, fallback to IE
iexplore = os.path.join(os.environ.get("PROGRAMFILES", "C:\\Program Files"),
"Internet Explorer\\IEXPLORE.EXE")
for browser in ("firefox", "firebird", "seamonkey", "mozilla",
"netscape", "opera", iexplore):
if shutil.which(browser):
register(browser, None, BackgroundBrowser(browser))
else:
# Prefer X browsers if present
if os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY"):
try:
cmd = "xdg-settings get default-web-browser".split()
raw_result = subprocess.check_output(cmd, stderr=subprocess.DEVNULL)
result = raw_result.decode().strip()
except (FileNotFoundError, subprocess.CalledProcessError, PermissionError, NotADirectoryError) :
pass
else:
global _os_preferred_browser
_os_preferred_browser = result
register_X_browsers()
# Also try console browsers
if os.environ.get("TERM"):
if shutil.which("www-browser"):
register("www-browser", None, GenericBrowser("www-browser"))
# The Links/elinks browsers <http://artax.karlin.mff.cuni.cz/~mikulas/links/>
if shutil.which("links"):
register("links", None, GenericBrowser("links"))
if shutil.which("elinks"):
register("elinks", None, Elinks("elinks"))
# The Lynx browser <http://lynx.isc.org/>, <http://lynx.browser.org/>
if shutil.which("lynx"):
register("lynx", None, GenericBrowser("lynx"))
# The w3m browser <http://w3m.sourceforge.net/>
if shutil.which("w3m"):
register("w3m", None, GenericBrowser("w3m"))
# OK, now that we know what the default preference orders for each
# platform are, allow user to override them with the BROWSER variable.
if "BROWSER" in os.environ:
userchoices = os.environ["BROWSER"].split(os.pathsep)
userchoices.reverse()
# Treat choices in same way as if passed into get() but do register
# and prepend to _tryorder
for cmdline in userchoices:
if cmdline != '':
cmd = _synthesize(cmdline, preferred=True)
if cmd[1] is None:
register(cmdline, None, GenericBrowser(cmdline), preferred=True)
# what to do if _tryorder is now empty?
# Also try console browsers
if os.environ.get("TERM"):
if shutil.which("www-browser"):
register("www-browser", None, GenericBrowser("www-browser"))
# The Links/elinks browsers <http://artax.karlin.mff.cuni.cz/~mikulas/links/>
if shutil.which("links"):
register("links", None, GenericBrowser("links"))
if shutil.which("elinks"):
register("elinks", None, Elinks("elinks"))
# The Lynx browser <http://lynx.isc.org/>, <http://lynx.browser.org/>
if shutil.which("lynx"):
register("lynx", None, GenericBrowser("lynx"))
# The w3m browser <http://w3m.sourceforge.net/>
if shutil.which("w3m"):
register("w3m", None, GenericBrowser("w3m"))
#
# Platform support for Windows
@@ -509,6 +602,7 @@ if os.environ.get("TERM"):
if sys.platform[:3] == "win":
class WindowsDefault(BaseBrowser):
def open(self, url, new=0, autoraise=True):
sys.audit("webbrowser.open", url)
try:
os.startfile(url)
except OSError:
@@ -518,20 +612,6 @@ if sys.platform[:3] == "win":
else:
return True
_tryorder = []
_browsers = {}
# First try to use the default Windows browser
register("windows-default", WindowsDefault)
# Detect some common Windows browsers, fallback to IE
iexplore = os.path.join(os.environ.get("PROGRAMFILES", "C:\\Program Files"),
"Internet Explorer\\IEXPLORE.EXE")
for browser in ("firefox", "firebird", "seamonkey", "mozilla",
"netscape", "opera", iexplore):
if shutil.which(browser):
register(browser, None, BackgroundBrowser(browser))
#
# Platform support for MacOS
#
@@ -552,6 +632,7 @@ if sys.platform == 'darwin':
self.name = name
def open(self, url, new=0, autoraise=True):
sys.audit("webbrowser.open", url)
assert "'" not in url
# hack for local urls
if not ':' in url:
@@ -608,34 +689,6 @@ if sys.platform == 'darwin':
return not rc
# Don't clear _tryorder or _browsers since OS X can use above Unix support
# (but we prefer using the OS X specific stuff)
register("safari", None, MacOSXOSAScript('safari'), -1)
register("firefox", None, MacOSXOSAScript('firefox'), -1)
register("chrome", None, MacOSXOSAScript('chrome'), -1)
register("MacOSX", None, MacOSXOSAScript('default'), -1)
# OK, now that we know what the default preference orders for each
# platform are, allow user to override them with the BROWSER variable.
if "BROWSER" in os.environ:
_userchoices = os.environ["BROWSER"].split(os.pathsep)
_userchoices.reverse()
# Treat choices in same way as if passed into get() but do register
# and prepend to _tryorder
for cmdline in _userchoices:
if cmdline != '':
cmd = _synthesize(cmdline, -1)
if cmd[1] is None:
register(cmdline, None, GenericBrowser(cmdline), -1)
cmdline = None # to make del work if _userchoices was empty
del cmdline
del _userchoices
# what to do if _tryorder is now empty?
def main():
import getopt
usage = """Usage: %s [-n | -t] url

View File

@@ -204,11 +204,11 @@ class ExpatBuilder:
buffer = file.read(16*1024)
if not buffer:
break
parser.Parse(buffer, 0)
parser.Parse(buffer, False)
if first_buffer and self.document.documentElement:
self._setup_subset(buffer)
first_buffer = False
parser.Parse("", True)
parser.Parse(b"", True)
except ParseEscape:
pass
doc = self.document
@@ -637,7 +637,7 @@ class FragmentBuilder(ExpatBuilder):
nsattrs = self._getNSattrs() # get ns decls from node's ancestors
document = _FRAGMENT_BUILDER_TEMPLATE % (ident, subset, nsattrs)
try:
parser.Parse(document, 1)
parser.Parse(document, True)
except:
self.reset()
raise
@@ -697,7 +697,7 @@ class FragmentBuilder(ExpatBuilder):
self.fragment = self.document.createDocumentFragment()
self.curNode = self.fragment
try:
parser.Parse(self._source, 1)
parser.Parse(self._source, True)
finally:
self.curNode = old_cur_node
self.document = old_document

View File

@@ -43,10 +43,11 @@ class Node(xml.dom.Node):
def __bool__(self):
return True
def toxml(self, encoding=None):
return self.toprettyxml("", "", encoding)
def toxml(self, encoding=None, standalone=None):
return self.toprettyxml("", "", encoding, standalone)
def toprettyxml(self, indent="\t", newl="\n", encoding=None):
def toprettyxml(self, indent="\t", newl="\n", encoding=None,
standalone=None):
if encoding is None:
writer = io.StringIO()
else:
@@ -56,7 +57,7 @@ class Node(xml.dom.Node):
newline='\n')
if self.nodeType == Node.DOCUMENT_NODE:
# Can pass encoding only to document, to put it into XML header
self.writexml(writer, "", indent, newl, encoding)
self.writexml(writer, "", indent, newl, encoding, standalone)
else:
self.writexml(writer, "", indent, newl)
if encoding is None:
@@ -718,6 +719,14 @@ class Element(Node):
Node.unlink(self)
def getAttribute(self, attname):
"""Returns the value of the specified attribute.
Returns the value of the element's attribute named attname as
a string. An empty string is returned if the element does not
have such an attribute. Note that an empty string may also be
returned as an explicitly given attribute value, use the
hasAttribute method to distinguish these two cases.
"""
if self._attrs is None:
return ""
try:
@@ -823,10 +832,16 @@ class Element(Node):
# Restore this since the node is still useful and otherwise
# unlinked
node.ownerDocument = self.ownerDocument
return node
removeAttributeNodeNS = removeAttributeNode
def hasAttribute(self, name):
"""Checks whether the element has an attribute with the specified name.
Returns True if the element has an attribute with the specified name.
Otherwise, returns False.
"""
if self._attrs is None:
return False
return name in self._attrs
@@ -837,6 +852,11 @@ class Element(Node):
return (namespaceURI, localName) in self._attrsNS
def getElementsByTagName(self, name):
"""Returns all descendant elements with the given tag name.
Returns the list of all descendant elements (not direct children
only) with the specified tag name.
"""
return _get_elements_by_tagName_helper(self, name, NodeList())
def getElementsByTagNameNS(self, namespaceURI, localName):
@@ -847,22 +867,27 @@ class Element(Node):
return "<DOM Element: %s at %#x>" % (self.tagName, id(self))
def writexml(self, writer, indent="", addindent="", newl=""):
"""Write an XML element to a file-like object
Write the element to the writer object that must provide
a write method (e.g. a file or StringIO object).
"""
# indent = current indentation
# addindent = indentation to add to higher levels
# newl = newline string
writer.write(indent+"<" + self.tagName)
attrs = self._get_attributes()
a_names = sorted(attrs.keys())
for a_name in a_names:
for a_name in attrs.keys():
writer.write(" %s=\"" % a_name)
_write_data(writer, attrs[a_name].value)
writer.write("\"")
if self.childNodes:
writer.write(">")
if (len(self.childNodes) == 1 and
self.childNodes[0].nodeType == Node.TEXT_NODE):
self.childNodes[0].nodeType in (
Node.TEXT_NODE, Node.CDATA_SECTION_NODE)):
self.childNodes[0].writexml(writer, '', '', '')
else:
writer.write(newl)
@@ -1786,12 +1811,17 @@ class Document(Node, DocumentLS):
raise xml.dom.NotSupportedErr("cannot import document type nodes")
return _clone_node(node, deep, self)
def writexml(self, writer, indent="", addindent="", newl="", encoding=None):
if encoding is None:
writer.write('<?xml version="1.0" ?>'+newl)
else:
writer.write('<?xml version="1.0" encoding="%s"?>%s' % (
encoding, newl))
def writexml(self, writer, indent="", addindent="", newl="", encoding=None,
standalone=None):
declarations = []
if encoding:
declarations.append(f'encoding="{encoding}"')
if standalone is not None:
declarations.append(f'standalone="{"yes" if standalone else "no"}"')
writer.write(f'<?xml version="1.0" {" ".join(declarations)}?>{newl}')
for node in self.childNodes:
node.writexml(writer, indent, addindent, newl)

View File

@@ -217,6 +217,13 @@ class DOMEventStream:
self.parser.setContentHandler(self.pulldom)
def __getitem__(self, pos):
import warnings
warnings.warn(
"DOMEventStream's __getitem__ method ignores 'pos' parameter. "
"Use iterator protocol instead.",
DeprecationWarning,
stacklevel=2
)
rc = self.getEvent()
if rc:
return rc

View File

@@ -1,7 +1,6 @@
"""Implementation of the DOM Level 3 'LS-Load' feature."""
import copy
import warnings
import xml.dom
from xml.dom.NodeFilter import NodeFilter
@@ -80,7 +79,7 @@ class DOMBuilder:
settings = self._settings[(_name_xform(name), state)]
except KeyError:
raise xml.dom.NotSupportedErr(
"unsupported feature: %r" % (name,))
"unsupported feature: %r" % (name,)) from None
else:
for name, value in settings:
setattr(self._options, name, value)
@@ -332,29 +331,10 @@ class DOMBuilderFilter:
del NodeFilter
class _AsyncDeprecatedProperty:
def warn(self, cls):
clsname = cls.__name__
warnings.warn(
"{cls}.async is deprecated; use {cls}.async_".format(cls=clsname),
DeprecationWarning)
def __get__(self, instance, cls):
self.warn(cls)
if instance is not None:
return instance.async_
return False
def __set__(self, instance, value):
self.warn(type(instance))
setattr(instance, 'async_', value)
class DocumentLS:
"""Mixin to create documents that conform to the load/save spec."""
async_ = False
locals()['async'] = _AsyncDeprecatedProperty() # Avoid DeprecationWarning
def _get_async(self):
return False
@@ -384,9 +364,6 @@ class DocumentLS:
return snode.toxml()
del _AsyncDeprecatedProperty
class DOMImplementationLS:
MODE_SYNCHRONOUS = 1
MODE_ASYNCHRONOUS = 2

View File

@@ -42,7 +42,7 @@
# --------------------------------------------------------------------
# Licensed to PSF under a Contributor Agreement.
# See http://www.python.org/psf/license for licensing details.
# See https://www.python.org/psf/license for licensing details.
##
# Limited XInclude support for the ElementTree package.
@@ -50,18 +50,28 @@
import copy
from . import ElementTree
from urllib.parse import urljoin
XINCLUDE = "{http://www.w3.org/2001/XInclude}"
XINCLUDE_INCLUDE = XINCLUDE + "include"
XINCLUDE_FALLBACK = XINCLUDE + "fallback"
# For security reasons, the inclusion depth is limited to this read-only value by default.
DEFAULT_MAX_INCLUSION_DEPTH = 6
##
# Fatal include error.
class FatalIncludeError(SyntaxError):
pass
class LimitedRecursiveIncludeError(FatalIncludeError):
pass
##
# Default loader. This loader reads an included resource from disk.
#
@@ -92,13 +102,33 @@ def default_loader(href, parse, encoding=None):
# @param loader Optional resource loader. If omitted, it defaults
# to {@link default_loader}. If given, it should be a callable
# that implements the same interface as <b>default_loader</b>.
# @param base_url The base URL of the original file, to resolve
# relative include file references.
# @param max_depth The maximum number of recursive inclusions.
# Limited to reduce the risk of malicious content explosion.
# Pass a negative value to disable the limitation.
# @throws LimitedRecursiveIncludeError If the {@link max_depth} was exceeded.
# @throws FatalIncludeError If the function fails to include a given
# resource, or if the tree contains malformed XInclude elements.
# @throws OSError If the function fails to load a given resource.
# @throws IOError If the function fails to load a given resource.
# @returns the node or its replacement if it was an XInclude node
def include(elem, loader=None):
def include(elem, loader=None, base_url=None,
max_depth=DEFAULT_MAX_INCLUSION_DEPTH):
if max_depth is None:
max_depth = -1
elif max_depth < 0:
raise ValueError("expected non-negative depth or None for 'max_depth', got %r" % max_depth)
if hasattr(elem, 'getroot'):
elem = elem.getroot()
if loader is None:
loader = default_loader
_include(elem, loader, base_url, max_depth, set())
def _include(elem, loader, base_url, max_depth, _parent_hrefs):
# look for xinclude elements
i = 0
while i < len(elem):
@@ -106,14 +136,24 @@ def include(elem, loader=None):
if e.tag == XINCLUDE_INCLUDE:
# process xinclude directive
href = e.get("href")
if base_url:
href = urljoin(base_url, href)
parse = e.get("parse", "xml")
if parse == "xml":
if href in _parent_hrefs:
raise FatalIncludeError("recursive include of %s" % href)
if max_depth == 0:
raise LimitedRecursiveIncludeError(
"maximum xinclude depth reached when including file %s" % href)
_parent_hrefs.add(href)
node = loader(href, parse)
if node is None:
raise FatalIncludeError(
"cannot load %r as %r" % (href, parse)
)
node = copy.copy(node)
node = copy.copy(node) # FIXME: this makes little sense with recursive includes
_include(node, loader, href, max_depth - 1, _parent_hrefs)
_parent_hrefs.remove(href)
if e.tail:
node.tail = (node.tail or "") + e.tail
elem[i] = node
@@ -123,11 +163,13 @@ def include(elem, loader=None):
raise FatalIncludeError(
"cannot load %r as %r" % (href, parse)
)
if e.tail:
text += e.tail
if i:
node = elem[i-1]
node.tail = (node.tail or "") + text + (e.tail or "")
node.tail = (node.tail or "") + text
else:
elem.text = (elem.text or "") + text + (e.tail or "")
elem.text = (elem.text or "") + text
del elem[i]
continue
else:
@@ -139,5 +181,5 @@ def include(elem, loader=None):
"xi:fallback tag must be child of xi:include (%r)" % e.tag
)
else:
include(e, loader)
i = i + 1
_include(e, loader, base_url, max_depth, _parent_hrefs)
i += 1

View File

@@ -48,7 +48,7 @@
# --------------------------------------------------------------------
# Licensed to PSF under a Contributor Agreement.
# See http://www.python.org/psf/license for licensing details.
# See https://www.python.org/psf/license for licensing details.
##
# Implementation module for XPath support. There's usually no reason
@@ -65,8 +65,9 @@ xpath_tokenizer_re = re.compile(
r"//?|"
r"\.\.|"
r"\(\)|"
r"!=|"
r"[/.*:\[\]\(\)@=])|"
r"((?:\{[^}]+\})?[^/\[\]\(\)@=\s]+)|"
r"((?:\{[^}]+\})?[^/\[\]\(\)@!=\s]+)|"
r"\s+"
)
@@ -225,7 +226,6 @@ def prepare_parent(next, token):
def prepare_predicate(next, token):
# FIXME: replace with real parser!!! refs:
# http://effbot.org/zone/simple-iterator-parser.htm
# http://javascript.crockford.com/tdop/tdop.html
signature = []
predicate = []
@@ -253,15 +253,19 @@ def prepare_predicate(next, token):
if elem.get(key) is not None:
yield elem
return select
if signature == "@-='":
# [@attribute='value']
if signature == "@-='" or signature == "@-!='":
# [@attribute='value'] or [@attribute!='value']
key = predicate[1]
value = predicate[-1]
def select(context, result):
for elem in result:
if elem.get(key) == value:
yield elem
return select
def select_negated(context, result):
for elem in result:
if (attr_value := elem.get(key)) is not None and attr_value != value:
yield elem
return select_negated if '!=' in signature else select
if signature == "-" and not re.match(r"\-?\d+$", predicate[0]):
# [tag]
tag = predicate[0]
@@ -270,8 +274,10 @@ def prepare_predicate(next, token):
if elem.find(tag) is not None:
yield elem
return select
if signature == ".='" or (signature == "-='" and not re.match(r"\-?\d+$", predicate[0])):
# [.='value'] or [tag='value']
if signature == ".='" or signature == ".!='" or (
(signature == "-='" or signature == "-!='")
and not re.match(r"\-?\d+$", predicate[0])):
# [.='value'] or [tag='value'] or [.!='value'] or [tag!='value']
tag = predicate[0]
value = predicate[-1]
if tag:
@@ -281,12 +287,22 @@ def prepare_predicate(next, token):
if "".join(e.itertext()) == value:
yield elem
break
def select_negated(context, result):
for elem in result:
for e in elem.iterfind(tag):
if "".join(e.itertext()) != value:
yield elem
break
else:
def select(context, result):
for elem in result:
if "".join(elem.itertext()) == value:
yield elem
return select
def select_negated(context, result):
for elem in result:
if "".join(elem.itertext()) != value:
yield elem
return select_negated if '!=' in signature else select
if signature == "-" or signature == "-()" or signature == "-()-":
# [index] or [last()] or [last()-index]
if signature == "-":

View File

@@ -35,7 +35,7 @@
#---------------------------------------------------------------------
# Licensed to PSF under a Contributor Agreement.
# See http://www.python.org/psf/license for licensing details.
# See https://www.python.org/psf/license for licensing details.
#
# ElementTree
# Copyright (c) 1999-2008 by Fredrik Lundh. All rights reserved.
@@ -76,7 +76,7 @@ __all__ = [
"dump",
"Element", "ElementTree",
"fromstring", "fromstringlist",
"iselement", "iterparse",
"indent", "iselement", "iterparse",
"parse", "ParseError",
"PI", "ProcessingInstruction",
"QName",
@@ -195,6 +195,13 @@ class Element:
original tree.
"""
warnings.warn(
"elem.copy() is deprecated. Use copy.copy(elem) instead.",
DeprecationWarning
)
return self.__copy__()
def __copy__(self):
elem = self.makeelement(self.tag, self.attrib)
elem.text = self.text
elem.tail = self.tail
@@ -273,19 +280,6 @@ class Element:
# assert iselement(element)
self._children.remove(subelement)
def getchildren(self):
"""(Deprecated) Return all subelements.
Elements are returned in document order.
"""
warnings.warn(
"This method will be removed in future versions. "
"Use 'list(elem)' or iteration over elem instead.",
DeprecationWarning, stacklevel=2
)
return self._children
def find(self, path, namespaces=None):
"""Find first matching element by tag name or path.
@@ -409,15 +403,6 @@ class Element:
for e in self._children:
yield from e.iter(tag)
# compatibility
def getiterator(self, tag=None):
warnings.warn(
"This method will be removed in future versions. "
"Use 'elem.iter()' or 'list(elem.iter())' instead.",
DeprecationWarning, stacklevel=2
)
return list(self.iter(tag))
def itertext(self):
"""Create text iterator.
@@ -617,15 +602,6 @@ class ElementTree:
# assert self._root is not None
return self._root.iter(tag)
# compatibility
def getiterator(self, tag=None):
warnings.warn(
"This method will be removed in future versions. "
"Use 'tree.iter()' or 'list(tree.iter())' instead.",
DeprecationWarning, stacklevel=2
)
return list(self.iter(tag))
def find(self, path, namespaces=None):
"""Find first matching element by tag name or path.
@@ -752,16 +728,11 @@ class ElementTree:
encoding = "utf-8"
else:
encoding = "us-ascii"
enc_lower = encoding.lower()
with _get_writer(file_or_filename, enc_lower) as write:
with _get_writer(file_or_filename, encoding) as (write, declared_encoding):
if method == "xml" and (xml_declaration or
(xml_declaration is None and
enc_lower not in ("utf-8", "us-ascii", "unicode"))):
declared_encoding = encoding
if enc_lower == "unicode":
# Retrieve the default encoding for the xml declaration
import locale
declared_encoding = locale.getpreferredencoding()
encoding.lower() != "unicode" and
declared_encoding.lower() not in ("utf-8", "us-ascii"))):
write("<?xml version='1.0' encoding='%s'?>\n" % (
declared_encoding,))
if method == "text":
@@ -786,19 +757,17 @@ def _get_writer(file_or_filename, encoding):
write = file_or_filename.write
except AttributeError:
# file_or_filename is a file name
if encoding == "unicode":
file = open(file_or_filename, "w")
else:
file = open(file_or_filename, "w", encoding=encoding,
errors="xmlcharrefreplace")
with file:
yield file.write
if encoding.lower() == "unicode":
encoding="utf-8"
with open(file_or_filename, "w", encoding=encoding,
errors="xmlcharrefreplace") as file:
yield file.write, encoding
else:
# file_or_filename is a file-like object
# encoding determines if it is a text or binary writer
if encoding == "unicode":
if encoding.lower() == "unicode":
# use a text writer as is
yield write
yield write, getattr(file_or_filename, "encoding", None) or "utf-8"
else:
# wrap a binary writer with TextIOWrapper
with contextlib.ExitStack() as stack:
@@ -829,7 +798,7 @@ def _get_writer(file_or_filename, encoding):
# Keep the original file open when the TextIOWrapper is
# destroyed
stack.callback(file.detach)
yield file.write
yield file.write, encoding
def _namespaces(elem, default_namespace=None):
# identify namespaces used in this tree
@@ -1081,15 +1050,15 @@ def _escape_attrib(text):
text = text.replace(">", "&gt;")
if "\"" in text:
text = text.replace("\"", "&quot;")
# The following business with carriage returns is to satisfy
# Section 2.11 of the XML specification, stating that
# CR or CR LN should be replaced with just LN
# Although section 2.11 of the XML specification states that CR or
# CR LN should be replaced with just LN, it applies only to EOLNs
# which take part of organizing file into lines. Within attributes,
# we are replacing these with entity numbers, so they do not count.
# http://www.w3.org/TR/REC-xml/#sec-line-ends
if "\r\n" in text:
text = text.replace("\r\n", "\n")
# The current solution, contained in following six lines, was
# discussed in issue 17582 and 39011.
if "\r" in text:
text = text.replace("\r", "\n")
#The following four lines are issue 17582
text = text.replace("\r", "&#13;")
if "\n" in text:
text = text.replace("\n", "&#10;")
if "\t" in text:
@@ -1185,6 +1154,57 @@ def dump(elem):
if not tail or tail[-1] != "\n":
sys.stdout.write("\n")
def indent(tree, space=" ", level=0):
"""Indent an XML document by inserting newlines and indentation space
after elements.
*tree* is the ElementTree or Element to modify. The (root) element
itself will not be changed, but the tail text of all elements in its
subtree will be adapted.
*space* is the whitespace to insert for each indentation level, two
space characters by default.
*level* is the initial indentation level. Setting this to a higher
value than 0 can be used for indenting subtrees that are more deeply
nested inside of a document.
"""
if isinstance(tree, ElementTree):
tree = tree.getroot()
if level < 0:
raise ValueError(f"Initial indentation level must be >= 0, got {level}")
if not len(tree):
return
# Reduce the memory consumption by reusing indentation strings.
indentations = ["\n" + level * space]
def _indent_children(elem, level):
# Start a new indentation level for the first child.
child_level = level + 1
try:
child_indentation = indentations[child_level]
except IndexError:
child_indentation = indentations[level] + space
indentations.append(child_indentation)
if not elem.text or not elem.text.strip():
elem.text = child_indentation
for child in elem:
if len(child):
_indent_children(child, child_level)
if not child.tail or not child.tail.strip():
child.tail = child_indentation
# Dedent after the last child by overwriting the previous indentation.
if not child.tail.strip():
child.tail = indentations[level]
_indent_children(tree, 0)
# --------------------------------------------------------------------
# parsing
@@ -1221,8 +1241,14 @@ def iterparse(source, events=None, parser=None):
# Use the internal, undocumented _parser argument for now; When the
# parser argument of iterparse is removed, this can be killed.
pullparser = XMLPullParser(events=events, _parser=parser)
def iterator():
def iterator(source):
close_source = False
try:
if not hasattr(source, "read"):
source = open(source, "rb")
close_source = True
yield None
while True:
yield from pullparser.read_events()
# load event buffer
@@ -1238,16 +1264,12 @@ def iterparse(source, events=None, parser=None):
source.close()
class IterParseIterator(collections.abc.Iterator):
__next__ = iterator().__next__
__next__ = iterator(source).__next__
it = IterParseIterator()
it.root = None
del iterator, IterParseIterator
close_source = False
if not hasattr(source, "read"):
source = open(source, "rb")
close_source = True
next(it)
return it
@@ -1256,7 +1278,7 @@ class XMLPullParser:
def __init__(self, events=None, *, _parser=None):
# The _parser argument is for internal use only and must not be relied
# upon in user code. It will be removed in a future release.
# See http://bugs.python.org/issue17741 for more details.
# See https://bugs.python.org/issue17741 for more details.
self._events_queue = collections.deque()
self._parser = _parser or XMLParser(target=TreeBuilder())
@@ -1533,7 +1555,6 @@ class XMLParser:
# Configure pyexpat: buffering, new-style attribute handling.
parser.buffer_text = 1
parser.ordered_attributes = 1
parser.specified_attributes = 1
self._doctype = None
self.entity = {}
try:
@@ -1553,7 +1574,6 @@ class XMLParser:
for event_name in events_to_report:
if event_name == "start":
parser.ordered_attributes = 1
parser.specified_attributes = 1
def handler(tag, attrib_in, event=event_name, append=append,
start=self._start):
append((event, start(tag, attrib_in)))
@@ -1690,14 +1710,14 @@ class XMLParser:
def feed(self, data):
"""Feed encoded data to parser."""
try:
self.parser.Parse(data, 0)
self.parser.Parse(data, False)
except self._error as v:
self._raiseerror(v)
def close(self):
"""Finish feeding data to parser and return element structure."""
try:
self.parser.Parse("", 1) # end of data
self.parser.Parse(b"", True) # end of data
except self._error as v:
self._raiseerror(v)
try:

View File

@@ -30,4 +30,4 @@
# --------------------------------------------------------------------
# Licensed to PSF under a Contributor Agreement.
# See http://www.python.org/psf/license for licensing details.
# See https://www.python.org/psf/license for licensing details.

View File

@@ -67,18 +67,18 @@ if sys.platform[:4] == "java" and sys.registry.containsKey(_key):
default_parser_list = sys.registry.getProperty(_key).split(",")
def make_parser(parser_list = []):
def make_parser(parser_list=()):
"""Creates and returns a SAX parser.
Creates the first parser it is able to instantiate of the ones
given in the list created by doing parser_list +
default_parser_list. The lists must contain the names of Python
given in the iterable created by chaining parser_list and
default_parser_list. The iterables must contain the names of Python
modules containing both a SAX parser and a create_parser function."""
for parser_name in parser_list + default_parser_list:
for parser_name in list(parser_list) + default_parser_list:
try:
return _create_parser(parser_name)
except ImportError as e:
except ImportError:
import sys
if parser_name in sys.modules:
# The parser module was found, but importing it

View File

@@ -93,7 +93,7 @@ class ExpatParser(xmlreader.IncrementalParser, xmlreader.Locator):
self._parser = None
self._namespaces = namespaceHandling
self._lex_handler_prop = None
self._parsing = 0
self._parsing = False
self._entity_stack = []
self._external_ges = 0
self._interning = None
@@ -203,10 +203,10 @@ class ExpatParser(xmlreader.IncrementalParser, xmlreader.Locator):
# IncrementalParser methods
def feed(self, data, isFinal = 0):
def feed(self, data, isFinal=False):
if not self._parsing:
self.reset()
self._parsing = 1
self._parsing = True
self._cont_handler.startDocument()
try:
@@ -237,13 +237,13 @@ class ExpatParser(xmlreader.IncrementalParser, xmlreader.Locator):
# If we are completing an external entity, do nothing here
return
try:
self.feed("", isFinal = 1)
self.feed(b"", isFinal=True)
self._cont_handler.endDocument()
self._parsing = 0
self._parsing = False
# break cycle created by expat handlers pointing to our methods
self._parser = None
finally:
self._parsing = 0
self._parsing = False
if self._parser is not None:
# Keep ErrorColumnNumber and ErrorLineNumber after closing.
parser = _ClosedParser()
@@ -307,7 +307,7 @@ class ExpatParser(xmlreader.IncrementalParser, xmlreader.Locator):
self._parser.SetParamEntityParsing(
expat.XML_PARAM_ENTITY_PARSING_UNLESS_STANDALONE)
self._parsing = 0
self._parsing = False
self._entity_stack = []
# Locator methods

View File

@@ -340,3 +340,48 @@ all_properties = [property_lexical_handler,
property_xml_string,
property_encoding,
property_interning_dict]
class LexicalHandler:
"""Optional SAX2 handler for lexical events.
This handler is used to obtain lexical information about an XML
document, that is, information about how the document was encoded
(as opposed to what it contains, which is reported to the
ContentHandler), such as comments and CDATA marked section
boundaries.
To set the LexicalHandler of an XMLReader, use the setProperty
method with the property identifier
'http://xml.org/sax/properties/lexical-handler'."""
def comment(self, content):
"""Reports a comment anywhere in the document (including the
DTD and outside the document element).
content is a string that holds the contents of the comment."""
def startDTD(self, name, public_id, system_id):
"""Report the start of the DTD declarations, if the document
has an associated DTD.
A startEntity event will be reported before declaration events
from the external DTD subset are reported, and this can be
used to infer from which subset DTD declarations derive.
name is the name of the document element type, public_id the
public identifier of the DTD (or None if none were supplied)
and system_id the system identfier of the external subset (or
None if none were supplied)."""
def endDTD(self):
"""Signals the end of DTD declarations."""
def startCDATA(self):
"""Reports the beginning of a CDATA marked section.
The contents of the CDATA marked section will be reported
through the characters event."""
def endCDATA(self):
"""Reports the end of a CDATA marked section."""

View File

@@ -56,8 +56,7 @@ def quoteattr(data, entities={}):
the optional entities parameter. The keys and values must all be
strings; each key will be replaced with its corresponding value.
"""
entities = entities.copy()
entities.update({'\n': '&#10;', '\r': '&#13;', '\t':'&#9;'})
entities = {**entities, '\n': '&#10;', '\r': '&#13;', '\t':'&#9;'}
data = escape(data, entities)
if '"' in data:
if "'" in data:
@@ -340,6 +339,8 @@ def prepare_input_source(source, base=""):
"""This function takes an InputSource and an optional base URL and
returns a fully resolved InputSource object ready for reading."""
if isinstance(source, os.PathLike):
source = os.fspath(source)
if isinstance(source, str):
source = xmlreader.InputSource(source)
elif hasattr(source, "read"):