Skip to content

Commit

Permalink
bug/Permit deleting db entries with leading slashes (#218)
Browse files Browse the repository at this point in the history
* poetry lock

* Formatting

* Avoid object has no attribute '_refresh_timer'

When called without a refresh function we end up getting

    AttributeError: 'AsyncDatabase' object has no attribute '_refresh_timer'

on attempted close. Fix this.

* Revert #172

This was a good first attempt, but still prevented users from deleting the keys they'd created

* Switch delete method to use new multipart/form-data delete method that supports deleting keys with leading slashes

* Adjusting tests to reflect the gap in getting slashy keys
  • Loading branch information
blast-hardcheese committed May 3, 2024
1 parent f8651c7 commit 0f7ac60
Show file tree
Hide file tree
Showing 9 changed files with 962 additions and 712 deletions.
1,582 changes: 898 additions & 684 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/replit/__main__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""CLI for interacting with your Repl's DB. Written as top-level script."""

import json

import click
Expand Down
1 change: 1 addition & 0 deletions src/replit/database/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Interface with the Replit Database."""

from typing import Any

from . import default_db
Expand Down
27 changes: 11 additions & 16 deletions src/replit/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from aiohttp_retry import ExponentialRetry, RetryClient # type: ignore
import requests
from requests.adapters import HTTPAdapter, Retry
from urllib3.filepost import encode_multipart_formdata


def to_primitive(o: Any) -> Any:
Expand Down Expand Up @@ -62,18 +63,6 @@ def dumps(val: Any) -> str:
_dumps = dumps


def _sanitize_key(key: str) -> str:
"""Strip slashes from the beginning of keys.
Args:
key (str): The key to strip
Returns:
str: The stripped key
"""
return key.lstrip("/")


class AsyncDatabase:
"""Async interface for Replit Database.
Expand Down Expand Up @@ -123,6 +112,8 @@ def __init__(
if self._get_db_url:
self._refresh_timer = threading.Timer(3600, self._refresh_db)
self._refresh_timer.start()
else:
self._refresh_timer = None
watched_thread = threading.main_thread()
self._watchdog_timer = threading.Timer(1, self._watchdog, args=[watched_thread])
self._watchdog_timer.start()
Expand Down Expand Up @@ -228,7 +219,6 @@ async def set_bulk_raw(self, values: Dict[str, str]) -> None:
Args:
values (Dict[str, str]): The key-value pairs to set.
"""
values = {_sanitize_key(k): v for k, v in values.items()}
async with self.client.post(self.db_url, data=values) as response:
response.raise_for_status()

Expand All @@ -241,8 +231,9 @@ async def delete(self, key: str) -> None:
Raises:
KeyError: Key does not exist
"""
body, content_type = encode_multipart_formdata({"key": key})
async with self.client.delete(
self.db_url + "/" + urllib.parse.quote(key)
self.db_url, data=body, headers={"Content-Type": content_type}
) as response:
if response.status == 404:
raise KeyError(key)
Expand Down Expand Up @@ -550,6 +541,8 @@ def __init__(
if self._get_db_url:
self._refresh_timer = threading.Timer(3600, self._refresh_db)
self._refresh_timer.start()
else:
self._refresh_timer = None
watched_thread = threading.main_thread()
self._watchdog_timer = threading.Timer(1, self._watchdog, args=[watched_thread])
self._watchdog_timer.start()
Expand Down Expand Up @@ -685,7 +678,6 @@ def set_bulk_raw(self, values: Dict[str, str]) -> None:
Args:
values (Dict[str, str]): The key-value pairs to set.
"""
values = {_sanitize_key(k): v for k, v in values.items()}
r = self.sess.post(self.db_url, data=values)
r.raise_for_status()

Expand All @@ -698,7 +690,10 @@ def __delitem__(self, key: str) -> None:
Raises:
KeyError: Key is not set
"""
r = self.sess.delete(self.db_url + "/" + urllib.parse.quote(key))
body, content_type = encode_multipart_formdata({"key": key})
r = self.sess.delete(
self.db_url, data=body, headers={"Content-Type": content_type}
)
if r.status_code == 404:
raise KeyError(key)

Expand Down
1 change: 1 addition & 0 deletions src/replit/database/default_db.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""A module containing the default database."""

import os
import os.path
from typing import Any, Optional
Expand Down
1 change: 1 addition & 0 deletions src/replit/database/server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""A module containing a database proxy implementation."""

from typing import Any
from urllib.parse import quote

Expand Down
1 change: 1 addition & 0 deletions src/replit/info.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Information about your repl."""

import os
from typing import Optional

Expand Down
1 change: 1 addition & 0 deletions src/replit/web/user.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utilities for working with user mappings."""

from collections.abc import Mapping, MutableMapping
from typing import Any, Iterator, Optional

Expand Down
59 changes: 47 additions & 12 deletions tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,30 +129,48 @@ async def test_slash_keys(self) -> None:
"""Test that slash keys work."""
k = "/key"
# set
await self.db.set(k,"val1")
self.assertEqual(await self.db.get(k), "val1")
await self.db.set(k, "val1")
# TODO: Getting slash keys is currently not supported
# See https://github.com/replit/replit-py/pull/218#discussion_r1588295348
# self.assertEqual(await self.db.get(k), "val1")
self.assertEqual(list(await self.db.list("/")), [k])
await self.db.delete(k)
# TODO: Getting slash keys is currently not supported
# KeyError is the same though, so it can stay.
with self.assertRaises(KeyError):
await self.db.get(k)
# set_raw
await self.db.set_raw(k,"val1")
self.assertEqual(await self.db.get_raw(k), "val1")
await self.db.set_raw(k, "val1")
# TODO: Getting slash keys is currently not supported
# self.assertEqual(await self.db.get_raw(k), "val1")
self.assertEqual(list(await self.db.list("/")), [k])
await self.db.delete(k)
# TODO: Getting slash keys is currently not supported.
# KeyError is the same though, so it can stay.
with self.assertRaises(KeyError):
await self.db.get(k)
# set_bulk
await self.db.set_bulk({k: "val1"})
self.assertEqual(await self.db.get(k), "val1")
# TODO: Getting slash keys is currently not supported
# self.assertEqual(await self.db.get(k), "val1")
self.assertEqual(list(await self.db.list("/")), [k])
await self.db.delete(k)
# TODO: Getting slash keys is currently not supported
# KeyError is the same though, so it can stay.
with self.assertRaises(KeyError):
await self.db.get(k)
# set_bulk_raw
await self.db.set_bulk_raw({k: "val1"})
self.assertEqual(await self.db.get_raw(k), "val1")
# TODO: Getting slash keys is currently not supported
# self.assertEqual(await self.db.get_raw(k), "val1")
self.assertEqual(list(await self.db.list("/")), [k])
await self.db.delete(k)
# TODO: Getting slash keys is currently not supported
# KeyError is the same though, so it can stay.
with self.assertRaises(KeyError):
await self.db.get(k)


class TestDatabase(unittest.TestCase):
"""Tests for replit.database.Database."""

Expand Down Expand Up @@ -291,26 +309,43 @@ def test_slash_keys(self) -> None:
"""Test that slash keys work."""
k = "/key"
# set
self.db.set(k,"val1")
self.assertEqual(self.db[k], "val1")
self.db.set(k, "val1")
# TODO: Getting slash keys is currently not supported
# See https://github.com/replit/replit-py/pull/218#discussion_r1588295348
# self.assertEqual(self.db[k], "val1")
self.assertEqual(list(self.db.keys()), [k])
del self.db[k]
# TODO: Getting slash keys is currently not supported
# KeyError is the same though, so it can stay.
with self.assertRaises(KeyError):
self.db[k]
# set_raw
self.db.set_raw(k,"val1")
self.assertEqual(self.db.get_raw(k), "val1")
self.db.set_raw(k, "val1")
# TODO: Getting slash keys is currently not supported
# self.assertEqual(self.db.get_raw(k), "val1")
self.assertEqual(list(self.db.keys()), [k])
del self.db[k]
# TODO: Getting slash keys is currently not supported
# KeyError is the same though, so it can stay.
with self.assertRaises(KeyError):
self.db[k]
# set_bulk
self.db.set_bulk({k: "val1"})
self.assertEqual(self.db.get(k), "val1")
# TODO: Getting slash keys is currently not supported
# self.assertEqual(self.db.get(k), "val1")
self.assertEqual(list(self.db.keys()), [k])
del self.db[k]
# TODO: Getting slash keys is currently not supported
# KeyError is the same though, so it can stay.
with self.assertRaises(KeyError):
self.db[k]
# set_bulk_raw
self.db.set_bulk_raw({k: "val1"})
self.assertEqual(self.db.get_raw(k), "val1")
# TODO: Getting slash keys is currently not supported
# self.assertEqual(self.db.get_raw(k), "val1")
self.assertEqual(list(self.db.keys()), [k])
del self.db[k]
# TODO: Getting slash keys is currently not supported
# KeyError is the same though, so it can stay.
with self.assertRaises(KeyError):
self.db[k]

0 comments on commit 0f7ac60

Please sign in to comment.