Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make WeightEnum and Weights public + cleanups #7100

Merged
merged 9 commits into from
Feb 2, 2023
2 changes: 1 addition & 1 deletion test/test_extended_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
from common_extended_utils import get_file_size_mb, get_ops
from torchvision import models
from torchvision.models._api import get_model_weights, Weights, WeightsEnum
from torchvision.models import get_model_weights, Weights, WeightsEnum
from torchvision.models._utils import handle_legacy_interface

run_if_test_with_extended = pytest.mark.skipif(
Expand Down
7 changes: 6 additions & 1 deletion torchvision/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,9 @@
from .swin_transformer import *
from .maxvit import *
from . import detection, optical_flow, quantization, segmentation, video
from ._api import get_model, get_model_builder, get_model_weights, get_weight, list_models

# The Weights and WeightsEnum are developer-facing utils that we make public for
# downstream libs like torchgeo https://github.com/pytorch/vision/issues/7094
# TODO: we could / should document them publicly, but it's not clear where, as
# they're not intended for end users.
from ._api import get_model, get_model_builder, get_model_weights, get_weight, list_models, Weights, WeightsEnum
31 changes: 16 additions & 15 deletions torchvision/models/_api.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import importlib
import inspect
import sys
from dataclasses import dataclass, fields
from dataclasses import dataclass
from enum import Enum
from functools import partial
from inspect import signature
from types import ModuleType
from typing import Any, Callable, cast, Dict, List, Mapping, Optional, TypeVar, Union

from torch import nn

from torchvision._utils import StrEnum

from .._internally_replaced_utils import load_state_dict_from_url


Expand Down Expand Up @@ -65,7 +64,7 @@ def __eq__(self, other: Any) -> bool:
return self.transforms == other.transforms


class WeightsEnum(StrEnum):
class WeightsEnum(Enum):
"""
This class is the parent class of all model weights. Each model building method receives an optional `weights`
parameter with its associated pre-trained weights. It inherits from `Enum` and its values should be of type
Expand All @@ -75,14 +74,11 @@ class WeightsEnum(StrEnum):
value (Weights): The data class entry with the weight information.
"""

def __init__(self, value: Weights):
self._value_ = value

@classmethod
def verify(cls, obj: Any) -> Any:
if obj is not None:
if type(obj) is str:
obj = cls.from_str(obj.replace(cls.__name__ + ".", ""))
obj = cls[obj.replace(cls.__name__ + ".", "")]
elif not isinstance(obj, cls):
raise TypeError(
f"Invalid Weight class provided; expected {cls.__name__} but received {obj.__class__.__name__}."
Expand All @@ -95,12 +91,17 @@ def get_state_dict(self, progress: bool) -> Mapping[str, Any]:
def __repr__(self) -> str:
return f"{self.__class__.__name__}.{self._name_}"

def __getattr__(self, name):
# Be able to fetch Weights attributes directly
for f in fields(Weights):
if f.name == name:
return object.__getattribute__(self.value, name)
return super().__getattr__(name)
@property
def url(self):
return self.value.url

@property
def transforms(self):
return self.value.transforms

@property
def meta(self):
return self.value.meta


def get_weight(name: str) -> WeightsEnum:
Expand Down Expand Up @@ -136,7 +137,7 @@ def get_weight(name: str) -> WeightsEnum:
if weights_enum is None:
raise ValueError(f"The weight enum '{enum_name}' for the specific method couldn't be retrieved.")

return weights_enum.from_str(value_name)
return weights_enum[value_name]


def get_model_weights(name: Union[Callable, str]) -> WeightsEnum:
Expand Down