Skip to content

MyBatisNativeConfiguration.java

Kazuki Shimizu edited this page Jun 6, 2023 · 3 revisions

This class is a configuration class for building a native image with MyBatis. If you want to use the MyBatis on Native image, please this class copy to your application.

NOTE:

This configuration will move to mybatis-spring-native.

package com.example.mybatisnativesample;

import org.apache.commons.logging.LogFactory;
import org.apache.ibatis.annotations.DeleteProvider;
import org.apache.ibatis.annotations.InsertProvider;
import org.apache.ibatis.annotations.SelectProvider;
import org.apache.ibatis.annotations.UpdateProvider;
import org.apache.ibatis.cache.decorators.FifoCache;
import org.apache.ibatis.cache.decorators.LruCache;
import org.apache.ibatis.cache.decorators.SoftCache;
import org.apache.ibatis.cache.decorators.WeakCache;
import org.apache.ibatis.cache.impl.PerpetualCache;
import org.apache.ibatis.javassist.util.proxy.ProxyFactory;
import org.apache.ibatis.javassist.util.proxy.RuntimeSupport;
import org.apache.ibatis.logging.Log;
import org.apache.ibatis.logging.commons.JakartaCommonsLoggingImpl;
import org.apache.ibatis.logging.jdk14.Jdk14LoggingImpl;
import org.apache.ibatis.logging.log4j2.Log4j2Impl;
import org.apache.ibatis.logging.nologging.NoLoggingImpl;
import org.apache.ibatis.logging.slf4j.Slf4jImpl;
import org.apache.ibatis.logging.stdout.StdOutImpl;
import org.apache.ibatis.reflection.TypeParameterResolver;
import org.apache.ibatis.scripting.defaults.RawLanguageDriver;
import org.apache.ibatis.scripting.xmltags.XMLLanguageDriver;
import org.apache.ibatis.session.SqlSessionFactory;
import org.mybatis.spring.SqlSessionFactoryBean;
import org.mybatis.spring.mapper.MapperFactoryBean;
import org.mybatis.spring.mapper.MapperScannerConfigurer;
import org.springframework.aot.hint.MemberCategory;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.RuntimeHintsRegistrar;
import org.springframework.beans.PropertyValue;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.BeanFactoryAware;
import org.springframework.beans.factory.aot.BeanFactoryInitializationAotContribution;
import org.springframework.beans.factory.aot.BeanFactoryInitializationAotProcessor;
import org.springframework.beans.factory.aot.BeanRegistrationExcludeFilter;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.ConfigurableBeanFactory;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.MergedBeanDefinitionPostProcessor;
import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.ImportRuntimeHints;
import org.springframework.core.ResolvableType;
import org.springframework.util.ClassUtils;
import org.springframework.util.ReflectionUtils;

import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

@Configuration(proxyBeanMethods = false)
@ImportRuntimeHints(MyBatisNativeConfiguration.MyBaitsRuntimeHintsRegistrar.class)
public class MyBatisNativeConfiguration {

  @Bean
  MyBatisBeanFactoryInitializationAotProcessor myBatisBeanFactoryInitializationAotProcessor() {
    return new MyBatisBeanFactoryInitializationAotProcessor();
  }

  @Bean
  static MyBatisMapperFactoryBeanPostProcessor myBatisMapperFactoryBeanPostProcessor() {
    return new MyBatisMapperFactoryBeanPostProcessor();
  }

  static class MyBaitsRuntimeHintsRegistrar implements RuntimeHintsRegistrar {

    @Override
    public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
      Stream.of(RawLanguageDriver.class,
          XMLLanguageDriver.class,
          RuntimeSupport.class,
          ProxyFactory.class,
          Slf4jImpl.class,
          Log.class,
          JakartaCommonsLoggingImpl.class,
          Log4j2Impl.class,
          Jdk14LoggingImpl.class,
          StdOutImpl.class,
          NoLoggingImpl.class,
          SqlSessionFactory.class,
          PerpetualCache.class,
          FifoCache.class,
          LruCache.class,
          SoftCache.class,
          WeakCache.class,
          SqlSessionFactoryBean.class,
          ArrayList.class,
          HashMap.class,
          TreeSet.class,
          HashSet.class
      ).forEach(x -> hints.reflection().registerType(x, MemberCategory.values()));
      Stream.of(
          "org/apache/ibatis/builder/xml/*.dtd",
          "org/apache/ibatis/builder/xml/*.xsd"
      ).forEach(hints.resources()::registerPattern);
    }
  }

  static class MyBatisBeanFactoryInitializationAotProcessor
      implements BeanFactoryInitializationAotProcessor, BeanRegistrationExcludeFilter {

    private final Set<Class<?>> excludeClasses = new HashSet<>();

    MyBatisBeanFactoryInitializationAotProcessor() {
      excludeClasses.add(MapperScannerConfigurer.class);
    }

    @Override public boolean isExcludedFromAotProcessing(RegisteredBean registeredBean) {
      return excludeClasses.contains(registeredBean.getBeanClass());
    }

    @Override
    public BeanFactoryInitializationAotContribution processAheadOfTime(ConfigurableListableBeanFactory beanFactory) {
      String[] beanNames = beanFactory.getBeanNamesForType(MapperFactoryBean.class);
      if (beanNames.length == 0) {
        return null;
      }
      return (context, code) -> {
        RuntimeHints hints = context.getRuntimeHints();
        for (String beanName : beanNames) {
          BeanDefinition beanDefinition = beanFactory.getBeanDefinition(beanName.substring(1));
          PropertyValue mapperInterface = beanDefinition.getPropertyValues().getPropertyValue("mapperInterface");
          if (mapperInterface != null && mapperInterface.getValue() != null) {
            Class<?> mapperInterfaceType = (Class<?>) mapperInterface.getValue();
            if (mapperInterfaceType != null) {
              registerReflectionTypeIfNecessary(mapperInterfaceType, hints);
              hints.proxies().registerJdkProxy(mapperInterfaceType);
              hints.resources()
                  .registerPattern(mapperInterfaceType.getName().replace('.', '/').concat(".xml"));
              registerMapperRelationships(mapperInterfaceType, hints);
            }
          }
        }
      };
    }

    private void registerMapperRelationships(Class<?> mapperInterfaceType, RuntimeHints hints) {
      Method[] methods = ReflectionUtils.getAllDeclaredMethods(mapperInterfaceType);
      for (Method method : methods) {
        if (method.getDeclaringClass() != Object.class) {
          ReflectionUtils.makeAccessible(method);
          registerSqlProviderTypes(method, hints, SelectProvider.class, SelectProvider::value, SelectProvider::type);
          registerSqlProviderTypes(method, hints, InsertProvider.class, InsertProvider::value, InsertProvider::type);
          registerSqlProviderTypes(method, hints, UpdateProvider.class, UpdateProvider::value, UpdateProvider::type);
          registerSqlProviderTypes(method, hints, DeleteProvider.class, DeleteProvider::value, DeleteProvider::type);
          Class<?> returnType = MyBatisMapperTypeUtils.resolveReturnClass(mapperInterfaceType, method);
          registerReflectionTypeIfNecessary(returnType, hints);
          MyBatisMapperTypeUtils.resolveParameterClasses(mapperInterfaceType, method)
              .forEach(x -> registerReflectionTypeIfNecessary(x, hints));
        }
      }
    }

    @SafeVarargs
    private <T extends Annotation> void registerSqlProviderTypes(
        Method method, RuntimeHints hints, Class<T> annotationType, Function<T, Class<?>>... providerTypeResolvers) {
      for (T annotation : method.getAnnotationsByType(annotationType)) {
        for (Function<T, Class<?>> providerTypeResolver : providerTypeResolvers) {
          registerReflectionTypeIfNecessary(providerTypeResolver.apply(annotation), hints);
        }
      }
    }

    private void registerReflectionTypeIfNecessary(Class<?> type, RuntimeHints hints) {
      if (!type.isPrimitive() && !type.getName().startsWith("java")) {
        hints.reflection().registerType(type, MemberCategory.values());
      }
    }

  }

  static class MyBatisMapperTypeUtils {
    private MyBatisMapperTypeUtils() {
      // NOP
    }

    static Class<?> resolveReturnClass(Class<?> mapperInterface, Method method) {
      Type resolvedReturnType = TypeParameterResolver.resolveReturnType(method, mapperInterface);
      return typeToClass(resolvedReturnType, method.getReturnType());
    }

    static Set<Class<?>> resolveParameterClasses(Class<?> mapperInterface, Method method) {
      return Stream.of(TypeParameterResolver.resolveParamTypes(method, mapperInterface))
          .map(x -> typeToClass(x, x instanceof Class ? (Class<?>) x : Object.class)).collect(Collectors.toSet());
    }

    private static Class<?> typeToClass(Type src, Class<?> fallback) {
      Class<?> result = null;
      if (src instanceof Class<?>) {
        if (((Class<?>) src).isArray()) {
          result = ((Class<?>) src).getComponentType();
        } else {
          result = (Class<?>) src;
        }
      } else if (src instanceof ParameterizedType) {
        ParameterizedType parameterizedType = (ParameterizedType) src;
        int index = (parameterizedType.getRawType() instanceof Class
            && Map.class.isAssignableFrom((Class<?>) parameterizedType.getRawType())
            && parameterizedType.getActualTypeArguments().length > 1) ? 1 : 0;
        Type actualType = parameterizedType.getActualTypeArguments()[index];
        result = typeToClass(actualType, fallback);
      }
      if (result == null) {
        result = fallback;
      }
      return result;
    }

  }

  static class MyBatisMapperFactoryBeanPostProcessor implements MergedBeanDefinitionPostProcessor, BeanFactoryAware {

    private static final org.apache.commons.logging.Log LOG = LogFactory.getLog(
        MyBatisMapperFactoryBeanPostProcessor.class);

    private static final String MAPPER_FACTORY_BEAN = "org.mybatis.spring.mapper.MapperFactoryBean";

    private ConfigurableBeanFactory beanFactory;

    @Override
    public void setBeanFactory(BeanFactory beanFactory) {
      this.beanFactory = (ConfigurableBeanFactory) beanFactory;
    }

    @Override
    public void postProcessMergedBeanDefinition(RootBeanDefinition beanDefinition, Class<?> beanType, String beanName) {
      if (ClassUtils.isPresent(MAPPER_FACTORY_BEAN, this.beanFactory.getBeanClassLoader())) {
        resolveMapperFactoryBeanTypeIfNecessary(beanDefinition);
      }
    }

    private void resolveMapperFactoryBeanTypeIfNecessary(RootBeanDefinition beanDefinition) {
      if (!beanDefinition.hasBeanClass() || !MapperFactoryBean.class.isAssignableFrom(beanDefinition.getBeanClass())) {
        return;
      }
      if (beanDefinition.getResolvableType().hasUnresolvableGenerics()) {
        Class<?> mapperInterface = getMapperInterface(beanDefinition);
        if (mapperInterface != null) {
          // Exposes a generic type information to context for prevent early initializing
          beanDefinition
              .setTargetType(ResolvableType.forClassWithGenerics(beanDefinition.getBeanClass(), mapperInterface));
        }
      }
    }

    private Class<?> getMapperInterface(RootBeanDefinition beanDefinition) {
      try {
        return (Class<?>) beanDefinition.getPropertyValues().get("mapperInterface");
      }
      catch (Exception e) {
        LOG.debug("Fail getting mapper interface type.", e);
        return null;
      }
    }

  }
}
Clone this wiki locally