Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support [package-]private init/destroy methods in AOT mode #30724

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -182,7 +182,7 @@ private LifecycleMetadata findLifecycleMetadata(RootBeanDefinition beanDefinitio
private static String[] safeMerge(@Nullable String[] existingNames, Collection<LifecycleMethod> detectedMethods) {
Stream<String> detectedNames = detectedMethods.stream().map(LifecycleMethod::getIdentifier);
Stream<String> mergedNames = (existingNames != null ?
Stream.concat(Stream.of(existingNames), detectedNames) : detectedNames);
Stream.concat(detectedNames, Stream.of(existingNames)) : detectedNames);
return mergedNames.distinct().toArray(String[]::new);
}

Expand Down
Expand Up @@ -72,6 +72,7 @@
*
* @author Phillip Webb
* @author Stephane Nicoll
* @author Sam Brannen
* @since 6.0
*/
class BeanDefinitionPropertiesCodeGenerator {
Expand Down Expand Up @@ -138,7 +139,25 @@ private void addInitDestroyMethods(Builder code, AbstractBeanDefinition beanDefi
}

private void addInitDestroyHint(Class<?> beanUserClass, String methodName) {
Method method = ReflectionUtils.findMethod(beanUserClass, methodName);
Class<?> methodDeclaringClass = beanUserClass;

// Parse fully-qualified method name if necessary.
int indexOfDot = methodName.lastIndexOf('.');
if (indexOfDot > 0) {
String className = methodName.substring(0, indexOfDot);
methodName = methodName.substring(indexOfDot + 1);
if (!beanUserClass.getName().equals(className)) {
try {
methodDeclaringClass = ClassUtils.forName(className, beanUserClass.getClassLoader());
}
catch (Throwable ex) {
throw new IllegalStateException("Failed to load Class [" + className +
"] from ClassLoader [" + beanUserClass.getClassLoader() + "]", ex);
}
}
}

Method method = ReflectionUtils.findMethod(methodDeclaringClass, methodName);
if (method != null) {
this.hints.reflection().registerMethod(method, ExecutableMode.INVOKE);
}
Expand Down
Expand Up @@ -1841,18 +1841,22 @@ protected void invokeInitMethods(String beanName, Object bean, @Nullable RootBea
protected void invokeCustomInitMethod(String beanName, Object bean, RootBeanDefinition mbd, String initMethodName)
throws Throwable {

Class<?> beanClass = bean.getClass();
MethodDescriptor descriptor = MethodDescriptor.create(beanName, beanClass, initMethodName);
String methodName = descriptor.methodName();

Method initMethod = (mbd.isNonPublicAccessAllowed() ?
BeanUtils.findMethod(bean.getClass(), initMethodName) :
ClassUtils.getMethodIfAvailable(bean.getClass(), initMethodName));
BeanUtils.findMethod(descriptor.declaringClass(), methodName) :
ClassUtils.getMethodIfAvailable(beanClass, methodName));

if (initMethod == null) {
if (mbd.isEnforceInitMethod()) {
throw new BeanDefinitionValidationException("Could not find an init method named '" +
initMethodName + "' on bean with name '" + beanName + "'");
methodName + "' on bean with name '" + beanName + "'");
}
else {
if (logger.isTraceEnabled()) {
logger.trace("No default init method named '" + initMethodName +
logger.trace("No default init method named '" + methodName +
"' found on bean with name '" + beanName + "'");
}
// Ignore non-existent default lifecycle methods.
Expand All @@ -1861,9 +1865,9 @@ protected void invokeCustomInitMethod(String beanName, Object bean, RootBeanDefi
}

if (logger.isTraceEnabled()) {
logger.trace("Invoking init method '" + initMethodName + "' on bean with name '" + beanName + "'");
logger.trace("Invoking init method '" + methodName + "' on bean with name '" + beanName + "'");
}
Method methodToInvoke = ClassUtils.getInterfaceMethodIfPossible(initMethod, bean.getClass());
Method methodToInvoke = ClassUtils.getInterfaceMethodIfPossible(initMethod, beanClass);

try {
ReflectionUtils.makeAccessible(methodToInvoke);
Expand Down
Expand Up @@ -255,12 +255,15 @@ else if (this.destroyMethodNames != null) {
private Method determineDestroyMethod(String destroyMethodName) {
try {
Class<?> beanClass = this.bean.getClass();
Method destroyMethod = findDestroyMethod(beanClass, destroyMethodName);
MethodDescriptor descriptor = MethodDescriptor.create(this.beanName, beanClass, destroyMethodName);
String methodName = descriptor.methodName();

Method destroyMethod = findDestroyMethod(descriptor.declaringClass(), methodName);
if (destroyMethod != null) {
return destroyMethod;
}
for (Class<?> beanInterface : beanClass.getInterfaces()) {
destroyMethod = findDestroyMethod(beanInterface, destroyMethodName);
destroyMethod = findDestroyMethod(beanInterface, methodName);
if (destroyMethod != null) {
return destroyMethod;
}
Expand Down
@@ -0,0 +1,73 @@
/*
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.beans.factory.support;

import org.springframework.util.ClassUtils;

/**
* Descriptor for a {@link java.lang.reflect.Method Method} which holds a
* reference to the method's {@linkplain #declaringClass declaring class},
* {@linkplain #methodName name}, and {@linkplain #parameterTypes parameter types}.
*
* @param declaringClass the method's declaring class
* @param methodName the name of the method
* @param parameterTypes the types of parameters accepted by the method
* @author Sam Brannen
* @since 6.0.11
*/
record MethodDescriptor(Class<?> declaringClass, String methodName, Class<?>... parameterTypes) {

/**
* Create a {@link MethodDescriptor} for the supplied bean class and method name.
* <p>The supplied {@code methodName} may be a {@linkplain Method#getName()
* simple method name} or a
* {@linkplain org.springframework.util.ClassUtils#getQualifiedMethodName(Method)
* qualified method name}.
* <p>If the method name is fully qualified, this utility will parse the
* method name and its declaring class from the qualified method name and then
* attempt to load the method's declaring class using the {@link ClassLoader}
* of the supplied {@code beanClass}. Otherwise, the returned descriptor will
* reference the supplied {@code beanClass} and {@code methodName}.
* @param beanName the bean name in the factory (for debugging purposes)
* @param beanClass the bean class
* @param methodName the name of the method
* @return a new {@code MethodDescriptor}; never {@code null}
*/
static MethodDescriptor create(String beanName, Class<?> beanClass, String methodName) {
try {
Class<?> declaringClass = beanClass;
String methodNameToUse = methodName;

// Parse fully-qualified method name if necessary.
int indexOfDot = methodName.lastIndexOf('.');
if (indexOfDot > 0) {
String className = methodName.substring(0, indexOfDot);
methodNameToUse = methodName.substring(indexOfDot + 1);
if (!beanClass.getName().equals(className)) {
declaringClass = ClassUtils.forName(className, beanClass.getClassLoader());
}
}
return new MethodDescriptor(declaringClass, methodNameToUse);
}
catch (Exception | LinkageError ex) {
throw new BeanDefinitionValidationException(
"Could not create MethodDescriptor for method '%s' on bean with name '%s': %s"
.formatted(methodName, beanName, ex.getMessage()));
}
}

}
Expand Up @@ -70,8 +70,8 @@ void processAheadOfTimeWhenHasInitDestroyAnnotationsAndCustomDefinedMethodNamesA
beanDefinition.setDestroyMethodNames("customDestroyMethod");
processAheadOfTime(beanDefinition);
RootBeanDefinition mergedBeanDefinition = getMergedBeanDefinition();
assertThat(mergedBeanDefinition.getInitMethodNames()).containsExactly("customInitMethod", "initMethod");
assertThat(mergedBeanDefinition.getDestroyMethodNames()).containsExactly("customDestroyMethod", "destroyMethod");
assertThat(mergedBeanDefinition.getInitMethodNames()).containsExactly("initMethod", "customInitMethod");
assertThat(mergedBeanDefinition.getDestroyMethodNames()).containsExactly("destroyMethod", "customDestroyMethod");
}

@Test
Expand Down Expand Up @@ -129,16 +129,16 @@ void processAheadOfTimeWithMultipleLevelsOfPublicAndPrivateInitAndDestroyMethods
RootBeanDefinition mergedBeanDefinition = getMergedBeanDefinition();
assertSoftly(softly -> {
softly.assertThat(mergedBeanDefinition.getInitMethodNames()).containsExactly(
"afterPropertiesSet",
"customInit",
CustomAnnotatedPrivateInitDestroyBean.class.getName() + ".privateInit", // fully-qualified private method
CustomAnnotatedPrivateSameNameInitDestroyBean.class.getName() + ".privateInit" // fully-qualified private method
CustomAnnotatedPrivateSameNameInitDestroyBean.class.getName() + ".privateInit", // fully-qualified private method
"afterPropertiesSet",
"customInit"
);
softly.assertThat(mergedBeanDefinition.getDestroyMethodNames()).containsExactly(
"destroy",
"customDestroy",
CustomAnnotatedPrivateSameNameInitDestroyBean.class.getName() + ".privateDestroy", // fully-qualified private method
CustomAnnotatedPrivateInitDestroyBean.class.getName() + ".privateDestroy" // fully-qualified private method
CustomAnnotatedPrivateInitDestroyBean.class.getName() + ".privateDestroy", // fully-qualified private method
"destroy",
"customDestroy"
);
});
}
Expand Down
Expand Up @@ -376,6 +376,9 @@ void multipleItems() {
@Nested
class InitDestroyMethodTests {

private final String privateInitMethod = InitDestroyBean.class.getName() + ".privateInit";
private final String privateDestroyMethod = InitDestroyBean.class.getName() + ".privateDestroy";

@BeforeEach
void setTargetType() {
beanDefinition.setTargetType(InitDestroyBean.class);
Expand All @@ -393,11 +396,18 @@ void singleInitMethod() {
assertHasMethodInvokeHints(InitDestroyBean.class, "init");
}

@Test
void privateInitMethod() {
beanDefinition.setInitMethodName(privateInitMethod);
compile((beanDef, compiled) -> assertThat(beanDef.getInitMethodNames()).containsExactly(privateInitMethod));
assertHasMethodInvokeHints(InitDestroyBean.class, "privateInit");
}

@Test
void multipleInitMethods() {
beanDefinition.setInitMethodNames("init", "init2");
compile((beanDef, compiled) -> assertThat(beanDef.getInitMethodNames()).containsExactly("init", "init2"));
assertHasMethodInvokeHints(InitDestroyBean.class, "init", "init2");
beanDefinition.setInitMethodNames("init", privateInitMethod);
compile((beanDef, compiled) -> assertThat(beanDef.getInitMethodNames()).containsExactly("init", privateInitMethod));
assertHasMethodInvokeHints(InitDestroyBean.class, "init", "privateInit");
}

@Test
Expand All @@ -412,11 +422,18 @@ void singleDestroyMethod() {
assertHasMethodInvokeHints(InitDestroyBean.class, "destroy");
}

@Test
void privateDestroyMethod() {
beanDefinition.setDestroyMethodName(privateDestroyMethod);
compile((beanDef, compiled) -> assertThat(beanDef.getDestroyMethodNames()).containsExactly(privateDestroyMethod));
assertHasMethodInvokeHints(InitDestroyBean.class, "privateDestroy");
}

@Test
void multipleDestroyMethods() {
beanDefinition.setDestroyMethodNames("destroy", "destroy2");
compile((beanDef, compiled) -> assertThat(beanDef.getDestroyMethodNames()).containsExactly("destroy", "destroy2"));
assertHasMethodInvokeHints(InitDestroyBean.class, "destroy", "destroy2");
beanDefinition.setDestroyMethodNames("destroy", privateDestroyMethod);
compile((beanDef, compiled) -> assertThat(beanDef.getDestroyMethodNames()).containsExactly("destroy", privateDestroyMethod));
assertHasMethodInvokeHints(InitDestroyBean.class, "destroy", "privateDestroy");
}

}
Expand Down Expand Up @@ -461,13 +478,15 @@ static class InitDestroyBean {
void init() {
}

void init2() {
@SuppressWarnings("unused")
private void privateInit() {
}

void destroy() {
}

void destroy2() {
@SuppressWarnings("unused")
private void privateDestroy() {
}

}
Expand Down