diff --git a/spring-messaging/src/main/java/org/springframework/messaging/support/MessageBuilder.java b/spring-messaging/src/main/java/org/springframework/messaging/support/MessageBuilder.java index e722bafc8ace..cd909f4baf45 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/support/MessageBuilder.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/support/MessageBuilder.java @@ -99,6 +99,7 @@ public MessageBuilder removeHeaders(String... headerPatterns) { this.headerAccessor.removeHeaders(headerPatterns); return this; } + /** * Remove the value for the given header name. */ @@ -153,7 +154,11 @@ public Message build() { } MessageHeaders headersToUse = this.headerAccessor.toMessageHeaders(); if (this.payload instanceof Throwable) { - return (Message) new ErrorMessage((Throwable) this.payload, headersToUse); + Message originalMessage = null; + if (this.originalMessage != null && this.originalMessage instanceof ErrorMessage) { + originalMessage = ((ErrorMessage) this.originalMessage).getOriginalMessage(); + } + return (Message) new ErrorMessage((Throwable) this.payload, headersToUse, originalMessage); } else { return new GenericMessage<>(this.payload, headersToUse); @@ -165,6 +170,10 @@ public Message build() { * Create a builder for a new {@link Message} instance pre-populated with all of the * headers copied from the provided message. The payload of the provided Message will * also be used as the payload for the new message. + * + * If the provided message is an {@link ErrorMessage} - the + * {@link ErrorMessage#originalMessage} link will be provided to the new instance. + * * @param message the Message from which the payload and all headers will be copied */ public static MessageBuilder fromMessage(Message message) { diff --git a/spring-messaging/src/test/java/org/springframework/messaging/support/MessageBuilderTests.java b/spring-messaging/src/test/java/org/springframework/messaging/support/MessageBuilderTests.java index e0a0b8ccb067..3e26ba42c2c9 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/support/MessageBuilderTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/support/MessageBuilderTests.java @@ -16,6 +16,7 @@ package org.springframework.messaging.support; +import java.util.Collections; import java.util.Date; import java.util.HashMap; import java.util.Map; @@ -107,6 +108,20 @@ public void createFromMessage() { assertThat(message2.getHeaders().get("foo")).isEqualTo("bar"); } + @Test + public void createErrorMessageFromErrorMessage() { + Message originalMessage = MessageBuilder.withPayload("test") + .setHeader("foo", "bar").build(); + RuntimeException errorPayload = new RuntimeException(); + ErrorMessage errorMessage1 = new ErrorMessage(errorPayload, Collections.singletonMap("baz", "42"), originalMessage); + Message errorMessage2 = MessageBuilder.fromMessage(errorMessage1).build(); + assertThat(errorMessage2).isExactlyInstanceOf(ErrorMessage.class); + ErrorMessage actual = (ErrorMessage) errorMessage2; + assertThat(actual.getPayload()).isSameAs(errorPayload); + assertThat(actual.getHeaders().get("baz")).isEqualTo("42"); + assertThat(actual.getOriginalMessage()).isSameAs(originalMessage); + } + @Test public void createIdRegenerated() { Message message1 = MessageBuilder.withPayload("test") @@ -119,20 +134,20 @@ public void createIdRegenerated() { @Test public void testRemove() { Message message1 = MessageBuilder.withPayload(1) - .setHeader("foo", "bar").build(); + .setHeader("foo", "bar").build(); Message message2 = MessageBuilder.fromMessage(message1) - .removeHeader("foo") - .build(); + .removeHeader("foo") + .build(); assertThat(message2.getHeaders().containsKey("foo")).isFalse(); } @Test public void testSettingToNullRemoves() { Message message1 = MessageBuilder.withPayload(1) - .setHeader("foo", "bar").build(); + .setHeader("foo", "bar").build(); Message message2 = MessageBuilder.fromMessage(message1) - .setHeader("foo", null) - .build(); + .setHeader("foo", null) + .build(); assertThat(message2.getHeaders().containsKey("foo")).isFalse(); } @@ -192,7 +207,7 @@ public void testBuildMessageWithDefaultMutability() { assertThatIllegalStateException().isThrownBy(() -> accessor.setHeader("foo", "bar")) - .withMessageContaining("Already immutable"); + .withMessageContaining("Already immutable"); assertThat(MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class)).isSameAs(accessor); }