Skip to content

Commit

Permalink
Added retry mechanism on socket timeouts when connecting to the server (
Browse files Browse the repository at this point in the history
  • Loading branch information
barshaul authored and dvora-h committed Jan 25, 2022
1 parent b41f6b7 commit 4bbfac4
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 4 deletions.
6 changes: 4 additions & 2 deletions redis/connection.py
Expand Up @@ -604,7 +604,9 @@ def connect(self):
if self._sock:
return
try:
sock = self._connect()
sock = self.retry.call_with_retry(
lambda: self._connect(), lambda error: self.disconnect(error)
)
except socket.timeout:
raise TimeoutError("Timeout connecting to server")
except OSError as e:
Expand Down Expand Up @@ -721,7 +723,7 @@ def on_connect(self):
if str_if_bytes(self.read_response()) != "OK":
raise ConnectionError("Invalid Database")

def disconnect(self):
def disconnect(self, *args):
"Disconnects from the Redis server"
self._parser.on_disconnect()
if self._sock is None:
Expand Down
6 changes: 5 additions & 1 deletion redis/retry.py
@@ -1,3 +1,4 @@
import socket
from time import sleep

from redis.exceptions import ConnectionError, TimeoutError
Expand All @@ -7,7 +8,10 @@ class Retry:
"""Retry a specific number of times after a failure"""

def __init__(
self, backoff, retries, supported_errors=(ConnectionError, TimeoutError)
self,
backoff,
retries,
supported_errors=(ConnectionError, TimeoutError, socket.timeout),
):
"""
Initialize a `Retry` object with a `Backoff` object
Expand Down
50 changes: 49 additions & 1 deletion tests/test_connection.py
@@ -1,10 +1,14 @@
import socket
import types
from unittest import mock
from unittest.mock import patch

import pytest

from redis.backoff import NoBackoff
from redis.connection import Connection
from redis.exceptions import InvalidResponse
from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError
from redis.retry import Retry
from redis.utils import HIREDIS_AVAILABLE

from .conftest import skip_if_server_version_lt
Expand Down Expand Up @@ -74,3 +78,47 @@ def test_disconnect__close_OSError(self):
mock_sock.shutdown.assert_called_once()
mock_sock.close.assert_called_once()
assert conn._sock is None

def clear(self, conn):
conn.retry_on_error.clear()

def test_retry_connect_on_timeout_error(self):
"""Test that the _connect function is retried in case of a timeout"""
conn = Connection(retry_on_timeout=True, retry=Retry(NoBackoff(), 3))
origin_connect = conn._connect
conn._connect = mock.Mock()

def mock_connect():
# connect only on the last retry
if conn._connect.call_count <= 2:
raise socket.timeout
else:
return origin_connect()

conn._connect.side_effect = mock_connect
conn.connect()
assert conn._connect.call_count == 3
self.clear(conn)

def test_connect_without_retry_on_os_error(self):
"""Test that the _connect function is not being retried in case of a OSError"""
with patch.object(Connection, "_connect") as _connect:
_connect.side_effect = OSError("")
conn = Connection(retry_on_timeout=True, retry=Retry(NoBackoff(), 2))
with pytest.raises(ConnectionError):
conn.connect()
assert _connect.call_count == 1
self.clear(conn)

def test_connect_timeout_error_without_retry(self):
"""Test that the _connect function is not being retried if retry_on_timeout is
set to False"""
conn = Connection(retry_on_timeout=False)
conn._connect = mock.Mock()
conn._connect.side_effect = socket.timeout

with pytest.raises(TimeoutError) as e:
conn.connect()
assert conn._connect.call_count == 1
assert str(e.value) == "Timeout connecting to server"
self.clear(conn)

0 comments on commit 4bbfac4

Please sign in to comment.