diff --git a/lib/database.js b/lib/database.js index fef430b..f1e269e 100644 --- a/lib/database.js +++ b/lib/database.js @@ -84,6 +84,7 @@ Database.prototype.exec = wrappers.exec; Database.prototype.close = wrappers.close; Database.prototype.defaultSafeIntegers = wrappers.defaultSafeIntegers; Database.prototype.unsafeMode = wrappers.unsafeMode; +Database.prototype.loadExtension = wrappers.loadExtension; Database.prototype[util.inspect] = require('./methods/inspect'); // Export SQLITE_SCANSTAT_* constants from native addon diff --git a/lib/index.d.ts b/lib/index.d.ts index c4949dd..e57c14c 100644 --- a/lib/index.d.ts +++ b/lib/index.d.ts @@ -96,7 +96,7 @@ declare namespace BetterSqlite3 { result?: ((total: T) => unknown) | undefined; }, ): this; - loadExtension(path: string): this; + loadExtension(path: string, entryPoint?: string): this; close(): this; defaultSafeIntegers(toggleState?: boolean): this; backup(destinationFile: string, options?: Database.BackupOptions): Promise; diff --git a/lib/methods/wrappers.js b/lib/methods/wrappers.js index 0914ed3..67d80a4 100644 --- a/lib/methods/wrappers.js +++ b/lib/methods/wrappers.js @@ -25,6 +25,11 @@ exports.unsafeMode = function unsafeMode(...args) { return this; }; +exports.loadExtension = function loadExtension(...args) { + this[cppdb].loadExtension(...args); + return this; +}; + exports.getters = { name: { get: function name() { return this[cppdb].name; }, diff --git a/src/objects/database.cpp b/src/objects/database.cpp index 1fcacc0..19e615f 100644 --- a/src/objects/database.cpp +++ b/src/objects/database.cpp @@ -132,6 +132,7 @@ INIT(Database::Init) { SetPrototypeMethod(isolate, data, t, "function", JS_function); SetPrototypeMethod(isolate, data, t, "aggregate", JS_aggregate); SetPrototypeMethod(isolate, data, t, "table", JS_table); + SetPrototypeMethod(isolate, data, t, "loadExtension", JS_loadExtension); SetPrototypeMethod(isolate, data, t, "close", JS_close); SetPrototypeMethod(isolate, data, t, "defaultSafeIntegers", JS_defaultSafeIntegers); SetPrototypeMethod(isolate, data, t, "unsafeMode", JS_unsafeMode); @@ -171,7 +172,9 @@ NODE_METHOD(Database::JS_new) { sqlite3_busy_timeout(db_handle, timeout); sqlite3_limit(db_handle, SQLITE_LIMIT_LENGTH, MAX_BUFFER_SIZE < MAX_STRING_SIZE ? MAX_BUFFER_SIZE : MAX_STRING_SIZE); sqlite3_limit(db_handle, SQLITE_LIMIT_SQL_LENGTH, MAX_STRING_SIZE); - int status = sqlite3_db_config(db_handle, SQLITE_DBCONFIG_DEFENSIVE, 1, NULL); + int status = sqlite3_db_config(db_handle, SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION, 1, NULL); + assert(status == SQLITE_OK); + status = sqlite3_db_config(db_handle, SQLITE_DBCONFIG_DEFENSIVE, 1, NULL); assert(status == SQLITE_OK); ((void)status); if (node::Buffer::HasInstance(buffer) && !Deserialize(buffer.As(), addon, db_handle, readonly)) { @@ -206,6 +209,28 @@ NODE_METHOD(Database::JS_prepare) { if (!maybeStatement.IsEmpty()) info.GetReturnValue().Set(maybeStatement.ToLocalChecked()); } +NODE_METHOD(Database::JS_loadExtension) { + Database* db = Unwrap(info.This()); + REQUIRE_ARGUMENT_STRING(first, v8::Local filename); + v8::Local entryPoint; + if (info.Length() > 1) { REQUIRE_ARGUMENT_STRING(second, entryPoint); } + REQUIRE_DATABASE_OPEN(db); + REQUIRE_DATABASE_NOT_BUSY(db); + REQUIRE_DATABASE_NO_ITERATORS(db); + UseIsolate; + char* error; + int status = sqlite3_load_extension( + db->db_handle, + *v8::String::Utf8Value(isolate, filename), + entryPoint.IsEmpty() ? NULL : *v8::String::Utf8Value(isolate, entryPoint), + &error + ); + if (status != SQLITE_OK) { + ThrowSqliteError(db->addon, error, status); + } + sqlite3_free(error); +} + NODE_METHOD(Database::JS_exec) { Database* db = Unwrap(info.This()); REQUIRE_ARGUMENT_STRING(first, v8::Local source); diff --git a/src/objects/database.hpp b/src/objects/database.hpp index 87fb2c3..bd8889d 100644 --- a/src/objects/database.hpp +++ b/src/objects/database.hpp @@ -75,6 +75,7 @@ class Database : public node::ObjectWrap { static NODE_METHOD(JS_function); static NODE_METHOD(JS_aggregate); static NODE_METHOD(JS_table); + static NODE_METHOD(JS_loadExtension); static NODE_METHOD(JS_close); static NODE_METHOD(JS_defaultSafeIntegers); static NODE_METHOD(JS_unsafeMode); diff --git a/test/15.database.load-extension.js b/test/15.database.load-extension.js new file mode 100644 index 0000000..6ee725a --- /dev/null +++ b/test/15.database.load-extension.js @@ -0,0 +1,50 @@ +'use strict'; +const { execSync } = require('child_process'); +const path = require('path'); +const Database = require('../.'); + +const isWindows = process.platform === 'win32'; +const extensionSrc = path.join(__dirname, '..', 'deps', 'test_extension.c'); +const sqliteInclude = path.join(__dirname, '..', 'deps', 'sqlite3'); +const extensionPath = path.join(__dirname, '..', 'temp', 'test_extension'); + +(isWindows ? describe.skip : describe)('Database#loadExtension()', function () { + before(function () { + const ext = process.platform === 'darwin' ? '.dylib' : '.so'; + this.extensionFile = extensionPath + ext; + execSync(`cc -shared -fPIC -I "${sqliteInclude}" -o "${this.extensionFile}" "${extensionSrc}"`); + }); + beforeEach(function () { + this.db = new Database(util.next()); + }); + afterEach(function () { + this.db.close(); + }); + + it('should throw an exception if a string is not provided', function () { + expect(() => this.db.loadExtension(123)).to.throw(TypeError); + expect(() => this.db.loadExtension(null)).to.throw(TypeError); + expect(() => this.db.loadExtension()).to.throw(TypeError); + }); + it('should throw an exception if the extension is not found', function () { + expect(() => this.db.loadExtension('/tmp/nonexistent_extension')).to.throw(Database.SqliteError); + }); + it('should load the extension and make its functions available', function () { + const r = this.db.loadExtension(extensionPath); + expect(r).to.equal(this.db); + const result = this.db.prepare('SELECT testExtensionFunction(1, 2, 3) AS val').get(); + expect(result.val).to.equal(3); + }); + it('should not allow loading extensions while the database is busy', function () { + this.db.exec('CREATE TABLE data (x)'); + this.db.exec('INSERT INTO data VALUES (1)'); + const iter = this.db.prepare('SELECT * FROM data').iterate(); + iter.next(); + expect(() => this.db.loadExtension(extensionPath)).to.throw(TypeError); + iter.return(); + }); + it('should not allow loading extensions after the database is closed', function () { + this.db.close(); + expect(() => this.db.loadExtension(extensionPath)).to.throw(TypeError); + }); +});