Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Christian Bourjau <cbourjau@users.noreply.github.com>
  • Loading branch information
cbourjau and justinchuby committed Apr 15, 2024
1 parent b6233e9 commit fe257c4
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 13 deletions.
7 changes: 3 additions & 4 deletions onnx/backend/test/loader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@

import json
import os
from typing import List, Optional

from onnx.backend.test.case.test_case import TestCase


def load_model_tests(
data_dir: str | os.PathLike,
kind: Optional[str] = None,
) -> List[TestCase]:
kind: str | None = None,
) -> list[TestCase]:
"""Load model test cases from on-disk data files."""
supported_kinds = os.listdir(data_dir)
if kind not in supported_kinds:
Expand All @@ -32,7 +31,7 @@ def load_model_tests(
if os.path.exists(os.path.join(case_dir, "model.onnx")):
url = None
model_name = test_name[len("test_")]
model_dir: Optional[str] = case_dir
model_dir: str | None = case_dir
else:
with open(os.path.join(case_dir, "data.json")) as f:
data = json.load(f)
Expand Down
10 changes: 5 additions & 5 deletions onnx/backend/test/runner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,22 +76,22 @@ def __init__(
# {category: {name: func}}
self._test_items: dict[str, dict[str, TestItem]] = defaultdict(dict)

for rt in load_model_tests(data_dir=str(test_data_dir), kind="node"):
for rt in load_model_tests(data_dir=test_data_dir, kind="node"):
self._add_model_test(rt, "Node")

for rt in load_model_tests(data_dir=str(test_data_dir), kind="real"):
for rt in load_model_tests(data_dir=test_data_dir, kind="real"):
self._add_model_test(rt, "Real")

for rt in load_model_tests(data_dir=str(test_data_dir), kind="simple"):
for rt in load_model_tests(data_dir=test_data_dir, kind="simple"):
self._add_model_test(rt, "Simple")

for ct in load_model_tests(
data_dir=str(test_data_dir), kind="pytorch-converted"
data_dir=test_data_dir, kind="pytorch-converted"
):
self._add_model_test(ct, "PyTorchConverted")

for ot in load_model_tests(
data_dir=str(test_data_dir), kind="pytorch-operator"
data_dir=test_data_dir, kind="pytorch-operator"
):
self._add_model_test(ot, "PyTorchOperator")

Expand Down
4 changes: 2 additions & 2 deletions onnx/backend/test/stat_coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# SPDX-License-Identifier: Apache-2.0

import os
from pathlib import Path
import pathlib
from typing import IO, Any, Dict, List, Sequence

from onnx import AttributeProto, defs, load
Expand Down Expand Up @@ -154,7 +154,7 @@ def gen_model_test_coverage(
attrs: Dict[str, Dict[str, List[Any]]] = {}
model_paths: List[Any] = []
for rt in load_model_tests(
data_dir=str(Path(__file__).parent / "data"), kind="real"
data_dir=path.Path(__file__).parent / "data", kind="real"
):
if rt.url.startswith("onnx/backend/test/data/light/"):
# testing local files
Expand Down
4 changes: 2 additions & 2 deletions tests/backend/test_backend_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import platform
import unittest
from pathlib import Path
import pathlib
from typing import Any

import numpy
Expand All @@ -27,7 +27,7 @@
ort: Any = None # type: ignore[no-redef]
ort_version: Any = None # type: ignore[no-redef]

TEST_DATA_DIR = Path(__file__).parent.parent.parent / "onnx/backend/test/data"
TEST_DATA_DIR = pathlib.Path(__file__).parent.parent.parent / "onnx/backend/test/data"

# The following just executes a backend based on InferenceSession through the backend test

Expand Down

0 comments on commit fe257c4

Please sign in to comment.