Skip to content

Commit

Permalink
Merge pull request #1089 from neptune-ai/add_limit_do_disk_queue
Browse files Browse the repository at this point in the history
Add limit do disk queue
  • Loading branch information
pankin397 committed Nov 18, 2022
2 parents 22fffd1 + 00e21dc commit b52f4ca
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 42 deletions.
4 changes: 3 additions & 1 deletion src/neptune/new/cli/sync.py
Expand Up @@ -89,9 +89,11 @@ def sync_execution(
lock=threading.RLock(),
) as disk_queue:
while True:
batch, version = disk_queue.get_batch(1000)
batch = disk_queue.get_batch(1000)
if not batch:
break
version = batch[-1].ver
batch = [element.obj for element in batch]

start_time = time.monotonic()
expected_count = len(batch)
Expand Down
69 changes: 44 additions & 25 deletions src/neptune/new/internal/disk_queue.py
Expand Up @@ -18,6 +18,7 @@
import os
import shutil
import threading
from dataclasses import dataclass
from glob import glob
from pathlib import Path
from typing import (
Expand All @@ -33,16 +34,23 @@
from neptune.new.internal.utils.json_file_splitter import JsonFileSplitter
from neptune.new.internal.utils.sync_offset_file import SyncOffsetFile

__all__ = ["DiskQueue"]
__all__ = ["QueueElement", "DiskQueue"]

T = TypeVar("T")

_logger = logging.getLogger(__name__)


class DiskQueue(Generic[T]):
@dataclass
class QueueElement(Generic[T]):
obj: T
ver: int
size: int


class DiskQueue(Generic[T]):
# NOTICE: This class is thread-safe as long as there is only one consumer and one producer.
DEFAULT_MAX_BATCH_SIZE_BYTES = 100 * 1024**2

def __init__(
self,
Expand All @@ -51,11 +59,15 @@ def __init__(
from_dict: Callable[[dict], T],
lock: threading.RLock,
max_file_size: int = 64 * 1024**2,
max_batch_size_bytes: int = None,
):
self._dir_path = dir_path.resolve()
self._to_dict = to_dict
self._from_dict = from_dict
self._max_file_size = max_file_size
self._max_batch_size_bytes = max_batch_size_bytes or int(
os.environ.get("NEPTUNE_MAX_BATCH_SIZE_BYTES") or str(self.DEFAULT_MAX_BATCH_SIZE_BYTES)
)

try:
os.makedirs(self._dir_path)
Expand Down Expand Up @@ -90,57 +102,64 @@ def put(self, obj: T) -> int:
self._file_size += len(_json) + 1
return version

def get(self) -> Tuple[Optional[T], int]:
def get(self) -> Optional[QueueElement[T]]:
if self._should_skip_to_ack:
return self._skip_and_get()
else:
return self._get()

def _skip_and_get(self) -> Tuple[Optional[T], int]:
def _skip_and_get(self) -> Optional[QueueElement[T]]:
ack_version = self._last_ack_file.read_local()
ver = -1
while True:
obj, next_ver = self._get()
if obj is None:
return None, ver
ver = next_ver
if ver > ack_version:
top_element = self._get()
if top_element is None:
return None
if top_element.ver > ack_version:
self._should_skip_to_ack = False
if ver > ack_version + 1:
if top_element.ver > ack_version + 1:
_logger.warning(
"Possible data loss. Last acknowledged operation version: %d, next: %d",
ack_version,
ver,
top_element.ver,
)
return obj, ver
return top_element

def _get(self) -> Tuple[Optional[T], int]:
_json = self._reader.get()
def _get(self) -> Optional[QueueElement[T]]:
_json, size = self._reader.get_with_size()
if not _json:
if self._read_file_version >= self._write_file_version:
return None, -1
return None
self._reader.close()
self._read_file_version = self._next_log_file_version(self._read_file_version)
self._reader = JsonFileSplitter(self._get_log_file(self._read_file_version))
# It is safe. Max recursion level is 2.
return self._get()
try:
return self._deserialize(_json)
obj, ver = self._deserialize(_json)
return QueueElement[T](obj, ver, size)
except Exception as e:
raise MalformedOperation from e

def get_batch(self, size: int) -> Tuple[List[T], int]:
first, ver = self.get()
def get_batch(self, size: int) -> List[QueueElement[T]]:
if self._should_skip_to_ack:
first = self._skip_and_get()
else:
first = self._get()
if not first:
return [], ver
return []

ret = [first]
cur_batch_size = first.size
for _ in range(0, size - 1):
obj, next_ver = self._get()
if not obj:
if cur_batch_size >= self._max_batch_size_bytes:
break
ver = next_ver
ret.append(obj)
return ret, ver
next_obj = self._get()
if not next_obj:
break

cur_batch_size += next_obj.size
ret.append(next_obj)
return ret

def flush(self):
self._writer.flush()
Expand Down
Expand Up @@ -236,10 +236,10 @@ def work(self) -> None:
self._processor._queue.flush()

while True:
batch, version = self._processor._queue.get_batch(self._batch_size)
batch = self._processor._queue.get_batch(self._batch_size)
if not batch:
return
self.process_batch(batch, version)
self.process_batch([element.obj for element in batch], batch[-1].ver)

@Daemon.ConnectionRetryWrapper(
kill_message=(
Expand Down
16 changes: 12 additions & 4 deletions src/neptune/new/internal/utils/json_file_splitter.py
Expand Up @@ -17,11 +17,13 @@
from collections import deque
from io import StringIO
from json import JSONDecodeError
from typing import Optional
from typing import (
Optional,
Tuple,
)


class JsonFileSplitter:

BUFFER_SIZE = 64 * 1024
MAX_PART_READ = 8 * 1024

Expand All @@ -37,11 +39,15 @@ def close(self) -> None:
self._part_buffer.close()

def get(self) -> Optional[dict]:
return (self.get_with_size() or (None, None))[0]

def get_with_size(self) -> Tuple[Optional[dict], int]:
if self._parsed_queue:
return self._parsed_queue.popleft()
self._read_data()
if self._parsed_queue:
return self._parsed_queue.popleft()
return None, 0

def _read_data(self):
if self._part_buffer.tell() < self.MAX_PART_READ:
Expand All @@ -64,12 +70,14 @@ def _decode(self, data: str):
start = self._json_start(data)
while start is not None:
try:
json_data, start = self._decoder.raw_decode(data, start)
json_data, new_start = self._decoder.raw_decode(data, start)
size = new_start - start
start = new_start
except JSONDecodeError:
self._part_buffer.write(data[start:])
break
else:
self._parsed_queue.append(json_data)
self._parsed_queue.append((json_data, size))
start = self._json_start(data, start)

@staticmethod
Expand Down
59 changes: 49 additions & 10 deletions tests/neptune/new/internal/test_disk_queue.py
Expand Up @@ -13,15 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

import json
import random
import threading
import unittest
from glob import glob
from pathlib import Path
from tempfile import TemporaryDirectory

from neptune.new.internal.disk_queue import DiskQueue
from neptune.new.internal.disk_queue import (
DiskQueue,
QueueElement,
)


class TestDiskQueue(unittest.TestCase):
Expand All @@ -33,6 +36,15 @@ def __init__(self, num: int, txt: str):
def __eq__(self, other):
return isinstance(other, TestDiskQueue.Obj) and self.num == other.num and self.txt == other.txt

@staticmethod
def get_obj_size_bytes(obj, version) -> int:
return len(json.dumps({"obj": obj.__dict__, "version": version}))

@staticmethod
def get_queue_element(obj, version) -> QueueElement[Obj]:
obj_size = len(json.dumps({"obj": obj.__dict__, "version": version}))
return QueueElement(obj, version, obj_size)

def test_put(self):
with TemporaryDirectory() as dirpath:
queue = DiskQueue[TestDiskQueue.Obj](
Expand All @@ -44,7 +56,7 @@ def test_put(self):
obj = TestDiskQueue.Obj(5, "test")
queue.put(obj)
queue.flush()
self.assertEqual(queue.get(), (obj, 1))
self.assertEqual(queue.get(), self.get_queue_element(obj, 1))
queue.close()

def test_multiple_files(self):
Expand All @@ -61,7 +73,8 @@ def test_multiple_files(self):
queue.put(obj)
queue.flush()
for i in range(1, 101):
self.assertEqual(queue.get(), (TestDiskQueue.Obj(i, str(i)), i))
obj = TestDiskQueue.Obj(i, str(i))
self.assertEqual(queue.get(), self.get_queue_element(obj, i))
queue.close()
self.assertTrue(queue._read_file_version > 90)
self.assertTrue(queue._write_file_version > 90)
Expand All @@ -82,22 +95,47 @@ def test_get_batch(self):
queue.flush()
self.assertEqual(
queue.get_batch(25),
([TestDiskQueue.Obj(i, str(i)) for i in range(1, 26)], 25),
[self.get_queue_element(TestDiskQueue.Obj(i, str(i)), i) for i in range(1, 26)],
)
self.assertEqual(
queue.get_batch(25),
([TestDiskQueue.Obj(i, str(i)) for i in range(26, 51)], 50),
[self.get_queue_element(TestDiskQueue.Obj(i, str(i)), i) for i in range(26, 51)],
)
self.assertEqual(
queue.get_batch(25),
([TestDiskQueue.Obj(i, str(i)) for i in range(51, 76)], 75),
[self.get_queue_element(TestDiskQueue.Obj(i, str(i)), i) for i in range(51, 76)],
)
self.assertEqual(
queue.get_batch(25),
([TestDiskQueue.Obj(i, str(i)) for i in range(76, 91)], 90),
[self.get_queue_element(TestDiskQueue.Obj(i, str(i)), i) for i in range(76, 91)],
)
queue.close()

def test_batch_limit(self):
obj_size = self.get_obj_size_bytes(TestDiskQueue.Obj(1, "1"), 1)
with TemporaryDirectory() as dirpath:
queue = DiskQueue[TestDiskQueue.Obj](
Path(dirpath),
self._serializer,
self._deserializer,
threading.RLock(),
max_file_size=100,
max_batch_size_bytes=obj_size * 3,
)
for i in range(5):
obj = TestDiskQueue.Obj(i, str(i))
queue.put(obj)
queue.flush()

self.assertEqual(
queue.get_batch(5),
[self.get_queue_element(TestDiskQueue.Obj(i, str(i)), i + 1) for i in range(3)],
)
self.assertEqual(
queue.get_batch(2),
[self.get_queue_element(TestDiskQueue.Obj(i, str(i)), i + 1) for i in range(3, 5)],
)

def test_resuming_queue(self):
with TemporaryDirectory() as dirpath:
queue = DiskQueue[TestDiskQueue.Obj](
Expand All @@ -111,7 +149,7 @@ def test_resuming_queue(self):
obj = TestDiskQueue.Obj(i, str(i))
queue.put(obj)
queue.flush()
_, version = queue.get_batch(random.randrange(300, 400))
version = queue.get_batch(random.randrange(300, 400))[-1].ver
version_to_ack = version - random.randrange(100, 200)
queue.ack(version_to_ack)

Expand All @@ -131,7 +169,8 @@ def test_resuming_queue(self):
max_file_size=200,
)
for i in range(version_to_ack + 1, 501):
self.assertEqual(queue.get(), (TestDiskQueue.Obj(i, str(i)), i))
obj = TestDiskQueue.Obj(i, str(i))
self.assertEqual(queue.get(), self.get_queue_element(obj, i))

queue.close()

Expand Down
42 changes: 42 additions & 0 deletions tests/neptune/new/internal/utils/test_json_file_splitter.py
Expand Up @@ -133,3 +133,45 @@ def test_big_json(self):
self.assertEqual(splitter.get(), {})
self.assertEqual(splitter.get(), None)
splitter.close()

def test_data_size(self):
object1 = """{
"a": 5,
"b": "text"
}"""
object2 = """{
"a": 155,
"r": "something"
}"""
object3 = """{
"a": {
"b": [1, 2, 3]
}
}"""
content1 = """
{
"a": 5,
"b": "text"
}
{
"a": 1"""

content2 = """55,
"r": "something"
}
{
"a": {
"b": [1, 2, 3]
}
}"""

with create_file(content1) as filename, open(filename, "a") as fp:
splitter = JsonFileSplitter(filename)
self.assertEqual(splitter.get_with_size(), ({"a": 5, "b": "text"}, len(object1)))
self.assertIsNone(splitter.get_with_size()[0])
fp.write(content2)
fp.flush()
self.assertEqual(splitter.get_with_size(), ({"a": 155, "r": "something"}, len(object2)))
self.assertEqual(splitter.get_with_size(), ({"a": {"b": [1, 2, 3]}}, len(object3)))
self.assertIsNone(splitter.get_with_size()[0])
splitter.close()

0 comments on commit b52f4ca

Please sign in to comment.