Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make WeightEnum and Weights public + cleanups #7100

Merged
merged 9 commits into from
Feb 2, 2023
2 changes: 1 addition & 1 deletion test/test_extended_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
from common_extended_utils import get_file_size_mb, get_ops
from torchvision import models
from torchvision.models._api import get_model_weights, Weights, WeightsEnum
from torchvision.models import get_model_weights, Weights, WeightsEnum
from torchvision.models._utils import handle_legacy_interface

run_if_test_with_extended = pytest.mark.skipif(
Expand Down
6 changes: 5 additions & 1 deletion torchvision/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,8 @@
from .swin_transformer import *
from .maxvit import *
from . import detection, optical_flow, quantization, segmentation, video
from ._api import get_model, get_model_builder, get_model_weights, get_weight, list_models

# We're making Weights and WeightsEnum public for packages like torchgeo who are
# interested in using them https://github.com/pytorch/vision/issues/7094
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if that is required as a code comment? I mean we don't explain for all the other things as well. Aren't we just taking them as part of the public API now?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They're sort of special in the sense that they're more developer tools intended to be used by downstream libraries, rather than your typical transform / model which are made for end users.
I don't think we have ever made any developer tool public yet, and I added the comment to indicate that this is not done by mistake. I won't fight to keep it though :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I didn't consider that. Maybe change the wording to something like "These are public ..."? My initial understanding was that this was meant as a comment for the reviewer or the PR not someone arbitrary looking at the code.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the comment, hopefully the intent is clear now!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I like the wording better now.

# TODO: we could / should document them publicly as well?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Maybe even in this PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following up on my comment above about the fact that these are the first developer-facing tools, do you have thoughts on where / how we should document them?
I typically think it shouldn't be part of the already quite dense models doc page.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, you are right. With our new maintainer guide, downstream libraries are safe in the sense that we put them under our BC policy. My only concern is discoverability. If we leave it as is, people that want something like this maybe don't know that it is there? Or do we expect them to look into our implementation anyway? cc @adamjstewart

Copy link
Member Author

@NicolasHug NicolasHug Feb 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's a first, I'm tempted to leave them undocumented for now and see how it goes. Developers who need those would check the code IMO (or not, rightfully assuming that they're public since they can be imported from torchvision.models).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a TorchGeo perspective, we're fine with having to read the source code to figure out how to use the weights. But our docs try to link to the base class docs, so if the base class isn't documented we have to add an ignore. So long term, it would be cool if it was documented, but not a hard requirement.

from ._api import get_model, get_model_builder, get_model_weights, get_weight, list_models, Weights, WeightsEnum
27 changes: 15 additions & 12 deletions torchvision/models/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import inspect
import sys
from dataclasses import dataclass, fields
from enum import Enum
from functools import partial
from inspect import signature
from types import ModuleType
Expand Down Expand Up @@ -65,7 +66,7 @@ def __eq__(self, other: Any) -> bool:
return self.transforms == other.transforms


class WeightsEnum(StrEnum):
class WeightsEnum(Enum):
"""
This class is the parent class of all model weights. Each model building method receives an optional `weights`
parameter with its associated pre-trained weights. It inherits from `Enum` and its values should be of type
Expand All @@ -75,14 +76,11 @@ class WeightsEnum(StrEnum):
value (Weights): The data class entry with the weight information.
"""

def __init__(self, value: Weights):
self._value_ = value

@classmethod
def verify(cls, obj: Any) -> Any:
if obj is not None:
if type(obj) is str:
obj = cls.from_str(obj.replace(cls.__name__ + ".", ""))
obj = cls[obj.replace(cls.__name__ + ".", "")]
elif not isinstance(obj, cls):
raise TypeError(
f"Invalid Weight class provided; expected {cls.__name__} but received {obj.__class__.__name__}."
Expand All @@ -95,12 +93,17 @@ def get_state_dict(self, progress: bool) -> Mapping[str, Any]:
def __repr__(self) -> str:
return f"{self.__class__.__name__}.{self._name_}"

def __getattr__(self, name):
# Be able to fetch Weights attributes directly
for f in fields(Weights):
if f.name == name:
return object.__getattribute__(self.value, name)
return super().__getattr__(name)
@property
def url(self):
return self.value.url

@property
def transforms(self):
return self.value.transforms

@property
def meta(self):
return self.value.meta


def get_weight(name: str) -> WeightsEnum:
Expand Down Expand Up @@ -136,7 +139,7 @@ def get_weight(name: str) -> WeightsEnum:
if weights_enum is None:
raise ValueError(f"The weight enum '{enum_name}' for the specific method couldn't be retrieved.")

return weights_enum.from_str(value_name)
return weights_enum[value_name]


def get_model_weights(name: Union[Callable, str]) -> WeightsEnum:
Expand Down