diff --git a/src/neptune/new/cli/sync.py b/src/neptune/new/cli/sync.py index 71bbde923..ea1d55d44 100644 --- a/src/neptune/new/cli/sync.py +++ b/src/neptune/new/cli/sync.py @@ -89,9 +89,11 @@ def sync_execution( lock=threading.RLock(), ) as disk_queue: while True: - batch, version = disk_queue.get_batch(1000) + batch = disk_queue.get_batch(1000) if not batch: break + version = batch[-1].ver + batch = [element.obj for element in batch] start_time = time.monotonic() expected_count = len(batch) diff --git a/src/neptune/new/internal/disk_queue.py b/src/neptune/new/internal/disk_queue.py index 69763d8f9..96e36e6f8 100644 --- a/src/neptune/new/internal/disk_queue.py +++ b/src/neptune/new/internal/disk_queue.py @@ -18,6 +18,7 @@ import os import shutil import threading +from dataclasses import dataclass from glob import glob from pathlib import Path from typing import ( @@ -33,16 +34,23 @@ from neptune.new.internal.utils.json_file_splitter import JsonFileSplitter from neptune.new.internal.utils.sync_offset_file import SyncOffsetFile -__all__ = ["DiskQueue"] +__all__ = ["QueueElement", "DiskQueue"] T = TypeVar("T") _logger = logging.getLogger(__name__) -class DiskQueue(Generic[T]): +@dataclass +class QueueElement(Generic[T]): + obj: T + ver: int + size: int + +class DiskQueue(Generic[T]): # NOTICE: This class is thread-safe as long as there is only one consumer and one producer. + DEFAULT_MAX_BATCH_SIZE_BYTES = 100 * 1024**2 def __init__( self, @@ -51,11 +59,15 @@ def __init__( from_dict: Callable[[dict], T], lock: threading.RLock, max_file_size: int = 64 * 1024**2, + max_batch_size_bytes: int = None, ): self._dir_path = dir_path.resolve() self._to_dict = to_dict self._from_dict = from_dict self._max_file_size = max_file_size + self._max_batch_size_bytes = max_batch_size_bytes or int( + os.environ.get("NEPTUNE_MAX_BATCH_SIZE_BYTES") or str(self.DEFAULT_MAX_BATCH_SIZE_BYTES) + ) try: os.makedirs(self._dir_path) @@ -90,57 +102,64 @@ def put(self, obj: T) -> int: self._file_size += len(_json) + 1 return version - def get(self) -> Tuple[Optional[T], int]: + def get(self) -> Optional[QueueElement[T]]: if self._should_skip_to_ack: return self._skip_and_get() else: return self._get() - def _skip_and_get(self) -> Tuple[Optional[T], int]: + def _skip_and_get(self) -> Optional[QueueElement[T]]: ack_version = self._last_ack_file.read_local() - ver = -1 while True: - obj, next_ver = self._get() - if obj is None: - return None, ver - ver = next_ver - if ver > ack_version: + top_element = self._get() + if top_element is None: + return None + if top_element.ver > ack_version: self._should_skip_to_ack = False - if ver > ack_version + 1: + if top_element.ver > ack_version + 1: _logger.warning( "Possible data loss. Last acknowledged operation version: %d, next: %d", ack_version, - ver, + top_element.ver, ) - return obj, ver + return top_element - def _get(self) -> Tuple[Optional[T], int]: - _json = self._reader.get() + def _get(self) -> Optional[QueueElement[T]]: + _json, size = self._reader.get_with_size() if not _json: if self._read_file_version >= self._write_file_version: - return None, -1 + return None self._reader.close() self._read_file_version = self._next_log_file_version(self._read_file_version) self._reader = JsonFileSplitter(self._get_log_file(self._read_file_version)) # It is safe. Max recursion level is 2. return self._get() try: - return self._deserialize(_json) + obj, ver = self._deserialize(_json) + return QueueElement[T](obj, ver, size) except Exception as e: raise MalformedOperation from e - def get_batch(self, size: int) -> Tuple[List[T], int]: - first, ver = self.get() + def get_batch(self, size: int) -> List[QueueElement[T]]: + if self._should_skip_to_ack: + first = self._skip_and_get() + else: + first = self._get() if not first: - return [], ver + return [] + ret = [first] + cur_batch_size = first.size for _ in range(0, size - 1): - obj, next_ver = self._get() - if not obj: + if cur_batch_size >= self._max_batch_size_bytes: break - ver = next_ver - ret.append(obj) - return ret, ver + next_obj = self._get() + if not next_obj: + break + + cur_batch_size += next_obj.size + ret.append(next_obj) + return ret def flush(self): self._writer.flush() diff --git a/src/neptune/new/internal/operation_processors/async_operation_processor.py b/src/neptune/new/internal/operation_processors/async_operation_processor.py index ca5d96022..800aaa42e 100644 --- a/src/neptune/new/internal/operation_processors/async_operation_processor.py +++ b/src/neptune/new/internal/operation_processors/async_operation_processor.py @@ -236,10 +236,10 @@ def work(self) -> None: self._processor._queue.flush() while True: - batch, version = self._processor._queue.get_batch(self._batch_size) + batch = self._processor._queue.get_batch(self._batch_size) if not batch: return - self.process_batch(batch, version) + self.process_batch([element.obj for element in batch], batch[-1].ver) @Daemon.ConnectionRetryWrapper( kill_message=( diff --git a/src/neptune/new/internal/utils/json_file_splitter.py b/src/neptune/new/internal/utils/json_file_splitter.py index f764c88ff..7e2390fa4 100644 --- a/src/neptune/new/internal/utils/json_file_splitter.py +++ b/src/neptune/new/internal/utils/json_file_splitter.py @@ -17,11 +17,13 @@ from collections import deque from io import StringIO from json import JSONDecodeError -from typing import Optional +from typing import ( + Optional, + Tuple, +) class JsonFileSplitter: - BUFFER_SIZE = 64 * 1024 MAX_PART_READ = 8 * 1024 @@ -37,11 +39,15 @@ def close(self) -> None: self._part_buffer.close() def get(self) -> Optional[dict]: + return (self.get_with_size() or (None, None))[0] + + def get_with_size(self) -> Tuple[Optional[dict], int]: if self._parsed_queue: return self._parsed_queue.popleft() self._read_data() if self._parsed_queue: return self._parsed_queue.popleft() + return None, 0 def _read_data(self): if self._part_buffer.tell() < self.MAX_PART_READ: @@ -64,12 +70,14 @@ def _decode(self, data: str): start = self._json_start(data) while start is not None: try: - json_data, start = self._decoder.raw_decode(data, start) + json_data, new_start = self._decoder.raw_decode(data, start) + size = new_start - start + start = new_start except JSONDecodeError: self._part_buffer.write(data[start:]) break else: - self._parsed_queue.append(json_data) + self._parsed_queue.append((json_data, size)) start = self._json_start(data, start) @staticmethod diff --git a/tests/neptune/new/internal/test_disk_queue.py b/tests/neptune/new/internal/test_disk_queue.py index 4115f0fef..fdd2d6b7e 100644 --- a/tests/neptune/new/internal/test_disk_queue.py +++ b/tests/neptune/new/internal/test_disk_queue.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import json import random import threading import unittest @@ -21,7 +21,10 @@ from pathlib import Path from tempfile import TemporaryDirectory -from neptune.new.internal.disk_queue import DiskQueue +from neptune.new.internal.disk_queue import ( + DiskQueue, + QueueElement, +) class TestDiskQueue(unittest.TestCase): @@ -33,6 +36,15 @@ def __init__(self, num: int, txt: str): def __eq__(self, other): return isinstance(other, TestDiskQueue.Obj) and self.num == other.num and self.txt == other.txt + @staticmethod + def get_obj_size_bytes(obj, version) -> int: + return len(json.dumps({"obj": obj.__dict__, "version": version})) + + @staticmethod + def get_queue_element(obj, version) -> QueueElement[Obj]: + obj_size = len(json.dumps({"obj": obj.__dict__, "version": version})) + return QueueElement(obj, version, obj_size) + def test_put(self): with TemporaryDirectory() as dirpath: queue = DiskQueue[TestDiskQueue.Obj]( @@ -44,7 +56,7 @@ def test_put(self): obj = TestDiskQueue.Obj(5, "test") queue.put(obj) queue.flush() - self.assertEqual(queue.get(), (obj, 1)) + self.assertEqual(queue.get(), self.get_queue_element(obj, 1)) queue.close() def test_multiple_files(self): @@ -61,7 +73,8 @@ def test_multiple_files(self): queue.put(obj) queue.flush() for i in range(1, 101): - self.assertEqual(queue.get(), (TestDiskQueue.Obj(i, str(i)), i)) + obj = TestDiskQueue.Obj(i, str(i)) + self.assertEqual(queue.get(), self.get_queue_element(obj, i)) queue.close() self.assertTrue(queue._read_file_version > 90) self.assertTrue(queue._write_file_version > 90) @@ -82,22 +95,47 @@ def test_get_batch(self): queue.flush() self.assertEqual( queue.get_batch(25), - ([TestDiskQueue.Obj(i, str(i)) for i in range(1, 26)], 25), + [self.get_queue_element(TestDiskQueue.Obj(i, str(i)), i) for i in range(1, 26)], ) self.assertEqual( queue.get_batch(25), - ([TestDiskQueue.Obj(i, str(i)) for i in range(26, 51)], 50), + [self.get_queue_element(TestDiskQueue.Obj(i, str(i)), i) for i in range(26, 51)], ) self.assertEqual( queue.get_batch(25), - ([TestDiskQueue.Obj(i, str(i)) for i in range(51, 76)], 75), + [self.get_queue_element(TestDiskQueue.Obj(i, str(i)), i) for i in range(51, 76)], ) self.assertEqual( queue.get_batch(25), - ([TestDiskQueue.Obj(i, str(i)) for i in range(76, 91)], 90), + [self.get_queue_element(TestDiskQueue.Obj(i, str(i)), i) for i in range(76, 91)], ) queue.close() + def test_batch_limit(self): + obj_size = self.get_obj_size_bytes(TestDiskQueue.Obj(1, "1"), 1) + with TemporaryDirectory() as dirpath: + queue = DiskQueue[TestDiskQueue.Obj]( + Path(dirpath), + self._serializer, + self._deserializer, + threading.RLock(), + max_file_size=100, + max_batch_size_bytes=obj_size * 3, + ) + for i in range(5): + obj = TestDiskQueue.Obj(i, str(i)) + queue.put(obj) + queue.flush() + + self.assertEqual( + queue.get_batch(5), + [self.get_queue_element(TestDiskQueue.Obj(i, str(i)), i + 1) for i in range(3)], + ) + self.assertEqual( + queue.get_batch(2), + [self.get_queue_element(TestDiskQueue.Obj(i, str(i)), i + 1) for i in range(3, 5)], + ) + def test_resuming_queue(self): with TemporaryDirectory() as dirpath: queue = DiskQueue[TestDiskQueue.Obj]( @@ -111,7 +149,7 @@ def test_resuming_queue(self): obj = TestDiskQueue.Obj(i, str(i)) queue.put(obj) queue.flush() - _, version = queue.get_batch(random.randrange(300, 400)) + version = queue.get_batch(random.randrange(300, 400))[-1].ver version_to_ack = version - random.randrange(100, 200) queue.ack(version_to_ack) @@ -131,7 +169,8 @@ def test_resuming_queue(self): max_file_size=200, ) for i in range(version_to_ack + 1, 501): - self.assertEqual(queue.get(), (TestDiskQueue.Obj(i, str(i)), i)) + obj = TestDiskQueue.Obj(i, str(i)) + self.assertEqual(queue.get(), self.get_queue_element(obj, i)) queue.close() diff --git a/tests/neptune/new/internal/utils/test_json_file_splitter.py b/tests/neptune/new/internal/utils/test_json_file_splitter.py index 97eccc1ab..a42ed73f2 100644 --- a/tests/neptune/new/internal/utils/test_json_file_splitter.py +++ b/tests/neptune/new/internal/utils/test_json_file_splitter.py @@ -133,3 +133,45 @@ def test_big_json(self): self.assertEqual(splitter.get(), {}) self.assertEqual(splitter.get(), None) splitter.close() + + def test_data_size(self): + object1 = """{ + "a": 5, + "b": "text" + }""" + object2 = """{ + "a": 155, + "r": "something" + }""" + object3 = """{ + "a": { + "b": [1, 2, 3] + } + }""" + content1 = """ + { + "a": 5, + "b": "text" + } + { + "a": 1""" + + content2 = """55, + "r": "something" + } + { + "a": { + "b": [1, 2, 3] + } + }""" + + with create_file(content1) as filename, open(filename, "a") as fp: + splitter = JsonFileSplitter(filename) + self.assertEqual(splitter.get_with_size(), ({"a": 5, "b": "text"}, len(object1))) + self.assertIsNone(splitter.get_with_size()[0]) + fp.write(content2) + fp.flush() + self.assertEqual(splitter.get_with_size(), ({"a": 155, "r": "something"}, len(object2))) + self.assertEqual(splitter.get_with_size(), ({"a": {"b": [1, 2, 3]}}, len(object3))) + self.assertIsNone(splitter.get_with_size()[0]) + splitter.close()