View Javadoc
1   package io.extact.rms.test.junit5;
2   
3   import java.util.Map;
4   import java.util.stream.Collectors;
5   import java.util.stream.StreamSupport;
6   
7   import org.apache.commons.lang3.StringUtils;
8   import org.eclipse.microprofile.config.ConfigProvider;
9   import org.junit.jupiter.api.extension.AfterAllCallback;
10  import org.junit.jupiter.api.extension.AfterEachCallback;
11  import org.junit.jupiter.api.extension.AfterTestExecutionCallback;
12  import org.junit.jupiter.api.extension.BeforeAllCallback;
13  import org.junit.jupiter.api.extension.BeforeEachCallback;
14  import org.junit.jupiter.api.extension.BeforeTestExecutionCallback;
15  import org.junit.jupiter.api.extension.ExtensionContext;
16  import org.junit.jupiter.api.extension.ExtensionContext.Namespace;
17  import org.junit.jupiter.api.extension.ExtensionContext.Store;
18  import org.junit.jupiter.api.extension.ParameterContext;
19  import org.junit.jupiter.api.extension.ParameterResolver;
20  import org.junit.platform.commons.support.AnnotationSupport;
21  
22  import jakarta.persistence.EntityManager;
23  import jakarta.persistence.EntityManagerFactory;
24  import jakarta.persistence.EntityTransaction;
25  import jakarta.persistence.Persistence;
26  
27  /**
28   * Acquire and release EntityManager in pre-processing and post-processing for each test.
29   * If a transaction is required, @TransactionalTest can be annotated so that the test method
30   * can multiply the transaction that is the transaction boundary.
31   * <pre>
32   * @ExtendWith(JpaTransactionalExtension.class)
33   * class RentalItemJpaRepositoryTest {
34   *
35   *   @BeforeEach
36   *   void setup(EntityManager em) {
37   *      ....
38   *   }
39   *   @TransactionalTest(shouldCommit = false)
40   *   void addTest() {
41   *      ....
42   *   }
43   * </pre>
44   */
45  public class JpaTransactionalExtension implements
46          BeforeAllCallback, AfterAllCallback, BeforeEachCallback, AfterEachCallback,
47          BeforeTestExecutionCallback, AfterTestExecutionCallback, ParameterResolver {
48  
49      private static final String CURRENT_ENTITY_FACTORY = "CURRENT_ENTITY_MANAGER_FACTORY";
50      private static final String CURRENT_ENTITY_MANAGER = "CURRENT_ENTITY_MANAGER";
51      private static final String CURRENT_ENTITY_TRANSACTION = "CURRENT_ENTITY_TRANSACTION";
52  
53  
54      // ----------------------------------------------------- before methods
55  
56      @Override
57      public void beforeAll(ExtensionContext context) {
58          var unitName = geTragetUnitName();
59          var properties = getPersistenceProperties();
60          var emf = Persistence.createEntityManagerFactory(unitName, properties);
61          getEntityManagerFactoryStore(context).put(CURRENT_ENTITY_FACTORY, new CloseableWrapper(emf));
62      }
63  
64      @Override
65      public void beforeEach(ExtensionContext context) throws Exception {
66          EntityManagerFactory emf = getEntityManagerFactoryStore(context).get(CURRENT_ENTITY_FACTORY, CloseableWrapper.class).unwrap();
67          getEntityManagerStore(context).put(CURRENT_ENTITY_MANAGER, new CloseableWrapper(emf.createEntityManager()));
68      }
69  
70      @Override
71      public boolean supportsParameter(ParameterContext parameterContext, ExtensionContext extensionContext) {
72          return parameterContext.getParameter().getType() == EntityManager.class;
73      }
74  
75      @Override
76      public Object resolveParameter(ParameterContext parameterContext, ExtensionContext extensionContext) {
77          return getEntityManagerStore(extensionContext).get(CURRENT_ENTITY_MANAGER, CloseableWrapper.class).unwrap();
78      }
79  
80      @Override
81      public void beforeTestExecution(ExtensionContext context) throws Exception {
82          if (!AnnotationSupport.isAnnotated(context.getTestClass(), TransactionalForTest.class)
83                  && !AnnotationSupport.isAnnotated(context.getTestMethod(), TransactionalForTest.class)) {
84              return;
85          }
86  
87          EntityManager em = getEntityManagerStore(context).get(CURRENT_ENTITY_MANAGER, CloseableWrapper.class).unwrap();
88          if (em == null) {
89              throw new IllegalStateException("EntityManage is unset.");
90          }
91  
92          var tx = em.getTransaction();
93          tx.begin();
94          getEntityManagerStore(context).put(CURRENT_ENTITY_TRANSACTION, tx);
95      }
96  
97  
98      // ----------------------------------------------------- after methods
99  
100     @Override
101     public void afterTestExecution(ExtensionContext context) {
102         // Give priority to Method Annotation
103         TransactionalForTest transactionalTest = AnnotationSupport
104                     .findAnnotation(context.getRequiredTestMethod(), TransactionalForTest.class)
105                 .orElse(AnnotationSupport.findAnnotation(context.getRequiredTestClass(), TransactionalForTest.class).orElse(null));
106         if (transactionalTest == null) {
107             return;
108         }
109 
110         var tx = getEntityManagerStore(context).remove(CURRENT_ENTITY_TRANSACTION, EntityTransaction.class);
111         if (transactionalTest.shouldCommit()) {
112             tx.commit();
113         } else {
114             tx.rollback();
115         }
116     }
117 
118     @Override
119     public void afterEach(ExtensionContext context) {
120         getEntityManagerStore(context).remove(CURRENT_ENTITY_MANAGER, CloseableWrapper.class).close();
121     }
122 
123     @Override
124     public void afterAll(ExtensionContext context) {
125         getEntityManagerFactoryStore(context).remove(CURRENT_ENTITY_FACTORY, CloseableWrapper.class).close();
126     }
127 
128 
129     // ----------------------------------------------------- private methods
130 
131     private String geTragetUnitName() {
132         return ConfigProvider.getConfig().getValue("test.db.connection.unitname", String.class);
133     }
134 
135     private Map<String, String> getPersistenceProperties() {
136         var config = ConfigProvider.getConfig();
137         var keys = StreamSupport.stream(config.getPropertyNames().spliterator(), false)
138                 .filter(key -> key.startsWith("test.db.connection.properties."))
139                 .toList();
140         return keys.stream().collect(Collectors.toMap(
141                     key -> StringUtils.remove(key, "test.db.connection.properties."), // prop-key
142                     key -> config.getOptionalValue(key, String.class).orElse("") // prop-value
143                 ));
144     }
145 
146     private Store getEntityManagerFactoryStore(ExtensionContext context) {
147         return context.getStore(Namespace.create(context.getRequiredTestClass()));
148     }
149 
150     private Store getEntityManagerStore(ExtensionContext context) {
151         return context.getStore(Namespace.create(getClass(), context.getRequiredTestMethod()));
152     }
153 
154 
155     // ----------------------------------------------------- private methods
156 
157     static class CloseableWrapper implements Store.CloseableResource {
158         private Object org;
159 
160         public CloseableWrapper(EntityManagerFactory closeable) {
161             this.org = closeable;
162         }
163         public CloseableWrapper(EntityManager closeable) {
164             this.org = closeable;
165         }
166 
167         @Override
168         public void close() {
169             if (org instanceof EntityManagerFactory) {
170                 var closeable = (EntityManagerFactory) org;
171                 if (closeable.isOpen()) {
172                     closeable.close();
173                 }
174             }
175             if (org instanceof EntityManager) {
176                 var closeable = (EntityManager) org;
177                 if (closeable.isOpen()) {
178                     closeable.close();
179                 }
180             }
181         }
182         @SuppressWarnings("unchecked")
183         public <T> T unwrap() {
184             return (T) org;
185         }
186     }
187 }