diff --git a/CHANGELOG.md b/CHANGELOG.md index 6fb37271fa..cef7c4e042 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Operations for file (de)compression - Usecase example notebook for "Evaluating Model Performance" - Tests for all notebooks auto created and run via ``test_notebooks.py`` +- Support for additional layers in pytorch pretrained models via Python API ### Changed - Calls to hashlib now go through helper functions - Build docs using `dffml service dev docs` diff --git a/model/pytorch/dffml_model_pytorch/pytorch_pretrained.py b/model/pytorch/dffml_model_pytorch/pytorch_pretrained.py index cbcb991417..0d35b91260 100644 --- a/model/pytorch/dffml_model_pytorch/pytorch_pretrained.py +++ b/model/pytorch/dffml_model_pytorch/pytorch_pretrained.py @@ -1,4 +1,6 @@ import sys +from typing import Union + import torch.nn as nn from torchvision import models @@ -26,7 +28,9 @@ class PyTorchPreTrainedModelConfig(PyTorchModelConfig): add_layers: bool = field( "Replace the last layer of the pretrained model", default=False, ) - layers: dict = field( + layers: Union[ + dict, nn.ModuleDict, nn.Sequential, nn.ModuleList, nn.Module + ] = field( "Extra layers to replace the last layer of the pretrained model", default=None, ) @@ -60,10 +64,20 @@ def createModel(self): param.require_grad = self.parent.config.trainable if self.parent.config.add_layers: - layers = [ - create_layer(value) - for key, value in self.parent.config.layers.items() - ] + if self.parent.config.layers.__class__.__base__.__name__ in [ + "ModuleDict", + "Sequential", + "ModuleList", + "Module", + ]: + layers = nn.Sequential() + for name, module in self.parent.config.layers.named_children(): + layers.add_module(name, module) + else: + layers = [ + create_layer(value) + for key, value in self.parent.config.layers.items() + ] if self.parent.LAST_LAYER_TYPE == "classifier_sequential": if len(layers) > 1: diff --git a/requirements-dev.txt b/requirements-dev.txt index f3f10ad56f..5dd63f5bfc 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -13,7 +13,7 @@ jsbeautifier>=1.14.0 twine # Test requirements httptest>=0.0.15 -Pillow>=7.1.2 +Pillow>=8.3.1 pre-commit ipykernel matplotlib \ No newline at end of file