Skip to content

Commit

Permalink
fix(test): Fixes for "TestClient" changes
Browse files Browse the repository at this point in the history
Seems that client is optional according to the ASGI spec.
https://asgi.readthedocs.io/en/latest/specs/www.html

With Starlette 0.35 the TestClient connection  scope is None for "client".
encode/starlette#2377

Signed-off-by: moson <moson@archlinux.org>
  • Loading branch information
moson-mo committed Jan 19, 2024
1 parent 22e1577 commit 2fcd793
Show file tree
Hide file tree
Showing 8 changed files with 29 additions and 16 deletions.
3 changes: 2 additions & 1 deletion aurweb/models/ban.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from aurweb import db, schema
from aurweb.models.declarative import Base
from aurweb.util import get_client_ip


class Ban(Base):
Expand All @@ -14,6 +15,6 @@ def __init__(self, **kwargs):


def is_banned(request: Request):
ip = request.client.host
ip = get_client_ip(request)
exists = db.query(Ban).filter(Ban.IPAddress == ip).exists()
return db.query(exists).scalar()
2 changes: 1 addition & 1 deletion aurweb/models/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def login(self, request: Request, password: str) -> str:
try:
with db.begin():
self.LastLogin = now_ts
self.LastLoginIPAddress = request.client.host
self.LastLoginIPAddress = util.get_client_ip(request)
if not self.session:
sid = generate_unique_sid()
self.session = db.create(
Expand Down
7 changes: 4 additions & 3 deletions aurweb/ratelimit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from aurweb import aur_logging, config, db, time
from aurweb.aur_redis import redis_connection
from aurweb.models import ApiRateLimit
from aurweb.util import get_client_ip

logger = aur_logging.get_logger(__name__)

Expand All @@ -13,7 +14,7 @@ def _update_ratelimit_redis(request: Request, pipeline: Pipeline):
now = time.utcnow()
time_to_delete = now - window_length

host = request.client.host
host = get_client_ip(request)
window_key = f"ratelimit-ws:{host}"
requests_key = f"ratelimit:{host}"

Expand Down Expand Up @@ -55,7 +56,7 @@ def retry_create(record: ApiRateLimit, now: int, host: str) -> ApiRateLimit:
record.Requests += 1
return record

host = request.client.host
host = get_client_ip(request)
record = db.query(ApiRateLimit, ApiRateLimit.IP == host).first()
record = retry_create(record, now, host)

Expand Down Expand Up @@ -92,7 +93,7 @@ def check_ratelimit(request: Request):
record = update_ratelimit(request, pipeline)

# Get cache value, else None.
host = request.client.host
host = get_client_ip(request)
pipeline.get(f"ratelimit:{host}")
requests = pipeline.execute()[0]

Expand Down
6 changes: 4 additions & 2 deletions aurweb/routers/sso.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def open_session(request, conn, user_id):
conn.execute(
Users.update()
.where(Users.c.ID == user_id)
.values(LastLogin=int(time.time()), LastLoginIPAddress=request.client.host)
.values(
LastLogin=int(time.time()), LastLoginIPAddress=util.get_client_ip(request)
)
)

return sid
Expand Down Expand Up @@ -110,7 +112,7 @@ async def authenticate(
Receive an OpenID Connect ID token, validate it, then process it to create
an new AUR session.
"""
if is_ip_banned(conn, request.client.host):
if is_ip_banned(conn, util.get_client_ip(request)):
_ = get_translator_for_request(request)
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
Expand Down
2 changes: 1 addition & 1 deletion aurweb/users/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def invalid_password(


def is_banned(request: Request = None, **kwargs) -> None:
host = request.client.host
host = util.get_client_ip(request)
exists = db.query(models.Ban, models.Ban.IPAddress == host).exists()
if db.query(exists).scalar():
raise ValidationError(
Expand Down
8 changes: 8 additions & 0 deletions aurweb/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,11 @@ def hash_query(query: Query):
return sha1(
str(query.statement.compile(compile_kwargs={"literal_binds": True})).encode()
).hexdigest()


def get_client_ip(request: fastapi.Request) -> str:
"""
Returns the client's IP address for a Request.
Falls back to 'no-client' is request.client is None
"""
return request.client.host if request.client else "no-client"
5 changes: 3 additions & 2 deletions test/test_accounts_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,9 +391,10 @@ def test_post_register_error_invalid_captcha(client: TestClient):


def test_post_register_error_ip_banned(client: TestClient):
# 'testclient' is used as request.client.host via FastAPI TestClient.
# 'no-client' is our fallback value in case request.client is None
# which is the case for TestClient
with db.begin():
create(Ban, IPAddress="testclient", BanTS=datetime.utcnow())
create(Ban, IPAddress="no-client", BanTS=datetime.utcnow())

with client as request:
response = post_register(request)
Expand Down
12 changes: 6 additions & 6 deletions test/test_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,10 +310,10 @@ def pipeline():
redis = redis_connection()
pipeline = redis.pipeline()

# The 'testclient' host is used when requesting the app
# via fastapi.testclient.TestClient.
pipeline.delete("ratelimit-ws:testclient")
pipeline.delete("ratelimit:testclient")
# 'no-client' is our fallback value in case request.client is None
# which is the case for TestClient
pipeline.delete("ratelimit-ws:no-client")
pipeline.delete("ratelimit:no-client")
pipeline.execute()

yield pipeline
Expand Down Expand Up @@ -760,8 +760,8 @@ def test_rpc_ratelimit(
assert response.status_code == int(HTTPStatus.TOO_MANY_REQUESTS)

# Delete the cached records.
pipeline.delete("ratelimit-ws:testclient")
pipeline.delete("ratelimit:testclient")
pipeline.delete("ratelimit-ws:no-client")
pipeline.delete("ratelimit:no-client")
one, two = pipeline.execute()
assert one and two

Expand Down

0 comments on commit 2fcd793

Please sign in to comment.