JpaTransactionalExtension.java

package io.extact.rms.test.junit5;

import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;

import org.apache.commons.lang3.StringUtils;
import org.eclipse.microprofile.config.ConfigProvider;
import org.junit.jupiter.api.extension.AfterAllCallback;
import org.junit.jupiter.api.extension.AfterEachCallback;
import org.junit.jupiter.api.extension.AfterTestExecutionCallback;
import org.junit.jupiter.api.extension.BeforeAllCallback;
import org.junit.jupiter.api.extension.BeforeEachCallback;
import org.junit.jupiter.api.extension.BeforeTestExecutionCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.api.extension.ExtensionContext.Namespace;
import org.junit.jupiter.api.extension.ExtensionContext.Store;
import org.junit.jupiter.api.extension.ParameterContext;
import org.junit.jupiter.api.extension.ParameterResolver;
import org.junit.platform.commons.support.AnnotationSupport;

import jakarta.persistence.EntityManager;
import jakarta.persistence.EntityManagerFactory;
import jakarta.persistence.EntityTransaction;
import jakarta.persistence.Persistence;

/**
 * Acquire and release EntityManager in pre-processing and post-processing for each test.
 * If a transaction is required, @TransactionalTest can be annotated so that the test method
 * can multiply the transaction that is the transaction boundary.
 * <pre>
 * @ExtendWith(JpaTransactionalExtension.class)
 * class RentalItemJpaRepositoryTest {
 *
 *   @BeforeEach
 *   void setup(EntityManager em) {
 *      ....
 *   }
 *   @TransactionalTest(shouldCommit = false)
 *   void addTest() {
 *      ....
 *   }
 * </pre>
 */
public class JpaTransactionalExtension implements
        BeforeAllCallback, AfterAllCallback, BeforeEachCallback, AfterEachCallback,
        BeforeTestExecutionCallback, AfterTestExecutionCallback, ParameterResolver {

    private static final String CURRENT_ENTITY_FACTORY = "CURRENT_ENTITY_MANAGER_FACTORY";
    private static final String CURRENT_ENTITY_MANAGER = "CURRENT_ENTITY_MANAGER";
    private static final String CURRENT_ENTITY_TRANSACTION = "CURRENT_ENTITY_TRANSACTION";


    // ----------------------------------------------------- before methods

    @Override
    public void beforeAll(ExtensionContext context) {
        var unitName = geTragetUnitName();
        var properties = getPersistenceProperties();
        var emf = Persistence.createEntityManagerFactory(unitName, properties);
        getEntityManagerFactoryStore(context).put(CURRENT_ENTITY_FACTORY, new CloseableWrapper(emf));
    }

    @Override
    public void beforeEach(ExtensionContext context) throws Exception {
        EntityManagerFactory emf = getEntityManagerFactoryStore(context).get(CURRENT_ENTITY_FACTORY, CloseableWrapper.class).unwrap();
        getEntityManagerStore(context).put(CURRENT_ENTITY_MANAGER, new CloseableWrapper(emf.createEntityManager()));
    }

    @Override
    public boolean supportsParameter(ParameterContext parameterContext, ExtensionContext extensionContext) {
        return parameterContext.getParameter().getType() == EntityManager.class;
    }

    @Override
    public Object resolveParameter(ParameterContext parameterContext, ExtensionContext extensionContext) {
        return getEntityManagerStore(extensionContext).get(CURRENT_ENTITY_MANAGER, CloseableWrapper.class).unwrap();
    }

    @Override
    public void beforeTestExecution(ExtensionContext context) throws Exception {
        if (!AnnotationSupport.isAnnotated(context.getTestClass(), TransactionalForTest.class)
                && !AnnotationSupport.isAnnotated(context.getTestMethod(), TransactionalForTest.class)) {
            return;
        }

        EntityManager em = getEntityManagerStore(context).get(CURRENT_ENTITY_MANAGER, CloseableWrapper.class).unwrap();
        if (em == null) {
            throw new IllegalStateException("EntityManage is unset.");
        }

        var tx = em.getTransaction();
        tx.begin();
        getEntityManagerStore(context).put(CURRENT_ENTITY_TRANSACTION, tx);
    }


    // ----------------------------------------------------- after methods

    @Override
    public void afterTestExecution(ExtensionContext context) {
        // Give priority to Method Annotation
        TransactionalForTest transactionalTest = AnnotationSupport
                    .findAnnotation(context.getRequiredTestMethod(), TransactionalForTest.class)
                .orElse(AnnotationSupport.findAnnotation(context.getRequiredTestClass(), TransactionalForTest.class).orElse(null));
        if (transactionalTest == null) {
            return;
        }

        var tx = getEntityManagerStore(context).remove(CURRENT_ENTITY_TRANSACTION, EntityTransaction.class);
        if (transactionalTest.shouldCommit()) {
            tx.commit();
        } else {
            tx.rollback();
        }
    }

    @Override
    public void afterEach(ExtensionContext context) {
        getEntityManagerStore(context).remove(CURRENT_ENTITY_MANAGER, CloseableWrapper.class).close();
    }

    @Override
    public void afterAll(ExtensionContext context) {
        getEntityManagerFactoryStore(context).remove(CURRENT_ENTITY_FACTORY, CloseableWrapper.class).close();
    }


    // ----------------------------------------------------- private methods

    private String geTragetUnitName() {
        return ConfigProvider.getConfig().getValue("test.db.connection.unitname", String.class);
    }

    private Map<String, String> getPersistenceProperties() {
        var config = ConfigProvider.getConfig();
        var keys = StreamSupport.stream(config.getPropertyNames().spliterator(), false)
                .filter(key -> key.startsWith("test.db.connection.properties."))
                .toList();
        return keys.stream().collect(Collectors.toMap(
                    key -> StringUtils.remove(key, "test.db.connection.properties."), // prop-key
                    key -> config.getOptionalValue(key, String.class).orElse("") // prop-value
                ));
    }

    private Store getEntityManagerFactoryStore(ExtensionContext context) {
        return context.getStore(Namespace.create(context.getRequiredTestClass()));
    }

    private Store getEntityManagerStore(ExtensionContext context) {
        return context.getStore(Namespace.create(getClass(), context.getRequiredTestMethod()));
    }


    // ----------------------------------------------------- private methods

    static class CloseableWrapper implements Store.CloseableResource {
        private Object org;

        public CloseableWrapper(EntityManagerFactory closeable) {
            this.org = closeable;
        }
        public CloseableWrapper(EntityManager closeable) {
            this.org = closeable;
        }

        @Override
        public void close() {
            if (org instanceof EntityManagerFactory) {
                var closeable = (EntityManagerFactory) org;
                if (closeable.isOpen()) {
                    closeable.close();
                }
            }
            if (org instanceof EntityManager) {
                var closeable = (EntityManager) org;
                if (closeable.isOpen()) {
                    closeable.close();
                }
            }
        }
        @SuppressWarnings("unchecked")
        public <T> T unwrap() {
            return (T) org;
        }
    }
}