Skip to content

Commit

Permalink
add InitDestroyAnnotationBeanPostProcessor to BeanFactoryInitializati…
Browse files Browse the repository at this point in the history
…onAotProcessor

closes spring-projectsgh-30755

Signed-off-by: Dmitrii Bocharov <bdshadow@gmail.com>
  • Loading branch information
bdshadow committed Nov 26, 2023
1 parent dd97dee commit 5fa4e74
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,34 @@
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Stream;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import org.springframework.aot.generate.GeneratedMethod;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.hint.ExecutableMode;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanCreationException;
import org.springframework.beans.factory.aot.BeanFactoryInitializationAotContribution;
import org.springframework.beans.factory.aot.BeanFactoryInitializationAotProcessor;
import org.springframework.beans.factory.aot.BeanFactoryInitializationCode;
import org.springframework.beans.factory.aot.BeanRegistrationAotContribution;
import org.springframework.beans.factory.aot.BeanRegistrationAotProcessor;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.config.DestructionAwareBeanPostProcessor;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.beans.factory.support.MergedBeanDefinitionPostProcessor;
import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.core.Ordered;
import org.springframework.core.PriorityOrdered;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.MethodSpec;
import org.springframework.lang.Nullable;
import org.springframework.util.ClassUtils;
import org.springframework.util.CollectionUtils;
import org.springframework.util.ReflectionUtils;

/**
Expand Down Expand Up @@ -85,7 +94,7 @@
*/
@SuppressWarnings("serial")
public class InitDestroyAnnotationBeanPostProcessor implements DestructionAwareBeanPostProcessor,
MergedBeanDefinitionPostProcessor, BeanRegistrationAotProcessor, PriorityOrdered, Serializable {
MergedBeanDefinitionPostProcessor, BeanRegistrationAotProcessor, BeanFactoryInitializationAotProcessor, PriorityOrdered, Serializable {

private final transient LifecycleMetadata emptyLifecycleMetadata =
new LifecycleMetadata(Object.class, Collections.emptyList(), Collections.emptyList()) {
Expand Down Expand Up @@ -188,15 +197,22 @@ public BeanRegistrationAotContribution processAheadOfTime(RegisteredBean registe
RootBeanDefinition beanDefinition = registeredBean.getMergedBeanDefinition();
beanDefinition.resolveDestroyMethodIfNecessary();
LifecycleMetadata metadata = findLifecycleMetadata(beanDefinition, registeredBean.getBeanClass());
if (!CollectionUtils.isEmpty(metadata.initMethods)) {
String[] initMethodNames = safeMerge(beanDefinition.getInitMethodNames(), metadata.initMethods);
beanDefinition.setInitMethodNames(initMethodNames);
}
if (!CollectionUtils.isEmpty(metadata.destroyMethods)) {
String[] destroyMethodNames = safeMerge(beanDefinition.getDestroyMethodNames(), metadata.destroyMethods);
beanDefinition.setDestroyMethodNames(destroyMethodNames);
return (generationContext, beanRegistrationCode) -> {
metadata.initMethods.forEach(lm -> registerLifecycleMethodForInvoke(generationContext, lm));
metadata.destroyMethods.forEach(lm -> registerLifecycleMethodForInvoke(generationContext, lm));
};
}

private void registerLifecycleMethodForInvoke(GenerationContext generationContext, LifecycleMethod lm) {
generationContext.getRuntimeHints().reflection().registerMethod(lm.getMethod(), ExecutableMode.INVOKE);
}

@Override
public BeanFactoryInitializationAotContribution processAheadOfTime(ConfigurableListableBeanFactory beanFactory) {
if (this.initAnnotationTypes.isEmpty() && this.destroyAnnotationTypes.isEmpty()) {
return null;
}
return null;
return new BeanFactoryAotContribution(this.initAnnotationTypes, this.destroyAnnotationTypes);
}

private LifecycleMetadata findLifecycleMetadata(RootBeanDefinition beanDefinition, Class<?> beanClass) {
Expand All @@ -205,13 +221,6 @@ private LifecycleMetadata findLifecycleMetadata(RootBeanDefinition beanDefinitio
return metadata;
}

private static String[] safeMerge(@Nullable String[] existingNames, Collection<LifecycleMethod> detectedMethods) {
Stream<String> detectedNames = detectedMethods.stream().map(LifecycleMethod::getIdentifier);
Stream<String> mergedNames = (existingNames != null ?
Stream.concat(detectedNames, Stream.of(existingNames)) : detectedNames);
return mergedNames.distinct().toArray(String[]::new);
}

@Override
public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
LifecycleMetadata metadata = findLifecycleMetadata(bean.getClass());
Expand Down Expand Up @@ -486,4 +495,33 @@ private static boolean isPrivateOrNotVisible(Method method, Class<?> beanClass)

}

private record BeanFactoryAotContribution(
Set<Class<? extends Annotation>> initAnnotationTypes,
Set<Class<? extends Annotation>> destroyAnnotationTypes) implements BeanFactoryInitializationAotContribution {
private static final String BEAN_FACTORY_PARAMETER_NAME = "beanFactory";
private static final String POST_PROCESSOR_PARAMETER_NAME = "postProcessor";

@Override
public void applyTo(GenerationContext generationContext, BeanFactoryInitializationCode beanFactoryInitializationCode) {
// to generate a unique name in case there are multiple InitDestroyAnnotationBeanPostProcessor-s
String[] methodNameParts = {"addInitDestroyBeanPostProcessorMethod"};
GeneratedMethod generatedMethod = beanFactoryInitializationCode.getMethods()
.add(methodNameParts, this::generateAddInitDestroyBeanPostProcessorMethod);
beanFactoryInitializationCode.addInitializer(generatedMethod.toMethodReference());
}

private void generateAddInitDestroyBeanPostProcessorMethod(MethodSpec.Builder method) {
method.addJavadoc("Apply known externally managed init/destroy annotation bean post processors");
method.addModifiers(javax.lang.model.element.Modifier.PRIVATE);
method.addParameter(DefaultListableBeanFactory.class, BEAN_FACTORY_PARAMETER_NAME);

CodeBlock.Builder code = CodeBlock.builder();
code.addStatement("$T $L = new $T()",
InitDestroyAnnotationBeanPostProcessor.class, POST_PROCESSOR_PARAMETER_NAME, InitDestroyAnnotationBeanPostProcessor.class);
this.initAnnotationTypes.forEach(type -> code.addStatement("$L.addInitAnnotationType($T.class)", POST_PROCESSOR_PARAMETER_NAME, ClassName.get(type)));
this.destroyAnnotationTypes.forEach(type -> code.addStatement("$L.addDestroyAnnotationType($T.class)", POST_PROCESSOR_PARAMETER_NAME, ClassName.get(type)));
code.addStatement("$L.addBeanPostProcessor($L)", BEAN_FACTORY_PARAMETER_NAME, POST_PROCESSOR_PARAMETER_NAME);
method.addCode(code.build());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import org.springframework.beans.factory.DisposableBean;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.aot.BeanRegistrationAotContribution;
import org.springframework.beans.factory.support.AbstractBeanDefinition;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.beans.factory.support.RegisteredBean;
Expand Down Expand Up @@ -57,21 +58,23 @@ void processAheadOfTimeWhenNoCallbackDoesNotMutateRootBeanDefinition() {
@Test
void processAheadOfTimeWhenHasInitDestroyAnnotationsAddsMethodNames() {
RootBeanDefinition beanDefinition = new RootBeanDefinition(InitDestroyBean.class);
processAheadOfTime(beanDefinition);
BeanRegistrationAotContribution beanRegistrationAotContribution = processAheadOfTime(beanDefinition);
assertThat(beanRegistrationAotContribution).isNotNull();
RootBeanDefinition mergedBeanDefinition = getMergedBeanDefinition();
assertThat(mergedBeanDefinition.getInitMethodNames()).containsExactly("initMethod");
assertThat(mergedBeanDefinition.getDestroyMethodNames()).containsExactly("destroyMethod");
assertThat(mergedBeanDefinition.getInitMethodNames()).isNull();
assertThat(mergedBeanDefinition.getDestroyMethodNames()).isNull();
}

@Test
void processAheadOfTimeWhenHasInitDestroyAnnotationsAndCustomDefinedMethodNamesAddsMethodNames() {
RootBeanDefinition beanDefinition = new RootBeanDefinition(InitDestroyBean.class);
beanDefinition.setInitMethodName("customInitMethod");
beanDefinition.setDestroyMethodNames("customDestroyMethod");
processAheadOfTime(beanDefinition);
BeanRegistrationAotContribution beanRegistrationAotContribution = processAheadOfTime(beanDefinition);
assertThat(beanRegistrationAotContribution).isNotNull();
RootBeanDefinition mergedBeanDefinition = getMergedBeanDefinition();
assertThat(mergedBeanDefinition.getInitMethodNames()).containsExactly("initMethod", "customInitMethod");
assertThat(mergedBeanDefinition.getDestroyMethodNames()).containsExactly("destroyMethod", "customDestroyMethod");
assertThat(mergedBeanDefinition.getInitMethodNames()).containsExactly("customInitMethod");
assertThat(mergedBeanDefinition.getDestroyMethodNames()).containsExactly("customDestroyMethod");
}

@Test
Expand Down Expand Up @@ -108,10 +111,11 @@ void processAheadOfTimeWhenHasInferredDestroyMethodAndNoCandidateDoesNotMutateRo
@Test
void processAheadOfTimeWhenHasMultipleInitDestroyAnnotationsAddsAllMethodNames() {
RootBeanDefinition beanDefinition = new RootBeanDefinition(MultiInitDestroyBean.class);
processAheadOfTime(beanDefinition);
BeanRegistrationAotContribution beanRegistrationAotContribution = processAheadOfTime(beanDefinition);
assertThat(beanRegistrationAotContribution).isNotNull();
RootBeanDefinition mergedBeanDefinition = getMergedBeanDefinition();
assertThat(mergedBeanDefinition.getInitMethodNames()).containsExactly("initMethod", "anotherInitMethod");
assertThat(mergedBeanDefinition.getDestroyMethodNames()).containsExactly("anotherDestroyMethod", "destroyMethod");
assertThat(mergedBeanDefinition.getInitMethodNames()).isNull();
assertThat(mergedBeanDefinition.getDestroyMethodNames()).isNull();
}

@Test
Expand All @@ -125,27 +129,24 @@ void processAheadOfTimeWithMultipleLevelsOfPublicAndPrivateInitAndDestroyMethods
// to ensure that it will be tracked as such even though it has the same
// name as DisposableBean#destroy().
beanDefinition.setDestroyMethodNames("destroy", "customDestroy");
processAheadOfTime(beanDefinition);
BeanRegistrationAotContribution beanRegistrationAotContribution = processAheadOfTime(beanDefinition);
assertThat(beanRegistrationAotContribution).isNotNull();
RootBeanDefinition mergedBeanDefinition = getMergedBeanDefinition();
assertSoftly(softly -> {
softly.assertThat(mergedBeanDefinition.getInitMethodNames()).containsExactly(
CustomAnnotatedPrivateInitDestroyBean.class.getName() + ".privateInit", // fully-qualified private method
CustomAnnotatedPrivateSameNameInitDestroyBean.class.getName() + ".privateInit", // fully-qualified private method
"afterPropertiesSet",
"customInit"
);
softly.assertThat(mergedBeanDefinition.getDestroyMethodNames()).containsExactly(
CustomAnnotatedPrivateSameNameInitDestroyBean.class.getName() + ".privateDestroy", // fully-qualified private method
CustomAnnotatedPrivateInitDestroyBean.class.getName() + ".privateDestroy", // fully-qualified private method
"destroy",
"customDestroy"
);
});
}

private void processAheadOfTime(RootBeanDefinition beanDefinition) {
private BeanRegistrationAotContribution processAheadOfTime(RootBeanDefinition beanDefinition) {
RegisteredBean registeredBean = registerBean(beanDefinition);
assertThat(createAotBeanPostProcessor().processAheadOfTime(registeredBean)).isNull();
return createAotBeanPostProcessor().processAheadOfTime(registeredBean);
}

private RegisteredBean registerBean(RootBeanDefinition beanDefinition) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,17 +185,17 @@ void jakartaAnnotationsWithCustomSameMethodNamesWithAotProcessingAndAotRuntime()
CustomAnnotatedPrivateSameNameInitDestroyBean bean = aotApplicationContext.getBean("lifecycleTestBean", beanClass);

assertThat(bean.initMethods).as("init-methods").containsExactly(
"afterPropertiesSet",
"@PostConstruct.privateCustomInit1",
"@PostConstruct.sameNameCustomInit1",
"afterPropertiesSet",
"customInit"
);

aotApplicationContext.close();
assertThat(bean.destroyMethods).as("destroy-methods").containsExactly(
"destroy",
"@PreDestroy.sameNameCustomDestroy1",
"@PreDestroy.privateCustomDestroy1",
"destroy",
"customDestroy"
);
});
Expand All @@ -220,17 +220,17 @@ void jakartaAnnotationsWithPackagePrivateInitDestroyMethodsWithAotProcessingAndA
SubPackagePrivateInitDestroyBean bean = aotApplicationContext.getBean("lifecycleTestBean", beanClass);

assertThat(bean.initMethods).as("init-methods").containsExactly(
"InitializingBean.afterPropertiesSet",
"PackagePrivateInitDestroyBean.postConstruct",
"SubPackagePrivateInitDestroyBean.postConstruct",
"InitializingBean.afterPropertiesSet",
"initMethod"
);

aotApplicationContext.close();
assertThat(bean.destroyMethods).as("destroy-methods").containsExactly(
"DisposableBean.destroy",
"SubPackagePrivateInitDestroyBean.preDestroy",
"PackagePrivateInitDestroyBean.preDestroy",
"DisposableBean.destroy",
"destroyMethod"
);
});
Expand Down

0 comments on commit 5fa4e74

Please sign in to comment.