diff --git a/model/pytorch/dffml_model_pytorch/pytorch_pretrained.py b/model/pytorch/dffml_model_pytorch/pytorch_pretrained.py index cbcb991417..0d35b91260 100644 --- a/model/pytorch/dffml_model_pytorch/pytorch_pretrained.py +++ b/model/pytorch/dffml_model_pytorch/pytorch_pretrained.py @@ -1,4 +1,6 @@ import sys +from typing import Union + import torch.nn as nn from torchvision import models @@ -26,7 +28,9 @@ class PyTorchPreTrainedModelConfig(PyTorchModelConfig): add_layers: bool = field( "Replace the last layer of the pretrained model", default=False, ) - layers: dict = field( + layers: Union[ + dict, nn.ModuleDict, nn.Sequential, nn.ModuleList, nn.Module + ] = field( "Extra layers to replace the last layer of the pretrained model", default=None, ) @@ -60,10 +64,20 @@ def createModel(self): param.require_grad = self.parent.config.trainable if self.parent.config.add_layers: - layers = [ - create_layer(value) - for key, value in self.parent.config.layers.items() - ] + if self.parent.config.layers.__class__.__base__.__name__ in [ + "ModuleDict", + "Sequential", + "ModuleList", + "Module", + ]: + layers = nn.Sequential() + for name, module in self.parent.config.layers.named_children(): + layers.add_module(name, module) + else: + layers = [ + create_layer(value) + for key, value in self.parent.config.layers.items() + ] if self.parent.LAST_LAYER_TYPE == "classifier_sequential": if len(layers) > 1: