Skip to content

Commit

Permalink
Make get_model_builder public (#6560)
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Sep 12, 2022
1 parent a67cc87 commit cac4e22
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 7 deletions.
15 changes: 15 additions & 0 deletions test/test_extended_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,21 @@ def test_get_model(name, model_class):
assert isinstance(models.get_model(name), model_class)


@pytest.mark.parametrize(
"name, model_fn",
[
("resnet50", models.resnet50),
("retinanet_resnet50_fpn_v2", models.detection.retinanet_resnet50_fpn_v2),
("raft_large", models.optical_flow.raft_large),
("quantized_resnet50", models.quantization.resnet50),
("lraspp_mobilenet_v3_large", models.segmentation.lraspp_mobilenet_v3_large),
("mvit_v1_b", models.video.mvit_v1_b),
],
)
def test_get_model_builder(name, model_fn):
assert models.get_model_builder(name) == model_fn


@pytest.mark.parametrize(
"name, weight",
[
Expand Down
4 changes: 2 additions & 2 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@
from _utils_internal import get_relative_path
from common_utils import cpu_and_gpu, freeze_rng_state, map_nested_tensor_object, needs_cuda, set_rng_seed
from torchvision import models
from torchvision.models._api import find_model, list_models
from torchvision.models import get_model_builder, list_models


ACCEPT = os.getenv("EXPECTTEST_ACCEPT", "0") == "1"
SKIP_BIG_MODEL = os.getenv("SKIP_BIG_MODEL", "1") == "1"


def list_model_fns(module):
return [find_model(name) for name in list_models(module)]
return [get_model_builder(name) for name in list_models(module)]


@pytest.fixture
Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
from .vision_transformer import *
from .swin_transformer import *
from . import detection, optical_flow, quantization, segmentation, video
from ._api import get_model, get_model_weights, get_weight, list_models
from ._api import get_model, get_model_builder, get_model_weights, get_weight, list_models
19 changes: 15 additions & 4 deletions torchvision/models/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .._internally_replaced_utils import load_state_dict_from_url


__all__ = ["WeightsEnum", "Weights", "get_model", "get_model_weights", "get_weight", "list_models"]
__all__ = ["WeightsEnum", "Weights", "get_model", "get_model_builder", "get_model_weights", "get_weight", "list_models"]


@dataclass
Expand Down Expand Up @@ -127,7 +127,7 @@ def get_model_weights(name: Union[Callable, str]) -> W:
Returns:
weights_enum (W): The weights enum class associated with the model.
"""
model = find_model(name) if isinstance(name, str) else name
model = get_model_builder(name) if isinstance(name, str) else name
return cast(W, _get_enum_from_fn(model))


Expand Down Expand Up @@ -199,7 +199,18 @@ def list_models(module: Optional[ModuleType] = None) -> List[str]:
return sorted(models)


def find_model(name: str) -> Callable[..., M]:
def get_model_builder(name: str) -> Callable[..., M]:
"""
Gets the model name and returns the model builder method.
.. betastatus:: function
Args:
name (str): The name under which the model is registered.
Returns:
fn (Callable): The model builder method.
"""
name = name.lower()
try:
fn = BUILTIN_MODELS[name]
Expand All @@ -221,5 +232,5 @@ def get_model(name: str, **config: Any) -> M:
Returns:
model (nn.Module): The initialized model.
"""
fn = find_model(name)
fn = get_model_builder(name)
return fn(**config)

0 comments on commit cac4e22

Please sign in to comment.