Skip to content

Commit

Permalink
[fbsync] Make WeightEnum and Weights public + cleanups (#7100)
Browse files Browse the repository at this point in the history
Reviewed By: vmoens

Differential Revision: D43116103

fbshipit-source-id: 8283b5f1a94dd934f9f7c87e6ec5492733ed910a
  • Loading branch information
NicolasHug authored and facebook-github-bot committed Feb 9, 2023
1 parent 49fb060 commit 034fd7a
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 17 deletions.
2 changes: 1 addition & 1 deletion test/test_extended_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
from common_extended_utils import get_file_size_mb, get_ops
from torchvision import models
from torchvision.models._api import get_model_weights, Weights, WeightsEnum
from torchvision.models import get_model_weights, Weights, WeightsEnum
from torchvision.models._utils import handle_legacy_interface

run_if_test_with_extended = pytest.mark.skipif(
Expand Down
7 changes: 6 additions & 1 deletion torchvision/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,9 @@
from .swin_transformer import *
from .maxvit import *
from . import detection, optical_flow, quantization, segmentation, video
from ._api import get_model, get_model_builder, get_model_weights, get_weight, list_models

# The Weights and WeightsEnum are developer-facing utils that we make public for
# downstream libs like torchgeo https://github.com/pytorch/vision/issues/7094
# TODO: we could / should document them publicly, but it's not clear where, as
# they're not intended for end users.
from ._api import get_model, get_model_builder, get_model_weights, get_weight, list_models, Weights, WeightsEnum
31 changes: 16 additions & 15 deletions torchvision/models/_api.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import importlib
import inspect
import sys
from dataclasses import dataclass, fields
from dataclasses import dataclass
from enum import Enum
from functools import partial
from inspect import signature
from types import ModuleType
from typing import Any, Callable, cast, Dict, List, Mapping, Optional, TypeVar, Union

from torch import nn

from torchvision._utils import StrEnum

from .._internally_replaced_utils import load_state_dict_from_url


Expand Down Expand Up @@ -65,7 +64,7 @@ def __eq__(self, other: Any) -> bool:
return self.transforms == other.transforms


class WeightsEnum(StrEnum):
class WeightsEnum(Enum):
"""
This class is the parent class of all model weights. Each model building method receives an optional `weights`
parameter with its associated pre-trained weights. It inherits from `Enum` and its values should be of type
Expand All @@ -75,14 +74,11 @@ class WeightsEnum(StrEnum):
value (Weights): The data class entry with the weight information.
"""

def __init__(self, value: Weights):
self._value_ = value

@classmethod
def verify(cls, obj: Any) -> Any:
if obj is not None:
if type(obj) is str:
obj = cls.from_str(obj.replace(cls.__name__ + ".", ""))
obj = cls[obj.replace(cls.__name__ + ".", "")]
elif not isinstance(obj, cls):
raise TypeError(
f"Invalid Weight class provided; expected {cls.__name__} but received {obj.__class__.__name__}."
Expand All @@ -95,12 +91,17 @@ def get_state_dict(self, progress: bool) -> Mapping[str, Any]:
def __repr__(self) -> str:
return f"{self.__class__.__name__}.{self._name_}"

def __getattr__(self, name):
# Be able to fetch Weights attributes directly
for f in fields(Weights):
if f.name == name:
return object.__getattribute__(self.value, name)
return super().__getattr__(name)
@property
def url(self):
return self.value.url

@property
def transforms(self):
return self.value.transforms

@property
def meta(self):
return self.value.meta


def get_weight(name: str) -> WeightsEnum:
Expand Down Expand Up @@ -134,7 +135,7 @@ def get_weight(name: str) -> WeightsEnum:
if weights_enum is None:
raise ValueError(f"The weight enum '{enum_name}' for the specific method couldn't be retrieved.")

return weights_enum.from_str(value_name)
return weights_enum[value_name]


def get_model_weights(name: Union[Callable, str]) -> WeightsEnum:
Expand Down

0 comments on commit 034fd7a

Please sign in to comment.