From e90f1e7de7db8b3ea715d613c57a56290709c6ed Mon Sep 17 00:00:00 2001 From: mHash1m Date: Sat, 3 Jul 2021 00:00:06 +0500 Subject: [PATCH] model: pytorch: Add support for additional layers via Python API Previously only supplying a dict which would then be converted to PyTorch objects was supported. Now PyTorch objects can be supplied directly. Fixes: #1147 Related: #840 Related: #1151 --- .../dffml_model_pytorch/pytorch_pretrained.py | 24 +++++++++++++++---- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/model/pytorch/dffml_model_pytorch/pytorch_pretrained.py b/model/pytorch/dffml_model_pytorch/pytorch_pretrained.py index cbcb991417..0d35b91260 100644 --- a/model/pytorch/dffml_model_pytorch/pytorch_pretrained.py +++ b/model/pytorch/dffml_model_pytorch/pytorch_pretrained.py @@ -1,4 +1,6 @@ import sys +from typing import Union + import torch.nn as nn from torchvision import models @@ -26,7 +28,9 @@ class PyTorchPreTrainedModelConfig(PyTorchModelConfig): add_layers: bool = field( "Replace the last layer of the pretrained model", default=False, ) - layers: dict = field( + layers: Union[ + dict, nn.ModuleDict, nn.Sequential, nn.ModuleList, nn.Module + ] = field( "Extra layers to replace the last layer of the pretrained model", default=None, ) @@ -60,10 +64,20 @@ def createModel(self): param.require_grad = self.parent.config.trainable if self.parent.config.add_layers: - layers = [ - create_layer(value) - for key, value in self.parent.config.layers.items() - ] + if self.parent.config.layers.__class__.__base__.__name__ in [ + "ModuleDict", + "Sequential", + "ModuleList", + "Module", + ]: + layers = nn.Sequential() + for name, module in self.parent.config.layers.named_children(): + layers.add_module(name, module) + else: + layers = [ + create_layer(value) + for key, value in self.parent.config.layers.items() + ] if self.parent.LAST_LAYER_TYPE == "classifier_sequential": if len(layers) > 1: