Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

model: pytorch: pretrained: Add support for additional layers Python API #1148

Merged
merged 2 commits into from Jul 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Operations for file (de)compression
- Usecase example notebook for "Evaluating Model Performance"
- Tests for all notebooks auto created and run via ``test_notebooks.py``
- Support for additional layers in pytorch pretrained models via Python API
### Changed
- Calls to hashlib now go through helper functions
- Build docs using `dffml service dev docs`
Expand Down
24 changes: 19 additions & 5 deletions model/pytorch/dffml_model_pytorch/pytorch_pretrained.py
@@ -1,4 +1,6 @@
import sys
from typing import Union

import torch.nn as nn
from torchvision import models

Expand Down Expand Up @@ -26,7 +28,9 @@ class PyTorchPreTrainedModelConfig(PyTorchModelConfig):
add_layers: bool = field(
"Replace the last layer of the pretrained model", default=False,
)
layers: dict = field(
layers: Union[
dict, nn.ModuleDict, nn.Sequential, nn.ModuleList, nn.Module
] = field(
"Extra layers to replace the last layer of the pretrained model",
default=None,
)
Expand Down Expand Up @@ -60,10 +64,20 @@ def createModel(self):
param.require_grad = self.parent.config.trainable

if self.parent.config.add_layers:
layers = [
create_layer(value)
for key, value in self.parent.config.layers.items()
]
if self.parent.config.layers.__class__.__base__.__name__ in [
"ModuleDict",
"Sequential",
"ModuleList",
"Module",
]:
layers = nn.Sequential()
for name, module in self.parent.config.layers.named_children():
layers.add_module(name, module)
else:
layers = [
create_layer(value)
for key, value in self.parent.config.layers.items()
]

if self.parent.LAST_LAYER_TYPE == "classifier_sequential":
if len(layers) > 1:
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Expand Up @@ -13,7 +13,7 @@ jsbeautifier>=1.14.0
twine
# Test requirements
httptest>=0.0.15
Pillow>=7.1.2
Pillow>=8.3.1
pre-commit
ipykernel
matplotlib