From c17aac858eb1186dc8c2827b0c9301b8a5627a36 Mon Sep 17 00:00:00 2001 From: Michele Cardone Date: Wed, 3 Feb 2021 11:17:15 +0100 Subject: [PATCH] Align `FakeConnection` constructor signature to base class. As `FakeConnection` is a subclass of the `Connection` class in `redis-py`, in order to honor OOP substitutability principle, the signature of the constructor should be aligned with the one of the base class. That means it should contain at least all the parameters of the base class. This automatically fixes a problem occurring when the connection object is dynamically instantiated within the context of the default connection pool class in `redis-py`, named `ConnectionPool`, and when, in turn, the latter is dynamically instantiated using its standard `from_url` factory method. In such scenario the instantiation of the connection object fails because the `FakeConnection` class misses both `host` and `port` parameters in its constructor signature. That happens because the `from_url` factory method above, given a canonical Redis URL, guesses several connection parameters values, among which `host`, `username`, `password`, `path`, and `db`, which values override any other corresponding value passed to the factory method itself via the connection variadic arguments. Aligning the `FakeConnection` constructor signature to the one of its base class, by including the `host` and `port` parameters, is helpful in contexts where there's no direct control over the instantiantiation of Redis clients, connections and connection pools objects (e.g. `django-redis` library): whilst preserving current behavior, this avoids to workaround the issue by defining either custom connection factories or custom connection pools. --- fakeredis/_server.py | 50 +++++++++++++++++--------------------------- 1 file changed, 19 insertions(+), 31 deletions(-) diff --git a/fakeredis/_server.py b/fakeredis/_server.py index 556fb4e..549955d 100644 --- a/fakeredis/_server.py +++ b/fakeredis/_server.py @@ -1,5 +1,4 @@ import logging -import os import time import threading import math @@ -2649,39 +2648,28 @@ def check_is_ready_for_command(self, timeout): class FakeConnection(redis.Connection): description_format = "FakeConnection" - def __init__(self, server, db=0, username=None, password=None, + def __init__(self, server, host='localhost', port=6379, db=0, password=None, socket_timeout=None, socket_connect_timeout=None, socket_keepalive=False, socket_keepalive_options=None, - socket_type=0, retry_on_timeout=False, - encoding='utf-8', encoding_errors='strict', - decode_responses=False, parser_class=_DummyParser, - socket_read_size=65536, health_check_interval=0, - client_name=None): - self.pid = os.getpid() - self.db = db - self.username = username - self.client_name = client_name - self.password = password - # Allow socket attributes to be passed in and saved even if they aren't used - self.socket_timeout = socket_timeout - self.socket_connect_timeout = socket_connect_timeout or socket_timeout - self.socket_keepalive = socket_keepalive - self.socket_keepalive_options = socket_keepalive_options or {} - self.socket_type = socket_type - self.retry_on_timeout = retry_on_timeout - self.encoder = redis.connection.Encoder(encoding, encoding_errors, decode_responses) - self._description_args = {'db': self.db} - self._connect_callbacks = [] - self._buffer_cutoff = 6000 + socket_type=0, retry_on_timeout=False, encoding='utf-8', + encoding_errors='strict', decode_responses=False, + parser_class=_DummyParser, socket_read_size=65536, + health_check_interval=0, client_name=None, username=None): self._server = server - # self._parser isn't used for anything, but some of the - # base class methods depend on it and it's easier not to - # override them. - self._parser = parser_class(socket_read_size=socket_read_size) - self._sock = None - # added in redis==3.3.0 - self.health_check_interval = health_check_interval - self.next_health_check = 0 + self._description_args = {'db': db} + super().__init__( + host=host, port=port, db=db, password=password, + socket_timeout=socket_timeout, + socket_connect_timeout=socket_connect_timeout, + socket_keepalive=socket_keepalive, + socket_keepalive_options=socket_keepalive_options, + socket_type=socket_type, retry_on_timeout=retry_on_timeout, + encoding=encoding, encoding_errors=encoding_errors, + decode_responses=decode_responses, parser_class=parser_class, + socket_read_size=socket_read_size, + health_check_interval=health_check_interval, + client_name=client_name, username=username + ) def connect(self): super().connect()