Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Christian Bourjau <cbourjau@users.noreply.github.com>
  • Loading branch information
cbourjau and justinchuby committed Apr 15, 2024
1 parent 477fcf1 commit 6976c1a
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 10 deletions.
2 changes: 1 addition & 1 deletion onnx/backend/test/loader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


def load_model_tests(
data_dir: str,
data_dir: str | os.PathLike,
kind: Optional[str] = None,
) -> List[TestCase]:
"""Load model test cases from on-disk data files."""
Expand Down
2 changes: 1 addition & 1 deletion onnx/backend/test/runner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class Runner:
def __init__(
self,
backend: type[Backend],
test_data_dir: Path,
test_data_dir: str | os.PathLike,
parent_module: str | None = None,
test_kwargs: dict | None = None,
) -> None:
Expand Down
10 changes: 5 additions & 5 deletions tests/backend/test_backend_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
# SPDX-License-Identifier: Apache-2.0

import os
import pathlib
import platform
import sys
import unittest
from pathlib import Path
from typing import Any

import numpy
import packaging

import onnx.backend.base
import onnx.backend.test
Expand All @@ -18,9 +19,8 @@
from onnx import ModelProto
from onnx.backend.base import Device, DeviceType
from onnx.reference import ReferenceEvaluator
from onnx.test import version_utils

TEST_DATA_DIR = Path(__file__).parent.parent.parent / "onnx/backend/test/data"
_TEST_DATA_DIR = pathlib.Path(__file__).parent.parent.parent / "onnx/backend/test/data"

# The following just executes a backend based on ReferenceEvaluator through the backend test

Expand Down Expand Up @@ -93,7 +93,7 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):


backend_test = onnx.backend.test.BackendTest(
ReferenceEvaluatorBackend, test_data_dir=TEST_DATA_DIR, parent_module=__name__
ReferenceEvaluatorBackend, test_data_dir=_TEST_DATA_DIR, parent_module=__name__
)

if os.getenv("APPVEYOR"):
Expand Down Expand Up @@ -198,7 +198,7 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
backend_test.exclude("test_qlinearmatmul_3D_int8_float32_cpu")

# op_dft and op_stft requires numpy >= 1.21.5
if version_utils.numpy_older_than("1.21.5"):
if packaging.version.parse(numpy.__version__) < packaging.version.parse("1.21.5"):
backend_test.exclude("test_stft")
backend_test.exclude("test_stft_with_window")
backend_test.exclude("test_stft_cpu")
Expand Down
6 changes: 3 additions & 3 deletions tests/backend/test_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

import itertools
import os
import pathlib
import platform
import unittest
from pathlib import Path
from typing import Any, Optional, Sequence, Tuple

import numpy
Expand All @@ -30,7 +30,7 @@
# We don't enable report in this test because the report collection logic itself
# fails when models are mal-formed.

TEST_DATA_DIR = Path(__file__).parent.parent.parent / "onnx/backend/test/data"
_TEST_DATA_DIR = pathlib.Path(__file__).parent.parent.parent / "onnx/backend/test/data"


class DummyBackend(onnx.backend.base.Backend):
Expand Down Expand Up @@ -115,7 +115,7 @@ def do_enforce_test_coverage_safelist(model: ModelProto) -> bool:

backend_test = onnx.backend.test.BackendTest(
DummyBackend,
test_data_dir=TEST_DATA_DIR,
test_data_dir=_TEST_DATA_DIR,
parent_module=__name__,
test_kwargs=test_kwargs,
)
Expand Down

0 comments on commit 6976c1a

Please sign in to comment.