diff --git a/modules/database-commons/src/main/java/org/testcontainers/ext/ScriptUtils.java b/modules/database-commons/src/main/java/org/testcontainers/ext/ScriptUtils.java index 7d89517f200..73c49c0947c 100644 --- a/modules/database-commons/src/main/java/org/testcontainers/ext/ScriptUtils.java +++ b/modules/database-commons/src/main/java/org/testcontainers/ext/ScriptUtils.java @@ -16,12 +16,16 @@ package org.testcontainers.ext; +import org.apache.commons.io.IOUtils; import org.apache.commons.lang.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.testcontainers.delegate.DatabaseDelegate; import javax.script.ScriptException; +import java.io.IOException; +import java.net.URL; +import java.nio.charset.StandardCharsets; import java.util.LinkedList; import java.util.List; @@ -210,6 +214,30 @@ public static boolean containsSqlScriptDelimiters(String script, String delim) { return false; } + /** + * Load script from classpath and apply it to the given database + * + * @param databaseDelegate database delegate for script execution + * @param initScriptPath the resource to load the init script from + */ + public static void runInitScript(DatabaseDelegate databaseDelegate, String initScriptPath) { + try { + URL resource = ScriptUtils.class.getClassLoader().getResource(initScriptPath); + if (resource == null) { + LOGGER.warn("Could not load classpath init script: {}", initScriptPath); + throw new ScriptLoadException("Could not load classpath init script: " + initScriptPath + ". Resource not found."); + } + String scripts = IOUtils.toString(resource, StandardCharsets.UTF_8); + executeDatabaseScript(databaseDelegate, initScriptPath, scripts); + } catch (IOException e) { + LOGGER.warn("Could not load classpath init script: {}", initScriptPath); + throw new ScriptLoadException("Could not load classpath init script: " + initScriptPath, e); + } catch (ScriptException e) { + LOGGER.error("Error while executing init script: {}", initScriptPath, e); + throw new UncategorizedScriptException("Error while executing init script: " + initScriptPath, e); + } + } + public static void executeDatabaseScript(DatabaseDelegate databaseDelegate, String scriptPath, String script) throws ScriptException { executeDatabaseScript(databaseDelegate, scriptPath, script, false, false, DEFAULT_COMMENT_PREFIX, DEFAULT_STATEMENT_SEPARATOR, DEFAULT_BLOCK_COMMENT_START_DELIMITER, DEFAULT_BLOCK_COMMENT_END_DELIMITER); } diff --git a/modules/jdbc-test/src/test/java/org/testcontainers/junit/SimpleMySQLTest.java b/modules/jdbc-test/src/test/java/org/testcontainers/junit/SimpleMySQLTest.java index 4c3731fcaf9..5c8e4133a22 100644 --- a/modules/jdbc-test/src/test/java/org/testcontainers/junit/SimpleMySQLTest.java +++ b/modules/jdbc-test/src/test/java/org/testcontainers/junit/SimpleMySQLTest.java @@ -132,6 +132,20 @@ public void testMySQL8() throws SQLException { } } + @Test + public void testExplicitInitScript() throws SQLException { + try (MySQLContainer container = (MySQLContainer) new MySQLContainer() + .withInitScript("somepath/init_mysql.sql") + .withLogConsumer(new Slf4jLogConsumer(logger))) { + container.start(); + + ResultSet resultSet = performQuery(container, "SELECT foo FROM bar"); + String firstColumnValue = resultSet.getString(1); + + assertEquals("Value from init script should equal real value", "hello world", firstColumnValue); + } + } + @Test public void testEmptyPasswordWithNonRootUser() { diff --git a/modules/jdbc-test/src/test/java/org/testcontainers/junit/SimplePostgreSQLTest.java b/modules/jdbc-test/src/test/java/org/testcontainers/junit/SimplePostgreSQLTest.java index 942d8f1c1ba..2581240cd84 100644 --- a/modules/jdbc-test/src/test/java/org/testcontainers/junit/SimplePostgreSQLTest.java +++ b/modules/jdbc-test/src/test/java/org/testcontainers/junit/SimplePostgreSQLTest.java @@ -2,8 +2,8 @@ import com.zaxxer.hikari.HikariConfig; import com.zaxxer.hikari.HikariDataSource; -import org.junit.Rule; import org.junit.Test; +import org.testcontainers.containers.JdbcDatabaseContainer; import org.testcontainers.containers.PostgreSQLContainer; import java.sql.ResultSet; @@ -17,23 +17,42 @@ */ public class SimplePostgreSQLTest { - @Rule - public PostgreSQLContainer postgres = new PostgreSQLContainer(); - @Test public void testSimple() throws SQLException { + try (PostgreSQLContainer postgres = new PostgreSQLContainer<>()) { + postgres.start(); + + ResultSet resultSet = performQuery(postgres, "SELECT 1"); + + int resultSetInt = resultSet.getInt(1); + assertEquals("A basic SELECT query succeeds", 1, resultSetInt); + } + } + + @Test + public void testExplicitInitScript() throws SQLException { + try (PostgreSQLContainer postgres = new PostgreSQLContainer<>() + .withInitScript("somepath/init_postgresql.sql")) { + postgres.start(); + + ResultSet resultSet = performQuery(postgres, "SELECT foo FROM bar"); + + String firstColumnValue = resultSet.getString(1); + assertEquals("Value from init script should equal real value", "hello world", firstColumnValue); + } + } + + private ResultSet performQuery(JdbcDatabaseContainer container, String sql) throws SQLException { HikariConfig hikariConfig = new HikariConfig(); - hikariConfig.setJdbcUrl(postgres.getJdbcUrl()); - hikariConfig.setUsername(postgres.getUsername()); - hikariConfig.setPassword(postgres.getPassword()); + hikariConfig.setJdbcUrl(container.getJdbcUrl()); + hikariConfig.setUsername(container.getUsername()); + hikariConfig.setPassword(container.getPassword()); HikariDataSource ds = new HikariDataSource(hikariConfig); Statement statement = ds.getConnection().createStatement(); - statement.execute("SELECT 1"); + statement.execute(sql); ResultSet resultSet = statement.getResultSet(); - resultSet.next(); - int resultSetInt = resultSet.getInt(1); - assertEquals("A basic SELECT query succeeds", 1, resultSetInt); + return resultSet; } } diff --git a/modules/jdbc-test/src/test/resources/somepath/init_postgresql.sql b/modules/jdbc-test/src/test/resources/somepath/init_postgresql.sql new file mode 100644 index 00000000000..2b00ee968b0 --- /dev/null +++ b/modules/jdbc-test/src/test/resources/somepath/init_postgresql.sql @@ -0,0 +1,5 @@ +CREATE TABLE bar ( + foo VARCHAR(255) +); + +INSERT INTO bar (foo) VALUES ('hello world'); \ No newline at end of file diff --git a/modules/jdbc/src/main/java/org/testcontainers/containers/JdbcDatabaseContainer.java b/modules/jdbc/src/main/java/org/testcontainers/containers/JdbcDatabaseContainer.java index c43b74105dc..e480070168f 100644 --- a/modules/jdbc/src/main/java/org/testcontainers/containers/JdbcDatabaseContainer.java +++ b/modules/jdbc/src/main/java/org/testcontainers/containers/JdbcDatabaseContainer.java @@ -1,11 +1,15 @@ package org.testcontainers.containers; import lombok.NonNull; +import com.github.dockerjava.api.command.InspectContainerResponse; import org.jetbrains.annotations.NotNull; import org.rnorth.ducttape.ratelimits.RateLimiter; import org.rnorth.ducttape.ratelimits.RateLimiterBuilder; import org.rnorth.ducttape.unreliables.Unreliables; import org.testcontainers.containers.traits.LinkableContainer; +import org.testcontainers.delegate.DatabaseDelegate; +import org.testcontainers.ext.ScriptUtils; +import org.testcontainers.jdbc.JdbcDatabaseDelegate; import org.testcontainers.utility.MountableFile; import java.sql.Connection; @@ -26,6 +30,7 @@ public abstract class JdbcDatabaseContainer parameters = new HashMap<>(); private static final RateLimiter DB_CONNECT_RATE_LIMIT = RateLimiterBuilder.newBuilder() @@ -111,6 +116,11 @@ public SELF withConnectTimeoutSeconds(int connectTimeoutSeconds) { return self(); } + public SELF withInitScript(String initScriptPath) { + this.initScriptPath = initScriptPath; + return self(); + } + @Override protected void waitUntilContainerStarted() { // Repeatedly try and open a connection to the DB and execute a test query @@ -135,6 +145,11 @@ protected void waitUntilContainerStarted() { }); } + @Override + protected void containerIsStarted(InspectContainerResponse containerInfo) { + runInitScriptIfRequired(); + } + /** * Obtain an instance of the correct JDBC driver for this particular database container type * @@ -202,6 +217,15 @@ protected void optionallyMapResourceParameterAsVolume(@NotNull String paramName, } } + /** + * Load init script content and apply it to the database if initScriptPath is set + */ + protected void runInitScriptIfRequired() { + if (initScriptPath != null) { + ScriptUtils.runInitScript(getDatabaseDelegate(), initScriptPath); + } + } + public void setParameters(Map parameters) { this.parameters = parameters; } @@ -228,4 +252,8 @@ protected int getStartupTimeoutSeconds() { protected int getConnectTimeoutSeconds() { return connectTimeoutSeconds; } + + protected DatabaseDelegate getDatabaseDelegate() { + return new JdbcDatabaseDelegate(this, ""); + } }