Skip to content

Commit

Permalink
model: pytorch: Add support for additional layers via Python API
Browse files Browse the repository at this point in the history
Previously only supplying a dict which would then be converted to
PyTorch objects was supported. Now PyTorch objects can be supplied
directly.

Fixes: #1147
Related: #840
Related: #1151
  • Loading branch information
mhash1m committed Jul 7, 2021
1 parent ac181a0 commit e90f1e7
Showing 1 changed file with 19 additions and 5 deletions.
24 changes: 19 additions & 5 deletions model/pytorch/dffml_model_pytorch/pytorch_pretrained.py
@@ -1,4 +1,6 @@
import sys
from typing import Union

import torch.nn as nn
from torchvision import models

Expand Down Expand Up @@ -26,7 +28,9 @@ class PyTorchPreTrainedModelConfig(PyTorchModelConfig):
add_layers: bool = field(
"Replace the last layer of the pretrained model", default=False,
)
layers: dict = field(
layers: Union[
dict, nn.ModuleDict, nn.Sequential, nn.ModuleList, nn.Module
] = field(
"Extra layers to replace the last layer of the pretrained model",
default=None,
)
Expand Down Expand Up @@ -60,10 +64,20 @@ def createModel(self):
param.require_grad = self.parent.config.trainable

if self.parent.config.add_layers:
layers = [
create_layer(value)
for key, value in self.parent.config.layers.items()
]
if self.parent.config.layers.__class__.__base__.__name__ in [
"ModuleDict",
"Sequential",
"ModuleList",
"Module",
]:
layers = nn.Sequential()
for name, module in self.parent.config.layers.named_children():
layers.add_module(name, module)
else:
layers = [
create_layer(value)
for key, value in self.parent.config.layers.items()
]

if self.parent.LAST_LAYER_TYPE == "classifier_sequential":
if len(layers) > 1:
Expand Down

0 comments on commit e90f1e7

Please sign in to comment.