Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test(sdk): add many tests for InternalApi.upload_file #4539

Merged
merged 16 commits into from Dec 2, 2022
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions tests/unit_tests/conftest.py
Expand Up @@ -119,6 +119,12 @@ def copy_asset_fn(
# --------------------------------


@pytest.fixture
def mock_responses():
with responses.RequestsMock() as rsps:
yield rsps


@pytest.fixture(scope="function", autouse=True)
def unset_global_objects():
from wandb.sdk.lib.module import unset_globals
Expand Down
288 changes: 287 additions & 1 deletion tests/unit_tests/test_internal_api.py
Expand Up @@ -2,12 +2,16 @@
import hashlib
import os
import tempfile
from typing import Optional
from pathlib import Path
from typing import Callable, Mapping, Optional, Type
from unittest.mock import Mock, call

import pytest
import requests
import responses
from wandb.apis import internal
from wandb.errors import CommError
from wandb.sdk.lib import retry


def test_agent_heartbeat_with_no_agent_id_fails():
Expand Down Expand Up @@ -72,3 +76,285 @@ def test_download_write_file_fetches_iff_file_checksum_mismatched(
assert response is not None
else:
assert response is None


@pytest.fixture
def some_file(tmp_path: Path):
p = tmp_path / "some_file.txt"
p.write_text("some text")
return p


class TestUploadFile:
class TestSimple:
def test_adds_headers_to_request(
self, mock_responses: responses.RequestsMock, some_file: Path
):
response_callback = Mock(return_value=(200, {}, "success!"))
mock_responses.add_callback(
"PUT", "http://example.com/upload-dst", response_callback
)
internal.InternalApi().upload_file(
"http://example.com/upload-dst",
some_file.open("rb"),
extra_headers={"X-Test": "test"},
)
assert response_callback.call_args[0][0].headers["X-Test"] == "test"

def test_returns_response_on_success(
self, mock_responses: responses.RequestsMock, some_file: Path
):
mock_responses.add(
"PUT", "http://example.com/upload-dst", status=200, body="success!"
)
resp = internal.InternalApi().upload_file(
"http://example.com/upload-dst", some_file.open("rb")
)
assert resp.content == b"success!"

@pytest.mark.parametrize(
"status,transient", [(400, False), (500, True), (502, True)]
)
def test_returns_transient_error_on_transient_statuscodes(
self,
mock_responses: responses.RequestsMock,
some_file: Path,
status: int,
transient: bool,
):
mock_responses.add(
"PUT", "http://example.com/upload-dst", status=status, body="failure!"
)
with pytest.raises(
retry.TransientError if transient else requests.exceptions.HTTPError
):
internal.InternalApi().upload_file(
"http://example.com/upload-dst", some_file.open("rb")
)

@pytest.mark.parametrize(
"error",
[requests.exceptions.ConnectionError(), requests.exceptions.Timeout()],
)
def test_returns_transient_error_on_network_errors(
self,
mock_responses: responses.RequestsMock,
some_file: Path,
error: Exception,
):
mock_responses.add("PUT", "http://example.com/upload-dst", body=error)
with pytest.raises(retry.TransientError):
internal.InternalApi().upload_file(
"http://example.com/upload-dst", some_file.open("rb")
)

class TestProgressCallback:
def test_smoke(self, mock_responses: responses.RequestsMock, some_file: Path):
file_contents = "some text"
some_file.write_text(file_contents)

def response_callback(request: requests.models.PreparedRequest):
assert request.body.read() == file_contents.encode()
return (200, {}, "success!")

mock_responses.add_callback(
"PUT", "http://example.com/upload-dst", response_callback
)

progress_callback = Mock()
internal.InternalApi().upload_file(
"http://example.com/upload-dst",
some_file.open("rb"),
callback=progress_callback,
)

assert progress_callback.call_args_list == [
call(len(file_contents), len(file_contents))
]

def test_handles_multiple_calls(
self, mock_responses: responses.RequestsMock, some_file: Path
):
some_file.write_text("12345")

def response_callback(request: requests.models.PreparedRequest):
assert request.body.read(2) == b"12"
assert request.body.read(2) == b"34"
assert request.body.read() == b"5"
assert request.body.read() == b""
return (200, {}, "success!")

mock_responses.add_callback(
"PUT", "http://example.com/upload-dst", response_callback
)

progress_callback = Mock()
internal.InternalApi().upload_file(
"http://example.com/upload-dst",
some_file.open("rb"),
callback=progress_callback,
)

assert progress_callback.call_args_list == [
call(2, 2),
call(2, 4),
call(1, 5),
call(0, 5),
]

@pytest.mark.parametrize(
"failure",
[
requests.exceptions.Timeout(),
requests.exceptions.ConnectionError(),
(500, {}, ""),
],
)
def test_rewinds_on_failure(
self, mock_responses: responses.RequestsMock, some_file: Path, failure
):
some_file.write_text("1234567")

def response_callback(request: requests.models.PreparedRequest):
assert request.body.read(2) == b"12"
assert request.body.read(2) == b"34"
return failure

mock_responses.add_callback(
"PUT", "http://example.com/upload-dst", response_callback
)

progress_callback = Mock()
with pytest.raises(Exception):
internal.InternalApi().upload_file(
"http://example.com/upload-dst",
some_file.open("rb"),
callback=progress_callback,
)

assert progress_callback.call_args_list == [
call(2, 2),
call(2, 4),
call(-4, 0),
]

@pytest.mark.parametrize(
"request_headers,response,expected_errtype",
[
(
{"x-amz-meta-md5": "1234"},
(400, {}, "blah blah RequestTimeout blah blah"),
retry.TransientError,
),
(
{"x-amz-meta-md5": "1234"},
(400, {}, "non-timeout-related error message"),
requests.RequestException,
),
(
{"x-amz-meta-md5": "1234"},
requests.exceptions.ConnectionError(),
retry.TransientError,
),
(
{},
(400, {}, "blah blah RequestTimeout blah blah"),
requests.RequestException,
),
],
)
def test_transient_failure_on_special_aws_request_timeout(
self,
mock_responses: responses.RequestsMock,
some_file: Path,
request_headers: Mapping[str, str],
response,
expected_errtype: Type[Exception],
):
mock_responses.add_callback(
"PUT", "http://example.com/upload-dst", lambda _: response
)
with pytest.raises(expected_errtype):
internal.InternalApi().upload_file(
"http://example.com/upload-dst",
some_file.open("rb"),
extra_headers=request_headers,
)

class TestAzure:
MAGIC_HEADERS = {"x-ms-blob-type": "SomeBlobType"}

@pytest.mark.parametrize(
"request_headers,uses_azure_lib",
[
({}, False),
(MAGIC_HEADERS, True),
],
)
def test_uses_azure_lib_if_available(
self,
mock_responses: responses.RequestsMock,
some_file: Path,
request_headers: Mapping[str, str],
uses_azure_lib: bool,
):
api = internal.InternalApi()

if uses_azure_lib:
api._azure_blob_module = Mock()
else:
mock_responses.add("PUT", "http://example.com/upload-dst")

api.upload_file(
"http://example.com/upload-dst",
some_file.open("rb"),
extra_headers=request_headers,
)

if uses_azure_lib:
api._azure_blob_module.BlobClient.from_blob_url().upload_blob.assert_called_once()
else:
assert len(mock_responses.calls) == 1

@pytest.mark.parametrize(
"response,expected_errtype,check_err",
[
(
(400, {}, "my-reason"),
requests.RequestException,
lambda e: e.response.status_code == 400 and "my-reason" in str(e),
),
(
(500, {}, "my-reason"),
retry.TransientError,
lambda e: (
e.exception.response.status_code == 500
and "my-reason" in str(e.exception)
),
),
(
requests.exceptions.ConnectionError("my-reason"),
retry.TransientError,
lambda e: "my-reason" in str(e.exception),
),
],
)
def test_translates_azure_err_to_normal_err(
self,
mock_responses: responses.RequestsMock,
some_file: Path,
response,
expected_errtype: Type[Exception],
check_err: Callable[[Exception], bool],
):
mock_responses.add_callback(
"PUT", "https://example.com/foo/bar/baz", lambda _: response
)
with pytest.raises(expected_errtype) as e:
internal.InternalApi().upload_file(
"https://example.com/foo/bar/baz",
some_file.open("rb"),
extra_headers=self.MAGIC_HEADERS,
)

assert check_err(e.value), e.value
3 changes: 1 addition & 2 deletions wandb/sdk/internal/internal_api.py
Expand Up @@ -1883,7 +1883,6 @@ def upload_file_azure(
response = requests.models.Response()
response.status_code = e.response.status_code
response.headers = e.response.headers
response.raw = e.response.internal_response
speezepearson marked this conversation as resolved.
Show resolved Hide resolved
raise requests.exceptions.RequestException(e.message, response=response)
else:
raise requests.exceptions.ConnectionError(e.message)
Expand Down Expand Up @@ -1932,7 +1931,7 @@ def upload_file(
is_aws_retryable = (
"x-amz-meta-md5" in extra_headers
and status_code == 400
and "RequestTimeout" in response_content
speezepearson marked this conversation as resolved.
Show resolved Hide resolved
and "RequestTimeout" in str(response_content)
)
# We need to rewind the file for the next retry (the file passed in is seeked to 0)
progress.rewind()
Expand Down
2 changes: 1 addition & 1 deletion wandb/sdk/internal/progress.py
Expand Up @@ -59,7 +59,7 @@ def read(self, size=-1):
return bites

def rewind(self) -> None:
self.callback(0, -self.bytes_read)
speezepearson marked this conversation as resolved.
Show resolved Hide resolved
self.callback(-self.bytes_read, 0)
self.bytes_read = 0
self.file.seek(0)

Expand Down