From e49ab00d905fea86a46123c3d752f235fdfd2f74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Deleuze?= Date: Wed, 13 Nov 2019 16:58:48 +0100 Subject: [PATCH] Provide orNull extensions for WebFlux ServerRequest Closes gh-23761 --- .../server/ServerRequestExtensions.kt | 56 ++++++++++++ .../server/ServerRequestExtensionsTests.kt | 90 +++++++++++++++++++ 2 files changed, 146 insertions(+) diff --git a/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/ServerRequestExtensions.kt b/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/ServerRequestExtensions.kt index 7847bc627d0a..3e98680613e9 100644 --- a/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/ServerRequestExtensions.kt +++ b/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/ServerRequestExtensions.kt @@ -21,11 +21,14 @@ import kotlinx.coroutines.reactive.awaitFirstOrNull import kotlinx.coroutines.reactive.awaitSingle import kotlinx.coroutines.reactive.asFlow import org.springframework.core.ParameterizedTypeReference +import org.springframework.http.MediaType import org.springframework.http.codec.multipart.Part +import org.springframework.util.CollectionUtils import org.springframework.util.MultiValueMap import org.springframework.web.server.WebSession import reactor.core.publisher.Flux import reactor.core.publisher.Mono +import java.net.InetSocketAddress import java.security.Principal /** @@ -112,3 +115,56 @@ suspend fun ServerRequest.awaitPrincipal(): Principal? = */ suspend fun ServerRequest.awaitSession(): WebSession = session().awaitSingle() + +/** + * Nullable variant of [ServerRequest.remoteAddress] + * + * @author Sebastien Deleuze + * @since 5.2.2 + */ +fun ServerRequest.remoteAddressOrNull(): InetSocketAddress? = remoteAddress().orElse(null) + +/** + * Nullable variant of [ServerRequest.attribute] + * + * @author Sebastien Deleuze + * @since 5.2.2 + */ +fun ServerRequest.attributeOrNull(name: String): Any? = attributes()[name] + +/** + * Nullable variant of [ServerRequest.queryParam] + * + * @author Sebastien Deleuze + * @since 5.2.2 + */ +fun ServerRequest.queryParamOrNull(name: String): String? { + val queryParamValues = queryParams()[name] + return if (CollectionUtils.isEmpty(queryParamValues)) { + null + } else { + var value: String? = queryParamValues!![0] + if (value == null) { + value = "" + } + value + } +} + +/** + * Nullable variant of [ServerRequest.Headers.contentLength] + * + * @author Sebastien Deleuze + * @since 5.2.2 + */ +fun ServerRequest.Headers.contentLengthOrNull(): Long? = + contentLength().run { if (isPresent) asLong else null } + +/** + * Nullable variant of [ServerRequest.Headers.contentType] + * + * @author Sebastien Deleuze + * @since 5.2.2 + */ +fun ServerRequest.Headers.contentTypeOrNull(): MediaType? = + contentType().orElse(null) \ No newline at end of file diff --git a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/ServerRequestExtensionsTests.kt b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/ServerRequestExtensionsTests.kt index 6d1b7960306f..4922b0409704 100644 --- a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/ServerRequestExtensionsTests.kt +++ b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/ServerRequestExtensionsTests.kt @@ -23,11 +23,15 @@ import kotlinx.coroutines.runBlocking import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.Test import org.springframework.core.ParameterizedTypeReference +import org.springframework.http.MediaType import org.springframework.http.codec.multipart.Part +import org.springframework.util.CollectionUtils import org.springframework.util.MultiValueMap import org.springframework.web.server.WebSession import reactor.core.publisher.Mono +import java.net.InetSocketAddress import java.security.Principal +import java.util.* /** * Mock object based tests for [ServerRequest] Kotlin extensions. @@ -38,6 +42,8 @@ class ServerRequestExtensionsTests { val request = mockk(relaxed = true) + val headers = mockk(relaxed = true) + @Test fun `bodyToMono with reified type parameters`() { request.bodyToMono>() @@ -108,6 +114,90 @@ class ServerRequestExtensionsTests { } } + @Test + fun `remoteAddressOrNull with value`() { + val remoteAddress = InetSocketAddress(1234) + every { request.remoteAddress() } returns Optional.of(remoteAddress) + assertThat(remoteAddress).isEqualTo(request.remoteAddressOrNull()) + verify { request.remoteAddress() } + } + + @Test + fun `remoteAddressOrNull with null`() { + every { request.remoteAddress() } returns Optional.empty() + assertThat(request.remoteAddressOrNull()).isNull() + verify { request.remoteAddress() } + } + + @Test + fun `attributeOrNull with value`() { + every { request.attributes() } returns mapOf("foo" to "bar") + assertThat(request.attributeOrNull("foo")).isEqualTo("bar") + verify { request.attributes() } + } + + @Test + fun `attributeOrNull with null`() { + every { request.attributes() } returns mapOf("foo" to "bar") + assertThat(request.attributeOrNull("baz")).isNull() + verify { request.attributes() } + } + + @Test + fun `queryParamOrNull with value`() { + every { request.queryParams() } returns CollectionUtils.toMultiValueMap(mapOf("foo" to listOf("bar"))) + assertThat(request.queryParamOrNull("foo")).isEqualTo("bar") + verify { request.queryParams() } + } + + @Test + fun `queryParamOrNull with values`() { + every { request.queryParams() } returns CollectionUtils.toMultiValueMap(mapOf("foo" to listOf("bar", "bar"))) + assertThat(request.queryParamOrNull("foo")).isEqualTo("bar") + verify { request.queryParams() } + } + + @Test + fun `queryParamOrNull with null value`() { + every { request.queryParams() } returns CollectionUtils.toMultiValueMap(mapOf("foo" to listOf(null))) + assertThat(request.queryParamOrNull("foo")).isEqualTo("") + verify { request.queryParams() } + } + + @Test + fun `queryParamOrNull with null`() { + every { request.queryParams() } returns CollectionUtils.toMultiValueMap(mapOf("foo" to listOf("bar"))) + assertThat(request.queryParamOrNull("baz")).isNull() + verify { request.queryParams() } + } + + @Test + fun `contentLengthOrNull with value`() { + every { headers.contentLength() } returns OptionalLong.of(123) + assertThat(headers.contentLengthOrNull()).isEqualTo(123) + verify { headers.contentLength() } + } + + @Test + fun `contentLengthOrNull with null`() { + every { headers.contentLength() } returns OptionalLong.empty() + assertThat(headers.contentLengthOrNull()).isNull() + verify { headers.contentLength() } + } + + @Test + fun `contentTypeOrNull with value`() { + every { headers.contentType() } returns Optional.of(MediaType.APPLICATION_JSON) + assertThat(headers.contentTypeOrNull()).isEqualTo(MediaType.APPLICATION_JSON) + verify { headers.contentType() } + } + + @Test + fun `contentTypeOrNull with null`() { + every { headers.contentType() } returns Optional.empty() + assertThat(headers.contentTypeOrNull()).isNull() + verify { headers.contentType() } + } class Foo }