Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve performance when sending large messages #4119

Merged
merged 16 commits into from
Aug 18, 2022
55 changes: 44 additions & 11 deletions wandb/sdk/lib/sock_client.py
Expand Up @@ -2,7 +2,7 @@
import struct
import threading
import time
from typing import Any, Optional, TYPE_CHECKING
from typing import Any, List, Optional, TYPE_CHECKING
import uuid

from wandb.proto import wandb_server_pb2 as spb
Expand All @@ -21,7 +21,9 @@ class SockClientClosedError(Exception):

class SockClient:
_sock: socket.socket
_data: bytes
_buffer_list: List[bytes]
_buffer_lens: List[int]
_buffer_total: int
_sockid: str
_retry_delay: float
_lock: "threading.Lock"
Expand All @@ -30,7 +32,9 @@ class SockClient:
HEADLEN = 1 + 4

def __init__(self) -> None:
self._data = b""
self._buffer_list = []
self._buffer_lens = []
self._buffer_total = 0
# TODO: use safe uuid's (python3.7+) or emulate this
self._sockid = uuid.uuid4().hex
self._retry_delay = 0.1
Expand Down Expand Up @@ -149,20 +153,48 @@ def send_record_publish(self, record: "pb.Record") -> None:
server_req.record_publish.CopyFrom(record)
self.send_server_request(server_req)

def _buffer_get(self, start: int, end: int, advance: bool = False) -> bytes:
if advance:
data = b"".join(self._buffer_list)
raubitsj marked this conversation as resolved.
Show resolved Hide resolved
requested = data[start:end]
leftover = data[end:]
raubitsj marked this conversation as resolved.
Show resolved Hide resolved
self._buffer_total = len(leftover)
self._buffer_list = [leftover]
self._buffer_lens = [self._buffer_total]
return requested

buffers = []
need = end
for buf_len, buf_data in zip(self._buffer_lens, self._buffer_list):
buffers.append(buf_data if need >= buf_len else buf_data[:need])
need -= buf_len
if need <= 0:
break
data = b"".join(buffers)
requested = data[start:end]

return requested

def _buffer_append(self, data: bytes, data_len: int) -> None:
self._buffer_list.append(data)
self._buffer_lens.append(data_len)
self._buffer_total += data_len

def _extract_packet_bytes(self) -> Optional[bytes]:
# Do we have enough data to read the header?
len_data = len(self._data)
start_offset = self.HEADLEN
if len_data >= start_offset:
header = self._data[:start_offset]
if self._buffer_total >= start_offset:
# header = self._data[:start_offset]
header = self._buffer_get(0, start_offset)
fields = struct.unpack("<BI", header)
magic, dlength = fields
assert magic == ord("W")
# Do we have enough data to read the full record?
end_offset = self.HEADLEN + dlength
if len_data >= end_offset:
rec_data = self._data[start_offset:end_offset]
self._data = self._data[end_offset:]
if self._buffer_total >= end_offset:
rec_data = self._buffer_get(start_offset, end_offset, advance=True)
# rec_data = self._data[start_offset:end_offset]
# self._data = self._data[end_offset:]
return rec_data
return None

Expand Down Expand Up @@ -193,11 +225,12 @@ def _read_packet_bytes(self, timeout: int = None) -> Optional[bytes]:
finally:
if timeout:
self._sock.settimeout(None)
if len(data) == 0:
data_len = len(data)
if data_len == 0:
# socket.recv() will return 0 bytes if socket was shutdown
# caller will handle this condition like other connection problems
raise SockClientClosedError()
self._data += data
self._buffer_append(data, data_len)
return None

def read_server_request(self) -> Optional[spb.ServerRequest]:
Expand Down