JPA Association Fetching Validator
Imagine having a tool that can automatically detect JPA and Hibernate performance issues. Wouldn’t that be just awesome?
Well, Hypersistence Optimizer is that tool! And it works with Spring Boot, Spring Framework, Jakarta EE, Java EE, Quarkus, or Play Framework.
So, enjoy spending your time on the things you love rather than fixing performance issues in your production system on a Saturday night!
Introduction
In this article, I’m going to show you how we can build a JPA Association Fetching Validator that asserts whether JPA and Hibernate associations are fetched using joins or secondary queries.
While Hibernate does not provide built-in support for checking the entity association fetching behavior programmatically, the API is very flexible and allows us to customize it so that we can achieve this non-trivial requirement.
Domain Model
Let’s assume we have the following Post
, PostComment
, and PostCommentDetails
entities:
The Post
parent entity looks as follows:
@Entity(name = "Post") @Table(name = "post") public class Post { @Id private Long id; private String title; //Getters and setters omitted for brevity }
Next, we define the PostComment
child entity, like this:
@Entity(name = "PostComment") @Table(name = "post_comment") public class PostComment { @Id private Long id; @ManyToOne private Post post; private String review; //Getters and setters omitted for brevity }
Notice that the post
association uses the default fetch strategy provided by the @ManyToOne
association, the infamous FetchType.EAGER
strategy that’s responsible for causing lots of performance problems, as explained in this article.
And the PostCommentDetails
child entity defines a one-to-one child association to the PostComment
parent entity. And again, the comment
association uses the default FetchType.EAGER
fetching strategy.
@Entity(name = "PostCommentDetails") @Table(name = "post_comment_details") public class PostCommentDetails { @Id private Long id; @OneToOne @MapsId @OnDelete(action = OnDeleteAction.CASCADE) private PostComment comment; private int votes; //Getters and setters omitted for brevity }
The problem of FetchType.EAGER strategy
So, we have two associations using the FetchType.EAGER
anti-pattern. therefore, when executing the following JPQL query:
List<PostCommentDetails> commentDetailsList = entityManager.createQuery(""" select pcd from PostCommentDetails pcd order by pcd.id """, PostCommentDetails.class) .getResultList();
Hibernate executes the following 3 SQL queries:
SELECT pce.comment_id AS comment_2_2_, pce.votes AS votes1_2_ FROM post_comment_details pce ORDER BY pce.comment_id SELECT pc.id AS id1_1_0_, pc.post_id AS post_id3_1_0_, pc.review AS review2_1_0_, p.id AS id1_0_1_, p.title AS title2_0_1_ FROM post_comment pc LEFT OUTER JOIN post p ON pc.post_id=p.id WHERE pc.id = 1 SELECT pc.id AS id1_1_0_, pc.post_id AS post_id3_1_0_, pc.review AS review2_1_0_, p.id AS id1_0_1_, p.title AS title2_0_1_ FROM post_comment pc LEFT OUTER JOIN post p ON pc.post_id=p.id WHERE pc.id = 2
This is a classic N+1 query issue. However, not only are extra secondary queries executed to fetch the PostComment
associations, but these queries are using JOINs to fetch the associated Post
entity as well.
Unless you want to load the entire database with a single query, it’s best to avoid using the
FetchType.EAGER
anti-pattern.
So, let’s see if we can detect these extra secondary queries and JOINs programmatically.
Hibernate Statistics to detect secondary queries
As I explained in this article, not only can Hibernate collect statistical information, but we can even customize the data that gets collected.
For instance, we could monitor how many entities have been fetched per Session using the following SessionStatistics
utility:
public class SessionStatistics extends StatisticsImpl { private static final ThreadLocal<Map<Class, AtomicInteger>> entityFetchCountContext = new ThreadLocal<>(); public SessionStatistics( SessionFactoryImplementor sessionFactory) { super(sessionFactory); } @Override public void openSession() { entityFetchCountContext.set(new LinkedHashMap<>()); super.openSession(); } @Override public void fetchEntity( String entityName) { Map<Class, AtomicInteger> entityFetchCountMap = entityFetchCountContext .get(); entityFetchCountMap .computeIfAbsent( ReflectionUtils.getClass(entityName), clazz -> new AtomicInteger() ) .incrementAndGet(); super.fetchEntity(entityName); } @Override public void closeSession() { entityFetchCountContext.remove(); super.closeSession(); } public static int getEntityFetchCount( String entityClassName) { return getEntityFetchCount( ReflectionUtils.getClass(entityClassName) ); } public static int getEntityFetchCount( Class entityClass) { AtomicInteger entityFetchCount = entityFetchCountContext.get() .get(entityClass); return entityFetchCount != null ? entityFetchCount.get() : 0; } public static class Factory implements StatisticsFactory { public static final Factory INSTANCE = new Factory(); @Override public StatisticsImplementor buildStatistics( SessionFactoryImplementor sessionFactory) { return new SessionStatistics(sessionFactory); } } }
The SessionStatistics
class extends the default Hibernate StatisticsImpl
class and overrides the following methods:
openSession
– this callback method is called when a HibernateSession
is created for the first time. We are using this callback to initialize theThreadLocal
storage that contains the entity fetching registry.fetchEntity
– this callback is called whenever an entity is fetched from the database using a secondary query. And we use this callback to increase the entity fetching counter.closeSession
– this callback method is called when a HibernateSession
is closed. In our case, this is when we need to reset theThreadLocal
storage.
The getEntityFetchCount
method will allow us to inspect how many entity instances have been fetched from the database for a given entity class.
The Factory
nested class implements the StatisticsFactory
interface and implements the buildStatistics
method, which is called by the SessionFactory
at bootstrap time.
To configure Hibernate to use the custom SessionStatistics
, we have to provide the following two setting:
properties.put( AvailableSettings.GENERATE_STATISTICS, Boolean.TRUE.toString() ); properties.put( StatisticsInitiator.STATS_BUILDER, SessionStatistics.Factory.INSTANCE );
The first one activates the Hibernate statistics mechanism while the second one tells Hibernate to use a custom StatisticsFactory
.
So, let’s see it in action!
assertEquals(0, SessionStatistics.getEntityFetchCount(PostCommentDetails.class)); assertEquals(0, SessionStatistics.getEntityFetchCount(PostComment.class)); assertEquals(0, SessionStatistics.getEntityFetchCount(Post.class)); List<PostCommentDetails> commentDetailsList = entityManager.createQuery(""" select pcd from PostCommentDetails pcd order by pcd.id """, PostCommentDetails.class) .getResultList(); assertEquals(2, commentDetailsList.size()); assertEquals(0, SessionStatistics.getEntityFetchCount(PostCommentDetails.class)); assertEquals(2, SessionStatistics.getEntityFetchCount(PostComment.class)); assertEquals(0, SessionStatistics.getEntityFetchCount(Post.class));
So, the SessionStatistics
can only help us to determine the extra secondary queries, but it does not work for extra JOINs that are executed because of FetchType.EAGER
associations.
Hibernate Event Listeners to detect both secondary queries and extra JOINs
Fortunately for us, Hibernate is extremely customizable since, internally, it’s built on top of the Observer pattern.
Every entity action generates an event that’s handled by an event listener, and we can use this mechanism to monitor the entity fetching behavior.
When an entity is fetched directly using the find
method or via a query, a LoadEvent
is going to be triggered. The LoadEvent
is handled first by the LoadEventListener
and PostLoadEventListener
Hibernate event handlers.
While Hibernate provides default event handlers for all entity events, we can also prepend or append our own listeners using an Integrator
, like the following one:
public class AssociationFetchingEventListenerIntegrator implements Integrator { public static final AssociationFetchingEventListenerIntegrator INSTANCE = new AssociationFetchingEventListenerIntegrator(); @Override public void integrate( Metadata metadata, SessionFactoryImplementor sessionFactory, SessionFactoryServiceRegistry serviceRegistry) { final EventListenerRegistry eventListenerRegistry = serviceRegistry.getService(EventListenerRegistry.class); eventListenerRegistry.prependListeners( EventType.LOAD, AssociationFetchPreLoadEventListener.INSTANCE ); eventListenerRegistry.appendListeners( EventType.LOAD, AssociationFetchLoadEventListener.INSTANCE ); eventListenerRegistry.appendListeners( EventType.POST_LOAD, AssociationFetchPostLoadEventListener.INSTANCE ); } @Override public void disintegrate( SessionFactoryImplementor sessionFactory, SessionFactoryServiceRegistry serviceRegistry) { } }
Our AssociationFetchingEventListenerIntegrator
registers three extra event listeners:
- An
AssociationFetchPreLoadEventListener
that is executed before the default HibernateLoadEventListener
- An
AssociationFetchLoadEventListener
that is executed after the default HibernateLoadEventListener
- And an
AssociationFetchPostLoadEventListener
that is executed after the default HibernatePostLoadEventListener
To instruct Hibernate to use our custom AssociationFetchingEventListenerIntegrator
that registers the extra event listeners, we just have to set the hibernate.integrator_provider
configuration property:
properties.put( "hibernate.integrator_provider", (IntegratorProvider) () -> Collections.singletonList( AssociationFetchingEventListenerIntegrator.INSTANCE ) );
The AssociationFetchPreLoadEventListener
implements the LoadEventListener
interface and looks like this:
public class AssociationFetchPreLoadEventListener implements LoadEventListener { public static final AssociationFetchPreLoadEventListener INSTANCE = new AssociationFetchPreLoadEventListener(); @Override public void onLoad( LoadEvent event, LoadType loadType) { AssociationFetch.Context .get(event.getSession()) .preLoad(event); } }
The AssociationFetchLoadEventListener
also implements the LoadEventListener
interface and looks as follows:
public class AssociationFetchLoadEventListener implements LoadEventListener { public static final AssociationFetchLoadEventListener INSTANCE = new AssociationFetchLoadEventListener(); @Override public void onLoad( LoadEvent event, LoadType loadType) { AssociationFetch.Context .get(event.getSession()) .load(event); } }
And, the AssociationFetchPostLoadEventListener
implements the PostLoadEventListener
interface and looks like this:
public class AssociationFetchPostLoadEventListener implements PostLoadEventListener { public static final AssociationFetchPostLoadEventListener INSTANCE = new AssociationFetchPostLoadEventListener(); @Override public void onPostLoad( PostLoadEvent event) { AssociationFetch.Context .get(event.getSession()) .postLoad(event); } }
Notice that all the entity fetching monitoring logic is encapsulated in the following AssociationFetch
class:
public class AssociationFetch { private final Object entity; public AssociationFetch(Object entity) { this.entity = entity; } public Object getEntity() { return entity; } public static class Context implements Serializable { public static final String SESSION_PROPERTY_KEY = "ASSOCIATION_FETCH_LIST"; private Map<String, Integer> entityFetchCountByClassNameMap = new LinkedHashMap<>(); private Set<EntityIdentifier> joinedFetchedEntities = new LinkedHashSet<>(); private Set<EntityIdentifier> secondaryFetchedEntities = new LinkedHashSet<>(); private Map<EntityIdentifier, Object> loadedEntities = new LinkedHashMap<>(); public List<AssociationFetch> getAssociationFetches() { List<AssociationFetch> associationFetches = new ArrayList<>(); for(Map.Entry<EntityIdentifier, Object> loadedEntityMapEntry : loadedEntities.entrySet()) { EntityIdentifier entityIdentifier = loadedEntityMapEntry.getKey(); Object entity = loadedEntityMapEntry.getValue(); if(joinedFetchedEntities.contains(entityIdentifier) || secondaryFetchedEntities.contains(entityIdentifier)) { associationFetches.add(new AssociationFetch(entity)); } } return associationFetches; } public List<AssociationFetch> getJoinedAssociationFetches() { List<AssociationFetch> associationFetches = new ArrayList<>(); for(Map.Entry<EntityIdentifier, Object> loadedEntityMapEntry : loadedEntities.entrySet()) { EntityIdentifier entityIdentifier = loadedEntityMapEntry.getKey(); Object entity = loadedEntityMapEntry.getValue(); if(joinedFetchedEntities.contains(entityIdentifier)) { associationFetches.add(new AssociationFetch(entity)); } } return associationFetches; } public List<AssociationFetch> getSecondaryAssociationFetches() { List<AssociationFetch> associationFetches = new ArrayList<>(); for(Map.Entry<EntityIdentifier, Object> loadedEntityMapEntry : loadedEntities.entrySet()) { EntityIdentifier entityIdentifier = loadedEntityMapEntry.getKey(); Object entity = loadedEntityMapEntry.getValue(); if(secondaryFetchedEntities.contains(entityIdentifier)) { associationFetches.add(new AssociationFetch(entity)); } } return associationFetches; } public Map<Class, List<Object>> getAssociationFetchEntityMap() { return getAssociationFetches() .stream() .map(AssociationFetch::getEntity) .collect(groupingBy(Object::getClass)); } public void preLoad(LoadEvent loadEvent) { String entityClassName = loadEvent.getEntityClassName(); entityFetchCountByClassNameMap.put( entityClassName, SessionStatistics.getEntityFetchCount( entityClassName ) ); } public void load(LoadEvent loadEvent) { String entityClassName = loadEvent.getEntityClassName(); int previousFetchCount = entityFetchCountByClassNameMap.get( entityClassName ); int currentFetchCount = SessionStatistics.getEntityFetchCount( entityClassName ); EntityIdentifier entityIdentifier = new EntityIdentifier( ReflectionUtils.getClass(loadEvent.getEntityClassName()), loadEvent.getEntityId() ); if (loadEvent.isAssociationFetch()) { if (currentFetchCount == previousFetchCount) { joinedFetchedEntities.add(entityIdentifier); } else if (currentFetchCount > previousFetchCount){ secondaryFetchedEntities.add(entityIdentifier); } } } public void postLoad(PostLoadEvent postLoadEvent) { loadedEntities.put( new EntityIdentifier( postLoadEvent.getEntity().getClass(), postLoadEvent.getId() ), postLoadEvent.getEntity() ); } public static Context get(Session session) { Context context = (Context) session.getProperties() .get(SESSION_PROPERTY_KEY); if (context == null) { context = new Context(); session.setProperty(SESSION_PROPERTY_KEY, context); } return context; } public static Context get(EntityManager entityManager) { return get(entityManager.unwrap(Session.class)); } } private static class EntityIdentifier { private final Class entityClass; private final Serializable entityId; public EntityIdentifier(Class entityClass, Serializable entityId) { this.entityClass = entityClass; this.entityId = entityId; } public Class getEntityClass() { return entityClass; } public Serializable getEntityId() { return entityId; } @Override public boolean equals(Object o) { if (this == o) return true; if (!(o instanceof EntityIdentifier)) return false; EntityIdentifier that = (EntityIdentifier) o; return Objects.equals(getEntityClass(), that.getEntityClass()) && Objects.equals(getEntityId(), that.getEntityId()); } @Override public int hashCode() { return Objects.hash(getEntityClass(), getEntityId()); } } }
And, that’s it!
Testing Time
So, let’s see how this new utility works. When running the same query that was used at the beginning of this article, we can see that we can now capture all the association fetches that were done while executing the JPQL query:
AssociationFetch.Context context = AssociationFetch.Context.get( entityManager ); assertTrue(context.getAssociationFetches().isEmpty()); List<PostCommentDetails> commentDetailsList = entityManager.createQuery(""" select pcd from PostCommentDetails pcd order by pcd.id """, PostCommentDetails.class) .getResultList(); assertEquals(3, context.getAssociationFetches().size()); assertEquals(2, context.getSecondaryAssociationFetches().size()); assertEquals(1, context.getJoinedAssociationFetches().size()); Map<Class, List<Object>> associationFetchMap = context .getAssociationFetchEntityMap(); assertEquals(2, associationFetchMap.size()); for (PostCommentDetails commentDetails : commentDetailsList) { assertTrue( associationFetchMap.get(PostComment.class) .contains(commentDetails.getComment()) ); assertTrue( associationFetchMap.get(Post.class) .contains(commentDetails.getComment().getPost()) ); }
The tool tells us that 3 more entities are fetched by that query:
- 2
PostComment
entities using two secondary queries - one
Post
entity that’s fetched using a JOIN clause by the secondary queries
If we rewrite the previous query to use JOIN FETCH instead for all these 3 associations:
AssociationFetch.Context context = AssociationFetch.Context.get( entityManager ); assertTrue(context.getAssociationFetches().isEmpty()); List<PostCommentDetails> commentDetailsList = entityManager.createQuery(""" select pcd from PostCommentDetails pcd join fetch pcd.comment pc join fetch pc.post order by pcd.id """, PostCommentDetails.class) .getResultList(); assertEquals(3, context.getJoinedAssociationFetches().size()); assertTrue(context.getSecondaryAssociationFetches().isEmpty());
We can see that, indeed, no secondary SQL query is executed this time, and the 3 associations are fetched using JOIN clauses.
Cool, right?
If you enjoyed this article, I bet you are going to love my Book and Video Courses as well.
Conclusion
Building a JPA Association Fetching Validator can be done just fine using the Hibernate ORM since the API provides many extension points.
If you like this JPA Association Fetching Validator tool, then you are going to love Hypersistence Optizier, which promises tens of checks and validations so that you can get the most out of your Spring Boot or Jakarta EE application.
