diff --git a/Lib/sqlite3/__init__.py b/Lib/sqlite3/__init__.py index 927267cf0..ed727fae6 100644 --- a/Lib/sqlite3/__init__.py +++ b/Lib/sqlite3/__init__.py @@ -22,7 +22,7 @@ """ The sqlite3 extension module provides a DB-API 2.0 (PEP 249) compliant -interface to the SQLite library, and requires SQLite 3.7.15 or newer. +interface to the SQLite library, and requires SQLite 3.15.2 or newer. To use the module, start by creating a database Connection object: @@ -55,16 +55,3 @@ The sqlite3 module is written by Gerhard Häring . """ from sqlite3.dbapi2 import * -from sqlite3.dbapi2 import (_deprecated_names, - _deprecated_version_info, - _deprecated_version) - - -def __getattr__(name): - if name in _deprecated_names: - from warnings import warn - - warn(f"{name} is deprecated and will be removed in Python 3.14", - DeprecationWarning, stacklevel=2) - return globals()[f"_deprecated_{name}"] - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/Lib/sqlite3/__main__.py b/Lib/sqlite3/__main__.py index f8a5cca24..4ccf292dd 100644 --- a/Lib/sqlite3/__main__.py +++ b/Lib/sqlite3/__main__.py @@ -46,26 +46,34 @@ class SqliteInteractiveConsole(InteractiveConsole): """Override runsource, the core of the InteractiveConsole REPL. Return True if more input is needed; buffering is done automatically. - Return False is input is a complete statement ready for execution. + Return False if input is a complete statement ready for execution. """ - match source: - case ".version": - print(f"{sqlite3.sqlite_version}") - case ".help": - print("Enter SQL code and press enter.") - case ".quit": - sys.exit(0) - case _: - if not sqlite3.complete_statement(source): - return True - execute(self._cur, source) + if not source or source.isspace(): + return False + if source[0] == ".": + match source[1:].strip(): + case "version": + print(f"{sqlite3.sqlite_version}") + case "help": + print("Enter SQL code and press enter.") + case "quit": + sys.exit(0) + case "": + pass + case _ as unknown: + self.write("Error: unknown command or invalid arguments:" + f' "{unknown}".\n') + else: + if not sqlite3.complete_statement(source): + return True + execute(self._cur, source) return False -def main(): +def main(*args): parser = ArgumentParser( description="Python sqlite3 CLI", - prog="python -m sqlite3", + color=True, ) parser.add_argument( "filename", type=str, default=":memory:", nargs="?", @@ -86,7 +94,7 @@ def main(): version=f"SQLite version {sqlite3.sqlite_version}", help="Print underlying SQLite library version", ) - args = parser.parse_args() + args = parser.parse_args(*args) if args.filename == ":memory:": db_name = "a transient in-memory database" @@ -94,12 +102,16 @@ def main(): db_name = repr(args.filename) # Prepare REPL banner and prompts. + if sys.platform == "win32" and "idlelib.run" not in sys.modules: + eofkey = "CTRL-Z" + else: + eofkey = "CTRL-D" banner = dedent(f""" sqlite3 shell, running on SQLite version {sqlite3.sqlite_version} Connected to {db_name} Each command will be run using execute() on the cursor. - Type ".help" for more information; type ".quit" or CTRL-D to quit. + Type ".help" for more information; type ".quit" or {eofkey} to quit. """).strip() sys.ps1 = "sqlite> " sys.ps2 = " ... " @@ -112,9 +124,16 @@ def main(): else: # No SQL provided; start the REPL. console = SqliteInteractiveConsole(con) + try: + import readline # noqa: F401 + except ImportError: + pass console.interact(banner, exitmsg="") finally: con.close() + sys.exit(0) -main() + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/Lib/sqlite3/dbapi2.py b/Lib/sqlite3/dbapi2.py index 56fc0461e..031576051 100644 --- a/Lib/sqlite3/dbapi2.py +++ b/Lib/sqlite3/dbapi2.py @@ -25,9 +25,6 @@ import time import collections.abc from _sqlite3 import * -from _sqlite3 import _deprecated_version - -_deprecated_names = frozenset({"version", "version_info"}) paramstyle = "qmark" @@ -48,7 +45,7 @@ def TimeFromTicks(ticks): def TimestampFromTicks(ticks): return Timestamp(*time.localtime(ticks)[:6]) -_deprecated_version_info = tuple(map(int, _deprecated_version.split("."))) + sqlite_version_info = tuple([int(x) for x in sqlite_version.split(".")]) Binary = memoryview @@ -97,12 +94,3 @@ register_adapters_and_converters() # Clean up namespace del(register_adapters_and_converters) - -def __getattr__(name): - if name in _deprecated_names: - from warnings import warn - - warn(f"{name} is deprecated and will be removed in Python 3.14", - DeprecationWarning, stacklevel=2) - return globals()[f"_deprecated_{name}"] - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/Lib/sqlite3/dump.py b/Lib/sqlite3/dump.py index 07b9da10b..57e6a3b4f 100644 --- a/Lib/sqlite3/dump.py +++ b/Lib/sqlite3/dump.py @@ -7,7 +7,15 @@ # future enhancements, you should normally quote any identifier that # is an English language word, even if you do not have to." -def _iterdump(connection): +def _quote_name(name): + return '"{0}"'.format(name.replace('"', '""')) + + +def _quote_value(value): + return "'{0}'".format(value.replace("'", "''")) + + +def _iterdump(connection, *, filter=None): """ Returns an iterator to the dump of the database in an SQL text format. @@ -16,64 +24,87 @@ def _iterdump(connection): directly but instead called from the Connection method, iterdump(). """ + writeable_schema = False cu = connection.cursor() + cu.row_factory = None # Make sure we get predictable results. + # Disable foreign key constraints, if there is any foreign key violation. + violations = cu.execute("PRAGMA foreign_key_check").fetchall() + if violations: + yield('PRAGMA foreign_keys=OFF;') yield('BEGIN TRANSACTION;') + if filter: + # Return database objects which match the filter pattern. + filter_name_clause = 'AND "name" LIKE ?' + params = [filter] + else: + filter_name_clause = "" + params = [] # sqlite_master table contains the SQL CREATE statements for the database. - q = """ + q = f""" SELECT "name", "type", "sql" FROM "sqlite_master" WHERE "sql" NOT NULL AND "type" == 'table' + {filter_name_clause} ORDER BY "name" """ - schema_res = cu.execute(q) + schema_res = cu.execute(q, params) sqlite_sequence = [] for table_name, type, sql in schema_res.fetchall(): if table_name == 'sqlite_sequence': - rows = cu.execute('SELECT * FROM "sqlite_sequence";').fetchall() + rows = cu.execute('SELECT * FROM "sqlite_sequence";') sqlite_sequence = ['DELETE FROM "sqlite_sequence"'] sqlite_sequence += [ - f'INSERT INTO "sqlite_sequence" VALUES(\'{row[0]}\',{row[1]})' - for row in rows + f'INSERT INTO "sqlite_sequence" VALUES({_quote_value(table_name)},{seq_value})' + for table_name, seq_value in rows.fetchall() ] continue elif table_name == 'sqlite_stat1': yield('ANALYZE "sqlite_master";') elif table_name.startswith('sqlite_'): continue - # NOTE: Virtual table support not implemented - #elif sql.startswith('CREATE VIRTUAL TABLE'): - # qtable = table_name.replace("'", "''") - # yield("INSERT INTO sqlite_master(type,name,tbl_name,rootpage,sql)"\ - # "VALUES('table','{0}','{0}',0,'{1}');".format( - # qtable, - # sql.replace("''"))) + elif sql.startswith('CREATE VIRTUAL TABLE'): + if not writeable_schema: + writeable_schema = True + yield('PRAGMA writable_schema=ON;') + yield("INSERT INTO sqlite_master(type,name,tbl_name,rootpage,sql)" + "VALUES('table',{0},{0},0,{1});".format( + _quote_value(table_name), + _quote_value(sql), + )) else: yield('{0};'.format(sql)) # Build the insert statement for each row of the current table - table_name_ident = table_name.replace('"', '""') - res = cu.execute('PRAGMA table_info("{0}")'.format(table_name_ident)) + table_name_ident = _quote_name(table_name) + res = cu.execute(f'PRAGMA table_info({table_name_ident})') column_names = [str(table_info[1]) for table_info in res.fetchall()] - q = """SELECT 'INSERT INTO "{0}" VALUES({1})' FROM "{0}";""".format( + q = "SELECT 'INSERT INTO {0} VALUES('{1}')' FROM {0};".format( table_name_ident, - ",".join("""'||quote("{0}")||'""".format(col.replace('"', '""')) for col in column_names)) + "','".join( + "||quote({0})||".format(_quote_name(col)) for col in column_names + ) + ) query_res = cu.execute(q) for row in query_res: yield("{0};".format(row[0])) # Now when the type is 'index', 'trigger', or 'view' - q = """ + q = f""" SELECT "name", "type", "sql" FROM "sqlite_master" WHERE "sql" NOT NULL AND "type" IN ('index', 'trigger', 'view') + {filter_name_clause} """ - schema_res = cu.execute(q) + schema_res = cu.execute(q, params) for name, type, sql in schema_res.fetchall(): yield('{0};'.format(sql)) + if writeable_schema: + yield('PRAGMA writable_schema=OFF;') + # gh-79009: Yield statements concerning the sqlite_sequence table at the # end of the transaction. for row in sqlite_sequence: diff --git a/Lib/test/test_dbm_sqlite3.py b/Lib/test/test_dbm_sqlite3.py index 39eac7a35..f367a9886 100644 --- a/Lib/test/test_dbm_sqlite3.py +++ b/Lib/test/test_dbm_sqlite3.py @@ -1,12 +1,11 @@ import os import stat import sys -import test.support import unittest from contextlib import closing from functools import partial from pathlib import Path -from test.support import cpython_only, import_helper, os_helper +from test.support import import_helper, os_helper dbm_sqlite3 = import_helper.import_module("dbm.sqlite3") # N.B. The test will fail on some platforms without sqlite3 @@ -44,7 +43,7 @@ class URI(unittest.TestCase): ) for path, normalized in dataset: with self.subTest(path=path, normalized=normalized): - self.assertTrue(_normalize_uri(path).endswith(normalized)) + self.assertEndsWith(_normalize_uri(path), normalized) @unittest.skipUnless(sys.platform == "win32", "requires Windows") def test_uri_windows(self): @@ -63,7 +62,7 @@ class URI(unittest.TestCase): with self.subTest(path=path, normalized=normalized): if not Path(path).is_absolute(): self.skipTest(f"skipping relative path: {path!r}") - self.assertTrue(_normalize_uri(path).endswith(normalized)) + self.assertEndsWith(_normalize_uri(path), normalized) class ReadOnly(_SQLiteDbmTests): diff --git a/Lib/test/test_sqlite3/__init__.py b/Lib/test/test_sqlite3/__init__.py index d777fca82..78a1e2078 100644 --- a/Lib/test/test_sqlite3/__init__.py +++ b/Lib/test/test_sqlite3/__init__.py @@ -8,8 +8,7 @@ import sqlite3 # Implement the unittest "load tests" protocol. def load_tests(*args): + if verbose: + print(f"test_sqlite3: testing with SQLite version {sqlite3.sqlite_version}") pkg_dir = os.path.dirname(__file__) return load_package_tests(pkg_dir, *args) - -if verbose: - print(f"test_sqlite3: testing with SQLite version {sqlite3.sqlite_version}") diff --git a/Lib/test/test_sqlite3/test_backup.py b/Lib/test/test_sqlite3/test_backup.py index fb3a83e3b..9d31978b1 100644 --- a/Lib/test/test_sqlite3/test_backup.py +++ b/Lib/test/test_sqlite3/test_backup.py @@ -1,6 +1,8 @@ import sqlite3 as sqlite import unittest +from .util import memory_database + class BackupTests(unittest.TestCase): def setUp(self): @@ -32,34 +34,32 @@ class BackupTests(unittest.TestCase): self.cx.backup(self.cx) def test_bad_target_closed_connection(self): - bck = sqlite.connect(':memory:') - bck.close() - with self.assertRaises(sqlite.ProgrammingError): - self.cx.backup(bck) + with memory_database() as bck: + bck.close() + with self.assertRaises(sqlite.ProgrammingError): + self.cx.backup(bck) def test_bad_source_closed_connection(self): - bck = sqlite.connect(':memory:') - source = sqlite.connect(":memory:") - source.close() - with self.assertRaises(sqlite.ProgrammingError): - source.backup(bck) + with memory_database() as bck: + source = sqlite.connect(":memory:") + source.close() + with self.assertRaises(sqlite.ProgrammingError): + source.backup(bck) def test_bad_target_in_transaction(self): - bck = sqlite.connect(':memory:') - bck.execute('CREATE TABLE bar (key INTEGER)') - bck.executemany('INSERT INTO bar (key) VALUES (?)', [(3,), (4,)]) - with self.assertRaises(sqlite.OperationalError) as cm: - self.cx.backup(bck) - if sqlite.sqlite_version_info < (3, 8, 8): - self.assertEqual(str(cm.exception), 'target is in transaction') + with memory_database() as bck: + bck.execute('CREATE TABLE bar (key INTEGER)') + bck.executemany('INSERT INTO bar (key) VALUES (?)', [(3,), (4,)]) + with self.assertRaises(sqlite.OperationalError) as cm: + self.cx.backup(bck) def test_keyword_only_args(self): with self.assertRaises(TypeError): - with sqlite.connect(':memory:') as bck: + with memory_database() as bck: self.cx.backup(bck, 1) def test_simple(self): - with sqlite.connect(':memory:') as bck: + with memory_database() as bck: self.cx.backup(bck) self.verify_backup(bck) @@ -69,7 +69,7 @@ class BackupTests(unittest.TestCase): def progress(status, remaining, total): journal.append(status) - with sqlite.connect(':memory:') as bck: + with memory_database() as bck: self.cx.backup(bck, pages=1, progress=progress) self.verify_backup(bck) @@ -83,7 +83,7 @@ class BackupTests(unittest.TestCase): def progress(status, remaining, total): journal.append(remaining) - with sqlite.connect(':memory:') as bck: + with memory_database() as bck: self.cx.backup(bck, progress=progress) self.verify_backup(bck) @@ -96,18 +96,17 @@ class BackupTests(unittest.TestCase): def progress(status, remaining, total): journal.append(remaining) - with sqlite.connect(':memory:') as bck: + with memory_database() as bck: self.cx.backup(bck, pages=-1, progress=progress) self.verify_backup(bck) self.assertEqual(len(journal), 1) self.assertEqual(journal[0], 0) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_non_callable_progress(self): with self.assertRaises(TypeError) as cm: - with sqlite.connect(':memory:') as bck: + with memory_database() as bck: self.cx.backup(bck, pages=1, progress='bar') self.assertEqual(str(cm.exception), 'progress argument must be a callable') @@ -120,7 +119,7 @@ class BackupTests(unittest.TestCase): self.cx.commit() journal.append(remaining) - with sqlite.connect(':memory:') as bck: + with memory_database() as bck: self.cx.backup(bck, pages=1, progress=progress) self.verify_backup(bck) @@ -139,17 +138,17 @@ class BackupTests(unittest.TestCase): raise SystemError('nearly out of space') with self.assertRaises(SystemError) as err: - with sqlite.connect(':memory:') as bck: + with memory_database() as bck: self.cx.backup(bck, progress=progress) self.assertEqual(str(err.exception), 'nearly out of space') def test_database_source_name(self): - with sqlite.connect(':memory:') as bck: + with memory_database() as bck: self.cx.backup(bck, name='main') - with sqlite.connect(':memory:') as bck: + with memory_database() as bck: self.cx.backup(bck, name='temp') with self.assertRaises(sqlite.OperationalError) as cm: - with sqlite.connect(':memory:') as bck: + with memory_database() as bck: self.cx.backup(bck, name='non-existing') self.assertIn("unknown database", str(cm.exception)) @@ -157,7 +156,7 @@ class BackupTests(unittest.TestCase): self.cx.execute('CREATE TABLE attached_db.foo (key INTEGER)') self.cx.executemany('INSERT INTO attached_db.foo (key) VALUES (?)', [(3,), (4,)]) self.cx.commit() - with sqlite.connect(':memory:') as bck: + with memory_database() as bck: self.cx.backup(bck, name='attached_db') self.verify_backup(bck) diff --git a/Lib/test/test_sqlite3/test_cli.py b/Lib/test/test_sqlite3/test_cli.py index 560cd9efc..a03d7cbe1 100644 --- a/Lib/test/test_sqlite3/test_cli.py +++ b/Lib/test/test_sqlite3/test_cli.py @@ -1,52 +1,52 @@ """sqlite3 CLI tests.""" - -import sqlite3 as sqlite -import subprocess -import sys +import sqlite3 import unittest -from test.support import SHORT_TIMEOUT#, requires_subprocess +from sqlite3.__main__ import main as cli from test.support.os_helper import TESTFN, unlink +from test.support import ( + captured_stdout, + captured_stderr, + captured_stdin, + force_not_colorized, +) -# TODO: RUSTPYTHON -#@requires_subprocess() class CommandLineInterface(unittest.TestCase): def _do_test(self, *args, expect_success=True): - with subprocess.Popen( - [sys.executable, "-Xutf8", "-m", "sqlite3", *args], - encoding="utf-8", - bufsize=0, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) as proc: - proc.wait() - if expect_success == bool(proc.returncode): - self.fail("".join(proc.stderr)) - stdout = proc.stdout.read() - stderr = proc.stderr.read() - if expect_success: - self.assertEqual(stderr, "") - else: - self.assertEqual(stdout, "") - return stdout, stderr + with ( + captured_stdout() as out, + captured_stderr() as err, + self.assertRaises(SystemExit) as cm + ): + cli(args) + return out.getvalue(), err.getvalue(), cm.exception.code def expect_success(self, *args): - out, _ = self._do_test(*args) + out, err, code = self._do_test(*args) + self.assertEqual(code, 0, + "\n".join([f"Unexpected failure: {args=}", out, err])) + self.assertEqual(err, "") return out def expect_failure(self, *args): - _, err = self._do_test(*args, expect_success=False) + out, err, code = self._do_test(*args, expect_success=False) + self.assertNotEqual(code, 0, + "\n".join([f"Unexpected failure: {args=}", out, err])) + self.assertEqual(out, "") return err + @force_not_colorized def test_cli_help(self): out = self.expect_success("-h") - self.assertIn("usage: python -m sqlite3", out) + self.assertIn("usage: ", out) + self.assertIn(" [-h] [-v] [filename] [sql]", out) + self.assertIn("Python sqlite3 CLI", out) def test_cli_version(self): out = self.expect_success("-v") - self.assertIn(sqlite.sqlite_version, out) + self.assertIn(sqlite3.sqlite_version, out) def test_cli_execute_sql(self): out = self.expect_success(":memory:", "select 1") @@ -69,88 +69,126 @@ class CommandLineInterface(unittest.TestCase): self.assertIn("(0,)", out) -# TODO: RUSTPYTHON -#@requires_subprocess() class InteractiveSession(unittest.TestCase): - TIMEOUT = SHORT_TIMEOUT / 10. MEMORY_DB_MSG = "Connected to a transient in-memory database" PS1 = "sqlite> " PS2 = "... " - def start_cli(self, *args): - return subprocess.Popen( - [sys.executable, "-Xutf8", "-m", "sqlite3", *args], - encoding="utf-8", - bufsize=0, - stdin=subprocess.PIPE, - # Note: the banner is printed to stderr, the prompt to stdout. - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) + def run_cli(self, *args, commands=()): + with ( + captured_stdin() as stdin, + captured_stdout() as stdout, + captured_stderr() as stderr, + self.assertRaises(SystemExit) as cm + ): + for cmd in commands: + stdin.write(cmd + "\n") + stdin.seek(0) + cli(args) - def expect_success(self, proc): - proc.wait() - if proc.returncode: - self.fail("".join(proc.stderr)) + out = stdout.getvalue() + err = stderr.getvalue() + self.assertEqual(cm.exception.code, 0, + f"Unexpected failure: {args=}\n{out}\n{err}") + return out, err def test_interact(self): - with self.start_cli() as proc: - out, err = proc.communicate(timeout=self.TIMEOUT) - self.assertIn(self.MEMORY_DB_MSG, err) - self.assertIn(self.PS1, out) - self.expect_success(proc) + out, err = self.run_cli() + self.assertIn(self.MEMORY_DB_MSG, err) + self.assertIn(self.MEMORY_DB_MSG, err) + self.assertEndsWith(out, self.PS1) + self.assertEqual(out.count(self.PS1), 1) + self.assertEqual(out.count(self.PS2), 0) def test_interact_quit(self): - with self.start_cli() as proc: - out, err = proc.communicate(input=".quit", timeout=self.TIMEOUT) - self.assertIn(self.MEMORY_DB_MSG, err) - self.assertIn(self.PS1, out) - self.expect_success(proc) + out, err = self.run_cli(commands=(".quit",)) + self.assertIn(self.MEMORY_DB_MSG, err) + self.assertEndsWith(out, self.PS1) + self.assertEqual(out.count(self.PS1), 1) + self.assertEqual(out.count(self.PS2), 0) def test_interact_version(self): - with self.start_cli() as proc: - out, err = proc.communicate(input=".version", timeout=self.TIMEOUT) - self.assertIn(self.MEMORY_DB_MSG, err) - self.assertIn(sqlite.sqlite_version, out) - self.expect_success(proc) + out, err = self.run_cli(commands=(".version",)) + self.assertIn(self.MEMORY_DB_MSG, err) + self.assertIn(sqlite3.sqlite_version + "\n", out) + self.assertEndsWith(out, self.PS1) + self.assertEqual(out.count(self.PS1), 2) + self.assertEqual(out.count(self.PS2), 0) + self.assertIn(sqlite3.sqlite_version, out) + + def test_interact_empty_source(self): + out, err = self.run_cli(commands=("", " ")) + self.assertIn(self.MEMORY_DB_MSG, err) + self.assertEndsWith(out, self.PS1) + self.assertEqual(out.count(self.PS1), 3) + self.assertEqual(out.count(self.PS2), 0) + + def test_interact_dot_commands_unknown(self): + out, err = self.run_cli(commands=(".unknown_command", )) + self.assertIn(self.MEMORY_DB_MSG, err) + self.assertEndsWith(out, self.PS1) + self.assertEqual(out.count(self.PS1), 2) + self.assertEqual(out.count(self.PS2), 0) + self.assertIn("Error", err) + # test "unknown_command" is pointed out in the error message + self.assertIn("unknown_command", err) + + def test_interact_dot_commands_empty(self): + out, err = self.run_cli(commands=(".")) + self.assertIn(self.MEMORY_DB_MSG, err) + self.assertEndsWith(out, self.PS1) + self.assertEqual(out.count(self.PS1), 2) + self.assertEqual(out.count(self.PS2), 0) + + def test_interact_dot_commands_with_whitespaces(self): + out, err = self.run_cli(commands=(".version ", ". version")) + self.assertIn(self.MEMORY_DB_MSG, err) + self.assertEqual(out.count(sqlite3.sqlite_version + "\n"), 2) + self.assertEndsWith(out, self.PS1) + self.assertEqual(out.count(self.PS1), 3) + self.assertEqual(out.count(self.PS2), 0) def test_interact_valid_sql(self): - with self.start_cli() as proc: - out, err = proc.communicate(input="select 1;", - timeout=self.TIMEOUT) - self.assertIn(self.MEMORY_DB_MSG, err) - self.assertIn("(1,)", out) - self.expect_success(proc) + out, err = self.run_cli(commands=("SELECT 1;",)) + self.assertIn(self.MEMORY_DB_MSG, err) + self.assertIn("(1,)\n", out) + self.assertEndsWith(out, self.PS1) + self.assertEqual(out.count(self.PS1), 2) + self.assertEqual(out.count(self.PS2), 0) + + def test_interact_incomplete_multiline_sql(self): + out, err = self.run_cli(commands=("SELECT 1",)) + self.assertIn(self.MEMORY_DB_MSG, err) + self.assertEndsWith(out, self.PS2) + self.assertEqual(out.count(self.PS1), 1) + self.assertEqual(out.count(self.PS2), 1) def test_interact_valid_multiline_sql(self): - with self.start_cli() as proc: - out, err = proc.communicate(input="select 1\n;", - timeout=self.TIMEOUT) - self.assertIn(self.MEMORY_DB_MSG, err) - self.assertIn(self.PS2, out) - self.assertIn("(1,)", out) - self.expect_success(proc) + out, err = self.run_cli(commands=("SELECT 1\n;",)) + self.assertIn(self.MEMORY_DB_MSG, err) + self.assertIn(self.PS2, out) + self.assertIn("(1,)\n", out) + self.assertEndsWith(out, self.PS1) + self.assertEqual(out.count(self.PS1), 2) + self.assertEqual(out.count(self.PS2), 1) def test_interact_invalid_sql(self): - with self.start_cli() as proc: - out, err = proc.communicate(input="sel;", timeout=self.TIMEOUT) - self.assertIn(self.MEMORY_DB_MSG, err) - self.assertIn("OperationalError (SQLITE_ERROR)", err) - self.expect_success(proc) + out, err = self.run_cli(commands=("sel;",)) + self.assertIn(self.MEMORY_DB_MSG, err) + self.assertIn("OperationalError (SQLITE_ERROR)", err) + self.assertEndsWith(out, self.PS1) + self.assertEqual(out.count(self.PS1), 2) + self.assertEqual(out.count(self.PS2), 0) def test_interact_on_disk_file(self): self.addCleanup(unlink, TESTFN) - with self.start_cli(TESTFN) as proc: - out, err = proc.communicate(input="create table t(t);", - timeout=self.TIMEOUT) - self.assertIn(TESTFN, err) - self.assertIn(self.PS1, out) - self.expect_success(proc) - with self.start_cli(TESTFN, "select count(t) from t") as proc: - out = proc.stdout.read() - err = proc.stderr.read() - self.assertIn("(0,)", out) - self.expect_success(proc) + + out, err = self.run_cli(TESTFN, commands=("CREATE TABLE t(t);",)) + self.assertIn(TESTFN, err) + self.assertEndsWith(out, self.PS1) + + out, _ = self.run_cli(TESTFN, commands=("SELECT count(t) FROM t;",)) + self.assertIn("(0,)\n", out) if __name__ == "__main__": diff --git a/Lib/test/test_sqlite3/test_dbapi.py b/Lib/test/test_sqlite3/test_dbapi.py index a3ddd660c..46f098e46 100644 --- a/Lib/test/test_sqlite3/test_dbapi.py +++ b/Lib/test/test_sqlite3/test_dbapi.py @@ -21,6 +21,7 @@ # 3. This notice may not be removed or altered from any source distribution. import contextlib +import functools import os import sqlite3 as sqlite import subprocess @@ -28,33 +29,18 @@ import sys import threading import unittest import urllib.parse +import warnings from test.support import ( - SHORT_TIMEOUT, check_disallow_instantiation,# requires_subprocess, - #is_emscripten, is_wasi -# TODO: RUSTPYTHON + SHORT_TIMEOUT, check_disallow_instantiation, requires_subprocess ) -from test.support import threading_helper -# TODO: RUSTPYTHON -#from _testcapi import INT_MAX, ULLONG_MAX +from test.support import gc_collect +from test.support import threading_helper, import_helper from os import SEEK_SET, SEEK_CUR, SEEK_END from test.support.os_helper import TESTFN, TESTFN_UNDECODABLE, unlink, temp_dir, FakePath - -# Helper for temporary memory databases -def memory_database(*args, **kwargs): - cx = sqlite.connect(":memory:", *args, **kwargs) - return contextlib.closing(cx) - - -# Temporarily limit a database connection parameter -@contextlib.contextmanager -def cx_limit(cx, category=sqlite.SQLITE_LIMIT_SQL_LENGTH, limit=128): - try: - _prev = cx.setlimit(category, limit) - yield limit - finally: - cx.setlimit(category, _prev) +from .util import memory_database, cx_limit +from .util import MemoryDatabaseMixin class ModuleTests(unittest.TestCase): @@ -62,17 +48,6 @@ class ModuleTests(unittest.TestCase): self.assertEqual(sqlite.apilevel, "2.0", "apilevel is %s, should be 2.0" % sqlite.apilevel) - def test_deprecated_version(self): - msg = "deprecated and will be removed in Python 3.14" - for attr in "version", "version_info": - with self.subTest(attr=attr): - with self.assertWarnsRegex(DeprecationWarning, msg) as cm: - getattr(sqlite, attr) - self.assertEqual(cm.filename, __file__) - with self.assertWarnsRegex(DeprecationWarning, msg) as cm: - getattr(sqlite.dbapi2, attr) - self.assertEqual(cm.filename, __file__) - def test_thread_safety(self): self.assertIn(sqlite.threadsafety, {0, 1, 3}, "threadsafety is %d, should be 0, 1 or 3" % @@ -84,45 +59,34 @@ class ModuleTests(unittest.TestCase): sqlite.paramstyle) def test_warning(self): - self.assertTrue(issubclass(sqlite.Warning, Exception), - "Warning is not a subclass of Exception") + self.assertIsSubclass(sqlite.Warning, Exception) def test_error(self): - self.assertTrue(issubclass(sqlite.Error, Exception), - "Error is not a subclass of Exception") + self.assertIsSubclass(sqlite.Error, Exception) def test_interface_error(self): - self.assertTrue(issubclass(sqlite.InterfaceError, sqlite.Error), - "InterfaceError is not a subclass of Error") + self.assertIsSubclass(sqlite.InterfaceError, sqlite.Error) def test_database_error(self): - self.assertTrue(issubclass(sqlite.DatabaseError, sqlite.Error), - "DatabaseError is not a subclass of Error") + self.assertIsSubclass(sqlite.DatabaseError, sqlite.Error) def test_data_error(self): - self.assertTrue(issubclass(sqlite.DataError, sqlite.DatabaseError), - "DataError is not a subclass of DatabaseError") + self.assertIsSubclass(sqlite.DataError, sqlite.DatabaseError) def test_operational_error(self): - self.assertTrue(issubclass(sqlite.OperationalError, sqlite.DatabaseError), - "OperationalError is not a subclass of DatabaseError") + self.assertIsSubclass(sqlite.OperationalError, sqlite.DatabaseError) def test_integrity_error(self): - self.assertTrue(issubclass(sqlite.IntegrityError, sqlite.DatabaseError), - "IntegrityError is not a subclass of DatabaseError") + self.assertIsSubclass(sqlite.IntegrityError, sqlite.DatabaseError) def test_internal_error(self): - self.assertTrue(issubclass(sqlite.InternalError, sqlite.DatabaseError), - "InternalError is not a subclass of DatabaseError") + self.assertIsSubclass(sqlite.InternalError, sqlite.DatabaseError) def test_programming_error(self): - self.assertTrue(issubclass(sqlite.ProgrammingError, sqlite.DatabaseError), - "ProgrammingError is not a subclass of DatabaseError") + self.assertIsSubclass(sqlite.ProgrammingError, sqlite.DatabaseError) def test_not_supported_error(self): - self.assertTrue(issubclass(sqlite.NotSupportedError, - sqlite.DatabaseError), - "NotSupportedError is not a subclass of DatabaseError") + self.assertIsSubclass(sqlite.NotSupportedError, sqlite.DatabaseError) def test_module_constants(self): consts = [ @@ -167,6 +131,7 @@ class ModuleTests(unittest.TestCase): "SQLITE_INTERNAL", "SQLITE_INTERRUPT", "SQLITE_IOERR", + "SQLITE_LIMIT_WORKER_THREADS", "SQLITE_LOCKED", "SQLITE_MISMATCH", "SQLITE_MISUSE", @@ -174,6 +139,7 @@ class ModuleTests(unittest.TestCase): "SQLITE_NOMEM", "SQLITE_NOTADB", "SQLITE_NOTFOUND", + "SQLITE_NOTICE", "SQLITE_OK", "SQLITE_PERM", "SQLITE_PRAGMA", @@ -181,6 +147,7 @@ class ModuleTests(unittest.TestCase): "SQLITE_RANGE", "SQLITE_READ", "SQLITE_READONLY", + "SQLITE_RECURSIVE", "SQLITE_REINDEX", "SQLITE_ROW", "SQLITE_SAVEPOINT", @@ -189,6 +156,7 @@ class ModuleTests(unittest.TestCase): "SQLITE_TOOBIG", "SQLITE_TRANSACTION", "SQLITE_UPDATE", + "SQLITE_WARNING", # Run-time limit categories "SQLITE_LIMIT_LENGTH", "SQLITE_LIMIT_SQL_LENGTH", @@ -202,32 +170,43 @@ class ModuleTests(unittest.TestCase): "SQLITE_LIMIT_VARIABLE_NUMBER", "SQLITE_LIMIT_TRIGGER_DEPTH", ] - if sqlite.sqlite_version_info >= (3, 7, 17): - consts += ["SQLITE_NOTICE", "SQLITE_WARNING"] - if sqlite.sqlite_version_info >= (3, 8, 3): - consts.append("SQLITE_RECURSIVE") - if sqlite.sqlite_version_info >= (3, 8, 7): - consts.append("SQLITE_LIMIT_WORKER_THREADS") consts += ["PARSE_DECLTYPES", "PARSE_COLNAMES"] # Extended result codes consts += [ "SQLITE_ABORT_ROLLBACK", + "SQLITE_AUTH_USER", "SQLITE_BUSY_RECOVERY", + "SQLITE_BUSY_SNAPSHOT", + "SQLITE_CANTOPEN_CONVPATH", "SQLITE_CANTOPEN_FULLPATH", "SQLITE_CANTOPEN_ISDIR", "SQLITE_CANTOPEN_NOTEMPDIR", + "SQLITE_CONSTRAINT_CHECK", + "SQLITE_CONSTRAINT_COMMITHOOK", + "SQLITE_CONSTRAINT_FOREIGNKEY", + "SQLITE_CONSTRAINT_FUNCTION", + "SQLITE_CONSTRAINT_NOTNULL", + "SQLITE_CONSTRAINT_PRIMARYKEY", + "SQLITE_CONSTRAINT_ROWID", + "SQLITE_CONSTRAINT_TRIGGER", + "SQLITE_CONSTRAINT_UNIQUE", + "SQLITE_CONSTRAINT_VTAB", "SQLITE_CORRUPT_VTAB", "SQLITE_IOERR_ACCESS", + "SQLITE_IOERR_AUTH", "SQLITE_IOERR_BLOCKED", "SQLITE_IOERR_CHECKRESERVEDLOCK", "SQLITE_IOERR_CLOSE", + "SQLITE_IOERR_CONVPATH", "SQLITE_IOERR_DELETE", "SQLITE_IOERR_DELETE_NOENT", "SQLITE_IOERR_DIR_CLOSE", "SQLITE_IOERR_DIR_FSYNC", "SQLITE_IOERR_FSTAT", "SQLITE_IOERR_FSYNC", + "SQLITE_IOERR_GETTEMPPATH", "SQLITE_IOERR_LOCK", + "SQLITE_IOERR_MMAP", "SQLITE_IOERR_NOMEM", "SQLITE_IOERR_RDLOCK", "SQLITE_IOERR_READ", @@ -239,50 +218,18 @@ class ModuleTests(unittest.TestCase): "SQLITE_IOERR_SHORT_READ", "SQLITE_IOERR_TRUNCATE", "SQLITE_IOERR_UNLOCK", + "SQLITE_IOERR_VNODE", "SQLITE_IOERR_WRITE", "SQLITE_LOCKED_SHAREDCACHE", + "SQLITE_NOTICE_RECOVER_ROLLBACK", + "SQLITE_NOTICE_RECOVER_WAL", + "SQLITE_OK_LOAD_PERMANENTLY", "SQLITE_READONLY_CANTLOCK", + "SQLITE_READONLY_DBMOVED", "SQLITE_READONLY_RECOVERY", + "SQLITE_READONLY_ROLLBACK", + "SQLITE_WARNING_AUTOINDEX", ] - if sqlite.sqlite_version_info >= (3, 7, 16): - consts += [ - "SQLITE_CONSTRAINT_CHECK", - "SQLITE_CONSTRAINT_COMMITHOOK", - "SQLITE_CONSTRAINT_FOREIGNKEY", - "SQLITE_CONSTRAINT_FUNCTION", - "SQLITE_CONSTRAINT_NOTNULL", - "SQLITE_CONSTRAINT_PRIMARYKEY", - "SQLITE_CONSTRAINT_TRIGGER", - "SQLITE_CONSTRAINT_UNIQUE", - "SQLITE_CONSTRAINT_VTAB", - "SQLITE_READONLY_ROLLBACK", - ] - if sqlite.sqlite_version_info >= (3, 7, 17): - consts += [ - "SQLITE_IOERR_MMAP", - "SQLITE_NOTICE_RECOVER_ROLLBACK", - "SQLITE_NOTICE_RECOVER_WAL", - ] - if sqlite.sqlite_version_info >= (3, 8, 0): - consts += [ - "SQLITE_BUSY_SNAPSHOT", - "SQLITE_IOERR_GETTEMPPATH", - "SQLITE_WARNING_AUTOINDEX", - ] - if sqlite.sqlite_version_info >= (3, 8, 1): - consts += ["SQLITE_CANTOPEN_CONVPATH", "SQLITE_IOERR_CONVPATH"] - if sqlite.sqlite_version_info >= (3, 8, 2): - consts.append("SQLITE_CONSTRAINT_ROWID") - if sqlite.sqlite_version_info >= (3, 8, 3): - consts.append("SQLITE_READONLY_DBMOVED") - if sqlite.sqlite_version_info >= (3, 8, 7): - consts.append("SQLITE_AUTH_USER") - if sqlite.sqlite_version_info >= (3, 9, 0): - consts.append("SQLITE_IOERR_VNODE") - if sqlite.sqlite_version_info >= (3, 10, 0): - consts.append("SQLITE_IOERR_AUTH") - if sqlite.sqlite_version_info >= (3, 14, 1): - consts.append("SQLITE_OK_LOAD_PERMANENTLY") if sqlite.sqlite_version_info >= (3, 21, 0): consts += [ "SQLITE_IOERR_BEGIN_ATOMIC", @@ -316,7 +263,7 @@ class ModuleTests(unittest.TestCase): consts.append("SQLITE_IOERR_CORRUPTFS") for const in consts: with self.subTest(const=const): - self.assertTrue(hasattr(sqlite, const)) + self.assertHasAttr(sqlite, const) def test_error_code_on_exception(self): err_msg = "unable to open database file" @@ -330,10 +277,8 @@ class ModuleTests(unittest.TestCase): sqlite.connect(db) e = cm.exception self.assertEqual(e.sqlite_errorcode, err_code) - self.assertTrue(e.sqlite_errorname.startswith("SQLITE_CANTOPEN")) + self.assertStartsWith(e.sqlite_errorname, "SQLITE_CANTOPEN") - @unittest.skipIf(sqlite.sqlite_version_info <= (3, 7, 16), - "Requires SQLite 3.7.16 or newer") def test_extended_error_code_on_exception(self): with memory_database() as con: with con: @@ -347,9 +292,9 @@ class ModuleTests(unittest.TestCase): self.assertEqual(exc.sqlite_errorname, "SQLITE_CONSTRAINT_CHECK") def test_disallow_instantiation(self): - cx = sqlite.connect(":memory:") - check_disallow_instantiation(self, type(cx("select 1"))) - check_disallow_instantiation(self, sqlite.Blob) + with memory_database() as cx: + check_disallow_instantiation(self, type(cx("select 1"))) + check_disallow_instantiation(self, sqlite.Blob) def test_complete_statement(self): self.assertFalse(sqlite.complete_statement("select t")) @@ -363,6 +308,7 @@ class ConnectionTests(unittest.TestCase): cu = self.cx.cursor() cu.execute("create table test(id integer primary key, name text)") cu.execute("insert into test(name) values (?)", ("foo",)) + cu.close() def tearDown(self): self.cx.close() @@ -418,8 +364,7 @@ class ConnectionTests(unittest.TestCase): with self.cx: pass - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_exceptions(self): # Optional DB-API extension. self.assertEqual(self.cx.Warning, sqlite.Warning) @@ -435,28 +380,28 @@ class ConnectionTests(unittest.TestCase): def test_in_transaction(self): # Can't use db from setUp because we want to test initial state. - cx = sqlite.connect(":memory:") - cu = cx.cursor() - self.assertEqual(cx.in_transaction, False) - cu.execute("create table transactiontest(id integer primary key, name text)") - self.assertEqual(cx.in_transaction, False) - cu.execute("insert into transactiontest(name) values (?)", ("foo",)) - self.assertEqual(cx.in_transaction, True) - cu.execute("select name from transactiontest where name=?", ["foo"]) - row = cu.fetchone() - self.assertEqual(cx.in_transaction, True) - cx.commit() - self.assertEqual(cx.in_transaction, False) - cu.execute("select name from transactiontest where name=?", ["foo"]) - row = cu.fetchone() - self.assertEqual(cx.in_transaction, False) + with memory_database() as cx: + cu = cx.cursor() + self.assertEqual(cx.in_transaction, False) + cu.execute("create table transactiontest(id integer primary key, name text)") + self.assertEqual(cx.in_transaction, False) + cu.execute("insert into transactiontest(name) values (?)", ("foo",)) + self.assertEqual(cx.in_transaction, True) + cu.execute("select name from transactiontest where name=?", ["foo"]) + row = cu.fetchone() + self.assertEqual(cx.in_transaction, True) + cx.commit() + self.assertEqual(cx.in_transaction, False) + cu.execute("select name from transactiontest where name=?", ["foo"]) + row = cu.fetchone() + self.assertEqual(cx.in_transaction, False) + cu.close() def test_in_transaction_ro(self): with self.assertRaises(AttributeError): self.cx.in_transaction = True - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_connection_exceptions(self): exceptions = [ "DataError", @@ -471,14 +416,13 @@ class ConnectionTests(unittest.TestCase): ] for exc in exceptions: with self.subTest(exc=exc): - self.assertTrue(hasattr(self.cx, exc)) + self.assertHasAttr(self.cx, exc) self.assertIs(getattr(sqlite, exc), getattr(self.cx, exc)) def test_interrupt_on_closed_db(self): - cx = sqlite.connect(":memory:") - cx.close() + self.cx.close() with self.assertRaises(sqlite.ProgrammingError): - cx.interrupt() + self.cx.interrupt() def test_interrupt(self): self.assertIsNone(self.cx.interrupt()) @@ -546,29 +490,29 @@ class ConnectionTests(unittest.TestCase): self.assertEqual(cx.isolation_level, level) def test_connection_reinit(self): - db = ":memory:" - cx = sqlite.connect(db) - cx.text_factory = bytes - cx.row_factory = sqlite.Row - cu = cx.cursor() - cu.execute("create table foo (bar)") - cu.executemany("insert into foo (bar) values (?)", - ((str(v),) for v in range(4))) - cu.execute("select bar from foo") + with memory_database() as cx: + cx.text_factory = bytes + cx.row_factory = sqlite.Row + cu = cx.cursor() + cu.execute("create table foo (bar)") + cu.executemany("insert into foo (bar) values (?)", + ((str(v),) for v in range(4))) + cu.execute("select bar from foo") - rows = [r for r in cu.fetchmany(2)] - self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows)) - self.assertEqual([r[0] for r in rows], [b"0", b"1"]) + rows = [r for r in cu.fetchmany(2)] + self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows)) + self.assertEqual([r[0] for r in rows], [b"0", b"1"]) - cx.__init__(db) - cx.execute("create table foo (bar)") - cx.executemany("insert into foo (bar) values (?)", - ((v,) for v in ("a", "b", "c", "d"))) + cx.__init__(":memory:") + cx.execute("create table foo (bar)") + cx.executemany("insert into foo (bar) values (?)", + ((v,) for v in ("a", "b", "c", "d"))) - # This uses the old database, old row factory, but new text factory - rows = [r for r in cu.fetchall()] - self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows)) - self.assertEqual([r[0] for r in rows], ["2", "3"]) + # This uses the old database, old row factory, but new text factory + rows = [r for r in cu.fetchall()] + self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows)) + self.assertEqual([r[0] for r in rows], ["2", "3"]) + cu.close() def test_connection_bad_reinit(self): cx = sqlite.connect(":memory:") @@ -583,6 +527,58 @@ class ConnectionTests(unittest.TestCase): cx.executemany, "insert into t values(?)", ((v,) for v in range(3))) + @unittest.expectedFailure # TODO: RUSTPYTHON SQLITE_DBCONFIG constants not implemented + def test_connection_config(self): + op = sqlite.SQLITE_DBCONFIG_ENABLE_FKEY + with memory_database() as cx: + with self.assertRaisesRegex(ValueError, "unknown"): + cx.getconfig(-1) + + # Toggle and verify. + old = cx.getconfig(op) + new = not old + cx.setconfig(op, new) + self.assertEqual(cx.getconfig(op), new) + + cx.setconfig(op) # defaults to True + self.assertTrue(cx.getconfig(op)) + + # Check that foreign key support was actually enabled. + with cx: + cx.executescript(""" + create table t(t integer primary key); + create table u(u, foreign key(u) references t(t)); + """) + with self.assertRaisesRegex(sqlite.IntegrityError, "constraint"): + cx.execute("insert into u values(0)") + + @unittest.expectedFailure # TODO: RUSTPYTHON deprecation warning not emitted for positional args + def test_connect_positional_arguments(self): + regex = ( + r"Passing more than 1 positional argument to sqlite3.connect\(\)" + " is deprecated. Parameters 'timeout', 'detect_types', " + "'isolation_level', 'check_same_thread', 'factory', " + "'cached_statements' and 'uri' will become keyword-only " + "parameters in Python 3.15." + ) + with self.assertWarnsRegex(DeprecationWarning, regex) as cm: + cx = sqlite.connect(":memory:", 1.0) + cx.close() + self.assertEqual(cm.filename, __file__) + + @unittest.expectedFailure # TODO: RUSTPYTHON ResourceWarning not emitted + def test_connection_resource_warning(self): + with self.assertWarns(ResourceWarning): + cx = sqlite.connect(":memory:") + del cx + gc_collect() + + @unittest.expectedFailure # TODO: RUSTPYTHON Connection signature inspection not working + def test_connection_signature(self): + from inspect import signature + sig = signature(self.cx) + self.assertEqual(str(sig), "(sql, /)") + class UninitialisedConnectionTests(unittest.TestCase): def setUp(self): @@ -613,7 +609,6 @@ class SerializeTests(unittest.TestCase): with cx: cx.execute("create table t(t)") data = cx.serialize() - self.assertEqual(len(data), 8192) # Remove test table, verify that it was removed. with cx: @@ -651,6 +646,14 @@ class SerializeTests(unittest.TestCase): class OpenTests(unittest.TestCase): _sql = "create table test(id integer)" + def test_open_with_bytes_path(self): + path = os.fsencode(TESTFN) + self.addCleanup(unlink, path) + self.assertFalse(os.path.exists(path)) + with contextlib.closing(sqlite.connect(path)) as cx: + self.assertTrue(os.path.exists(path)) + cx.execute(self._sql) + def test_open_with_path_like_object(self): """ Checks that we can successfully connect to a database using an object that is PathLike, i.e. has __fspath__(). """ @@ -661,15 +664,21 @@ class OpenTests(unittest.TestCase): self.assertTrue(os.path.exists(path)) cx.execute(self._sql) - @unittest.skipIf(sys.platform == "win32", "skipped on Windows") - @unittest.skipIf(sys.platform == "darwin", "skipped on macOS") - # TODO: RUSTPYTHON - # @unittest.skipIf(is_emscripten or is_wasi, "not supported on Emscripten/WASI") - @unittest.skipUnless(TESTFN_UNDECODABLE, "only works if there are undecodable paths") - def test_open_with_undecodable_path(self): + def get_undecodable_path(self): path = TESTFN_UNDECODABLE + if not path: + self.skipTest("only works if there are undecodable paths") + try: + open(path, 'wb').close() + except OSError: + self.skipTest(f"can't create file with undecodable path {path!r}") + unlink(path) + return path + + @unittest.skipIf(sys.platform == "win32", "skipped on Windows") + def test_open_with_undecodable_path(self): + path = self.get_undecodable_path() self.addCleanup(unlink, path) - self.assertFalse(os.path.exists(path)) with contextlib.closing(sqlite.connect(path)) as cx: self.assertTrue(os.path.exists(path)) cx.execute(self._sql) @@ -709,21 +718,15 @@ class OpenTests(unittest.TestCase): cx.execute(self._sql) @unittest.skipIf(sys.platform == "win32", "skipped on Windows") - @unittest.skipIf(sys.platform == "darwin", "skipped on macOS") - # TODO: RUSTPYTHON - # @unittest.skipIf(is_emscripten or is_wasi, "not supported on Emscripten/WASI") - @unittest.skipUnless(TESTFN_UNDECODABLE, "only works if there are undecodable paths") def test_open_undecodable_uri(self): - path = TESTFN_UNDECODABLE + path = self.get_undecodable_path() self.addCleanup(unlink, path) uri = "file:" + urllib.parse.quote(path) - self.assertFalse(os.path.exists(path)) with contextlib.closing(sqlite.connect(uri, uri=True)) as cx: self.assertTrue(os.path.exists(path)) cx.execute(self._sql) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_factory_database_arg(self): def factory(database, *args, **kwargs): nonlocal database_arg @@ -872,6 +875,34 @@ class CursorTests(unittest.TestCase): with self.assertRaises(ZeroDivisionError): self.cu.execute("select name from test where name=?", L()) + @unittest.expectedFailure # TODO: RUSTPYTHON mixed named and positional parameters not validated + def test_execute_named_param_and_sequence(self): + dataset = ( + ("select :a", (1,)), + ("select :a, ?, ?", (1, 2, 3)), + ("select ?, :b, ?", (1, 2, 3)), + ("select ?, ?, :c", (1, 2, 3)), + ("select :a, :b, ?", (1, 2, 3)), + ) + msg = "Binding.*is a named parameter" + for query, params in dataset: + with self.subTest(query=query, params=params): + with self.assertRaisesRegex(sqlite.ProgrammingError, msg) as cm: + self.cu.execute(query, params) + + def test_execute_indexed_nameless_params(self): + # See gh-117995: "'?1' is considered a named placeholder" + for query, params, expected in ( + ("select ?1, ?2", (1, 2), (1, 2)), + ("select ?2, ?1", (1, 2), (2, 1)), + ): + with self.subTest(query=query, params=params): + with warnings.catch_warnings(): + warnings.simplefilter("error", DeprecationWarning) + cu = self.cu.execute(query, params) + actual, = cu.fetchall() + self.assertEqual(actual, expected) + def test_execute_too_many_params(self): category = sqlite.SQLITE_LIMIT_VARIABLE_NUMBER msg = "too many SQL variables" @@ -1047,7 +1078,7 @@ class CursorTests(unittest.TestCase): # now set to 2 self.cu.arraysize = 2 - # now make the query return 3 rows + # now make the query return 2 rows from a table of 3 rows self.cu.execute("delete from test") self.cu.execute("insert into test(name) values ('A')") self.cu.execute("insert into test(name) values ('B')") @@ -1057,13 +1088,53 @@ class CursorTests(unittest.TestCase): self.assertEqual(len(res), 2) + @unittest.expectedFailure # TODO: RUSTPYTHON arraysize validation not implemented + def test_invalid_array_size(self): + UINT32_MAX = (1 << 32) - 1 + setter = functools.partial(setattr, self.cu, 'arraysize') + + self.assertRaises(TypeError, setter, 1.0) + self.assertRaises(ValueError, setter, -3) + self.assertRaises(OverflowError, setter, UINT32_MAX + 1) + + @unittest.expectedFailure # TODO: RUSTPYTHON fetchmany behavior with exhausted cursor differs def test_fetchmany(self): + # no active SQL statement + res = self.cu.fetchmany() + self.assertEqual(res, []) + res = self.cu.fetchmany(1000) + self.assertEqual(res, []) + + # test default parameter + self.cu.execute("select name from test") + res = self.cu.fetchmany() + self.assertEqual(len(res), 1) + + # test when the number of requested rows exceeds the actual count self.cu.execute("select name from test") res = self.cu.fetchmany(100) self.assertEqual(len(res), 1) res = self.cu.fetchmany(100) self.assertEqual(res, []) + # test when size = 0 + self.cu.execute("select name from test") + res = self.cu.fetchmany(0) + self.assertEqual(res, []) + res = self.cu.fetchmany(100) + self.assertEqual(len(res), 1) + res = self.cu.fetchmany(100) + self.assertEqual(res, []) + + @unittest.expectedFailure # TODO: RUSTPYTHON fetchmany size validation not implemented + def test_invalid_fetchmany(self): + UINT32_MAX = (1 << 32) - 1 + fetchmany = self.cu.fetchmany + + self.assertRaises(TypeError, fetchmany, 1.0) + self.assertRaises(ValueError, fetchmany, -3) + self.assertRaises(OverflowError, fetchmany, UINT32_MAX + 1) + def test_fetchmany_kw_arg(self): """Checks if fetchmany works with keyword arguments""" self.cu.execute("select name from test") @@ -1183,12 +1254,9 @@ class BlobTests(unittest.TestCase): self.blob.seek(-10, SEEK_END) self.assertEqual(self.blob.tell(), 40) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_blob_seek_error(self): msg_oor = "offset out of blob range" msg_orig = "'origin' should be os.SEEK_SET, os.SEEK_CUR, or os.SEEK_END" - msg_of = "seek offset results in overflow" dataset = ( (ValueError, msg_oor, lambda: self.blob.seek(1000)), @@ -1200,12 +1268,15 @@ class BlobTests(unittest.TestCase): with self.subTest(exc=exc, msg=msg, fn=fn): self.assertRaisesRegex(exc, msg, fn) + def test_blob_seek_overflow_error(self): # Force overflow errors + msg_of = "seek offset results in overflow" + _testcapi = import_helper.import_module("_testcapi") self.blob.seek(1, SEEK_SET) with self.assertRaisesRegex(OverflowError, msg_of): - self.blob.seek(INT_MAX, SEEK_CUR) + self.blob.seek(_testcapi.INT_MAX, SEEK_CUR) with self.assertRaisesRegex(OverflowError, msg_of): - self.blob.seek(INT_MAX, SEEK_END) + self.blob.seek(_testcapi.INT_MAX, SEEK_END) def test_blob_read(self): buf = self.blob.read() @@ -1359,24 +1430,24 @@ class BlobTests(unittest.TestCase): with self.assertRaisesRegex(TypeError, msg): self.blob["a"] = b"b" - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_blob_get_item_error(self): dataset = [len(self.blob), 105, -105] for idx in dataset: with self.subTest(idx=idx): with self.assertRaisesRegex(IndexError, "index out of range"): self.blob[idx] - with self.assertRaisesRegex(IndexError, "cannot fit 'int'"): - self.blob[ULLONG_MAX] # Provoke read error self.cx.execute("update test set b='aaaa' where rowid=1") with self.assertRaises(sqlite.OperationalError): self.blob[0] - # TODO: RUSTPYTHON - @unittest.expectedFailure + def test_blob_get_item_error_bigint(self): + _testcapi = import_helper.import_module("_testcapi") + with self.assertRaisesRegex(IndexError, "cannot fit 'int'"): + self.blob[_testcapi.ULLONG_MAX] + + @unittest.expectedFailure # TODO: RUSTPYTHON def test_blob_set_item_error(self): with self.assertRaisesRegex(TypeError, "cannot be interpreted"): self.blob[0] = b"multiple" @@ -1413,7 +1484,7 @@ class BlobTests(unittest.TestCase): self.blob + self.blob with self.assertRaisesRegex(TypeError, "unsupported operand"): self.blob * 5 - with self.assertRaisesRegex(TypeError, "is not iterable"): + with self.assertRaisesRegex(TypeError, "is not.+iterable"): b"a" in self.blob def test_blob_context_manager(self): @@ -1474,9 +1545,16 @@ class BlobTests(unittest.TestCase): "Cannot operate on a closed database", blob.read) + def test_blob_32bit_rowid(self): + # gh-100370: we should not get an OverflowError for 32-bit rowids + with memory_database() as cx: + rowid = 2**32 + cx.execute("create table t(t blob)") + cx.execute("insert into t(rowid, t) values (?, zeroblob(1))", (rowid,)) + cx.blobopen('t', 't', rowid) -# TODO: RUSTPYTHON -# @threading_helper.requires_working_threading() + +@threading_helper.requires_working_threading() class ThreadTests(unittest.TestCase): def setUp(self): self.con = sqlite.connect(":memory:") @@ -1529,8 +1607,7 @@ class ThreadTests(unittest.TestCase): with self.subTest(fn=fn): self._run_test(fn) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_check_cursor_thread(self): fns = [ lambda: self.cur.execute("insert into test(name) values('a')"), @@ -1551,12 +1628,12 @@ class ThreadTests(unittest.TestCase): except sqlite.Error: err.append("multi-threading not allowed") - con = sqlite.connect(":memory:", check_same_thread=False) - err = [] - t = threading.Thread(target=run, kwargs={"con": con, "err": err}) - t.start() - t.join() - self.assertEqual(len(err), 0, "\n".join(err)) + with memory_database(check_same_thread=False) as con: + err = [] + t = threading.Thread(target=run, kwargs={"con": con, "err": err}) + t.start() + t.join() + self.assertEqual(len(err), 0, "\n".join(err)) class ConstructorTests(unittest.TestCase): @@ -1582,9 +1659,16 @@ class ConstructorTests(unittest.TestCase): b = sqlite.Binary(b"\0'") class ExtensionTests(unittest.TestCase): + def setUp(self): + self.con = sqlite.connect(":memory:") + self.cur = self.con.cursor() + + def tearDown(self): + self.cur.close() + self.con.close() + def test_script_string_sql(self): - con = sqlite.connect(":memory:") - cur = con.cursor() + cur = self.cur cur.executescript(""" -- bla bla /* a stupid comment */ @@ -1596,40 +1680,40 @@ class ExtensionTests(unittest.TestCase): self.assertEqual(res, 5) def test_script_syntax_error(self): - con = sqlite.connect(":memory:") - cur = con.cursor() with self.assertRaises(sqlite.OperationalError): - cur.executescript("create table test(x); asdf; create table test2(x)") + self.cur.executescript(""" + CREATE TABLE test(x); + asdf; + CREATE TABLE test2(x) + """) def test_script_error_normal(self): - con = sqlite.connect(":memory:") - cur = con.cursor() with self.assertRaises(sqlite.OperationalError): - cur.executescript("create table test(sadfsadfdsa); select foo from hurz;") + self.cur.executescript(""" + CREATE TABLE test(sadfsadfdsa); + SELECT foo FROM hurz; + """) def test_cursor_executescript_as_bytes(self): - con = sqlite.connect(":memory:") - cur = con.cursor() with self.assertRaises(TypeError): - cur.executescript(b"create table test(foo); insert into test(foo) values (5);") + self.cur.executescript(b""" + CREATE TABLE test(foo); + INSERT INTO test(foo) VALUES (5); + """) def test_cursor_executescript_with_null_characters(self): - con = sqlite.connect(":memory:") - cur = con.cursor() with self.assertRaises(ValueError): - cur.executescript(""" - create table a(i);\0 - insert into a(i) values (5); - """) + self.cur.executescript(""" + CREATE TABLE a(i);\0 + INSERT INTO a(i) VALUES (5); + """) def test_cursor_executescript_with_surrogates(self): - con = sqlite.connect(":memory:") - cur = con.cursor() with self.assertRaises(UnicodeEncodeError): - cur.executescript(""" - create table a(s); - insert into a(s) values ('\ud8ff'); - """) + self.cur.executescript(""" + CREATE TABLE a(s); + INSERT INTO a(s) VALUES ('\ud8ff'); + """) def test_cursor_executescript_too_large_script(self): msg = "query string is too large" @@ -1639,19 +1723,18 @@ class ExtensionTests(unittest.TestCase): cx.executescript("select 'too large'".ljust(lim+1)) def test_cursor_executescript_tx_control(self): - con = sqlite.connect(":memory:") + con = self.con con.execute("begin") self.assertTrue(con.in_transaction) con.executescript("select 1") self.assertFalse(con.in_transaction) def test_connection_execute(self): - con = sqlite.connect(":memory:") - result = con.execute("select 5").fetchone()[0] + result = self.con.execute("select 5").fetchone()[0] self.assertEqual(result, 5, "Basic test of Connection.execute") def test_connection_executemany(self): - con = sqlite.connect(":memory:") + con = self.con con.execute("create table test(foo)") con.executemany("insert into test(foo) values (?)", [(3,), (4,)]) result = con.execute("select foo from test order by foo").fetchall() @@ -1659,47 +1742,50 @@ class ExtensionTests(unittest.TestCase): self.assertEqual(result[1][0], 4, "Basic test of Connection.executemany") def test_connection_executescript(self): - con = sqlite.connect(":memory:") - con.executescript("create table test(foo); insert into test(foo) values (5);") + con = self.con + con.executescript(""" + CREATE TABLE test(foo); + INSERT INTO test(foo) VALUES (5); + """) result = con.execute("select foo from test").fetchone()[0] self.assertEqual(result, 5, "Basic test of Connection.executescript") + class ClosedConTests(unittest.TestCase): + def check(self, fn, *args, **kwds): + regex = "Cannot operate on a closed database." + with self.assertRaisesRegex(sqlite.ProgrammingError, regex): + fn(*args, **kwds) + + def setUp(self): + self.con = sqlite.connect(":memory:") + self.cur = self.con.cursor() + self.con.close() + + @unittest.expectedFailure # TODO: RUSTPYTHON error message differs for closed connection def test_closed_con_cursor(self): - con = sqlite.connect(":memory:") - con.close() - with self.assertRaises(sqlite.ProgrammingError): - cur = con.cursor() + self.check(self.con.cursor) + @unittest.expectedFailure # TODO: RUSTPYTHON error message differs for closed connection def test_closed_con_commit(self): - con = sqlite.connect(":memory:") - con.close() - with self.assertRaises(sqlite.ProgrammingError): - con.commit() + self.check(self.con.commit) + @unittest.expectedFailure # TODO: RUSTPYTHON error message differs for closed connection def test_closed_con_rollback(self): - con = sqlite.connect(":memory:") - con.close() - with self.assertRaises(sqlite.ProgrammingError): - con.rollback() + self.check(self.con.rollback) + @unittest.expectedFailure # TODO: RUSTPYTHON error message differs for closed connection def test_closed_cur_execute(self): - con = sqlite.connect(":memory:") - cur = con.cursor() - con.close() - with self.assertRaises(sqlite.ProgrammingError): - cur.execute("select 4") + self.check(self.cur.execute, "select 4") + @unittest.expectedFailure # TODO: RUSTPYTHON error message differs for closed connection def test_closed_create_function(self): - con = sqlite.connect(":memory:") - con.close() - def f(x): return 17 - with self.assertRaises(sqlite.ProgrammingError): - con.create_function("foo", 1, f) + def f(x): + return 17 + self.check(self.con.create_function, "foo", 1, f) + @unittest.expectedFailure # TODO: RUSTPYTHON error message differs for closed connection def test_closed_create_aggregate(self): - con = sqlite.connect(":memory:") - con.close() class Agg: def __init__(self): pass @@ -1707,34 +1793,28 @@ class ClosedConTests(unittest.TestCase): pass def finalize(self): return 17 - with self.assertRaises(sqlite.ProgrammingError): - con.create_aggregate("foo", 1, Agg) + self.check(self.con.create_aggregate, "foo", 1, Agg) + @unittest.expectedFailure # TODO: RUSTPYTHON error message differs for closed connection def test_closed_set_authorizer(self): - con = sqlite.connect(":memory:") - con.close() def authorizer(*args): return sqlite.DENY - with self.assertRaises(sqlite.ProgrammingError): - con.set_authorizer(authorizer) + self.check(self.con.set_authorizer, authorizer) + @unittest.expectedFailure # TODO: RUSTPYTHON error message differs for closed connection def test_closed_set_progress_callback(self): - con = sqlite.connect(":memory:") - con.close() - def progress(): pass - with self.assertRaises(sqlite.ProgrammingError): - con.set_progress_handler(progress, 100) + def progress(): + pass + self.check(self.con.set_progress_handler, progress, 100) + @unittest.expectedFailure # TODO: RUSTPYTHON error message differs for closed connection def test_closed_call(self): - con = sqlite.connect(":memory:") - con.close() - with self.assertRaises(sqlite.ProgrammingError): - con() + self.check(self.con) -class ClosedCurTests(unittest.TestCase): + +class ClosedCurTests(MemoryDatabaseMixin, unittest.TestCase): def test_closed(self): - con = sqlite.connect(":memory:") - cur = con.cursor() + cur = self.cx.cursor() cur.close() for method_name in ("execute", "executemany", "executescript", "fetchall", "fetchmany", "fetchone"): @@ -1843,14 +1923,14 @@ class SqliteOnConflictTests(unittest.TestCase): self.assertEqual(self.cu.fetchall(), [('Very different data!', 'foo')]) -# TODO: RUSTPYTHON -# @requires_subprocess() +@requires_subprocess() class MultiprocessTests(unittest.TestCase): - CONNECTION_TIMEOUT = SHORT_TIMEOUT / 1000. # Defaults to 30 ms + CONNECTION_TIMEOUT = 0 # Disable the busy timeout. def tearDown(self): unlink(TESTFN) + @unittest.expectedFailure # TODO: RUSTPYTHON multiprocess test fails def test_ctx_mgr_rollback_if_commit_failed(self): # bpo-27334: ctx manager does not rollback if commit fails SCRIPT = f"""if 1: @@ -1916,5 +1996,71 @@ class MultiprocessTests(unittest.TestCase): self.assertEqual(proc.returncode, 0) +class RowTests(unittest.TestCase): + + def setUp(self): + self.cx = sqlite.connect(":memory:") + self.cx.row_factory = sqlite.Row + + def tearDown(self): + self.cx.close() + + def test_row_keys(self): + cu = self.cx.execute("SELECT 1 as first, 2 as second") + row = cu.fetchone() + self.assertEqual(row.keys(), ["first", "second"]) + + def test_row_length(self): + cu = self.cx.execute("SELECT 1, 2, 3") + row = cu.fetchone() + self.assertEqual(len(row), 3) + + def test_row_getitem(self): + cu = self.cx.execute("SELECT 1 as a, 2 as b") + row = cu.fetchone() + self.assertEqual(row[0], 1) + self.assertEqual(row[1], 2) + self.assertEqual(row["a"], 1) + self.assertEqual(row["b"], 2) + for key in "nokey", 4, 1.2: + with self.subTest(key=key): + with self.assertRaises(IndexError): + row[key] + + def test_row_equality(self): + c1 = self.cx.execute("SELECT 1 as a") + r1 = c1.fetchone() + + c2 = self.cx.execute("SELECT 1 as a") + r2 = c2.fetchone() + + self.assertIsNot(r1, r2) + self.assertEqual(r1, r2) + + c3 = self.cx.execute("SELECT 1 as b") + r3 = c3.fetchone() + + self.assertNotEqual(r1, r3) + + @unittest.expectedFailure # TODO: RUSTPYTHON Row with no description fails + def test_row_no_description(self): + cu = self.cx.cursor() + self.assertIsNone(cu.description) + + row = sqlite.Row(cu, ()) + self.assertEqual(row.keys(), []) + with self.assertRaisesRegex(IndexError, "nokey"): + row["nokey"] + + def test_row_is_a_sequence(self): + from collections.abc import Sequence + + cu = self.cx.execute("SELECT 1") + row = cu.fetchone() + + self.assertIsSubclass(sqlite.Row, Sequence) + self.assertIsInstance(row, Sequence) + + if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_sqlite3/test_dump.py b/Lib/test/test_sqlite3/test_dump.py index ec4a11da8..74aacc05c 100644 --- a/Lib/test/test_sqlite3/test_dump.py +++ b/Lib/test/test_sqlite3/test_dump.py @@ -1,22 +1,18 @@ # Author: Paul Kippes import unittest -import sqlite3 as sqlite -from .test_dbapi import memory_database + +from .util import memory_database +from .util import MemoryDatabaseMixin +from .util import requires_virtual_table -# TODO: RUSTPYTHON -@unittest.expectedFailure -class DumpTests(unittest.TestCase): - def setUp(self): - self.cx = sqlite.connect(":memory:") - self.cu = self.cx.cursor() - - def tearDown(self): - self.cx.close() +class DumpTests(MemoryDatabaseMixin, unittest.TestCase): + @unittest.expectedFailure # TODO: RUSTPYTHON def test_table_dump(self): expected_sqls = [ + "PRAGMA foreign_keys=OFF;", """CREATE TABLE "index"("index" blob);""" , """INSERT INTO "index" VALUES(X'01');""" @@ -27,7 +23,8 @@ class DumpTests(unittest.TestCase): , "CREATE TABLE t1(id integer primary key, s1 text, " \ "t1_i1 integer not null, i2 integer, unique (s1), " \ - "constraint t1_idx1 unique (i2));" + "constraint t1_idx1 unique (i2), " \ + "constraint t1_i1_idx1 unique (t1_i1));" , "INSERT INTO \"t1\" VALUES(1,'foo',10,20);" , @@ -37,6 +34,9 @@ class DumpTests(unittest.TestCase): "t2_i2 integer, primary key (id)," \ "foreign key(t2_i1) references t1(t1_i1));" , + # Foreign key violation. + "INSERT INTO \"t2\" VALUES(1,2,3);" + , "CREATE TRIGGER trigger_1 update of t1_i1 on t1 " \ "begin " \ "update t2 set t2_i1 = new.t1_i1 where t2_i1 = old.t1_i1; " \ @@ -48,11 +48,87 @@ class DumpTests(unittest.TestCase): [self.cu.execute(s) for s in expected_sqls] i = self.cx.iterdump() actual_sqls = [s for s in i] - expected_sqls = ['BEGIN TRANSACTION;'] + expected_sqls + \ - ['COMMIT;'] + expected_sqls = [ + "PRAGMA foreign_keys=OFF;", + "BEGIN TRANSACTION;", + *expected_sqls[1:], + "COMMIT;", + ] [self.assertEqual(expected_sqls[i], actual_sqls[i]) for i in range(len(expected_sqls))] + @unittest.expectedFailure # TODO: RUSTPYTHON iterdump filter parameter not implemented + def test_table_dump_filter(self): + all_table_sqls = [ + """CREATE TABLE "some_table_2" ("id_1" INTEGER);""", + """INSERT INTO "some_table_2" VALUES(3);""", + """INSERT INTO "some_table_2" VALUES(4);""", + """CREATE TABLE "test_table_1" ("id_2" INTEGER);""", + """INSERT INTO "test_table_1" VALUES(1);""", + """INSERT INTO "test_table_1" VALUES(2);""", + ] + all_views_sqls = [ + """CREATE VIEW "view_1" AS SELECT * FROM "some_table_2";""", + """CREATE VIEW "view_2" AS SELECT * FROM "test_table_1";""", + ] + # Create database structure. + for sql in [*all_table_sqls, *all_views_sqls]: + self.cu.execute(sql) + # %_table_% matches all tables. + dump_sqls = list(self.cx.iterdump(filter="%_table_%")) + self.assertEqual( + dump_sqls, + ["BEGIN TRANSACTION;", *all_table_sqls, "COMMIT;"], + ) + # view_% matches all views. + dump_sqls = list(self.cx.iterdump(filter="view_%")) + self.assertEqual( + dump_sqls, + ["BEGIN TRANSACTION;", *all_views_sqls, "COMMIT;"], + ) + # %_1 matches tables and views with the _1 suffix. + dump_sqls = list(self.cx.iterdump(filter="%_1")) + self.assertEqual( + dump_sqls, + [ + "BEGIN TRANSACTION;", + """CREATE TABLE "test_table_1" ("id_2" INTEGER);""", + """INSERT INTO "test_table_1" VALUES(1);""", + """INSERT INTO "test_table_1" VALUES(2);""", + """CREATE VIEW "view_1" AS SELECT * FROM "some_table_2";""", + "COMMIT;" + ], + ) + # some_% matches some_table_2. + dump_sqls = list(self.cx.iterdump(filter="some_%")) + self.assertEqual( + dump_sqls, + [ + "BEGIN TRANSACTION;", + """CREATE TABLE "some_table_2" ("id_1" INTEGER);""", + """INSERT INTO "some_table_2" VALUES(3);""", + """INSERT INTO "some_table_2" VALUES(4);""", + "COMMIT;" + ], + ) + # Only single object. + dump_sqls = list(self.cx.iterdump(filter="view_2")) + self.assertEqual( + dump_sqls, + [ + "BEGIN TRANSACTION;", + """CREATE VIEW "view_2" AS SELECT * FROM "test_table_1";""", + "COMMIT;" + ], + ) + # % matches all objects. + dump_sqls = list(self.cx.iterdump(filter="%")) + self.assertEqual( + dump_sqls, + ["BEGIN TRANSACTION;", *all_table_sqls, *all_views_sqls, "COMMIT;"], + ) + + @unittest.expectedFailure # TODO: RUSTPYTHON _iterdump not implemented def test_dump_autoincrement(self): expected = [ 'CREATE TABLE "t1" (id integer primary key autoincrement);', @@ -73,6 +149,7 @@ class DumpTests(unittest.TestCase): actual = [stmt for stmt in self.cx.iterdump()] self.assertEqual(expected, actual) + @unittest.expectedFailure # TODO: RUSTPYTHON _iterdump not implemented def test_dump_autoincrement_create_new_db(self): self.cu.execute("BEGIN TRANSACTION") self.cu.execute("CREATE TABLE t1 (id integer primary key autoincrement)") @@ -98,6 +175,7 @@ class DumpTests(unittest.TestCase): rows = res.fetchall() self.assertEqual(rows[0][0], seq) + @unittest.expectedFailure # TODO: RUSTPYTHON _iterdump not implemented def test_unorderable_row(self): # iterdump() should be able to cope with unorderable row types (issue #15545) class UnorderableRow: @@ -119,6 +197,44 @@ class DumpTests(unittest.TestCase): got = list(self.cx.iterdump()) self.assertEqual(expected, got) + @unittest.expectedFailure # TODO: RUSTPYTHON _iterdump not implemented + def test_dump_custom_row_factory(self): + # gh-118221: iterdump should be able to cope with custom row factories. + def dict_factory(cu, row): + fields = [col[0] for col in cu.description] + return dict(zip(fields, row)) + + self.cx.row_factory = dict_factory + CREATE_TABLE = "CREATE TABLE test(t);" + expected = ["BEGIN TRANSACTION;", CREATE_TABLE, "COMMIT;"] + + self.cu.execute(CREATE_TABLE) + actual = list(self.cx.iterdump()) + self.assertEqual(expected, actual) + self.assertEqual(self.cx.row_factory, dict_factory) + + @unittest.expectedFailure # TODO: RUSTPYTHON _iterdump not implemented + @requires_virtual_table("fts4") + def test_dump_virtual_tables(self): + # gh-64662 + expected = [ + "BEGIN TRANSACTION;", + "PRAGMA writable_schema=ON;", + ("INSERT INTO sqlite_master(type,name,tbl_name,rootpage,sql)" + "VALUES('table','test','test',0,'CREATE VIRTUAL TABLE test USING fts4(example)');"), + "CREATE TABLE 'test_content'(docid INTEGER PRIMARY KEY, 'c0example');", + "CREATE TABLE 'test_docsize'(docid INTEGER PRIMARY KEY, size BLOB);", + ("CREATE TABLE 'test_segdir'(level INTEGER,idx INTEGER,start_block INTEGER," + "leaves_end_block INTEGER,end_block INTEGER,root BLOB,PRIMARY KEY(level, idx));"), + "CREATE TABLE 'test_segments'(blockid INTEGER PRIMARY KEY, block BLOB);", + "CREATE TABLE 'test_stat'(id INTEGER PRIMARY KEY, value BLOB);", + "PRAGMA writable_schema=OFF;", + "COMMIT;" + ] + self.cu.execute("CREATE VIRTUAL TABLE test USING fts4(example)") + actual = list(self.cx.iterdump()) + self.assertEqual(expected, actual) + if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_sqlite3/test_factory.py b/Lib/test/test_sqlite3/test_factory.py index 20cff8b58..d19b98b20 100644 --- a/Lib/test/test_sqlite3/test_factory.py +++ b/Lib/test/test_sqlite3/test_factory.py @@ -24,6 +24,9 @@ import unittest import sqlite3 as sqlite from collections.abc import Sequence +from .util import memory_database +from .util import MemoryDatabaseMixin + def dict_factory(cursor, row): d = {} @@ -37,6 +40,7 @@ class MyCursor(sqlite.Cursor): self.row_factory = dict_factory class ConnectionFactoryTests(unittest.TestCase): + @unittest.expectedFailure # TODO: RUSTPYTHON def test_connection_factories(self): class DefectFactory(sqlite.Connection): def __init__(self, *args, **kwargs): @@ -45,13 +49,14 @@ class ConnectionFactoryTests(unittest.TestCase): def __init__(self, *args, **kwargs): sqlite.Connection.__init__(self, *args, **kwargs) - for factory in DefectFactory, OkFactory: - with self.subTest(factory=factory): - con = sqlite.connect(":memory:", factory=factory) - self.assertIsInstance(con, factory) + with memory_database(factory=OkFactory) as con: + self.assertIsInstance(con, OkFactory) + regex = "Base Connection.__init__ not called." + with self.assertRaisesRegex(sqlite.ProgrammingError, regex): + with memory_database(factory=DefectFactory) as con: + self.assertIsInstance(con, DefectFactory) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_connection_factory_relayed_call(self): # gh-95132: keyword args must not be passed as positional args class Factory(sqlite.Connection): @@ -59,23 +64,31 @@ class ConnectionFactoryTests(unittest.TestCase): kwargs["isolation_level"] = None super(Factory, self).__init__(*args, **kwargs) - con = sqlite.connect(":memory:", factory=Factory) - self.assertIsNone(con.isolation_level) - self.assertIsInstance(con, Factory) + with memory_database(factory=Factory) as con: + self.assertIsNone(con.isolation_level) + self.assertIsInstance(con, Factory) + @unittest.expectedFailure # TODO: RUSTPYTHON def test_connection_factory_as_positional_arg(self): class Factory(sqlite.Connection): def __init__(self, *args, **kwargs): super(Factory, self).__init__(*args, **kwargs) - con = sqlite.connect(":memory:", 5.0, 0, None, True, Factory) - self.assertIsNone(con.isolation_level) - self.assertIsInstance(con, Factory) + regex = ( + r"Passing more than 1 positional argument to _sqlite3.Connection\(\) " + r"is deprecated. Parameters 'timeout', 'detect_types', " + r"'isolation_level', 'check_same_thread', 'factory', " + r"'cached_statements' and 'uri' will become keyword-only " + r"parameters in Python 3.15." + ) + with self.assertWarnsRegex(DeprecationWarning, regex) as cm: + with memory_database(5.0, 0, None, True, Factory) as con: + self.assertIsNone(con.isolation_level) + self.assertIsInstance(con, Factory) + self.assertEqual(cm.filename, __file__) -class CursorFactoryTests(unittest.TestCase): - def setUp(self): - self.con = sqlite.connect(":memory:") +class CursorFactoryTests(MemoryDatabaseMixin, unittest.TestCase): def tearDown(self): self.con.close() @@ -96,12 +109,10 @@ class CursorFactoryTests(unittest.TestCase): # invalid callable returning non-cursor self.assertRaises(TypeError, self.con.cursor, lambda con: None) -class RowFactoryTestsBackwardsCompat(unittest.TestCase): - def setUp(self): - self.con = sqlite.connect(":memory:") - # TODO: RUSTPYTHON - @unittest.expectedFailure +class RowFactoryTestsBackwardsCompat(MemoryDatabaseMixin, unittest.TestCase): + + @unittest.expectedFailure # TODO: RUSTPYTHON def test_is_produced_by_factory(self): cur = self.con.cursor(factory=MyCursor) cur.execute("select 4+5 as foo") @@ -109,22 +120,20 @@ class RowFactoryTestsBackwardsCompat(unittest.TestCase): self.assertIsInstance(row, dict) cur.close() - def tearDown(self): - self.con.close() -class RowFactoryTests(unittest.TestCase): +class RowFactoryTests(MemoryDatabaseMixin, unittest.TestCase): + def setUp(self): - self.con = sqlite.connect(":memory:") + super().setUp() + self.con.row_factory = sqlite.Row def test_custom_factory(self): self.con.row_factory = lambda cur, row: list(row) row = self.con.execute("select 1, 2").fetchone() self.assertIsInstance(row, list) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_sqlite_row_index(self): - self.con.row_factory = sqlite.Row row = self.con.execute("select 1 as a_1, 2 as b").fetchone() self.assertIsInstance(row, sqlite.Row) @@ -154,10 +163,8 @@ class RowFactoryTests(unittest.TestCase): with self.assertRaises(IndexError): row[complex()] # index must be int or string - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_sqlite_row_index_unicode(self): - self.con.row_factory = sqlite.Row row = self.con.execute("select 1 as \xff").fetchone() self.assertEqual(row["\xff"], 1) with self.assertRaises(IndexError): @@ -167,7 +174,6 @@ class RowFactoryTests(unittest.TestCase): def test_sqlite_row_slice(self): # A sqlite.Row can be sliced like a list. - self.con.row_factory = sqlite.Row row = self.con.execute("select 1, 2, 3, 4").fetchone() self.assertEqual(row[0:0], ()) self.assertEqual(row[0:1], (1,)) @@ -184,30 +190,32 @@ class RowFactoryTests(unittest.TestCase): self.assertEqual(row[3:0:-2], (4, 2)) def test_sqlite_row_iter(self): - """Checks if the row object is iterable""" - self.con.row_factory = sqlite.Row + # Checks if the row object is iterable. row = self.con.execute("select 1 as a, 2 as b").fetchone() - for col in row: - pass + + # Is iterable in correct order and produces valid results: + items = [col for col in row] + self.assertEqual(items, [1, 2]) + + # Is iterable the second time: + items = [col for col in row] + self.assertEqual(items, [1, 2]) def test_sqlite_row_as_tuple(self): - """Checks if the row object can be converted to a tuple""" - self.con.row_factory = sqlite.Row + # Checks if the row object can be converted to a tuple. row = self.con.execute("select 1 as a, 2 as b").fetchone() t = tuple(row) self.assertEqual(t, (row['a'], row['b'])) def test_sqlite_row_as_dict(self): - """Checks if the row object can be correctly converted to a dictionary""" - self.con.row_factory = sqlite.Row + # Checks if the row object can be correctly converted to a dictionary. row = self.con.execute("select 1 as a, 2 as b").fetchone() d = dict(row) self.assertEqual(d["a"], row["a"]) self.assertEqual(d["b"], row["b"]) def test_sqlite_row_hash_cmp(self): - """Checks if the row object compares and hashes correctly""" - self.con.row_factory = sqlite.Row + # Checks if the row object compares and hashes correctly. row_1 = self.con.execute("select 1 as a, 2 as b").fetchone() row_2 = self.con.execute("select 1 as a, 2 as b").fetchone() row_3 = self.con.execute("select 1 as a, 3 as b").fetchone() @@ -240,30 +248,29 @@ class RowFactoryTests(unittest.TestCase): self.assertEqual(hash(row_1), hash(row_2)) def test_sqlite_row_as_sequence(self): - """ Checks if the row object can act like a sequence """ - self.con.row_factory = sqlite.Row + # Checks if the row object can act like a sequence. row = self.con.execute("select 1 as a, 2 as b").fetchone() as_tuple = tuple(row) self.assertEqual(list(reversed(row)), list(reversed(as_tuple))) self.assertIsInstance(row, Sequence) + def test_sqlite_row_keys(self): + # Checks if the row object can return a list of columns as strings. + row = self.con.execute("select 1 as a, 2 as b").fetchone() + self.assertEqual(row.keys(), ['a', 'b']) + def test_fake_cursor_class(self): # Issue #24257: Incorrect use of PyObject_IsInstance() caused # segmentation fault. # Issue #27861: Also applies for cursor factory. class FakeCursor(str): __class__ = sqlite.Cursor - self.con.row_factory = sqlite.Row self.assertRaises(TypeError, self.con.cursor, FakeCursor) self.assertRaises(TypeError, sqlite.Row, FakeCursor(), ()) - def tearDown(self): - self.con.close() -class TextFactoryTests(unittest.TestCase): - def setUp(self): - self.con = sqlite.connect(":memory:") +class TextFactoryTests(MemoryDatabaseMixin, unittest.TestCase): def test_unicode(self): austria = "Österreich" @@ -282,17 +289,19 @@ class TextFactoryTests(unittest.TestCase): austria = "Österreich" row = self.con.execute("select ?", (austria,)).fetchone() self.assertEqual(type(row[0]), str, "type of row[0] must be unicode") - self.assertTrue(row[0].endswith("reich"), "column must contain original data") + self.assertEndsWith(row[0], "reich", "column must contain original data") - def tearDown(self): - self.con.close() class TextFactoryTestsWithEmbeddedZeroBytes(unittest.TestCase): + def setUp(self): self.con = sqlite.connect(":memory:") self.con.execute("create table test (value text)") self.con.execute("insert into test (value) values (?)", ("a\x00b",)) + def tearDown(self): + self.con.close() + def test_string(self): # text_factory defaults to str row = self.con.execute("select value from test").fetchone() @@ -318,9 +327,6 @@ class TextFactoryTestsWithEmbeddedZeroBytes(unittest.TestCase): self.assertIs(type(row[0]), bytes) self.assertEqual(row[0], b"a\x00b") - def tearDown(self): - self.con.close() - if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_sqlite3/test_hooks.py b/Lib/test/test_sqlite3/test_hooks.py index 21042b9bf..c47cfab18 100644 --- a/Lib/test/test_sqlite3/test_hooks.py +++ b/Lib/test/test_sqlite3/test_hooks.py @@ -26,34 +26,31 @@ import unittest from test.support.os_helper import TESTFN, unlink -from test.test_sqlite3.test_dbapi import memory_database, cx_limit -from test.test_sqlite3.test_userfunctions import with_tracebacks +from .util import memory_database, cx_limit, with_tracebacks +from .util import MemoryDatabaseMixin -class CollationTests(unittest.TestCase): +class CollationTests(MemoryDatabaseMixin, unittest.TestCase): + def test_create_collation_not_string(self): - con = sqlite.connect(":memory:") with self.assertRaises(TypeError): - con.create_collation(None, lambda x, y: (x > y) - (x < y)) + self.con.create_collation(None, lambda x, y: (x > y) - (x < y)) def test_create_collation_not_callable(self): - con = sqlite.connect(":memory:") with self.assertRaises(TypeError) as cm: - con.create_collation("X", 42) + self.con.create_collation("X", 42) self.assertEqual(str(cm.exception), 'parameter must be callable') def test_create_collation_not_ascii(self): - con = sqlite.connect(":memory:") - con.create_collation("collä", lambda x, y: (x > y) - (x < y)) + self.con.create_collation("collä", lambda x, y: (x > y) - (x < y)) def test_create_collation_bad_upper(self): class BadUpperStr(str): def upper(self): return None - con = sqlite.connect(":memory:") mycoll = lambda x, y: -((x > y) - (x < y)) - con.create_collation(BadUpperStr("mycoll"), mycoll) - result = con.execute(""" + self.con.create_collation(BadUpperStr("mycoll"), mycoll) + result = self.con.execute(""" select x from ( select 'a' as x union @@ -68,8 +65,7 @@ class CollationTests(unittest.TestCase): # reverse order return -((x > y) - (x < y)) - con = sqlite.connect(":memory:") - con.create_collation("mycoll", mycoll) + self.con.create_collation("mycoll", mycoll) sql = """ select x from ( select 'a' as x @@ -79,21 +75,20 @@ class CollationTests(unittest.TestCase): select 'c' as x ) order by x collate mycoll """ - result = con.execute(sql).fetchall() + result = self.con.execute(sql).fetchall() self.assertEqual(result, [('c',), ('b',), ('a',)], msg='the expected order was not returned') - con.create_collation("mycoll", None) + self.con.create_collation("mycoll", None) with self.assertRaises(sqlite.OperationalError) as cm: - result = con.execute(sql).fetchall() + result = self.con.execute(sql).fetchall() self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll') def test_collation_returns_large_integer(self): def mycoll(x, y): # reverse order return -((x > y) - (x < y)) * 2**32 - con = sqlite.connect(":memory:") - con.create_collation("mycoll", mycoll) + self.con.create_collation("mycoll", mycoll) sql = """ select x from ( select 'a' as x @@ -103,7 +98,7 @@ class CollationTests(unittest.TestCase): select 'c' as x ) order by x collate mycoll """ - result = con.execute(sql).fetchall() + result = self.con.execute(sql).fetchall() self.assertEqual(result, [('c',), ('b',), ('a',)], msg="the expected order was not returned") @@ -112,7 +107,7 @@ class CollationTests(unittest.TestCase): Register two different collation functions under the same name. Verify that the last one is actually used. """ - con = sqlite.connect(":memory:") + con = self.con con.create_collation("mycoll", lambda x, y: (x > y) - (x < y)) con.create_collation("mycoll", lambda x, y: -((x > y) - (x < y))) result = con.execute(""" @@ -126,25 +121,26 @@ class CollationTests(unittest.TestCase): Register a collation, then deregister it. Make sure an error is raised if we try to use it. """ - con = sqlite.connect(":memory:") + con = self.con con.create_collation("mycoll", lambda x, y: (x > y) - (x < y)) con.create_collation("mycoll", None) with self.assertRaises(sqlite.OperationalError) as cm: con.execute("select 'a' as x union select 'b' as x order by x collate mycoll") self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll') -class ProgressTests(unittest.TestCase): + +class ProgressTests(MemoryDatabaseMixin, unittest.TestCase): + def test_progress_handler_used(self): """ Test that the progress handler is invoked once it is set. """ - con = sqlite.connect(":memory:") progress_calls = [] def progress(): progress_calls.append(None) return 0 - con.set_progress_handler(progress, 1) - con.execute(""" + self.con.set_progress_handler(progress, 1) + self.con.execute(""" create table foo(a, b) """) self.assertTrue(progress_calls) @@ -153,7 +149,7 @@ class ProgressTests(unittest.TestCase): """ Test that the opcode argument is respected. """ - con = sqlite.connect(":memory:") + con = self.con progress_calls = [] def progress(): progress_calls.append(None) @@ -176,11 +172,10 @@ class ProgressTests(unittest.TestCase): """ Test that returning a non-zero value stops the operation in progress. """ - con = sqlite.connect(":memory:") def progress(): return 1 - con.set_progress_handler(progress, 1) - curs = con.cursor() + self.con.set_progress_handler(progress, 1) + curs = self.con.cursor() self.assertRaises( sqlite.OperationalError, curs.execute, @@ -190,7 +185,7 @@ class ProgressTests(unittest.TestCase): """ Test that setting the progress handler to None clears the previously set handler. """ - con = sqlite.connect(":memory:") + con = self.con action = 0 def progress(): nonlocal action @@ -201,33 +196,47 @@ class ProgressTests(unittest.TestCase): con.execute("select 1 union select 2 union select 3").fetchall() self.assertEqual(action, 0, "progress handler was not cleared") - @with_tracebacks(ZeroDivisionError, name="bad_progress") + @unittest.expectedFailure # TODO: RUSTPYTHON + @with_tracebacks(ZeroDivisionError, msg_regex="bad_progress") def test_error_in_progress_handler(self): - con = sqlite.connect(":memory:") def bad_progress(): 1 / 0 - con.set_progress_handler(bad_progress, 1) + self.con.set_progress_handler(bad_progress, 1) with self.assertRaises(sqlite.OperationalError): - con.execute(""" + self.con.execute(""" create table foo(a, b) """) - @with_tracebacks(ZeroDivisionError, name="bad_progress") + @unittest.expectedFailure # TODO: RUSTPYTHON unraisable exception handling not implemented + @with_tracebacks(ZeroDivisionError, msg_regex="bad_progress") def test_error_in_progress_handler_result(self): - con = sqlite.connect(":memory:") class BadBool: def __bool__(self): 1 / 0 def bad_progress(): return BadBool() - con.set_progress_handler(bad_progress, 1) + self.con.set_progress_handler(bad_progress, 1) with self.assertRaises(sqlite.OperationalError): - con.execute(""" + self.con.execute(""" create table foo(a, b) """) + @unittest.expectedFailure # TODO: RUSTPYTHON keyword-only arguments not supported for set_progress_handler + def test_progress_handler_keyword_args(self): + regex = ( + r"Passing keyword argument 'progress_handler' to " + r"_sqlite3.Connection.set_progress_handler\(\) is deprecated. " + r"Parameter 'progress_handler' will become positional-only in " + r"Python 3.15." + ) + + with self.assertWarnsRegex(DeprecationWarning, regex) as cm: + self.con.set_progress_handler(progress_handler=lambda: None, n=1) + self.assertEqual(cm.filename, __file__) + + +class TraceCallbackTests(MemoryDatabaseMixin, unittest.TestCase): -class TraceCallbackTests(unittest.TestCase): @contextlib.contextmanager def check_stmt_trace(self, cx, expected): try: @@ -242,12 +251,11 @@ class TraceCallbackTests(unittest.TestCase): """ Test that the trace callback is invoked once it is set. """ - con = sqlite.connect(":memory:") traced_statements = [] def trace(statement): traced_statements.append(statement) - con.set_trace_callback(trace) - con.execute("create table foo(a, b)") + self.con.set_trace_callback(trace) + self.con.execute("create table foo(a, b)") self.assertTrue(traced_statements) self.assertTrue(any("create table foo" in stmt for stmt in traced_statements)) @@ -255,7 +263,7 @@ class TraceCallbackTests(unittest.TestCase): """ Test that setting the trace callback to None clears the previously set callback. """ - con = sqlite.connect(":memory:") + con = self.con traced_statements = [] def trace(statement): traced_statements.append(statement) @@ -269,7 +277,7 @@ class TraceCallbackTests(unittest.TestCase): Test that the statement can contain unicode literals. """ unicode_value = '\xf6\xe4\xfc\xd6\xc4\xdc\xdf\u20ac' - con = sqlite.connect(":memory:") + con = self.con traced_statements = [] def trace(statement): traced_statements.append(statement) @@ -317,13 +325,14 @@ class TraceCallbackTests(unittest.TestCase): cx.execute("create table t(t)") cx.executemany("insert into t values(?)", ((v,) for v in range(3))) + @unittest.expectedFailure # TODO: RUSTPYTHON unraisable exception handling not implemented @with_tracebacks( sqlite.DataError, regex="Expanded SQL string exceeds the maximum string length" ) def test_trace_too_much_expanded_sql(self): # If the expanded string is too large, we'll fall back to the - # unexpanded SQL statement (for SQLite 3.14.0 and newer). + # unexpanded SQL statement. # The resulting string length is limited by the runtime limit # SQLITE_LIMIT_LENGTH. template = "select 1 as a where a=" @@ -334,8 +343,6 @@ class TraceCallbackTests(unittest.TestCase): unexpanded_query = template + "?" expected = [unexpanded_query] - if sqlite.sqlite_version_info < (3, 14, 0): - expected = [] with self.check_stmt_trace(cx, expected): cx.execute(unexpanded_query, (bad_param,)) @@ -343,12 +350,26 @@ class TraceCallbackTests(unittest.TestCase): with self.check_stmt_trace(cx, [expanded_query]): cx.execute(unexpanded_query, (ok_param,)) + @unittest.expectedFailure # TODO: RUSTPYTHON unraisable exception handling not implemented @with_tracebacks(ZeroDivisionError, regex="division by zero") def test_trace_bad_handler(self): with memory_database() as cx: cx.set_trace_callback(lambda stmt: 5/0) cx.execute("select 1") + @unittest.expectedFailure # TODO: RUSTPYTHON keyword-only arguments not supported for set_trace_callback + def test_trace_keyword_args(self): + regex = ( + r"Passing keyword argument 'trace_callback' to " + r"_sqlite3.Connection.set_trace_callback\(\) is deprecated. " + r"Parameter 'trace_callback' will become positional-only in " + r"Python 3.15." + ) + + with self.assertWarnsRegex(DeprecationWarning, regex) as cm: + self.con.set_trace_callback(trace_callback=lambda: None) + self.assertEqual(cm.filename, __file__) + if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_sqlite3/test_regression.py b/Lib/test/test_sqlite3/test_regression.py index d746be647..0ebd6d5e9 100644 --- a/Lib/test/test_sqlite3/test_regression.py +++ b/Lib/test/test_sqlite3/test_regression.py @@ -28,15 +28,12 @@ import functools from test import support from unittest.mock import patch -from test.test_sqlite3.test_dbapi import memory_database, cx_limit + +from .util import memory_database, cx_limit +from .util import MemoryDatabaseMixin -class RegressionTests(unittest.TestCase): - def setUp(self): - self.con = sqlite.connect(":memory:") - - def tearDown(self): - self.con.close() +class RegressionTests(MemoryDatabaseMixin, unittest.TestCase): def test_pragma_user_version(self): # This used to crash pysqlite because this pragma command returns NULL for the column name @@ -45,28 +42,24 @@ class RegressionTests(unittest.TestCase): def test_pragma_schema_version(self): # This still crashed pysqlite <= 2.2.1 - con = sqlite.connect(":memory:", detect_types=sqlite.PARSE_COLNAMES) - try: + with memory_database(detect_types=sqlite.PARSE_COLNAMES) as con: cur = self.con.cursor() cur.execute("pragma schema_version") - finally: - cur.close() - con.close() def test_statement_reset(self): # pysqlite 2.1.0 to 2.2.0 have the problem that not all statements are # reset before a rollback, but only those that are still in the # statement cache. The others are not accessible from the connection object. - con = sqlite.connect(":memory:", cached_statements=5) - cursors = [con.cursor() for x in range(5)] - cursors[0].execute("create table test(x)") - for i in range(10): - cursors[0].executemany("insert into test(x) values (?)", [(x,) for x in range(10)]) + with memory_database(cached_statements=5) as con: + cursors = [con.cursor() for x in range(5)] + cursors[0].execute("create table test(x)") + for i in range(10): + cursors[0].executemany("insert into test(x) values (?)", [(x,) for x in range(10)]) - for i in range(5): - cursors[i].execute(" " * i + "select x from test") + for i in range(5): + cursors[i].execute(" " * i + "select x from test") - con.rollback() + con.rollback() def test_column_name_with_spaces(self): cur = self.con.cursor() @@ -81,17 +74,15 @@ class RegressionTests(unittest.TestCase): # cache when closing the database. statements that were still # referenced in cursors weren't closed and could provoke " # "OperationalError: Unable to close due to unfinalised statements". - con = sqlite.connect(":memory:") cursors = [] # default statement cache size is 100 for i in range(105): - cur = con.cursor() + cur = self.con.cursor() cursors.append(cur) cur.execute("select 1 x union select " + str(i)) - con.close() def test_on_conflict_rollback(self): - con = sqlite.connect(":memory:") + con = self.con con.execute("create table foo(x, unique(x) on conflict rollback)") con.execute("insert into foo(x) values (1)") try: @@ -126,16 +117,16 @@ class RegressionTests(unittest.TestCase): a statement. This test exhibits the problem. """ SELECT = "select * from foo" - con = sqlite.connect(":memory:",detect_types=sqlite.PARSE_DECLTYPES) - cur = con.cursor() - cur.execute("create table foo(bar timestamp)") - with self.assertWarnsRegex(DeprecationWarning, "adapter"): - cur.execute("insert into foo(bar) values (?)", (datetime.datetime.now(),)) - cur.execute(SELECT) - cur.execute("drop table foo") - cur.execute("create table foo(bar integer)") - cur.execute("insert into foo(bar) values (5)") - cur.execute(SELECT) + with memory_database(detect_types=sqlite.PARSE_DECLTYPES) as con: + cur = con.cursor() + cur.execute("create table foo(bar timestamp)") + with self.assertWarnsRegex(DeprecationWarning, "adapter"): + cur.execute("insert into foo(bar) values (?)", (datetime.datetime.now(),)) + cur.execute(SELECT) + cur.execute("drop table foo") + cur.execute("create table foo(bar integer)") + cur.execute("insert into foo(bar) values (5)") + cur.execute(SELECT) def test_bind_mutating_list(self): # Issue41662: Crash when mutate a list of parameters during iteration. @@ -144,11 +135,11 @@ class RegressionTests(unittest.TestCase): parameters.clear() return "..." parameters = [X(), 0] - con = sqlite.connect(":memory:",detect_types=sqlite.PARSE_DECLTYPES) - con.execute("create table foo(bar X, baz integer)") - # Should not crash - with self.assertRaises(IndexError): - con.execute("insert into foo(bar, baz) values (?, ?)", parameters) + with memory_database(detect_types=sqlite.PARSE_DECLTYPES) as con: + con.execute("create table foo(bar X, baz integer)") + # Should not crash + with self.assertRaises(IndexError): + con.execute("insert into foo(bar, baz) values (?, ?)", parameters) def test_error_msg_decode_error(self): # When porting the module to Python 3.0, the error message about @@ -173,7 +164,7 @@ class RegressionTests(unittest.TestCase): def __del__(self): con.isolation_level = "" - con = sqlite.connect(":memory:") + con = self.con con.isolation_level = None for level in "", "DEFERRED", "IMMEDIATE", "EXCLUSIVE": with self.subTest(level=level): @@ -204,8 +195,7 @@ class RegressionTests(unittest.TestCase): def __init__(self, con): pass - con = sqlite.connect(":memory:") - cur = Cursor(con) + cur = Cursor(self.con) with self.assertRaises(sqlite.ProgrammingError): cur.execute("select 4+5").fetchall() with self.assertRaisesRegex(sqlite.ProgrammingError, @@ -238,7 +228,9 @@ class RegressionTests(unittest.TestCase): 2.5.3 introduced a regression so that these could no longer be created. """ - con = sqlite.connect(":memory:", isolation_level=None) + with memory_database(isolation_level=None) as con: + self.assertIsNone(con.isolation_level) + self.assertFalse(con.in_transaction) def test_pragma_autocommit(self): """ @@ -266,7 +258,7 @@ class RegressionTests(unittest.TestCase): # Lone surrogate cannot be encoded to the default encoding (utf8) "\uDC80", collation_cb) - @unittest.skip("TODO: RUSTPYTHON deadlock") + @unittest.skip('TODO: RUSTPYTHON; recursive cursor use causes lock contention') def test_recursive_cursor_use(self): """ http://bugs.python.org/issue10811 @@ -274,9 +266,7 @@ class RegressionTests(unittest.TestCase): Recursively using a cursor, such as when reusing it from a generator led to segfaults. Now we catch recursive cursor usage and raise a ProgrammingError. """ - con = sqlite.connect(":memory:") - - cur = con.cursor() + cur = self.con.cursor() cur.execute("create table a (bar)") cur.execute("create table b (baz)") @@ -296,29 +286,31 @@ class RegressionTests(unittest.TestCase): since the microsecond string "456" actually represents "456000". """ - con = sqlite.connect(":memory:", detect_types=sqlite.PARSE_DECLTYPES) - cur = con.cursor() - cur.execute("CREATE TABLE t (x TIMESTAMP)") + with memory_database(detect_types=sqlite.PARSE_DECLTYPES) as con: + cur = con.cursor() + cur.execute("CREATE TABLE t (x TIMESTAMP)") - # Microseconds should be 456000 - cur.execute("INSERT INTO t (x) VALUES ('2012-04-04 15:06:00.456')") + # Microseconds should be 456000 + cur.execute("INSERT INTO t (x) VALUES ('2012-04-04 15:06:00.456')") - # Microseconds should be truncated to 123456 - cur.execute("INSERT INTO t (x) VALUES ('2012-04-04 15:06:00.123456789')") + # Microseconds should be truncated to 123456 + cur.execute("INSERT INTO t (x) VALUES ('2012-04-04 15:06:00.123456789')") - cur.execute("SELECT * FROM t") - with self.assertWarnsRegex(DeprecationWarning, "converter"): - values = [x[0] for x in cur.fetchall()] + cur.execute("SELECT * FROM t") + with self.assertWarnsRegex(DeprecationWarning, "converter"): + values = [x[0] for x in cur.fetchall()] - self.assertEqual(values, [ - datetime.datetime(2012, 4, 4, 15, 6, 0, 456000), - datetime.datetime(2012, 4, 4, 15, 6, 0, 123456), - ]) + self.assertEqual(values, [ + datetime.datetime(2012, 4, 4, 15, 6, 0, 456000), + datetime.datetime(2012, 4, 4, 15, 6, 0, 123456), + ]) + @unittest.expectedFailure # TODO: RUSTPYTHON; error message mismatch def test_invalid_isolation_level_type(self): # isolation level is a string, not an integer - self.assertRaises(TypeError, - sqlite.connect, ":memory:", isolation_level=123) + regex = "isolation_level must be str or None" + with self.assertRaisesRegex(TypeError, regex): + memory_database(isolation_level=123).__enter__() def test_null_character(self): @@ -334,7 +326,7 @@ class RegressionTests(unittest.TestCase): cur.execute, query) def test_surrogates(self): - con = sqlite.connect(":memory:") + con = self.con self.assertRaises(UnicodeEncodeError, con, "select '\ud8ff'") self.assertRaises(UnicodeEncodeError, con, "select '\udcff'") cur = con.cursor() @@ -360,7 +352,7 @@ class RegressionTests(unittest.TestCase): to return rows multiple times when fetched from cursors after commit. See issues 10513 and 23129 for details. """ - con = sqlite.connect(":memory:") + con = self.con con.executescript(""" create table t(c); create table t2(c); @@ -392,10 +384,9 @@ class RegressionTests(unittest.TestCase): """ def callback(*args): pass - con = sqlite.connect(":memory:") - cur = sqlite.Cursor(con) + cur = sqlite.Cursor(self.con) ref = weakref.ref(cur, callback) - cur.__init__(con) + cur.__init__(self.con) del cur # The interpreter shouldn't crash when ref is collected. del ref @@ -405,8 +396,7 @@ class RegressionTests(unittest.TestCase): with self.assertRaises(AttributeError): del self.con.isolation_level - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_bpo37347(self): class Printer: def log(self, *args): @@ -428,6 +418,7 @@ class RegressionTests(unittest.TestCase): def test_table_lock_cursor_replace_stmt(self): with memory_database() as con: + con = self.con cur = con.cursor() cur.execute("create table t(t)") cur.executemany("insert into t values(?)", @@ -445,10 +436,11 @@ class RegressionTests(unittest.TestCase): con.commit() cur = con.execute("select t from t") del cur + support.gc_collect() con.execute("drop table t") con.commit() - @unittest.skip("TODO: RUSTPYTHON deadlock") + @unittest.skip('TODO: RUSTPYTHON; recursive cursor use causes lock contention') def test_table_lock_cursor_non_readonly_select(self): with memory_database() as con: con.execute("create table t(t)") @@ -461,6 +453,7 @@ class RegressionTests(unittest.TestCase): con.create_function("dup", 1, dup) cur = con.execute("select dup(t) from t") del cur + support.gc_collect() con.execute("drop table t") con.commit() @@ -476,7 +469,7 @@ class RegressionTests(unittest.TestCase): self.assertEqual(steps, values) -@unittest.skip("TODO: RUSTPYTHON deadlock") +@unittest.skip('TODO: RUSTPYTHON; recursive cursor use causes lock contention') class RecursiveUseOfCursors(unittest.TestCase): # GH-80254: sqlite3 should not segfault for recursive use of cursors. msg = "Recursive use of cursors not allowed" @@ -496,21 +489,21 @@ class RecursiveUseOfCursors(unittest.TestCase): def test_recursive_cursor_init(self): conv = lambda x: self.cur.__init__(self.con) with patch.dict(sqlite.converters, {"INIT": conv}): - self.cur.execute(f'select x as "x [INIT]", x from test') + self.cur.execute('select x as "x [INIT]", x from test') self.assertRaisesRegex(sqlite.ProgrammingError, self.msg, self.cur.fetchall) def test_recursive_cursor_close(self): conv = lambda x: self.cur.close() with patch.dict(sqlite.converters, {"CLOSE": conv}): - self.cur.execute(f'select x as "x [CLOSE]", x from test') + self.cur.execute('select x as "x [CLOSE]", x from test') self.assertRaisesRegex(sqlite.ProgrammingError, self.msg, self.cur.fetchall) def test_recursive_cursor_iter(self): conv = lambda x, l=[]: self.cur.fetchone() if l else l.append(None) with patch.dict(sqlite.converters, {"ITER": conv}): - self.cur.execute(f'select x as "x [ITER]", x from test') + self.cur.execute('select x as "x [ITER]", x from test') self.assertRaisesRegex(sqlite.ProgrammingError, self.msg, self.cur.fetchall) diff --git a/Lib/test/test_sqlite3/test_transactions.py b/Lib/test/test_sqlite3/test_transactions.py index 9c3d19e79..f38d042e5 100644 --- a/Lib/test/test_sqlite3/test_transactions.py +++ b/Lib/test/test_sqlite3/test_transactions.py @@ -22,22 +22,24 @@ import unittest import sqlite3 as sqlite +from contextlib import contextmanager -from test.support import LOOPBACK_TIMEOUT from test.support.os_helper import TESTFN, unlink +from test.support.script_helper import assert_python_ok -from test.test_sqlite3.test_dbapi import memory_database - - -TIMEOUT = LOOPBACK_TIMEOUT / 10 +from .util import memory_database +from .util import MemoryDatabaseMixin +@unittest.skip("TODO: RUSTPYTHON timeout parameter does not accept int type") class TransactionTests(unittest.TestCase): def setUp(self): - self.con1 = sqlite.connect(TESTFN, timeout=TIMEOUT) + # We can disable the busy handlers, since we control + # the order of SQLite C API operations. + self.con1 = sqlite.connect(TESTFN, timeout=0) self.cur1 = self.con1.cursor() - self.con2 = sqlite.connect(TESTFN, timeout=TIMEOUT) + self.con2 = sqlite.connect(TESTFN, timeout=0) self.cur2 = self.con2.cursor() def tearDown(self): @@ -117,10 +119,8 @@ class TransactionTests(unittest.TestCase): self.cur2.execute("insert into test(i) values (5)") def test_locking(self): - """ - This tests the improved concurrency with pysqlite 2.3.4. You needed - to roll back con2 before you could commit con1. - """ + # This tests the improved concurrency with pysqlite 2.3.4. You needed + # to roll back con2 before you could commit con1. self.cur1.execute("create table test(i)") self.cur1.execute("insert into test(i) values (5)") with self.assertRaises(sqlite.OperationalError): @@ -130,14 +130,14 @@ class TransactionTests(unittest.TestCase): def test_rollback_cursor_consistency(self): """Check that cursors behave correctly after rollback.""" - con = sqlite.connect(":memory:") - cur = con.cursor() - cur.execute("create table test(x)") - cur.execute("insert into test(x) values (5)") - cur.execute("select 1 union select 2 union select 3") + with memory_database() as con: + cur = con.cursor() + cur.execute("create table test(x)") + cur.execute("insert into test(x) values (5)") + cur.execute("select 1 union select 2 union select 3") - con.rollback() - self.assertEqual(cur.fetchall(), [(1,), (2,), (3,)]) + con.rollback() + self.assertEqual(cur.fetchall(), [(1,), (2,), (3,)]) def test_multiple_cursors_and_iternext(self): # gh-94028: statements are cleared and reset in cursor iternext. @@ -216,10 +216,7 @@ class RollbackTests(unittest.TestCase): -class SpecialCommandTests(unittest.TestCase): - def setUp(self): - self.con = sqlite.connect(":memory:") - self.cur = self.con.cursor() +class SpecialCommandTests(MemoryDatabaseMixin, unittest.TestCase): def test_drop_table(self): self.cur.execute("create table test(i)") @@ -231,14 +228,8 @@ class SpecialCommandTests(unittest.TestCase): self.cur.execute("insert into test(i) values (5)") self.cur.execute("pragma count_changes=1") - def tearDown(self): - self.cur.close() - self.con.close() - -class TransactionalDDL(unittest.TestCase): - def setUp(self): - self.con = sqlite.connect(":memory:") +class TransactionalDDL(MemoryDatabaseMixin, unittest.TestCase): def test_ddl_does_not_autostart_transaction(self): # For backwards compatibility reasons, DDL statements should not @@ -266,9 +257,6 @@ class TransactionalDDL(unittest.TestCase): with self.assertRaises(sqlite.OperationalError): self.con.execute("select * from test") - def tearDown(self): - self.con.close() - class IsolationLevelFromInit(unittest.TestCase): CREATE = "create table t(t)" @@ -366,5 +354,183 @@ class IsolationLevelPostInit(unittest.TestCase): self.assertEqual(self.traced, [self.QUERY]) +class AutocommitAttribute(unittest.TestCase): + """Test PEP 249-compliant autocommit behaviour.""" + legacy = sqlite.LEGACY_TRANSACTION_CONTROL + + @contextmanager + def check_stmt_trace(self, cx, expected, reset=True): + try: + traced = [] + cx.set_trace_callback(lambda stmt: traced.append(stmt)) + yield + finally: + self.assertEqual(traced, expected) + if reset: + cx.set_trace_callback(None) + + def test_autocommit_default(self): + with memory_database() as cx: + self.assertEqual(cx.autocommit, + sqlite.LEGACY_TRANSACTION_CONTROL) + + def test_autocommit_setget(self): + dataset = ( + True, + False, + sqlite.LEGACY_TRANSACTION_CONTROL, + ) + for mode in dataset: + with self.subTest(mode=mode): + with memory_database(autocommit=mode) as cx: + self.assertEqual(cx.autocommit, mode) + with memory_database() as cx: + cx.autocommit = mode + self.assertEqual(cx.autocommit, mode) + + @unittest.expectedFailure # TODO: RUSTPYTHON autocommit validation error messages differ + def test_autocommit_setget_invalid(self): + msg = "autocommit must be True, False, or.*LEGACY" + for mode in "a", 12, (), None: + with self.subTest(mode=mode): + with self.assertRaisesRegex(ValueError, msg): + sqlite.connect(":memory:", autocommit=mode) + + @unittest.expectedFailure # TODO: RUSTPYTHON autocommit behavior differs + def test_autocommit_disabled(self): + expected = [ + "SELECT 1", + "COMMIT", + "BEGIN", + "ROLLBACK", + "BEGIN", + ] + with memory_database(autocommit=False) as cx: + self.assertTrue(cx.in_transaction) + with self.check_stmt_trace(cx, expected): + cx.execute("SELECT 1") + cx.commit() + cx.rollback() + + @unittest.expectedFailure # TODO: RUSTPYTHON autocommit behavior differs + def test_autocommit_disabled_implicit_rollback(self): + expected = ["ROLLBACK"] + with memory_database(autocommit=False) as cx: + self.assertTrue(cx.in_transaction) + with self.check_stmt_trace(cx, expected, reset=False): + cx.close() + + def test_autocommit_enabled(self): + expected = ["CREATE TABLE t(t)", "INSERT INTO t VALUES(1)"] + with memory_database(autocommit=True) as cx: + self.assertFalse(cx.in_transaction) + with self.check_stmt_trace(cx, expected): + cx.execute("CREATE TABLE t(t)") + cx.execute("INSERT INTO t VALUES(1)") + self.assertFalse(cx.in_transaction) + + def test_autocommit_enabled_txn_ctl(self): + for op in "commit", "rollback": + with self.subTest(op=op): + with memory_database(autocommit=True) as cx: + meth = getattr(cx, op) + self.assertFalse(cx.in_transaction) + with self.check_stmt_trace(cx, []): + meth() # expect this to pass silently + self.assertFalse(cx.in_transaction) + + @unittest.expectedFailure # TODO: RUSTPYTHON autocommit behavior differs + def test_autocommit_disabled_then_enabled(self): + expected = ["COMMIT"] + with memory_database(autocommit=False) as cx: + self.assertTrue(cx.in_transaction) + with self.check_stmt_trace(cx, expected): + cx.autocommit = True # should commit + self.assertFalse(cx.in_transaction) + + def test_autocommit_enabled_then_disabled(self): + expected = ["BEGIN"] + with memory_database(autocommit=True) as cx: + self.assertFalse(cx.in_transaction) + with self.check_stmt_trace(cx, expected): + cx.autocommit = False # should begin + self.assertTrue(cx.in_transaction) + + def test_autocommit_explicit_then_disabled(self): + expected = ["BEGIN DEFERRED"] + with memory_database(autocommit=True) as cx: + self.assertFalse(cx.in_transaction) + with self.check_stmt_trace(cx, expected): + cx.execute("BEGIN DEFERRED") + cx.autocommit = False # should now be a no-op + self.assertTrue(cx.in_transaction) + + def test_autocommit_enabled_ctx_mgr(self): + with memory_database(autocommit=True) as cx: + # The context manager is a no-op if autocommit=True + with self.check_stmt_trace(cx, []): + with cx: + self.assertFalse(cx.in_transaction) + self.assertFalse(cx.in_transaction) + + @unittest.expectedFailure # TODO: RUSTPYTHON autocommit behavior differs + def test_autocommit_disabled_ctx_mgr(self): + expected = ["COMMIT", "BEGIN"] + with memory_database(autocommit=False) as cx: + with self.check_stmt_trace(cx, expected): + with cx: + self.assertTrue(cx.in_transaction) + self.assertTrue(cx.in_transaction) + + def test_autocommit_compat_ctx_mgr(self): + expected = ["BEGIN ", "INSERT INTO T VALUES(1)", "COMMIT"] + with memory_database(autocommit=self.legacy) as cx: + cx.execute("create table t(t)") + with self.check_stmt_trace(cx, expected): + with cx: + self.assertFalse(cx.in_transaction) + cx.execute("INSERT INTO T VALUES(1)") + self.assertTrue(cx.in_transaction) + self.assertFalse(cx.in_transaction) + + @unittest.expectedFailure # TODO: RUSTPYTHON autocommit behavior differs + def test_autocommit_enabled_executescript(self): + expected = ["BEGIN", "SELECT 1"] + with memory_database(autocommit=True) as cx: + with self.check_stmt_trace(cx, expected): + self.assertFalse(cx.in_transaction) + cx.execute("BEGIN") + cx.executescript("SELECT 1") + self.assertTrue(cx.in_transaction) + + @unittest.expectedFailure # TODO: RUSTPYTHON autocommit behavior differs + def test_autocommit_disabled_executescript(self): + expected = ["SELECT 1"] + with memory_database(autocommit=False) as cx: + with self.check_stmt_trace(cx, expected): + self.assertTrue(cx.in_transaction) + cx.executescript("SELECT 1") + self.assertTrue(cx.in_transaction) + + def test_autocommit_compat_executescript(self): + expected = ["BEGIN", "COMMIT", "SELECT 1"] + with memory_database(autocommit=self.legacy) as cx: + with self.check_stmt_trace(cx, expected): + self.assertFalse(cx.in_transaction) + cx.execute("BEGIN") + cx.executescript("SELECT 1") + self.assertFalse(cx.in_transaction) + + def test_autocommit_disabled_implicit_shutdown(self): + # The implicit ROLLBACK should not call back into Python during + # interpreter tear-down. + code = """if 1: + import sqlite3 + cx = sqlite3.connect(":memory:", autocommit=False) + cx.set_trace_callback(print) + """ + assert_python_ok("-c", code, PYTHONIOENCODING="utf-8") + + if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_sqlite3/test_types.py b/Lib/test/test_sqlite3/test_types.py index 623188235..66d27d21b 100644 --- a/Lib/test/test_sqlite3/test_types.py +++ b/Lib/test/test_sqlite3/test_types.py @@ -106,9 +106,9 @@ class SqliteTypeTests(unittest.TestCase): @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform') @support.bigmemtest(size=2**31, memuse=4, dry_run=False) def test_too_large_string(self, maxsize): - with self.assertRaises(sqlite.InterfaceError): + with self.assertRaises(sqlite.DataError): self.cur.execute("insert into test(s) values (?)", ('x'*(2**31-1),)) - with self.assertRaises(OverflowError): + with self.assertRaises(sqlite.DataError): self.cur.execute("insert into test(s) values (?)", ('x'*(2**31),)) self.cur.execute("select 1 from test") row = self.cur.fetchone() @@ -117,9 +117,9 @@ class SqliteTypeTests(unittest.TestCase): @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform') @support.bigmemtest(size=2**31, memuse=3, dry_run=False) def test_too_large_blob(self, maxsize): - with self.assertRaises(sqlite.InterfaceError): + with self.assertRaises(sqlite.DataError): self.cur.execute("insert into test(s) values (?)", (b'x'*(2**31-1),)) - with self.assertRaises(OverflowError): + with self.assertRaises(sqlite.DataError): self.cur.execute("insert into test(s) values (?)", (b'x'*(2**31),)) self.cur.execute("select 1 from test") row = self.cur.fetchone() @@ -371,7 +371,6 @@ class ColNamesTests(unittest.TestCase): self.assertIsNone(self.cur.description) -@unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), "CTEs not supported") class CommonTableExpressionTests(unittest.TestCase): def setUp(self): @@ -517,7 +516,7 @@ class DateTimeTests(unittest.TestCase): self.assertEqual(ts, ts2) def test_sql_timestamp(self): - now = datetime.datetime.utcnow() + now = datetime.datetime.now(tz=datetime.UTC) self.cur.execute("insert into test(ts) values (current_timestamp)") self.cur.execute("select ts from test") with self.assertWarnsRegex(DeprecationWarning, "converter"): diff --git a/Lib/test/test_sqlite3/test_userfunctions.py b/Lib/test/test_sqlite3/test_userfunctions.py index e8b98a66a..3fdde4a26 100644 --- a/Lib/test/test_sqlite3/test_userfunctions.py +++ b/Lib/test/test_sqlite3/test_userfunctions.py @@ -21,55 +21,15 @@ # misrepresented as being the original software. # 3. This notice may not be removed or altered from any source distribution. -import contextlib -import functools -import io -import re import sys import unittest import sqlite3 as sqlite from unittest.mock import Mock, patch -from test.support import bigmemtest, catch_unraisable_exception, gc_collect +from test.support import bigmemtest, gc_collect -from test.test_sqlite3.test_dbapi import cx_limit - - -def with_tracebacks(exc, regex="", name=""): - """Convenience decorator for testing callback tracebacks.""" - def decorator(func): - _regex = re.compile(regex) if regex else None - @functools.wraps(func) - def wrapper(self, *args, **kwargs): - with catch_unraisable_exception() as cm: - # First, run the test with traceback enabled. - with check_tracebacks(self, cm, exc, _regex, name): - func(self, *args, **kwargs) - - # Then run the test with traceback disabled. - func(self, *args, **kwargs) - return wrapper - return decorator - - -@contextlib.contextmanager -def check_tracebacks(self, cm, exc, regex, obj_name): - """Convenience context manager for testing callback tracebacks.""" - sqlite.enable_callback_tracebacks(True) - try: - buf = io.StringIO() - with contextlib.redirect_stderr(buf): - yield - - # TODO: RUSTPYTHON need unraisable exception - # self.assertEqual(cm.unraisable.exc_type, exc) - # if regex: - # msg = str(cm.unraisable.exc_value) - # self.assertIsNotNone(regex.search(msg)) - # if obj_name: - # self.assertEqual(cm.unraisable.object.__name__, obj_name) - finally: - sqlite.enable_callback_tracebacks(False) +from .util import cx_limit, memory_database +from .util import with_tracebacks def func_returntext(): @@ -196,7 +156,6 @@ class FunctionTests(unittest.TestCase): self.con.create_function("returnblob", 0, func_returnblob) self.con.create_function("returnlonglong", 0, func_returnlonglong) self.con.create_function("returnnan", 0, lambda: float("nan")) - self.con.create_function("returntoolargeint", 0, lambda: 1 << 65) self.con.create_function("return_noncont_blob", 0, lambda: memoryview(b"blob")[::2]) self.con.create_function("raiseexception", 0, func_raiseexception) @@ -211,8 +170,9 @@ class FunctionTests(unittest.TestCase): def tearDown(self): self.con.close() + @unittest.expectedFailure # TODO: RUSTPYTHON error message differs for invalid num args def test_func_error_on_create(self): - with self.assertRaises(sqlite.OperationalError): + with self.assertRaisesRegex(sqlite.ProgrammingError, "not -100"): self.con.create_function("bla", -100, lambda x: 2*x) def test_func_too_many_args(self): @@ -295,12 +255,8 @@ class FunctionTests(unittest.TestCase): cur.execute("select returnnan()") self.assertIsNone(cur.fetchone()[0]) - def test_func_return_too_large_int(self): - cur = self.con.cursor() - self.assertRaisesRegex(sqlite.DataError, "string or blob too big", - self.con.execute, "select returntoolargeint()") - - @with_tracebacks(ZeroDivisionError, name="func_raiseexception") + @unittest.expectedFailure # TODO: RUSTPYTHON unraisable exception handling not implemented + @with_tracebacks(ZeroDivisionError, msg_regex="func_raiseexception") def test_func_exception(self): cur = self.con.cursor() with self.assertRaises(sqlite.OperationalError) as cm: @@ -308,14 +264,16 @@ class FunctionTests(unittest.TestCase): cur.fetchone() self.assertEqual(str(cm.exception), 'user-defined function raised exception') - @with_tracebacks(MemoryError, name="func_memoryerror") + @unittest.expectedFailure # TODO: RUSTPYTHON unraisable exception handling not implemented + @with_tracebacks(MemoryError, msg_regex="func_memoryerror") def test_func_memory_error(self): cur = self.con.cursor() with self.assertRaises(MemoryError): cur.execute("select memoryerror()") cur.fetchone() - @with_tracebacks(OverflowError, name="func_overflowerror") + @unittest.expectedFailure # TODO: RUSTPYTHON unraisable exception handling not implemented + @with_tracebacks(OverflowError, msg_regex="func_overflowerror") def test_func_overflow_error(self): cur = self.con.cursor() with self.assertRaises(sqlite.DataError): @@ -348,6 +306,7 @@ class FunctionTests(unittest.TestCase): self.con.execute, "select spam(?)", (memoryview(b"blob")[::2],)) + @unittest.expectedFailure # TODO: RUSTPYTHON unraisable exception handling not implemented @with_tracebacks(BufferError, regex="buffer.*contiguous") def test_return_non_contiguous_blob(self): with self.assertRaises(sqlite.OperationalError): @@ -388,38 +347,22 @@ class FunctionTests(unittest.TestCase): # Regarding deterministic functions: # # Between 3.8.3 and 3.15.0, deterministic functions were only used to - # optimize inner loops, so for those versions we can only test if the - # sqlite machinery has factored out a call or not. From 3.15.0 and onward, - # deterministic functions were permitted in WHERE clauses of partial - # indices, which allows testing based on syntax, iso. the query optimizer. - @unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), "Requires SQLite 3.8.3 or higher") + # optimize inner loops. From 3.15.0 and onward, deterministic functions + # were permitted in WHERE clauses of partial indices, which allows testing + # based on syntax, iso. the query optimizer. def test_func_non_deterministic(self): mock = Mock(return_value=None) self.con.create_function("nondeterministic", 0, mock, deterministic=False) - if sqlite.sqlite_version_info < (3, 15, 0): - self.con.execute("select nondeterministic() = nondeterministic()") - self.assertEqual(mock.call_count, 2) - else: - with self.assertRaises(sqlite.OperationalError): - self.con.execute("create index t on test(t) where nondeterministic() is not null") + with self.assertRaises(sqlite.OperationalError): + self.con.execute("create index t on test(t) where nondeterministic() is not null") - @unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), "Requires SQLite 3.8.3 or higher") def test_func_deterministic(self): mock = Mock(return_value=None) self.con.create_function("deterministic", 0, mock, deterministic=True) - if sqlite.sqlite_version_info < (3, 15, 0): - self.con.execute("select deterministic() = deterministic()") - self.assertEqual(mock.call_count, 1) - else: - try: - self.con.execute("create index t on test(t) where deterministic() is not null") - except sqlite.OperationalError: - self.fail("Unexpected failure while creating partial index") - - @unittest.skipIf(sqlite.sqlite_version_info >= (3, 8, 3), "SQLite < 3.8.3 needed") - def test_func_deterministic_not_supported(self): - with self.assertRaises(sqlite.NotSupportedError): - self.con.create_function("deterministic", 0, int, deterministic=True) + try: + self.con.execute("create index t on test(t) where deterministic() is not null") + except sqlite.OperationalError: + self.fail("Unexpected failure while creating partial index") def test_func_deterministic_keyword_only(self): with self.assertRaises(TypeError): @@ -428,29 +371,32 @@ class FunctionTests(unittest.TestCase): def test_function_destructor_via_gc(self): # See bpo-44304: The destructor of the user function can # crash if is called without the GIL from the gc functions - dest = sqlite.connect(':memory:') def md5sum(t): return - dest.create_function("md5", 1, md5sum) - x = dest("create table lang (name, first_appeared)") - del md5sum, dest + with memory_database() as dest: + dest.create_function("md5", 1, md5sum) + x = dest("create table lang (name, first_appeared)") + del md5sum, dest - y = [x] - y.append(y) + y = [x] + y.append(y) - del x,y - gc_collect() + del x,y + gc_collect() + @unittest.expectedFailure # TODO: RUSTPYTHON unraisable exception handling not implemented @with_tracebacks(OverflowError) def test_func_return_too_large_int(self): cur = self.con.cursor() + msg = "string or blob too big" for value in 2**63, -2**63-1, 2**64: self.con.create_function("largeint", 0, lambda value=value: value) - with self.assertRaises(sqlite.DataError): + with self.assertRaisesRegex(sqlite.DataError, msg): cur.execute("select largeint()") - @with_tracebacks(UnicodeEncodeError, "surrogates not allowed", "chr") + @unittest.expectedFailure # TODO: RUSTPYTHON unraisable exception handling not implemented + @with_tracebacks(UnicodeEncodeError, "surrogates not allowed") def test_func_return_text_with_surrogates(self): cur = self.con.cursor() self.con.create_function("pychr", 1, chr) @@ -482,6 +428,30 @@ class FunctionTests(unittest.TestCase): self.assertRaisesRegex(sqlite.OperationalError, msg, self.con.execute, "select badreturn()") + @unittest.expectedFailure # TODO: RUSTPYTHON deprecation warning not emitted for keyword args + def test_func_keyword_args(self): + regex = ( + r"Passing keyword arguments 'name', 'narg' and 'func' to " + r"_sqlite3.Connection.create_function\(\) is deprecated. " + r"Parameters 'name', 'narg' and 'func' will become " + r"positional-only in Python 3.15." + ) + + def noop(): + return None + + with self.assertWarnsRegex(DeprecationWarning, regex) as cm: + self.con.create_function("noop", 0, func=noop) + self.assertEqual(cm.filename, __file__) + + with self.assertWarnsRegex(DeprecationWarning, regex) as cm: + self.con.create_function("noop", narg=0, func=noop) + self.assertEqual(cm.filename, __file__) + + with self.assertWarnsRegex(DeprecationWarning, regex) as cm: + self.con.create_function(name="noop", narg=0, func=noop) + self.assertEqual(cm.filename, __file__) + class WindowSumInt: def __init__(self): @@ -536,15 +506,20 @@ class WindowFunctionTests(unittest.TestCase): """ self.con.create_window_function("sumint", 1, WindowSumInt) + def tearDown(self): + self.cur.close() + self.con.close() + def test_win_sum_int(self): self.cur.execute(self.query % "sumint") self.assertEqual(self.cur.fetchall(), self.expected) + @unittest.expectedFailure # TODO: RUSTPYTHON error message differs for invalid num args def test_win_error_on_create(self): - self.assertRaises(sqlite.ProgrammingError, - self.con.create_window_function, - "shouldfail", -100, WindowSumInt) + with self.assertRaisesRegex(sqlite.ProgrammingError, "not -100"): + self.con.create_window_function("shouldfail", -100, WindowSumInt) + @unittest.expectedFailure # TODO: RUSTPYTHON unraisable exception handling not implemented @with_tracebacks(BadWindow) def test_win_exception_in_method(self): for meth in "__init__", "step", "value", "inverse": @@ -557,17 +532,19 @@ class WindowFunctionTests(unittest.TestCase): self.cur.execute(self.query % name) self.cur.fetchall() + @unittest.expectedFailure # TODO: RUSTPYTHON unraisable exception handling not implemented @with_tracebacks(BadWindow) def test_win_exception_in_finalize(self): # Note: SQLite does not (as of version 3.38.0) propagate finalize # callback errors to sqlite3_step(); this implies that OperationalError # is _not_ raised. with patch.object(WindowSumInt, "finalize", side_effect=BadWindow): - name = f"exception_in_finalize" + name = "exception_in_finalize" self.con.create_window_function(name, 1, WindowSumInt) self.cur.execute(self.query % name) self.cur.fetchall() + @unittest.expectedFailure # TODO: RUSTPYTHON unraisable exception handling not implemented @with_tracebacks(AttributeError) def test_win_missing_method(self): class MissingValue: @@ -599,6 +576,7 @@ class WindowFunctionTests(unittest.TestCase): self.cur.execute(self.query % name) self.cur.fetchall() + @unittest.expectedFailure # TODO: RUSTPYTHON unraisable exception handling not implemented @with_tracebacks(AttributeError) def test_win_missing_finalize(self): # Note: SQLite does not (as of version 3.38.0) propagate finalize @@ -656,6 +634,7 @@ class AggregateTests(unittest.TestCase): """) cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)", ("foo", 5, 3.14, None, memoryview(b"blob"),)) + cur.close() self.con.create_aggregate("nostep", 1, AggrNoStep) self.con.create_aggregate("nofinalize", 1, AggrNoFinalize) @@ -668,15 +647,15 @@ class AggregateTests(unittest.TestCase): self.con.create_aggregate("aggtxt", 1, AggrText) def tearDown(self): - #self.cur.close() - #self.con.close() - pass + self.con.close() + @unittest.expectedFailure # TODO: RUSTPYTHON error message differs for invalid num args def test_aggr_error_on_create(self): - with self.assertRaises(sqlite.OperationalError): + with self.assertRaisesRegex(sqlite.ProgrammingError, "not -100"): self.con.create_function("bla", -100, AggrSum) - @with_tracebacks(AttributeError, name="AggrNoStep") + @unittest.expectedFailure # TODO: RUSTPYTHON unraisable exception handling not implemented + @with_tracebacks(AttributeError, msg_regex="AggrNoStep") def test_aggr_no_step(self): cur = self.con.cursor() with self.assertRaises(sqlite.OperationalError) as cm: @@ -691,7 +670,8 @@ class AggregateTests(unittest.TestCase): cur.execute("select nofinalize(t) from test") val = cur.fetchone()[0] - @with_tracebacks(ZeroDivisionError, name="AggrExceptionInInit") + @unittest.expectedFailure # TODO: RUSTPYTHON unraisable exception handling not implemented + @with_tracebacks(ZeroDivisionError, msg_regex="AggrExceptionInInit") def test_aggr_exception_in_init(self): cur = self.con.cursor() with self.assertRaises(sqlite.OperationalError) as cm: @@ -699,7 +679,8 @@ class AggregateTests(unittest.TestCase): val = cur.fetchone()[0] self.assertEqual(str(cm.exception), "user-defined aggregate's '__init__' method raised error") - @with_tracebacks(ZeroDivisionError, name="AggrExceptionInStep") + @unittest.expectedFailure # TODO: RUSTPYTHON unraisable exception handling not implemented + @with_tracebacks(ZeroDivisionError, msg_regex="AggrExceptionInStep") def test_aggr_exception_in_step(self): cur = self.con.cursor() with self.assertRaises(sqlite.OperationalError) as cm: @@ -707,7 +688,8 @@ class AggregateTests(unittest.TestCase): val = cur.fetchone()[0] self.assertEqual(str(cm.exception), "user-defined aggregate's 'step' method raised error") - @with_tracebacks(ZeroDivisionError, name="AggrExceptionInFinalize") + @unittest.expectedFailure # TODO: RUSTPYTHON unraisable exception handling not implemented + @with_tracebacks(ZeroDivisionError, msg_regex="AggrExceptionInFinalize") def test_aggr_exception_in_finalize(self): cur = self.con.cursor() with self.assertRaises(sqlite.OperationalError) as cm: @@ -772,6 +754,28 @@ class AggregateTests(unittest.TestCase): val = cur.fetchone()[0] self.assertEqual(val, txt) + @unittest.expectedFailure # TODO: RUSTPYTHON keyword-only arguments not supported for create_aggregate + def test_agg_keyword_args(self): + regex = ( + r"Passing keyword arguments 'name', 'n_arg' and 'aggregate_class' to " + r"_sqlite3.Connection.create_aggregate\(\) is deprecated. " + r"Parameters 'name', 'n_arg' and 'aggregate_class' will become " + r"positional-only in Python 3.15." + ) + + with self.assertWarnsRegex(DeprecationWarning, regex) as cm: + self.con.create_aggregate("test", 1, aggregate_class=AggrText) + self.assertEqual(cm.filename, __file__) + + with self.assertWarnsRegex(DeprecationWarning, regex) as cm: + self.con.create_aggregate("test", n_arg=1, aggregate_class=AggrText) + self.assertEqual(cm.filename, __file__) + + with self.assertWarnsRegex(DeprecationWarning, regex) as cm: + self.con.create_aggregate(name="test", n_arg=0, + aggregate_class=AggrText) + self.assertEqual(cm.filename, __file__) + class AuthorizerTests(unittest.TestCase): @staticmethod @@ -783,8 +787,6 @@ class AuthorizerTests(unittest.TestCase): return sqlite.SQLITE_OK def setUp(self): - # TODO: RUSTPYTHON difference 'prohibited' - self.prohibited = 'not authorized' self.con = sqlite.connect(":memory:") self.con.executescript(""" create table t1 (c1, c2); @@ -799,23 +801,38 @@ class AuthorizerTests(unittest.TestCase): self.con.set_authorizer(self.authorizer_cb) def tearDown(self): - pass + self.con.close() + @unittest.expectedFailure # TODO: RUSTPYTHON error message differs def test_table_access(self): with self.assertRaises(sqlite.DatabaseError) as cm: self.con.execute("select * from t2") - self.assertIn(self.prohibited, str(cm.exception)) + self.assertIn('prohibited', str(cm.exception)) + @unittest.expectedFailure # TODO: RUSTPYTHON error message differs def test_column_access(self): with self.assertRaises(sqlite.DatabaseError) as cm: self.con.execute("select c2 from t1") - self.assertIn(self.prohibited, str(cm.exception)) + self.assertIn('prohibited', str(cm.exception)) def test_clear_authorizer(self): self.con.set_authorizer(None) self.con.execute("select * from t2") self.con.execute("select c2 from t1") + @unittest.expectedFailure # TODO: RUSTPYTHON keyword-only arguments not supported for set_authorizer + def test_authorizer_keyword_args(self): + regex = ( + r"Passing keyword argument 'authorizer_callback' to " + r"_sqlite3.Connection.set_authorizer\(\) is deprecated. " + r"Parameter 'authorizer_callback' will become positional-only in " + r"Python 3.15." + ) + + with self.assertWarnsRegex(DeprecationWarning, regex) as cm: + self.con.set_authorizer(authorizer_callback=lambda: None) + self.assertEqual(cm.filename, __file__) + class AuthorizerRaiseExceptionTests(AuthorizerTests): @staticmethod @@ -826,11 +843,13 @@ class AuthorizerRaiseExceptionTests(AuthorizerTests): raise ValueError return sqlite.SQLITE_OK - @with_tracebacks(ValueError, name="authorizer_cb") + @unittest.expectedFailure # TODO: RUSTPYTHON unraisable exception handling not implemented + @with_tracebacks(ValueError, msg_regex="authorizer_cb") def test_table_access(self): super().test_table_access() - @with_tracebacks(ValueError, name="authorizer_cb") + @unittest.expectedFailure # TODO: RUSTPYTHON unraisable exception handling not implemented + @with_tracebacks(ValueError, msg_regex="authorizer_cb") def test_column_access(self): super().test_table_access() diff --git a/Lib/test/test_sqlite3/util.py b/Lib/test/test_sqlite3/util.py new file mode 100644 index 000000000..cccd06216 --- /dev/null +++ b/Lib/test/test_sqlite3/util.py @@ -0,0 +1,89 @@ +import contextlib +import functools +import io +import re +import sqlite3 +import test.support +import unittest + + +# Helper for temporary memory databases +def memory_database(*args, **kwargs): + cx = sqlite3.connect(":memory:", *args, **kwargs) + return contextlib.closing(cx) + + +# Temporarily limit a database connection parameter +@contextlib.contextmanager +def cx_limit(cx, category=sqlite3.SQLITE_LIMIT_SQL_LENGTH, limit=128): + try: + _prev = cx.setlimit(category, limit) + yield limit + finally: + cx.setlimit(category, _prev) + + +def with_tracebacks(exc, regex="", name="", msg_regex=""): + """Convenience decorator for testing callback tracebacks.""" + def decorator(func): + exc_regex = re.compile(regex) if regex else None + _msg_regex = re.compile(msg_regex) if msg_regex else None + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + with test.support.catch_unraisable_exception() as cm: + # First, run the test with traceback enabled. + with check_tracebacks(self, cm, exc, exc_regex, _msg_regex, name): + func(self, *args, **kwargs) + + # Then run the test with traceback disabled. + func(self, *args, **kwargs) + return wrapper + return decorator + + +@contextlib.contextmanager +def check_tracebacks(self, cm, exc, exc_regex, msg_regex, obj_name): + """Convenience context manager for testing callback tracebacks.""" + sqlite3.enable_callback_tracebacks(True) + try: + buf = io.StringIO() + with contextlib.redirect_stderr(buf): + yield + + self.assertEqual(cm.unraisable.exc_type, exc) + if exc_regex: + msg = str(cm.unraisable.exc_value) + self.assertIsNotNone(exc_regex.search(msg), (exc_regex, msg)) + if msg_regex: + msg = cm.unraisable.err_msg + self.assertIsNotNone(msg_regex.search(msg), (msg_regex, msg)) + if obj_name: + self.assertEqual(cm.unraisable.object.__name__, obj_name) + finally: + sqlite3.enable_callback_tracebacks(False) + + +class MemoryDatabaseMixin: + + def setUp(self): + self.con = sqlite3.connect(":memory:") + self.cur = self.con.cursor() + + def tearDown(self): + self.cur.close() + self.con.close() + + @property + def cx(self): + return self.con + + @property + def cu(self): + return self.cur + + +def requires_virtual_table(module): + with memory_database() as cx: + supported = (module,) in list(cx.execute("PRAGMA module_list")) + reason = f"Requires {module!r} virtual table support" + return unittest.skipUnless(supported, reason)