Skip to content

Commit

Permalink
Support private init/destroy methods in AOT mode
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrannen committed Jun 21, 2023
1 parent adcdefc commit e8bbcd0
Show file tree
Hide file tree
Showing 6 changed files with 246 additions and 60 deletions.
Expand Up @@ -153,15 +153,15 @@ public int getOrder() {

@Override
public void postProcessMergedBeanDefinition(RootBeanDefinition beanDefinition, Class<?> beanType, String beanName) {
findInjectionMetadata(beanDefinition, beanType);
findLifecycleMetadata(beanDefinition, beanType);
}

@Override
@Nullable
public BeanRegistrationAotContribution processAheadOfTime(RegisteredBean registeredBean) {
RootBeanDefinition beanDefinition = registeredBean.getMergedBeanDefinition();
beanDefinition.resolveDestroyMethodIfNecessary();
LifecycleMetadata metadata = findInjectionMetadata(beanDefinition, registeredBean.getBeanClass());
LifecycleMetadata metadata = findLifecycleMetadata(beanDefinition, registeredBean.getBeanClass());
if (!CollectionUtils.isEmpty(metadata.initMethods)) {
String[] initMethodNames = safeMerge(beanDefinition.getInitMethodNames(), metadata.initMethods);
beanDefinition.setInitMethodNames(initMethodNames);
Expand All @@ -173,16 +173,16 @@ public BeanRegistrationAotContribution processAheadOfTime(RegisteredBean registe
return null;
}

private LifecycleMetadata findInjectionMetadata(RootBeanDefinition beanDefinition, Class<?> beanType) {
private LifecycleMetadata findLifecycleMetadata(RootBeanDefinition beanDefinition, Class<?> beanType) {
LifecycleMetadata metadata = findLifecycleMetadata(beanType);
metadata.checkInitDestroyMethods(beanDefinition);
return metadata;
}

private String[] safeMerge(@Nullable String[] existingNames, Collection<LifecycleElement> detectedElements) {
Stream<String> detectedNames = detectedElements.stream().map(LifecycleElement::getIdentifier);
private String[] safeMerge(@Nullable String[] existingNames, Collection<LifecycleMethod> detectedElements) {
Stream<String> detectedNames = detectedElements.stream().map(LifecycleMethod::getIdentifier);
Stream<String> mergedNames = (existingNames != null ?
Stream.concat(Stream.of(existingNames), detectedNames) : detectedNames);
Stream.concat(detectedNames, Stream.of(existingNames)) : detectedNames);
return mergedNames.distinct().toArray(String[]::new);
}

Expand Down Expand Up @@ -257,24 +257,23 @@ private LifecycleMetadata buildLifecycleMetadata(final Class<?> clazz) {
return this.emptyLifecycleMetadata;
}

List<LifecycleElement> initMethods = new ArrayList<>();
List<LifecycleElement> destroyMethods = new ArrayList<>();
List<LifecycleMethod> initMethods = new ArrayList<>();
List<LifecycleMethod> destroyMethods = new ArrayList<>();
Class<?> targetClass = clazz;

do {
final List<LifecycleElement> currInitMethods = new ArrayList<>();
final List<LifecycleElement> currDestroyMethods = new ArrayList<>();
final List<LifecycleMethod> currInitMethods = new ArrayList<>();
final List<LifecycleMethod> currDestroyMethods = new ArrayList<>();

ReflectionUtils.doWithLocalMethods(targetClass, method -> {
if (this.initAnnotationType != null && method.isAnnotationPresent(this.initAnnotationType)) {
LifecycleElement element = new LifecycleElement(method);
currInitMethods.add(element);
currInitMethods.add(new LifecycleMethod(method));
if (logger.isTraceEnabled()) {
logger.trace("Found init method on class [" + clazz.getName() + "]: " + method);
}
}
if (this.destroyAnnotationType != null && method.isAnnotationPresent(this.destroyAnnotationType)) {
currDestroyMethods.add(new LifecycleElement(method));
currDestroyMethods.add(new LifecycleMethod(method));
if (logger.isTraceEnabled()) {
logger.trace("Found destroy method on class [" + clazz.getName() + "]: " + method);
}
Expand Down Expand Up @@ -312,27 +311,27 @@ private class LifecycleMetadata {

private final Class<?> targetClass;

private final Collection<LifecycleElement> initMethods;
private final Collection<LifecycleMethod> initMethods;

private final Collection<LifecycleElement> destroyMethods;
private final Collection<LifecycleMethod> destroyMethods;

@Nullable
private volatile Set<LifecycleElement> checkedInitMethods;
private volatile Set<LifecycleMethod> checkedInitMethods;

@Nullable
private volatile Set<LifecycleElement> checkedDestroyMethods;
private volatile Set<LifecycleMethod> checkedDestroyMethods;

public LifecycleMetadata(Class<?> targetClass, Collection<LifecycleElement> initMethods,
Collection<LifecycleElement> destroyMethods) {
public LifecycleMetadata(Class<?> targetClass, Collection<LifecycleMethod> initMethods,
Collection<LifecycleMethod> destroyMethods) {

this.targetClass = targetClass;
this.initMethods = initMethods;
this.destroyMethods = destroyMethods;
}

public void checkInitDestroyMethods(RootBeanDefinition beanDefinition) {
Set<LifecycleElement> checkedInitMethods = new LinkedHashSet<>(this.initMethods.size());
for (LifecycleElement element : this.initMethods) {
Set<LifecycleMethod> checkedInitMethods = new LinkedHashSet<>(this.initMethods.size());
for (LifecycleMethod element : this.initMethods) {
String methodIdentifier = element.getIdentifier();
if (!beanDefinition.isExternallyManagedInitMethod(methodIdentifier)) {
beanDefinition.registerExternallyManagedInitMethod(methodIdentifier);
Expand All @@ -342,8 +341,8 @@ public void checkInitDestroyMethods(RootBeanDefinition beanDefinition) {
}
}
}
Set<LifecycleElement> checkedDestroyMethods = new LinkedHashSet<>(this.destroyMethods.size());
for (LifecycleElement element : this.destroyMethods) {
Set<LifecycleMethod> checkedDestroyMethods = new LinkedHashSet<>(this.destroyMethods.size());
for (LifecycleMethod element : this.destroyMethods) {
String methodIdentifier = element.getIdentifier();
if (!beanDefinition.isExternallyManagedDestroyMethod(methodIdentifier)) {
beanDefinition.registerExternallyManagedDestroyMethod(methodIdentifier);
Expand All @@ -358,11 +357,11 @@ public void checkInitDestroyMethods(RootBeanDefinition beanDefinition) {
}

public void invokeInitMethods(Object target, String beanName) throws Throwable {
Collection<LifecycleElement> checkedInitMethods = this.checkedInitMethods;
Collection<LifecycleElement> initMethodsToIterate =
Collection<LifecycleMethod> checkedInitMethods = this.checkedInitMethods;
Collection<LifecycleMethod> initMethodsToIterate =
(checkedInitMethods != null ? checkedInitMethods : this.initMethods);
if (!initMethodsToIterate.isEmpty()) {
for (LifecycleElement element : initMethodsToIterate) {
for (LifecycleMethod element : initMethodsToIterate) {
if (logger.isTraceEnabled()) {
logger.trace("Invoking init method on bean '" + beanName + "': " + element.getMethod());
}
Expand All @@ -372,11 +371,11 @@ public void invokeInitMethods(Object target, String beanName) throws Throwable {
}

public void invokeDestroyMethods(Object target, String beanName) throws Throwable {
Collection<LifecycleElement> checkedDestroyMethods = this.checkedDestroyMethods;
Collection<LifecycleElement> destroyMethodsToUse =
Collection<LifecycleMethod> checkedDestroyMethods = this.checkedDestroyMethods;
Collection<LifecycleMethod> destroyMethodsToUse =
(checkedDestroyMethods != null ? checkedDestroyMethods : this.destroyMethods);
if (!destroyMethodsToUse.isEmpty()) {
for (LifecycleElement element : destroyMethodsToUse) {
for (LifecycleMethod element : destroyMethodsToUse) {
if (logger.isTraceEnabled()) {
logger.trace("Invoking destroy method on bean '" + beanName + "': " + element.getMethod());
}
Expand All @@ -386,26 +385,26 @@ public void invokeDestroyMethods(Object target, String beanName) throws Throwabl
}

public boolean hasDestroyMethods() {
Collection<LifecycleElement> checkedDestroyMethods = this.checkedDestroyMethods;
Collection<LifecycleElement> destroyMethodsToUse =
Collection<LifecycleMethod> checkedDestroyMethods = this.checkedDestroyMethods;
Collection<LifecycleMethod> destroyMethodsToUse =
(checkedDestroyMethods != null ? checkedDestroyMethods : this.destroyMethods);
return !destroyMethodsToUse.isEmpty();
}
}


/**
* Class representing injection information about an annotated method.
* Class representing an annotated init or destroy methods.
*/
private static class LifecycleElement {
private static class LifecycleMethod {

private final Method method;

private final String identifier;

public LifecycleElement(Method method) {
public LifecycleMethod(Method method) {
if (method.getParameterCount() != 0) {
throw new IllegalStateException("Lifecycle method annotation requires a no-arg method: " + method);
throw new IllegalStateException("Lifecycle annotation requires a no-arg method: " + method);
}
this.method = method;
this.identifier = (Modifier.isPrivate(method.getModifiers()) ?
Expand All @@ -422,18 +421,13 @@ public String getIdentifier() {

public void invoke(Object target) throws Throwable {
ReflectionUtils.makeAccessible(this.method);
this.method.invoke(target, (Object[]) null);
this.method.invoke(target);
}

@Override
public boolean equals(@Nullable Object other) {
if (this == other) {
return true;
}
if (!(other instanceof LifecycleElement otherElement)) {
return false;
}
return (this.identifier.equals(otherElement.identifier));
return (this == other || (other instanceof LifecycleMethod that &&
this.identifier.equals(that.identifier)));
}

@Override
Expand Down
Expand Up @@ -72,6 +72,7 @@
*
* @author Phillip Webb
* @author Stephane Nicoll
* @author Sam Brannen
* @since 6.0
*/
class BeanDefinitionPropertiesCodeGenerator {
Expand Down Expand Up @@ -138,6 +139,7 @@ private void addInitDestroyMethods(Builder code,
}

private void addInitDestroyHint(Class<?> beanUserClass, String methodName) {
// TODO Handle fully-qualified method names
Method method = ReflectionUtils.findMethod(beanUserClass, methodName);
if (method != null) {
this.hints.reflection().registerMethod(method, ExecutableMode.INVOKE);
Expand Down
Expand Up @@ -1832,26 +1832,40 @@ protected void invokeInitMethods(String beanName, Object bean, @Nullable RootBea

/**
* Invoke the specified custom init method on the given bean.
* Called by invokeInitMethods.
* <p>Can be overridden in subclasses for custom resolution of init
* methods with arguments.
* <p>Called by {@link #invokeInitMethods(String, Object, RootBeanDefinition)}.
* <p>Can be overridden in subclasses for custom resolution of init methods
* with arguments.
* @see #invokeInitMethods
*/
protected void invokeCustomInitMethod(String beanName, Object bean, RootBeanDefinition mbd, String initMethodName)
throws Throwable {

Class<?> beanClass = bean.getClass();
Class<?> methodDeclaringClass = beanClass;
String methodName = initMethodName;

// Parse fully-qualified method name if necessary.
int indexOfDot = initMethodName.lastIndexOf('.');
if (indexOfDot > 0) {
String className = initMethodName.substring(0, indexOfDot);
methodName = initMethodName.substring(indexOfDot + 1);
if (!beanClass.getName().equals((className))) {
methodDeclaringClass = ClassUtils.forName(className, beanClass.getClassLoader());
}
}

Method initMethod = (mbd.isNonPublicAccessAllowed() ?
BeanUtils.findMethod(bean.getClass(), initMethodName) :
ClassUtils.getMethodIfAvailable(bean.getClass(), initMethodName));
BeanUtils.findMethod(methodDeclaringClass, methodName) :
ClassUtils.getMethodIfAvailable(beanClass, methodName));

if (initMethod == null) {
if (mbd.isEnforceInitMethod()) {
throw new BeanDefinitionValidationException("Could not find an init method named '" +
initMethodName + "' on bean with name '" + beanName + "'");
methodName + "' on bean with name '" + beanName + "'");
}
else {
if (logger.isTraceEnabled()) {
logger.trace("No default init method named '" + initMethodName +
logger.trace("No default init method named '" + methodName +
"' found on bean with name '" + beanName + "'");
}
// Ignore non-existent default lifecycle methods.
Expand All @@ -1860,9 +1874,9 @@ protected void invokeCustomInitMethod(String beanName, Object bean, RootBeanDefi
}

if (logger.isTraceEnabled()) {
logger.trace("Invoking init method '" + initMethodName + "' on bean with name '" + beanName + "'");
logger.trace("Invoking init method '" + methodName + "' on bean with name '" + beanName + "'");
}
Method methodToInvoke = ClassUtils.getInterfaceMethodIfPossible(initMethod, bean.getClass());
Method methodToInvoke = ClassUtils.getInterfaceMethodIfPossible(initMethod, beanClass);

try {
ReflectionUtils.makeAccessible(methodToInvoke);
Expand Down
Expand Up @@ -255,19 +255,32 @@ else if (this.destroyMethodNames != null) {
private Method determineDestroyMethod(String name) {
try {
Class<?> beanClass = this.bean.getClass();
Method destroyMethod = findDestroyMethod(beanClass, name);
Class<?> methodDeclaringClass = beanClass;
String methodName = name;

// Parse fully-qualified method name if necessary.
int indexOfDot = name.lastIndexOf('.');
if (indexOfDot > 0) {
String className = name.substring(0, indexOfDot);
methodName = name.substring(indexOfDot + 1);
if (!beanClass.getName().equals((className))) {
methodDeclaringClass = ClassUtils.forName(className, beanClass.getClassLoader());
}
}

Method destroyMethod = findDestroyMethod(methodDeclaringClass, methodName);
if (destroyMethod != null) {
return destroyMethod;
}
for (Class<?> beanInterface : beanClass.getInterfaces()) {
destroyMethod = findDestroyMethod(beanInterface, name);
destroyMethod = findDestroyMethod(beanInterface, methodName);
if (destroyMethod != null) {
return destroyMethod;
}
}
return null;
}
catch (IllegalArgumentException ex) {
catch (ClassNotFoundException | IllegalArgumentException ex) {
throw new BeanDefinitionValidationException("Could not find unique destroy method on bean with name '" +
this.beanName + ": " + ex.getMessage());
}
Expand Down

0 comments on commit e8bbcd0

Please sign in to comment.