Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEEDBACK] Model Registration beta API #6365

Open
datumbox opened this issue Aug 3, 2022 · 5 comments
Open

[FEEDBACK] Model Registration beta API #6365

datumbox opened this issue Aug 3, 2022 · 5 comments

Comments

@datumbox
Copy link
Contributor

datumbox commented Aug 3, 2022

馃殌 Feedback Request

This issue is dedicated for collecting community feedback on the Model Registration API. Please review the dedicated RFC and blogpost where we describe the API in detail and provide an overview of its features.

We would love to get your thoughts, comments and input in order to finalize the API and include it on the new release of TorchVision.

@dataplayer12
Copy link

It would be great if list_models could list only specific models matching a regex, or at least wildcard searches like list_models("resnet*")
timm has this functionality and it really boosts productivity.

@TeodorPoncu
Copy link
Contributor

I agree with @dataplayer12. Whilst going through some tests in prototype the fetch all like behaviour might raise issues when dealing with constructors for future models we might plan on adding which do not follow the same initialisation scheme

@pytest.mark.parametrize("model_fn", TM.list_model_fns(models.depth.stereo))
@pytest.mark.parametrize("model_mode", ("standard", "scripted"))
@pytest.mark.parametrize("dev", cpu_and_gpu())
def test_raft_stereo(model_fn, model_mode, dev):
# A simple test to make sure the model can do forward pass and jit scriptable
set_rng_seed(0)
# Use corr_pyramid and corr_block with smaller num_levels and radius to prevent nan output
# get the idea from test_models.test_raft
corr_pyramid = models.depth.stereo.raft_stereo.CorrPyramid1d(num_levels=2)
corr_block = models.depth.stereo.raft_stereo.CorrBlock1d(num_levels=2, radius=2)
model = model_fn(corr_pyramid=corr_pyramid, corr_block=corr_block).eval().to(dev)

Furthermore this can get quite tricky when we're dealing with models that do not have the same out shapes or number of outputs even though they "solve" the same task.

preds = model(img1, img2, num_iters=num_iters)
depth_pred = preds[-1]
assert len(preds) == num_iters, "Number of predictions should be the same as model.num_iters"
assert depth_pred.shape == torch.Size(
[1, 1, 64, 64]
), f"The output shape of depth_pred should be [1, 1, 64, 64] but instead it is {preds[0].shape}"

The BC-compatible fix I see is rather non intrusive and rather simple.

We could change find_model to something like:

def find_model(name: str, pattern: str) -> Callable[..., M]:
    name = name.lower()
    try:
        fn = BUILTIN_MODELS[name]
        # check if the name matches the pattern
        if not re.match(pattern, name):
            return None
    except KeyError:
        raise ValueError(f"Unknown model {name}")
    return fn

Then we could change list_model_fns to something like:

def list_model_fns(module, pattern: str = "*") -> List[Callable[..., M]]:
    model_fns = [find_model(name, pattern) for name in list_models(module)]
    model_fns = list(filter(lambda x: x is not None, model_fns))
    return model_fns

Other than giving the users the option of selecting only a specific family of models I believe that this might help with easing developer experience in the case of writing tests or various utilities whilst maintaining the same API.

The alternative, in terms of developer experience would be to pass in individually each model class in the function arguments or decorator, when we cannot make the assertion that all model from a module behave in the exact same way.

@adamjstewart
Copy link
Contributor

We're trying to adopt the new API in TorchGeo but it isn't clear how the registration API works for weights that are not built into torchvision. We list our own WeightsEnums but torchvision.models.list_models doesn't know anything about them and list_models(module=torchgeo.models) doesn't work. According to the blog:

The model registration methods are kept private on purpose as we currently focus only on supporting the built-in models of TorchVision.

So it's possible this is by design. Guess I'll just wait for them to become public and copy-n-paste all the code for now...

@NicolasHug
Copy link
Member

NicolasHug commented Jan 25, 2023

Thanks for the feedback @adamjstewart .

The registrators are private right now because they weren't intended to work for external packages. What kind of workflow would you like to enable? It seems like it would work like this for torchgeo users:

from torchvision.models import list_models

list_models(module=torchgeo.models)

which IMHO seems awkward; torchgeo users probably just want to use something like torchgeo.models.list_models, and the fact that it (may) rely on torchvision should just be an implementation detail, not something exposed to users.

IIRC from the design stage, we introduced the module= parameter because some models have the same name in the torchvision.models and torchvision.models.quantized namespaces - we had to introduce module to disambiguate, but it's probably not something we would have done otherwise. It may give a sense that we intend to support arbitrary packages, but that wasn't the original intention

We're still open to making those public if we can find a nice/easy/useful way to do so, but for now I think a good old copy-n-paste is your best strat :)

@adamjstewart
Copy link
Contributor

It seems like it would work like this for torchgeo users:

from torchvision.models import list_models

list_models(module=torchgeo.models)

I would love it if that syntax worked, but it doesn't:

>>> import torchgeo.models
>>> from torchvision.models import list_models
>>> list_models(module=torchgeo.models)
[]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants