Skip to content

Commit

Permalink
Reverting last commit and adding backward compatibility to 'username'…
Browse files Browse the repository at this point in the history
… and 'password' inside on_connect function
  • Loading branch information
barshaul committed Oct 27, 2022
1 parent eebe6d3 commit c68a715
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 80 deletions.
88 changes: 12 additions & 76 deletions redis/connection.py
Expand Up @@ -8,7 +8,6 @@
from itertools import chain
from queue import Empty, Full, LifoQueue
from time import time
from typing import Optional
from urllib.parse import parse_qs, unquote, urlparse

from redis.backoff import NoBackoff
Expand Down Expand Up @@ -527,10 +526,8 @@ def __init__(
)

self.credential_provider = credential_provider
if username or password:
# Keep backward compatibility by creating a static credential provider
# for the passed username and password
self.credential_provider = StaticCredentialProvider(username, password)
self.password = password
self.username = username
self.socket_timeout = socket_timeout
self.socket_connect_timeout = socket_connect_timeout or socket_timeout
self.socket_keepalive = socket_keepalive
Expand Down Expand Up @@ -563,38 +560,6 @@ def __init__(
self._connect_callbacks = []
self._buffer_cutoff = 6000

@property
def password(self) -> Optional[str]:
if self.credential_provider is not None:
return self.credential_provider.password
else:
return None

@password.setter
def password(self, value: Optional[str]):
if value is None:
# Delete the credential provider
self.credential_provider = None
return
if self.credential_provider is not None:
self.credential_provider.password = value
else:
self.credential_provider = StaticCredentialProvider(password=value)

@property
def username(self) -> Optional[str]:
if self.credential_provider is not None:
return self.credential_provider.username
else:
return None

@username.setter
def username(self, value: Optional[str]):
if self.credential_provider is not None:
self.credential_provider.username = value
else:
self.credential_provider = StaticCredentialProvider(username=value)

def __repr__(self):
repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()])
return f"{self.__class__.__name__}<{repr_args}>"
Expand Down Expand Up @@ -721,9 +686,14 @@ def on_connect(self):
"Initialize the connection, authenticate and select a database"
self._parser.on_connect(self)

# if credentials provider is set, authenticate
if self.credential_provider is not None:
auth_args = self.credential_provider.get_credentials()
# if credential provider or username and/or password are set, authenticate
if self.credential_provider or (self.username or self.password):
cred_provider = (
self.credential_provider
if self.credential_provider
else StaticCredentialProvider(self.username, self.password)
)
auth_args = cred_provider.get_credentials()
# avoid checking health here -- PING will fail if we try
# to check the health prior to the AUTH
self.send_command("AUTH", *auth_args, check_health=False)
Expand Down Expand Up @@ -1118,10 +1088,8 @@ def __init__(
"2. 'credential_provider'"
)
self.credential_provider = credential_provider
if username or password:
# Keep backward compatibility by creating a static credential provider
# for the passed username and password
self.credential_provider = StaticCredentialProvider(username, password)
self.password = password
self.username = username
self.socket_timeout = socket_timeout
self.retry_on_timeout = retry_on_timeout
if retry_on_error is SENTINEL:
Expand Down Expand Up @@ -1150,38 +1118,6 @@ def __init__(
self._connect_callbacks = []
self._buffer_cutoff = 6000

@property
def password(self) -> Optional[str]:
if self.credential_provider is not None:
return self.credential_provider.password
else:
return None

@password.setter
def password(self, value: Optional[str]):
if value is None:
# Delete the credential provider
self.credential_provider = None
return
if self.credential_provider is not None:
self.credential_provider.password = value
else:
self.credential_provider = StaticCredentialProvider(password=value)

@property
def username(self) -> Optional[str]:
if self.credential_provider is not None:
return self.credential_provider.username
else:
return None

@username.setter
def username(self, value: Optional[str]):
if self.credential_provider is not None:
self.credential_provider.username = value
else:
self.credential_provider = StaticCredentialProvider(username=value)

def repr_pieces(self):
pieces = [("path", self.path), ("db", self.db)]
if self.client_name:
Expand Down
4 changes: 0 additions & 4 deletions tests/test_credentials.py
Expand Up @@ -199,14 +199,10 @@ def test_change_username_password_on_existing_connection(self, r, request):
init_acl_user(r, request, new_username, new_password)
conn.password = new_password
conn.username = new_username
assert conn.credential_provider.password == new_password
assert conn.credential_provider.username == new_username
conn.send_command("PING")
assert str_if_bytes(conn.read_response()) == "PONG"
conn.username = None
assert conn.credential_provider.username == ""
conn.password = None
assert conn.credential_provider is None


class TestStaticCredentialProvider:
Expand Down

0 comments on commit c68a715

Please sign in to comment.