Skip to content

Commit

Permalink
Provide orNull extensions for WebFlux ServerRequest
Browse files Browse the repository at this point in the history
Closes gh-23761
  • Loading branch information
sdeleuze committed Nov 13, 2019
1 parent 22211a0 commit 6fa9871
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 0 deletions.
Expand Up @@ -21,11 +21,14 @@ import kotlinx.coroutines.reactive.awaitFirstOrNull
import kotlinx.coroutines.reactive.awaitSingle
import kotlinx.coroutines.reactive.asFlow
import org.springframework.core.ParameterizedTypeReference
import org.springframework.http.MediaType
import org.springframework.http.codec.multipart.Part
import org.springframework.util.CollectionUtils
import org.springframework.util.MultiValueMap
import org.springframework.web.server.WebSession
import reactor.core.publisher.Flux
import reactor.core.publisher.Mono
import java.net.InetSocketAddress
import java.security.Principal

/**
Expand Down Expand Up @@ -112,3 +115,56 @@ suspend fun ServerRequest.awaitPrincipal(): Principal? =
*/
suspend fun ServerRequest.awaitSession(): WebSession =
session().awaitSingle()

/**
* Nullable variant of [ServerRequest.remoteAddress]
*
* @author Sebastien Deleuze
* @since 5.2.2
*/
fun ServerRequest.remoteAddressOrNull(): InetSocketAddress? = remoteAddress().orElse(null)

/**
* Nullable variant of [ServerRequest.attribute]
*
* @author Sebastien Deleuze
* @since 5.2.2
*/
fun ServerRequest.attributeOrNull(name: String): Any? = attributes()[name]

/**
* Nullable variant of [ServerRequest.queryParam]
*
* @author Sebastien Deleuze
* @since 5.2.2
*/
fun ServerRequest.queryParamOrNull(name: String): String? {
val queryParamValues = queryParams()[name]
return if (CollectionUtils.isEmpty(queryParamValues)) {
null
} else {
var value: String? = queryParamValues!![0]
if (value == null) {
value = ""
}
value
}
}

/**
* Nullable variant of [ServerRequest.Headers.contentLength]
*
* @author Sebastien Deleuze
* @since 5.2.2
*/
fun ServerRequest.Headers.contentLengthOrNull(): Long? =
contentLength().run { if (isPresent) asLong else null }

/**
* Nullable variant of [ServerRequest.Headers.contentType]
*
* @author Sebastien Deleuze
* @since 5.2.2
*/
fun ServerRequest.Headers.contentTypeOrNull(): MediaType? =
contentType().orElse(null)
Expand Up @@ -23,11 +23,15 @@ import kotlinx.coroutines.runBlocking
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.Test
import org.springframework.core.ParameterizedTypeReference
import org.springframework.http.MediaType
import org.springframework.http.codec.multipart.Part
import org.springframework.util.CollectionUtils
import org.springframework.util.MultiValueMap
import org.springframework.web.server.WebSession
import reactor.core.publisher.Mono
import java.net.InetSocketAddress
import java.security.Principal
import java.util.*

/**
* Mock object based tests for [ServerRequest] Kotlin extensions.
Expand All @@ -38,6 +42,8 @@ class ServerRequestExtensionsTests {

val request = mockk<ServerRequest>(relaxed = true)

val headers = mockk<ServerRequest.Headers>(relaxed = true)

@Test
fun `bodyToMono with reified type parameters`() {
request.bodyToMono<List<Foo>>()
Expand Down Expand Up @@ -108,6 +114,90 @@ class ServerRequestExtensionsTests {
}
}

@Test
fun `remoteAddressOrNull with value`() {
val remoteAddress = InetSocketAddress(1234)
every { request.remoteAddress() } returns Optional.of(remoteAddress)
assertThat(remoteAddress).isEqualTo(request.remoteAddressOrNull())
verify { request.remoteAddress() }
}

@Test
fun `remoteAddressOrNull with null`() {
every { request.remoteAddress() } returns Optional.empty()
assertThat(request.remoteAddressOrNull()).isNull()
verify { request.remoteAddress() }
}

@Test
fun `attributeOrNull with value`() {
every { request.attributes() } returns mapOf("foo" to "bar")
assertThat(request.attributeOrNull("foo")).isEqualTo("bar")
verify { request.attributes() }
}

@Test
fun `attributeOrNull with null`() {
every { request.attributes() } returns mapOf("foo" to "bar")
assertThat(request.attributeOrNull("baz")).isNull()
verify { request.attributes() }
}

@Test
fun `queryParamOrNull with value`() {
every { request.queryParams() } returns CollectionUtils.toMultiValueMap(mapOf("foo" to listOf("bar")))
assertThat(request.queryParamOrNull("foo")).isEqualTo("bar")
verify { request.queryParams() }
}

@Test
fun `queryParamOrNull with values`() {
every { request.queryParams() } returns CollectionUtils.toMultiValueMap(mapOf("foo" to listOf("bar", "bar")))
assertThat(request.queryParamOrNull("foo")).isEqualTo("bar")
verify { request.queryParams() }
}

@Test
fun `queryParamOrNull with null value`() {
every { request.queryParams() } returns CollectionUtils.toMultiValueMap(mapOf("foo" to listOf(null)))
assertThat(request.queryParamOrNull("foo")).isEqualTo("")
verify { request.queryParams() }
}

@Test
fun `queryParamOrNull with null`() {
every { request.queryParams() } returns CollectionUtils.toMultiValueMap(mapOf("foo" to listOf("bar")))
assertThat(request.queryParamOrNull("baz")).isNull()
verify { request.queryParams() }
}

@Test
fun `contentLengthOrNull with value`() {
every { headers.contentLength() } returns OptionalLong.of(123)
assertThat(headers.contentLengthOrNull()).isEqualTo(123)
verify { headers.contentLength() }
}

@Test
fun `contentLengthOrNull with null`() {
every { headers.contentLength() } returns OptionalLong.empty()
assertThat(headers.contentLengthOrNull()).isNull()
verify { headers.contentLength() }
}

@Test
fun `contentTypeOrNull with value`() {
every { headers.contentType() } returns Optional.of(MediaType.APPLICATION_JSON)
assertThat(headers.contentTypeOrNull()).isEqualTo(MediaType.APPLICATION_JSON)
verify { headers.contentType() }
}

@Test
fun `contentTypeOrNull with null`() {
every { headers.contentType() } returns Optional.empty()
assertThat(headers.contentTypeOrNull()).isNull()
verify { headers.contentType() }
}

class Foo
}

0 comments on commit 6fa9871

Please sign in to comment.