1 package io.extact.rms.test.junit5;
2
3 import java.util.Map;
4 import java.util.stream.Collectors;
5 import java.util.stream.StreamSupport;
6
7 import org.apache.commons.lang3.StringUtils;
8 import org.eclipse.microprofile.config.ConfigProvider;
9 import org.junit.jupiter.api.extension.AfterAllCallback;
10 import org.junit.jupiter.api.extension.AfterEachCallback;
11 import org.junit.jupiter.api.extension.AfterTestExecutionCallback;
12 import org.junit.jupiter.api.extension.BeforeAllCallback;
13 import org.junit.jupiter.api.extension.BeforeEachCallback;
14 import org.junit.jupiter.api.extension.BeforeTestExecutionCallback;
15 import org.junit.jupiter.api.extension.ExtensionContext;
16 import org.junit.jupiter.api.extension.ExtensionContext.Namespace;
17 import org.junit.jupiter.api.extension.ExtensionContext.Store;
18 import org.junit.jupiter.api.extension.ParameterContext;
19 import org.junit.jupiter.api.extension.ParameterResolver;
20 import org.junit.platform.commons.support.AnnotationSupport;
21
22 import jakarta.persistence.EntityManager;
23 import jakarta.persistence.EntityManagerFactory;
24 import jakarta.persistence.EntityTransaction;
25 import jakarta.persistence.Persistence;
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45 public class JpaTransactionalExtension implements
46 BeforeAllCallback, AfterAllCallback, BeforeEachCallback, AfterEachCallback,
47 BeforeTestExecutionCallback, AfterTestExecutionCallback, ParameterResolver {
48
49 private static final String CURRENT_ENTITY_FACTORY = "CURRENT_ENTITY_MANAGER_FACTORY";
50 private static final String CURRENT_ENTITY_MANAGER = "CURRENT_ENTITY_MANAGER";
51 private static final String CURRENT_ENTITY_TRANSACTION = "CURRENT_ENTITY_TRANSACTION";
52
53
54
55
56 @Override
57 public void beforeAll(ExtensionContext context) {
58 var unitName = geTragetUnitName();
59 var properties = getPersistenceProperties();
60 var emf = Persistence.createEntityManagerFactory(unitName, properties);
61 getEntityManagerFactoryStore(context).put(CURRENT_ENTITY_FACTORY, new CloseableWrapper(emf));
62 }
63
64 @Override
65 public void beforeEach(ExtensionContext context) throws Exception {
66 EntityManagerFactory emf = getEntityManagerFactoryStore(context).get(CURRENT_ENTITY_FACTORY, CloseableWrapper.class).unwrap();
67 getEntityManagerStore(context).put(CURRENT_ENTITY_MANAGER, new CloseableWrapper(emf.createEntityManager()));
68 }
69
70 @Override
71 public boolean supportsParameter(ParameterContext parameterContext, ExtensionContext extensionContext) {
72 return parameterContext.getParameter().getType() == EntityManager.class;
73 }
74
75 @Override
76 public Object resolveParameter(ParameterContext parameterContext, ExtensionContext extensionContext) {
77 return getEntityManagerStore(extensionContext).get(CURRENT_ENTITY_MANAGER, CloseableWrapper.class).unwrap();
78 }
79
80 @Override
81 public void beforeTestExecution(ExtensionContext context) throws Exception {
82 if (!AnnotationSupport.isAnnotated(context.getTestClass(), TransactionalForTest.class)
83 && !AnnotationSupport.isAnnotated(context.getTestMethod(), TransactionalForTest.class)) {
84 return;
85 }
86
87 EntityManager em = getEntityManagerStore(context).get(CURRENT_ENTITY_MANAGER, CloseableWrapper.class).unwrap();
88 if (em == null) {
89 throw new IllegalStateException("EntityManage is unset.");
90 }
91
92 var tx = em.getTransaction();
93 tx.begin();
94 getEntityManagerStore(context).put(CURRENT_ENTITY_TRANSACTION, tx);
95 }
96
97
98
99
100 @Override
101 public void afterTestExecution(ExtensionContext context) {
102
103 TransactionalForTest transactionalTest = AnnotationSupport
104 .findAnnotation(context.getRequiredTestMethod(), TransactionalForTest.class)
105 .orElse(AnnotationSupport.findAnnotation(context.getRequiredTestClass(), TransactionalForTest.class).orElse(null));
106 if (transactionalTest == null) {
107 return;
108 }
109
110 var tx = getEntityManagerStore(context).remove(CURRENT_ENTITY_TRANSACTION, EntityTransaction.class);
111 if (transactionalTest.shouldCommit()) {
112 tx.commit();
113 } else {
114 tx.rollback();
115 }
116 }
117
118 @Override
119 public void afterEach(ExtensionContext context) {
120 getEntityManagerStore(context).remove(CURRENT_ENTITY_MANAGER, CloseableWrapper.class).close();
121 }
122
123 @Override
124 public void afterAll(ExtensionContext context) {
125 getEntityManagerFactoryStore(context).remove(CURRENT_ENTITY_FACTORY, CloseableWrapper.class).close();
126 }
127
128
129
130
131 private String geTragetUnitName() {
132 return ConfigProvider.getConfig().getValue("test.db.connection.unitname", String.class);
133 }
134
135 private Map<String, String> getPersistenceProperties() {
136 var config = ConfigProvider.getConfig();
137 var keys = StreamSupport.stream(config.getPropertyNames().spliterator(), false)
138 .filter(key -> key.startsWith("test.db.connection.properties."))
139 .toList();
140 return keys.stream().collect(Collectors.toMap(
141 key -> StringUtils.remove(key, "test.db.connection.properties."),
142 key -> config.getOptionalValue(key, String.class).orElse("")
143 ));
144 }
145
146 private Store getEntityManagerFactoryStore(ExtensionContext context) {
147 return context.getStore(Namespace.create(context.getRequiredTestClass()));
148 }
149
150 private Store getEntityManagerStore(ExtensionContext context) {
151 return context.getStore(Namespace.create(getClass(), context.getRequiredTestMethod()));
152 }
153
154
155
156
157 static class CloseableWrapper implements Store.CloseableResource {
158 private Object org;
159
160 public CloseableWrapper(EntityManagerFactory closeable) {
161 this.org = closeable;
162 }
163 public CloseableWrapper(EntityManager closeable) {
164 this.org = closeable;
165 }
166
167 @Override
168 public void close() {
169 if (org instanceof EntityManagerFactory) {
170 var closeable = (EntityManagerFactory) org;
171 if (closeable.isOpen()) {
172 closeable.close();
173 }
174 }
175 if (org instanceof EntityManager) {
176 var closeable = (EntityManager) org;
177 if (closeable.isOpen()) {
178 closeable.close();
179 }
180 }
181 }
182 @SuppressWarnings("unchecked")
183 public <T> T unwrap() {
184 return (T) org;
185 }
186 }
187 }