JPA Association Fetching Validator

Imagine having a tool that can automatically detect JPA and Hibernate performance issues. Wouldn’t that be just awesome?

Well, Hypersistence Optimizer is that tool! And it works with Spring Boot, Spring Framework, Jakarta EE, Java EE, Quarkus, or Play Framework.

So, enjoy spending your time on the things you love rather than fixing performance issues in your production system on a Saturday night!

Introduction

In this article, I’m going to show you how we can build a JPA Association Fetching Validator that asserts whether JPA and Hibernate associations are fetched using joins or secondary queries.

While Hibernate does not provide built-in support for checking the entity association fetching behavior programmatically, the API is very flexible and allows us to customize it so that we can achieve this non-trivial requirement.

Domain Model

Let’s assume we have the following Post, PostComment, and PostCommentDetails entities:

JPA Association Fetching Validator Entities

The Post parent entity looks as follows:

@Entity(name = "Post")
@Table(name = "post")
public class Post {

    @Id
    private Long id;

    private String title;

    //Getters and setters omitted for brevity
}

Next, we define the PostComment child entity, like this:

@Entity(name = "PostComment")
@Table(name = "post_comment")
public class PostComment {

    @Id
    private Long id;

    @ManyToOne
    private Post post;

    private String review;

    //Getters and setters omitted for brevity
}

Notice that the post association uses the default fetch strategy provided by the @ManyToOne association, the infamous FetchType.EAGER strategy that’s responsible for causing lots of performance problems, as explained in this article.

And the PostCommentDetails child entity defines a one-to-one child association to the PostComment parent entity. And again, the comment association uses the default FetchType.EAGER fetching strategy.

@Entity(name = "PostCommentDetails")
@Table(name = "post_comment_details")
public class PostCommentDetails {

    @Id
    private Long id;

    @OneToOne
    @MapsId
    @OnDelete(action = OnDeleteAction.CASCADE)
    private PostComment comment;

    private int votes;

    //Getters and setters omitted for brevity
}

The problem of FetchType.EAGER strategy

So, we have two associations using the FetchType.EAGER anti-pattern. therefore, when executing the following JPQL query:

List<PostCommentDetails> commentDetailsList = entityManager.createQuery("""
    select pcd
    from PostCommentDetails pcd
    order by pcd.id
    """,
    PostCommentDetails.class)
.getResultList();

Hibernate executes the following 3 SQL queries:

SELECT 
    pce.comment_id AS comment_2_2_,
    pce.votes AS votes1_2_
FROM 
    post_comment_details pce
ORDER BY 
    pce.comment_id

SELECT 
    pc.id AS id1_1_0_,
    pc.post_id AS post_id3_1_0_,
    pc.review AS review2_1_0_,
    p.id AS id1_0_1_,
    p.title AS title2_0_1_
FROM 
    post_comment pc
LEFT OUTER JOIN 
    post p ON pc.post_id=p.id
WHERE 
    pc.id = 1
    
SELECT 
    pc.id AS id1_1_0_,
    pc.post_id AS post_id3_1_0_,
    pc.review AS review2_1_0_,
    p.id AS id1_0_1_,
    p.title AS title2_0_1_
FROM 
    post_comment pc
LEFT OUTER JOIN 
    post p ON pc.post_id=p.id
WHERE 
    pc.id = 2

This is a classic N+1 query issue. However, not only are extra secondary queries executed to fetch the PostComment associations, but these queries are using JOINs to fetch the associated Post entity as well.

Unless you want to load the entire database with a single query, it’s best to avoid using the FetchType.EAGER anti-pattern.

So, let’s see if we can detect these extra secondary queries and JOINs programmatically.

Hibernate Statistics to detect secondary queries

As I explained in this article, not only can Hibernate collect statistical information, but we can even customize the data that gets collected.

For instance, we could monitor how many entities have been fetched per Session using the following SessionStatistics utility:

public class SessionStatistics extends StatisticsImpl {

    private static final ThreadLocal<Map<Class, AtomicInteger>> 
        entityFetchCountContext = new ThreadLocal<>();

    public SessionStatistics(
            SessionFactoryImplementor sessionFactory) {
        super(sessionFactory);
    }

    @Override
    public void openSession() {
        entityFetchCountContext.set(new LinkedHashMap<>());
        
        super.openSession();
    }

    @Override
    public void fetchEntity(
            String entityName) {
        Map<Class, AtomicInteger> entityFetchCountMap = entityFetchCountContext
            .get();
        
        entityFetchCountMap
            .computeIfAbsent(
                ReflectionUtils.getClass(entityName), 
                clazz -> new AtomicInteger()
            )
            .incrementAndGet();
        
        super.fetchEntity(entityName);
    }

    @Override
    public void closeSession() {
        entityFetchCountContext.remove();
        
        super.closeSession();
    }

    public static int getEntityFetchCount(
            String entityClassName) {        
        return getEntityFetchCount(
            ReflectionUtils.getClass(entityClassName)
        );
    }

    public static int getEntityFetchCount(
            Class entityClass) {
        AtomicInteger entityFetchCount = entityFetchCountContext.get()
            .get(entityClass);
            
        return entityFetchCount != null ? entityFetchCount.get() : 0;
    }

    public static class Factory implements StatisticsFactory {

        public static final Factory INSTANCE = new Factory();

        @Override
        public StatisticsImplementor buildStatistics(
                SessionFactoryImplementor sessionFactory) {
            return new SessionStatistics(sessionFactory);
        }
    }
}

The SessionStatistics class extends the default Hibernate StatisticsImpl class and overrides the following methods:

  • openSession – this callback method is called when a Hibernate Session is created for the first time. We are using this callback to initialize the ThreadLocal storage that contains the entity fetching registry.
  • fetchEntity – this callback is called whenever an entity is fetched from the database using a secondary query. And we use this callback to increase the entity fetching counter.
  • closeSession – this callback method is called when a Hibernate Session is closed. In our case, this is when we need to reset the ThreadLocal storage.

The getEntityFetchCount method will allow us to inspect how many entity instances have been fetched from the database for a given entity class.

The Factory nested class implements the StatisticsFactory interface and implements the buildStatistics method, which is called by the SessionFactory at bootstrap time.

To configure Hibernate to use the custom SessionStatistics, we have to provide the following two setting:

properties.put(
    AvailableSettings.GENERATE_STATISTICS,
    Boolean.TRUE.toString()
);
properties.put(
    StatisticsInitiator.STATS_BUILDER,
    SessionStatistics.Factory.INSTANCE
);

The first one activates the Hibernate statistics mechanism while the second one tells Hibernate to use a custom StatisticsFactory.

So, let’s see it in action!

assertEquals(0, SessionStatistics.getEntityFetchCount(PostCommentDetails.class));
assertEquals(0, SessionStatistics.getEntityFetchCount(PostComment.class));
assertEquals(0, SessionStatistics.getEntityFetchCount(Post.class));

List<PostCommentDetails> commentDetailsList = entityManager.createQuery("""
    select pcd
    from PostCommentDetails pcd
    order by pcd.id
    """,
    PostCommentDetails.class)
.getResultList();

assertEquals(2, commentDetailsList.size());

assertEquals(0, SessionStatistics.getEntityFetchCount(PostCommentDetails.class));
assertEquals(2, SessionStatistics.getEntityFetchCount(PostComment.class));
assertEquals(0, SessionStatistics.getEntityFetchCount(Post.class));

So, the SessionStatistics can only help us to determine the extra secondary queries, but it does not work for extra JOINs that are executed because of FetchType.EAGER associations.

Hibernate Event Listeners to detect both secondary queries and extra JOINs

Fortunately for us, Hibernate is extremely customizable since, internally, it’s built on top of the Observer pattern.

Every entity action generates an event that’s handled by an event listener, and we can use this mechanism to monitor the entity fetching behavior.

When an entity is fetched directly using the find method or via a query, a LoadEvent is going to be triggered. The LoadEvent is handled first by the LoadEventListener and PostLoadEventListener Hibernate event handlers.

While Hibernate provides default event handlers for all entity events, we can also prepend or append our own listeners using an Integrator, like the following one:

public class AssociationFetchingEventListenerIntegrator 
        implements Integrator {

    public static final AssociationFetchingEventListenerIntegrator INSTANCE = 
        new AssociationFetchingEventListenerIntegrator();

    @Override
    public void integrate(
            Metadata metadata,
            SessionFactoryImplementor sessionFactory,
            SessionFactoryServiceRegistry serviceRegistry) {

        final EventListenerRegistry eventListenerRegistry =
            serviceRegistry.getService(EventListenerRegistry.class);

        eventListenerRegistry.prependListeners(
            EventType.LOAD,
            AssociationFetchPreLoadEventListener.INSTANCE
        );

        eventListenerRegistry.appendListeners(
            EventType.LOAD,
            AssociationFetchLoadEventListener.INSTANCE
        );

        eventListenerRegistry.appendListeners(
            EventType.POST_LOAD,
            AssociationFetchPostLoadEventListener.INSTANCE
        );
    }

    @Override
    public void disintegrate(
            SessionFactoryImplementor sessionFactory,
            SessionFactoryServiceRegistry serviceRegistry) {
    }
}

Our AssociationFetchingEventListenerIntegrator registers three extra event listeners:

  • An AssociationFetchPreLoadEventListener that is executed before the default Hibernate LoadEventListener
  • An AssociationFetchLoadEventListener that is executed after the default Hibernate LoadEventListener
  • And an AssociationFetchPostLoadEventListener that is executed after the default Hibernate PostLoadEventListener

To instruct Hibernate to use our custom AssociationFetchingEventListenerIntegrator that registers the extra event listeners, we just have to set the hibernate.integrator_provider configuration property:

properties.put(
    "hibernate.integrator_provider", 
    (IntegratorProvider) () -> Collections.singletonList(
        AssociationFetchingEventListenerIntegrator.INSTANCE
    )
);

The AssociationFetchPreLoadEventListener implements the LoadEventListener interface and looks like this:

public class AssociationFetchPreLoadEventListener 
        implements LoadEventListener {

    public static final AssociationFetchPreLoadEventListener INSTANCE = 
        new AssociationFetchPreLoadEventListener();

    @Override
    public void onLoad(
            LoadEvent event, 
            LoadType loadType) {
        AssociationFetch.Context
            .get(event.getSession())
            .preLoad(event);
    }
}

The AssociationFetchLoadEventListener also implements the LoadEventListener interface and looks as follows:

public class AssociationFetchLoadEventListener 
        implements LoadEventListener {

    public static final AssociationFetchLoadEventListener INSTANCE = 
        new AssociationFetchLoadEventListener();

    @Override
    public void onLoad(
            LoadEvent event, 
            LoadType loadType) {
        AssociationFetch.Context
            .get(event.getSession())
            .load(event);
    }
}

And, the AssociationFetchPostLoadEventListener implements the PostLoadEventListener interface and looks like this:

public class AssociationFetchPostLoadEventListener 
        implements PostLoadEventListener {

    public static final AssociationFetchPostLoadEventListener INSTANCE = 
        new AssociationFetchPostLoadEventListener();

    @Override
    public void onPostLoad(
            PostLoadEvent event) {
        AssociationFetch.Context
            .get(event.getSession())
            .postLoad(event);
    }
}

Notice that all the entity fetching monitoring logic is encapsulated in the following AssociationFetch class:

public class AssociationFetch {

    private final Object entity;

    public AssociationFetch(Object entity) {
        this.entity = entity;
    }

    public Object getEntity() {
        return entity;
    }

    public static class Context implements Serializable {
        public static final String SESSION_PROPERTY_KEY = "ASSOCIATION_FETCH_LIST";

        private Map<String, Integer> entityFetchCountByClassNameMap = 
            new LinkedHashMap<>();

        private Set<EntityIdentifier> joinedFetchedEntities = 
            new LinkedHashSet<>();

        private Set<EntityIdentifier> secondaryFetchedEntities = 
            new LinkedHashSet<>();

        private Map<EntityIdentifier, Object> loadedEntities = 
            new LinkedHashMap<>();

        public List<AssociationFetch> getAssociationFetches() {
            List<AssociationFetch> associationFetches = new ArrayList<>();

            for(Map.Entry<EntityIdentifier, Object> loadedEntityMapEntry : 
                    loadedEntities.entrySet()) {
                EntityIdentifier entityIdentifier = loadedEntityMapEntry.getKey();
                Object entity = loadedEntityMapEntry.getValue();
                
                if(joinedFetchedEntities.contains(entityIdentifier) ||
                   secondaryFetchedEntities.contains(entityIdentifier)) {
                    associationFetches.add(new AssociationFetch(entity));
                }
            }

            return associationFetches;
        }

        public List<AssociationFetch> getJoinedAssociationFetches() {
            List<AssociationFetch> associationFetches = new ArrayList<>();

            for(Map.Entry<EntityIdentifier, Object> loadedEntityMapEntry : 
                    loadedEntities.entrySet()) {
                EntityIdentifier entityIdentifier = loadedEntityMapEntry.getKey();
                Object entity = loadedEntityMapEntry.getValue();
                
                if(joinedFetchedEntities.contains(entityIdentifier)) {
                    associationFetches.add(new AssociationFetch(entity));
                }
            }

            return associationFetches;
        }

        public List<AssociationFetch> getSecondaryAssociationFetches() {
            List<AssociationFetch> associationFetches = new ArrayList<>();

            for(Map.Entry<EntityIdentifier, Object> loadedEntityMapEntry : 
                    loadedEntities.entrySet()) {
                EntityIdentifier entityIdentifier = loadedEntityMapEntry.getKey();
                Object entity = loadedEntityMapEntry.getValue();
                if(secondaryFetchedEntities.contains(entityIdentifier)) {
                    associationFetches.add(new AssociationFetch(entity));
                }
            }

            return associationFetches;
        }

        public Map<Class, List<Object>> getAssociationFetchEntityMap() {
            return getAssociationFetches()
                .stream()
                .map(AssociationFetch::getEntity)
                .collect(groupingBy(Object::getClass));
        }

        public void preLoad(LoadEvent loadEvent) {
            String entityClassName = loadEvent.getEntityClassName();
            entityFetchCountByClassNameMap.put(
                entityClassName, 
                SessionStatistics.getEntityFetchCount(
                    entityClassName
                )
            );
        }

        public void load(LoadEvent loadEvent) {
            String entityClassName = loadEvent.getEntityClassName();
            int previousFetchCount = entityFetchCountByClassNameMap.get(
                entityClassName
            );
            int currentFetchCount = SessionStatistics.getEntityFetchCount(
                entityClassName
            );

            EntityIdentifier entityIdentifier = new EntityIdentifier(
                ReflectionUtils.getClass(loadEvent.getEntityClassName()),
                loadEvent.getEntityId()
            );

            if (loadEvent.isAssociationFetch()) {
                if (currentFetchCount == previousFetchCount) {
                    joinedFetchedEntities.add(entityIdentifier);
                } else if (currentFetchCount > previousFetchCount){
                    secondaryFetchedEntities.add(entityIdentifier);
                }
            }
        }

        public void postLoad(PostLoadEvent postLoadEvent) {
            loadedEntities.put(
                new EntityIdentifier(
                    postLoadEvent.getEntity().getClass(),
                    postLoadEvent.getId()
                ),
                postLoadEvent.getEntity()
            );
        }

        public static Context get(Session session) {
            Context context = (Context) session.getProperties()
                .get(SESSION_PROPERTY_KEY);
                
            if (context == null) {
                context = new Context();
                session.setProperty(SESSION_PROPERTY_KEY, context);
            }
            return context;
        }

        public static Context get(EntityManager entityManager) {
            return get(entityManager.unwrap(Session.class));
        }
    }

    private static class EntityIdentifier {
        private final Class entityClass;

        private final Serializable entityId;

        public EntityIdentifier(Class entityClass, Serializable entityId) {
            this.entityClass = entityClass;
            this.entityId = entityId;
        }

        public Class getEntityClass() {
            return entityClass;
        }

        public Serializable getEntityId() {
            return entityId;
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) return true;
            if (!(o instanceof EntityIdentifier)) return false;
            EntityIdentifier that = (EntityIdentifier) o;
            return Objects.equals(getEntityClass(), that.getEntityClass()) && 
                   Objects.equals(getEntityId(), that.getEntityId());
        }

        @Override
        public int hashCode() {
            return Objects.hash(getEntityClass(), getEntityId());
        }
    }
}

And, that’s it!

Testing Time

So, let’s see how this new utility works. When running the same query that was used at the beginning of this article, we can see that we can now capture all the association fetches that were done while executing the JPQL query:

AssociationFetch.Context context = AssociationFetch.Context.get(
    entityManager
);
assertTrue(context.getAssociationFetches().isEmpty());

List<PostCommentDetails> commentDetailsList = entityManager.createQuery("""
    select pcd
    from PostCommentDetails pcd
    order by pcd.id
    """,
    PostCommentDetails.class)
.getResultList();

assertEquals(3, context.getAssociationFetches().size());
assertEquals(2, context.getSecondaryAssociationFetches().size());
assertEquals(1, context.getJoinedAssociationFetches().size());

Map<Class, List<Object>> associationFetchMap = context
    .getAssociationFetchEntityMap();
    
assertEquals(2, associationFetchMap.size());

for (PostCommentDetails commentDetails : commentDetailsList) {
    assertTrue(
        associationFetchMap.get(PostComment.class)
            .contains(commentDetails.getComment())
    );
    assertTrue(
        associationFetchMap.get(Post.class)
            .contains(commentDetails.getComment().getPost())
    );
}

The tool tells us that 3 more entities are fetched by that query:

  • 2 PostComment entities using two secondary queries
  • one Post entity that’s fetched using a JOIN clause by the secondary queries

If we rewrite the previous query to use JOIN FETCH instead for all these 3 associations:

AssociationFetch.Context context = AssociationFetch.Context.get(
    entityManager
);
assertTrue(context.getAssociationFetches().isEmpty());

List<PostCommentDetails> commentDetailsList = entityManager.createQuery("""
    select pcd
    from PostCommentDetails pcd
    join fetch pcd.comment pc
    join fetch pc.post
    order by pcd.id
    """,
    PostCommentDetails.class)
.getResultList();

assertEquals(3, context.getJoinedAssociationFetches().size());
assertTrue(context.getSecondaryAssociationFetches().isEmpty());

We can see that, indeed, no secondary SQL query is executed this time, and the 3 associations are fetched using JOIN clauses.

Cool, right?

If you enjoyed this article, I bet you are going to love my Book and Video Courses as well.

Seize the deal! 40% discount. Seize the deal! 40% discount.

Conclusion

Building a JPA Association Fetching Validator can be done just fine using the Hibernate ORM since the API provides many extension points.

If you like this JPA Association Fetching Validator tool, then you are going to love Hypersistence Optizier, which promises tens of checks and validations so that you can get the most out of your Spring Boot or Jakarta EE application.

Transactions and Concurrency Control eBook

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.