mirror of
https://github.com/RustPython/RustPython.git
synced 2026-06-02 19:39:49 +09:00
3547 lines
121 KiB
Rust
3547 lines
121 KiB
Rust
// spell-checker:ignore libsqlite3 threadsafety PYSQLITE decltypes colnames collseq cantinit dirtywal
|
||
// spell-checker:ignore corruptfs narg setinputsizes setoutputsize lastrowid arraysize executemany
|
||
// spell-checker:ignore blobopen executescript iterdump getlimit setlimit errorcode errorname
|
||
// spell-checker:ignore rowid rowcount fetchone fetchmany fetchall errcode errname vtable pagecount
|
||
// spell-checker:ignore autocommit libversion toobig errmsg nomem threadsafe longlong vdbe reindex
|
||
// spell-checker:ignore savepoint cantopen ioerr nolfs nomem notadb notfound fullpath notempdir vtab
|
||
// spell-checker:ignore checkreservedlock noent fstat rdlock shmlock shmmap shmopen shmsize sharedcache
|
||
// spell-checker:ignore cantlock commithook foreignkey notnull primarykey gettemppath autoindex convpath
|
||
// spell-checker:ignore dbmoved vnode nbytes
|
||
|
||
pub(crate) use _sqlite3::module_def;
|
||
|
||
#[pymodule]
|
||
mod _sqlite3 {
|
||
use core::{
|
||
ffi::{CStr, c_int, c_longlong, c_uint, c_void},
|
||
fmt::Debug,
|
||
ops::Deref,
|
||
ptr::{NonNull, null, null_mut},
|
||
};
|
||
use libsqlite3_sys::{
|
||
SQLITE_BLOB, SQLITE_DETERMINISTIC, SQLITE_FLOAT, SQLITE_INTEGER, SQLITE_NULL,
|
||
SQLITE_OPEN_CREATE, SQLITE_OPEN_READWRITE, SQLITE_OPEN_URI, SQLITE_TEXT, SQLITE_TRACE_STMT,
|
||
SQLITE_TRANSIENT, SQLITE_UTF8, sqlite3, sqlite3_aggregate_context, sqlite3_backup_finish,
|
||
sqlite3_backup_init, sqlite3_backup_pagecount, sqlite3_backup_remaining,
|
||
sqlite3_backup_step, sqlite3_bind_blob, sqlite3_bind_double, sqlite3_bind_int64,
|
||
sqlite3_bind_null, sqlite3_bind_parameter_count, sqlite3_bind_parameter_name,
|
||
sqlite3_bind_text, sqlite3_blob, sqlite3_blob_bytes, sqlite3_blob_close, sqlite3_blob_open,
|
||
sqlite3_blob_read, sqlite3_blob_write, sqlite3_busy_timeout, sqlite3_changes,
|
||
sqlite3_column_blob, sqlite3_column_bytes, sqlite3_column_count, sqlite3_column_decltype,
|
||
sqlite3_column_double, sqlite3_column_int64, sqlite3_column_name, sqlite3_column_text,
|
||
sqlite3_column_type, sqlite3_complete, sqlite3_context, sqlite3_context_db_handle,
|
||
sqlite3_create_collation_v2, sqlite3_create_function_v2, sqlite3_create_window_function,
|
||
sqlite3_data_count, sqlite3_db_handle, sqlite3_errcode, sqlite3_errmsg, sqlite3_exec,
|
||
sqlite3_expanded_sql, sqlite3_extended_errcode, sqlite3_finalize, sqlite3_get_autocommit,
|
||
sqlite3_interrupt, sqlite3_last_insert_rowid, sqlite3_libversion, sqlite3_limit,
|
||
sqlite3_open_v2, sqlite3_prepare_v2, sqlite3_progress_handler, sqlite3_reset,
|
||
sqlite3_result_blob, sqlite3_result_double, sqlite3_result_error,
|
||
sqlite3_result_error_nomem, sqlite3_result_error_toobig, sqlite3_result_int64,
|
||
sqlite3_result_null, sqlite3_result_text, sqlite3_set_authorizer, sqlite3_sleep,
|
||
sqlite3_step, sqlite3_stmt, sqlite3_stmt_busy, sqlite3_stmt_readonly, sqlite3_threadsafe,
|
||
sqlite3_total_changes, sqlite3_trace_v2, sqlite3_user_data, sqlite3_value,
|
||
sqlite3_value_blob, sqlite3_value_bytes, sqlite3_value_double, sqlite3_value_int64,
|
||
sqlite3_value_text, sqlite3_value_type,
|
||
};
|
||
use malachite_bigint::Sign;
|
||
use rustpython_common::{
|
||
atomic::{Ordering, PyAtomic, Radium},
|
||
hash::PyHash,
|
||
lock::{PyMappedMutexGuard, PyMutex, PyMutexGuard},
|
||
static_cell,
|
||
};
|
||
use rustpython_vm::{
|
||
__exports::paste,
|
||
AsObject, Py, PyAtomicRef, PyObject, PyObjectRef, PyPayload, PyRef, PyResult,
|
||
TryFromBorrowedObject, VirtualMachine, atomic_func,
|
||
builtins::{
|
||
PyBaseException, PyBaseExceptionRef, PyByteArray, PyBytes, PyDict, PyDictRef, PyFloat,
|
||
PyInt, PyIntRef, PyModule, PySlice, PyStr, PyStrRef, PyTuple, PyTupleRef, PyType,
|
||
PyTypeRef, PyUtf8Str, PyUtf8StrRef,
|
||
},
|
||
convert::IntoObject,
|
||
function::{
|
||
ArgCallable, ArgIterable, FsPath, FuncArgs, OptionalArg, PyComparisonValue,
|
||
PySetterValue,
|
||
},
|
||
object::{Traverse, TraverseFn},
|
||
protocol::{
|
||
PyBuffer, PyIterReturn, PyMappingMethods, PyNumberMethods, PySequence,
|
||
PySequenceMethods,
|
||
},
|
||
sliceable::{SaturatedSliceIter, SliceableSequenceOp},
|
||
types::{
|
||
AsMapping, AsNumber, AsSequence, Callable, Comparable, Constructor, Hashable,
|
||
Initializer, IterNext, Iterable, PyComparisonOp, SelfIter,
|
||
},
|
||
utils::ToCString,
|
||
};
|
||
use std::thread::ThreadId;
|
||
|
||
macro_rules! exceptions {
|
||
($(($x:ident, $base:expr)),*) => {
|
||
paste::paste! {
|
||
static_cell! {
|
||
$(
|
||
static [<$x:snake:upper>]: PyTypeRef;
|
||
)*
|
||
}
|
||
$(
|
||
#[allow(dead_code)]
|
||
fn [<new_ $x:snake>](vm: &VirtualMachine, msg: String) -> PyBaseExceptionRef {
|
||
vm.new_exception_msg([<$x:snake _type>]().to_owned(), msg)
|
||
}
|
||
fn [<$x:snake _type>]() -> &'static Py<PyType> {
|
||
[<$x:snake:upper>].get().expect("exception type not initialize")
|
||
}
|
||
)*
|
||
fn setup_module_exceptions(module: &PyObject, vm: &VirtualMachine) {
|
||
$(
|
||
#[allow(clippy::redundant_closure_call)]
|
||
let exception = [<$x:snake:upper>].get_or_init(|| {
|
||
let base = $base(vm);
|
||
vm.ctx.new_exception_type("_sqlite3", stringify!($x), Some(vec![base.to_owned()]))
|
||
});
|
||
module.set_attr(stringify!($x), exception.clone().into_object(), vm).unwrap();
|
||
)*
|
||
}
|
||
}
|
||
};
|
||
}
|
||
|
||
exceptions!(
|
||
(Warning, |vm: &VirtualMachine| vm
|
||
.ctx
|
||
.exceptions
|
||
.exception_type),
|
||
(Error, |vm: &VirtualMachine| vm
|
||
.ctx
|
||
.exceptions
|
||
.exception_type),
|
||
(InterfaceError, |_| error_type()),
|
||
(DatabaseError, |_| error_type()),
|
||
(DataError, |_| database_error_type()),
|
||
(OperationalError, |_| database_error_type()),
|
||
(IntegrityError, |_| database_error_type()),
|
||
(InternalError, |_| database_error_type()),
|
||
(ProgrammingError, |_| database_error_type()),
|
||
(NotSupportedError, |_| database_error_type())
|
||
);
|
||
|
||
#[pyattr]
|
||
fn sqlite_version(vm: &VirtualMachine) -> String {
|
||
let s = unsafe { sqlite3_libversion() };
|
||
ptr_to_str(s, vm).unwrap().to_owned()
|
||
}
|
||
|
||
#[pyattr]
|
||
fn threadsafety(_: &VirtualMachine) -> c_int {
|
||
let mode = unsafe { sqlite3_threadsafe() };
|
||
match mode {
|
||
0 => 0,
|
||
1 => 3,
|
||
2 => 1,
|
||
_ => panic!("Unable to interpret SQLite threadsafety mode"),
|
||
}
|
||
}
|
||
|
||
#[pyattr(name = "_deprecated_version")]
|
||
const PYSQLITE_VERSION: &str = "2.6.0";
|
||
|
||
#[pyattr]
|
||
const PARSE_DECLTYPES: c_int = 1;
|
||
#[pyattr]
|
||
const PARSE_COLNAMES: c_int = 2;
|
||
#[pyattr]
|
||
const LEGACY_TRANSACTION_CONTROL: c_int = -1;
|
||
|
||
#[pyattr]
|
||
use libsqlite3_sys::{
|
||
SQLITE_ALTER_TABLE, SQLITE_ANALYZE, SQLITE_ATTACH, SQLITE_CREATE_INDEX,
|
||
SQLITE_CREATE_TABLE, SQLITE_CREATE_TEMP_INDEX, SQLITE_CREATE_TEMP_TABLE,
|
||
SQLITE_CREATE_TEMP_TRIGGER, SQLITE_CREATE_TEMP_VIEW, SQLITE_CREATE_TRIGGER,
|
||
SQLITE_CREATE_VIEW, SQLITE_CREATE_VTABLE, SQLITE_DELETE, SQLITE_DENY, SQLITE_DETACH,
|
||
SQLITE_DROP_INDEX, SQLITE_DROP_TABLE, SQLITE_DROP_TEMP_INDEX, SQLITE_DROP_TEMP_TABLE,
|
||
SQLITE_DROP_TEMP_TRIGGER, SQLITE_DROP_TEMP_VIEW, SQLITE_DROP_TRIGGER, SQLITE_DROP_VIEW,
|
||
SQLITE_DROP_VTABLE, SQLITE_FUNCTION, SQLITE_IGNORE, SQLITE_INSERT, SQLITE_LIMIT_ATTACHED,
|
||
SQLITE_LIMIT_COLUMN, SQLITE_LIMIT_COMPOUND_SELECT, SQLITE_LIMIT_EXPR_DEPTH,
|
||
SQLITE_LIMIT_FUNCTION_ARG, SQLITE_LIMIT_LENGTH, SQLITE_LIMIT_LIKE_PATTERN_LENGTH,
|
||
SQLITE_LIMIT_SQL_LENGTH, SQLITE_LIMIT_TRIGGER_DEPTH, SQLITE_LIMIT_VARIABLE_NUMBER,
|
||
SQLITE_LIMIT_VDBE_OP, SQLITE_LIMIT_WORKER_THREADS, SQLITE_PRAGMA, SQLITE_READ,
|
||
SQLITE_RECURSIVE, SQLITE_REINDEX, SQLITE_SAVEPOINT, SQLITE_SELECT, SQLITE_TRANSACTION,
|
||
SQLITE_UPDATE,
|
||
};
|
||
|
||
macro_rules! error_codes {
|
||
($($x:ident),*) => {
|
||
$(
|
||
#[allow(unused_imports)]
|
||
use libsqlite3_sys::$x;
|
||
)*
|
||
static ERROR_CODES: &[(&str, c_int)] = &[
|
||
$(
|
||
(stringify!($x), libsqlite3_sys::$x),
|
||
)*
|
||
];
|
||
};
|
||
}
|
||
|
||
error_codes!(
|
||
SQLITE_ABORT,
|
||
SQLITE_AUTH,
|
||
SQLITE_BUSY,
|
||
SQLITE_CANTOPEN,
|
||
SQLITE_CONSTRAINT,
|
||
SQLITE_CORRUPT,
|
||
SQLITE_DONE,
|
||
SQLITE_EMPTY,
|
||
SQLITE_ERROR,
|
||
SQLITE_FORMAT,
|
||
SQLITE_FULL,
|
||
SQLITE_INTERNAL,
|
||
SQLITE_INTERRUPT,
|
||
SQLITE_IOERR,
|
||
SQLITE_LOCKED,
|
||
SQLITE_MISMATCH,
|
||
SQLITE_MISUSE,
|
||
SQLITE_NOLFS,
|
||
SQLITE_NOMEM,
|
||
SQLITE_NOTADB,
|
||
SQLITE_NOTFOUND,
|
||
SQLITE_OK,
|
||
SQLITE_PERM,
|
||
SQLITE_PROTOCOL,
|
||
SQLITE_RANGE,
|
||
SQLITE_READONLY,
|
||
SQLITE_ROW,
|
||
SQLITE_SCHEMA,
|
||
SQLITE_TOOBIG,
|
||
SQLITE_NOTICE,
|
||
SQLITE_WARNING,
|
||
SQLITE_ABORT_ROLLBACK,
|
||
SQLITE_BUSY_RECOVERY,
|
||
SQLITE_CANTOPEN_FULLPATH,
|
||
SQLITE_CANTOPEN_ISDIR,
|
||
SQLITE_CANTOPEN_NOTEMPDIR,
|
||
SQLITE_CORRUPT_VTAB,
|
||
SQLITE_IOERR_ACCESS,
|
||
SQLITE_IOERR_BLOCKED,
|
||
SQLITE_IOERR_CHECKRESERVEDLOCK,
|
||
SQLITE_IOERR_CLOSE,
|
||
SQLITE_IOERR_DELETE,
|
||
SQLITE_IOERR_DELETE_NOENT,
|
||
SQLITE_IOERR_DIR_CLOSE,
|
||
SQLITE_IOERR_DIR_FSYNC,
|
||
SQLITE_IOERR_FSTAT,
|
||
SQLITE_IOERR_FSYNC,
|
||
SQLITE_IOERR_LOCK,
|
||
SQLITE_IOERR_NOMEM,
|
||
SQLITE_IOERR_RDLOCK,
|
||
SQLITE_IOERR_READ,
|
||
SQLITE_IOERR_SEEK,
|
||
SQLITE_IOERR_SHMLOCK,
|
||
SQLITE_IOERR_SHMMAP,
|
||
SQLITE_IOERR_SHMOPEN,
|
||
SQLITE_IOERR_SHMSIZE,
|
||
SQLITE_IOERR_SHORT_READ,
|
||
SQLITE_IOERR_TRUNCATE,
|
||
SQLITE_IOERR_UNLOCK,
|
||
SQLITE_IOERR_WRITE,
|
||
SQLITE_LOCKED_SHAREDCACHE,
|
||
SQLITE_READONLY_CANTLOCK,
|
||
SQLITE_READONLY_RECOVERY,
|
||
SQLITE_CONSTRAINT_CHECK,
|
||
SQLITE_CONSTRAINT_COMMITHOOK,
|
||
SQLITE_CONSTRAINT_FOREIGNKEY,
|
||
SQLITE_CONSTRAINT_FUNCTION,
|
||
SQLITE_CONSTRAINT_NOTNULL,
|
||
SQLITE_CONSTRAINT_PRIMARYKEY,
|
||
SQLITE_CONSTRAINT_TRIGGER,
|
||
SQLITE_CONSTRAINT_UNIQUE,
|
||
SQLITE_CONSTRAINT_VTAB,
|
||
SQLITE_READONLY_ROLLBACK,
|
||
SQLITE_IOERR_MMAP,
|
||
SQLITE_NOTICE_RECOVER_ROLLBACK,
|
||
SQLITE_NOTICE_RECOVER_WAL,
|
||
SQLITE_BUSY_SNAPSHOT,
|
||
SQLITE_IOERR_GETTEMPPATH,
|
||
SQLITE_WARNING_AUTOINDEX,
|
||
SQLITE_CANTOPEN_CONVPATH,
|
||
SQLITE_IOERR_CONVPATH,
|
||
SQLITE_CONSTRAINT_ROWID,
|
||
SQLITE_READONLY_DBMOVED,
|
||
SQLITE_AUTH_USER,
|
||
SQLITE_OK_LOAD_PERMANENTLY,
|
||
SQLITE_IOERR_VNODE,
|
||
SQLITE_IOERR_AUTH,
|
||
SQLITE_IOERR_BEGIN_ATOMIC,
|
||
SQLITE_IOERR_COMMIT_ATOMIC,
|
||
SQLITE_IOERR_ROLLBACK_ATOMIC,
|
||
SQLITE_ERROR_MISSING_COLLSEQ,
|
||
SQLITE_ERROR_RETRY,
|
||
SQLITE_READONLY_CANTINIT,
|
||
SQLITE_READONLY_DIRECTORY,
|
||
SQLITE_CORRUPT_SEQUENCE,
|
||
SQLITE_LOCKED_VTAB,
|
||
SQLITE_CANTOPEN_DIRTYWAL,
|
||
SQLITE_ERROR_SNAPSHOT,
|
||
SQLITE_CANTOPEN_SYMLINK,
|
||
SQLITE_CONSTRAINT_PINNED,
|
||
SQLITE_OK_SYMLINK,
|
||
SQLITE_BUSY_TIMEOUT,
|
||
SQLITE_CORRUPT_INDEX,
|
||
SQLITE_IOERR_DATA,
|
||
SQLITE_IOERR_CORRUPTFS
|
||
);
|
||
|
||
/// Autocommit mode setting for sqlite3 connections.
|
||
/// - Legacy (default): use isolation_level to control transactions
|
||
/// - Enabled: autocommit mode (no automatic transactions)
|
||
/// - Disabled: manual commit mode
|
||
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
|
||
enum AutocommitMode {
|
||
#[default]
|
||
Legacy,
|
||
Enabled,
|
||
Disabled,
|
||
}
|
||
|
||
impl TryFromBorrowedObject<'_> for AutocommitMode {
|
||
fn try_from_borrowed_object(vm: &VirtualMachine, obj: &PyObject) -> PyResult<Self> {
|
||
if obj.is(&vm.ctx.true_value) {
|
||
Ok(Self::Enabled)
|
||
} else if obj.is(&vm.ctx.false_value) {
|
||
Ok(Self::Disabled)
|
||
} else if let Ok(val) = obj.try_to_value::<c_int>(vm) {
|
||
if val == LEGACY_TRANSACTION_CONTROL {
|
||
Ok(Self::Legacy)
|
||
} else {
|
||
Err(vm.new_value_error(format!(
|
||
"autocommit must be True, False, or sqlite3.LEGACY_TRANSACTION_CONTROL, not {val}"
|
||
)))
|
||
}
|
||
} else {
|
||
Err(vm.new_type_error(format!(
|
||
"autocommit must be True, False, or sqlite3.LEGACY_TRANSACTION_CONTROL, not {}",
|
||
obj.class().name()
|
||
)))
|
||
}
|
||
}
|
||
}
|
||
|
||
#[derive(FromArgs)]
|
||
struct ConnectArgs {
|
||
#[pyarg(any)]
|
||
database: FsPath,
|
||
#[pyarg(any, default = 5.0)]
|
||
timeout: f64,
|
||
#[pyarg(any, default = 0)]
|
||
detect_types: c_int,
|
||
#[pyarg(any, default = Some(vm.ctx.empty_str.to_owned()))]
|
||
isolation_level: Option<PyStrRef>,
|
||
#[pyarg(any, default = true)]
|
||
check_same_thread: bool,
|
||
#[pyarg(any, default = Connection::class(&vm.ctx).to_owned())]
|
||
factory: PyTypeRef,
|
||
// TODO: cache statements
|
||
#[allow(dead_code)]
|
||
#[pyarg(any, default = 0)]
|
||
cached_statements: c_int,
|
||
#[pyarg(any, default = false)]
|
||
uri: bool,
|
||
#[pyarg(any, default)]
|
||
autocommit: AutocommitMode,
|
||
}
|
||
|
||
unsafe impl Traverse for ConnectArgs {
|
||
fn traverse(&self, tracer_fn: &mut TraverseFn<'_>) {
|
||
self.isolation_level.traverse(tracer_fn);
|
||
self.factory.traverse(tracer_fn);
|
||
}
|
||
}
|
||
|
||
#[derive(FromArgs)]
|
||
struct BackupArgs {
|
||
#[pyarg(any)]
|
||
target: PyRef<Connection>,
|
||
#[pyarg(named, default = -1)]
|
||
pages: c_int,
|
||
#[pyarg(named, optional)]
|
||
progress: Option<ArgCallable>,
|
||
#[pyarg(named, optional)]
|
||
name: Option<PyStrRef>,
|
||
#[pyarg(named, default = 0.250)]
|
||
sleep: f64,
|
||
}
|
||
|
||
unsafe impl Traverse for BackupArgs {
|
||
fn traverse(&self, tracer_fn: &mut TraverseFn<'_>) {
|
||
self.progress.traverse(tracer_fn);
|
||
self.name.traverse(tracer_fn);
|
||
}
|
||
}
|
||
|
||
#[derive(FromArgs)]
|
||
struct CreateFunctionArgs {
|
||
#[pyarg(any)]
|
||
name: PyStrRef,
|
||
#[pyarg(any)]
|
||
narg: c_int,
|
||
#[pyarg(any)]
|
||
func: PyObjectRef,
|
||
#[pyarg(named, default)]
|
||
deterministic: bool,
|
||
}
|
||
|
||
#[derive(FromArgs)]
|
||
struct CreateAggregateArgs {
|
||
#[pyarg(any)]
|
||
name: PyStrRef,
|
||
#[pyarg(positional)]
|
||
narg: c_int,
|
||
#[pyarg(positional)]
|
||
aggregate_class: PyObjectRef,
|
||
}
|
||
|
||
#[derive(FromArgs)]
|
||
struct BlobOpenArgs {
|
||
#[pyarg(positional)]
|
||
table: PyStrRef,
|
||
#[pyarg(positional)]
|
||
column: PyStrRef,
|
||
#[pyarg(positional)]
|
||
row: i64,
|
||
#[pyarg(named, default)]
|
||
readonly: bool,
|
||
#[pyarg(named, default = vm.ctx.new_str("main"))]
|
||
name: PyStrRef,
|
||
}
|
||
|
||
#[derive(FromArgs)]
|
||
struct CursorArgs {
|
||
#[pyarg(any, default)]
|
||
factory: OptionalArg<PyObjectRef>,
|
||
}
|
||
|
||
struct CallbackData {
|
||
obj: NonNull<PyObject>,
|
||
vm: *const VirtualMachine,
|
||
}
|
||
|
||
impl CallbackData {
|
||
fn new(obj: PyObjectRef, vm: &VirtualMachine) -> Option<Self> {
|
||
(!vm.is_none(&obj)).then_some(Self {
|
||
obj: obj.into_raw(),
|
||
vm,
|
||
})
|
||
}
|
||
|
||
fn retrieve(&self) -> (&PyObject, &VirtualMachine) {
|
||
unsafe { (self.obj.as_ref(), &*self.vm) }
|
||
}
|
||
|
||
unsafe extern "C" fn destructor(data: *mut c_void) {
|
||
drop(unsafe { Box::from_raw(data.cast::<Self>()) });
|
||
}
|
||
|
||
unsafe extern "C" fn func_callback(
|
||
context: *mut sqlite3_context,
|
||
argc: c_int,
|
||
argv: *mut *mut sqlite3_value,
|
||
) {
|
||
let context = SqliteContext::from(context);
|
||
let (func, vm) = unsafe { (*context.user_data::<Self>()).retrieve() };
|
||
let args = unsafe { core::slice::from_raw_parts(argv, argc as usize) };
|
||
|
||
let f = || -> PyResult<()> {
|
||
let db = context.db_handle();
|
||
let args = args
|
||
.iter()
|
||
.cloned()
|
||
.map(|val| value_to_object(val, db, vm))
|
||
.collect::<PyResult<Vec<PyObjectRef>>>()?;
|
||
|
||
let val = func.call(args, vm)?;
|
||
|
||
context.result_from_object(&val, vm)
|
||
};
|
||
|
||
if let Err(exc) = f() {
|
||
context.result_exception(vm, exc, "user-defined function raised exception\0")
|
||
}
|
||
}
|
||
|
||
unsafe extern "C" fn step_callback(
|
||
context: *mut sqlite3_context,
|
||
argc: c_int,
|
||
argv: *mut *mut sqlite3_value,
|
||
) {
|
||
let context = SqliteContext::from(context);
|
||
let (cls, vm) = unsafe { (*context.user_data::<Self>()).retrieve() };
|
||
let args = unsafe { core::slice::from_raw_parts(argv, argc as usize) };
|
||
let instance = context.aggregate_context::<*const PyObject>();
|
||
if unsafe { (*instance).is_null() } {
|
||
match cls.call((), vm) {
|
||
Ok(obj) => unsafe { *instance = obj.into_raw().as_ptr() },
|
||
Err(exc) => {
|
||
return context.result_exception(
|
||
vm,
|
||
exc,
|
||
"user-defined aggregate's '__init__' method raised error\0",
|
||
);
|
||
}
|
||
}
|
||
}
|
||
let instance = unsafe { &**instance };
|
||
|
||
Self::call_method_with_args(context, instance, "step", args, vm);
|
||
}
|
||
|
||
unsafe extern "C" fn finalize_callback(context: *mut sqlite3_context) {
|
||
let context = SqliteContext::from(context);
|
||
let (_, vm) = unsafe { (*context.user_data::<Self>()).retrieve() };
|
||
let instance = context.aggregate_context::<*const PyObject>();
|
||
let Some(instance) = (unsafe { (*instance).as_ref() }) else {
|
||
return;
|
||
};
|
||
|
||
Self::callback_result_from_method(context, instance, "finalize", vm);
|
||
}
|
||
|
||
unsafe extern "C" fn collation_callback(
|
||
data: *mut c_void,
|
||
a_len: c_int,
|
||
a_ptr: *const c_void,
|
||
b_len: c_int,
|
||
b_ptr: *const c_void,
|
||
) -> c_int {
|
||
let (callable, vm) = unsafe { (*data.cast::<Self>()).retrieve() };
|
||
|
||
let f = || -> PyResult<c_int> {
|
||
let text1 = ptr_to_string(a_ptr.cast(), a_len, null_mut(), vm)?;
|
||
let text1 = vm.ctx.new_str(text1);
|
||
let text2 = ptr_to_string(b_ptr.cast(), b_len, null_mut(), vm)?;
|
||
let text2 = vm.ctx.new_str(text2);
|
||
|
||
let val = callable.call((text1, text2), vm)?;
|
||
let Some(val) = val.number().index(vm) else {
|
||
return Ok(0);
|
||
};
|
||
|
||
let val = match val?.as_bigint().sign() {
|
||
Sign::Plus => 1,
|
||
Sign::Minus => -1,
|
||
Sign::NoSign => 0,
|
||
};
|
||
|
||
Ok(val)
|
||
};
|
||
|
||
f().unwrap_or(0)
|
||
}
|
||
|
||
unsafe extern "C" fn value_callback(context: *mut sqlite3_context) {
|
||
let context = SqliteContext::from(context);
|
||
let (_, vm) = unsafe { (*context.user_data::<Self>()).retrieve() };
|
||
let instance = context.aggregate_context::<*const PyObject>();
|
||
let instance = unsafe { &**instance };
|
||
|
||
Self::callback_result_from_method(context, instance, "value", vm);
|
||
}
|
||
|
||
unsafe extern "C" fn inverse_callback(
|
||
context: *mut sqlite3_context,
|
||
argc: c_int,
|
||
argv: *mut *mut sqlite3_value,
|
||
) {
|
||
let context = SqliteContext::from(context);
|
||
let (_, vm) = unsafe { (*context.user_data::<Self>()).retrieve() };
|
||
let args = unsafe { core::slice::from_raw_parts(argv, argc as usize) };
|
||
let instance = context.aggregate_context::<*const PyObject>();
|
||
let instance = unsafe { &**instance };
|
||
|
||
Self::call_method_with_args(context, instance, "inverse", args, vm);
|
||
}
|
||
|
||
unsafe extern "C" fn authorizer_callback(
|
||
data: *mut c_void,
|
||
action: c_int,
|
||
arg1: *const libc::c_char,
|
||
arg2: *const libc::c_char,
|
||
db_name: *const libc::c_char,
|
||
access: *const libc::c_char,
|
||
) -> c_int {
|
||
let (callable, vm) = unsafe { (*data.cast::<Self>()).retrieve() };
|
||
let f = || -> PyResult<c_int> {
|
||
let arg1 = ptr_to_str(arg1, vm)?;
|
||
let arg2 = ptr_to_str(arg2, vm)?;
|
||
let db_name = ptr_to_str(db_name, vm)?;
|
||
let access = ptr_to_str(access, vm)?;
|
||
|
||
let val = callable.call((action, arg1, arg2, db_name, access), vm)?;
|
||
let Some(val) = val.downcast_ref::<PyInt>() else {
|
||
return Ok(SQLITE_DENY);
|
||
};
|
||
val.try_to_primitive::<c_int>(vm)
|
||
};
|
||
|
||
f().unwrap_or(SQLITE_DENY)
|
||
}
|
||
|
||
unsafe extern "C" fn trace_callback(
|
||
_typ: c_uint,
|
||
data: *mut c_void,
|
||
stmt: *mut c_void,
|
||
sql: *mut c_void,
|
||
) -> c_int {
|
||
let (callable, vm) = unsafe { (*data.cast::<Self>()).retrieve() };
|
||
let expanded = unsafe { sqlite3_expanded_sql(stmt.cast()) };
|
||
let f = || -> PyResult<()> {
|
||
let stmt = ptr_to_str(expanded, vm).or_else(|_| ptr_to_str(sql.cast(), vm))?;
|
||
callable.call((stmt,), vm)?;
|
||
Ok(())
|
||
};
|
||
let _ = f();
|
||
0
|
||
}
|
||
|
||
unsafe extern "C" fn progress_callback(data: *mut c_void) -> c_int {
|
||
let (callable, vm) = unsafe { (*data.cast::<Self>()).retrieve() };
|
||
if let Ok(val) = callable.call((), vm)
|
||
&& let Ok(val) = val.is_true(vm)
|
||
{
|
||
return val as c_int;
|
||
}
|
||
-1
|
||
}
|
||
|
||
fn callback_result_from_method(
|
||
context: SqliteContext,
|
||
instance: &PyObject,
|
||
name: &str,
|
||
vm: &VirtualMachine,
|
||
) {
|
||
let f = || -> PyResult<()> {
|
||
let val = vm.call_method(instance, name, ())?;
|
||
context.result_from_object(&val, vm)
|
||
};
|
||
|
||
if let Err(exc) = f() {
|
||
if exc.fast_isinstance(vm.ctx.exceptions.attribute_error) {
|
||
context.result_exception(
|
||
vm,
|
||
exc,
|
||
&format!("user-defined aggregate's '{name}' method not defined\0"),
|
||
)
|
||
} else {
|
||
context.result_exception(
|
||
vm,
|
||
exc,
|
||
&format!("user-defined aggregate's '{name}' method raised error\0"),
|
||
)
|
||
}
|
||
}
|
||
}
|
||
|
||
fn call_method_with_args(
|
||
context: SqliteContext,
|
||
instance: &PyObject,
|
||
name: &str,
|
||
args: &[*mut sqlite3_value],
|
||
vm: &VirtualMachine,
|
||
) {
|
||
let f = || -> PyResult<()> {
|
||
let db = context.db_handle();
|
||
let args = args
|
||
.iter()
|
||
.cloned()
|
||
.map(|val| value_to_object(val, db, vm))
|
||
.collect::<PyResult<Vec<PyObjectRef>>>()?;
|
||
vm.call_method(instance, name, args).map(drop)
|
||
};
|
||
|
||
if let Err(exc) = f() {
|
||
if exc.fast_isinstance(vm.ctx.exceptions.attribute_error) {
|
||
context.result_exception(
|
||
vm,
|
||
exc,
|
||
&format!("user-defined aggregate's '{name}' method not defined\0"),
|
||
)
|
||
} else {
|
||
context.result_exception(
|
||
vm,
|
||
exc,
|
||
&format!("user-defined aggregate's '{name}' method raised error\0"),
|
||
)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
impl Drop for CallbackData {
|
||
fn drop(&mut self) {
|
||
unsafe { PyObjectRef::from_raw(self.obj) };
|
||
}
|
||
}
|
||
|
||
#[pyfunction]
|
||
fn connect(args: ConnectArgs, vm: &VirtualMachine) -> PyResult {
|
||
let factory = args.factory.clone();
|
||
let conn = Connection::py_new(&factory, args, vm)?;
|
||
conn.into_ref_with_type(vm, factory).map(Into::into)
|
||
}
|
||
|
||
#[pyfunction]
|
||
fn complete_statement(statement: PyStrRef, vm: &VirtualMachine) -> PyResult<bool> {
|
||
let s = statement.to_cstring(vm)?;
|
||
let ret = unsafe { sqlite3_complete(s.as_ptr()) };
|
||
Ok(ret == 1)
|
||
}
|
||
|
||
#[pyfunction]
|
||
fn enable_callback_tracebacks(flag: bool) {
|
||
enable_traceback().store(flag, Ordering::Relaxed);
|
||
}
|
||
|
||
#[pyfunction]
|
||
fn register_adapter(typ: PyTypeRef, adapter: ArgCallable, vm: &VirtualMachine) -> PyResult<()> {
|
||
if typ.is(PyInt::class(&vm.ctx))
|
||
|| typ.is(PyFloat::class(&vm.ctx))
|
||
|| typ.is(PyStr::class(&vm.ctx))
|
||
|| typ.is(PyByteArray::class(&vm.ctx))
|
||
{
|
||
let _ = BASE_TYPE_ADAPTED.set(());
|
||
}
|
||
let protocol = PrepareProtocol::class(&vm.ctx).to_owned();
|
||
let key = vm.ctx.new_tuple(vec![typ.into(), protocol.into()]);
|
||
adapters().set_item(key.as_object(), adapter.into(), vm)
|
||
}
|
||
|
||
#[pyfunction]
|
||
fn register_converter(
|
||
typename: PyStrRef,
|
||
converter: ArgCallable,
|
||
vm: &VirtualMachine,
|
||
) -> PyResult<()> {
|
||
let name = typename.as_str().to_uppercase();
|
||
converters().set_item(&name, converter.into(), vm)
|
||
}
|
||
|
||
fn _adapt<F>(obj: &PyObject, proto: PyTypeRef, alt: F, vm: &VirtualMachine) -> PyResult
|
||
where
|
||
F: FnOnce(&PyObject) -> PyResult,
|
||
{
|
||
let proto = proto.into_object();
|
||
let key = vm
|
||
.ctx
|
||
.new_tuple(vec![obj.class().to_owned().into(), proto.clone()]);
|
||
|
||
if let Some(adapter) = adapters().get_item_opt(key.as_object(), vm)? {
|
||
return adapter.call((obj,), vm);
|
||
}
|
||
if let Ok(adapter) = proto.get_attr("__adapt__", vm) {
|
||
match adapter.call((obj,), vm) {
|
||
Ok(val) => {
|
||
if !vm.is_none(&val) {
|
||
return Ok(val);
|
||
}
|
||
}
|
||
Err(exc) => {
|
||
if !exc.fast_isinstance(vm.ctx.exceptions.type_error) {
|
||
return Err(exc);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
if let Ok(adapter) = obj.get_attr("__conform__", vm) {
|
||
match adapter.call((proto,), vm) {
|
||
Ok(val) => {
|
||
if !vm.is_none(&val) {
|
||
return Ok(val);
|
||
}
|
||
}
|
||
Err(exc) => {
|
||
if !exc.fast_isinstance(vm.ctx.exceptions.type_error) {
|
||
return Err(exc);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
alt(obj)
|
||
}
|
||
|
||
#[pyfunction]
|
||
fn adapt(
|
||
obj: PyObjectRef,
|
||
proto: OptionalArg<Option<PyTypeRef>>,
|
||
alt: OptionalArg<PyObjectRef>,
|
||
vm: &VirtualMachine,
|
||
) -> PyResult {
|
||
if matches!(proto, OptionalArg::Present(None)) {
|
||
return if let OptionalArg::Present(alt) = alt {
|
||
Ok(alt)
|
||
} else {
|
||
Err(new_programming_error(vm, "can't adapt".to_owned()))
|
||
};
|
||
}
|
||
|
||
let proto = proto
|
||
.flatten()
|
||
.unwrap_or_else(|| PrepareProtocol::class(&vm.ctx).to_owned());
|
||
|
||
_adapt(
|
||
&obj,
|
||
proto,
|
||
|_| {
|
||
if let OptionalArg::Present(alt) = alt {
|
||
Ok(alt)
|
||
} else {
|
||
Err(new_programming_error(vm, "can't adapt".to_owned()))
|
||
}
|
||
},
|
||
vm,
|
||
)
|
||
}
|
||
|
||
fn need_adapt(obj: &PyObject, vm: &VirtualMachine) -> bool {
|
||
if BASE_TYPE_ADAPTED.get().is_some() {
|
||
true
|
||
} else {
|
||
let cls = obj.class();
|
||
!(cls.is(vm.ctx.types.int_type)
|
||
|| cls.is(vm.ctx.types.float_type)
|
||
|| cls.is(vm.ctx.types.str_type)
|
||
|| cls.is(vm.ctx.types.bytearray_type))
|
||
}
|
||
}
|
||
|
||
static_cell! {
|
||
static CONVERTERS: PyDictRef;
|
||
static ADAPTERS: PyDictRef;
|
||
static BASE_TYPE_ADAPTED: ();
|
||
static USER_FUNCTION_EXCEPTION: PyAtomicRef<Option<PyBaseException>>;
|
||
static ENABLE_TRACEBACK: PyAtomic<bool>;
|
||
}
|
||
|
||
fn converters() -> &'static Py<PyDict> {
|
||
CONVERTERS.get().expect("converters not initialize")
|
||
}
|
||
|
||
fn adapters() -> &'static Py<PyDict> {
|
||
ADAPTERS.get().expect("adapters not initialize")
|
||
}
|
||
|
||
fn user_function_exception() -> &'static PyAtomicRef<Option<PyBaseException>> {
|
||
USER_FUNCTION_EXCEPTION
|
||
.get()
|
||
.expect("user function exception not initialize")
|
||
}
|
||
|
||
fn enable_traceback() -> &'static PyAtomic<bool> {
|
||
ENABLE_TRACEBACK
|
||
.get()
|
||
.expect("enable traceback not initialize")
|
||
}
|
||
|
||
pub(crate) fn module_exec(vm: &VirtualMachine, module: &Py<PyModule>) -> PyResult<()> {
|
||
__module_exec(vm, module);
|
||
|
||
for (name, code) in ERROR_CODES {
|
||
let name = vm.ctx.intern_str(*name);
|
||
let code = vm.new_pyobj(*code);
|
||
module.set_attr(name, code, vm)?;
|
||
}
|
||
|
||
setup_module_exceptions(module.as_object(), vm);
|
||
|
||
let _ = CONVERTERS.set(vm.ctx.new_dict());
|
||
let _ = ADAPTERS.set(vm.ctx.new_dict());
|
||
let _ = USER_FUNCTION_EXCEPTION.set(PyAtomicRef::from(None));
|
||
let _ = ENABLE_TRACEBACK.set(Radium::new(false));
|
||
|
||
module.set_attr("converters", converters().to_owned(), vm)?;
|
||
module.set_attr("adapters", adapters().to_owned(), vm)?;
|
||
|
||
Ok(())
|
||
}
|
||
|
||
#[pyattr]
|
||
#[pyclass(name)]
|
||
#[derive(PyPayload)]
|
||
struct Connection {
|
||
db: PyMutex<Option<Sqlite>>,
|
||
initialized: PyAtomic<bool>,
|
||
detect_types: PyAtomic<c_int>,
|
||
isolation_level: PyAtomicRef<Option<PyStr>>,
|
||
check_same_thread: PyAtomic<bool>,
|
||
thread_ident: PyMutex<ThreadId>, // TODO: Use atomic
|
||
row_factory: PyAtomicRef<Option<PyObject>>,
|
||
text_factory: PyAtomicRef<PyObject>,
|
||
autocommit: PyMutex<AutocommitMode>,
|
||
}
|
||
|
||
impl Debug for Connection {
|
||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> Result<(), core::fmt::Error> {
|
||
write!(f, "Sqlite3 Connection")
|
||
}
|
||
}
|
||
|
||
impl Constructor for Connection {
|
||
type Args = ConnectArgs;
|
||
|
||
fn py_new(cls: &Py<PyType>, args: Self::Args, vm: &VirtualMachine) -> PyResult<Self> {
|
||
let text_factory = PyStr::class(&vm.ctx).to_owned().into_object();
|
||
|
||
// For non-subclassed Connection, initialize in __new__
|
||
// For subclassed Connection, leave db as None and require __init__ to be called
|
||
let is_base_class = cls.is(Connection::class(&vm.ctx));
|
||
|
||
let db = if is_base_class {
|
||
// Initialize immediately for base class
|
||
Some(Connection::initialize_db(&args, vm)?)
|
||
} else {
|
||
// For subclasses, require __init__ to be called
|
||
None
|
||
};
|
||
|
||
let initialized = db.is_some();
|
||
|
||
Ok(Self {
|
||
db: PyMutex::new(db),
|
||
initialized: Radium::new(initialized),
|
||
detect_types: Radium::new(args.detect_types),
|
||
isolation_level: PyAtomicRef::from(args.isolation_level),
|
||
check_same_thread: Radium::new(args.check_same_thread),
|
||
thread_ident: PyMutex::new(std::thread::current().id()),
|
||
row_factory: PyAtomicRef::from(None),
|
||
text_factory: PyAtomicRef::from(text_factory),
|
||
autocommit: PyMutex::new(args.autocommit),
|
||
})
|
||
}
|
||
}
|
||
|
||
impl Callable for Connection {
|
||
type Args = FuncArgs;
|
||
|
||
fn call(zelf: &Py<Self>, args: Self::Args, vm: &VirtualMachine) -> PyResult {
|
||
let _ = zelf.db_lock(vm)?;
|
||
|
||
let (sql,): (PyUtf8StrRef,) = args.bind(vm)?;
|
||
|
||
if let Some(stmt) = Statement::new(zelf, sql, vm)? {
|
||
Ok(stmt.into_ref(&vm.ctx).into())
|
||
} else {
|
||
Ok(vm.ctx.none())
|
||
}
|
||
}
|
||
}
|
||
|
||
impl Initializer for Connection {
|
||
type Args = ConnectArgs;
|
||
|
||
fn init(zelf: PyRef<Self>, args: Self::Args, vm: &VirtualMachine) -> PyResult<()> {
|
||
let was_initialized = Radium::swap(&zelf.initialized, false, Ordering::AcqRel);
|
||
|
||
// Reset factories to their defaults, matching CPython's behavior.
|
||
zelf.reset_factories(vm);
|
||
|
||
if was_initialized {
|
||
zelf.drop_db();
|
||
}
|
||
|
||
// Attempt to open the new database before mutating other state so failures leave
|
||
// the connection uninitialized (and subsequent operations raise ProgrammingError).
|
||
let db = Self::initialize_db(&args, vm)?;
|
||
|
||
let ConnectArgs {
|
||
detect_types,
|
||
isolation_level,
|
||
check_same_thread,
|
||
autocommit,
|
||
..
|
||
} = args;
|
||
|
||
zelf.detect_types.store(detect_types, Ordering::Relaxed);
|
||
zelf.check_same_thread
|
||
.store(check_same_thread, Ordering::Relaxed);
|
||
*zelf.autocommit.lock() = autocommit;
|
||
*zelf.thread_ident.lock() = std::thread::current().id();
|
||
let _ = unsafe { zelf.isolation_level.swap(isolation_level) };
|
||
|
||
let mut guard = zelf.db.lock();
|
||
*guard = Some(db);
|
||
Radium::store(&zelf.initialized, true, Ordering::Release);
|
||
Ok(())
|
||
}
|
||
}
|
||
|
||
#[pyclass(with(Constructor, Callable, Initializer), flags(BASETYPE))]
|
||
impl Connection {
|
||
fn drop_db(&self) {
|
||
self.db.lock().take();
|
||
}
|
||
|
||
fn reset_factories(&self, vm: &VirtualMachine) {
|
||
let default_text_factory = PyStr::class(&vm.ctx).to_owned().into_object();
|
||
let _ = unsafe { self.row_factory.swap(None) };
|
||
let _ = unsafe { self.text_factory.swap(default_text_factory) };
|
||
}
|
||
|
||
fn initialize_db(args: &ConnectArgs, vm: &VirtualMachine) -> PyResult<Sqlite> {
|
||
let path = args.database.to_cstring(vm)?;
|
||
let db = Sqlite::from(SqliteRaw::open(path.as_ptr(), args.uri, vm)?);
|
||
let timeout = (args.timeout * 1000.0) as c_int;
|
||
db.busy_timeout(timeout);
|
||
if let Some(isolation_level) = &args.isolation_level {
|
||
begin_statement_ptr_from_isolation_level(isolation_level, vm)?;
|
||
}
|
||
Ok(db)
|
||
}
|
||
|
||
fn db_lock(&self, vm: &VirtualMachine) -> PyResult<PyMappedMutexGuard<'_, Sqlite>> {
|
||
self.check_thread(vm)?;
|
||
self._db_lock(vm)
|
||
}
|
||
|
||
fn _db_lock(&self, vm: &VirtualMachine) -> PyResult<PyMappedMutexGuard<'_, Sqlite>> {
|
||
let guard = self.db.lock();
|
||
if guard.is_some() {
|
||
Ok(PyMutexGuard::map(guard, |x| unsafe {
|
||
x.as_mut().unwrap_unchecked()
|
||
}))
|
||
} else {
|
||
Err(new_programming_error(
|
||
vm,
|
||
"Base Connection.__init__ not called.".to_owned(),
|
||
))
|
||
}
|
||
}
|
||
|
||
#[pymethod]
|
||
fn cursor(
|
||
zelf: PyRef<Self>,
|
||
args: CursorArgs,
|
||
vm: &VirtualMachine,
|
||
) -> PyResult<PyObjectRef> {
|
||
zelf.db_lock(vm).map(drop)?;
|
||
|
||
let factory = match args.factory {
|
||
OptionalArg::Present(f) => f,
|
||
OptionalArg::Missing => Cursor::class(&vm.ctx).to_owned().into(),
|
||
};
|
||
|
||
let cursor = factory.call((zelf.clone(),), vm)?;
|
||
|
||
if !cursor.class().fast_issubclass(Cursor::class(&vm.ctx)) {
|
||
return Err(vm.new_type_error(format!(
|
||
"factory must return a cursor, not {}",
|
||
cursor.class()
|
||
)));
|
||
}
|
||
|
||
if let Some(cursor_ref) = cursor.downcast_ref::<Cursor>() {
|
||
let _ = unsafe { cursor_ref.row_factory.swap(zelf.row_factory.to_owned()) };
|
||
}
|
||
|
||
Ok(cursor)
|
||
}
|
||
|
||
#[pymethod]
|
||
fn blobopen(
|
||
zelf: PyRef<Self>,
|
||
args: BlobOpenArgs,
|
||
vm: &VirtualMachine,
|
||
) -> PyResult<PyRef<Blob>> {
|
||
let table = args.table.to_cstring(vm)?;
|
||
let column = args.column.to_cstring(vm)?;
|
||
let name = args.name.to_cstring(vm)?;
|
||
|
||
let db = zelf.db_lock(vm)?;
|
||
|
||
let mut blob = null_mut();
|
||
let ret = unsafe {
|
||
sqlite3_blob_open(
|
||
db.db,
|
||
name.as_ptr(),
|
||
table.as_ptr(),
|
||
column.as_ptr(),
|
||
args.row,
|
||
(!args.readonly) as c_int,
|
||
&mut blob,
|
||
)
|
||
};
|
||
db.check(ret, vm)?;
|
||
drop(db);
|
||
|
||
let blob = SqliteBlob { blob };
|
||
let blob = Blob {
|
||
connection: zelf,
|
||
inner: PyMutex::new(Some(BlobInner { blob, offset: 0 })),
|
||
};
|
||
Ok(blob.into_ref(&vm.ctx))
|
||
}
|
||
|
||
#[pymethod]
|
||
fn close(&self, vm: &VirtualMachine) -> PyResult<()> {
|
||
self.check_thread(vm)?;
|
||
self.drop_db();
|
||
Ok(())
|
||
}
|
||
|
||
fn is_closed(&self) -> bool {
|
||
self.db.lock().is_none()
|
||
}
|
||
|
||
#[pymethod]
|
||
fn commit(&self, vm: &VirtualMachine) -> PyResult<()> {
|
||
self.db_lock(vm)?.implicit_commit(vm)
|
||
}
|
||
|
||
#[pymethod]
|
||
fn rollback(&self, vm: &VirtualMachine) -> PyResult<()> {
|
||
let db = self.db_lock(vm)?;
|
||
if !db.is_autocommit() {
|
||
db._exec(b"ROLLBACK\0", vm)
|
||
} else {
|
||
Ok(())
|
||
}
|
||
}
|
||
|
||
#[pymethod]
|
||
fn execute(
|
||
zelf: PyRef<Self>,
|
||
sql: PyUtf8StrRef,
|
||
parameters: OptionalArg<PyObjectRef>,
|
||
vm: &VirtualMachine,
|
||
) -> PyResult<PyRef<Cursor>> {
|
||
let row_factory = zelf.row_factory.to_owned();
|
||
let cursor = Cursor::new(zelf, row_factory, vm).into_ref(&vm.ctx);
|
||
Cursor::execute(cursor, sql, parameters, vm)
|
||
}
|
||
|
||
#[pymethod]
|
||
fn executemany(
|
||
zelf: PyRef<Self>,
|
||
sql: PyUtf8StrRef,
|
||
seq_of_params: ArgIterable,
|
||
vm: &VirtualMachine,
|
||
) -> PyResult<PyRef<Cursor>> {
|
||
let row_factory = zelf.row_factory.to_owned();
|
||
let cursor = Cursor::new(zelf, row_factory, vm).into_ref(&vm.ctx);
|
||
Cursor::executemany(cursor, sql, seq_of_params, vm)
|
||
}
|
||
|
||
#[pymethod]
|
||
fn executescript(
|
||
zelf: PyRef<Self>,
|
||
script: PyUtf8StrRef,
|
||
vm: &VirtualMachine,
|
||
) -> PyResult<PyRef<Cursor>> {
|
||
let row_factory = zelf.row_factory.to_owned();
|
||
Cursor::executescript(
|
||
Cursor::new(zelf, row_factory, vm).into_ref(&vm.ctx),
|
||
script,
|
||
vm,
|
||
)
|
||
}
|
||
|
||
// TODO: Make it build without clippy::manual_c_str_literals
|
||
#[pymethod]
|
||
#[allow(clippy::manual_c_str_literals)]
|
||
fn backup(zelf: &Py<Self>, args: BackupArgs, vm: &VirtualMachine) -> PyResult<()> {
|
||
let BackupArgs {
|
||
target,
|
||
pages,
|
||
progress,
|
||
name,
|
||
sleep,
|
||
} = args;
|
||
if zelf.is(&target) {
|
||
return Err(vm.new_value_error("target cannot be the same connection instance"));
|
||
}
|
||
|
||
let pages = if pages == 0 { -1 } else { pages };
|
||
|
||
let name_cstring;
|
||
let name_ptr = if let Some(name) = &name {
|
||
name_cstring = name.to_cstring(vm)?;
|
||
name_cstring.as_ptr()
|
||
} else {
|
||
b"main\0".as_ptr().cast()
|
||
};
|
||
|
||
let sleep_ms = (sleep * 1000.0) as c_int;
|
||
|
||
let db = zelf.db_lock(vm)?;
|
||
let target_db = target.db_lock(vm)?;
|
||
|
||
let handle = unsafe {
|
||
sqlite3_backup_init(target_db.db, b"main\0".as_ptr().cast(), db.db, name_ptr)
|
||
};
|
||
|
||
if handle.is_null() {
|
||
return Err(target_db.error_extended(vm));
|
||
}
|
||
|
||
drop(db);
|
||
drop(target_db);
|
||
|
||
loop {
|
||
let ret = unsafe { sqlite3_backup_step(handle, pages) };
|
||
|
||
if let Some(progress) = &progress {
|
||
let remaining = unsafe { sqlite3_backup_remaining(handle) };
|
||
let pagecount = unsafe { sqlite3_backup_pagecount(handle) };
|
||
if let Err(err) = progress.invoke((ret, remaining, pagecount), vm) {
|
||
unsafe { sqlite3_backup_finish(handle) };
|
||
return Err(err);
|
||
}
|
||
}
|
||
|
||
if ret == SQLITE_BUSY || ret == SQLITE_LOCKED {
|
||
unsafe { sqlite3_sleep(sleep_ms) };
|
||
} else if ret != SQLITE_OK {
|
||
break;
|
||
}
|
||
}
|
||
|
||
let ret = unsafe { sqlite3_backup_finish(handle) };
|
||
if ret == SQLITE_OK {
|
||
Ok(())
|
||
} else {
|
||
Err(target.db_lock(vm)?.error_extended(vm))
|
||
}
|
||
}
|
||
|
||
#[pymethod]
|
||
fn create_function(&self, args: CreateFunctionArgs, vm: &VirtualMachine) -> PyResult<()> {
|
||
let name = args.name.to_cstring(vm)?;
|
||
let flags = if args.deterministic {
|
||
SQLITE_UTF8 | SQLITE_DETERMINISTIC
|
||
} else {
|
||
SQLITE_UTF8
|
||
};
|
||
let db = self.db_lock(vm)?;
|
||
let Some(data) = CallbackData::new(args.func, vm) else {
|
||
return db.create_function(
|
||
name.as_ptr(),
|
||
args.narg,
|
||
flags,
|
||
null_mut(),
|
||
None,
|
||
None,
|
||
None,
|
||
None,
|
||
vm,
|
||
);
|
||
};
|
||
|
||
db.create_function(
|
||
name.as_ptr(),
|
||
args.narg,
|
||
flags,
|
||
Box::into_raw(Box::new(data)).cast(),
|
||
Some(CallbackData::func_callback),
|
||
None,
|
||
None,
|
||
Some(CallbackData::destructor),
|
||
vm,
|
||
)
|
||
}
|
||
|
||
#[pymethod]
|
||
fn create_aggregate(&self, args: CreateAggregateArgs, vm: &VirtualMachine) -> PyResult<()> {
|
||
let name = args.name.to_cstring(vm)?;
|
||
let db = self.db_lock(vm)?;
|
||
let Some(data) = CallbackData::new(args.aggregate_class, vm) else {
|
||
return db.create_function(
|
||
name.as_ptr(),
|
||
args.narg,
|
||
SQLITE_UTF8,
|
||
null_mut(),
|
||
None,
|
||
None,
|
||
None,
|
||
None,
|
||
vm,
|
||
);
|
||
};
|
||
|
||
db.create_function(
|
||
name.as_ptr(),
|
||
args.narg,
|
||
SQLITE_UTF8,
|
||
Box::into_raw(Box::new(data)).cast(),
|
||
None,
|
||
Some(CallbackData::step_callback),
|
||
Some(CallbackData::finalize_callback),
|
||
Some(CallbackData::destructor),
|
||
vm,
|
||
)
|
||
}
|
||
|
||
#[pymethod]
|
||
fn create_collation(
|
||
&self,
|
||
name: PyUtf8StrRef,
|
||
callable: PyObjectRef,
|
||
vm: &VirtualMachine,
|
||
) -> PyResult<()> {
|
||
let name = name.to_cstring(vm)?;
|
||
let db = self.db_lock(vm)?;
|
||
let Some(data) = CallbackData::new(callable.clone(), vm) else {
|
||
unsafe {
|
||
sqlite3_create_collation_v2(
|
||
db.db,
|
||
name.as_ptr(),
|
||
SQLITE_UTF8,
|
||
null_mut(),
|
||
None,
|
||
None,
|
||
);
|
||
}
|
||
return Ok(());
|
||
};
|
||
let data = Box::into_raw(Box::new(data));
|
||
|
||
if !callable.is_callable() {
|
||
return Err(vm.new_type_error("parameter must be callable"));
|
||
}
|
||
|
||
let ret = unsafe {
|
||
sqlite3_create_collation_v2(
|
||
db.db,
|
||
name.as_ptr(),
|
||
SQLITE_UTF8,
|
||
data.cast(),
|
||
Some(CallbackData::collation_callback),
|
||
Some(CallbackData::destructor),
|
||
)
|
||
};
|
||
|
||
db.check(ret, vm).inspect_err(|_| {
|
||
// create_collation do not call destructor if error occur
|
||
let _ = unsafe { Box::from_raw(data) };
|
||
})
|
||
}
|
||
|
||
#[pymethod]
|
||
fn create_window_function(
|
||
&self,
|
||
name: PyStrRef,
|
||
narg: c_int,
|
||
aggregate_class: PyObjectRef,
|
||
vm: &VirtualMachine,
|
||
) -> PyResult<()> {
|
||
let name = name.to_cstring(vm)?;
|
||
let db = self.db_lock(vm)?;
|
||
let Some(data) = CallbackData::new(aggregate_class, vm) else {
|
||
unsafe {
|
||
sqlite3_create_window_function(
|
||
db.db,
|
||
name.as_ptr(),
|
||
narg,
|
||
SQLITE_UTF8,
|
||
null_mut(),
|
||
None,
|
||
None,
|
||
None,
|
||
None,
|
||
None,
|
||
)
|
||
};
|
||
return Ok(());
|
||
};
|
||
|
||
let ret = unsafe {
|
||
sqlite3_create_window_function(
|
||
db.db,
|
||
name.as_ptr(),
|
||
narg,
|
||
SQLITE_UTF8,
|
||
Box::into_raw(Box::new(data)).cast(),
|
||
Some(CallbackData::step_callback),
|
||
Some(CallbackData::finalize_callback),
|
||
Some(CallbackData::value_callback),
|
||
Some(CallbackData::inverse_callback),
|
||
Some(CallbackData::destructor),
|
||
)
|
||
};
|
||
db.check(ret, vm)
|
||
.map_err(|_| new_programming_error(vm, "Error creating window function".to_owned()))
|
||
}
|
||
|
||
#[pymethod]
|
||
fn set_authorizer(&self, callable: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
|
||
let db = self.db_lock(vm)?;
|
||
let Some(data) = CallbackData::new(callable, vm) else {
|
||
unsafe { sqlite3_set_authorizer(db.db, None, null_mut()) };
|
||
return Ok(());
|
||
};
|
||
|
||
let ret = unsafe {
|
||
sqlite3_set_authorizer(
|
||
db.db,
|
||
Some(CallbackData::authorizer_callback),
|
||
Box::into_raw(Box::new(data)).cast(),
|
||
)
|
||
};
|
||
db.check(ret, vm).map_err(|_| {
|
||
new_operational_error(vm, "Error setting authorizer callback".to_owned())
|
||
})
|
||
}
|
||
|
||
#[pymethod]
|
||
fn set_trace_callback(&self, callable: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
|
||
let db = self.db_lock(vm)?;
|
||
let Some(data) = CallbackData::new(callable, vm) else {
|
||
unsafe { sqlite3_trace_v2(db.db, SQLITE_TRACE_STMT, None, null_mut()) };
|
||
return Ok(());
|
||
};
|
||
|
||
let ret = unsafe {
|
||
sqlite3_trace_v2(
|
||
db.db,
|
||
SQLITE_TRACE_STMT,
|
||
Some(CallbackData::trace_callback),
|
||
Box::into_raw(Box::new(data)).cast(),
|
||
)
|
||
};
|
||
|
||
db.check(ret, vm)
|
||
}
|
||
|
||
#[pymethod]
|
||
fn set_progress_handler(
|
||
&self,
|
||
callable: PyObjectRef,
|
||
n: c_int,
|
||
vm: &VirtualMachine,
|
||
) -> PyResult<()> {
|
||
let db = self.db_lock(vm)?;
|
||
let Some(data) = CallbackData::new(callable, vm) else {
|
||
unsafe { sqlite3_progress_handler(db.db, n, None, null_mut()) };
|
||
return Ok(());
|
||
};
|
||
|
||
unsafe {
|
||
sqlite3_progress_handler(
|
||
db.db,
|
||
n,
|
||
Some(CallbackData::progress_callback),
|
||
Box::into_raw(Box::new(data)).cast(),
|
||
)
|
||
};
|
||
|
||
Ok(())
|
||
}
|
||
|
||
#[pymethod]
|
||
fn iterdump(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult {
|
||
let module = vm.import("sqlite3.dump", 0)?;
|
||
let func = module.get_attr("_iterdump", vm)?;
|
||
func.call((zelf,), vm)
|
||
}
|
||
|
||
#[pymethod]
|
||
fn interrupt(&self, vm: &VirtualMachine) -> PyResult<()> {
|
||
// DO NOT check thread safety
|
||
self._db_lock(vm).map(|x| x.interrupt())
|
||
}
|
||
|
||
#[pymethod]
|
||
fn getlimit(&self, category: c_int, vm: &VirtualMachine) -> PyResult<c_int> {
|
||
self.db_lock(vm)?.limit(category, -1, vm)
|
||
}
|
||
|
||
#[pymethod]
|
||
fn setlimit(&self, category: c_int, limit: c_int, vm: &VirtualMachine) -> PyResult<c_int> {
|
||
self.db_lock(vm)?.limit(category, limit, vm)
|
||
}
|
||
|
||
#[pymethod]
|
||
fn __enter__(zelf: PyRef<Self>) -> PyRef<Self> {
|
||
zelf
|
||
}
|
||
|
||
#[pymethod]
|
||
fn __exit__(
|
||
&self,
|
||
cls: PyObjectRef,
|
||
exc: PyObjectRef,
|
||
tb: PyObjectRef,
|
||
vm: &VirtualMachine,
|
||
) -> PyResult<()> {
|
||
if vm.is_none(&cls) && vm.is_none(&exc) && vm.is_none(&tb) {
|
||
self.commit(vm)
|
||
} else {
|
||
self.rollback(vm)
|
||
}
|
||
}
|
||
|
||
#[pygetset]
|
||
fn isolation_level(&self) -> Option<PyStrRef> {
|
||
self.isolation_level.deref().map(|x| x.to_owned())
|
||
}
|
||
#[pygetset(setter)]
|
||
fn set_isolation_level(
|
||
&self,
|
||
value: PySetterValue<Option<PyStrRef>>,
|
||
vm: &VirtualMachine,
|
||
) -> PyResult<()> {
|
||
match value {
|
||
PySetterValue::Assign(value) => {
|
||
if let Some(val_str) = &value {
|
||
begin_statement_ptr_from_isolation_level(val_str, vm)?;
|
||
}
|
||
|
||
// If setting isolation_level to None (auto-commit mode), commit any pending transaction
|
||
if value.is_none() {
|
||
let db = self.db_lock(vm)?;
|
||
if !db.is_autocommit() {
|
||
// Keep the lock and call implicit_commit directly to avoid race conditions
|
||
db.implicit_commit(vm)?;
|
||
}
|
||
}
|
||
let _ = unsafe { self.isolation_level.swap(value) };
|
||
Ok(())
|
||
}
|
||
PySetterValue::Delete => Err(vm.new_attribute_error(
|
||
"'isolation_level' attribute cannot be deleted".to_owned(),
|
||
)),
|
||
}
|
||
}
|
||
|
||
#[pygetset]
|
||
fn autocommit(&self, vm: &VirtualMachine) -> PyObjectRef {
|
||
match *self.autocommit.lock() {
|
||
AutocommitMode::Enabled => vm.ctx.true_value.clone().into(),
|
||
AutocommitMode::Disabled => vm.ctx.false_value.clone().into(),
|
||
AutocommitMode::Legacy => vm.ctx.new_int(LEGACY_TRANSACTION_CONTROL).into(),
|
||
}
|
||
}
|
||
#[pygetset(setter)]
|
||
fn set_autocommit(&self, val: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
|
||
let mode = AutocommitMode::try_from_borrowed_object(vm, &val)?;
|
||
let db = self.db_lock(vm)?;
|
||
|
||
// Handle transaction state based on mode change
|
||
match mode {
|
||
AutocommitMode::Enabled => {
|
||
// If there's a pending transaction, commit it
|
||
if !db.is_autocommit() {
|
||
db._exec(b"COMMIT |