diff --git a/crates/core/src/rows.rs b/crates/core/src/rows.rs index c3fd871e3f..6bf2685303 100644 --- a/crates/core/src/rows.rs +++ b/crates/core/src/rows.rs @@ -1,11 +1,12 @@ use crate::{errors, Error, Params, Result, Statement}; +use crate::statement::StatementInner; use std::cell::RefCell; +use std::rc::Rc; /// Query result rows. pub struct Rows { - pub(crate) raw: *mut libsql_sys::ffi::sqlite3, - pub(crate) raw_stmt: *mut libsql_sys::ffi::sqlite3_stmt, + pub(crate) stmt: Rc, pub(crate) err: RefCell>, } @@ -15,21 +16,21 @@ impl Rows { pub fn next(&self) -> Result> { let err = match self.err.take() { Some(err) => err, - None => unsafe { libsql_sys::ffi::sqlite3_step(self.raw_stmt) }, + None => unsafe { libsql_sys::ffi::sqlite3_step(self.stmt.raw_stmt) }, }; match err as u32 { libsql_sys::ffi::SQLITE_OK => Ok(None), libsql_sys::ffi::SQLITE_DONE => Ok(None), - libsql_sys::ffi::SQLITE_ROW => Ok(Some(Row { raw: self.raw_stmt })), + libsql_sys::ffi::SQLITE_ROW => Ok(Some(Row { raw: self.stmt.raw_stmt })), _ => Err(Error::QueryFailed(format!( "Failed to fetch next row: {}", - errors::sqlite_error_message(self.raw) + errors::sqlite_error_message(self.stmt.raw) ))), } } pub fn column_count(&self) -> i32 { - unsafe { libsql_sys::ffi::sqlite3_column_count(self.raw_stmt) } + unsafe { libsql_sys::ffi::sqlite3_column_count(self.stmt.raw_stmt) } } } diff --git a/crates/core/src/statement.rs b/crates/core/src/statement.rs index 51d32c86a6..8b235d76ba 100644 --- a/crates/core/src/statement.rs +++ b/crates/core/src/statement.rs @@ -1,11 +1,24 @@ use crate::{errors, Error, Params, Result, Rows, Value}; +use std::rc::Rc; use std::cell::RefCell; /// A prepared statement. pub struct Statement { - raw: *mut libsql_sys::ffi::sqlite3, - raw_stmt: *mut libsql_sys::ffi::sqlite3_stmt, + inner: Rc, +} + +pub(crate) struct StatementInner { + pub(crate) raw: *mut libsql_sys::ffi::sqlite3, + pub(crate) raw_stmt: *mut libsql_sys::ffi::sqlite3_stmt, +} + +impl Drop for StatementInner { + fn drop(&mut self) { + if !self.raw_stmt.is_null() { + unsafe { libsql_sys::ffi::sqlite3_finalize(self.raw_stmt); } + } + } } impl Statement { @@ -21,7 +34,7 @@ impl Statement { ) }; match err as u32 { - libsql_sys::ffi::SQLITE_OK => Ok(Statement { raw, raw_stmt }), + libsql_sys::ffi::SQLITE_OK => Ok(Statement { inner: Rc::new(StatementInner { raw, raw_stmt }) }), _ => Err(Error::QueryFailed(format!( "Failed to prepare statement: `{}`: {}", sql, @@ -38,18 +51,18 @@ impl Statement { let i = i as i32 + 1; match param { Value::Null => unsafe { - libsql_sys::ffi::sqlite3_bind_null(self.raw_stmt, i); + libsql_sys::ffi::sqlite3_bind_null(self.inner.raw_stmt, i); }, Value::Integer(value) => unsafe { - libsql_sys::ffi::sqlite3_bind_int64(self.raw_stmt, i, *value); + libsql_sys::ffi::sqlite3_bind_int64(self.inner.raw_stmt, i, *value); }, Value::Float(value) => unsafe { - libsql_sys::ffi::sqlite3_bind_double(self.raw_stmt, i, *value); + libsql_sys::ffi::sqlite3_bind_double(self.inner.raw_stmt, i, *value); }, Value::Text(value) => unsafe { let value = value.as_bytes(); libsql_sys::ffi::sqlite3_bind_text( - self.raw_stmt, + self.inner.raw_stmt, i, value.as_ptr() as *const i8, value.len() as i32, @@ -58,7 +71,7 @@ impl Statement { }, Value::Blob(value) => unsafe { libsql_sys::ffi::sqlite3_bind_blob( - self.raw_stmt, + self.inner.raw_stmt, i, value.as_ptr() as *const std::ffi::c_void, value.len() as i32, @@ -73,13 +86,12 @@ impl Statement { pub fn execute(&self, params: &Params) -> Option { self.bind(params); - let err = unsafe { libsql_sys::ffi::sqlite3_step(self.raw_stmt) }; + let err = unsafe { libsql_sys::ffi::sqlite3_step(self.inner.raw_stmt) }; match err as u32 { libsql_sys::ffi::SQLITE_OK => None, libsql_sys::ffi::SQLITE_DONE => None, _ => Some(Rows { - raw: self.raw, - raw_stmt: self.raw_stmt, + stmt: self.inner.clone(), err: RefCell::new(Some(err)), }), } @@ -87,6 +99,6 @@ impl Statement { /// Reset the prepared statement to initial state for reuse. pub fn reset(&self) { - unsafe { libsql_sys::ffi::sqlite3_reset(self.raw_stmt) }; + unsafe { libsql_sys::ffi::sqlite3_reset(self.inner.raw_stmt) }; } }