DatabaseExtension.java
/**
* Copyright (C) 2022 Christopher J. Stehno
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.github.cjstehno.testthings.junit;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.*;
import org.junit.jupiter.api.extension.ExtensionContext.Namespace;
import javax.sql.DataSource;
import java.sql.Connection;
import java.util.Optional;
import static io.github.cjstehno.testthings.Resources.resourceToString;
import static org.junit.jupiter.api.extension.ExtensionContext.Namespace.create;
import static org.junit.platform.commons.support.HierarchyTraversalMode.TOP_DOWN;
import static org.junit.platform.commons.support.ModifierSupport.isStatic;
import static org.junit.platform.commons.support.ReflectionSupport.findFields;
import static org.junit.platform.commons.support.ReflectionSupport.findMethod;
/**
* A JUnit 5 extension used to setup and tear down a database using a provided {@link DataSource}.
*
* The {@link PrepareDatabase} annotation may be applied at the class or test method level to append or override the
* setup and teardown methods.
*
* If a test method is given a {@link DataSource} parameter, it will be populated with the current data source for that
* method for use in the test.
*
* See the User Guide for more details and examples.
*/
@Slf4j
public class DatabaseExtension implements BeforeEachCallback, AfterEachCallback, ParameterResolver {
private static final Namespace NAMESPACE = create("test-things", "database");
private static final String DATA_SOURCE = "data-source";
private static final String DEFAULT_CREATOR = "createDataSource";
private static final String DEFAULT_DESTROYER = "destroyDataSource";
@Override public void beforeEach(final ExtensionContext context) throws Exception {
val dataSource = createDataSource(context).orElseThrow();
context.getStore(NAMESPACE).put(DATA_SOURCE, dataSource);
runSetupScripts(context, dataSource);
log.info("The database is set-up.");
}
@Override public void afterEach(final ExtensionContext context) throws Exception {
val dataSource = (DataSource) context.getStore(NAMESPACE).remove(DATA_SOURCE);
runTeardownScripts(context, dataSource);
destroyDataSource(context, dataSource);
log.info("The database was torn-down.");
}
@Override
public boolean supportsParameter(final ParameterContext parameterContext, final ExtensionContext extensionContext) throws ParameterResolutionException {
return extensionContext.getRequiredTestMethod().isAnnotationPresent(Test.class)
&& DataSource.class.isAssignableFrom(parameterContext.getParameter().getType());
}
@Override
public Object resolveParameter(final ParameterContext parameterContext, final ExtensionContext extensionContext) throws ParameterResolutionException {
return extensionContext.getStore(NAMESPACE).get(DATA_SOURCE);
}
private static Optional<DataSource> invokeDataSourceCreator(final ExtensionContext context, final String methodName) {
val dataSourceMethod = findMethod(context.getRequiredTestClass(), methodName).orElseThrow();
val target = isStatic(dataSourceMethod) ? context.getRequiredTestClass() : context.getRequiredTestInstance();
try {
return Optional.ofNullable((DataSource) dataSourceMethod.invoke(target));
} catch (Exception e) {
throw new RuntimeException(e);
}
}
private static Optional<DataSource> createDataSource(final ExtensionContext context) {
// check test method annotation (on test method)
if (isTestMethodAnnotated(context)) {
val anno = getMethodAnnotation(context);
if (!anno.creator().isBlank()) {
return invokeDataSourceCreator(context, anno.creator());
}
}
// check class method annotation (on test class)
if (isTestClassAnnotated(context)) {
val anno = getClassAnnotation(context);
if (!anno.creator().isBlank()) {
return invokeDataSourceCreator(context, anno.creator());
}
}
// check defined provider method
if (findMethod(context.getRequiredTestClass(), DEFAULT_CREATOR).isPresent()) {
return invokeDataSourceCreator(context, DEFAULT_CREATOR);
}
// check field of DataSource type (static or not)
val dataSourceField = findFields(context.getRequiredTestClass(), f -> DataSource.class.isAssignableFrom(f.getType()), TOP_DOWN);
if (!dataSourceField.isEmpty()) {
val target = isStatic(dataSourceField.get(0)) ? context.getRequiredTestClass() : context.getRequiredTestInstance();
try {
return Optional.ofNullable((DataSource) dataSourceField.get(0).get(target));
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
}
}
return Optional.empty();
}
private static void invokeDataSourceDestroyer(final ExtensionContext context, final String methodName, final DataSource dataSource) {
val dataSourceMethod = findMethod(context.getRequiredTestClass(), methodName, DataSource.class).orElseThrow();
val target = isStatic(dataSourceMethod) ? context.getRequiredTestClass() : context.getRequiredTestInstance();
try {
dataSourceMethod.invoke(target, dataSource);
} catch (Exception e) {
log.error("Unable to destroy DataSource using ({}): {}", methodName, e.getMessage(), e);
}
}
private static void destroyDataSource(final ExtensionContext context, final DataSource dataSource) {
log.info("Destroying data source.");
// check test method annotation (on test method)
if (isTestMethodAnnotated(context)) {
val anno = getMethodAnnotation(context);
if (!anno.destroyer().isBlank()) {
invokeDataSourceDestroyer(context, anno.destroyer(), dataSource);
}
}
// check class method annotation (on test class)
if (isTestClassAnnotated(context)) {
val anno = getClassAnnotation(context);
if (!anno.destroyer().isBlank()) {
invokeDataSourceDestroyer(context, anno.destroyer(), dataSource);
}
}
// check defined provider method
if (findMethod(context.getRequiredTestClass(), DEFAULT_DESTROYER, DataSource.class).isPresent()) {
invokeDataSourceDestroyer(context, DEFAULT_DESTROYER, dataSource);
}
// check field of DataSource type (static or not)
val dataSourceField = findFields(context.getRequiredTestClass(), f -> DataSource.class.isAssignableFrom(f.getType()), TOP_DOWN);
if (!dataSourceField.isEmpty()) {
val target = isStatic(dataSourceField.get(0)) ? context.getRequiredTestClass() : context.getRequiredTestInstance();
try {
dataSourceField.get(0).set(target, null);
} catch (IllegalAccessException e) {
log.error("Unable to destroy DataSource (field): {}", e.getMessage(), e);
}
}
}
private static void runSetupScripts(final ExtensionContext context, final DataSource dataSource) {
if (isTestMethodAnnotated(context)) {
val methodAnno = getMethodAnnotation(context);
if (methodAnno.additive()) {
// run any scripts from the class annotation
findClassPrepareDatabase(context).ifPresent(classAnno -> {
runScripts(dataSource, classAnno.setup());
});
}
// run any scripts from the method annotation
runScripts(dataSource, methodAnno.setup());
} else {
findClassPrepareDatabase(context).ifPresent(classAnno -> {
runScripts(dataSource, classAnno.setup());
});
}
}
private static void runTeardownScripts(final ExtensionContext context, final DataSource dataSource) {
if (isTestMethodAnnotated(context)) {
val methodAnno = getMethodAnnotation(context);
if (methodAnno.additive()) {
// run any scripts from the class annotation
findClassPrepareDatabase(context).ifPresent(classAnno -> {
runScripts(dataSource, classAnno.teardown());
});
}
// run any scripts from the method annotation
runScripts(dataSource, methodAnno.teardown());
} else {
findClassPrepareDatabase(context).ifPresent(classAnno -> {
runScripts(dataSource, classAnno.teardown());
});
}
}
private static Optional<PrepareDatabase> findClassPrepareDatabase(final ExtensionContext context) {
if (isTestClassAnnotated(context)) {
return Optional.of(getClassAnnotation(context));
}
return Optional.empty();
}
private static void runScripts(final DataSource dataSource, final String[] scripts) {
if (scripts != null) {
try (val conn = dataSource.getConnection()) {
for (val scriptPath : scripts) {
runScript(conn, scriptPath);
}
log.info("Done running scripts.");
} catch (Exception ex) {
log.error("Connection problem while running scripts: {}", ex.getMessage(), ex);
}
}
}
private static void runScript(final Connection conn, final String scriptPath) {
if (!scriptPath.isBlank()) {
try (val stmt = conn.createStatement()) {
stmt.execute(resourceToString(scriptPath));
conn.commit();
log.info("Executed database script ({}).", scriptPath);
} catch (Exception se) {
log.error("Problem executing database script ({}): {}", scriptPath, se.getMessage(), se);
}
}
}
private static PrepareDatabase getMethodAnnotation(final ExtensionContext context) {
return context.getRequiredTestMethod().getAnnotation(PrepareDatabase.class);
}
private static PrepareDatabase getClassAnnotation(final ExtensionContext context) {
return context.getRequiredTestClass().getAnnotation(PrepareDatabase.class);
}
private static boolean isTestMethodAnnotated(final ExtensionContext context) {
return context.getRequiredTestMethod().isAnnotationPresent(PrepareDatabase.class);
}
private static boolean isTestClassAnnotated(final ExtensionContext context) {
return context.getRequiredTestClass().isAnnotationPresent(PrepareDatabase.class);
}
}