diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..73f69e0
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
+# Editor-based HTTP Client requests
+/httpRequests/
diff --git a/.idea/gmed-icml.iml b/.idea/gmed-icml.iml
new file mode 100644
index 0000000..25a5ed4
--- /dev/null
+++ b/.idea/gmed-icml.iml
@@ -0,0 +1,15 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..59f68cd
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,75 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..026d02f
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,7 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..dfff931
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..05dd38b
--- /dev/null
+++ b/README.md
@@ -0,0 +1,39 @@
+# GMED-anonymous-submission
+
+## Links for downloadable datasets:
+- MNIST dataset
+
+ http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
+ http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
+ http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
+ http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
+
+- CIFAR-10 and CIFAR-100 datsets
+
+ https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
+ https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz
+
+
+- mini-ImageNet
+
+ https://data.deepai.org/miniimagenet.zip
+
+ The partition of the datasets into tasks will be performed each time you train the model.
+
+## Running experiments
+Running experiments on MNIST
+
+```
+export stride=0.1 # editing stride (alpha)
+export reg=0.01 # reg strength (beta)
+export dataset="split_mnist"
+export mem=500
+export start_seed=100
+export stop_seed=109
+export method="-" # change to "mirr" for MIR+GMED
+
+./scripts/gme_tune_${dataset}.sh ${dataset}_iter1_stride${stride}_1_0_2${reg}_m${mem} 1 ${stride} 1 0 2 ${mem} ${method} ${start_seed} ${stop_seed} "OUTPUT_DIR=runs EXTERNAL.OCL.REG_STRENGTH=${reg}"
+```
+
+Similary, experiments on other datasets can be run by changing the name of the dataset variable above.
+
diff --git a/__init__.py b/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/config.py b/config.py
new file mode 100644
index 0000000..76183f7
--- /dev/null
+++ b/config.py
@@ -0,0 +1,11 @@
+class _Config:
+ def __init__(self):
+ self.image_size = 224
+ self.vocab_path = ''
+
+ def update(self, cfg):
+ for k in cfg.__dict__:
+ if k not in self.__dict__:
+ setattr(self, k, cfg.__dict__[k])
+
+cfg = _Config()
\ No newline at end of file
diff --git a/configs/memevolve/permuted_mnist.yaml b/configs/memevolve/permuted_mnist.yaml
new file mode 100644
index 0000000..e6ff274
--- /dev/null
+++ b/configs/memevolve/permuted_mnist.yaml
@@ -0,0 +1,102 @@
+# config file
+MNIST:
+ ACTIVATED: True
+ TASK: 'permute'
+ EPOCH: 1
+ INSTANCE_NUM: 1000
+EXTERNAL:
+ IMAGE_IDS: []
+ OBJECT_NAMES: []
+ # OBJECT_NAMES: []
+ ATTRIBUTES: []
+
+ OBJECT_TOP_K: 100
+ # 0 refers to all
+ ATTRIBUTE_TOP_K: 0
+ RELOAD_SCENE_GRAPH: False
+
+ # Change the directory of your own
+ IMG_DIR: "/home/xisen/online-concept-learning/GQA/images"
+ TRAIN_SCENE_GRAPH_PATH: "/home/xisen/online-concept-learning/GQA/sceneGraphs/train_sceneGraphs.json"
+ TRAIN_SCENE_GRAPH_DUMP_PATH: "/home/xisen/online-concept-learning/GQADump/sceneGraphs/train_sceneGraphs_dump.pkl"
+ VAL_SCENE_GRAPH_PATH: "/home/xisen/online-concept-learning/GQA/sceneGraphs/val_sceneGraphs.json"
+ VAL_SCENE_GRAPH_DUMP_PATH: "/home/xisen/online-concept-learning/GQADump/sceneGraphs/val_sceneGraphs_dump.pkl"
+
+ NUM_WORKERS: 0
+ PIN_MEMORY: True
+ SHUFFLE: False
+ # IMAGES PER GPU
+ BATCH_SIZE: 10
+ BALANCE_ATTRIBUTES: False
+ IMAGE_SIZE: 224
+ IMAGE:
+ HEIGHT: 640
+ WIDTH: 640
+ REPLAY:
+ MEM_BS: 10
+ MEM_LIMIT: 100
+ FILTER_SELF: 0
+
+ # 0 for unknown
+ ROI_BOX_HEAD:
+ NUM_ATTR: 13
+
+ OPTIMIZER:
+ ADAM: False
+
+ OCL:
+ ACTIVATED: True
+ SORT_BY_ATTRIBUTES: False
+ ALGO: "VERX"
+ VOCAB: "/home/xisen/online-concept-learning/vocab/vocab_gqa_full.pkl"
+INPUT:
+ MIN_SIZE_TRAIN: (640,)
+ MAX_SIZE_TRAIN: 640
+ MIN_SIZE_TEST: 640
+ MAX_SIZE_TEST: 640
+MODEL:
+ META_ARCHITECTURE: "GeneralizedRCNN"
+ # WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50" # ResNet-50 on Imagenet
+ WEIGHT: "./pretrained-models/e2e_faster_rcnn_R_50_FPN_1x_trimmed.pth" # faster-rcnn-r50
+ BACKBONE:
+ CONV_BODY: "R-50-FPN"
+ RESNETS:
+ BACKBONE_OUT_CHANNELS: 256
+ RPN:
+ USE_FPN: True
+ ANCHOR_STRIDE: (4, 8, 16, 32, 64)
+ PRE_NMS_TOP_N_TRAIN: 1200
+ PRE_NMS_TOP_N_TEST: 1200
+ PRE_NMS_TOP_N_TEST: 200
+ POST_NMS_TOP_N_TEST: 200
+ FPN_POST_NMS_TOP_N_TRAIN: 1000
+ FPN_POST_NMS_TOP_N_TEST: 1000
+ ROI_HEADS:
+ USE_FPN: True
+ # BG_IOU_THRESHOLD: 0.3
+ # FG_IOU_THRESHOLD: 0.5
+ # OHEM: False
+ ROI_BOX_HEAD:
+ POOLER_RESOLUTION: 7
+ POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
+ POOLER_SAMPLING_RATIO: 2
+ # MLP_HEAD_DIM: 2048
+ FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
+ PREDICTOR: "FPNPredictor"
+ NUM_CLASSES: 1601
+ MASK_ON: False
+TYPE: "float32"
+DATASETS:
+ TRAIN: ("coco_gqa_train",)
+ TEST: ("coco_gqa_val",)
+DATALOADER:
+ SIZE_DIVISIBILITY: 32
+SOLVER:
+ BASE_LR: 0.05
+ MOMENTUM: 0.
+ WEIGHT_DECAY: 0.00001
+ STEPS: (12500, )
+ MAX_ITER: 30000
+ IMS_PER_BATCH: 32
+TEST:
+ IMS_PER_BATCH: 32
diff --git a/configs/memevolve/rotated_mnist.yaml b/configs/memevolve/rotated_mnist.yaml
new file mode 100644
index 0000000..2527edd
--- /dev/null
+++ b/configs/memevolve/rotated_mnist.yaml
@@ -0,0 +1,106 @@
+# config file
+MNIST:
+ ACTIVATED: True
+ TASK: 'rotate'
+ EPOCH: 1
+ INSTANCE_NUM: 1000
+EXTERNAL:
+ IMAGE_IDS: []
+ OBJECT_NAMES: []
+ # OBJECT_NAMES: []
+ ATTRIBUTES: []
+
+ OBJECT_TOP_K: 100
+ # 0 refers to all
+ ATTRIBUTE_TOP_K: 0
+ RELOAD_SCENE_GRAPH: False
+
+ # Change the directory of your own
+ IMG_DIR: "/home/xisen/online-concept-learning/GQA/images"
+ TRAIN_SCENE_GRAPH_PATH: "/home/xisen/online-concept-learning/GQA/sceneGraphs/train_sceneGraphs.json"
+ TRAIN_SCENE_GRAPH_DUMP_PATH: "/home/xisen/online-concept-learning/GQADump/sceneGraphs/train_sceneGraphs_dump.pkl"
+ VAL_SCENE_GRAPH_PATH: "/home/xisen/online-concept-learning/GQA/sceneGraphs/val_sceneGraphs.json"
+ VAL_SCENE_GRAPH_DUMP_PATH: "/home/xisen/online-concept-learning/GQADump/sceneGraphs/val_sceneGraphs_dump.pkl"
+
+ NUM_WORKERS: 0
+ PIN_MEMORY: True
+ SHUFFLE: False
+ # IMAGES PER GPU
+ BATCH_SIZE: 10
+ BALANCE_ATTRIBUTES: False
+ IMAGE_SIZE: 224
+ IMAGE:
+ HEIGHT: 640
+ WIDTH: 640
+ REPLAY:
+ MEM_BS: 10
+ MEM_LIMIT: 200
+ FILTER_SELF: 0
+
+ # 0 for unknown
+ ROI_BOX_HEAD:
+ NUM_ATTR: 13
+
+ OPTIMIZER:
+ ADAM: False
+
+ OCL:
+ ACTIVATED: True
+ SORT_BY_ATTRIBUTES: False
+ ALGO: "VERX"
+ TASK_INCREMENTAL: False
+ TASK_NUM: 20
+ CLASS_NUM: 10
+
+ VOCAB: "/home/xisen/online-concept-learning/vocab/vocab_gqa_full.pkl"
+INPUT:
+ MIN_SIZE_TRAIN: (640,)
+ MAX_SIZE_TRAIN: 640
+ MIN_SIZE_TEST: 640
+ MAX_SIZE_TEST: 640
+MODEL:
+ META_ARCHITECTURE: "GeneralizedRCNN"
+ # WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50" # ResNet-50 on Imagenet
+ WEIGHT: "./pretrained-models/e2e_faster_rcnn_R_50_FPN_1x_trimmed.pth" # faster-rcnn-r50
+ BACKBONE:
+ CONV_BODY: "R-50-FPN"
+ RESNETS:
+ BACKBONE_OUT_CHANNELS: 256
+ RPN:
+ USE_FPN: True
+ ANCHOR_STRIDE: (4, 8, 16, 32, 64)
+ PRE_NMS_TOP_N_TRAIN: 1200
+ PRE_NMS_TOP_N_TEST: 1200
+ PRE_NMS_TOP_N_TEST: 200
+ POST_NMS_TOP_N_TEST: 200
+ FPN_POST_NMS_TOP_N_TRAIN: 1000
+ FPN_POST_NMS_TOP_N_TEST: 1000
+ ROI_HEADS:
+ USE_FPN: True
+ # BG_IOU_THRESHOLD: 0.3
+ # FG_IOU_THRESHOLD: 0.5
+ # OHEM: False
+ ROI_BOX_HEAD:
+ POOLER_RESOLUTION: 7
+ POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
+ POOLER_SAMPLING_RATIO: 2
+ # MLP_HEAD_DIM: 2048
+ FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
+ PREDICTOR: "FPNPredictor"
+ NUM_CLASSES: 1601
+ MASK_ON: False
+TYPE: "float32"
+DATASETS:
+ TRAIN: ("coco_gqa_train",)
+ TEST: ("coco_gqa_val",)
+DATALOADER:
+ SIZE_DIVISIBILITY: 32
+SOLVER:
+ BASE_LR: 0.05
+ MOMENTUM: 0.
+ WEIGHT_DECAY: 0.00001
+ STEPS: (12500, )
+ MAX_ITER: 30000
+ IMS_PER_BATCH: 32
+TEST:
+ IMS_PER_BATCH: 32
diff --git a/configs/memevolve/split_cifar10.yaml b/configs/memevolve/split_cifar10.yaml
new file mode 100644
index 0000000..fefd4a6
--- /dev/null
+++ b/configs/memevolve/split_cifar10.yaml
@@ -0,0 +1,107 @@
+# config file
+CIFAR:
+ ACTIVATED: True
+ EPOCH: 1
+ DATASET: "CIFAR10"
+VERX:
+ LOSS1: 0
+ LOSS2: 0
+EXTERNAL:
+ IMAGE_IDS: []
+ OBJECT_NAMES: []
+ # OBJECT_NAMES: []
+ ATTRIBUTES: []
+
+ OBJECT_TOP_K: 100
+ # 0 refers to all
+ ATTRIBUTE_TOP_K: 0
+ RELOAD_SCENE_GRAPH: False
+
+ # Change the directory of your own
+ IMG_DIR: "/home/xisen/online-concept-learning/GQA/images"
+ TRAIN_SCENE_GRAPH_PATH: "/home/xisen/online-concept-learning/GQA/sceneGraphs/train_sceneGraphs.json"
+ TRAIN_SCENE_GRAPH_DUMP_PATH: "/home/xisen/online-concept-learning/GQADump/sceneGraphs/train_sceneGraphs_dump.pkl"
+ VAL_SCENE_GRAPH_PATH: "/home/xisen/online-concept-learning/GQA/sceneGraphs/val_sceneGraphs.json"
+ VAL_SCENE_GRAPH_DUMP_PATH: "/home/xisen/online-concept-learning/GQADump/sceneGraphs/val_sceneGraphs_dump.pkl"
+
+ NUM_WORKERS: 0
+ PIN_MEMORY: True
+ SHUFFLE: False
+ # IMAGES PER GPU
+ BATCH_SIZE: 10
+ BALANCE_ATTRIBUTES: False
+ IMAGE_SIZE: 224
+ IMAGE:
+ HEIGHT: 640
+ WIDTH: 640
+ REPLAY:
+ MEM_BS: 10
+ MEM_LIMIT: 500
+ FILTER_SELF: 0
+
+ # 0 for unknown
+ ROI_BOX_HEAD:
+ NUM_ATTR: 13
+
+ OPTIMIZER:
+ ADAM: False
+
+ OCL:
+ ACTIVATED: True
+ SORT_BY_ATTRIBUTES: False
+ ALGO: "VERX"
+ TASK_NUM: 5
+ CLASS_NUM: 2
+
+ VOCAB: "/home/xisen/online-concept-learning/vocab/vocab_gqa_full.pkl"
+INPUT:
+ MIN_SIZE_TRAIN: (640,)
+ MAX_SIZE_TRAIN: 640
+ MIN_SIZE_TEST: 640
+ MAX_SIZE_TEST: 640
+MODEL:
+ META_ARCHITECTURE: "GeneralizedRCNN"
+ # WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50" # ResNet-50 on Imagenet
+ WEIGHT: "./pretrained-models/e2e_faster_rcnn_R_50_FPN_1x_trimmed.pth" # faster-rcnn-r50
+ BACKBONE:
+ CONV_BODY: "R-50-FPN"
+ RESNETS:
+ BACKBONE_OUT_CHANNELS: 256
+ RPN:
+ USE_FPN: True
+ ANCHOR_STRIDE: (4, 8, 16, 32, 64)
+ PRE_NMS_TOP_N_TRAIN: 1200
+ PRE_NMS_TOP_N_TEST: 1200
+ PRE_NMS_TOP_N_TEST: 200
+ POST_NMS_TOP_N_TEST: 200
+ FPN_POST_NMS_TOP_N_TRAIN: 1000
+ FPN_POST_NMS_TOP_N_TEST: 1000
+ ROI_HEADS:
+ USE_FPN: True
+ # BG_IOU_THRESHOLD: 0.3
+ # FG_IOU_THRESHOLD: 0.5
+ # OHEM: False
+ ROI_BOX_HEAD:
+ POOLER_RESOLUTION: 7
+ POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
+ POOLER_SAMPLING_RATIO: 2
+ # MLP_HEAD_DIM: 2048
+ FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
+ PREDICTOR: "FPNPredictor"
+ NUM_CLASSES: 1601
+ MASK_ON: False
+TYPE: "float32"
+DATASETS:
+ TRAIN: ("coco_gqa_train",)
+ TEST: ("coco_gqa_val",)
+DATALOADER:
+ SIZE_DIVISIBILITY: 32
+SOLVER:
+ BASE_LR: 0.1
+ MOMENTUM: 0.
+ WEIGHT_DECAY: 0.00001
+ STEPS: (12500, )
+ MAX_ITER: 30000
+ IMS_PER_BATCH: 32
+TEST:
+ IMS_PER_BATCH: 32
diff --git a/configs/memevolve/split_cifar100.yaml b/configs/memevolve/split_cifar100.yaml
new file mode 100644
index 0000000..eb5bf2e
--- /dev/null
+++ b/configs/memevolve/split_cifar100.yaml
@@ -0,0 +1,110 @@
+# config file
+CIFAR:
+ ACTIVATED: True
+ EPOCH: 1
+ DATASET: "CIFAR100"
+ INSTANCE_NUM: 10000
+VERX:
+ LOSS1: 0
+ LOSS2: 0
+EXTERNAL:
+ IMAGE_IDS: []
+ OBJECT_NAMES: []
+ # OBJECT_NAMES: []
+ ATTRIBUTES: []
+
+ OBJECT_TOP_K: 100
+ # 0 refers to all
+ ATTRIBUTE_TOP_K: 0
+ RELOAD_SCENE_GRAPH: False
+
+ # Change the directory of your own
+ IMG_DIR: "/home/xisen/online-concept-learning/GQA/images"
+ TRAIN_SCENE_GRAPH_PATH: "/home/xisen/online-concept-learning/GQA/sceneGraphs/train_sceneGraphs.json"
+ TRAIN_SCENE_GRAPH_DUMP_PATH: "/home/xisen/online-concept-learning/GQADump/sceneGraphs/train_sceneGraphs_dump.pkl"
+ VAL_SCENE_GRAPH_PATH: "/home/xisen/online-concept-learning/GQA/sceneGraphs/val_sceneGraphs.json"
+ VAL_SCENE_GRAPH_DUMP_PATH: "/home/xisen/online-concept-learning/GQADump/sceneGraphs/val_sceneGraphs_dump.pkl"
+
+ NUM_WORKERS: 0
+ PIN_MEMORY: True
+ SHUFFLE: False
+ # IMAGES PER GPU
+ BATCH_SIZE: 10
+ BALANCE_ATTRIBUTES: False
+ IMAGE_SIZE: 224
+ IMAGE:
+ HEIGHT: 640
+ WIDTH: 640
+ REPLAY:
+ MEM_BS: 10
+ MEM_LIMIT: 10000
+ FILTER_SELF: 0
+
+ # 0 for unknown
+ ROI_BOX_HEAD:
+ NUM_ATTR: 13
+
+ OPTIMIZER:
+ ADAM: False
+
+ OCL:
+ ACTIVATED: True
+ SORT_BY_ATTRIBUTES: False
+ ALGO: "VERX"
+ TASK_INCREMENTAL: False
+ TASK_NUM: 20
+ CLASS_NUM: 5
+
+
+ VOCAB: "/home/xisen/online-concept-learning/vocab/vocab_gqa_full.pkl"
+INPUT:
+ MIN_SIZE_TRAIN: (640,)
+ MAX_SIZE_TRAIN: 640
+ MIN_SIZE_TEST: 640
+ MAX_SIZE_TEST: 640
+MODEL:
+ META_ARCHITECTURE: "GeneralizedRCNN"
+ # WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50" # ResNet-50 on Imagenet
+ WEIGHT: "./pretrained-models/e2e_faster_rcnn_R_50_FPN_1x_trimmed.pth" # faster-rcnn-r50
+ BACKBONE:
+ CONV_BODY: "R-50-FPN"
+ RESNETS:
+ BACKBONE_OUT_CHANNELS: 256
+ RPN:
+ USE_FPN: True
+ ANCHOR_STRIDE: (4, 8, 16, 32, 64)
+ PRE_NMS_TOP_N_TRAIN: 1200
+ PRE_NMS_TOP_N_TEST: 1200
+ PRE_NMS_TOP_N_TEST: 200
+ POST_NMS_TOP_N_TEST: 200
+ FPN_POST_NMS_TOP_N_TRAIN: 1000
+ FPN_POST_NMS_TOP_N_TEST: 1000
+ ROI_HEADS:
+ USE_FPN: True
+ # BG_IOU_THRESHOLD: 0.3
+ # FG_IOU_THRESHOLD: 0.5
+ # OHEM: False
+ ROI_BOX_HEAD:
+ POOLER_RESOLUTION: 7
+ POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
+ POOLER_SAMPLING_RATIO: 2
+ # MLP_HEAD_DIM: 2048
+ FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
+ PREDICTOR: "FPNPredictor"
+ NUM_CLASSES: 1601
+ MASK_ON: False
+TYPE: "float32"
+DATASETS:
+ TRAIN: ("coco_gqa_train",)
+ TEST: ("coco_gqa_val",)
+DATALOADER:
+ SIZE_DIVISIBILITY: 32
+SOLVER:
+ BASE_LR: 0.03
+ MOMENTUM: 0.
+ WEIGHT_DECAY: 0.00001
+ STEPS: (12500, )
+ MAX_ITER: 30000
+ IMS_PER_BATCH: 32
+TEST:
+ IMS_PER_BATCH: 32
diff --git a/configs/memevolve/split_mini_imagenet.yaml b/configs/memevolve/split_mini_imagenet.yaml
new file mode 100644
index 0000000..22056d4
--- /dev/null
+++ b/configs/memevolve/split_mini_imagenet.yaml
@@ -0,0 +1,113 @@
+# config file
+CIFAR:
+ ACTIVATED: True
+ MINI_IMAGENET: 1
+ EPOCH: 1
+ INSTANCE_NUM: 3000
+VERX:
+ LOSS1: 0
+ LOSS2: 0
+
+EXTERNAL:
+ IMAGE_IDS: []
+ OBJECT_NAMES: []
+ # OBJECT_NAMES: []
+ ATTRIBUTES: []
+
+ OBJECT_TOP_K: 100
+ # 0 refers to all
+ ATTRIBUTE_TOP_K: 0
+ RELOAD_SCENE_GRAPH: False
+
+ # Change the directory of your own
+ IMG_DIR: "/home/xisen/online-concept-learning/GQA/images"
+ TRAIN_SCENE_GRAPH_PATH: "/home/xisen/online-concept-learning/GQA/sceneGraphs/train_sceneGraphs.json"
+ TRAIN_SCENE_GRAPH_DUMP_PATH: "/home/xisen/online-concept-learning/GQADump/sceneGraphs/train_sceneGraphs_dump.pkl"
+ VAL_SCENE_GRAPH_PATH: "/home/xisen/online-concept-learning/GQA/sceneGraphs/val_sceneGraphs.json"
+ VAL_SCENE_GRAPH_DUMP_PATH: "/home/xisen/online-concept-learning/GQADump/sceneGraphs/val_sceneGraphs_dump.pkl"
+
+ NUM_WORKERS: 0
+ PIN_MEMORY: True
+ SHUFFLE: False
+ # IMAGES PER GPU
+ BATCH_SIZE: 10
+ BALANCE_ATTRIBUTES: False
+ IMAGE_SIZE: 224
+ IMAGE:
+ HEIGHT: 640
+ WIDTH: 640
+ REPLAY:
+ MEM_BS: 10
+ MEM_LIMIT: 10000
+ FILTER_SELF: 0
+ # 0 for unknown
+ ROI_BOX_HEAD:
+ NUM_ATTR: 13
+
+ OPTIMIZER:
+ ADAM: False
+ USE_LOSS_1: 1
+ USE_LOSS_2: -1
+ PROJ_REG_LOSS: 0
+
+ OCL:
+ ACTIVATED: True
+ SORT_BY_ATTRIBUTES: False
+ ALGO: "VERX"
+ #TASK_INCREMENTAL: True
+ TASK_NUM: 20
+ CLASS_NUM: 5
+ N_ITER: 3
+
+ VOCAB: "/home/xisen/online-concept-learning/vocab/vocab_gqa_full.pkl"
+INPUT:
+ MIN_SIZE_TRAIN: (640,)
+ MAX_SIZE_TRAIN: 640
+ MIN_SIZE_TEST: 640
+ MAX_SIZE_TEST: 640
+MODEL:
+ META_ARCHITECTURE: "GeneralizedRCNN"
+ # WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50" # ResNet-50 on Imagenet
+ WEIGHT: "./pretrained-models/e2e_faster_rcnn_R_50_FPN_1x_trimmed.pth" # faster-rcnn-r50
+ BACKBONE:
+ CONV_BODY: "R-50-FPN"
+ RESNETS:
+ BACKBONE_OUT_CHANNELS: 256
+ RPN:
+ USE_FPN: True
+ ANCHOR_STRIDE: (4, 8, 16, 32, 64)
+ PRE_NMS_TOP_N_TRAIN: 1200
+ PRE_NMS_TOP_N_TEST: 1200
+ PRE_NMS_TOP_N_TEST: 200
+ POST_NMS_TOP_N_TEST: 200
+ FPN_POST_NMS_TOP_N_TRAIN: 1000
+ FPN_POST_NMS_TOP_N_TEST: 1000
+ ROI_HEADS:
+ USE_FPN: True
+ # BG_IOU_THRESHOLD: 0.3
+ # FG_IOU_THRESHOLD: 0.5
+ # OHEM: False
+ ROI_BOX_HEAD:
+ POOLER_RESOLUTION: 7
+ POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
+ POOLER_SAMPLING_RATIO: 2
+ # MLP_HEAD_DIM: 2048
+ FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
+ PREDICTOR: "FPNPredictor"
+ NUM_CLASSES: 1601
+ MASK_ON: False
+TYPE: "float32"
+DATASETS:
+ TRAIN: ("coco_gqa_train",)
+ TEST: ("coco_gqa_val",)
+DATALOADER:
+ SIZE_DIVISIBILITY: 32
+SOLVER:
+ BASE_LR: 0.1
+ MOMENTUM: 0.
+ WEIGHT_DECAY: 0.00001
+ STEPS: (12500, )
+ MAX_ITER: 30000
+ IMS_PER_BATCH: 32
+TEST:
+ IMS_PER_BATCH: 32
diff --git a/configs/memevolve/split_mnist.yaml b/configs/memevolve/split_mnist.yaml
new file mode 100644
index 0000000..ad173ae
--- /dev/null
+++ b/configs/memevolve/split_mnist.yaml
@@ -0,0 +1,102 @@
+# config file
+MNIST:
+ ACTIVATED: True
+ TASK: 'split'
+ EPOCH: 1
+ INSTANCE_NUM: 1000
+EXTERNAL:
+ IMAGE_IDS: []
+ OBJECT_NAMES: []
+ # OBJECT_NAMES: []
+ ATTRIBUTES: []
+
+ OBJECT_TOP_K: 100
+ # 0 refers to all
+ ATTRIBUTE_TOP_K: 0
+ RELOAD_SCENE_GRAPH: False
+
+ # Change the directory of your own
+ IMG_DIR: "/home/xisen/online-concept-learning/GQA/images"
+ TRAIN_SCENE_GRAPH_PATH: "/home/xisen/online-concept-learning/GQA/sceneGraphs/train_sceneGraphs.json"
+ TRAIN_SCENE_GRAPH_DUMP_PATH: "/home/xisen/online-concept-learning/GQADump/sceneGraphs/train_sceneGraphs_dump.pkl"
+ VAL_SCENE_GRAPH_PATH: "/home/xisen/online-concept-learning/GQA/sceneGraphs/val_sceneGraphs.json"
+ VAL_SCENE_GRAPH_DUMP_PATH: "/home/xisen/online-concept-learning/GQADump/sceneGraphs/val_sceneGraphs_dump.pkl"
+
+ NUM_WORKERS: 0
+ PIN_MEMORY: True
+ SHUFFLE: False
+ # IMAGES PER GPU
+ BATCH_SIZE: 10
+ BALANCE_ATTRIBUTES: False
+ IMAGE_SIZE: 224
+ IMAGE:
+ HEIGHT: 640
+ WIDTH: 640
+ REPLAY:
+ MEM_BS: 10
+ MEM_LIMIT: 100
+ FILTER_SELF: 0
+
+ # 0 for unknown
+ ROI_BOX_HEAD:
+ NUM_ATTR: 13
+
+ OPTIMIZER:
+ ADAM: False
+
+ OCL:
+ ACTIVATED: True
+ SORT_BY_ATTRIBUTES: False
+ ALGO: "VERX"
+ VOCAB: "/home/xisen/online-concept-learning/vocab/vocab_gqa_full.pkl"
+INPUT:
+ MIN_SIZE_TRAIN: (640,)
+ MAX_SIZE_TRAIN: 640
+ MIN_SIZE_TEST: 640
+ MAX_SIZE_TEST: 640
+MODEL:
+ META_ARCHITECTURE: "GeneralizedRCNN"
+ # WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50" # ResNet-50 on Imagenet
+ WEIGHT: "./pretrained-models/e2e_faster_rcnn_R_50_FPN_1x_trimmed.pth" # faster-rcnn-r50
+ BACKBONE:
+ CONV_BODY: "R-50-FPN"
+ RESNETS:
+ BACKBONE_OUT_CHANNELS: 256
+ RPN:
+ USE_FPN: True
+ ANCHOR_STRIDE: (4, 8, 16, 32, 64)
+ PRE_NMS_TOP_N_TRAIN: 1200
+ PRE_NMS_TOP_N_TEST: 1200
+ PRE_NMS_TOP_N_TEST: 200
+ POST_NMS_TOP_N_TEST: 200
+ FPN_POST_NMS_TOP_N_TRAIN: 1000
+ FPN_POST_NMS_TOP_N_TEST: 1000
+ ROI_HEADS:
+ USE_FPN: True
+ # BG_IOU_THRESHOLD: 0.3
+ # FG_IOU_THRESHOLD: 0.5
+ # OHEM: False
+ ROI_BOX_HEAD:
+ POOLER_RESOLUTION: 7
+ POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
+ POOLER_SAMPLING_RATIO: 2
+ # MLP_HEAD_DIM: 2048
+ FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
+ PREDICTOR: "FPNPredictor"
+ NUM_CLASSES: 1601
+ MASK_ON: False
+TYPE: "float32"
+DATASETS:
+ TRAIN: ("coco_gqa_train",)
+ TEST: ("coco_gqa_val",)
+DATALOADER:
+ SIZE_DIVISIBILITY: 32
+SOLVER:
+ BASE_LR: 0.05
+ MOMENTUM: 0.
+ WEIGHT_DECAY: 0.00001
+ STEPS: (12500, )
+ MAX_ITER: 30000
+ IMS_PER_BATCH: 32
+TEST:
+ IMS_PER_BATCH: 32
diff --git a/data/benchmark_mir.py b/data/benchmark_mir.py
new file mode 100644
index 0000000..ba95f77
--- /dev/null
+++ b/data/benchmark_mir.py
@@ -0,0 +1,579 @@
+"""
+Copyright (c) 2020 Rahaf Aljundi, Lucas Caccia, Eugene Belilovsky, Massimo Caccia
+
+Modified from https://github.com/optimass/Maximally_Interfered_Retrieval/
+"""
+import os
+import torch
+import numpy as np
+from PIL import Image
+import random
+from scipy.ndimage.interpolation import rotate
+from torchvision import datasets, transforms
+import yaml
+
+
+""" Template Dataset with Labels """
+class XYDataset(torch.utils.data.Dataset):
+ def __init__(self, x, y, **kwargs):
+ self.x, self.y = x, y
+
+ # this was to store the inverse permutation in permuted_mnist
+ # so that we could 'unscramble' samples and plot them
+ for name, value in kwargs.items():
+ setattr(self, name, value)
+
+ def __len__(self):
+ return len(self.x)
+
+ def __getitem__(self, idx):
+ x, y = self.x[idx], self.y[idx]
+
+ if type(x) != torch.Tensor:
+ # mini_imagenet
+ # we assume it's a path --> load from file
+ x = self.transform(Image.open(x).convert('RGB'))
+ y = torch.Tensor(1).fill_(y).long().squeeze()
+ elif self.source.startswith('cifar'):
+ #cifar10_mean = (0.5, 0.5, 0.5)
+ #cifar10_std = (0.5, 0.5, 0.5)
+ x = x.float() / 255
+ # transform = transforms.Compose([
+ # transforms.ToPILImage(),
+ # transforms.RandomCrop(32, padding=4),
+ # transforms.RandomHorizontalFlip(),
+ # transforms.RandomRotation(10),
+ # transforms.ToTensor(),
+ # ])
+ # x = transform(x)
+ y = y.long()
+ else:
+ x = x.float() / 255.
+ y = y.long()
+
+ # for some reason mnist does better \in [0,1] than [-1, 1]
+ if self.source == 'mnist' or self.source.startswith('cifar'):
+ return x, y
+ else:
+ return (x - 0.5) * 2, y
+
+
+""" Template Dataset for Continual Learning """
+class CLDataLoader(object):
+ def __init__(self, datasets_per_task, batch_size, train=True):
+ bs = batch_size if train else 64
+
+ self.datasets = datasets_per_task
+ self.loaders = [
+ torch.utils.data.DataLoader(x, batch_size=bs, shuffle=True, drop_last=train, num_workers=0)
+ for x in self.datasets ]
+
+ def __getitem__(self, idx):
+ return self.loaders[idx]
+
+ def __len__(self):
+ return len(self.loaders)
+
+
+class FuzzyCLDataLoader(object):
+ def __init__(self, datasets_per_task, batch_size, train=True):
+ bs = batch_size if train else 64
+ self.raw_datasets = datasets_per_task
+ self.datasets = [_ for _ in datasets_per_task]
+ for i in range(len(self.datasets) - 1):
+ self.datasets[i], self.datasets[i + 1] = self.mix_two_datasets(self.datasets[i], self.datasets[i + 1])
+ self.loaders = [
+ torch.utils.data.DataLoader(x, batch_size=bs, shuffle=True, drop_last=train, num_workers=0)
+ for x in self.datasets ]
+
+ def shuffle(self, x, y):
+ perm = np.random.permutation(len(x))
+ x = x[perm]
+ y = y[perm]
+ return x, y
+
+ def mix_two_datasets(self, a, b, start=0.5):
+ a.x, a.y = self.shuffle(a.x, a.y)
+ b.x, b.y = self.shuffle(b.x, b.y)
+
+ def cmf_examples(i):
+ if i < start * len(a):
+ return 0
+ else:
+ return (1 - start) * len(a) * 0.25 * ((i / len(a) - start) / (1 - start)) ** 2
+
+ s, swaps = 0, []
+ for i in range(len(a)):
+ c = cmf_examples(i)
+ if s < c:
+ swaps.append(i)
+ s += 1
+
+ for idx in swaps:
+ a.x[idx], b.x[len(b) - idx], a.y[idx], b.y[len(b) - idx] = b.x[len(b) - idx], a.x[idx], b.y[len(b) - idx], a.y[idx]
+ return a, b
+
+ def __getitem__(self, idx):
+ return self.loaders[idx]
+
+ def __len__(self):
+ return len(self.loaders)
+
+
+
+class IIDDataset(torch.utils.data.Dataset):
+ def __init__(self, data_loaders, seed=0):
+ self.data_loader = data_loaders
+ self.idx = []
+ for task_id in range(len(data_loaders)):
+ for i in range(len(data_loaders[task_id].dataset)):
+ self.idx.append((task_id, i))
+ random.Random(seed).shuffle(self.idx)
+
+ def __getitem__(self, idx):
+ task_id, instance_id = self.idx[idx]
+ return self.data_loader[task_id].dataset.__getitem__(instance_id)
+
+ def __len__(self):
+ return len(self.idx)
+
+""" Permuted MNIST """
+def get_permuted_mnist(args):
+ #assert not args.use_conv
+ args.multiple_heads = False
+ args.n_classes = 10
+ #if 'mem_size' in args:
+ # args.buffer_size = args.mem_size * args.n_classes
+ args.n_tasks = 10 #if args.n_tasks==-1 else args.n_tasks
+ args.use_conv = False
+ args.input_type = 'binary'
+ args.input_size = [784]
+ #if args.output_loss is None:
+ args.output_loss = 'bernouilli'
+
+ # fetch MNIST
+ train = datasets.MNIST('data/', train=True, download=True)
+ test = datasets.MNIST('data/', train=False, download=True)
+
+ try:
+ train_x, train_y = train.data, train.targets
+ test_x, test_y = test.data, test.targets
+ except:
+ train_x, train_y = train.train_data, train.train_labels
+ test_x, test_y = test.test_data, test.test_labels
+
+ # only select 1000 of train x
+ permutation = np.random.RandomState(0).permutation(train_x.size(0))[:1000]
+ train_x = train_x[permutation]
+ train_y = train_y[permutation]
+
+ train_x = train_x.view(train_x.size(0), -1)
+ test_x = test_x.view(test_x.size(0), -1)
+
+ train_ds, test_ds, inv_perms = [], [], []
+ for task in range(args.n_tasks):
+ perm = torch.arange(train_x.size(-1)) if task == 0 else torch.randperm(train_x.size(-1))
+
+ # build inverse permutations, so we can display samples
+ inv_perm = torch.zeros_like(perm)
+ for i in range(perm.size(0)):
+ inv_perm[perm[i]] = i
+
+ inv_perms += [inv_perm]
+ train_ds += [(train_x[:, perm], train_y)]
+ test_ds += [(test_x[:, perm], test_y)]
+
+ train_ds, val_ds = make_valid_from_train(train_ds)
+
+ train_ds = map(lambda x, y : XYDataset(x[0], x[1], **{'inv_perm': y, 'source': 'mnist'}), train_ds, inv_perms)
+ val_ds = map(lambda x, y: XYDataset(x[0], x[1], **{'inv_perm': y, 'source': 'mnist'}), val_ds, inv_perms)
+ test_ds = map(lambda x, y : XYDataset(x[0], x[1], **{'inv_perm': y, 'source': 'mnist'}), test_ds, inv_perms)
+
+ return train_ds, val_ds, test_ds
+
+""" Rotated MNIST """
+def get_rotated_mnist(args):
+ #assert not args.use_conv
+ args.multiple_heads = False
+ args.n_classes = 10
+ #if 'mem_size' in args:
+ # args.buffer_size = args.mem_size * args.n_classes
+ args.n_tasks = 20
+ args.use_conv = False
+ args.input_type = 'binary'
+ args.input_size = [784]
+ #if args.output_loss is None:
+ args.output_loss = 'bernouilli'
+
+ args.min_rot = 0
+ args.max_rot = 180
+ train_ds, test_ds, inv_perms = [], [], []
+ val_ds = []
+ # fetch MNIST
+ to_tensor = transforms.ToTensor()
+
+ def rotate_dataset(x, angle):
+ x_np = np.copy(x.cpu().numpy())
+ x_np = rotate(x_np, angle=angle, axes=(2,1), reshape=False)
+ return torch.from_numpy(x_np).float()
+
+ train = datasets.MNIST('data/', train=True, download=True)
+ test = datasets.MNIST('data/', train=False, download=True)
+
+
+ for task in range(args.n_tasks):
+ #angle = random.random() * (args.max_rot - args.min_rot) + args.min_rot
+ min_rot = 1.0 * task / args.n_tasks * (args.max_rot - args.min_rot) + \
+ args.min_rot
+ max_rot = 1.0 * (task + 1) / args.n_tasks * \
+ (args.max_rot - args.min_rot) + args.min_rot
+ angle = random.random() * (max_rot - min_rot) + min_rot
+
+ rand_perm = np.random.permutation(len(train.data))[:1000]
+ rand_perm_test = np.random.permutation(len(test.data))[:1000]
+
+ try:
+ train_x, train_y = train.data[rand_perm], train.targets[rand_perm]
+ test_x, test_y = test.data[rand_perm_test], test.targets[rand_perm_test]
+ except:
+ train_x, train_y = train.train_data[rand_perm], train.train_labels[rand_perm]
+ test_x, test_y = test.test_data[rand_perm_test], test.test_labels[rand_perm_test]
+
+ train_x, train_y, val_x, val_y = train_x[:950], train_y[:950], train_x[950:], train_y[950:]
+
+ train_x = rotate_dataset(train_x, angle)
+ test_x = rotate_dataset(test_x, angle)
+ val_x = rotate_dataset(val_x, angle)
+ #train_x = train_x.view(train_x.size(0), -1)
+ #test_x = test_x.view(test_x.size(0), -1)
+
+ train_ds += [(train_x, train_y)]
+ test_ds += [(test_x, test_y)]
+ val_ds += [(val_x, val_y)]
+ #train_ds, _ = make_valid_from_train(train_ds, cut=0.99)
+
+ train_ds = map(lambda x: XYDataset(x[0], x[1], **{'source': 'mnist'}), train_ds)
+ val_ds = map(lambda x: XYDataset(x[0], x[1], **{'source': 'mnist'}), val_ds)
+ test_ds = map(lambda x: XYDataset(x[0], x[1], **{'source': 'mnist'}), test_ds)
+
+ return train_ds, val_ds, test_ds
+
+""" Split MNIST into 5 tasks {{0,1}, ... {8,9}} """
+def get_split_mnist(args, cfg):
+ args.multiple_heads = False
+ args.n_classes = 10
+ args.n_tasks = 5 #if args.n_tasks==-1 else args.n_tasks
+ if 'mem_size' in args:
+ args.buffer_size = args.n_tasks * args.mem_size * 2
+ args.use_conv = False
+ args.input_type = 'binary'
+ args.input_size = [1,28,28]
+ #if args.output_loss is None:
+ args.output_loss = 'bernouilli'
+
+ assert args.n_tasks in [5, 10], 'SplitMnist only works with 5 or 10 tasks'
+ assert '1.' in str(torch.__version__)[:2], 'Use Pytorch 1.x!'
+
+ # fetch MNIST
+ train = datasets.MNIST('Data/', train=True, download=True)
+ test = datasets.MNIST('Data/', train=False, download=True)
+
+ try:
+ train_x, train_y = train.data, train.targets
+ test_x, test_y = test.data, test.targets
+ except:
+ train_x, train_y = train.train_data, train.train_labels
+ test_x, test_y = test.test_data, test.test_labels
+
+ # sort according to the label
+ out_train = [
+ (x,y) for (x,y) in sorted(zip(train_x, train_y), key=lambda v : v[1]) ]
+
+ out_test = [
+ (x,y) for (x,y) in sorted(zip(test_x, test_y), key=lambda v : v[1]) ]
+
+ train_x, train_y = [
+ torch.stack([elem[i] for elem in out_train]) for i in [0,1] ]
+
+ test_x, test_y = [
+ torch.stack([elem[i] for elem in out_test]) for i in [0,1] ]
+
+ # cast in 3D:
+ train_x = train_x.view(train_x.size(0), 1, train_x.size(1), train_x.size(2))
+ test_x = test_x.view(test_x.size(0), 1, test_x.size(1), test_x.size(2))
+
+ # get indices of class split
+ train_idx = [((train_y + i) % 10).argmax() for i in range(10)]
+ train_idx = [0] + sorted(train_idx)
+
+ test_idx = [((test_y + i) % 10).argmax() for i in range(10)]
+ test_idx = [0] + sorted(test_idx)
+
+ train_ds, test_ds = [], []
+ skip = 10 // args.n_tasks
+ for i in range(0, 10, skip):
+ tr_s, tr_e = train_idx[i], train_idx[i + skip]
+ te_s, te_e = test_idx[i], test_idx[i + skip]
+
+ train_ds += [(train_x[tr_s:tr_e], train_y[tr_s:tr_e])]
+ test_ds += [(test_x[te_s:te_e], test_y[te_s:te_e])]
+
+ if hasattr(cfg, 'NOVAL') and cfg.NOVAL:
+ train_ds, val_ds = train_ds, test_ds
+ print('no validation set')
+ else:
+ train_ds, val_ds = make_valid_from_train(train_ds)
+
+ train_ds = map(lambda x : XYDataset(x[0], x[1], **{'source': 'mnist'}), train_ds)
+ val_ds = map(lambda x : XYDataset(x[0], x[1], **{'source': 'mnist'}), val_ds)
+ test_ds = map(lambda x : XYDataset(x[0], x[1], **{'source': 'mnist'}), test_ds)
+
+ return train_ds, val_ds, test_ds
+
+
+
+""" Split CIFAR10 into 5 tasks {{0,1}, ... {8,9}} """
+def get_split_cifar10(args, cfg):
+ # assert args.n_tasks in [5, 10], 'SplitCifar only works with 5 or 10 tasks'
+ assert '1.' in str(torch.__version__)[:2], 'Use Pytorch 1.x!'
+ args.n_tasks = 5
+ args.n_classes = 10
+ #args.buffer_size = args.n_tasks * args.mem_size * 2
+ args.multiple_heads = False
+ args.use_conv = True
+ args.n_classes_per_task = 2
+ args.input_size = [3, 32, 32]
+ args.input_type = 'continuous'
+ # because data is between [-1,1]:
+ # fetch MNIST
+ train = datasets.CIFAR10('Data/', train=True, download=True)
+ test = datasets.CIFAR10('Data/', train=False, download=True)
+
+ try:
+ train_x, train_y = train.data, train.targets
+ test_x, test_y = test.data, test.targets
+ except:
+ train_x, train_y = train.train_data, train.train_labels
+ test_x, test_y = test.test_data, test.test_labels
+
+ # sort according to the label
+ out_train = [
+ (x,y) for (x,y) in sorted(zip(train_x, train_y), key=lambda v : v[1]) ]
+
+ out_test = [
+ (x,y) for (x,y) in sorted(zip(test_x, test_y), key=lambda v : v[1]) ]
+
+ train_x, train_y = [
+ np.stack([elem[i] for elem in out_train]) for i in [0,1] ]
+
+ test_x, test_y = [
+ np.stack([elem[i] for elem in out_test]) for i in [0,1] ]
+
+ train_x = torch.Tensor(train_x).permute(0, 3, 1, 2).contiguous()
+ test_x = torch.Tensor(test_x).permute(0, 3, 1, 2).contiguous()
+
+ train_y = torch.Tensor(train_y)
+ test_y = torch.Tensor(test_y)
+
+ # get indices of class split
+ train_idx = [((train_y + i) % 10).argmax() for i in range(10)]
+ train_idx = [0] + [x + 1 for x in sorted(train_idx)]
+
+ test_idx = [((test_y + i) % 10).argmax() for i in range(10)]
+ test_idx = [0] + [x + 1 for x in sorted(test_idx)]
+
+ train_ds, test_ds = [], []
+ skip = 10 // 5 #args.n_tasks
+ for i in range(0, 10, skip):
+ tr_s, tr_e = train_idx[i], train_idx[i + skip]
+ te_s, te_e = test_idx[i], test_idx[i + skip]
+
+ train_ds += [(train_x[tr_s:tr_e], train_y[tr_s:tr_e])]
+ test_ds += [(test_x[te_s:te_e], test_y[te_s:te_e])]
+ if hasattr(cfg, 'NOVAL') and cfg.NOVAL:
+ train_ds, val_ds = train_ds, test_ds
+ print('no validation set')
+ else:
+ train_ds, val_ds = make_valid_from_train(train_ds)
+ #else:
+ # train_ds, val_ds = train_ds, test_ds
+ train_ds = map(lambda x : XYDataset(x[0], x[1], **{'source':'cifar10'}), train_ds)
+ val_ds = map(lambda x : XYDataset(x[0], x[1], **{'source':'cifar10'}), val_ds)
+ test_ds = map(lambda x : XYDataset(x[0], x[1], **{'source':'cifar10'}), test_ds)
+
+ return train_ds, val_ds, test_ds
+
+""" Split CIFAR100 into 20 tasks {{0,1,2,3,4}, ... {95,96,97,98,99}} """
+def get_split_cifar100(args, cfg):
+ # assert args.n_tasks in [5, 10], 'SplitCifar only works with 5 or 10 tasks'
+ assert '1.' in str(torch.__version__)[:2], 'Use Pytorch 1.x!'
+ args.n_tasks = 20
+ args.n_classes = 100
+ #args.buffer_size = args.n_tasks * args.mem_size * 2
+ args.multiple_heads = False
+ args.use_conv = True
+ args.n_classes_per_task = 5
+ args.input_size = [3, 32, 32]
+ args.input_type = 'continuous'
+ # fetch MNIST
+ train = datasets.CIFAR100('Data/', train=True, download=True)
+ test = datasets.CIFAR100('Data/', train=False, download=True)
+
+ try:
+ train_x, train_y = train.data, train.targets
+ test_x, test_y = test.data, test.targets
+ except:
+ train_x, train_y = train.train_data, train.train_labels
+ test_x, test_y = test.test_data, test.test_labels
+
+ # sort according to the label
+ out_train = [
+ (x,y) for (x,y) in sorted(zip(train_x, train_y), key=lambda v : v[1]) ]
+
+ out_test = [
+ (x,y) for (x,y) in sorted(zip(test_x, test_y), key=lambda v : v[1]) ]
+
+ train_x, train_y = [
+ np.stack([elem[i] for elem in out_train]) for i in [0,1] ]
+
+ test_x, test_y = [
+ np.stack([elem[i] for elem in out_test]) for i in [0,1] ]
+
+ train_x = torch.Tensor(train_x).permute(0, 3, 1, 2).contiguous()
+ test_x = torch.Tensor(test_x).permute(0, 3, 1, 2).contiguous()
+ train_y = torch.Tensor(train_y)
+ test_y = torch.Tensor(test_y)
+ train_ds, test_ds = [], []
+
+ with open('data/cifar100-split-online.yaml') as f:
+ label_splits = yaml.load(f)
+ for label_split in label_splits:
+ train_indice = []
+ task_labels = [_[1] for _ in label_split['subsets']]
+ for i in range(len(train_y)):
+ if train_y[i].item() in task_labels:
+ train_indice.append(i)
+ test_indice = []
+ for i in range(len(test_y)):
+ if test_y[i].item() in task_labels:
+ test_indice.append(i)
+ train_ds.append((train_x[train_indice], train_y[train_indice]))
+ test_ds.append((test_x[test_indice], test_y[test_indice]))
+
+ train_ds, val_ds = train_ds, test_ds
+ train_ds = map(lambda x : XYDataset(x[0], x[1], **{'source':'cifar100'}), train_ds)
+ val_ds = map(lambda x : XYDataset(x[0], x[1], **{'source':'cifar100'}), val_ds)
+ test_ds = map(lambda x : XYDataset(x[0], x[1], **{'source':'cifar100'}), test_ds)
+
+ return train_ds, val_ds, test_ds
+
+def get_miniimagenet(args):
+ ROOT_PATH = 'datasets/miniImagenet/'
+
+ args.use_conv = True
+ args.n_tasks = 20
+ args.n_classes = 100
+ args.multiple_heads = False
+ args.n_classes_per_task = 5
+ args.input_size = (3, 84, 84)
+ label2id = {}
+
+ def get_data(setname):
+ ds_dir = os.path.join(ROOT_PATH, setname)
+ label_dirs = os.listdir(ds_dir)
+ data, labels = [], []
+
+ for label in label_dirs:
+ label_dir = os.path.join(ds_dir, label)
+ for image_file in os.listdir(label_dir):
+ data.append(os.path.join(label_dir, image_file))
+ if label not in label2id:
+ label_id = len(label2id)
+ label2id[label] = label_id
+ label_id = label2id[label]
+ labels.append(label_id)
+ return data, labels
+
+ transform = transforms.Compose([
+ transforms.Resize(84),
+ transforms.CenterCrop(84),
+ transforms.ToTensor(),
+ ])
+
+ train_data, train_label = get_data('train')
+ valid_data, valid_label = get_data('val')
+ test_data, test_label = get_data('test')
+
+ # total of 60k examples for training, the rest for testing
+ all_data = np.array(train_data + valid_data + test_data)
+ all_label = np.array(train_label + valid_label + test_label)
+
+
+ train_ds, test_ds = [], []
+ current_train, current_test = None, None
+
+ cat = lambda x, y: np.concatenate((x, y), axis=0)
+
+ for i in range(args.n_classes):
+ class_indices = np.argwhere(all_label == i).reshape(-1)
+ class_data = all_data[class_indices]
+ class_label = all_label[class_indices]
+ split = int(0.8 * class_data.shape[0])
+
+ data_train, data_test = class_data[:split], class_data[split:]
+ label_train, label_test = class_label[:split], class_label[split:]
+
+ if current_train is None:
+ current_train, current_test = (data_train, label_train), (data_test, label_test)
+ else:
+ current_train = cat(current_train[0], data_train), cat(current_train[1], label_train)
+ current_test = cat(current_test[0], data_test), cat(current_test[1], label_test)
+
+ if i % args.n_classes_per_task == (args.n_classes_per_task - 1):
+ train_ds += [current_train]
+ test_ds += [current_test]
+ current_train, current_test = None, None
+
+ # TODO: remove this
+ ## Facebook actually does 17 tasks (3 to CV)
+ #train_ds = train_ds[:17]
+ #test_ds = test_ds[:17]
+
+ # build masks
+ masks = []
+ task_ids = [None for _ in range(20)]
+ for task, task_data in enumerate(train_ds):
+ labels = np.unique(task_data[1]) #task_data[1].unique().long()
+ assert labels.shape[0] == args.n_classes_per_task
+ mask = torch.zeros(args.n_classes).cuda()
+ mask[labels] = 1
+ masks += [mask]
+ task_ids[task] = labels
+
+ task_ids = torch.from_numpy(np.stack(task_ids)).cuda().long()
+
+ train_ds, val_ds = make_valid_from_train(train_ds)
+ train_ds = map(lambda x, y : XYDataset(x[0], x[1], **{'source':'cifar100', 'mask':y, 'task_ids':task_ids, 'transform':transform}), train_ds, masks)
+ val_ds = map(lambda x, y: XYDataset(x[0], x[1], **{'source': 'cifar100', 'mask': y, 'task_ids': task_ids, 'transform': transform}), val_ds, masks)
+ test_ds = map(lambda x, y : XYDataset(x[0], x[1], **{'source':'cifar100', 'mask':y, 'task_ids':task_ids, 'transform':transform}), test_ds, masks)
+
+ return train_ds, val_ds, test_ds
+
+
+def make_valid_from_train(dataset, cut=0.95):
+ tr_ds, val_ds = [], []
+ for task_ds in dataset:
+ x_t, y_t = task_ds
+
+ # shuffle before splitting
+ perm = torch.randperm(len(x_t))
+ x_t, y_t = x_t[perm], y_t[perm]
+
+ split = int(len(x_t) * cut)
+ x_tr, y_tr = x_t[:split], y_t[:split]
+ x_val, y_val = x_t[split:], y_t[split:]
+
+ tr_ds += [(x_tr, y_tr)]
+ val_ds += [(x_val, y_val)]
+
+ return tr_ds, val_ds
diff --git a/dataloader.py b/dataloader.py
new file mode 100644
index 0000000..f200858
--- /dev/null
+++ b/dataloader.py
@@ -0,0 +1,144 @@
+import torch
+import yaml
+import numpy as np
+import pickle
+from PIL import Image, ImageDraw
+from torch.utils.data import DataLoader
+from torchvision.transforms import transforms as T
+from tqdm import tqdm
+from yacs.config import CfgNode
+from maskrcnn_benchmark.config import cfg as maskrcnn_cfg
+from utils.tupperware import tupperware
+from utils.build_transforms import build_transforms
+from data.benchmark_mir import CLDataLoader, get_permuted_mnist, get_split_mnist, get_miniimagenet, get_rotated_mnist, \
+ get_split_cifar10, get_split_cifar100, IIDDataset, FuzzyCLDataLoader
+
+from utils.utils import DotDict, get_config_attr
+
+import random
+
+_dataset = {
+
+}
+
+_smnist_loaders = None
+def get_split_mnist_dataloader(cfg, split='train', filter_obj=None, batch_size=128, *args, **kwargs):
+ fuzzy = get_config_attr(cfg,'EXTERNAL.OCL.FUZZY', default=0, mute=True)
+ d = DotDict()
+ global _smnist_loaders
+ if not _smnist_loaders:
+ data = get_split_mnist(d, cfg)
+ loader_cls = CLDataLoader if not fuzzy else FuzzyCLDataLoader
+ train_loader, val_loader, test_loader = [loader_cls(elem, batch_size, train=t) \
+ for elem, t in zip(data, [True, False, False])]
+ _smnist_loaders = train_loader, val_loader, test_loader
+ else:
+ train_loader, val_loader, test_loader = _smnist_loaders
+
+ if split == 'train':
+ return train_loader[filter_obj[0]]
+ elif split == 'val':
+ return val_loader[filter_obj[0]]
+ elif split == 'test':
+ return test_loader[filter_obj[0]]
+
+
+_rmnist_loaders = None
+def get_rotated_mnist_dataloader(cfg, split='train', filter_obj=None, batch_size=128, task_num=10, *args, **kwargs):
+ d = DotDict()
+ fuzzy = get_config_attr(cfg, 'EXTERNAL.OCL.FUZZY', default=0, mute=True)
+ global _rmnist_loaders
+ if not _rmnist_loaders:
+ data = get_rotated_mnist(d)
+ #train_loader, val_loader, test_loader = [CLDataLoader(elem, batch_size, train=t) \
+ # for elem, t in zip(data, [True, False, False])]
+ loader_cls = CLDataLoader if not fuzzy else FuzzyCLDataLoader
+ train_loader, val_loader, test_loader = [loader_cls(elem, batch_size, train=t) \
+ for elem, t in zip(data, [True, False, False])]
+ _rmnist_loaders = train_loader, val_loader, test_loader
+ else:
+ train_loader, val_loader, test_loader = _rmnist_loaders
+ if split == 'train':
+ return train_loader[filter_obj[0]]
+ elif split == 'val':
+ return val_loader[filter_obj[0]]
+ elif split == 'test':
+ return test_loader[filter_obj[0]]
+
+_pmnist_loaders = None
+def get_permute_mnist_dataloader(cfg, split='train', filter_obj=None, batch_size=128, task_num=10, *args, **kwargs):
+ d = DotDict()
+ fuzzy = get_config_attr(cfg, 'EXTERNAL.OCL.FUZZY', default=0, mute=True)
+ global _pmnist_loaders
+ if not _pmnist_loaders:
+ data = get_permuted_mnist(d)
+ loader_cls = CLDataLoader if not fuzzy else FuzzyCLDataLoader
+ train_loader, val_loader, test_loader = [loader_cls(elem, batch_size, train=t) \
+ for elem, t in zip(data, [True, False, False])]
+ _pmnist_loaders = train_loader, val_loader, test_loader
+ else:
+ train_loader, val_loader, test_loader = _pmnist_loaders
+ if split == 'train':
+ return train_loader[filter_obj[0]]
+ elif split == 'val':
+ return val_loader[filter_obj[0]]
+ elif split == 'test':
+ return test_loader[filter_obj[0]]
+
+_cache_cifar = None
+def get_split_cifar_dataloader(cfg, split='train', filter_obj=None, batch_size=128, *args, **kwargs):
+ d = DotDict()
+ fuzzy = get_config_attr(cfg, 'EXTERNAL.OCL.FUZZY', default=0, mute=True)
+ global _cache_cifar
+ if not _cache_cifar:
+ data = get_split_cifar10(d,cfg) #ds_cifar10and100(batch_size=batch_size, num_workers=0, cfg=cfg, **kwargs)
+ loader_cls = CLDataLoader if not fuzzy else FuzzyCLDataLoader
+ train_loader, val_loader, test_loader = [loader_cls(elem, batch_size, train=t) \
+ for elem, t in zip(data, [True, False, False])]
+ _cache_cifar = train_loader, val_loader, test_loader
+ train_loader, val_loader, test_loader = _cache_cifar
+ if split == 'train':
+ return train_loader[filter_obj[0]]
+ elif split == 'val':
+ return val_loader[filter_obj[0]]
+ elif split == 'test':
+ return test_loader[filter_obj[0]]
+
+_cache_cifar100 = None
+def get_split_cifar100_dataloader(cfg, split='train', filter_obj=None, batch_size=128, *args, **kwargs):
+ d = DotDict()
+ fuzzy = get_config_attr(cfg, 'EXTERNAL.OCL.FUZZY', default=0, mute=True)
+ global _cache_cifar100
+ if not _cache_cifar100:
+ data = get_split_cifar100(d,cfg) #ds_cifar10and100(batch_size=batch_size, num_workers=0, cfg=cfg, **kwargs)
+ loader_cls = CLDataLoader if not fuzzy else FuzzyCLDataLoader
+ train_loader, val_loader, test_loader = [loader_cls(elem, batch_size, train=t) \
+ for elem, t in zip(data, [True, False, False])]
+ _cache_cifar100 = train_loader, val_loader, test_loader
+ train_loader, val_loader, test_loader = _cache_cifar100
+ if split == 'train':
+ return train_loader[filter_obj[0]]
+ elif split == 'val':
+ return val_loader[filter_obj[0]]
+ elif split == 'test':
+ return test_loader[filter_obj[0]]
+
+_cache_mini_imagenet = None
+def get_split_mini_imagenet_dataloader(cfg, split='train', filter_obj=None, batch_size=128, *args, **kwargs):
+ global _cache_mini_imagenet
+ d = DotDict()
+ fuzzy = get_config_attr(cfg, 'EXTERNAL.OCL.FUZZY', default=0, mute=True)
+ if not _cache_mini_imagenet:
+ data = get_miniimagenet(d)
+ loader_cls = CLDataLoader if not fuzzy else FuzzyCLDataLoader
+ train_loader, val_loader, test_loader = [loader_cls(elem, batch_size, train=t) \
+ for elem, t in zip(data, [True, False, False])]
+ _cache_mini_imagenet = train_loader, val_loader, test_loader
+ train_loader, val_loader, test_loader = _cache_mini_imagenet
+ if split == 'train':
+ return train_loader[filter_obj[0]]
+ elif split == 'val':
+ return val_loader[filter_obj[0]]
+ elif split == 'test':
+ return test_loader[filter_obj[0]]
+
diff --git a/inference.py b/inference.py
new file mode 100644
index 0000000..5243da1
--- /dev/null
+++ b/inference.py
@@ -0,0 +1,359 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import logging
+import h5py
+
+import torch
+import json
+from tqdm import tqdm
+from time import sleep
+
+from config import cfg
+from utils.utils import Timer
+from dataloader import get_dataloader
+from sklearn.metrics import f1_score
+from collections import defaultdict
+from torch.nn import functional as F
+
+
+def to_list(tensor):
+ return tensor.cpu().numpy().tolist()
+
+def compute_on_dataset(model, data_loader, local_rank, device, timer=None, output_file='debug.h5'):
+ model.eval()
+ results_dict = {}
+ gt_dict = {}
+ cpu_device = torch.device("cpu")
+ text = "GPU {}".format(local_rank)
+ pbar = tqdm(
+ total=len(data_loader),
+ position=local_rank,
+ desc=text,
+ )
+ with h5py.File(output_file, 'w') as f:
+ box_feat_ds = f.create_dataset(
+ 'bbox_features', shape=(len(data_loader), 150, 1024)
+ )
+
+ box_ds = f.create_dataset(
+ 'bboxes', shape=(len(data_loader), 150, 4)
+ )
+
+ num_boxes_ds = f.create_dataset(
+ 'num_boxes', shape=(len(data_loader), 1)
+ )
+
+ img_size_ds = f.create_dataset(
+ 'image_size', shape=(len(data_loader), 2)
+ )
+
+ info_dict = {}
+
+ for e, out_dict in enumerate(data_loader):
+ images = out_dict['images']
+ targets = out_dict['gt_bboxes']
+ image_ids = out_dict['image_ids']
+ info = out_dict['info']
+
+ images = [image.to(device) for image in images]
+ targets = [target.to(device) for target in targets]
+ if timer:
+ timer.tic()
+ x, _, output = model(images, targets)
+ if timer:
+ if not cfg.MODEL.DEVICE == 'cpu':
+ torch.cuda.synchronize()
+ timer.toc()
+ assert (x.shape[0] == len(targets[0].bbox))
+ output = [o.to(cpu_device) for o in output]
+ info_dict[image_ids[0]] = {}
+ info_dict[image_ids[0]]['idx'] = e
+ info_dict[image_ids[0]]['objects'] = {}
+ bboxes = targets[0].bbox
+ num_boxes = len(targets[0].bbox)
+ for i in range(num_boxes):
+ tmp = {}
+ tmp = info[0][i]
+ tmp['idx'] = i
+ info_dict[image_ids[0]]['objects'][info[0][i]['object_id']] = tmp
+
+ '''
+ for i in range(5):
+ fpn_feat_ds[i][e] = output[i].numpy()
+ '''
+ box_feat_ds[e, :num_boxes, :] = x.cpu().numpy()
+ box_ds[e, :num_boxes, :] = bboxes.cpu().numpy()
+ num_boxes_ds[e] = num_boxes
+ img_size_ds[e] = targets[0].size
+ pbar.update(1)
+ sleep(0.001)
+
+ with open(output_file.replace('.h5', '_map.json'), 'w') as fp:
+ json.dump(info_dict, fp)
+
+
+def inference_step(model, out_dict, device=torch.device('cuda')):
+ images = torch.stack(out_dict['images'])
+ obj_labels = torch.cat(out_dict['object_labels'], -1)
+ attr_labels = torch.cat(out_dict['attribute_labels'], -1)
+ cropped_image = torch.stack(out_dict['cropped_image'])
+ images = images.to(device)
+ obj_labels = obj_labels.to(device)
+ attr_labels = attr_labels.to(device)
+
+ cropped_image = cropped_image.to(device)
+ # loss_dict = model(images, targets)
+ ret_dict = model(bbox_images=cropped_image, spatial_feat=None,
+ attr_labels=attr_labels, obj_labels=obj_labels,
+ images=images)
+ attr_score, obj_score = ret_dict.get('attr_score', None), \
+ ret_dict.get('obj_score', None)
+ if attr_score is not None:
+ attr_score_norm = F.softmax(attr_score, -1)
+ ret_dict['pred_attr_prob'], ret_dict['pred_attr'] = attr_score_norm.max(-1)
+ if obj_score is not None:
+ obj_score_norm = F.softmax(obj_score, -1)
+ ret_dict['pred_obj_prob'], ret_dict['pred_obj'] = obj_score_norm.max(-1)
+ ret_dict['obj_labels'], ret_dict['attr_labels'] = obj_labels, attr_labels
+ return ret_dict
+
+
+def inference(
+ model,
+ current_epoch,
+ current_iter,
+ local_rank,
+ data_loader,
+ dataset_name,
+ device="cuda",
+ max_instance=3200,
+ mute=False,
+ verbose_return=False
+):
+ model.train(False)
+ # convert to a torch.device for efficiency
+ device = torch.device(device)
+ if not mute:
+ logger = logging.getLogger("maskrcnn_benchmark.inference")
+ logger.info("Start evaluation")
+ total_timer = Timer()
+ total_timer.tic()
+ torch.cuda.empty_cache()
+ if not mute:
+ pbar = tqdm(
+ total=len(data_loader),
+ desc="Validation in progress"
+ )
+
+ def to_list(tensor):
+ return tensor.cpu().numpy().tolist()
+ with torch.no_grad():
+ all_pred_obj, all_truth_obj, all_pred_attr, all_truth_attr = [], [], [], []
+ all_image_ids, all_boxes = [], []
+ all_pred_attr_prob = []
+ all_raws = []
+ obj_loss_all, attr_loss_all = 0, 0
+
+ cnt = 0
+ for iteration, out_dict in enumerate(data_loader):
+ if type(max_instance) is int:
+ if iteration == max_instance // model.cfg.EXTERNAL.BATCH_SIZE: break
+ if type(max_instance) is float:
+ if iteration > max_instance * len(data_loader) // model.cfg.EXTERNAL.BATCH_SIZE: break
+ # print(iteration)
+
+ if verbose_return:
+ all_image_ids.extend(out_dict['image_ids'])
+ all_boxes.extend(out_dict['gt_bboxes'])
+ all_raws.extend(out_dict['raw'])
+
+ ret_dict = inference_step(model, out_dict, device)
+ loss_attr, loss_obj, attr_score, obj_score = ret_dict.get('attr_loss', None), \
+ ret_dict.get('obj_loss', None), \
+ ret_dict.get('attr_score', None), \
+ ret_dict.get('obj_score', None)
+
+ if loss_attr is not None:
+ attr_loss_all += loss_attr.item()
+ pred_attr_prob, pred_attr = ret_dict['pred_attr_prob'], ret_dict['pred_attr']
+ all_pred_attr.extend(to_list(pred_attr))
+ all_truth_attr.extend(to_list(ret_dict['attr_labels']))
+ all_pred_attr_prob.extend(to_list(pred_attr_prob))
+ if loss_obj is not None:
+ obj_loss_all += loss_obj.item()
+ _, pred_obj = obj_score.max(-1)
+ all_pred_obj.extend(to_list(pred_obj))
+ all_truth_obj.extend(to_list(ret_dict['obj_labels']))
+ cnt += 1
+ if not mute:
+ pbar.update(1)
+
+ obj_f1 = f1_score(all_truth_obj, all_pred_obj, average='micro')
+ attr_f1 = f1_score(all_truth_attr, all_pred_attr, average='micro')
+ obj_loss_all /= (cnt + 1e-10)
+ attr_loss_all /= (cnt + 1e-10)
+ if not mute:
+ logger.info('Epoch: {}\tIteration: {}\tObject f1: {}\tAttr f1:{}\tObject loss:{}\tAttr loss:{}'.
+ format(current_epoch, current_iter, obj_f1, attr_f1, obj_loss_all, attr_loss_all))
+ #compute_on_dataset(model, data_loader, local_rank, device, inference_timer, output_file)
+ # wait for all processes to complete before measuring the time
+ total_time = total_timer.toc()
+ model.train(True)
+ if not verbose_return:
+ return obj_f1, attr_f1, len(all_truth_attr)
+ else:
+ return obj_f1, attr_f1, all_pred_attr, all_truth_attr, all_pred_obj, all_truth_obj, all_image_ids, all_boxes, \
+ all_pred_attr_prob, all_raws
+
+def run_forget_metrics(metric_dict, finished_tasks, all_tasks, key, forget_dict):
+ for task in all_tasks:
+ if len(metric_dict[key][task]) <= 1 or task not in finished_tasks:
+ forget_dict[task].append(-1)
+ else:
+ forget_dict[task].append(max(metric_dict[key][task][:-1]) - metric_dict[key][task][-1])
+
+
+def run_forward_transfer_metrics(metric_dict, seen_tasks, all_tasks, key, ft_dict):
+ for task in all_tasks:
+ if task not in seen_tasks:
+ ft_dict[task].append(metric_dict[key][task][-1])
+ else:
+ ft_dict[task].append(-1)
+
+
+def numericalize_metric_scores(metric_dict):
+ result_dict = defaultdict(list)
+ for t in range(metric_dict['length']):
+ for key in ['attr_acc', 'forget_dict', 'obj_acc']:
+ total_inst = 0
+ total_score = 0
+ for attr in metric_dict[key]:
+ inst_num = metric_dict['inst_num'][attr][t]
+ score = metric_dict[key][attr][t]
+ if score != -1:
+ total_inst += inst_num
+ total_score += score * inst_num
+ avg_score = total_score / (total_inst + 1e-10)
+ result_dict[key].append(avg_score)
+ return result_dict
+
+
+def inference_ocl_attr(
+ model,
+ current_epoch,
+ current_iter,
+ dataset_name,
+ prev_metric_dict,
+ seen_objects,
+ finished_objects,
+ all_objects,
+ max_instance
+):
+ """
+
+ :param model:
+ :param current_epoch:
+ :param current_iter:
+ :param prev_metric_dict: {attr_acc: : [acc1, acc2]}
+ :param filter_objects:
+ :param filter_attrs:
+ :return:
+ """
+ model.train(False)
+ device = torch.device('cuda')
+
+ if not prev_metric_dict:
+ prev_metric_dict = {
+ 'attr_acc': defaultdict(list),
+ 'inst_num': defaultdict(list),
+ 'ft_dict': defaultdict(list),
+ 'forget_dict': defaultdict(list),
+ 'obj_acc': defaultdict(list),
+ 'length': 0
+ }
+
+ pbar = tqdm(
+ total=len(all_objects),
+ desc="Validation in progress"
+ )
+ # only seen objects by this time
+ for obj in all_objects:
+ dataloader = get_dataloader(model.cfg, 'val',False,False,filter_obj=[obj])
+ obj_acc, attr_acc, inst_num = inference(model, current_epoch, current_iter, 0, dataloader, dataset_name,
+ max_instance=max_instance, mute=True)
+
+ prev_metric_dict['attr_acc'][obj].append(attr_acc)
+ prev_metric_dict['inst_num'][obj].append(inst_num)
+ prev_metric_dict['obj_acc'][obj].append(obj_acc)
+ pbar.update(1)
+
+ metric_dict = prev_metric_dict
+ #run_forward_transfer_metrics(metric_dict, seen_objects, all_objects, 'attr_acc', metric_dict['ft_dict'])
+ run_forget_metrics(metric_dict, finished_objects, all_objects, 'attr_acc', metric_dict['forget_dict'])
+ metric_dict['length'] += 1
+
+ numerical_metric_dict = numericalize_metric_scores(metric_dict)
+
+ return metric_dict, numerical_metric_dict
+
+def inference_mean_exemplar(
+ model,
+ current_epoch,
+ current_iter,
+ local_rank,
+ data_loader,
+ dataset_name,
+ device="cuda",
+ max_instance=3200,
+ mute=False,
+):
+ model.train(False)
+ # convert to a torch.device for efficiency
+ device = torch.device(device)
+ if not mute:
+ logger = logging.getLogger("maskrcnn_benchmark.inference")
+ logger.info("Start evaluation")
+ total_timer = Timer()
+ inference_timer = Timer()
+ total_timer.tic()
+ torch.cuda.empty_cache()
+ if not mute:
+ pbar = tqdm(
+ total=len(data_loader),
+ desc="Validation in progress"
+ )
+ with torch.no_grad():
+ all_pred_obj, all_truth_obj, all_pred_attr, all_truth_attr = [], [], [], []
+ obj_loss_all, attr_loss_all = 0, 0
+ cnt = 0
+ for iteration, out_dict in enumerate(data_loader):
+ if type(max_instance) is int:
+ if iteration == max_instance // model.cfg.EXTERNAL.BATCH_SIZE: break
+ if type(max_instance) is float:
+ if iteration > max_instance * len(data_loader) // model.cfg.EXTERNAL.BATCH_SIZE: break
+ # print(iteration)
+ images = torch.stack(out_dict['images'])
+ obj_labels = torch.cat(out_dict['object_labels'], -1)
+ attr_labels = torch.cat(out_dict['attribute_labels'], -1)
+ cropped_image = torch.stack(out_dict['cropped_image'])
+
+ images = images.to(device)
+ obj_labels = obj_labels.to(device)
+ attr_labels = attr_labels.to(device)
+
+ cropped_image = cropped_image.to(device)
+ # loss_dict = model(images, targets)
+ pred_obj = model.mean_of_exemplar_classify(cropped_image)
+
+ all_pred_obj.extend(to_list(pred_obj))
+ all_truth_obj.extend(to_list(obj_labels))
+ cnt += 1
+ if not mute:
+ pbar.update(1)
+
+ obj_f1 = f1_score(all_truth_obj, all_pred_obj, average='micro')
+ #attr_f1 = f1_score(all_truth_attr, all_pred_attr, average='micro')
+ obj_loss_all /= (cnt + 1e-10)
+ # wait for all processes to complete before measuring the time
+ total_time = total_timer.toc()
+ model.train(True)
+ return obj_f1, 0, len(all_truth_obj)
\ No newline at end of file
diff --git a/nets/__init__.py b/nets/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nets/classifier.py b/nets/classifier.py
new file mode 100644
index 0000000..1411ba2
--- /dev/null
+++ b/nets/classifier.py
@@ -0,0 +1,303 @@
+from .initialization import get_glove_matrix
+from torchvision.models import resnet50, resnet34 #resnet18
+import torch
+from torch import nn
+import torch.nn.functional as F
+import logging
+
+logger = logging.getLogger(__name__)
+logger.setLevel('INFO')
+logger.addHandler(logging.StreamHandler())
+
+
+def get_resnet_features(x, net):
+ if isinstance(net, ResNet):
+ return net(x)
+ else:
+ x = net.conv1(x)
+ x = net.bn1(x)
+ x = net.relu(x)
+ x = net.maxpool(x)
+
+ x = net.layer1(x)
+ x = net.layer2(x)
+ x = net.layer3(x)
+ x = net.layer4(x)
+
+ x = net.avgpool(x)
+ x = torch.flatten(x, 1)
+ return x
+
+class ResNetClassifier(nn.Module):
+ def __init__(self, cfg, depth='34', mlp=3, init=True, ignore_index=0, num_of_datasets=1,
+ num_of_classes=100, task_incremental=False, goal=None, *args, **kwargs):
+ super().__init__()
+
+ self.debug = hasattr(cfg, 'DEBUG') and cfg.DEBUG
+
+ if depth == '34':
+ self.resnet = resnet34(pretrained=False)
+ resnet_feat_size = self.resnet.fc.weight.size(1)
+ elif depth == '18':
+ if goal == 'split_mini_imagenet':
+ self.resnet = ResNet18(input_size=(3,84,84))
+ else:
+ self.resnet = ResNet18()
+ resnet_feat_size = self.resnet.last_hid_size
+ hidden_size = resnet_feat_size
+ mlps = []
+ for _ in range(mlp-1):
+ mlps.append(nn.Linear(hidden_size, hidden_size))
+ mlps.append(nn.ReLU())
+
+ self.mlp_attr = nn.Sequential(
+ *mlps
+ )
+
+ self.task_incremental = task_incremental
+ self.num_of_datasets = num_of_datasets
+ self.num_of_classes = num_of_classes
+
+ if self.debug:
+ self.final_layer = nn.Linear(self.resnet.last_hid_size, num_of_classes * num_of_datasets)
+ else:
+ self.final_layer = nn.Linear(hidden_size, num_of_classes * num_of_datasets)
+
+ self.mlp_obj = nn.Sequential(
+ nn.Linear(hidden_size, hidden_size // 2),
+ nn.ReLU(),
+ nn.Dropout(0.2),
+ nn.Linear(hidden_size // 2, 2000)
+ )
+
+ self.ignore_index = ignore_index
+ self.criterion = F.cross_entropy
+
+ self.cfg = cfg
+ #self.task_specific_params = nn.ModuleList([self.mlp_attr[0], self.mlp_attr[2], self.mlp_attr[4]])
+
+ if init and cfg.MODE == 'train':
+ self.initialize()
+
+ def forward(self, images, bbox_images, spatial_feat, attr_labels, obj_labels, weights=None, reduce=True,
+ task_ids=None):
+ if self.debug:
+ feat = self.resnet.return_hidden(bbox_images)
+ attr_score = self.final_layer(feat)
+ else:
+ feat = get_resnet_features(bbox_images, self.resnet)
+ attr_feat = self.mlp_attr(torch.cat([feat], -1))
+ attr_score = self.final_layer(attr_feat)
+
+ if self.task_incremental:
+ scores = [] # get logits of corresponding tasks
+ for b in range(attr_score.size(0)):
+ scores.append(attr_score[b, task_ids[b] * self.num_of_classes: (task_ids[b] + 1) * self.num_of_classes])
+ attr_score = torch.stack(scores)
+
+ loss_attr, loss_obj = None, None
+ if attr_labels is not None:
+ reduction = 'mean' if reduce else 'none'
+ loss_attr = self.criterion(attr_score, attr_labels, ignore_index=self.ignore_index,
+ reduction=reduction)
+ # if obj_labels is not None:
+ # reduction = 'mean' if reduce else 'none'
+ # loss_obj = self.criterion(obj_score, obj_labels, ignore_index=self.ignore_index,
+ # reduction=reduction)
+
+ return {
+ 'attr_score': attr_score,
+ #'obj_score': obj_score,
+ 'attr_loss': loss_attr,
+ #'obj_loss': loss_obj,
+ 'loss': loss_attr, #loss_obj + loss_attr,
+ 'score': attr_score,
+ 'feat': feat
+ }
+
+ def forward_from_feat(self, feat, attr_labels, reduce=True, **kwargs):
+ reduction = 'mean' if reduce else 'none'
+ attr_score = self.mlp_attr(feat)
+ loss_attr = self.criterion(attr_score, attr_labels, ignore_index=self.ignore_index,
+ reduction=reduction)
+ return {
+ 'attr_score': attr_score,
+ 'attr_loss': loss_attr,
+ 'loss': loss_attr,
+ 'score': attr_score,
+ 'feat': feat
+ }
+
+ def get_obj_features(self, bbox_images):
+ feat = get_resnet_features(bbox_images, self.resnet)
+ for i, module in enumerate(self.mlp_obj._modules.values()):
+ if i == len(self.mlp_obj) - 1: break
+ feat = module(feat)
+ return feat
+
+ def set_task_specific_weights(self, idx, weight):
+ self.task_specific_params[idx].weight = weight
+
+ def _freeze_resnet_if_required(self, cfg):
+ if hasattr(cfg.EXTERNAL, 'FREEZE_RESNET') and cfg.EXTERNAL.FREEZE_RESNET:
+ for param in self.resnet.parameters():
+ param.requires_grad = False
+ logger.info('Resnet frozen')
+
+ def _load_resnet_weights_if_required(self, cfg):
+ if hasattr(cfg.EXTERNAL, 'LOAD_RESNET') and cfg.EXTERNAL.LOAD_RESNET:
+ checkpoint = torch.load(cfg.EXTERNAL.LOAD_RESNET, map_location=torch.device("cpu"))
+ model_state_dict = checkpoint['model']
+ model_state_dict = dict(filter(lambda x: x[0].startswith('resnet'), model_state_dict.items()))
+ model = ResNetClassifier(self.cfg, init=False)
+ model.load_state_dict(model_state_dict, strict=False)
+ self.resnet = model.resnet.to(cfg.MODEL.DEVICE)
+ logger.info('loaded resnet from %s' % cfg.EXTERNAL.LOAD_RESNET)
+
+ def initialize(self):
+ self._freeze_resnet_if_required(self.cfg)
+ self._load_resnet_weights_if_required(self.cfg)
+
+
+class ResNetClassifierWObj(ResNetClassifier):
+ def __init__(self, cfg, init=True, *args, **kwargs):
+ # do not init
+ super().__init__(cfg, init=False, *args, **kwargs)
+ self.cfg = cfg
+ self.resnet = resnet34(pretrained=True)
+ resnet_feat_size = self.resnet.fc.weight.size(1)
+ hidden_size = resnet_feat_size
+ label_embed_size = self.cfg.WOBJ.LABEL_EMBED_DIM
+
+ self.obj_emb = nn.Embedding(2000, label_embed_size)
+ self.mlp_attr = nn.Sequential(
+ nn.Linear(hidden_size + label_embed_size, hidden_size),
+ nn.ReLU(),
+ nn.Linear(hidden_size, hidden_size),
+ nn.ReLU(),
+ nn.Linear(hidden_size, 101)
+ )
+ self.task_specific_params = nn.ModuleList([self.mlp_attr[0], self.mlp_attr[2], self.mlp_attr[4]])
+ self.criterion = nn.CrossEntropyLoss(ignore_index=0)
+ if init and cfg.MODE == 'train':
+ self.initialize()
+
+ def forward(self, images, bbox_images, spatial_feat, attr_labels, obj_labels, weights=None):
+ obj_emb = self.obj_emb(obj_labels)
+ feat = get_resnet_features(bbox_images, self.resnet)
+ feat = torch.cat([feat, obj_emb], -1)
+
+ if weights is None:
+ attr_score = self.mlp_attr(feat)
+ else: # hypernetwork
+ hidden_1 = F.relu(F.linear(feat, weights[0], weights[1]))
+ hidden_2 = F.relu(F.linear(hidden_1, weights[2], weights[3]))
+ attr_score = F.linear(hidden_2, weights[4], weights[5])
+
+ loss_attr, loss_obj = None, None
+ if attr_labels is not None:
+ loss_attr = self.criterion(attr_score, attr_labels)
+ #if obj_labels is not None:
+ # loss_obj = self.criterion(obj_score, obj_labels)
+
+ return {
+ 'attr_score': attr_score,
+ #'obj_score': obj_score,
+ 'attr_loss': loss_attr,
+ #'obj_loss': loss_obj,
+ 'loss': loss_attr
+ }
+
+ def get_obj_features(self, bbox_images):
+ feat = get_resnet_features(bbox_images, self.resnet)
+ for i, module in enumerate(self.mlp_obj._modules.values()):
+ if i == len(self.mlp_obj) - 1: break
+ feat = module(feat)
+ return feat
+
+ def set_task_specific_weights(self, idx, weight):
+ self.task_specific_params[idx].weight = weight
+
+ def initialize_word_emb(self, vocab, w2v_file):
+ device = self.obj_emb.weight.data.device
+ mat = get_glove_matrix(vocab, w2v_file, self.obj_emb.weight.data.cpu().numpy())
+ mat = torch.from_numpy(mat).float().to(device)
+ self.obj_emb.weight.data = mat
+ logger.info('loaded embedding from {}'.format(w2v_file))
+
+def conv3x3(in_planes, out_planes, stride=1):
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, in_planes, planes, stride=1):
+ super(BasicBlock, self).__init__()
+ self.conv1 = conv3x3(in_planes, planes, stride)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = nn.BatchNorm2d(planes)
+
+ self.shortcut = nn.Sequential()
+ if stride != 1 or in_planes != self.expansion * planes:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,
+ stride=stride, bias=False),
+ nn.BatchNorm2d(self.expansion * planes)
+ )
+
+ def forward(self, x):
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = self.bn2(self.conv2(out))
+ out += self.shortcut(x)
+ out = F.relu(out)
+ return out
+
+class ResNet(nn.Module):
+ def __init__(self, block, num_blocks, nf, input_size):
+ super(ResNet, self).__init__()
+ self.in_planes = nf
+ self.input_size = input_size
+
+ self.conv1 = conv3x3(input_size[0], nf * 1)
+ self.bn1 = nn.BatchNorm2d(nf * 1)
+ #self.bn1 = CategoricalConditionalBatchNorm(nf, 2)
+ self.layer1 = self._make_layer(block, nf * 1, num_blocks[0], stride=1)
+ self.layer2 = self._make_layer(block, nf * 2, num_blocks[1], stride=2)
+ self.layer3 = self._make_layer(block, nf * 4, num_blocks[2], stride=2)
+ self.layer4 = self._make_layer(block, nf * 8, num_blocks[3], stride=2)
+
+ # hardcoded for now
+ self.last_hid_size = nf * 8 * block.expansion if input_size[1] in [8,16,21,32,42] else 640
+ #self.linear = nn.Linear(last_hid, num_classes)
+
+ def _make_layer(self, block, planes, num_blocks, stride):
+ strides = [stride] + [1] * (num_blocks - 1)
+ layers = []
+ for stride in strides:
+ layers.append(block(self.in_planes, planes, stride))
+ self.in_planes = planes * block.expansion
+ return nn.Sequential(*layers)
+
+ def return_hidden(self, x):
+ bsz = x.size(0)
+ #pre_bn = self.conv1(x.view(bsz, 3, 32, 32))
+ #post_bn = self.bn1(pre_bn, 1 if is_real else 0)
+ #out = F.relu(post_bn)
+ out = F.relu(self.bn1(self.conv1(x.view(bsz, *self.input_size))))
+ out = self.layer1(out)
+ out = self.layer2(out)
+ out = self.layer3(out)
+ out = self.layer4(out)
+ out = F.avg_pool2d(out, 4)
+ out = out.view(out.size(0), -1)
+ return out
+
+ def forward(self, x):
+ out = self.return_hidden(x)
+ #out = self.linear(out)
+ return out
+
+def ResNet18(nf=20, input_size=(3, 32, 32)):
+ return ResNet(BasicBlock, [2, 2, 2, 2], nf, input_size)
\ No newline at end of file
diff --git a/nets/initialization.py b/nets/initialization.py
new file mode 100644
index 0000000..acee318
--- /dev/null
+++ b/nets/initialization.py
@@ -0,0 +1,36 @@
+import logging
+import numpy as np
+
+logger = logging.getLogger(__name__)
+logger.setLevel('INFO')
+
+def get_glove_matrix(vocab, glove_path, initial_embedding_np):
+ """
+ return a glove embedding matrix
+ :param initial_embedding_np:
+ :return: np array of [V,E]
+ """
+ ef = open(glove_path, 'r', encoding='utf-8')
+ cnt = 0
+ vec_array = initial_embedding_np
+ old_avg = np.average(vec_array)
+ old_std = np.std(vec_array)
+ vec_array = vec_array.astype(np.float32)
+ new_avg, new_std = 0, 0
+
+ for line in ef.readlines():
+ line = line.strip().split(' ')
+ word, vec = line[0].lower(), line[1:]
+ vec = np.array(vec, np.float32)
+ if word in vocab:
+ cnt += 1
+ word_idx = vocab[word]
+ vec_array[word_idx] = vec
+ new_avg += np.average(vec)
+ new_std += np.std(vec)
+ new_avg /= cnt
+ new_std /= cnt
+ ef.close()
+ print('%d known embedding. old mean: %f new mean %f, old std %f new std %f' % (cnt, old_avg,
+ new_avg, old_std, new_std))
+ return vec_array
diff --git a/nets/simplenet.py b/nets/simplenet.py
new file mode 100644
index 0000000..500c99d
--- /dev/null
+++ b/nets/simplenet.py
@@ -0,0 +1,83 @@
+import torch.nn.functional as func
+import torch.nn as nn
+import torch
+
+class GatedDense(nn.Module):
+ def __init__(self, input_size, output_size, activation=None):
+ super(GatedDense, self).__init__()
+
+ self.activation = activation
+ self.sigmoid = nn.Sigmoid()
+ self.h = nn.Linear(input_size, output_size)
+ self.g = nn.Linear(input_size, output_size)
+
+ def forward(self, x):
+ h = self.h(x)
+ if self.activation is not None:
+ h = self.activation( self.h( x ) )
+
+ g = self.sigmoid( self.g( x ) )
+
+ return h * g
+
+class FC2Layers(nn.Module):
+ def __init__(self, **kwargs):
+ super(FC2Layers, self).__init__()
+ layer1_width = 400
+
+ self.ds_idx = 0
+ self.num_of_datasets = kwargs.get("num_of_datasets", 1)
+ self.num_of_classes = kwargs.get("num_of_classes", 10)
+ self.input_size = kwargs.get("input_size", 784)
+ self.task_incremental = kwargs.get('task_incremental', False)
+
+ act = nn.ReLU()
+ self.layer1 = nn.Sequential(
+ nn.Linear(self.input_size, layer1_width),
+ nn.ReLU(),
+ nn.Linear(layer1_width, layer1_width),
+ nn.ReLU()
+ )
+
+ self.last_layer = nn.Linear(layer1_width, self.num_of_classes * self.num_of_datasets)
+ self.criterion = func.cross_entropy
+
+
+ def forward(self, x, y, task_ids=None, reduce=True, from_weights=False, weights=None):
+ x = x.view(-1, self.input_size)
+ if from_weights:
+ out = func.relu(func.linear(x, weights[0], weights[1]))
+ feat = func.relu(func.linear(out, weights[2], weights[3]))
+ out = func.linear(feat, weights[4], weights[5])
+ else:
+ out = self.layer1(x)
+ feat = func.relu(out)
+ out = self.last_layer(feat)
+
+ if self.task_incremental:
+ scores = [] # get logits of corresponding tasks
+ for b in range(x.size(0)):
+ scores.append(feat[b, task_ids[b] * self.num_of_classes: (task_ids[b] + 1) * self.num_of_classes])
+ out = torch.stack(scores)
+
+ if reduce:
+ loss = self.criterion(out, y)
+ else:
+ loss = self.criterion(out, y, reduction='none')
+ return {'score': out, 'loss': loss, 'feat': feat}
+
+ def forward_from_feat(self, feat, y, reduce=True, **kwargs):
+ out = self.last_layer[self.ds_idx](feat)
+ if reduce:
+ loss = self.criterion(out, y)
+ else:
+ loss = self.criterion(out, y, reduction='none')
+ return {'score': out, 'loss': loss, 'feat': feat}
+
+
+
+def mnist_simple_net_400width_classlearning_1024input_10cls_1ds(**kwargs):
+ return FC2Layers(input_size=28 * 28, layer1_width=400, layer2_width=400, **kwargs)
+
+def mnist_simple_net_400width_domainlearning_1024input_10cls_1ds(**kwargs):
+ return FC2Layers(input_size=28 * 28, layer1_width=400, layer2_width=400, **kwargs)
\ No newline at end of file
diff --git a/ocl/__init__.py b/ocl/__init__.py
new file mode 100644
index 0000000..e32d7e5
--- /dev/null
+++ b/ocl/__init__.py
@@ -0,0 +1,5 @@
+from .agem import AGEM
+from .er import ExperienceReplay
+from .naive import NaiveWrapper
+from .ver import FOExperienceEvolve
+from .ver_approx import ExperienceEvolveApprox
\ No newline at end of file
diff --git a/ocl/agem.py b/ocl/agem.py
new file mode 100644
index 0000000..2666414
--- /dev/null
+++ b/ocl/agem.py
@@ -0,0 +1,89 @@
+import torch
+from torch import nn
+from .er import ExperienceReplay, store_grad
+from config import cfg
+from utils.utils import get_config_attr
+
+def overwrite_grad(pp, newgrad, grad_dims):
+ """
+ This is used to overwrite the gradients with a new gradient
+ vector, whenever violations occur.
+ pp: parameters
+ newgrad: corrected gradient
+ grad_dims: list storing number of parameters at each layer
+ """
+ cnt = 0
+ for param in pp():
+ if param.grad is not None:
+ beg = 0 if cnt == 0 else sum(grad_dims[:cnt])
+ en = sum(grad_dims[:cnt + 1])
+ this_grad = newgrad[beg: en].contiguous().view(
+ param.grad.data.size())
+ param.grad.data.copy_(this_grad)
+ cnt += 1
+
+
+class AGEM(ExperienceReplay):
+ def __init__(self, base, optimizer, input_size, cfg, goal):
+ super(AGEM, self).__init__(base, optimizer, input_size, cfg, goal)
+
+ # self.grads = self.grads.cuda()
+ self.violation_count = 0
+ self.agem_k = get_config_attr(cfg, 'EXTERNAL.OCL.AGEM_K', default=256)
+
+ def compute_grad(self, mem_x, mem_y):
+ self.zero_grad()
+ ret_dict = self.forward_net(mem_x, mem_y)
+ # regularize both of the loss
+ loss = ret_dict['loss']
+ loss.backward()
+ grads = torch.Tensor(sum(self.grad_dims)).to(mem_x.device)
+ store_grad(self.parameters, grads, self.grad_dims)
+ return grads
+
+ def fix_grad(self, mem_grads):
+ # check whether the current gradient interfere with the average gradients
+ grads = torch.Tensor(sum(self.grad_dims)).to(mem_grads.device)
+ store_grad(self.parameters, grads, self.grad_dims)
+ dotp = torch.dot(grads, mem_grads)
+ if dotp < 0:
+ # project the grads back to the mem_grads
+ # g_new = g - g^Tg_{ref} / g_{ref}^Tg_{ref} * g_{ref}
+ new_grad = grads - (torch.dot(grads, mem_grads) / (torch.dot(mem_grads, mem_grads) + 1e-12)) * mem_grads
+ overwrite_grad(self.parameters, new_grad, self.grad_dims)
+ return 1
+ else:
+ return 0
+
+ def observe(self, x, y, task_ids=None, extra=None, optimize=True):
+ # recover image, feat from x
+ self.optimizer.zero_grad()
+ mem_x, mem_y, _ = self.sample_mem_batch(x.device, k=self.agem_k)
+
+ if mem_x is not None:
+ # calculate gradients on the memory batch
+ mem_grads = self.compute_grad(mem_x, mem_y)
+
+ # backward on the current minibatch
+ self.optimizer.zero_grad()
+ batch_size = x.size(0)
+ ret_dict = \
+ self.forward_net(x, y)
+ for b in range(batch_size):
+ #self.update_mem(x[b], y[b], task_ids[b] if task_ids is not None else None)
+ if type(y) is tuple:
+ y_ = [_[b] for _ in y]
+ else:
+ y_ = y[b]
+ self.update_mem(x[b], y_, task_ids[b] if task_ids is not None else None,
+ x_feat=None)
+ loss = ret_dict['loss']
+ if optimize:
+ loss.backward()
+ if mem_x is not None:
+ violated = self.fix_grad(mem_grads)
+ self.violation_count += violated
+
+ self.optimizer.step()
+
+ return ret_dict
diff --git a/ocl/er.py b/ocl/er.py
new file mode 100644
index 0000000..1ef838b
--- /dev/null
+++ b/ocl/er.py
@@ -0,0 +1,774 @@
+import torch
+import numpy as np
+import pickle
+from utils.utils import get_config_attr
+import copy
+from .naive import NaiveWrapper
+from torch.optim import SGD, Adam
+from collections import defaultdict
+
+
+import math
+try:
+ from pytorch_transformers import AdamW
+except ImportError:
+ AdamW = None
+
+
+def y_to_np(y):
+ if type(y) is tuple:
+ return tuple(x.item() for x in y)
+ else:
+ return y.cpu().numpy()
+
+
+def y_to_cpu(y):
+ if torch.is_tensor(y):
+ y = y.cpu()
+ else:
+ y = [_.cpu() for _ in y]
+ return y
+
+
+def index_select(l, indices, device):
+ ret = []
+ for i in indices:
+ if type(l[i]) is np.ndarray:
+ x = torch.from_numpy(l[i]).to(device)
+ ret.append(x.unsqueeze(0))
+ else:
+ if type(l[i]) is list:
+ item = []
+ for j in range(len(l[i])):
+ if type(l[i][j]) is np.ndarray:
+ item.append(torch.from_numpy(l[i][j]))
+ else:
+ item.append(l[i][j])
+ ret.append(item)
+ else:
+ ret.append(l[i])
+ return ret
+
+
+def concat_with_padding(l):
+ if l is None or l[0] is None: return None
+ if type(l[0]) in [list, tuple]:
+ ret = [torch.cat(t, 0) for t in zip(*l)]
+ else:
+ if len(l[0].size()) == 2: # potentially requires padding
+ max_length = max([x.size(1) for x in l])
+ ret = []
+ for x in l:
+ pad = torch.zeros(x.size(0), max_length - x.size(1)).long().to(x.device)
+ x_pad = torch.cat([x, pad], -1)
+ ret.append(x_pad)
+ ret = torch.cat(ret, 0)
+ else:
+ ret = torch.cat(l, 0)
+ return ret
+
+
+def store_grad(pp, grads, grad_dims):
+ """
+ This stores parameter gradients of past tasks.
+ pp: parameters
+ grads: gradients
+ grad_dims: list with number of parameters per layers
+ tid: task id
+ """
+ # store the gradients
+ grads.fill_(0.0)
+ cnt = 0
+ for param in pp():
+ if param.grad is not None:
+ beg = 0 if cnt == 0 else sum(grad_dims[:cnt])
+ en = sum(grad_dims[:cnt + 1])
+ grads[beg: en].copy_(param.grad.data.view(-1))
+ cnt += 1
+
+class ExperienceReplay(NaiveWrapper):
+ def __init__(self, base, optimizer, input_size, cfg, goal):
+ super().__init__(base, optimizer, input_size, cfg, goal)
+ self.net = base
+ self.optimizer = optimizer
+ self.mem_limit = cfg.EXTERNAL.REPLAY.MEM_LIMIT
+ self.mem_bs = cfg.EXTERNAL.REPLAY.MEM_BS
+ self.input_size = input_size
+ self.reservoir, self.example_seen = None, None
+ self.reset_mem()
+ self.mem_occupied = {}
+
+ self.seen_tasks = []
+ self.balanced = False
+ self.policy = get_config_attr(cfg, 'EXTERNAL.OCL.POLICY', default='reservoir', totype=str)
+ self.mir_k = get_config_attr(cfg, 'EXTERNAL.REPLAY.MIR_K', default=10, totype=int)
+ self.mir = get_config_attr(cfg, 'EXTERNAL.OCL.MIR', default=0, totype=int)
+
+ self.mir_agg = get_config_attr(cfg, 'EXTERNAL.OCL.MIR_AGG', default='avg', totype=str)
+
+ self.concat_replay = get_config_attr(cfg, 'EXTERNAL.OCL.CONCAT', default=0, totype=int)
+ self.separate_replay = get_config_attr(cfg, 'EXTERNAL.OCL.SEPARATE', default=0, totype=int)
+ self.mem_augment = get_config_attr(cfg, 'EXTERNAL.OCL.MEM_AUG', default=0, totype=int)
+ self.legacy_aug = get_config_attr(cfg, 'EXTERNAL.OCL.LEGACY_AUG', default=0, totype=int)
+ self.use_hflip_aug = get_config_attr(cfg,'EXTERNAL.OCL.USE_HFLIP_AUG',default=1,totype=int)
+ self.padding_aug = get_config_attr(cfg,'EXTERNAL.OCL.PADDING_AUG',default=-1,totype=int)
+ self.rot_aug = get_config_attr(cfg,'EXTERNAL.OCL.ROT_AUG',default=-1,totype=int)
+
+ self.lb_reservoir = get_config_attr(cfg,'EXTERNAL.OCL.LB_RESERVOIR', default=0)
+
+ self.cfg = cfg
+ self.grad_dims = []
+
+ for param in self.parameters():
+ self.grad_dims.append(param.data.numel())
+
+ def get_config_padding(self, pad, rot):
+ if self.padding_aug != -1:
+ pad = self.padding_aug
+ if self.rot_aug != -1:
+ rot = self.rot_aug
+ return pad, rot
+
+ def reset_mem(self):
+ self.reservoir = {'x': np.zeros((self.mem_limit, self.input_size)),
+ 'y': [None] * self.mem_limit,
+ 'y_extra': [None] * self.mem_limit,
+ 'replay_time': np.zeros(self.mem_limit),
+ 'loss_stats': np.zeros(self.mem_limit),
+ 'label_cnts': defaultdict(int),
+ }
+ self.example_seen = 0
+
+ def update_mem(self, *args, **kwargs):
+ if self.policy == 'balanced':
+ return self.update_mem_balanced(*args, **kwargs)
+ elif self.policy == 'reservoir':
+ return self.update_mem_reservoir(*args, **kwargs)
+ elif self.policy == 'clus':
+ return self.update_mem_kmeans(*args, **kwargs)
+ else:
+ raise ValueError
+
+ def reinit_mem(self, xsize):
+ self.input_size = xsize
+ self.reset_mem()
+
+ def forward(self, *args, **kwargs):
+ return self.net(*args, **kwargs)
+
+ def update_mem_reservoir(self, x, y, y_extra=None, loss_x=None, *args, **kwargs):
+ if self.example_seen == 0 and self.reservoir['x'].shape[-1] != x.shape[-1]:
+ self.reinit_mem(x.shape[-1])
+
+ x = x.cpu().numpy()
+ if type(y) not in [list, tuple]:
+ y = y_to_np(y)
+ #if y.shape == (1,):
+ # y = y[0]
+ else:
+ y = y_to_cpu(y)
+
+ if type(y_extra) not in [list, tuple]:
+ y_extra = y_to_np(y_extra)
+ elif y_extra is not None:
+ y_extra = y_to_cpu(y_extra)
+
+ if self.example_seen < self.mem_limit:
+ self.reservoir['x'][self.example_seen] = x
+ self.reservoir['y'][self.example_seen] = y
+ self.reservoir['y_extra'][self.example_seen] = y_extra
+ #self.reservoir['replay_time'][self.example_seen] = 0
+ self.reservoir['label_cnts'][y.item()] += 1
+ j = self.example_seen
+ else:
+
+ #
+ #else:
+ j = np.random.RandomState(self.example_seen + self.cfg.SEED).randint(0, self.example_seen)
+ if j < self.mem_limit:
+ if self.lb_reservoir:
+ j = self.get_loss_aware_balanced_reservoir_sampling_index()
+ self.reservoir['label_cnts'][self.reservoir['y'][j].item()] -= 1
+ self.reservoir['x'][j] = x
+ self.reservoir['y'][j] = y
+ self.reservoir['y_extra'][j] = y_extra
+ #self.reservoir['replay_time'][j] = 0
+ self.reservoir['label_cnts'][y.item()] += 1
+
+ #if loss_x is not None:
+ # self.reservoir['loss_stats'][j] = loss_x if not torch.is_tensor(loss_x) else loss_x.item()
+ #self.reservoir['loss_stat_steps'][j] = [self.example_seen]
+ #self.reservoir['forget'] = 0
+ self.example_seen += 1
+
+ def update_loss_states(self, loss, indices):
+ for i in range(len(indices)):
+ self.reservoir['loss_stats'][indices[i]] = loss[i].item()
+
+ def get_loss_aware_balanced_reservoir_sampling_index(self):
+ # assumes the mem is full
+ _random = np.random.RandomState(self.example_seen + self.cfg.SEED)
+ s_balance = np.array([self.reservoir['label_cnts'][mem_y.item()] for mem_y in self.reservoir['y']])
+ s_loss = -self.reservoir['loss_stats']
+ alpha = abs(s_balance.sum()) / s_loss.sum()
+ s = s_loss * alpha + s_balance
+ probs = s / s.sum()
+ idx = _random.choice(len(probs), p=probs)
+ return idx
+
+ def update_mem_kmeans(self, x, y, y_extra=None, x_feat=None, **kwargs):
+ if 'x_clus' not in self.reservoir:
+ self.reservoir['x_clus'] = np.zeros((self.mem_limit, x_feat.shape[0]))
+ self.reservoir['x_cnt'] = [None] * self.mem_limit
+ self.reservoir['x_feat'] = np.zeros((self.mem_limit, x_feat.shape[0]))
+ x = x.cpu().numpy()
+
+ if type(y) not in [list, tuple]:
+ y = y_to_np(y)
+ else:
+ y = y_to_cpu(y)
+
+ if self.example_seen < self.mem_limit:
+ self.reservoir['x'][self.example_seen] = x
+ self.reservoir['y'][self.example_seen] = y
+ self.reservoir['y_extra'][self.example_seen] = y_extra
+ self.reservoir['x_clus'][self.example_seen] = x_feat
+ self.reservoir['x_cnt'][self.example_seen] = 0
+ self.reservoir['x_feat'][self.example_seen] = x_feat
+ else:
+ # compute L2 distance
+ center_i, center_d = -1, 1e10
+ for mb in range(self.mem_limit):
+ dist = np.sum(np.square(x_feat - self.reservoir['x_clus'][mb]))
+ if dist < center_d:
+ center_i = mb
+ center_d = dist
+ cnt = self.reservoir['x_cnt'][center_i]
+ self.reservoir['x_clus'][center_i] = (self.reservoir['x_clus'][center_i] * cnt + x_feat) / (cnt + 1)
+ self.reservoir['x_cnt'][center_i] += 1
+
+ d_mem = np.sum(np.square(x_feat - self.reservoir['x_clus'][center_i]))
+ d_mem_best = np.sum(np.square(self.reservoir['x_feat'][center_i] - self.reservoir['x_clus'][center_i]))
+ if d_mem < d_mem_best:
+ self.reservoir['x'][center_i] = x
+ self.reservoir['y'][center_i] = y
+ self.reservoir['y_extra'][center_i] = y_extra
+ self.reservoir['x_feat'][center_i] = x_feat
+ self.example_seen += 1
+
+ def compute_offset(self, task_id, n):
+ idx = self.seen_tasks.index(task_id)
+ return int(idx / n * self.mem_limit), int((idx + 1) / n * self.mem_limit)
+
+ def reallocate_memory(self, old_mem, old_occ):
+ new_mem = {'x': np.zeros((self.mem_limit, self.input_size)),
+ 'y': [None] * self.mem_limit,
+ 'y_extra': [None] * self.mem_limit}
+ new_occ = {}
+ n_tasks = len(self.seen_tasks)
+ for task in self.seen_tasks:
+ old_offset_start, old_offset_stop = self.compute_offset(task, n_tasks)
+ new_offset_start, new_offset_stop = self.compute_offset(task, n_tasks + 1)
+ i = 0
+ while new_offset_start + i < new_offset_stop:
+ new_offset = new_offset_start + i
+ old_offset = old_offset_start + i
+ for k in old_mem:
+ new_mem[k][new_offset] = old_mem[k][old_offset]
+ i += 1
+ new_occ[task] = old_occ[task] # min(old_occ[task], new_offset_stop - new_offset_start)
+ return new_mem, new_occ
+
+ def update_mem_balanced(self, x, y):
+ y_attr, y_obj = y_to_np(y)
+ y_obj = y_obj.item()
+ y_attr = y_attr.item()
+ x = x.cpu().numpy()
+
+ if y_obj not in self.seen_tasks:
+ # reallocate memory by expanding seen task by 1
+ new_mem, new_occ = self.reallocate_memory(self.reservoir, self.mem_occupied)
+ self.reservoir = new_mem
+ self.mem_occupied = new_occ
+ self.mem_occupied[y_obj] = 0
+ self.seen_tasks.append(y_obj)
+ offset_start, offset_stop = self.compute_offset(y_obj, len(self.seen_tasks))
+
+ if self.mem_occupied[y_obj] < offset_stop - offset_start:
+ pos = self.mem_occupied[y_obj] + offset_start
+ self.reservoir['x'][pos] = x
+ self.reservoir['y'][pos] = y
+ # self.reservoir['y_attr'][pos] = y_attr
+ else:
+ j = np.random.RandomState(self.example_seen + self.cfg.SEED).randint(0, self.mem_occupied[y_obj])
+ if j < offset_stop - offset_start:
+ pos = j + offset_start
+ self.reservoir['x'][pos] = x
+ self.reservoir['y'][pos] = y
+ # self.reservoir['y_attr'][pos] = y_attr
+
+ self.mem_occupied[y_obj] += 1
+ self.example_seen += 1
+
+ def get_available_index(self):
+ l = []
+ for idx, t in enumerate(self.seen_tasks):
+ offset_start, offset_stop = self.compute_offset(t, len(self.seen_tasks))
+ for i in range(offset_start, min(offset_start + self.mem_occupied[t], offset_stop)):
+ l.append(i)
+ return l
+
+ def get_random(self, seed=1):
+ random_state = None
+ for i in range(seed):
+ if random_state is None:
+ random_state = np.random.RandomState(self.example_seen + self.cfg.SEED)
+ else:
+ random_state = np.random.RandomState(random_state.randint(0, int(1e5)))
+ return random_state
+
+ def store_cache(self):
+ # for i, param in enumerate(self.net.parameters()):
+ # self.parameter_cache[i].copy_(param.data)
+ self.cache = copy.deepcopy(self.net.state_dict())
+
+ def load_cache(self):
+ self.net.load_state_dict(self.cache)
+ self.net.zero_grad()
+
+ def get_loss_and_pseudo_update(self, x, y, task_ids):
+ ret_dict_d = self.forward_net(x, y, task_ids)
+ self.optimizer.zero_grad()
+ ret_dict_d['loss'].backward(retain_graph=False)
+ if isinstance(self.optimizer, torch.optim.SGD):
+ step_wo_state_update_sgd(self.optimizer, amp=1.)
+ elif isinstance(self.optimizer, torch.optim.Adam):
+ step_wo_state_update_adam(self.optimizer, amp=1.)
+ elif isinstance(self.optimizer, AdamW):
+ step_wo_state_update_adamw(self.optimizer)
+ else:
+ raise NotImplementedError
+ return ret_dict_d
+
+ def decide_mir_mem(self, x, y, task_ids, mir_k, cand_x, cand_y, cand_task_ids, indices, mir_least):
+ if cand_x.size(0) < mir_k:
+ return cand_x, cand_y, cand_task_ids, indices
+ else:
+ self.store_cache()
+ if type(cand_y[0]) not in [list, tuple]:
+ cand_y = concat_with_padding(cand_y)
+ else:
+ cand_y = [torch.stack(_).to(x.device) for _ in zip(*cand_y)]
+ with torch.no_grad():
+ ret_dict_mem_before = self.forward_net(cand_x, cand_y, reduce=False, task_ids=cand_task_ids)
+ ret_dict_d = self.get_loss_and_pseudo_update(x, y, task_ids)
+ with torch.no_grad():
+ ret_dict_mem_after = self.forward_net(cand_x, cand_y, reduce=False, task_ids=cand_task_ids)
+ loss_increase = ret_dict_mem_after['loss'] - ret_dict_mem_before['loss']
+ with torch.no_grad():
+ if self.goal == 'captioning':
+ if self.mir_agg == 'avg':
+ loss_increase_by_ts = loss_increase.view(cand_x.size(0), -1).sum(1)
+ mask_num_by_ts = (cand_y[2] != -1).sum(1).float() + 1e-10
+ loss_increase = loss_increase_by_ts / mask_num_by_ts
+ elif self.mir_agg == 'max':
+ loss_increase, _ = loss_increase.view(cand_x.size(0), -1).max(1)
+
+ _, topi = loss_increase.topk(mir_k, largest=not mir_least)
+ if type(cand_y) is not list:
+ mem_x, mem_y, mem_task_ids = cand_x[topi], cand_y[topi], cand_task_ids[topi]
+ else:
+ mem_x, mem_task_ids = cand_x[topi], cand_task_ids[topi]
+ mem_y = [_[topi] for _ in cand_y]
+
+ self.load_cache()
+ return mem_x, mem_y, mem_task_ids, indices[topi.cpu()]
+
+
+ def sample_mem_batch(self, device, return_indices=False, k=None, seed=1,
+ mir=False, input_x=None, input_y=None, input_task_ids=None, mir_k=0,
+ skip_task=None, mir_least=False):
+ random_state = self.get_random(seed)
+ if k is None:
+ k = self.mem_bs
+
+ if not self.balanced:
+ # reservoir
+ n_max = min(self.mem_limit, self.example_seen)
+ available_indices = [_ for _ in range(n_max)]
+ if skip_task is not None and get_config_attr(self.cfg, 'EXTERNAL.REPLAY.FILTER_SELF', default=0, mute=True):
+ available_indices = list(filter(lambda x: self.reservoir['y_extra'][x] != skip_task, available_indices))
+ if not available_indices:
+ if return_indices:
+ return None, None, None
+ else:
+ return None, None, None
+ elif len(available_indices) < k:
+ indices = np.arange(n_max)
+ else:
+ indices = random_state.choice(available_indices, k, replace=False)
+ else:
+ available_index = self.get_available_index()
+ if len(available_index) == 0:
+ if return_indices:
+ return None, None, None
+ else:
+ return None, None, None
+ elif len(available_index) < k:
+ indices = np.array(available_index)
+ else:
+ indices = random_state.choice(available_index, k, replace=False)
+ x = self.reservoir['x'][indices]
+ x = torch.from_numpy(x).to(device).float()
+
+ y = index_select(self.reservoir['y'], indices, device) # [ [...], [...] ]
+ y_extra = index_select(self.reservoir['y_extra'], indices, device)
+ if type(y[0]) not in [list, tuple]:
+ y_pad = concat_with_padding(y)
+ else:
+ y_pad = [torch.stack(_).to(device) for _ in zip(*y)]
+ y_extra = concat_with_padding(y_extra)
+
+ if mir:
+ x, y_pad, y_extra, indices = self.decide_mir_mem(input_x, input_y, input_task_ids, mir_k,
+ x, y, y_extra, indices, mir_least)
+
+ if not return_indices:
+ return x, y_pad, y_extra
+ else:
+ return (x, indices), y_pad, y_extra
+
+ def observe(self, x, y, task_ids=None, extra=None, optimize=True):
+ # recover image, feat from x
+ if task_ids is None:
+ task_ids = torch.zeros(x.size(0)).to(x.device).long()
+
+ if self.mir:
+ mem_x, mem_y, mem_task_ids = self.sample_mem_batch(x.device, input_x=x, input_y=y, input_task_ids=task_ids,
+ mir_k=self.mir_k, mir=self.mir,
+ skip_task=task_ids[0].item())
+ else:
+ mem_x, mem_y, mem_task_ids = self.sample_mem_batch(x.device, skip_task=task_ids[0].item())
+
+ batch_size = x.size(0)
+ if mem_x is not None and not self.separate_replay and not self.goal == 'captioning': # a dirty fix to prevent oom
+ if not self.mem_augment:
+ combined_x = torch.cat([x, mem_x], 0)
+ combined_y = concat_with_padding([y, mem_y])
+ combined_task_ids = concat_with_padding([task_ids, mem_task_ids])
+ else:
+ aug_mem_x = self.transform_image_batch(mem_x)
+ combined_x = torch.cat([x,mem_x,aug_mem_x],0)
+ combined_y = concat_with_padding([y,mem_y,mem_y])
+ combined_task_ids = concat_with_padding([task_ids, mem_task_ids, mem_task_ids])
+ else:
+ combined_x, combined_y, combined_task_ids = x, y, task_ids
+
+ ret_dict = self.forward_net(combined_x, combined_y, combined_task_ids,
+ reduce=self.concat_replay or self.separate_replay)
+
+ loss_tmp = ret_dict['loss']
+ if optimize:
+ # loss = loss_tmp.mean()
+ # print(loss.item())
+ if self.concat_replay or self.separate_replay:
+ loss = ret_dict['loss']
+ else:
+ loss = loss_tmp[: x.size(0)].mean()
+ if mem_x is not None:
+ loss += loss_tmp[x.size(0):].mean()
+
+ self.optimizer.zero_grad()
+
+ if self.concat_replay and mem_x is not None:
+ loss = loss / 2
+
+ loss.backward()
+
+ #if mem_x is None or (not self.separate_replay and not self.goal == 'captioning'):
+ if not self.concat_replay or mem_x is None:
+ self.optimizer.step()
+
+ if (self.separate_replay or self.goal == 'captioning') and mem_x is not None:
+ ret_dict_mem = self.forward_net(mem_x, mem_y, mem_task_ids, reduce=True)
+
+ if not self.concat_replay:
+ self.optimizer.zero_grad()
+
+ if self.concat_replay:
+ ret_dict_mem['loss'] = ret_dict_mem['loss'] / 2
+
+ ret_dict_mem['loss'].backward()
+ self.optimizer.step()
+ ret_dict['loss'] = (ret_dict['loss'] + ret_dict_mem['loss']) / 2
+
+ for b in range(batch_size): # x.size(0)
+ if type(y) is tuple:
+ y_ = [_[b] for _ in y]
+ else:
+ y_ = y[b]
+ self.update_mem(x[b], y_, task_ids[b] if task_ids is not None else None,
+ x_feat=None)
+ return ret_dict
+
+ def dump_reservoir(self, path, verbose=False):
+ f = open(path, 'wb')
+ pickle.dump({
+ 'reservoir_x': self.reservoir['x'] if verbose else None,
+ 'reservoir_y': self.reservoir['y'],
+ 'reservoir_y_extra': self.reservoir['y_extra'],
+ 'mem_occupied': self.mem_occupied,
+ 'example_seen': self.example_seen,
+ 'seen_tasks': self.seen_tasks,
+ 'balanced': self.balanced
+ }, f)
+ f.close()
+
+ def load_reservoir(self, path):
+ try:
+ f = open(path, 'rb')
+ dic = pickle.load(f)
+ for k in dic:
+ setattr(self, k, dic[k])
+ f.close()
+ return dic
+ except FileNotFoundError:
+ print('no replay buffer dump file')
+ return {}
+
+ def load_reservoir_from_dic(self, dic):
+ for k in dic:
+ setattr(self, k, dic[k])
+
+ def get_reservoir(self):
+ return {'reservoir': self.reservoir, 'mem_occupied': self.mem_occupied,
+ 'example_seen': self.example_seen, 'seen_tasks': self.seen_tasks,
+ 'balanced': self.balanced}
+
+ def mean_of_exemplar_classify(self, cropped_image_inp):
+ if not hasattr(self, 'mean_exemplar_vec'):
+ mean_exemplar_vec = []
+ for task in self.seen_tasks:
+ offset_start, offset_stop = self.compute_offset(task, len(self.seen_tasks))
+ x = self.reservoir['x'][offset_start: min(offset_stop, offset_start + self.mem_occupied[task])]
+ x = torch.from_numpy(x).float().to(cropped_image_inp.device)
+ cropped_image = x[:, : 3 * self.cfg.EXTERNAL.IMAGE_SIZE * self.cfg.EXTERNAL.IMAGE_SIZE] \
+ .view(-1, 3, self.cfg.EXTERNAL.IMAGE_SIZE, self.cfg.EXTERNAL.IMAGE_SIZE)
+ feat = self.net.get_obj_features(cropped_image)
+ mean_feat = feat.mean(0)
+ mean_exemplar_vec.append(mean_feat)
+ self.mean_exemplar_vec = torch.stack(mean_exemplar_vec) # [C, H]
+
+ feat = self.net.get_obj_features(cropped_image_inp) # [B, H]
+ mean_exemplar_vec_expand = self.mean_exemplar_vec.unsqueeze(0).expand(feat.size(0), -1, -1) # [B,C,H]
+ feat_expand = feat.unsqueeze(1).expand(-1, self.mean_exemplar_vec.size(0), -1) # [B,C,H]
+ dist = torch.sum((mean_exemplar_vec_expand - feat_expand) ** 2, -1) # [B,C]
+ dist = torch.sqrt(dist)
+ _, pred = dist.min(-1) # [B]
+
+ pred_index_fix = torch.zeros(pred.size()).to(pred.device)
+ for b in range(pred.size(0)):
+ pred_index_fix[b] = self.seen_tasks[pred[b]]
+
+ return pred_index_fix
+
+
+def step_wo_state_update_adam(adam, closure=None, amp=1.):
+ """Performs a single optimization step. Do not update optimizer states
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ if type(adam) is not Adam:
+ raise ValueError
+ loss = None
+ if closure is not None:
+ loss = closure()
+
+ for group in adam.param_groups:
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ grad = p.grad.data
+ if grad.is_sparse:
+ raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
+ amsgrad = group['amsgrad']
+
+ state = adam.state[p]
+
+ # State initialization
+ if len(state) == 0:
+ state['step'] = 0
+ # Exponential moving average of gradient values
+ state['exp_avg'] = torch.zeros_like(p.data)
+ # Exponential moving average of squared gradient values
+ state['exp_avg_sq'] = torch.zeros_like(p.data)
+ if amsgrad:
+ # Maintains max of all exp. moving avg. of sq. grad. values
+ state['max_exp_avg_sq'] = torch.zeros_like(p.data)
+
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+ if amsgrad:
+ max_exp_avg_sq = state['max_exp_avg_sq']
+ beta1, beta2 = group['betas']
+
+ # state['step'] += 1
+
+ if group['weight_decay'] != 0:
+ grad.add_(group['weight_decay'], p.data)
+
+ # Decay the first and second moment running average coefficient
+ exp_avg = exp_avg.mul(beta1).add(1 - beta1, grad)
+ exp_avg_sq = exp_avg_sq.mul(beta2).addcmul(1 - beta2, grad, grad)
+ if amsgrad:
+ # Maintains the maximum of all 2nd moment running avg. till now
+ torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
+ # Use the max. for normalizing running avg. of gradient
+ denom = max_exp_avg_sq.sqrt().add_(group['eps'])
+ else:
+ denom = exp_avg_sq.sqrt().add_(group['eps'])
+
+ bias_correction1 = 1 - beta1 ** state['step']
+ bias_correction2 = 1 - beta2 ** state['step']
+ step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 * amp
+
+ p.data.addcdiv_(-step_size, exp_avg, denom)
+
+ return loss
+
+
+def step_wo_state_update_sgd(sgd, closure=None, amp=1.):
+ """Performs a single optimization step.
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ loss = closure()
+
+ for group in sgd.param_groups:
+ weight_decay = group['weight_decay']
+ momentum = group['momentum']
+ dampening = group['dampening']
+ nesterov = group['nesterov']
+
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ d_p = p.grad.data
+ if weight_decay != 0:
+ d_p.add_(weight_decay, p.data)
+ if momentum != 0:
+ param_state = sgd.state[p]
+ if 'momentum_buffer' not in param_state:
+ buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
+ else:
+ buf = param_state['momentum_buffer']
+ buf.mul_(momentum).add_(1 - dampening, d_p)
+ if nesterov:
+ d_p = d_p.add(momentum, buf)
+ else:
+ d_p = buf
+
+ p.data.add_(-group['lr'] * amp, d_p)
+
+ return loss
+
+
+def step_wo_state_update_adamw(self, closure=None):
+ """Performs a single optimization step.
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ grad = p.grad.data
+ if grad.is_sparse:
+ raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
+
+ state = self.state[p]
+
+ # State initialization
+ #if len(state) == 0:
+ # state['step'] = 0
+ # # Exponential moving average of gradient values
+ # state['exp_avg'] = torch.zeros_like(p.data)
+ # # Exponential moving average of squared gradient values
+ # state['exp_avg_sq'] = torch.zeros_like(p.data)
+
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+ beta1, beta2 = group['betas']
+
+ #state['step'] += 1
+
+ # Decay the first and second moment running average coefficient
+ # In-place operations to update the averages at the same time
+ exp_avg = exp_avg.mul(beta1).add(1.0 - beta1, grad)
+ exp_avg_sq = exp_avg_sq.mul(beta2).addcmul(1.0 - beta2, grad, grad)
+ denom = exp_avg_sq.sqrt().add(group['eps'])
+
+ step_size = group['lr']
+ if group['correct_bias']: # No bias correction for Bert
+ bias_correction1 = 1.0 - beta1 ** state['step']
+ bias_correction2 = 1.0 - beta2 ** state['step']
+ step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
+
+ p.data.addcdiv_(-step_size, exp_avg, denom)
+
+ # Just adding the square of the weights to the loss function is *not*
+ # the correct way of using L2 regularization/weight decay with Adam,
+ # since that will interact with the m and v parameters in strange ways.
+ #
+ # Instead we want to decay the weights in a manner that doesn't interact
+ # with the m/v parameters. This is equivalent to adding the square
+ # of the weights to the loss with plain (non-momentum) SGD.
+ # Add weight decay at the end (fixed version)
+ if group['weight_decay'] > 0.0:
+ p.data.add_(-group['lr'] * group['weight_decay'], p.data)
+
+ return loss
+
+def get_updated_weights_sgd(sgd, closure=None, amp=1.):
+ """Performs a single optimization step.
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ loss = closure()
+ weights = []
+ for group in sgd.param_groups:
+ weight_decay = group['weight_decay']
+ momentum = group['momentum']
+ dampening = group['dampening']
+ nesterov = group['nesterov']
+
+ for p in group['params']:
+ if p.grad is None:
+ weights.append(p.data)
+ d_p = p.grad
+ if weight_decay != 0:
+ d_p = d_p.add(weight_decay, p.data)
+ if momentum != 0:
+ param_state = sgd.state[p]
+ if 'momentum_buffer' not in param_state:
+ buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
+ else:
+ buf = param_state['momentum_buffer']
+ buf.mul_(momentum).add_(1 - dampening, d_p)
+ if nesterov:
+ d_p = d_p.add(momentum, buf)
+ else:
+ d_p = buf
+ weights.append(p.data - group['lr'] * amp * d_p)
+ return weights
\ No newline at end of file
diff --git a/ocl/naive.py b/ocl/naive.py
new file mode 100644
index 0000000..f523206
--- /dev/null
+++ b/ocl/naive.py
@@ -0,0 +1,100 @@
+from torch import nn
+from torch import optim
+import torch
+from utils.utils import get_config_attr
+import numpy as np
+
+class NaiveWrapper(nn.Module):
+ def __init__(self, base, optimizer, input_size, cfg, goal, **kwargs):
+ super().__init__()
+ self.net = base
+ self.optimizer = optimizer
+ self.input_size = input_size
+ self.cfg = cfg
+ self.goal = goal
+
+ if 'caption' in self.goal:
+ self.clip_grad = True
+ self.use_image_feat = get_config_attr(self.cfg, 'EXTERNAL.USE_IMAGE_FEAT', default=0)
+ self.spatial_feat_shape = (2048, 7, 7)
+ self.bbox_feat_shape = (100, 2048)
+ self.bbox_shape = (100,4)
+ self.rev_optim = None
+ if hasattr(self.net, 'rev_update_modules'):
+ self.rev_optim = optim.Adam(lr=cfg.SOLVER.BASE_LR, betas=(0.9, 0.999),
+ params=self.net.rev_update_modules.parameters())
+
+ def forward(self, *args, **kwargs):
+ return self.net(*args, **kwargs)
+
+ def forward_net(self, x, y, task_ids=None, **kwargs):
+ if self.goal in ['classification']:
+ if type(y) is tuple:
+ attr_labels, obj_labels = y
+ else:
+ attr_labels, obj_labels = y, None
+ cropped_image = x[:, : 3 * self.cfg.EXTERNAL.IMAGE_SIZE * self.cfg.EXTERNAL.IMAGE_SIZE] \
+ .view(-1, 3, self.cfg.EXTERNAL.IMAGE_SIZE, self.cfg.EXTERNAL.IMAGE_SIZE)
+ images = x[:, 3 * self.cfg.EXTERNAL.IMAGE_SIZE * self.cfg.EXTERNAL.IMAGE_SIZE:] \
+ .view(-1, 3, self.cfg.EXTERNAL.IMAGE_SIZE, self.cfg.EXTERNAL.IMAGE_SIZE)
+ ret_dict = self.net(
+ bbox_images=cropped_image, spatial_feat=None,
+ images=images,
+ attr_labels=attr_labels,
+ obj_labels=obj_labels,
+ task_ids=task_ids
+ **kwargs
+ )
+ elif self.goal == 'captioning':
+ # legacy
+ if self.use_image_feat:
+ images = x.view(-1, 3, self.cfg.EXTERNAL.IMAGE_SIZE, self.cfg.EXTERNAL.IMAGE_SIZE)
+ captions, caption_lens, labels = y
+ ret_dict = self.net(images=images, captions=captions, caption_lens=caption_lens, labels=labels, **kwargs)
+ else:
+ # rebuild spatial and packed inputs
+ batch_size = x.size(0)
+ bfeat_dim = np.prod(self.bbox_feat_shape)
+ # bbox_feats, spatial_feats, bboxes = x[:, :bfeat_dim].view(batch_size, *self.bbox_feat_shape), \
+ # x[:,bfeat_dim:bfeat_dim + sfeat_dim].view(batch_size, *self.spatial_feat_shape), \
+ # x[:,bfeat_dim + sfeat_dim:].view(batch_size, *self.bbox_shape)
+ bbox_feats, bboxes = x[:, :bfeat_dim].view(batch_size, *self.bbox_feat_shape), \
+ x[:,bfeat_dim:].view(batch_size, *self.bbox_shape)
+ captions, caption_lens, labels = y
+ ret_dict = self.net(bbox_feats=bbox_feats, captions=captions, caption_lens=caption_lens, labels=labels,
+ bboxes=bboxes, **kwargs)
+
+ elif self.goal == 'split_mnist' or self.goal == 'permute_mnist' or self.goal == 'rotated_mnist':
+ images = x.view(x.size(0), -1)
+ ret_dict = self.net(images, y, task_ids=task_ids, **kwargs)
+ elif self.goal == 'split_cifar':
+ images = x.view(-1, 3, 32, 32)
+ ret_dict = self.net(bbox_images=images, spatial_feat=None, images=None, attr_labels=y, obj_labels=None,
+ task_ids=task_ids, **kwargs)
+ elif self.goal == 'split_mini_imagenet':
+ images = x.view(-1,3,84,84)
+ ret_dict = self.net(bbox_images=images, spatial_feat=None, images=None, attr_labels=y, obj_labels=None,
+ task_ids=task_ids, **kwargs)
+ else:
+ raise ValueError
+ return ret_dict
+
+ def observe(self, x, y, task_ids, optimize=True):
+ # if deprecated is not None:
+ # y = (y, deprecated)
+ # recover image, feat from x
+ self.optimizer.zero_grad()
+ ret_dict = \
+ self.forward_net(x, y, task_ids)
+
+ loss = ret_dict['loss']
+ if optimize:
+ loss.backward()
+ #if self.clip_grad:
+ # torch.nn.utils.clip_grad_norm_(self.net.parameters(), 1.0)
+ self.optimizer.step()
+
+ return ret_dict
+
+ def initialize_word_emb(self, *args, **kwargs):
+ self.net.initialize_word_emb(*args, **kwargs)
diff --git a/ocl/utils.py b/ocl/utils.py
new file mode 100644
index 0000000..b59903e
--- /dev/null
+++ b/ocl/utils.py
@@ -0,0 +1,39 @@
+import numpy as np
+
+def get_glove_matrix(vocab, glove_path, initial_embedding_np):
+ """
+ return a glove embedding matrix
+ :param initial_embedding_np:
+ :return: np array of [V,E]
+ """
+ ef = open(glove_path, 'r', encoding='utf-8')
+ cnt = 0
+ vec_array = initial_embedding_np
+ old_avg = np.average(vec_array)
+ old_std = np.std(vec_array)
+ vec_array = vec_array.astype(np.float32)
+ new_avg, new_std = 0, 0
+
+ indices = []
+
+ for line in ef.readlines():
+ line = line.strip().split(' ')
+ word, vec = line[0].lower(), line[1:]
+ vec = np.array(vec, np.float32)
+ if word in vocab:
+ cnt += 1
+ word_idx = vocab[word]
+ if word_idx < vec_array.shape[0]:
+ vec_array[word_idx] = vec
+ new_avg += np.average(vec)
+ new_std += np.std(vec)
+ indices.append(word_idx)
+ new_avg /= cnt
+ new_std /= cnt
+ ef.close()
+ print('%d known embedding. old mean: %f new mean %f, old std %f new std %f' % (cnt, old_avg,
+ new_avg, old_std, new_std))
+ # scale the added embeddings
+ vec_array[indices] *= old_avg / new_avg
+
+ return vec_array
\ No newline at end of file
diff --git a/ocl/ver.py b/ocl/ver.py
new file mode 100644
index 0000000..1f00c0a
--- /dev/null
+++ b/ocl/ver.py
@@ -0,0 +1,141 @@
+from .er import *
+import math
+from torch.optim import Adam
+import copy
+from utils.utils import get_config_attr
+
+class FOExperienceEvolve(ExperienceReplay):
+ def __init__(self, base, optimizer, input_size, cfg, goal):
+ super().__init__(base, optimizer, input_size, cfg, goal)
+ # self.reservoir = {'x': np.zeros((self.mem_limit, input_size)),
+ # 'y': [None] * self.mem_limit,
+ # 'y_extra': [None] * self.mem_limit,
+ # 'x_origin': np.zeros((self.mem_limit, input_size)),
+ # 'x_edit_state': [None] * self.mem_limit,
+ # 'loss_stats': [None] * self.mem_limit,
+ # 'loss_stat_steps': [None] * self.mem_limit,
+ # 'forget': [None] * self.mem_limit,
+ # 'support': [None] * self.mem_limit
+ # }
+ self.itf_cnt = 0
+ self.total_cnt = 0
+ self.grad_iter = get_config_attr(cfg, 'EXTERNAL.OCL.GRAD_ITER', default=1)
+ self.grad_stride = get_config_attr(cfg, 'EXTERNAL.OCL.GRAD_STRIDE', default=10.)
+ self.edit_decay = get_config_attr(cfg, 'EXTERNAL.OCL.EDIT_DECAY', default=0.)
+ self.no_write_back = get_config_attr(cfg, 'EXTERNAL.OCL.NO_WRITE_BACK', default=0)
+ self.reservoir['age'] = np.zeros(self.mem_limit)
+
+ def get_mem_ages(self, indices, astype):
+ ages = self.reservoir['age'][indices]
+ if torch.is_tensor(astype):
+ ages = torch.from_numpy(ages).float().to(astype.device)
+ return ages
+
+ def observe(self, x, y, task_ids, extra=None, optimize=True, sequential=False):
+ sequential = True
+ global total_cnt, itf_cnt
+ self.optimizer.zero_grad()
+ mem_x_indices, mem_x_origin, mem_y, mem_task_ids = self.sample_mem_batch(x.device, return_indices=True)
+
+ batch_size = x.size(0)
+ if mem_x_indices is None:
+ combined_x, combined_y, combined_task_ids = self.sample_mem_batch()
+ else:
+ mem_x, indices = mem_x_indices
+ self.store_cache()
+ for i in range(self.grad_iter):
+ # evaluate loss on mem_x, mem_y
+ mem_x.requires_grad = True
+ mem_x.grad = None
+ mem_x_origin.requires_grad = True
+
+ # evaluate grad of l wrt mem
+ self.optimizer.zero_grad()
+ ret_dict_mem_before = self.forward_net(mem_x, mem_y, reduce=False, task_ids=task_ids)
+ # grad_l = -torch.autograd.grad(torch.sum(ret_dict_mem_origin_before['loss']), mem_x_origin, retain_graph=True)[0]
+
+ # train the model on D
+ if not sequential:
+ self.get_loss_and_pseudo_update(x, y, task_ids)
+ else:
+ self.train(False)
+ for b in range(batch_size):
+ x_b = x[b].unsqueeze(0)
+ if type(y) in [tuple, list]:
+ y_b = [_[b].unsqueeze(0) for _ in y]
+ else:
+ y_b = y[b].unsqueeze(0)
+ ret_dict_db = self.forward_net(x_b, y_b)
+ self.optimizer.zero_grad()
+ ret_dict_db['loss'].backward()
+ if isinstance(self.optimizer, torch.optim.SGD):
+ step_wo_state_update_sgd(self.optimizer, amp=1.)
+ elif isinstance(self.optimizer, torch.optim.Adam):
+ step_wo_state_update_adam(self.optimizer, amp=1.)
+ else:
+ raise NotImplementedError
+ self.train(True)
+
+ ret_dict_mem_after = self.forward_net(mem_x, mem_y, reduce=False)
+ if 'mask_cnts' not in ret_dict_mem_after:
+ loss_increase = (ret_dict_mem_after['loss'] - ret_dict_mem_before['loss']).mean()
+ else:
+ loss_increase = (ret_dict_mem_after['loss'] - ret_dict_mem_before['loss']).sum() / \
+ (sum(ret_dict_mem_after['mask_cnts']) + 1e-10)
+ loss_increase.backward(retain_graph=False)
+ grad_delta = mem_x.grad
+
+ self.load_cache()
+
+ mem_ages = self.get_mem_ages(mem_x_indices, astype=mem_x)
+ stride_decayed = (1 - self.edit_decay) ** mem_ages
+
+ proposed_mem_x = mem_x + self.grad_stride * stride_decayed.view(-1,1) * grad_delta
+ proposed_mem_x.detach_()
+
+ mem_x = proposed_mem_x
+ mem_x = mem_x.detach()
+ self.evolve_mem(mem_x, indices)
+
+ # load cached parameters back
+ self.load_cache()
+ combined_x = torch.cat([x, mem_x], 0)
+ combined_y = concat_with_padding([y, mem_y])
+
+ ret_dict = self.forward_net(combined_x, combined_y)
+
+ for b in range(batch_size):
+ if type(y) is tuple:
+ self.update_mem(x[b], [_[b] for _ in y], extra[b] if extra is not None else None)
+ else:
+ self.update_mem(x[b], y[b], extra[b] if extra is not None else None)
+
+ loss = ret_dict['loss']
+ if optimize:
+ self.optimizer.zero_grad()
+ loss.backward()
+ self.optimizer.step()
+
+ return ret_dict
+
+ def evolve_mem(self, x, indices):
+ for i, idx in enumerate(indices):
+ #if self.reservoir['age'][idx] < 10:
+ self.reservoir['x'][idx] = x[i].cpu().numpy()
+ self.reservoir['age'][idx] += 1
+
+
+
+total_cnt, itf_cnt = 0, 0
+def proj_grad(a, b, binary, always_proj):
+ # project b to the direction of a
+
+ dotp = torch.dot(a,b)
+ if dotp >= 0:
+ if not always_proj:
+ return b
+ else:
+ return b - (torch.dot(-a,b) / torch.dot(a,a)) * a
+ else:
+ if binary: return torch.zeros_like(a)
+ else: return b - (torch.dot(a,b) / torch.dot(a,a)) * a
diff --git a/ocl/ver_approx.py b/ocl/ver_approx.py
new file mode 100644
index 0000000..d076403
--- /dev/null
+++ b/ocl/ver_approx.py
@@ -0,0 +1,365 @@
+from .ver import *
+from torch.nn import functional as F
+import random
+
+
+class ExperienceEvolveApprox(FOExperienceEvolve):
+ def __init__(self, base, optimizer, input_size, cfg, goal):
+ super().__init__(base, optimizer, input_size, cfg, goal)
+ self.edit_least = get_config_attr(cfg, 'EXTERNAL.OCL.EDIT_LEAST', default=0)
+ self.edit_random = get_config_attr(cfg, 'EXTERNAL.OCL.EDIT_RANDOM', default=0)
+
+ self.edit_interfere = get_config_attr(cfg, 'EXTERNAL.OCL.EDIT_INTERFERE', default=1)
+ self.edit_replace = get_config_attr(cfg, 'EXTERNAL.OCL.EDIT_REPLACE', default=0)
+ self.replace_reweight = get_config_attr(cfg, 'EXTERNAL.OCL.REPLACE_REWEIGHT', default=0)
+ self.use_relu = get_config_attr(cfg, 'EXTERNAL.OCL.USE_RELU', default=0)
+ self.reg_strength = get_config_attr(cfg, 'EXTERNAL.OCL.REG_STRENGTH', default=0.1)
+ self.always_proj = get_config_attr(cfg, 'EXTERNAL.OCL.ALWAYS_PROJ', default=0)
+
+ self.edit_mir_k = get_config_attr(cfg, 'EXTERNAL.OCL.EDIT_K', default=-1)
+
+ self.hal_mem = get_config_attr(cfg, 'EXTERNAL.OCL.HAL_MEM', default=0)
+ self.post_edit_mem_aug = get_config_attr(cfg,'EXTERNAL.OCL.POST_EDIT_MEM_AUG', default=1)
+ self.edit_aug_mem = get_config_attr(cfg,'EXTERNAL.OCL.EDIT_AUG_MEM', default=0)
+ self.double_weight = get_config_attr(cfg,'EXTERNAL.OCL.DOUBLE_WEIGHT', default=0)
+
+ if self.edit_mir_k == -1:
+ self.edit_mir_k = self.mir_k
+
+ def sample_mem_batch_same_task(self, device, task_id_or_label, return_indices=False, mem_k=None, seed=0, use_same_label=False):
+ if mem_k is None:
+ mem_k = self.mem_bs
+ if use_same_label:
+ label = task_id_or_label
+ else:
+ task_id = task_id_or_label
+
+ n_max = min(self.mem_limit, self.example_seen)
+ indices = []
+ for i in range(n_max):
+ if use_same_label:
+ if self.reservoir['y'][i] == label:
+ indices.append(i)
+ else:
+ if self.reservoir['y_extra'][i] == task_id:
+ indices.append(i)
+
+ # reservoir
+ if not indices:
+ return None, None, None
+ elif len(indices) >= mem_k:
+ indices = np.random.RandomState(seed * self.example_seen + self.cfg.SEED).\
+ choice(indices, mem_k, replace=False)
+
+ x = self.reservoir['x'][indices]
+ #x_origin = self.reservoir['x_origin'][indices]
+
+ x = torch.from_numpy(x).to(device).float()
+ #x_origin = torch.from_numpy(x_origin).to(device).float()
+ y = index_select(self.reservoir['y'], indices, device) # [ [...], [...] ]
+ y_extra = index_select(self.reservoir['y_extra'], indices, device)
+ y_extra = concat_with_padding(y_extra)
+ if type(y[0]) not in [list, tuple]:
+ y_pad = concat_with_padding(y)
+ else:
+ y_pad = [torch.stack(_).to(device) for _ in zip(*y)]
+
+ if not return_indices:
+ return x, y_pad, y_extra
+ else:
+ return (x, indices), y_pad, y_extra
+
+ def clear_mem_grad(self, mem_x):
+ mem_x.detach_()
+ mem_x.grad = None
+ mem_x.requires_grad = True
+
+
+ def observe(self, x, y, task_ids, extra=None, optimize=True, sequential=False):
+ n_iter = get_config_attr(self.cfg, 'EXTERNAL.OCL.N_ITER', default=1, mute=True)
+ batch_size = x.size(0)
+ self.store_cache()
+ for i_iter in range(n_iter):
+ if not self.mir:
+ mem_x_indices, mem_y, mem_task_ids = self.sample_mem_batch(x.device, return_indices=True, seed=i_iter + 1)
+ if self.edit_random: # select another batch for editing
+ edit_x_indices, edit_y, edit_task_ids = self.sample_mem_batch(x.device, return_indices=True,
+ k=self.mir_k, seed=i_iter + 2)
+ else:
+ edit_x_indices, edit_y, edit_task_ids = mem_x_indices, mem_y, mem_task_ids
+ else:
+ mem_x_indices, mem_y, mem_task_ids = self.sample_mem_batch(x.device, return_indices=True,
+ input_x=x, input_y=y, input_task_ids=task_ids,
+ mir_k=self.mir_k, mir=self.mir,
+ skip_task=task_ids[0].item(),
+ seed=i_iter + 1
+ )
+ if self.edit_least:
+ edit_x_indices, edit_y, edit_task_ids = self.sample_mem_batch(x.device, return_indices=True,
+ input_x=x, input_y=y, input_task_ids=task_ids,
+ mir_k=self.edit_mir_k, mir=self.mir,
+ skip_task=task_ids[0].item(),
+ mir_least=True,
+ seed=i_iter + 2
+ )
+ elif self.edit_random:
+ edit_x_indices, edit_y, edit_task_ids = self.sample_mem_batch(x.device, return_indices=True,
+ k=self.edit_mir_k, seed=i_iter + 2)
+ else:
+ edit_x_indices, edit_y, edit_task_ids = mem_x_indices, mem_y, mem_task_ids
+ self.optimizer.zero_grad()
+
+ edit_x_val_indices, edit_y_val, _ = self.sample_mem_batch_same_task(x.device, task_ids.cpu().numpy()[0],
+ return_indices=True, seed=i_iter + 2,
+ mem_k=self.mir_k if self.mir else self.mem_bs)
+ edit_task_ids_val = task_ids
+ #mem_x_val_indices, _, mem_y_val, mem_task_ids_val = self.sample_mem_batch(x.device, return_indices=True, seed=1)
+
+ if edit_x_indices is None:
+ combined_x, combined_y, combined_task_ids = x, y, task_ids
+ else:
+ mem_x, mem_indices = mem_x_indices
+ edit_x, edit_indices = edit_x_indices
+ if edit_x_val_indices is not None:
+ edit_x_val, indices_val = edit_x_val_indices
+ else:
+ edit_x_val, indices_val = None, None
+ if self.mem_augment:
+ aug_mem_x = self.transform_image_batch(mem_x)
+ if i_iter == 0 and self.edit_interfere:
+ if self.hal_mem:
+ train_x, train_y, _ = self.sample_mem_batch(x.device, return_indices=False,
+ seed=i_iter + 3)
+ else:
+ train_x, train_y = x, y
+
+ edit_x, mem_x = self.edit_mem_interfere(train_x, train_y, task_ids, mem_x, mem_y, edit_x, edit_y, edit_task_ids,
+ edit_x_val, edit_y_val, edit_task_ids_val, edit_indices)
+ if not self.no_write_back:
+ self.evolve_mem(edit_x, edit_indices)
+
+ # load cached parameters back
+ self.load_cache()
+ if not self.mem_augment:
+ combined_x = torch.cat([x, mem_x], 0)
+ combined_y = concat_with_padding([y, mem_y])
+ combined_task_ids = concat_with_padding([task_ids, mem_task_ids])
+ else:
+ if self.post_edit_mem_aug:
+ aug_mem_x = self.transform_image_batch(mem_x)
+ if self.edit_aug_mem:
+ aug_mem_x, _ = self.edit_mem_interfere(train_x, train_y, task_ids, aug_mem_x, mem_y, aug_mem_x,
+ mem_y, edit_task_ids,
+ edit_x_val, edit_y_val, edit_task_ids_val, edit_indices)
+ if self.double_weight:
+ combined_x = torch.cat([x, mem_x, aug_mem_x], 0)
+ combined_y = concat_with_padding([y, mem_y, mem_y])
+ combined_task_ids = concat_with_padding([task_ids, mem_task_ids, mem_task_ids])
+ else:
+ combined_x = torch.cat([x, mem_x, aug_mem_x], 0)
+ combined_y = concat_with_padding([y, mem_y, mem_y])
+ combined_task_ids = concat_with_padding([task_ids, mem_task_ids, mem_task_ids])
+
+ ret_dict = self.forward_net(combined_x, combined_y, task_ids=combined_task_ids, reduce=False)
+ loss_tmp = ret_dict['loss']
+ if optimize:
+ loss = loss_tmp[: x.size(0)].mean()
+ if mem_x_indices is not None:
+ loss += loss_tmp[x.size(0):].mean() #* (2 if self.double_weight else 1)
+ self.optimizer.zero_grad()
+ loss.backward()
+ self.optimizer.step()
+ if self.lb_reservoir and mem_x_indices is not None:
+ self.update_loss_states(loss_tmp[x.size(0):], mem_indices)
+ for b in range(batch_size):
+ if type(y) is tuple:
+ self.update_mem(x[b], [_[b] for _ in y], task_ids[b])
+ else:
+ self.update_mem(x[b], y[b], task_ids[b])
+
+ return ret_dict
+
+ def edit_mem_interfere(self, x, y, task_ids, mem_x, mem_y, edit_x, edit_y, edit_task_ids,
+ edit_x_val, edit_y_val, edit_task_ids_val, edit_indices):
+ """
+ Edit memory so that they are more inter
+ :param x:
+ :param y:
+ :param task_ids:
+ :param mem_x:
+ :param mem_y:
+ :param edit_x:
+ :param edit_y:
+ :param edit_task_ids:
+ :param edit_x_val:
+ :param edit_y_val:
+ :param edit_task_ids_val:
+ :return:
+ """
+ device = x.device
+ # only edit at the first iter
+
+ for i in range(self.grad_iter):
+ # evaluate loss on edit_x, edit_y
+ self.clear_mem_grad(edit_x)
+ # evaluate grad of l wrt edit
+ ret_dict_edit_before = self.forward_net(edit_x, edit_y, reduce=True, task_ids=edit_task_ids)
+ # train the model on D
+ grad_reg = -torch.autograd.grad(torch.sum(ret_dict_edit_before['loss']),
+ edit_x, retain_graph=True)[0]
+ ret_dict_edit_before['loss'].backward()
+ edit_x_grad1 = edit_x.grad
+
+ self.clear_mem_grad(edit_x)
+
+ for _ in range(1):
+ ret_dict_d = self.forward_net(x, y, task_ids=task_ids)
+ self.optimizer.zero_grad()
+ ret_dict_d['loss'].backward(retain_graph=False)
+ if isinstance(self.optimizer, torch.optim.SGD):
+ step_wo_state_update_sgd(self.optimizer, amp=1.)
+ elif isinstance(self.optimizer, torch.optim.Adam):
+ step_wo_state_update_adam(self.optimizer, amp=1.)
+ else:
+ raise NotImplementedError
+
+ ret_dict_edit_after = self.forward_net(edit_x, edit_y, reduce=True, task_ids=edit_task_ids)
+ if 'mask_cnts' not in ret_dict_edit_after:
+ loss_increase = ret_dict_edit_after['loss'] - ret_dict_edit_before['loss']
+ else:
+ loss_increase = (ret_dict_edit_after['loss'] - ret_dict_edit_before['loss']).sum() / \
+ (sum(ret_dict_edit_after['mask_cnts']) + 1e-10)
+
+ #if self.use_relu:
+ # loss_increase = F.relu(loss_increase)
+ ret_dict_edit_after['loss'].backward()
+ edit_x_grad2 = edit_x.grad
+ grad_delta = edit_x_grad2 - edit_x_grad1
+
+ grad_delta_2 = 0
+
+ self.clear_mem_grad(edit_x)
+ self.load_cache()
+
+ total_grad = 0
+
+ if self.cfg.EXTERNAL.OCL.USE_LOSS_1:
+ if self.cfg.EXTERNAL.OCL.USE_LOSS_1 == 1:
+ total_grad += self.cfg.EXTERNAL.OCL.USE_LOSS_1 * grad_delta
+ elif self.cfg.EXTERNAL.OCL.USE_LOSS_1 == -2: # random tied direction uniform norm
+ random_vecs = self.get_random_grad(grad_delta, edit_indices)
+ total_grad += random_vecs
+ elif self.cfg.EXTERNAL.OCL.USE_LOSS_1 == -3: # random tied direction keep norm
+ random_vecs = self.get_direction_perturbed_grad(grad_delta, edit_indices)
+ total_grad += random_vecs
+ elif self.cfg.EXTERNAL.OCL.USE_LOSS_1 == -4: # random untied direction keep norm
+ random_vecs = self.get_direction_perturbed_grad_untied(grad_delta, edit_indices)
+ total_grad += random_vecs
+ elif self.cfg.EXTERNAL.OCL.USE_LOSS_1 == -5: # negated gradient
+ neg_vecs = self.get_neg_gradients(grad_delta, edit_indices)
+ total_grad += neg_vecs
+ elif self.cfg.EXTERNAL.OCL.USE_LOSS_1 == -6: # adversarial-continuous
+ total_grad += edit_x_grad1
+ elif self.cfg.EXTERNAL.OCL.USE_LOSS_1 == -7: # untied, fixed
+ random_vecs = self.get_direction_perturbed_grad_untied_fixed(grad_delta, edit_indices)
+ total_grad += random_vecs
+ elif self.cfg.EXTERNAL.OCL.USE_LOSS_1 == -8: # PGD
+ total_grad += torch.sign(edit_x_grad1)
+ else:
+ raise ValueError
+
+ if self.cfg.EXTERNAL.OCL.USE_LOSS_2:
+ total_grad += self.cfg.EXTERNAL.OCL.USE_LOSS_2 * grad_delta_2
+
+ if type(total_grad) is not int: # has grad update
+ #if self.cfg.EXTERNAL.OCL.PROJ_LOSS_REG:
+ if self.cfg.EXTERNAL.OCL.PROJ_LOSS_REG == 1:
+ for b in range(total_grad.size(0)):
+ total_grad[b] = proj_grad(-grad_reg[b], total_grad[b], binary=False, always_proj=self.always_proj)
+ elif self.cfg.EXTERNAL.OCL.PROJ_LOSS_REG == 2:
+ for b in range(total_grad.size(0)):
+ total_grad[b] = -grad_reg[b] * self.reg_strength + total_grad[b]
+
+ mem_ages = self.get_mem_ages(edit_indices, astype=edit_x)
+ stride_decayed = (1 - self.edit_decay) ** mem_ages
+
+ for b in range(total_grad.size(0)):
+ edit_x[b] = edit_x[b] + self.grad_stride * stride_decayed[b] * total_grad[b]
+ edit_x = edit_x.detach()
+ mem_x = mem_x.detach()
+
+ return edit_x, mem_x
+
+ # code 2: edit memory to a random direction tied to each input example
+ def get_random_grad(self, grad, indices):
+ if not hasattr(self, 'random_dirs'):
+ self.random_dirs = torch.zeros(self.mem_limit, *grad[0].size()).uniform_(-1,1)
+ random_vecs = []
+ for i, indice in enumerate(indices):
+ random_vec = self.random_dirs[indice].to(grad.device)
+ random_vec = random_vec / random_vec.norm()
+ random_vecs.append(random_vec)
+ random_vecs = torch.stack(random_vecs)
+ return random_vecs
+
+ # code 3: perturbate editing direction - keep norm
+ def get_direction_perturbed_grad(self, grad, indices):
+ if not hasattr(self, 'random_dirs'):
+ self.random_dirs = torch.zeros(self.mem_limit, *grad[0].size()).uniform_(-1, 1)
+ random_vecs = []
+ for i, indice in enumerate(indices):
+ random_vec = self.random_dirs[indice].to(grad.device)
+ random_vec = random_vec / random_vec.norm() * grad[i].norm()
+ random_vecs.append(random_vec)
+ random_vecs = torch.stack(random_vecs)
+ return random_vecs
+
+ # code 4: perturbate editing direction - not tied
+ def get_direction_perturbed_grad_untied(self, grad, indices):
+ self.random_dirs = torch.zeros(self.mem_limit, *grad[0].size()).to(grad.device).uniform_(-1, 1)
+ random_vecs = []
+ for i, indice in enumerate(indices):
+ random_vec = self.random_dirs[indice]
+ random_vec = random_vec / random_vec.norm() * grad[i].norm()
+ random_vecs.append(random_vec)
+ random_vecs = torch.stack(random_vecs)
+ return random_vecs
+
+ # code 5: flip the update direction
+ def get_neg_gradients(self, grad, indices):
+ return -grad
+
+ # code 7: untied and fixed
+ def get_direction_perturbed_grad_untied_fixed(self, grad, indices):
+ self.random_dirs = torch.zeros(len(indices), *grad[0].size()).to(grad.device).normal_(0, 1)
+ random_vecs = []
+ for i, indice in enumerate(indices):
+ random_vec = self.random_dirs[i]
+ random_vec = random_vec / random_vec.norm()
+ random_vecs.append(random_vec)
+ random_vecs = torch.stack(random_vecs)
+ return random_vecs
+
+ def to_mem_type(self, x, y, y_extra):
+ x = x.cpu().numpy()
+ if type(y) not in [list, tuple]:
+ y = y_to_np(y)
+ else:
+ y = y_to_cpu(y)
+ if type(y_extra) not in [list, tuple]:
+ y_extra = y_to_np(y_extra)
+ else:
+ y_extra = y_to_cpu(y_extra)
+ return x, y, y_extra
+
+ def indices_to_examples(self, indices, device):
+ cand_x = self.reservoir['x'][indices]
+ cand_x = torch.from_numpy(cand_x).to(device).float()
+ cand_y = index_select(self.reservoir['y'], indices, device) # [ [...], [...] ]
+ cand_y_extra = index_select(self.reservoir['y_extra'], indices, device)
+ if type(cand_y[0]) not in [list, tuple]:
+ cand_y_pad = concat_with_padding(cand_y)
+ else:
+ cand_y_pad = [torch.stack(_).to(device) for _ in zip(*cand_y)]
+ cand_y_extra = concat_with_padding(cand_y_extra)
+ return cand_x, cand_y_pad, cand_y_extra
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..2aec31b
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,126 @@
+absl-py==0.9.0
+apex==0.1
+astunparse==1.6.3
+attrs==19.3.0
+backcall==0.1.0
+bleach==3.1.0
+boto3==1.11.9
+botocore==1.14.9
+cachetools==4.1.0
+certifi==2019.11.28
+cffi==1.13.2
+chardet==3.0.4
+cityscapesScripts==1.1.0
+Click==7.0
+cycler==0.10.0
+Cython==0.29.14
+decorator==4.4.1
+defusedxml==0.6.0
+docutils==0.15.2
+easydict==1.9
+entrypoints==0.3
+gast==0.3.3
+google-auth==1.15.0
+google-auth-oauthlib==0.4.1
+google-pasta==0.2.0
+grpcio==1.29.0
+h5py==2.10.0
+idna==2.8
+imagecodecs==2020.2.18
+imageio==2.8.0
+importlib-metadata==1.5.0
+inflect==4.1.0
+ipykernel==5.1.4
+ipython==7.10.2
+ipython-genutils==0.2.0
+jaraco.itertools==5.0.0
+jedi==0.15.1
+Jinja2==2.11.1
+jmespath==0.9.4
+joblib==0.14.1
+jsonlines==1.2.0
+jsonschema==3.2.0
+jupyter-client==5.3.4
+jupyter-core==4.6.1
+Keras-Preprocessing==1.1.2
+kiwisolver==1.1.0
+Markdown==3.2.2
+MarkupSafe==1.1.1
+-e git+https://github.com/facebookresearch/maskrcnn-benchmark.git@57eec25b75144d9fb1a6857f32553e1574177daf#egg=maskrcnn_benchmark
+matplotlib==3.1.2
+mistune==0.8.4
+mkl-fft==1.0.15
+mkl-random==1.1.0
+mkl-service==2.3.0
+more-itertools==8.2.0
+nbconvert==5.6.1
+nbformat==5.0.4
+networkx==2.4
+ninja==1.9.0.post1
+nltk==3.4.5
+notebook==6.0.3
+numpy==1.18.0
+oauthlib==3.1.0
+olefile==0.46
+opencv-python==4.1.2.30
+opt-einsum==3.2.1
+pandas==1.0.3
+pandocfilters==1.4.2
+parso==0.5.2
+pexpect==4.7.0
+pickleshare==0.7.5
+Pillow==6.2.1
+prometheus-client==0.7.1
+prompt-toolkit==3.0.2
+protobuf==3.10.0
+ptyprocess==0.6.0
+pyasn1==0.4.8
+pyasn1-modules==0.2.8
+pycocotools==2.0
+pycparser==2.19
+Pygments==2.5.2
+pygtrie==2.3.3
+pyparsing==2.4.6
+pyrsistent==0.15.7
+python-dateutil==2.8.1
+pytorch-transformers==1.2.0
+pytz==2020.1
+PyWavelets==1.1.1
+PyYAML==5.2
+pyzmq==18.1.1
+quadprog==0.1.7
+recordclass==0.12.0.1
+regex==2019.8.19
+requests==2.22.0
+requests-oauthlib==1.3.0
+rsa==4.0
+s3transfer==0.3.2
+sacremoses==0.0.38
+scikit-image==0.17.2
+scikit-learn==0.22
+scipy==1.3.2
+Send2Trash==1.5.0
+sentencepiece==0.1.85
+six==1.13.0
+stanfordnlp==0.2.0
+tensorboard==2.2.1
+tensorboard-plugin-wit==1.6.0.post3
+tensorboardX==2.0
+tensorflow==2.2.0
+tensorflow-estimator==2.2.0
+termcolor==1.1.0
+terminado==0.8.3
+testpath==0.4.4
+tifffile==2020.5.11
+torch==1.1.0
+torchvision==0.2.2
+tornado==6.0.3
+tqdm==4.41.0
+traitlets==4.3.3
+urllib3==1.25.8
+wcwidth==0.1.7
+webencodings==0.5.1
+Werkzeug==1.0.1
+wrapt==1.12.1
+yacs==0.1.6
+zipp==2.1.0
diff --git a/scripts/gmed_mini_imagenet.sh b/scripts/gmed_mini_imagenet.sh
new file mode 100644
index 0000000..337f56e
--- /dev/null
+++ b/scripts/gmed_mini_imagenet.sh
@@ -0,0 +1,56 @@
+#!/bin/bash
+
+min_seed=5
+max_seed=9
+model_name=${1}
+grad_iter=${2}
+grad_stride=${3}
+use_loss_1=${4}
+use_loss_2=${5}
+proj_loss_reg=${6}
+mem_size=${7}
+extra_opt=${8}
+extra=""
+extra_args=""
+
+if [[ -n "$9" ]]; then
+ min_seed=${9}
+fi
+if [[ -n "${10}" ]]; then
+ max_seed=${10}
+fi
+if [[ -n "${11}" ]]; then
+ extra_args=${11}
+fi
+#if [[ -n "$9" ]]; then
+# mem_bs=${9}
+#else
+# mem_bs=25
+#fi
+mem_bs=25
+
+seed=$min_seed
+
+
+if [[ ${extra_opt} == "mir" ]]
+then
+ extra="EXTERNAL.REPLAY.MEM_BS=${mem_bs} EXTERNAL.REPLAY.MIR_K=10 EXTERNAL.OCL.MIR=1"
+fi
+if [[ ${extra_opt} == "mirl" ]]
+then
+ extra="EXTERNAL.REPLAY.MEM_BS=${mem_bs} EXTERNAL.REPLAY.MIR_K=10 EXTERNAL.OCL.MIR=1 EXTERNAL.OCL.EDIT_LEAST=1"
+fi
+if [[ ${extra_opt} == "mirr" ]]
+then
+ extra="EXTERNAL.REPLAY.MEM_BS=${mem_bs} EXTERNAL.REPLAY.MIR_K=10 EXTERNAL.OCL.MIR=1 EXTERNAL.OCL.EDIT_RANDOM=1"
+fi
+if [[ ${extra_opt} == 'relu' ]]
+then
+ extra="EXTERNAL.OCL.USE_RELU=1"
+fi
+
+while(( $seed<= $max_seed ))
+do
+ python train.py --name ${1} --config configs/memevolve/verx_mini_imagenet.yaml --seed ${seed} --cfg EXTERNAL.OCL.GRAD_ITER=${grad_iter} EXTERNAL.OCL.GRAD_STRIDE=${grad_stride} EXTERNAL.OCL.USE_LOSS_1=${use_loss_1} EXTERNAL.OCL.USE_LOSS_2=${use_loss_2} EXTERNAL.OCL.PROJ_LOSS_REG=${proj_loss_reg} EXTERNAL.REPLAY.MEM_LIMIT=${mem_size} ${extra} ${extra_args}
+ let "seed++"
+done
\ No newline at end of file
diff --git a/scripts/gmed_permuted_mnist.sh b/scripts/gmed_permuted_mnist.sh
new file mode 100644
index 0000000..9eac682
--- /dev/null
+++ b/scripts/gmed_permuted_mnist.sh
@@ -0,0 +1,78 @@
+#!/bin/bash
+
+min_seed=0
+max_seed=4
+seed=$min_seed
+model_name=${1}
+grad_iter=${2}
+grad_stride=${3}
+use_loss_1=${4}
+use_loss_2=${5}
+proj_loss_reg=${6}
+mem_size=${7}
+extra_opt=${8}
+extra=""
+extra_args=""
+
+if [[ -n "$9" ]]; then
+ min_seed=${9}
+fi
+if [[ -n "${10}" ]]; then
+ max_seed=${10}
+fi
+if [[ -n "${11}" ]]; then
+ extra_args=${11}
+fi
+#if [[ -n "$9" ]]; then
+# mem_bs=${9}
+#else
+# mem_bs=25
+#fi
+mem_bs=50
+
+seed=$min_seed
+
+if [[ ${extra_opt} == "mir" ]]
+then
+ extra="EXTERNAL.REPLAY.MEM_BS=${mem_bs} EXTERNAL.REPLAY.MIR_K=10 EXTERNAL.OCL.MIR=1"
+fi
+if [[ ${extra_opt} == "mirl" ]]
+then
+ extra="EXTERNAL.REPLAY.MEM_BS=${mem_bs} EXTERNAL.REPLAY.MIR_K=10 EXTERNAL.OCL.MIR=1 EXTERNAL.OCL.EDIT_LEAST=1"
+fi
+if [[ ${extra_opt} == "mirr" ]]
+then
+ extra="EXTERNAL.REPLAY.MEM_BS=${mem_bs} EXTERNAL.REPLAY.MIR_K=10 EXTERNAL.OCL.MIR=1 EXTERNAL.OCL.EDIT_RANDOM=1"
+fi
+if [[ ${extra_opt} == "edit_replace" ]]
+then
+ extra="EXTERNAL.OCL.EDIT_INTERFERE=0 EXTERNAL.OCL.EDIT_REPLACE=1"
+fi
+if [[ ${extra_opt} == "rrw" ]]
+then
+ extra="EXTERNAL.OCL.REPLACE_REWEIGHT=1"
+fi
+if [[ ${extra_opt} == "rrwr" ]]
+then
+ extra="EXTERNAL.OCL.REPLACE_REWEIGHT=2"
+fi
+if [[ ${extra_opt} == "supp_proj" ]]
+then
+ extra="EXTERNAL.OCL.REG_SUPPORTIVE=1"
+fi
+if [[ ${extra_opt} == "supp_reg" ]]
+then
+ extra="EXTERNAL.OCL.REG_SUPPORTIVE=2"
+fi
+
+config_file="configs/memevolve/ver_permute_approx.yaml"
+if [[ ${extra_opt} == "hal" ]]
+then
+ config_file="configs/baselines/hal_permuted_mnist.yaml"
+fi
+
+while(( $seed<= $max_seed ))
+do
+ python train.py --name ${1} --config ${config_file} --seed ${seed} --cfg EXTERNAL.OCL.GRAD_ITER=${grad_iter} EXTERNAL.OCL.GRAD_STRIDE=${grad_stride} EXTERNAL.OCL.USE_LOSS_1=${use_loss_1} EXTERNAL.OCL.USE_LOSS_2=${use_loss_2} EXTERNAL.OCL.PROJ_LOSS_REG=${proj_loss_reg} EXTERNAL.REPLAY.MEM_LIMIT=${mem_size} ${extra} ${extra_args}
+ let "seed++"
+done
\ No newline at end of file
diff --git a/scripts/gmed_rotate_mnist.sh b/scripts/gmed_rotate_mnist.sh
new file mode 100644
index 0000000..66cf003
--- /dev/null
+++ b/scripts/gmed_rotate_mnist.sh
@@ -0,0 +1,77 @@
+#!/bin/bash
+
+min_seed=0
+max_seed=9
+model_name=${1}
+grad_iter=${2}
+grad_stride=${3}
+use_loss_1=${4}
+use_loss_2=${5}
+proj_loss_reg=${6}
+mem_size=${7}
+extra_opt=${8}
+extra=""
+extra_arg=""
+
+
+#if [[ -n "$9" ]]; then
+# mem_bs=${9}
+#else
+# mem_bs=50
+#fi
+mem_bs=50
+
+if [[ -n "$9" ]]; then
+ min_seed=${9}
+fi
+if [[ -n "${10}" ]]; then
+ max_seed=${10}
+fi
+
+seed=$min_seed
+
+if [[ -n "${11}" ]]; then
+ extra_args=${11}
+fi
+
+if [[ ${extra_opt} == "mir" ]]
+then
+ extra="EXTERNAL.REPLAY.MEM_BS=${mem_bs} EXTERNAL.REPLAY.MIR_K=10 EXTERNAL.OCL.MIR=1"
+fi
+if [[ ${extra_opt} == "mirl" ]]
+then
+ extra="EXTERNAL.REPLAY.MEM_BS=${mem_bs} EXTERNAL.REPLAY.MIR_K=10 EXTERNAL.OCL.MIR=1 EXTERNAL.OCL.EDIT_LEAST=1"
+fi
+if [[ ${extra_opt} == "mirr" ]]
+then
+ extra="EXTERNAL.REPLAY.MEM_BS=${mem_bs} EXTERNAL.REPLAY.MIR_K=10 EXTERNAL.OCL.MIR=1 EXTERNAL.OCL.EDIT_RANDOM=1"
+fi
+if [[ ${extra_opt} == "supp_proj" ]]
+then
+ extra="EXTERNAL.OCL.REG_SUPPORTIVE=1"
+fi
+if [[ ${extra_opt} == "supp_reg" ]]
+then
+ extra="EXTERNAL.OCL.REG_SUPPORTIVE=2"
+fi
+if [[ ${extra_opt} == "supp_proj_meta" ]]
+then
+ extra="EXTERNAL.OCL.REG_SUPPORTIVE=3"
+fi
+if [[ ${extra_opt} == "supp_reg_meta" ]]
+then
+ extra="EXTERNAL.OCL.REG_SUPPORTIVE=4"
+fi
+
+config_file="configs/memevolve/ver_permute_approx.yaml"
+if [[ ${extra_opt} == "hal" ]]
+then
+ config_file="configs/baselines/hal_permuted_mnist.yaml"
+fi
+
+
+while(( $seed<= $max_seed ))
+do
+ python train.py --name ${1} --config configs/memevolve/ver_rotate_approx.yaml --seed ${seed} --cfg EXTERNAL.OCL.GRAD_ITER=${grad_iter} EXTERNAL.OCL.GRAD_STRIDE=${grad_stride} EXTERNAL.OCL.USE_LOSS_1=${use_loss_1} EXTERNAL.OCL.USE_LOSS_2=${use_loss_2} EXTERNAL.OCL.PROJ_LOSS_REG=${proj_loss_reg} EXTERNAL.REPLAY.MEM_LIMIT=${mem_size} ${extra} ${extra_args}
+ let "seed++"
+done
\ No newline at end of file
diff --git a/scripts/gmed_split_cifar10.sh b/scripts/gmed_split_cifar10.sh
new file mode 100644
index 0000000..e328ca0
--- /dev/null
+++ b/scripts/gmed_split_cifar10.sh
@@ -0,0 +1,78 @@
+#!/bin/bash
+
+min_seed=0
+max_seed=4
+seed=$min_seed
+model_name=${1}
+grad_iter=${2}
+grad_stride=${3}
+use_loss_1=${4}
+use_loss_2=${5}
+proj_loss_reg=${6}
+mem_size=${7}
+extra_opt=${8}
+extra=""
+extra_args=""
+
+if [[ -n "$9" ]]; then
+ min_seed=${9}
+fi
+if [[ -n "${10}" ]]; then
+ max_seed=${10}
+fi
+if [[ -n "${11}" ]]; then
+ extra_args=${11}
+fi
+#if [[ -n "$9" ]]; then
+# mem_bs=${9}
+#else
+# mem_bs=25
+#fi
+mem_bs=50
+
+seed=$min_seed
+
+if [[ ${extra_opt} == "mir" ]]
+then
+ extra="EXTERNAL.REPLAY.MEM_BS=${mem_bs} EXTERNAL.REPLAY.MIR_K=10 EXTERNAL.OCL.MIR=1"
+fi
+if [[ ${extra_opt} == "mirl" ]]
+then
+ extra="EXTERNAL.REPLAY.MEM_BS=${mem_bs} EXTERNAL.REPLAY.MIR_K=10 EXTERNAL.OCL.MIR=1 EXTERNAL.OCL.EDIT_LEAST=1"
+fi
+if [[ ${extra_opt} == "mirr" ]]
+then
+ extra="EXTERNAL.REPLAY.MEM_BS=${mem_bs} EXTERNAL.REPLAY.MIR_K=10 EXTERNAL.OCL.MIR=1 EXTERNAL.OCL.EDIT_RANDOM=1"
+fi
+if [[ ${extra_opt} == "rrw" ]]
+then
+ extra="EXTERNAL.OCL.REPLACE_REWEIGHT=1"
+fi
+if [[ ${extra_opt} == "rrwr" ]]
+then
+ extra="EXTERNAL.OCL.REPLACE_REWEIGHT=2"
+fi
+if [[ ${extra_opt} == "supp_proj" ]]
+then
+ extra="EXTERNAL.OCL.REG_SUPPORTIVE=1"
+fi
+if [[ ${extra_opt} == "supp_reg" ]]
+then
+ extra="EXTERNAL.OCL.REG_SUPPORTIVE=2"
+fi
+if [[ ${extra_opt} == 'relu' ]]
+then
+ extra="EXTERNAL.OCL.USE_RELU=1"
+fi
+
+config_file="configs/memevolve/verx_cifar.yaml"
+if [[ ${extra_opt} == "hal" ]]
+then
+ config_file="configs/baselines/hal_split_cifar.yaml"
+fi
+
+while(( $seed<= $max_seed ))
+do
+ python train.py --name ${1} --config ${config_file} --seed ${seed} --cfg EXTERNAL.OCL.GRAD_ITER=${grad_iter} EXTERNAL.OCL.GRAD_STRIDE=${grad_stride} EXTERNAL.OCL.USE_LOSS_1=${use_loss_1} EXTERNAL.OCL.USE_LOSS_2=${use_loss_2} EXTERNAL.OCL.PROJ_LOSS_REG=${proj_loss_reg} EXTERNAL.REPLAY.MEM_LIMIT=${mem_size} ${extra} ${extra_args}
+ let "seed++"
+done
\ No newline at end of file
diff --git a/scripts/gmed_split_cifar100.sh b/scripts/gmed_split_cifar100.sh
new file mode 100644
index 0000000..b1f34b4
--- /dev/null
+++ b/scripts/gmed_split_cifar100.sh
@@ -0,0 +1,84 @@
+#!/bin/bash
+
+min_seed=0
+max_seed=4
+seed=$min_seed
+model_name=${1}
+grad_iter=${2}
+grad_stride=${3}
+use_loss_1=${4}
+use_loss_2=${5}
+proj_loss_reg=${6}
+mem_size=${7}
+extra_opt=${8}
+extra=""
+extra_args=""
+
+if [[ -n "$9" ]]; then
+ min_seed=${9}
+fi
+if [[ -n "${10}" ]]; then
+ max_seed=${10}
+fi
+if [[ -n "${11}" ]]; then
+ extra_args=${11}
+fi
+#if [[ -n "$9" ]]; then
+# mem_bs=${9}
+#else
+# mem_bs=25
+#fi
+mem_bs=50
+
+seed=$min_seed
+
+if [[ ${extra_opt} == "mir" ]]
+then
+ extra="EXTERNAL.REPLAY.MEM_BS=${mem_bs} EXTERNAL.REPLAY.MIR_K=10 EXTERNAL.OCL.MIR=1"
+fi
+if [[ ${extra_opt} == "mirl" ]]
+then
+ extra="EXTERNAL.REPLAY.MEM_BS=${mem_bs} EXTERNAL.REPLAY.MIR_K=10 EXTERNAL.OCL.MIR=1 EXTERNAL.OCL.EDIT_LEAST=1"
+fi
+if [[ ${extra_opt} == "mirr" ]]
+then
+ extra="EXTERNAL.REPLAY.MEM_BS=${mem_bs} EXTERNAL.REPLAY.MIR_K=10 EXTERNAL.OCL.MIR=1 EXTERNAL.OCL.EDIT_RANDOM=1"
+fi
+if [[ ${extra_opt} == "rrw" ]]
+then
+ extra="EXTERNAL.OCL.REPLACE_REWEIGHT=1"
+fi
+if [[ ${extra_opt} == "rrwr" ]]
+then
+ extra="EXTERNAL.OCL.REPLACE_REWEIGHT=2"
+fi
+if [[ ${extra_opt} == "supp_proj" ]]
+then
+ extra="EXTERNAL.OCL.REG_SUPPORTIVE=1"
+fi
+if [[ ${extra_opt} == "supp_reg" ]]
+then
+ extra="EXTERNAL.OCL.REG_SUPPORTIVE=2"
+fi
+if [[ ${extra_opt} == 'relu' ]]
+then
+ extra="EXTERNAL.OCL.USE_RELU=1"
+fi
+
+config_file="configs/memevolve/verx_cifar100.yaml"
+
+if [[ ${extra_opt} == "hal" ]]
+then
+ config_file="configs/baselines/hal_split_cifar100.yaml"
+fi
+
+if [[ ${extra_opt} == "cndpm" ]]
+then
+ config_file="configs/baselines/cndpm_split_cifar100.yaml"
+fi
+
+while(( $seed<= $max_seed ))
+do
+ python train.py --name ${1} --config ${config_file} --seed ${seed} --cfg EXTERNAL.OCL.GRAD_ITER=${grad_iter} EXTERNAL.OCL.GRAD_STRIDE=${grad_stride} EXTERNAL.OCL.USE_LOSS_1=${use_loss_1} EXTERNAL.OCL.USE_LOSS_2=${use_loss_2} EXTERNAL.OCL.PROJ_LOSS_REG=${proj_loss_reg} EXTERNAL.REPLAY.MEM_LIMIT=${mem_size} ${extra} ${extra_args}
+ let "seed++"
+done
\ No newline at end of file
diff --git a/scripts/gmed_split_mnist.sh b/scripts/gmed_split_mnist.sh
new file mode 100755
index 0000000..0924f8b
--- /dev/null
+++ b/scripts/gmed_split_mnist.sh
@@ -0,0 +1,82 @@
+#!/bin/bash
+
+min_seed=0
+max_seed=4
+seed=$min_seed
+model_name=${1}
+grad_iter=${2}
+grad_stride=${3}
+use_loss_1=${4}
+use_loss_2=${5}
+proj_loss_reg=${6}
+mem_size=${7}
+extra_opt=${8}
+extra=""
+
+mem_bs=50
+if [[ -n "$9" ]]; then
+ min_seed=${9}
+fi
+if [[ -n "${10}" ]]; then
+ max_seed=${10}
+fi
+
+seed=$min_seed
+
+if [[ -n "${11}" ]]; then
+ extra_args=${11}
+fi
+
+
+if [[ ${extra_opt} == "mir" ]]
+then
+ extra="EXTERNAL.REPLAY.MEM_BS=${mem_bs} EXTERNAL.REPLAY.MIR_K=10 EXTERNAL.OCL.MIR=1"
+fi
+if [[ ${extra_opt} == "mirl" ]]
+then
+ extra="EXTERNAL.REPLAY.MEM_BS=${mem_bs} EXTERNAL.REPLAY.MIR_K=10 EXTERNAL.OCL.MIR=1 EXTERNAL.OCL.EDIT_LEAST=1"
+fi
+if [[ ${extra_opt} == "mirr" ]]
+then
+ extra="EXTERNAL.REPLAY.MEM_BS=${mem_bs} EXTERNAL.REPLAY.MIR_K=10 EXTERNAL.OCL.MIR=1 EXTERNAL.OCL.EDIT_RANDOM=1"
+fi
+if [[ ${extra_opt} == "edit_replace" ]]
+then
+ extra="EXTERNAL.OCL.EDIT_INTERFERE=0 EXTERNAL.OCL.EDIT_REPLACE=1"
+fi
+if [[ ${extra_opt} == "rrw" ]]
+then
+ extra="EXTERNAL.OCL.REPLACE_REWEIGHT=1"
+fi
+if [[ ${extra_opt} == "rrwr" ]]
+then
+ extra="EXTERNAL.OCL.REPLACE_REWEIGHT=2"
+fi
+if [[ ${extra_opt} == "supp_proj" ]]
+then
+ extra="EXTERNAL.OCL.REG_SUPPORTIVE=1"
+fi
+if [[ ${extra_opt} == "supp_reg" ]]
+then
+ extra="EXTERNAL.OCL.REG_SUPPORTIVE=2"
+fi
+if [[ ${extra_opt} == "supp_proj_meta" ]]
+then
+ extra="EXTERNAL.OCL.REG_SUPPORTIVE=3"
+fi
+if [[ ${extra_opt} == "supp_reg_meta" ]]
+then
+ extra="EXTERNAL.OCL.REG_SUPPORTIVE=4"
+fi
+
+config_file="configs/memevolve/ver_approx.yaml"
+
+if [[ ${extra_opt} == "hal" ]]
+then
+ config_file="configs/baselines/hal_split_mnist.yaml"
+fi
+while(( $seed<= $max_seed ))
+do
+ python train.py --name ${1} --config ${config_file} --seed ${seed} --cfg EXTERNAL.OCL.GRAD_ITER=${grad_iter} EXTERNAL.OCL.GRAD_STRIDE=${grad_stride} EXTERNAL.OCL.USE_LOSS_1=${use_loss_1} EXTERNAL.OCL.USE_LOSS_2=${use_loss_2} EXTERNAL.OCL.PROJ_LOSS_REG=${proj_loss_reg} EXTERNAL.REPLAY.MEM_LIMIT=${mem_size} ${extra} ${extra_args}
+ let "seed++"
+done
\ No newline at end of file
diff --git a/scripts/tune_params.py b/scripts/tune_params.py
new file mode 100644
index 0000000..d348c32
--- /dev/null
+++ b/scripts/tune_params.py
@@ -0,0 +1,96 @@
+import os
+import logging
+import argparse
+import random
+
+logger = logging.getLogger(__name__)
+
+ds_to_name = {'split_cifar': 'cifar10_5tasks_2class_ci_verx', 'mini_imagenet': 'mini_imagenet_ci_verx',
+ 'rotated_mnist': 'rotated_mnist_verx', 'split_mnist':'split_mnist_verx',
+ 'permuted_mnist': 'pm_verx'}
+ds_to_config = {'split_cifar': 'verx_cifar.yaml', 'mini_imagenet': 'verx_mini_imagenet.yaml',
+ 'rotated_mnist': 'ver_rotate_approx.yaml', 'split_mnist':'ver_approx.yaml',
+ 'permuted_mnist': 'ver_permute_approx.yaml'}
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--start_seed', type=int, default=0)
+ parser.add_argument('--stop_seed', type=int, default=5)
+ parser.add_argument('--dataset')
+ parser.add_argument('--strides', type=float, nargs='+', default=[0.1,0.05,0.5,1.0])
+ parser.add_argument('--proj_flags', type=int, nargs='+', default=[2])
+ parser.add_argument('--reg_strengths', type=float, nargs='+', default=[0.001,0.01,0.1])
+ parser.add_argument('--lr', type=float, nargs='*', default=[0])
+ parser.add_argument('--iters', type=int, nargs='+')
+ parser.add_argument('--mem_sizes', type=int, nargs='+')
+ parser.add_argument('--use_relu', type=int, default=0)
+ parser.add_argument('--epoch', type=int, default=3)
+ parser.add_argument('--batch_size', type=int, default=10)
+ parser.add_argument('--task_num', type=int, default=3)
+ parser.add_argument('--extra', default='')
+ parser.add_argument('--name', default='')
+ parser.add_argument('--overwrite', action='store_true')
+ parser.add_argument('--mir', action='store_true')
+ parser.add_argument('--rand', action='store_true')
+
+ args = parser.parse_args()
+
+ logger.addHandler(logging.FileHandler('logs/{}_{}.txt'.format(args.dataset, random.randint(0,10000000))))
+ logger.setLevel(logging.INFO)
+
+ strides = args.strides
+ mem_sizes = args.mem_sizes
+ iters = args.iters
+
+ use_loss_1 = 1 if not args.rand else -2
+
+ extra_auto = ''
+ if args.mir:
+ extra_auto = 'EXTERNAL.OCL.MIR=1 EXTERNAL.REPLAY.MEM_BS=10 EXTERNAL.REPLAY.MIR_K=10 EXTERNAL.OCL.EDIT_RANDOM=1'
+
+ for stride in strides:
+ for mem_size in mem_sizes:
+ for proj_flag in args.proj_flags:
+ for reg_strength in args.reg_strengths:
+ for lr in args.lr:
+
+ for iter_n in iters:
+ for seed in range(args.start_seed, args.stop_seed):
+ name_mark = ''
+ if args.use_relu:
+ name_mark += 'relu'
+ if args.mir:
+ name_mark += 'mir'
+
+ name = '{}_iter{}_stride_{}_{}_0_{}{}_m{}_{}_{}_{}'.format(ds_to_name[args.dataset], iter_n, stride, use_loss_1,
+ proj_flag, reg_strength, mem_size, seed,
+ name_mark,
+ args.name)
+ if lr != 0:
+ name += '_lr{}_epoch{}'.format(lr, args.epoch)
+ result_file = 'runs/{}_{}/result_tune_k{}.json'.format(name, seed, args.task_num)
+ if not args.overwrite and os.path.isfile(result_file):
+ print('result file exists, skiping')
+ else:
+ command = 'python train.py --name {name} --tune --config configs/memevolve/{config} --seed {seed} ' \
+ '--cfg EXTERNAL.OCL.GRAD_ITER={grad_iter} EXTERNAL.OCL.GRAD_STRIDE={grad_stride} '\
+ 'EXTERNAL.OCL.USE_LOSS_1={use_loss_1} EXTERNAL.OCL.USE_LOSS_2=0 EXTERNAL.OCL.USE_RELU=0 '\
+ 'EXTERNAL.OCL.PROJ_LOSS_REG={proj_flag} EXTERNAL.REPLAY.MEM_LIMIT={mem_size} ' \
+ 'EXTERNAL.OCL.REG_STRENGTH={reg_strength} ' \
+ 'EXTERNAL.OCL.USE_RELU={use_relu} ' \
+ 'EXTERNAL.EPOCH={epoch} ' \
+ 'EXTERNAL.BATCH_SIZE={batch_size} ' \
+ 'EXTERNAL.OCL.TASK_NUM={task_num} ' \
+ '{extra} ' \
+ '{extra_auto} '\
+ .format(**{'name': name, 'config': ds_to_config[args.dataset], 'seed':seed,
+ 'grad_iter':iter_n, 'grad_stride': stride, 'mem_size': mem_size,
+ 'extra': args.extra, 'proj_flag': proj_flag, 'reg_strength': reg_strength,
+ 'use_relu': args.use_relu, 'epoch': args.epoch, 'extra_auto': extra_auto,
+ 'batch_size': args.batch_size, 'task_num': args.task_num, 'use_loss_1': use_loss_1})
+ if lr != 0:
+ command += ' SOLVER.BASE_LR={} '.format(lr)
+ logger.info(command)
+ exit_code = os.system(command)
+ logger.info('Exit code {}'.format(exit_code))
\ No newline at end of file
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..effb24a
--- /dev/null
+++ b/train.py
@@ -0,0 +1,232 @@
+import argparse
+import os
+import sys
+
+import torch
+import numpy as np
+import random
+from trainer_benchmark import ocl_train_mnist, ocl_train_cifar
+from utils.utils import get_exp_id, set_config_attr, get_config_attr
+from yacs.config import CfgNode
+
+from nets.classifier import ResNetClassifier, ResNetClassifierWObj
+from nets.simplenet import mnist_simple_net_400width_classlearning_1024input_10cls_1ds
+
+
+from ocl import NaiveWrapper, ExperienceReplay, AGEM, ExperienceEvolveApprox
+
+
+def train(cfg, local_rank, distributed, tune=False):
+ is_ocl = hasattr(cfg.EXTERNAL.OCL, 'ALGO') and cfg.EXTERNAL.OCL.ALGO != 'PLAIN'
+ task_incremental = get_config_attr(cfg, 'EXTERNAL.OCL.TASK_INCREMENTAL', default=False)
+
+ cfg.TUNE = tune
+
+ algo = cfg.EXTERNAL.OCL.ALGO
+ if hasattr(cfg,'MNIST'):
+ if cfg.MNIST.TASK == 'split':
+ goal = 'split_mnist'
+ elif cfg.MNIST.TASK == 'permute':
+ goal = 'permute_mnist'
+ elif cfg.MNIST.TASK == 'rotate':
+ goal = 'rotated_mnist'
+ if hasattr(cfg, 'CIFAR'):
+ goal = 'split_cifar'
+ if get_config_attr(cfg, 'CIFAR.DATASET', default="") == 'CIFAR100':
+ goal = 'split_cifar100'
+ if get_config_attr(cfg, 'CIFAR.MINI_IMAGENET', default=0):
+ goal = 'split_mini_imagenet'
+
+
+
+ if hasattr(cfg,'MNIST'):
+
+ num_of_datasets = 1 if not task_incremental else get_config_attr(cfg, 'EXTERNAL.OCL.TASK_NUM', totype=int)
+ num_of_classes = 10 if not task_incremental else get_config_attr(cfg, 'EXTERNAL.OCL.CLASS_NUM', totype=int)
+ base_model = mnist_simple_net_400width_classlearning_1024input_10cls_1ds(num_of_datasets=num_of_datasets,
+ num_of_classes=num_of_classes,
+ task_incremental=task_incremental)
+
+ base_model.cfg = cfg
+ elif hasattr(cfg, 'CIFAR'):
+ if goal == 'split_cifar':
+ num_of_datasets = 1 if not task_incremental else get_config_attr(cfg, 'EXTERNAL.OCL.TASK_NUM', totype=int)
+ num_of_classes = 10 if not task_incremental else get_config_attr(cfg, 'EXTERNAL.OCL.CLASS_NUM', totype=int)
+ elif goal == 'split_cifar100':
+ num_of_datasets = 1 if not task_incremental else get_config_attr(cfg, 'EXTERNAL.OCL.TASK_NUM', totype=int)
+ num_of_classes = 100 if not task_incremental else get_config_attr(cfg, 'EXTERNAL.OCL.CLASS_NUM', totype=int)
+ elif goal == 'split_mini_imagenet':
+ num_of_datasets = 1 if not task_incremental else get_config_attr(cfg, 'EXTERNAL.OCL.TASK_NUM', totype=int)
+ num_of_classes = 100 if not task_incremental else get_config_attr(cfg, 'EXTERNAL.OCL.CLASS_NUM', totype=int)
+
+
+ base_model = ResNetClassifier(cfg, depth='18', mlp=1, ignore_index=-100, num_of_datasets=num_of_datasets,
+ num_of_classes=num_of_classes, task_incremental=task_incremental, goal=goal)
+ base_model.cfg = cfg
+ else:
+ base_model = ResNetClassifier(cfg)
+
+ device = torch.device(cfg.MODEL.DEVICE)
+ base_model.to(device)
+ if cfg.EXTERNAL.OPTIMIZER.ADAM:
+ optimizer = torch.optim.Adam(
+ filter(lambda x: x.requires_grad, base_model.parameters()),
+ lr=cfg.SOLVER.BASE_LR, betas=(0.9, 0.999)
+ )
+ else:
+ optimizer = torch.optim.SGD(
+ filter(lambda x: x.requires_grad, base_model.parameters()),
+ lr=cfg.SOLVER.BASE_LR
+ )
+
+ # algorithm specific model wrapper
+ x_size = 3 * 2 * base_model.cfg.EXTERNAL.IMAGE_SIZE ** 2 if goal == 'classification' else \
+ 3 * base_model.cfg.EXTERNAL.IMAGE_SIZE ** 2
+ if goal == 'split_mnist' or goal == 'permute_mnist' or goal == 'rotated_mnist': x_size = 28 * 28
+ if goal == 'split_cifar' or goal == 'split_cifar100': x_size = 3 * 32 * 32
+ if goal == 'split_mini_imagenet': x_size = 3 * 84 * 84
+
+ if algo == 'ER':
+ model = ExperienceReplay(base_model, optimizer, x_size, base_model.cfg, goal)
+ elif algo == 'VERX':
+ model = ExperienceEvolveApprox(base_model, optimizer, x_size, base_model.cfg, goal)
+ elif algo == 'AGEM':
+ model = AGEM(base_model, optimizer, x_size, base_model.cfg, goal)
+ elif algo == 'naive':
+ model = NaiveWrapper(base_model, optimizer, x_size, base_model.cfg, goal)
+ model.to(device)
+
+ use_mixed_precision = cfg.DTYPE == "float16"
+ arguments = {"iteration": 0, "global_step": 0, "epoch": 0}
+ output_dir = cfg.OUTPUT_DIR
+ writer = None
+ epoch_num = 1
+ for e in range(epoch_num):
+ print("epoch")
+
+ arguments['iteration'] = 0
+ epoch = arguments['epoch']
+ if goal == 'split_mnist' or goal == 'permute_mnist' or goal == 'rotated_mnist':
+ ocl_train_mnist(model, optimizer, None, device, arguments, writer, epoch, goal, tune=tune)
+ elif goal == 'split_cifar' or goal == 'split_cifar100' or goal == 'split_mini_imagenet':
+ ocl_train_cifar(model, optimizer, None, device, arguments, writer, epoch, goal, tune=tune)
+ else:
+ raise NotImplementedError
+ arguments['epoch'] += 1
+
+ with open(os.path.join(output_dir, 'model.bin'),'wb') as wf:
+ torch.save(model.state_dict(), wf)
+ # else:
+ # break
+ if is_ocl and hasattr(model, 'dump_reservoir') and args.dump_reservoir:
+ model.dump_reservoir(os.path.join(cfg.OUTPUT_DIR, 'mem_dump.pkl'), verbose=args.dump_reservoir_verbose)
+ return model
+
+def set_cfg_from_args(args, cfg):
+ cfg_params = args.cfg
+ if cfg_params is None: return
+ for param in cfg_params:
+ k, v = param.split('=')
+ set_config_attr(cfg, k, v)
+
+def count_params(m: torch.nn.Module, only_trainable: bool = False):
+ """
+ returns the total number of parameters used by `m` (only counting
+ shared parameters once); if `only_trainable` is True, then only
+ includes parameters with `requires_grad = True`
+ """
+ parameters = m.parameters()
+ if only_trainable:
+ parameters = list(p for p in parameters if p.requires_grad)
+ unique = dict((p.data_ptr(), p) for p in parameters).values()
+ return sum(p.numel() for p in unique)
+
+
+def main(args):
+ if '%id' in args.name:
+ exp_name = args.name.replace('%id', get_exp_id())
+ else:
+ exp_name = args.name
+
+ combined_cfg = CfgNode(new_allowed=True)
+ combined_cfg.merge_from_file(args.config)
+ cfg = combined_cfg
+ cfg.EXTERNAL.EXPERIMENT_NAME = exp_name
+ cfg.SEED = args.seed
+ cfg.DEBUG = args.debug
+
+ set_cfg_from_args(args, cfg)
+
+ output_dir = get_config_attr(cfg, 'OUTPUT_DIR', default='')
+ if output_dir == '.': output_dir = 'runs/'
+ cfg.OUTPUT_DIR = os.path.join(output_dir,
+ '{}_{}'.format(cfg.EXTERNAL.EXPERIMENT_NAME, cfg.SEED))
+ cfg.MODE = 'train'
+
+ # cfg.freeze()
+
+ num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
+ distributed = num_gpus > 1
+ local_rank = int(os.environ.get('LOCAL_RANK', 0))
+
+ if distributed:
+ torch.cuda.set_device(local_rank)
+ torch.distributed.init_process_group(
+ backend="nccl", init_method="env://"
+ )
+
+ output_dir = cfg.OUTPUT_DIR
+
+ # save overloaded model config in the output directory
+ model = train(cfg, local_rank, distributed, tune=args.tune)
+
+ output_args_path = os.path.join(output_dir, 'args.txt')
+ wf = open(output_args_path, 'w')
+ wf.write(' '.join(sys.argv))
+ wf.close()
+
+def seed_everything(seed):
+ '''
+ :param seed:
+ :param device:
+ :return:
+ '''
+ random.seed(seed)
+ os.environ['PYTHONHASHSEED'] = str(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ # some cudnn methods can be random even after fixing the seed
+ # unless you tell it to be deterministic
+ torch.backends.cudnn.deterministic = True
+
+def seed_everything_old(seed):
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--local_rank', type=int)
+ parser.add_argument('--name', type=str, default='%id')
+ parser.add_argument('--config', type=str, default='config.yaml')
+ parser.add_argument('--seed', type=int, default=0)
+ parser.add_argument('--seed_old', action='store_true')
+ parser.add_argument('--n_runs', type=int, default=1)
+ parser.add_argument('--tune', action='store_true')
+ parser.add_argument('--cfg', nargs='*')
+ parser.add_argument('--debug', action='store_true')
+ parser.add_argument('--dump_reservoir', action='store_true')
+ parser.add_argument('--dump_reservoir_verbose', action='store_true')
+ parser.add_argument('--single_word', action='store')
+
+ args = parser.parse_args()
+ if not args.seed_old:
+ seed_everything(args.seed)
+ else:
+ seed_everything_old(args.seed)
+ n_runs = args.n_runs
+ for i in range(n_runs):
+ main(args)
diff --git a/trainer_benchmark.py b/trainer_benchmark.py
new file mode 100644
index 0000000..db24761
--- /dev/null
+++ b/trainer_benchmark.py
@@ -0,0 +1,242 @@
+#from trainer import *
+import logging
+import torch
+
+from tqdm import tqdm
+import json, os
+
+from inference import to_list, f1_score
+from dataloader import get_split_mnist_dataloader, get_permute_mnist_dataloader,\
+ get_split_cifar_dataloader, get_split_cifar100_dataloader, get_split_mini_imagenet_dataloader,\
+ get_rotated_mnist_dataloader, IIDDataset
+from utils.utils import get_config_attr
+from torch.utils.data import DataLoader
+
+def exp_decay_lr(optimizer, step, total_step, init_lr):
+ gamma = (1 / 6) ** (step / total_step)
+ for param_group in optimizer.param_groups:
+ param_group['lr'] = init_lr * gamma
+
+def ocl_train_mnist(model, optimizer, checkpointer, device, arguments, writer, epoch,
+ goal='split', tune=False):
+ logger = logging.getLogger("maskrcnn_benchmark.trainer")
+ logger.info("Start training @ epoch {:02d}".format(arguments['epoch']))
+ model.train()
+ cfg = model.cfg
+ pbar = tqdm(
+ position=0,
+ desc='GPU: 0'
+ )
+
+ num_instances = cfg.MNIST.INSTANCE_NUM
+
+ if goal == 'split_mnist':
+ task_num = 5
+ loader_func = get_split_mnist_dataloader
+ elif goal == 'permute_mnist':
+ task_num = 10
+ loader_func = get_permute_mnist_dataloader
+ elif goal == 'rotated_mnist':
+ task_num = 20
+ loader_func = get_rotated_mnist_dataloader
+ else:
+ raise ValueError
+
+ if tune:
+ task_num = get_config_attr(cfg, 'EXTERNAL.OCL.TASK_NUM', totype=int, default=3)
+
+ num_epoch = get_config_attr(cfg, 'EXTERNAL.EPOCH', totype=int, default=1)
+ total_step = task_num * 1000
+ base_lr = get_config_attr(cfg,'SOLVER.BASE_LR',totype=float)
+ # whether iid
+ iid = not get_config_attr(cfg, 'EXTERNAL.OCL.ACTIVATED', totype=bool)
+ do_exp_lr_decay = get_config_attr(cfg,'EXTERNAL.OCL.EXP_LR_DECAY',0)
+
+ all_accs = []
+ best_avg_accs = []
+ step = 0
+ for task_id in range(task_num):
+ if iid:
+ if task_id != 0: break
+ data_loaders = [loader_func(cfg, 'train', [task_id], batch_size=cfg.EXTERNAL.BATCH_SIZE,
+ max_instance=num_instances) for task_id in range(task_num)]
+ data_loader = DataLoader(IIDDataset(data_loaders), batch_size=cfg.EXTERNAL.BATCH_SIZE)
+ num_instances *= task_num
+ else:
+ data_loader = loader_func(cfg, 'train', [task_id], batch_size=cfg.EXTERNAL.BATCH_SIZE,
+ max_instance=num_instances)
+
+ best_avg_acc = -1
+ #model.net.set_task(task_id) # choose the classifier head if the model supports
+ for epoch in range(num_epoch):
+ seen = 0
+ for i, data in enumerate(data_loader):
+ if seen >= num_instances: break
+ inputs, labels = data
+ inputs, labels = (inputs.to(device), labels.to(device))
+ task_ids = torch.LongTensor([task_id] * labels.size(0)).to(inputs.device)
+ inputs = inputs.flatten(1)
+ model.observe(inputs, labels, task_ids=task_ids)
+ step += 1
+ if do_exp_lr_decay:
+ exp_decay_lr(optimizer, step, total_step, base_lr)
+
+ seen += labels.size(0)
+ # run evaluation
+ with torch.no_grad():
+ if iid:
+ accs, _, avg_acc = inference_mnist(model, task_num, loader_func, device, tune=tune)
+ else:
+ accs, _, avg_acc = inference_mnist(model, task_id + 1, loader_func, device, tune=tune)
+ logger.info('Epoch {}\tTask {}\tAcc {}'.format(epoch, task_id, avg_acc))
+ for i, acc in enumerate(accs):
+ logger.info('::Val Task {}\t Acc {}'.format(i, acc))
+ all_accs.append(accs)
+ if avg_acc > best_avg_acc:
+ best_avg_acc = avg_acc
+ else:
+ break
+ best_avg_accs.append(best_avg_acc)
+ file_name = 'result.json' if not tune else 'result_tune_k{}.json'.format(task_num)
+ result_file = open(os.path.join(cfg.OUTPUT_DIR, file_name), 'w')
+ json.dump({'all_accs': all_accs, 'avg_acc': avg_acc}, result_file, indent=4)
+ result_file.close()
+
+
+def inference_mnist(model, max_task, loader_func, device, tune=False):
+ model.train(False)
+ accs, instance_nums = [], []
+ for val_task_id in range(0, max_task):
+ #task_id = 0
+ all_pred, all_truth = [], []
+ val_data_loader = loader_func(model.cfg, 'test' if not tune else 'val', [val_task_id],
+ batch_size=model.cfg.EXTERNAL.BATCH_SIZE)
+ print('-------len val data loader {}-------'.format(len(val_data_loader)))
+ for i, data in enumerate(val_data_loader):
+
+ inputs, labels = data
+ inputs, labels = (inputs.to(device), labels.to(device))
+ task_ids = torch.LongTensor([val_task_id] * labels.size(0)).to(inputs.device)
+ ret_dict = model(inputs, labels, task_ids=task_ids)
+ score = ret_dict['score']
+ _, pred = torch.max(score, -1)
+ all_pred.extend(to_list(pred))
+ all_truth.extend(to_list(labels))
+ acc = f1_score(all_truth, all_pred, average='micro')
+ accs.append(acc)
+ instance_nums.append(len(all_pred))
+ total_instance_num = sum(instance_nums)
+ model.train(True)
+ return accs, instance_nums, sum([x * y / total_instance_num for x,y in zip(accs, instance_nums)])
+
+
+def ocl_train_cifar(model, optimizer, checkpointer, device, arguments, writer, epoch,
+ goal='split_cifar', tune=False):
+ logger = logging.getLogger("maskrcnn_benchmark.trainer")
+ logger.info("Start training @ epoch {:02d}".format(arguments['epoch']))
+ model.train()
+ cfg = model.cfg
+
+ num_epoch = cfg.CIFAR.EPOCH
+ if goal == 'split_cifar':
+ loader_func = get_split_cifar_dataloader
+ total_step = 4750
+ elif goal == 'split_cifar100':
+ loader_func = get_split_cifar100_dataloader
+ total_step = 25000
+ else:
+ loader_func = get_split_mini_imagenet_dataloader
+ total_step = 22500
+ max_instance = cfg.CIFAR.INSTANCE_NUM if hasattr(cfg.CIFAR, 'INSTANCE_NUM') else 1e10
+ if not tune:
+ task_num = get_config_attr(cfg, 'EXTERNAL.OCL.TASK_NUM', totype=int)
+ else:
+ task_num = get_config_attr(cfg, 'EXTERNAL.OCL.TASK_NUM', totype=int)
+
+
+ do_exp_lr_decay = get_config_attr(cfg,'EXTERNAL.OCL.EXP_LR_DECAY',0)
+ base_lr = get_config_attr(cfg,'SOLVER.BASE_LR',totype=float)
+ step = 0
+
+ num_epoch = get_config_attr(cfg, 'EXTERNAL.EPOCH', totype=int, default=1)
+ all_accs = []
+ best_avg_accs = []
+ iid = not get_config_attr(cfg, 'EXTERNAL.OCL.ACTIVATED', totype=bool)
+ for task_id in range(task_num):
+ if iid:
+ if task_id != 0: break
+ data_loaders = [loader_func(cfg, 'train', [task_id], batch_size=cfg.EXTERNAL.BATCH_SIZE,
+ max_instance=max_instance) for task_id in range(task_num)]
+ data_loader = DataLoader(IIDDataset(data_loaders), batch_size=cfg.EXTERNAL.BATCH_SIZE)
+ max_instance *= task_num
+ else:
+ data_loader = loader_func(cfg, 'train', [task_id], batch_size=cfg.EXTERNAL.BATCH_SIZE, max_instance=max_instance)
+ pbar = tqdm(
+ position=0,
+ desc='GPU: 0',
+ total=len(data_loader)
+ )
+ best_avg_acc = -1
+ for epoch in range(num_epoch):
+ seen = 0
+ for i, data in enumerate(data_loader):
+ if seen >= max_instance: break
+ pbar.update(1)
+ inputs, labels = data
+ inputs, labels = (inputs.to(device), labels.to(device))
+ inputs = inputs.flatten(1)
+ task_ids = torch.LongTensor([task_id] * labels.size(0)).to(inputs.device)
+ model.observe(inputs, labels, task_ids)
+ seen += inputs.size(0)
+ if do_exp_lr_decay:
+ exp_decay_lr(optimizer, step, total_step, base_lr)
+ step += 1
+ # # run evaluation
+ with torch.no_grad():
+ if iid:
+ accs, _, avg_acc = inference_cifar(model, task_num, loader_func, device, goal, tune=tune)
+ else:
+ accs, _, avg_acc = inference_cifar(model, task_id + 1, loader_func, device, goal, tune=tune)
+ logger.info('Epoch {}\tTask {}\tAcc {}'.format(epoch, task_id, avg_acc))
+ for i, acc in enumerate(accs):
+ logger.info('::Val Task {}\t Acc {}'.format(i, acc))
+ all_accs.append(accs)
+ if avg_acc > best_avg_acc:
+ best_avg_acc = avg_acc
+ else:
+ break
+ best_avg_accs.append(best_avg_acc)
+ file_name = 'result.json' if not tune else 'result_tune_k{}.json'.format(task_num)
+ result_file = open(os.path.join(cfg.OUTPUT_DIR, file_name), 'w')
+ json.dump({'all_accs': all_accs, 'avg_acc': avg_acc, 'best_avg_accs': best_avg_accs}, result_file, indent=4)
+ result_file.close()
+ return best_avg_accs
+
+def inference_cifar(model, max_task, loader_func, device, goal, tune=False):
+ accs, instance_nums = [], []
+ model.train(False)
+ for val_task_id in range(0, max_task):
+ all_pred, all_truth = [], []
+ val_data_loader = loader_func(model.cfg, 'test' if not tune else 'val', [val_task_id], batch_size=model.cfg.EXTERNAL.BATCH_SIZE)
+ for i, data in enumerate(val_data_loader):
+ inputs, labels = data
+ inputs, labels = (inputs.to(device), labels.to(device))
+ inputs = inputs.view(-1, 3, 32, 32) if goal == 'split_cifar' or goal == 'split_cifar100' else inputs.view(-1, 3, 84, 84)
+ task_ids = torch.LongTensor([val_task_id] * labels.size(0)).to(inputs.device)
+ if model.cfg.EXTERNAL.OCL.ALGO == 'CNDPM':
+ score = model(inputs)
+ else:
+ ret_dict = model(bbox_images=inputs, spatial_feat=None, attr_labels=labels,
+ obj_labels=None, images=None, task_ids=task_ids)
+ score = ret_dict['score']
+ _, pred = torch.max(score, -1)
+ all_pred.extend(to_list(pred))
+ all_truth.extend(to_list(labels))
+ acc = f1_score(all_truth, all_pred, average='micro')
+ accs.append(acc)
+ instance_nums.append(len(all_pred))
+ total_instance_num = sum(instance_nums)
+ model.train(True)
+ return accs, instance_nums, sum([x * y / total_instance_num for x,y in zip(accs, instance_nums)])
+
+
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000..b7be57f
--- /dev/null
+++ b/utils/__init__.py
@@ -0,0 +1,2 @@
+import sys
+sys.path.append('.')
\ No newline at end of file
diff --git a/utils/build_transforms.py b/utils/build_transforms.py
new file mode 100644
index 0000000..3a02647
--- /dev/null
+++ b/utils/build_transforms.py
@@ -0,0 +1,45 @@
+# from maskrcnn_benchmark.data.transforms import transforms as T
+from torchvision.transforms import transforms as T
+
+def build_transforms(cfg, split="train"):
+ if split=="train":
+ min_size = min(cfg.EXTERNAL.IMAGE.HEIGHT,cfg.EXTERNAL.IMAGE.WIDTH)
+ max_size = max(cfg.EXTERNAL.IMAGE.HEIGHT,cfg.EXTERNAL.IMAGE.WIDTH)
+ flip_horizontal_prob = 0.5 # cfg.INPUT.FLIP_PROB_TRAIN
+ flip_vertical_prob = cfg.INPUT.VERTICAL_FLIP_PROB_TRAIN
+ brightness = cfg.INPUT.BRIGHTNESS
+ contrast = cfg.INPUT.CONTRAST
+ saturation = cfg.INPUT.SATURATION
+ hue = cfg.INPUT.HUE
+ else:
+ min_size = min(cfg.EXTERNAL.IMAGE.HEIGHT, cfg.EXTERNAL.IMAGE.WIDTH)
+ max_size = max(cfg.EXTERNAL.IMAGE.HEIGHT, cfg.EXTERNAL.IMAGE.WIDTH)
+ flip_horizontal_prob = 0.0
+ flip_vertical_prob = 0.0
+ brightness = 0.0
+ contrast = 0.0
+ saturation = 0.0
+ hue = 0.0
+
+ to_bgr255 = cfg.INPUT.TO_BGR255
+ normalize_transform = T.Normalize(
+ mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD,
+ )
+ color_jitter = T.ColorJitter(
+ brightness=brightness,
+ contrast=contrast,
+ saturation=saturation,
+ hue=hue,
+ )
+
+ transform = T.Compose(
+ [
+ color_jitter,
+ T.Resize(min_size, max_size),
+ T.RandomHorizontalFlip(flip_horizontal_prob),
+ T.RandomVerticalFlip(flip_vertical_prob),
+ T.ToTensor(),
+ normalize_transform,
+ ]
+ )
+ return transform
\ No newline at end of file
diff --git a/utils/tupperware.py b/utils/tupperware.py
new file mode 100644
index 0000000..8961239
--- /dev/null
+++ b/utils/tupperware.py
@@ -0,0 +1,73 @@
+from collections import UserDict
+import collections
+from recordclass import recordclass
+
+def tupperware(mapping):
+ """ Convert mappings to 'tupperwares' recursively.
+ Lets you use dicts like they're JavaScript Object Literals (~=JSON)...
+ It recursively turns mappings (dictionaries) into namedtuples.
+ Thus, you can cheaply create an object whose attributes are accessible
+ by dotted notation (all the way down).
+ Use cases:
+ * Fake objects (useful for dependency injection when you're making
+ fakes/stubs that are simpler than proper mocks)
+ * Storing data (like fixtures) in a structured way, in Python code
+ (data whose initial definition reads nicely like JSON). You could do
+ this with dictionaries, but namedtuples are immutable, and their
+ dotted notation can be clearer in some contexts.
+ .. doctest::
+ >>> t = tupperware({
+ ... 'foo': 'bar',
+ ... 'baz': {'qux': 'quux'},
+ ... 'tito': {
+ ... 'tata': 'tutu',
+ ... 'totoro': 'tots',
+ ... 'frobnicator': ['this', 'is', 'not', 'a', 'mapping']
+ ... }
+ ... })
+ >>> t # doctest: +ELLIPSIS
+ Tupperware(tito=Tupperware(...), foo='bar', baz=Tupperware(qux='quux'))
+ >>> t.tito # doctest: +ELLIPSIS
+ Tupperware(frobnicator=[...], tata='tutu', totoro='tots')
+ >>> t.tito.tata
+ 'tutu'
+ >>> t.tito.frobnicator
+ ['this', 'is', 'not', 'a', 'mapping']
+ >>> t.foo
+ 'bar'
+ >>> t.baz.qux
+ 'quux'
+ Args:
+ mapping: An object that might be a mapping. If it's a mapping, convert
+ it (and all of its contents that are mappings) to namedtuples
+ (called 'Tupperwares').
+ Returns:
+ A tupperware (a namedtuple (of namedtuples (of namedtuples (...)))).
+ If argument is not a mapping, it just returns it (this enables the
+ recursion).
+ """
+
+ if (isinstance(mapping, collections.Mapping) and
+ not isinstance(mapping, ProtectedDict)):
+ for key, value in mapping.items():
+ mapping[key] = tupperware(value)
+ return namedtuple_from_mapping(mapping)
+ return mapping
+
+
+def namedtuple_from_mapping(mapping, name="Tupperware"):
+ # this_namedtuple_maker = collections.namedtuple(name, mapping.keys())
+ this_namedtuple_maker = recordclass(name, mapping.keys())
+ return this_namedtuple_maker(**mapping)
+
+
+class ProtectedDict(UserDict):
+ """ A class that exists just to tell `tupperware` not to eat it.
+ `tupperware` eats all dicts you give it, recursively; but what if you
+ actually want a dictionary in there? This will stop it. Just do
+ ProtectedDict({...}) or ProtectedDict(kwarg=foo).
+ """
+
+
+def tupperware_from_kwargs(**kwargs):
+ return tupperware(kwargs)
diff --git a/utils/utils.py b/utils/utils.py
new file mode 100644
index 0000000..5ba1d07
--- /dev/null
+++ b/utils/utils.py
@@ -0,0 +1,291 @@
+from collections import Counter
+import uuid, time, datetime, os, torch, logging
+from collections import OrderedDict
+import numpy as np
+
+def get_top_k_by_frequency(vocab, top_k=None):
+ '''
+ Get the top k elements from a list by frequency
+ :param vocab: the list
+ :parm top_k: top_k
+ :return: top k elements
+ '''
+ counter = Counter(vocab)
+
+ sorted_vocab = sorted(
+ [t for t in counter],
+ key=counter.get,
+ reverse=True
+ )
+
+ if top_k:
+ return sorted_vocab[:top_k]
+
+ return sorted_vocab
+
+def get_exp_id():
+ return uuid.uuid4().hex[:6]
+
+
+class Timer(object):
+ def __init__(self):
+ self.reset()
+
+ @property
+ def average_time(self):
+ return self.total_time / self.calls if self.calls > 0 else 0.0
+
+ def tic(self):
+ # using time.time instead of time.clock because time time.clock
+ # does not normalize for multithreading
+ self.start_time = time.time()
+
+ def toc(self, average=True):
+ self.add(time.time() - self.start_time)
+ if average:
+ return self.average_time
+ else:
+ return self.diff
+
+ def add(self, time_diff):
+ self.diff = time_diff
+ self.total_time += self.diff
+ self.calls += 1
+
+ def reset(self):
+ self.total_time = 0.0
+ self.calls = 0
+ self.start_time = 0.0
+ self.diff = 0.0
+
+ def avg_time_str(self):
+ time_str = str(datetime.timedelta(seconds=self.average_time))
+ return time_str
+
+ class Checkpointer(object):
+ def __init__(
+ self,
+ model,
+ optimizer=None,
+ scheduler=None,
+ save_dir="",
+ save_to_disk=None,
+ logger=None,
+ ):
+ self.model = model
+ self.optimizer = optimizer
+ self.scheduler = scheduler
+ self.save_dir = save_dir
+ self.save_to_disk = save_to_disk
+ if logger is None:
+ logger = logging.getLogger(__name__)
+ self.logger = logger
+
+ def save(self, name, **kwargs):
+ if not self.save_dir:
+ return
+
+ if not self.save_to_disk:
+ return
+
+ data = {}
+ data["model"] = self.model.state_dict()
+ if self.optimizer is not None:
+ data["optimizer"] = self.optimizer.state_dict()
+ if self.scheduler is not None:
+ data["scheduler"] = self.scheduler.state_dict()
+ data.update(kwargs)
+
+ save_file = os.path.join(self.save_dir, "{}.pth".format(name))
+ self.logger.info("Saving checkpoint to {}".format(save_file))
+ torch.save(data, save_file)
+ self.tag_last_checkpoint(save_file)
+
+ def load(self, f=None, use_latest=True):
+ if self.has_checkpoint() and use_latest:
+ # override argument with existing checkpoint
+ f = self.get_checkpoint_file()
+ if not f:
+ # no checkpoint could be found
+ self.logger.info("No checkpoint found. Initializing model from scratch")
+ return {}
+ self.logger.info("Loading checkpoint from {}".format(f))
+ checkpoint = self._load_file(f)
+ self._load_model(checkpoint)
+ if "optimizer" in checkpoint and self.optimizer:
+ self.logger.info("Loading optimizer from {}".format(f))
+ self.optimizer.load_state_dict(checkpoint.pop("optimizer"))
+ if "scheduler" in checkpoint and self.scheduler:
+ self.logger.info("Loading scheduler from {}".format(f))
+ self.scheduler.load_state_dict(checkpoint.pop("scheduler"))
+
+ # return any further checkpoint data
+ return checkpoint
+
+ def has_checkpoint(self):
+ save_file = os.path.join(self.save_dir, "last_checkpoint")
+ return os.path.exists(save_file)
+
+ def get_checkpoint_file(self):
+ save_file = os.path.join(self.save_dir, "last_checkpoint")
+ try:
+ with open(save_file, "r") as f:
+ last_saved = f.read()
+ last_saved = last_saved.strip()
+ except IOError:
+ # if file doesn't exist, maybe because it has just been
+ # deleted by a separate process
+ last_saved = ""
+ return last_saved
+
+ def tag_last_checkpoint(self, last_filename):
+ save_file = os.path.join(self.save_dir, "last_checkpoint")
+ with open(save_file, "w") as f:
+ f.write(last_filename)
+
+ def _load_file(self, f):
+ return torch.load(f, map_location=torch.device("cpu"))
+
+ def _load_model(self, checkpoint):
+ load_state_dict(self.model, checkpoint.pop("model"))
+
+def load_state_dict(model, loaded_state_dict):
+ model_state_dict = model.state_dict()
+ # if the state_dict comes from a model that was wrapped in a
+ # DataParallel or DistributedDataParallel during serialization,
+ # remove the "module" prefix before performing the matching
+ loaded_state_dict = strip_prefix_if_present(loaded_state_dict, prefix="module.")
+ align_and_update_state_dicts(model_state_dict, loaded_state_dict)
+
+ # use strict loading
+ model.load_state_dict(model_state_dict)
+
+def align_and_update_state_dicts(model_state_dict, loaded_state_dict):
+ """
+ Strategy: suppose that the models that we will create will have prefixes appended
+ to each of its keys, for example due to an extra level of nesting that the original
+ pre-trained weights from ImageNet won't contain. For example, model.state_dict()
+ might return backbone[0].body.res2.conv1.weight, while the pre-trained model contains
+ res2.conv1.weight. We thus want to match both parameters together.
+ For that, we look for each model weight, look among all loaded keys if there is one
+ that is a suffix of the current weight name, and use it if that's the case.
+ If multiple matches exist, take the one with longest size
+ of the corresponding name. For example, for the same model as before, the pretrained
+ weight file can contain both res2.conv1.weight, as well as conv1.weight. In this case,
+ we want to match backbone[0].body.conv1.weight to conv1.weight, and
+ backbone[0].body.res2.conv1.weight to res2.conv1.weight.
+ """
+ current_keys = sorted(list(model_state_dict.keys()))
+ loaded_keys = sorted(list(loaded_state_dict.keys()))
+ # get a matrix of string matches, where each (i, j) entry correspond to the size of the
+ # loaded_key string, if it matches
+ match_matrix = [
+ len(j) if i.endswith(j) else 0 for i in current_keys for j in loaded_keys
+ ]
+ match_matrix = torch.as_tensor(match_matrix).view(
+ len(current_keys), len(loaded_keys)
+ )
+ max_match_size, idxs = match_matrix.max(1)
+ # remove indices that correspond to no-match
+ idxs[max_match_size == 0] = -1
+
+ # used for logging
+ max_size = max([len(key) for key in current_keys]) if current_keys else 1
+ max_size_loaded = max([len(key) for key in loaded_keys]) if loaded_keys else 1
+ log_str_template = "{: <{}} loaded from {: <{}} of shape {}"
+ logger = logging.getLogger(__name__)
+ for idx_new, idx_old in enumerate(idxs.tolist()):
+ if idx_old == -1:
+ continue
+ key = current_keys[idx_new]
+ key_old = loaded_keys[idx_old]
+ model_state_dict[key] = loaded_state_dict[key_old]
+ logger.info(
+ log_str_template.format(
+ key,
+ max_size,
+ key_old,
+ max_size_loaded,
+ tuple(loaded_state_dict[key_old].shape),
+ )
+ )
+
+
+def strip_prefix_if_present(state_dict, prefix):
+ keys = sorted(state_dict.keys())
+ if not all(key.startswith(prefix) for key in keys):
+ return state_dict
+ stripped_state_dict = OrderedDict()
+ for key, value in state_dict.items():
+ stripped_state_dict[key.replace(prefix, "")] = value
+ return stripped_state_dict
+
+def get_config_attr(cfg, attr_string, default=None, totype=None, mute=False):
+ try:
+ attrs = attr_string.split('.')
+ obj = cfg
+ for s in attrs:
+ obj = getattr(obj, s)
+ if totype is None:
+ return type(default)(obj)
+ else:
+ if totype is bool and obj not in ['True','False',True,False]:
+ raise ValueError('malformed boolean input: {}, {}'.format(obj,type(obj)))
+ if totype is bool and obj in ['False',False]:
+ return False
+ return totype(obj)
+ except AttributeError:
+ if not mute:
+ print('Warning: attribute {} not found. Default: {}'.format(attr_string, default))
+ return default
+
+class DotDict(dict):
+ __getattr__ = dict.__getitem__
+ __setattr__ = dict.__setitem__
+ __delattr__ = dict.__delitem__
+
+ def __init__(self, **dct):
+ for key, value in dct.items():
+ #if hasattr(value, 'keys'):
+ # value = DotDict(**value)
+ self[key] = value
+
+
+def set_config_attr(cfg, attr_key, attr_value):
+ attrs = attr_key.split('.')
+ obj = cfg
+ for attr in attrs[:-1]:
+ if not hasattr(obj, attr):
+ setattr(obj, attr, DotDict())
+ obj = getattr(obj, attr)
+
+
+ try:
+ attr_value = int(attr_value)
+ except ValueError:
+ try:
+ attr_value = float(attr_value)
+ except ValueError:
+ pass
+ pass
+
+ if attr_value == 'True': attr_value = True
+ if attr_value == 'False': attr_value = False
+
+ setattr(obj, attrs[-1], attr_value)
+
+
+def filter_outliers(l):
+ q1 = np.quantile(l, 0.25)
+ q2 = np.quantile(l, 0.75)
+ iqr = q2 - q1
+ lb, rb = q1 - 1.5 * iqr, q2 + 1.5 * iqr
+ l = [x for x in l if lb <= x <= rb]
+ return l
+
+def set_cfg_from_args(args, cfg):
+ cfg_params = args.cfg
+ if cfg_params is None: return
+ for param in cfg_params:
+ k, v = param.split('=')
+ set_config_attr(cfg, k, v)