diff --git a/mockwebserver/api/mockwebserver3.api b/mockwebserver/api/mockwebserver3.api index 6c0b263acb26..2e32e3c41a82 100644 --- a/mockwebserver/api/mockwebserver3.api +++ b/mockwebserver/api/mockwebserver3.api @@ -86,6 +86,7 @@ public final class mockwebserver3/MockWebServer : java/io/Closeable { public final fun getBodyLimit ()J public final fun getDispatcher ()Lmockwebserver3/Dispatcher; public final fun getHostName ()Ljava/lang/String; + public final fun getInetSocketAddress ()Ljava/net/InetSocketAddress; public final fun getPort ()I public final fun getProtocolNegotiationEnabled ()Z public final fun getRequestCount ()I diff --git a/mockwebserver/src/main/kotlin/mockwebserver3/MockWebServer.kt b/mockwebserver/src/main/kotlin/mockwebserver3/MockWebServer.kt index 4f5a65ad60a5..4c2732385e6f 100644 --- a/mockwebserver/src/main/kotlin/mockwebserver3/MockWebServer.kt +++ b/mockwebserver/src/main/kotlin/mockwebserver3/MockWebServer.kt @@ -152,10 +152,16 @@ class MockWebServer : Closeable { val hostName: String get() { before() - return inetSocketAddress!!.address.canonicalHostName + return _inetSocketAddress!!.address.canonicalHostName } - private var inetSocketAddress: InetSocketAddress? = null + private var _inetSocketAddress: InetSocketAddress? = null + + val inetSocketAddress: InetSocketAddress + get() { + before() + return InetSocketAddress(hostName, portField) + } /** * True if ALPN is used on incoming HTTPS connections to negotiate a protocol like HTTP/1.1 or @@ -206,7 +212,7 @@ class MockWebServer : Closeable { fun toProxyAddress(): Proxy { before() - val address = InetSocketAddress(inetSocketAddress!!.address.canonicalHostName, port) + val address = InetSocketAddress(_inetSocketAddress!!.address.canonicalHostName, port) return Proxy(Proxy.Type.HTTP, address) } @@ -383,7 +389,7 @@ class MockWebServer : Closeable { if (started) return started = true - this.inetSocketAddress = inetSocketAddress + this._inetSocketAddress = inetSocketAddress serverSocket = serverSocketFactory!!.createServerSocket() diff --git a/okhttp-testing-support/src/main/kotlin/okhttp3/SpecificHostSocketFactory.kt b/okhttp-testing-support/src/main/kotlin/okhttp3/SpecificHostSocketFactory.kt new file mode 100644 index 000000000000..d58d08ee50f1 --- /dev/null +++ b/okhttp-testing-support/src/main/kotlin/okhttp3/SpecificHostSocketFactory.kt @@ -0,0 +1,48 @@ +/* + * Copyright (C) 2022 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3 + +import java.net.InetAddress +import java.net.InetSocketAddress +import java.net.Socket +import java.net.SocketAddress + +/** + * A [SocketFactory] that redirects connections to [defaultAddress] or specific overridden address via [set]. + */ +class SpecificHostSocketFactory( + val defaultAddress: InetSocketAddress? +) : DelegatingSocketFactory(getDefault()) { + private val hostMapping = mutableMapOf() + + /** Sets the [real] address for [requested]. */ + operator fun set( + requested: InetAddress, + real: InetSocketAddress + ) { + hostMapping[requested] = real + } + + override fun createSocket(): Socket { + return object : Socket() { + override fun connect(endpoint: SocketAddress?, timeout: Int) { + val requested = (endpoint as InetSocketAddress) + val inetSocketAddress = hostMapping[requested.address] ?: defaultAddress ?: requested + super.connect(inetSocketAddress, timeout) + } + } + } +} diff --git a/okhttp/src/jvmMain/kotlin/okhttp3/internal/connection/RouteDatabase.kt b/okhttp/src/jvmMain/kotlin/okhttp3/internal/connection/RouteDatabase.kt index c7c7422fc39e..4685de667166 100644 --- a/okhttp/src/jvmMain/kotlin/okhttp3/internal/connection/RouteDatabase.kt +++ b/okhttp/src/jvmMain/kotlin/okhttp3/internal/connection/RouteDatabase.kt @@ -24,18 +24,22 @@ import okhttp3.Route * preferred. */ class RouteDatabase { - private val failedRoutes = mutableSetOf() + private val _failedRoutes = mutableSetOf() + + val failedRoutes: Set + @Synchronized get() = _failedRoutes.toSet() + /** Records a failure connecting to [failedRoute]. */ @Synchronized fun failed(failedRoute: Route) { - failedRoutes.add(failedRoute) + _failedRoutes.add(failedRoute) } /** Records success connecting to [route]. */ @Synchronized fun connected(route: Route) { - failedRoutes.remove(route) + _failedRoutes.remove(route) } /** Returns true if [route] has failed recently and should be avoided. */ - @Synchronized fun shouldPostpone(route: Route): Boolean = route in failedRoutes + @Synchronized fun shouldPostpone(route: Route): Boolean = route in _failedRoutes } diff --git a/okhttp/src/jvmTest/java/okhttp3/RouteFailureTest.kt b/okhttp/src/jvmTest/java/okhttp3/RouteFailureTest.kt new file mode 100644 index 000000000000..21789543772f --- /dev/null +++ b/okhttp/src/jvmTest/java/okhttp3/RouteFailureTest.kt @@ -0,0 +1,275 @@ +/* + * Copyright (C) 2022 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3 + +import java.io.IOException +import java.net.InetAddress +import java.net.InetSocketAddress +import mockwebserver3.MockResponse +import mockwebserver3.MockWebServer +import mockwebserver3.SocketPolicy +import mockwebserver3.junit5.internal.MockWebServerInstance +import okhttp3.internal.http2.ErrorCode +import okhttp3.testing.PlatformRule +import okhttp3.tls.internal.TlsUtil.localhost +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.RegisterExtension + +class RouteFailureTest { + private lateinit var socketFactory: SpecificHostSocketFactory + private lateinit var client: OkHttpClient + + @RegisterExtension + val platform = PlatformRule() + + @RegisterExtension + val clientTestRule = OkHttpClientTestRule() + + private lateinit var server1: MockWebServer + private lateinit var server2: MockWebServer + + private var listener = RecordingEventListener() + + private val handshakeCertificates = localhost() + + val dns = FakeDns() + + val ipv4 = InetAddress.getByName("203.0.113.1") + val ipv6 = InetAddress.getByName("2001:db8:ffff:ffff:ffff:ffff:ffff:1") + + val refusedStream = MockResponse() + .setHttp2ErrorCode(ErrorCode.REFUSED_STREAM.httpCode) + .setSocketPolicy(SocketPolicy.RESET_STREAM_AT_START) + val bodyResponse = MockResponse().setBody("body") + + @BeforeEach + fun setUp( + server: MockWebServer, + @MockWebServerInstance("server2") server2: MockWebServer + ) { + this.server1 = server + this.server2 = server2 + + socketFactory = SpecificHostSocketFactory(InetSocketAddress(server.hostName, server.port)) + + client = clientTestRule.newClientBuilder() + .dns(dns) + .socketFactory(socketFactory) + .eventListenerFactory(clientTestRule.wrap(listener)) + .build() + } + + @Test + fun http2OneBadHostOneGoodNoRetryOnConnectionFailure() { + enableProtocol(Protocol.HTTP_2) + + val request = Request(server1.url("/")) + + server1.enqueue(refusedStream) + server2.enqueue(bodyResponse) + + dns[server1.hostName] = listOf(ipv6, ipv4) + socketFactory[ipv6] = server1.inetSocketAddress + socketFactory[ipv4] = server2.inetSocketAddress + + client = client.newBuilder() + .fastFallback(false) + .apply { + retryOnConnectionFailure = false + } + .build() + + executeSynchronously(request) + .assertFailureMatches("stream was reset: REFUSED_STREAM") + + assertThat(client.routeDatabase.failedRoutes).isEmpty() + assertThat(server1.requestCount).isEqualTo(1) + assertThat(server2.requestCount).isEqualTo(0) + } + + @Test + fun http2OneBadHostOneGoodRetryOnConnectionFailure() { + enableProtocol(Protocol.HTTP_2) + + val request = Request(server1.url("/")) + + server1.enqueue(refusedStream) + server1.enqueue(refusedStream) + server2.enqueue(bodyResponse) + + dns[server1.hostName] = listOf(ipv6, ipv4) + socketFactory[ipv6] = server1.inetSocketAddress + socketFactory[ipv4] = server2.inetSocketAddress + + client = client.newBuilder() + .fastFallback(false) + .apply { + retryOnConnectionFailure = true + } + .build() + + executeSynchronously(request) + .assertBody("body") + + assertThat(client.routeDatabase.failedRoutes).isEmpty() + // TODO check if we expect a second request to server1, before attempting server2 + assertThat(server1.requestCount).isEqualTo(2) + assertThat(server2.requestCount).isEqualTo(1) + } + + @Test + fun http2OneBadHostOneGoodNoRetryOnConnectionFailureFastFallback() { + enableProtocol(Protocol.HTTP_2) + + val request = Request(server1.url("/")) + + server1.enqueue(refusedStream) + server2.enqueue(bodyResponse) + + dns[server1.hostName] = listOf(ipv6, ipv4) + socketFactory[ipv6] = server1.inetSocketAddress + socketFactory[ipv4] = server2.inetSocketAddress + + client = client.newBuilder() + .fastFallback(true) + .apply { + retryOnConnectionFailure = false + } + .build() + + executeSynchronously(request) + .assertFailureMatches("stream was reset: REFUSED_STREAM") + + assertThat(client.routeDatabase.failedRoutes).isEmpty() + assertThat(server1.requestCount).isEqualTo(1) + assertThat(server2.requestCount).isEqualTo(0) + } + + @Test + fun http2OneBadHostOneGoodRetryOnConnectionFailureFastFallback() { + enableProtocol(Protocol.HTTP_2) + + val request = Request(server1.url("/")) + + server1.enqueue(refusedStream) + server1.enqueue(refusedStream) + server2.enqueue(bodyResponse) + + dns[server1.hostName] = listOf(ipv6, ipv4) + socketFactory[ipv6] = server1.inetSocketAddress + socketFactory[ipv4] = server2.inetSocketAddress + + client = client.newBuilder() + .fastFallback(true) + .apply { + retryOnConnectionFailure = true + } + .build() + + executeSynchronously(request) + .assertBody("body") + + assertThat(client.routeDatabase.failedRoutes).isEmpty() + // TODO check if we expect a second request to server1, before attempting server2 + assertThat(server1.requestCount).isEqualTo(2) + assertThat(server2.requestCount).isEqualTo(1) + } + + @Test + fun http2OneBadHostRetryOnConnectionFailure() { + enableProtocol(Protocol.HTTP_2) + + val request = Request(server1.url("/")) + + server1.enqueue(refusedStream) + server1.enqueue(refusedStream) + + dns[server1.hostName] = listOf(ipv6) + socketFactory[ipv6] = server1.inetSocketAddress + + client = client.newBuilder() + .fastFallback(false) + .apply { + retryOnConnectionFailure = true + } + .build() + + executeSynchronously(request) + .assertFailureMatches("stream was reset: REFUSED_STREAM") + + assertThat(client.routeDatabase.failedRoutes).isEmpty() + assertThat(server1.requestCount).isEqualTo(1) + } + + @Test + fun http2OneBadHostRetryOnConnectionFailureFastFallback() { + enableProtocol(Protocol.HTTP_2) + + val request = Request(server1.url("/")) + + server1.enqueue(refusedStream) + server1.enqueue(refusedStream) + + dns[server1.hostName] = listOf(ipv6) + socketFactory[ipv6] = server1.inetSocketAddress + + client = client.newBuilder() + .fastFallback(true) + .apply { + retryOnConnectionFailure = true + } + .build() + + executeSynchronously(request) + .assertFailureMatches("stream was reset: REFUSED_STREAM") + + assertThat(client.routeDatabase.failedRoutes).isEmpty() + assertThat(server1.requestCount).isEqualTo(1) + } + + private fun enableProtocol(protocol: Protocol) { + enableTls() + client = client.newBuilder() + .protocols(listOf(protocol, Protocol.HTTP_1_1)) + .build() + server1.protocols = client.protocols + server2.protocols = client.protocols + } + + private fun enableTls() { + client = client.newBuilder() + .sslSocketFactory( + handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager + ) + .hostnameVerifier(RecordingHostnameVerifier()) + .build() + server1.useHttps(handshakeCertificates.sslSocketFactory()) + server2.useHttps(handshakeCertificates.sslSocketFactory()) + } + + private fun executeSynchronously(request: Request): RecordedResponse { + val call = client.newCall(request) + return try { + val response = call.execute() + val bodyString = response.body.string() + RecordedResponse(request, response, null, bodyString, null) + } catch (e: IOException) { + RecordedResponse(request, null, null, null, e) + } + } +}