Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make WeightEnum and Weights public + cleanups #7100

Merged
merged 9 commits into from
Feb 2, 2023
12 changes: 6 additions & 6 deletions torchvision/models/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import inspect
import sys
from dataclasses import dataclass, fields
from enum import Enum
from inspect import signature
from types import ModuleType
from typing import Any, Callable, cast, Dict, List, Mapping, Optional, TypeVar, Union
Expand Down Expand Up @@ -38,7 +39,8 @@ class Weights:
meta: Dict[str, Any]


class WeightsEnum(StrEnum):
# TODO: can't be pickled? https://github.com/pytorch/vision/issues/7099
class WeightsEnum(Enum):
"""
This class is the parent class of all model weights. Each model building method receives an optional `weights`
parameter with its associated pre-trained weights. It inherits from `Enum` and its values should be of type
Expand All @@ -48,14 +50,11 @@ class WeightsEnum(StrEnum):
value (Weights): The data class entry with the weight information.
"""

def __init__(self, value: Weights):
self._value_ = value

@classmethod
def verify(cls, obj: Any) -> Any:
if obj is not None:
if type(obj) is str:
obj = cls.from_str(obj.replace(cls.__name__ + ".", ""))
obj = cls[obj.replace(cls.__name__ + ".", "")]
elif not isinstance(obj, cls):
raise TypeError(
f"Invalid Weight class provided; expected {cls.__name__} but received {obj.__class__.__name__}."
Expand All @@ -76,6 +75,7 @@ def __getattr__(self, name):
return super().__getattr__(name)

def __deepcopy__(self, memodict=None):
# TODO: this isn't a real deep-copy? https://github.com/pytorch/vision/pull/6883#discussion_r1011372296
NicolasHug marked this conversation as resolved.
Show resolved Hide resolved
return self


Expand Down Expand Up @@ -112,7 +112,7 @@ def get_weight(name: str) -> WeightsEnum:
if weights_enum is None:
raise ValueError(f"The weight enum '{enum_name}' for the specific method couldn't be retrieved.")

return weights_enum.from_str(value_name)
return weights_enum[value_name]


def get_model_weights(name: Union[Callable, str]) -> WeightsEnum:
Expand Down