diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a251190 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +.idea/ +__pycache__/ +*.pyc +.DS_Store \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..05dd38b --- /dev/null +++ b/README.md @@ -0,0 +1,39 @@ +# GMED-anonymous-submission + +## Links for downloadable datasets: +- MNIST dataset + + http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz + http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz + http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz + http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz + +- CIFAR-10 and CIFAR-100 datsets + + https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz + https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz + + +- mini-ImageNet + + https://data.deepai.org/miniimagenet.zip + + The partition of the datasets into tasks will be performed each time you train the model. + +## Running experiments +Running experiments on MNIST + +``` +export stride=0.1 # editing stride (alpha) +export reg=0.01 # reg strength (beta) +export dataset="split_mnist" +export mem=500 +export start_seed=100 +export stop_seed=109 +export method="-" # change to "mirr" for MIR+GMED + +./scripts/gme_tune_${dataset}.sh ${dataset}_iter1_stride${stride}_1_0_2${reg}_m${mem} 1 ${stride} 1 0 2 ${mem} ${method} ${start_seed} ${stop_seed} "OUTPUT_DIR=runs EXTERNAL.OCL.REG_STRENGTH=${reg}" +``` + +Similary, experiments on other datasets can be run by changing the name of the dataset variable above. + diff --git a/config.py b/config.py new file mode 100644 index 0000000..76183f7 --- /dev/null +++ b/config.py @@ -0,0 +1,11 @@ +class _Config: + def __init__(self): + self.image_size = 224 + self.vocab_path = '' + + def update(self, cfg): + for k in cfg.__dict__: + if k not in self.__dict__: + setattr(self, k, cfg.__dict__[k]) + +cfg = _Config() \ No newline at end of file diff --git a/configs/memevolve/permuted_mnist.yaml b/configs/memevolve/permuted_mnist.yaml new file mode 100644 index 0000000..e6ff274 --- /dev/null +++ b/configs/memevolve/permuted_mnist.yaml @@ -0,0 +1,102 @@ +# config file +MNIST: + ACTIVATED: True + TASK: 'permute' + EPOCH: 1 + INSTANCE_NUM: 1000 +EXTERNAL: + IMAGE_IDS: [] + OBJECT_NAMES: [] + # OBJECT_NAMES: [] + ATTRIBUTES: [] + + OBJECT_TOP_K: 100 + # 0 refers to all + ATTRIBUTE_TOP_K: 0 + RELOAD_SCENE_GRAPH: False + + # Change the directory of your own + IMG_DIR: "/home/xisen/online-concept-learning/GQA/images" + TRAIN_SCENE_GRAPH_PATH: "/home/xisen/online-concept-learning/GQA/sceneGraphs/train_sceneGraphs.json" + TRAIN_SCENE_GRAPH_DUMP_PATH: "/home/xisen/online-concept-learning/GQADump/sceneGraphs/train_sceneGraphs_dump.pkl" + VAL_SCENE_GRAPH_PATH: "/home/xisen/online-concept-learning/GQA/sceneGraphs/val_sceneGraphs.json" + VAL_SCENE_GRAPH_DUMP_PATH: "/home/xisen/online-concept-learning/GQADump/sceneGraphs/val_sceneGraphs_dump.pkl" + + NUM_WORKERS: 0 + PIN_MEMORY: True + SHUFFLE: False + # IMAGES PER GPU + BATCH_SIZE: 10 + BALANCE_ATTRIBUTES: False + IMAGE_SIZE: 224 + IMAGE: + HEIGHT: 640 + WIDTH: 640 + REPLAY: + MEM_BS: 10 + MEM_LIMIT: 100 + FILTER_SELF: 0 + + # 0 for unknown + ROI_BOX_HEAD: + NUM_ATTR: 13 + + OPTIMIZER: + ADAM: False + + OCL: + ACTIVATED: True + SORT_BY_ATTRIBUTES: False + ALGO: "VERX" + VOCAB: "/home/xisen/online-concept-learning/vocab/vocab_gqa_full.pkl" +INPUT: + MIN_SIZE_TRAIN: (640,) + MAX_SIZE_TRAIN: 640 + MIN_SIZE_TEST: 640 + MAX_SIZE_TEST: 640 +MODEL: + META_ARCHITECTURE: "GeneralizedRCNN" + # WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50" # ResNet-50 on Imagenet + WEIGHT: "./pretrained-models/e2e_faster_rcnn_R_50_FPN_1x_trimmed.pth" # faster-rcnn-r50 + BACKBONE: + CONV_BODY: "R-50-FPN" + RESNETS: + BACKBONE_OUT_CHANNELS: 256 + RPN: + USE_FPN: True + ANCHOR_STRIDE: (4, 8, 16, 32, 64) + PRE_NMS_TOP_N_TRAIN: 1200 + PRE_NMS_TOP_N_TEST: 1200 + PRE_NMS_TOP_N_TEST: 200 + POST_NMS_TOP_N_TEST: 200 + FPN_POST_NMS_TOP_N_TRAIN: 1000 + FPN_POST_NMS_TOP_N_TEST: 1000 + ROI_HEADS: + USE_FPN: True + # BG_IOU_THRESHOLD: 0.3 + # FG_IOU_THRESHOLD: 0.5 + # OHEM: False + ROI_BOX_HEAD: + POOLER_RESOLUTION: 7 + POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125) + POOLER_SAMPLING_RATIO: 2 + # MLP_HEAD_DIM: 2048 + FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor" + PREDICTOR: "FPNPredictor" + NUM_CLASSES: 1601 + MASK_ON: False +TYPE: "float32" +DATASETS: + TRAIN: ("coco_gqa_train",) + TEST: ("coco_gqa_val",) +DATALOADER: + SIZE_DIVISIBILITY: 32 +SOLVER: + BASE_LR: 0.05 + MOMENTUM: 0. + WEIGHT_DECAY: 0.00001 + STEPS: (12500, ) + MAX_ITER: 30000 + IMS_PER_BATCH: 32 +TEST: + IMS_PER_BATCH: 32 diff --git a/configs/memevolve/rotated_mnist.yaml b/configs/memevolve/rotated_mnist.yaml new file mode 100644 index 0000000..2527edd --- /dev/null +++ b/configs/memevolve/rotated_mnist.yaml @@ -0,0 +1,106 @@ +# config file +MNIST: + ACTIVATED: True + TASK: 'rotate' + EPOCH: 1 + INSTANCE_NUM: 1000 +EXTERNAL: + IMAGE_IDS: [] + OBJECT_NAMES: [] + # OBJECT_NAMES: [] + ATTRIBUTES: [] + + OBJECT_TOP_K: 100 + # 0 refers to all + ATTRIBUTE_TOP_K: 0 + RELOAD_SCENE_GRAPH: False + + # Change the directory of your own + IMG_DIR: "/home/xisen/online-concept-learning/GQA/images" + TRAIN_SCENE_GRAPH_PATH: "/home/xisen/online-concept-learning/GQA/sceneGraphs/train_sceneGraphs.json" + TRAIN_SCENE_GRAPH_DUMP_PATH: "/home/xisen/online-concept-learning/GQADump/sceneGraphs/train_sceneGraphs_dump.pkl" + VAL_SCENE_GRAPH_PATH: "/home/xisen/online-concept-learning/GQA/sceneGraphs/val_sceneGraphs.json" + VAL_SCENE_GRAPH_DUMP_PATH: "/home/xisen/online-concept-learning/GQADump/sceneGraphs/val_sceneGraphs_dump.pkl" + + NUM_WORKERS: 0 + PIN_MEMORY: True + SHUFFLE: False + # IMAGES PER GPU + BATCH_SIZE: 10 + BALANCE_ATTRIBUTES: False + IMAGE_SIZE: 224 + IMAGE: + HEIGHT: 640 + WIDTH: 640 + REPLAY: + MEM_BS: 10 + MEM_LIMIT: 200 + FILTER_SELF: 0 + + # 0 for unknown + ROI_BOX_HEAD: + NUM_ATTR: 13 + + OPTIMIZER: + ADAM: False + + OCL: + ACTIVATED: True + SORT_BY_ATTRIBUTES: False + ALGO: "VERX" + TASK_INCREMENTAL: False + TASK_NUM: 20 + CLASS_NUM: 10 + + VOCAB: "/home/xisen/online-concept-learning/vocab/vocab_gqa_full.pkl" +INPUT: + MIN_SIZE_TRAIN: (640,) + MAX_SIZE_TRAIN: 640 + MIN_SIZE_TEST: 640 + MAX_SIZE_TEST: 640 +MODEL: + META_ARCHITECTURE: "GeneralizedRCNN" + # WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50" # ResNet-50 on Imagenet + WEIGHT: "./pretrained-models/e2e_faster_rcnn_R_50_FPN_1x_trimmed.pth" # faster-rcnn-r50 + BACKBONE: + CONV_BODY: "R-50-FPN" + RESNETS: + BACKBONE_OUT_CHANNELS: 256 + RPN: + USE_FPN: True + ANCHOR_STRIDE: (4, 8, 16, 32, 64) + PRE_NMS_TOP_N_TRAIN: 1200 + PRE_NMS_TOP_N_TEST: 1200 + PRE_NMS_TOP_N_TEST: 200 + POST_NMS_TOP_N_TEST: 200 + FPN_POST_NMS_TOP_N_TRAIN: 1000 + FPN_POST_NMS_TOP_N_TEST: 1000 + ROI_HEADS: + USE_FPN: True + # BG_IOU_THRESHOLD: 0.3 + # FG_IOU_THRESHOLD: 0.5 + # OHEM: False + ROI_BOX_HEAD: + POOLER_RESOLUTION: 7 + POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125) + POOLER_SAMPLING_RATIO: 2 + # MLP_HEAD_DIM: 2048 + FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor" + PREDICTOR: "FPNPredictor" + NUM_CLASSES: 1601 + MASK_ON: False +TYPE: "float32" +DATASETS: + TRAIN: ("coco_gqa_train",) + TEST: ("coco_gqa_val",) +DATALOADER: + SIZE_DIVISIBILITY: 32 +SOLVER: + BASE_LR: 0.05 + MOMENTUM: 0. + WEIGHT_DECAY: 0.00001 + STEPS: (12500, ) + MAX_ITER: 30000 + IMS_PER_BATCH: 32 +TEST: + IMS_PER_BATCH: 32 diff --git a/configs/memevolve/split_cifar10.yaml b/configs/memevolve/split_cifar10.yaml new file mode 100644 index 0000000..fefd4a6 --- /dev/null +++ b/configs/memevolve/split_cifar10.yaml @@ -0,0 +1,107 @@ +# config file +CIFAR: + ACTIVATED: True + EPOCH: 1 + DATASET: "CIFAR10" +VERX: + LOSS1: 0 + LOSS2: 0 +EXTERNAL: + IMAGE_IDS: [] + OBJECT_NAMES: [] + # OBJECT_NAMES: [] + ATTRIBUTES: [] + + OBJECT_TOP_K: 100 + # 0 refers to all + ATTRIBUTE_TOP_K: 0 + RELOAD_SCENE_GRAPH: False + + # Change the directory of your own + IMG_DIR: "/home/xisen/online-concept-learning/GQA/images" + TRAIN_SCENE_GRAPH_PATH: "/home/xisen/online-concept-learning/GQA/sceneGraphs/train_sceneGraphs.json" + TRAIN_SCENE_GRAPH_DUMP_PATH: "/home/xisen/online-concept-learning/GQADump/sceneGraphs/train_sceneGraphs_dump.pkl" + VAL_SCENE_GRAPH_PATH: "/home/xisen/online-concept-learning/GQA/sceneGraphs/val_sceneGraphs.json" + VAL_SCENE_GRAPH_DUMP_PATH: "/home/xisen/online-concept-learning/GQADump/sceneGraphs/val_sceneGraphs_dump.pkl" + + NUM_WORKERS: 0 + PIN_MEMORY: True + SHUFFLE: False + # IMAGES PER GPU + BATCH_SIZE: 10 + BALANCE_ATTRIBUTES: False + IMAGE_SIZE: 224 + IMAGE: + HEIGHT: 640 + WIDTH: 640 + REPLAY: + MEM_BS: 10 + MEM_LIMIT: 500 + FILTER_SELF: 0 + + # 0 for unknown + ROI_BOX_HEAD: + NUM_ATTR: 13 + + OPTIMIZER: + ADAM: False + + OCL: + ACTIVATED: True + SORT_BY_ATTRIBUTES: False + ALGO: "VERX" + TASK_NUM: 5 + CLASS_NUM: 2 + + VOCAB: "/home/xisen/online-concept-learning/vocab/vocab_gqa_full.pkl" +INPUT: + MIN_SIZE_TRAIN: (640,) + MAX_SIZE_TRAIN: 640 + MIN_SIZE_TEST: 640 + MAX_SIZE_TEST: 640 +MODEL: + META_ARCHITECTURE: "GeneralizedRCNN" + # WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50" # ResNet-50 on Imagenet + WEIGHT: "./pretrained-models/e2e_faster_rcnn_R_50_FPN_1x_trimmed.pth" # faster-rcnn-r50 + BACKBONE: + CONV_BODY: "R-50-FPN" + RESNETS: + BACKBONE_OUT_CHANNELS: 256 + RPN: + USE_FPN: True + ANCHOR_STRIDE: (4, 8, 16, 32, 64) + PRE_NMS_TOP_N_TRAIN: 1200 + PRE_NMS_TOP_N_TEST: 1200 + PRE_NMS_TOP_N_TEST: 200 + POST_NMS_TOP_N_TEST: 200 + FPN_POST_NMS_TOP_N_TRAIN: 1000 + FPN_POST_NMS_TOP_N_TEST: 1000 + ROI_HEADS: + USE_FPN: True + # BG_IOU_THRESHOLD: 0.3 + # FG_IOU_THRESHOLD: 0.5 + # OHEM: False + ROI_BOX_HEAD: + POOLER_RESOLUTION: 7 + POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125) + POOLER_SAMPLING_RATIO: 2 + # MLP_HEAD_DIM: 2048 + FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor" + PREDICTOR: "FPNPredictor" + NUM_CLASSES: 1601 + MASK_ON: False +TYPE: "float32" +DATASETS: + TRAIN: ("coco_gqa_train",) + TEST: ("coco_gqa_val",) +DATALOADER: + SIZE_DIVISIBILITY: 32 +SOLVER: + BASE_LR: 0.1 + MOMENTUM: 0. + WEIGHT_DECAY: 0.00001 + STEPS: (12500, ) + MAX_ITER: 30000 + IMS_PER_BATCH: 32 +TEST: + IMS_PER_BATCH: 32 diff --git a/configs/memevolve/split_cifar100.yaml b/configs/memevolve/split_cifar100.yaml new file mode 100644 index 0000000..eb5bf2e --- /dev/null +++ b/configs/memevolve/split_cifar100.yaml @@ -0,0 +1,110 @@ +# config file +CIFAR: + ACTIVATED: True + EPOCH: 1 + DATASET: "CIFAR100" + INSTANCE_NUM: 10000 +VERX: + LOSS1: 0 + LOSS2: 0 +EXTERNAL: + IMAGE_IDS: [] + OBJECT_NAMES: [] + # OBJECT_NAMES: [] + ATTRIBUTES: [] + + OBJECT_TOP_K: 100 + # 0 refers to all + ATTRIBUTE_TOP_K: 0 + RELOAD_SCENE_GRAPH: False + + # Change the directory of your own + IMG_DIR: "/home/xisen/online-concept-learning/GQA/images" + TRAIN_SCENE_GRAPH_PATH: "/home/xisen/online-concept-learning/GQA/sceneGraphs/train_sceneGraphs.json" + TRAIN_SCENE_GRAPH_DUMP_PATH: "/home/xisen/online-concept-learning/GQADump/sceneGraphs/train_sceneGraphs_dump.pkl" + VAL_SCENE_GRAPH_PATH: "/home/xisen/online-concept-learning/GQA/sceneGraphs/val_sceneGraphs.json" + VAL_SCENE_GRAPH_DUMP_PATH: "/home/xisen/online-concept-learning/GQADump/sceneGraphs/val_sceneGraphs_dump.pkl" + + NUM_WORKERS: 0 + PIN_MEMORY: True + SHUFFLE: False + # IMAGES PER GPU + BATCH_SIZE: 10 + BALANCE_ATTRIBUTES: False + IMAGE_SIZE: 224 + IMAGE: + HEIGHT: 640 + WIDTH: 640 + REPLAY: + MEM_BS: 10 + MEM_LIMIT: 10000 + FILTER_SELF: 0 + + # 0 for unknown + ROI_BOX_HEAD: + NUM_ATTR: 13 + + OPTIMIZER: + ADAM: False + + OCL: + ACTIVATED: True + SORT_BY_ATTRIBUTES: False + ALGO: "VERX" + TASK_INCREMENTAL: False + TASK_NUM: 20 + CLASS_NUM: 5 + + + VOCAB: "/home/xisen/online-concept-learning/vocab/vocab_gqa_full.pkl" +INPUT: + MIN_SIZE_TRAIN: (640,) + MAX_SIZE_TRAIN: 640 + MIN_SIZE_TEST: 640 + MAX_SIZE_TEST: 640 +MODEL: + META_ARCHITECTURE: "GeneralizedRCNN" + # WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50" # ResNet-50 on Imagenet + WEIGHT: "./pretrained-models/e2e_faster_rcnn_R_50_FPN_1x_trimmed.pth" # faster-rcnn-r50 + BACKBONE: + CONV_BODY: "R-50-FPN" + RESNETS: + BACKBONE_OUT_CHANNELS: 256 + RPN: + USE_FPN: True + ANCHOR_STRIDE: (4, 8, 16, 32, 64) + PRE_NMS_TOP_N_TRAIN: 1200 + PRE_NMS_TOP_N_TEST: 1200 + PRE_NMS_TOP_N_TEST: 200 + POST_NMS_TOP_N_TEST: 200 + FPN_POST_NMS_TOP_N_TRAIN: 1000 + FPN_POST_NMS_TOP_N_TEST: 1000 + ROI_HEADS: + USE_FPN: True + # BG_IOU_THRESHOLD: 0.3 + # FG_IOU_THRESHOLD: 0.5 + # OHEM: False + ROI_BOX_HEAD: + POOLER_RESOLUTION: 7 + POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125) + POOLER_SAMPLING_RATIO: 2 + # MLP_HEAD_DIM: 2048 + FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor" + PREDICTOR: "FPNPredictor" + NUM_CLASSES: 1601 + MASK_ON: False +TYPE: "float32" +DATASETS: + TRAIN: ("coco_gqa_train",) + TEST: ("coco_gqa_val",) +DATALOADER: + SIZE_DIVISIBILITY: 32 +SOLVER: + BASE_LR: 0.03 + MOMENTUM: 0. + WEIGHT_DECAY: 0.00001 + STEPS: (12500, ) + MAX_ITER: 30000 + IMS_PER_BATCH: 32 +TEST: + IMS_PER_BATCH: 32 diff --git a/configs/memevolve/split_mini_imagenet.yaml b/configs/memevolve/split_mini_imagenet.yaml new file mode 100644 index 0000000..22056d4 --- /dev/null +++ b/configs/memevolve/split_mini_imagenet.yaml @@ -0,0 +1,113 @@ +# config file +CIFAR: + ACTIVATED: True + MINI_IMAGENET: 1 + EPOCH: 1 + INSTANCE_NUM: 3000 +VERX: + LOSS1: 0 + LOSS2: 0 + +EXTERNAL: + IMAGE_IDS: [] + OBJECT_NAMES: [] + # OBJECT_NAMES: [] + ATTRIBUTES: [] + + OBJECT_TOP_K: 100 + # 0 refers to all + ATTRIBUTE_TOP_K: 0 + RELOAD_SCENE_GRAPH: False + + # Change the directory of your own + IMG_DIR: "/home/xisen/online-concept-learning/GQA/images" + TRAIN_SCENE_GRAPH_PATH: "/home/xisen/online-concept-learning/GQA/sceneGraphs/train_sceneGraphs.json" + TRAIN_SCENE_GRAPH_DUMP_PATH: "/home/xisen/online-concept-learning/GQADump/sceneGraphs/train_sceneGraphs_dump.pkl" + VAL_SCENE_GRAPH_PATH: "/home/xisen/online-concept-learning/GQA/sceneGraphs/val_sceneGraphs.json" + VAL_SCENE_GRAPH_DUMP_PATH: "/home/xisen/online-concept-learning/GQADump/sceneGraphs/val_sceneGraphs_dump.pkl" + + NUM_WORKERS: 0 + PIN_MEMORY: True + SHUFFLE: False + # IMAGES PER GPU + BATCH_SIZE: 10 + BALANCE_ATTRIBUTES: False + IMAGE_SIZE: 224 + IMAGE: + HEIGHT: 640 + WIDTH: 640 + REPLAY: + MEM_BS: 10 + MEM_LIMIT: 10000 + FILTER_SELF: 0 + # 0 for unknown + ROI_BOX_HEAD: + NUM_ATTR: 13 + + OPTIMIZER: + ADAM: False + USE_LOSS_1: 1 + USE_LOSS_2: -1 + PROJ_REG_LOSS: 0 + + OCL: + ACTIVATED: True + SORT_BY_ATTRIBUTES: False + ALGO: "VERX" + #TASK_INCREMENTAL: True + TASK_NUM: 20 + CLASS_NUM: 5 + N_ITER: 3 + + VOCAB: "/home/xisen/online-concept-learning/vocab/vocab_gqa_full.pkl" +INPUT: + MIN_SIZE_TRAIN: (640,) + MAX_SIZE_TRAIN: 640 + MIN_SIZE_TEST: 640 + MAX_SIZE_TEST: 640 +MODEL: + META_ARCHITECTURE: "GeneralizedRCNN" + # WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50" # ResNet-50 on Imagenet + WEIGHT: "./pretrained-models/e2e_faster_rcnn_R_50_FPN_1x_trimmed.pth" # faster-rcnn-r50 + BACKBONE: + CONV_BODY: "R-50-FPN" + RESNETS: + BACKBONE_OUT_CHANNELS: 256 + RPN: + USE_FPN: True + ANCHOR_STRIDE: (4, 8, 16, 32, 64) + PRE_NMS_TOP_N_TRAIN: 1200 + PRE_NMS_TOP_N_TEST: 1200 + PRE_NMS_TOP_N_TEST: 200 + POST_NMS_TOP_N_TEST: 200 + FPN_POST_NMS_TOP_N_TRAIN: 1000 + FPN_POST_NMS_TOP_N_TEST: 1000 + ROI_HEADS: + USE_FPN: True + # BG_IOU_THRESHOLD: 0.3 + # FG_IOU_THRESHOLD: 0.5 + # OHEM: False + ROI_BOX_HEAD: + POOLER_RESOLUTION: 7 + POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125) + POOLER_SAMPLING_RATIO: 2 + # MLP_HEAD_DIM: 2048 + FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor" + PREDICTOR: "FPNPredictor" + NUM_CLASSES: 1601 + MASK_ON: False +TYPE: "float32" +DATASETS: + TRAIN: ("coco_gqa_train",) + TEST: ("coco_gqa_val",) +DATALOADER: + SIZE_DIVISIBILITY: 32 +SOLVER: + BASE_LR: 0.1 + MOMENTUM: 0. + WEIGHT_DECAY: 0.00001 + STEPS: (12500, ) + MAX_ITER: 30000 + IMS_PER_BATCH: 32 +TEST: + IMS_PER_BATCH: 32 diff --git a/configs/memevolve/split_mnist.yaml b/configs/memevolve/split_mnist.yaml new file mode 100644 index 0000000..ad173ae --- /dev/null +++ b/configs/memevolve/split_mnist.yaml @@ -0,0 +1,102 @@ +# config file +MNIST: + ACTIVATED: True + TASK: 'split' + EPOCH: 1 + INSTANCE_NUM: 1000 +EXTERNAL: + IMAGE_IDS: [] + OBJECT_NAMES: [] + # OBJECT_NAMES: [] + ATTRIBUTES: [] + + OBJECT_TOP_K: 100 + # 0 refers to all + ATTRIBUTE_TOP_K: 0 + RELOAD_SCENE_GRAPH: False + + # Change the directory of your own + IMG_DIR: "/home/xisen/online-concept-learning/GQA/images" + TRAIN_SCENE_GRAPH_PATH: "/home/xisen/online-concept-learning/GQA/sceneGraphs/train_sceneGraphs.json" + TRAIN_SCENE_GRAPH_DUMP_PATH: "/home/xisen/online-concept-learning/GQADump/sceneGraphs/train_sceneGraphs_dump.pkl" + VAL_SCENE_GRAPH_PATH: "/home/xisen/online-concept-learning/GQA/sceneGraphs/val_sceneGraphs.json" + VAL_SCENE_GRAPH_DUMP_PATH: "/home/xisen/online-concept-learning/GQADump/sceneGraphs/val_sceneGraphs_dump.pkl" + + NUM_WORKERS: 0 + PIN_MEMORY: True + SHUFFLE: False + # IMAGES PER GPU + BATCH_SIZE: 10 + BALANCE_ATTRIBUTES: False + IMAGE_SIZE: 224 + IMAGE: + HEIGHT: 640 + WIDTH: 640 + REPLAY: + MEM_BS: 10 + MEM_LIMIT: 100 + FILTER_SELF: 0 + + # 0 for unknown + ROI_BOX_HEAD: + NUM_ATTR: 13 + + OPTIMIZER: + ADAM: False + + OCL: + ACTIVATED: True + SORT_BY_ATTRIBUTES: False + ALGO: "VERX" + VOCAB: "/home/xisen/online-concept-learning/vocab/vocab_gqa_full.pkl" +INPUT: + MIN_SIZE_TRAIN: (640,) + MAX_SIZE_TRAIN: 640 + MIN_SIZE_TEST: 640 + MAX_SIZE_TEST: 640 +MODEL: + META_ARCHITECTURE: "GeneralizedRCNN" + # WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50" # ResNet-50 on Imagenet + WEIGHT: "./pretrained-models/e2e_faster_rcnn_R_50_FPN_1x_trimmed.pth" # faster-rcnn-r50 + BACKBONE: + CONV_BODY: "R-50-FPN" + RESNETS: + BACKBONE_OUT_CHANNELS: 256 + RPN: + USE_FPN: True + ANCHOR_STRIDE: (4, 8, 16, 32, 64) + PRE_NMS_TOP_N_TRAIN: 1200 + PRE_NMS_TOP_N_TEST: 1200 + PRE_NMS_TOP_N_TEST: 200 + POST_NMS_TOP_N_TEST: 200 + FPN_POST_NMS_TOP_N_TRAIN: 1000 + FPN_POST_NMS_TOP_N_TEST: 1000 + ROI_HEADS: + USE_FPN: True + # BG_IOU_THRESHOLD: 0.3 + # FG_IOU_THRESHOLD: 0.5 + # OHEM: False + ROI_BOX_HEAD: + POOLER_RESOLUTION: 7 + POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125) + POOLER_SAMPLING_RATIO: 2 + # MLP_HEAD_DIM: 2048 + FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor" + PREDICTOR: "FPNPredictor" + NUM_CLASSES: 1601 + MASK_ON: False +TYPE: "float32" +DATASETS: + TRAIN: ("coco_gqa_train",) + TEST: ("coco_gqa_val",) +DATALOADER: + SIZE_DIVISIBILITY: 32 +SOLVER: + BASE_LR: 0.05 + MOMENTUM: 0. + WEIGHT_DECAY: 0.00001 + STEPS: (12500, ) + MAX_ITER: 30000 + IMS_PER_BATCH: 32 +TEST: + IMS_PER_BATCH: 32 diff --git a/data/benchmark_mir.py b/data/benchmark_mir.py new file mode 100644 index 0000000..ba95f77 --- /dev/null +++ b/data/benchmark_mir.py @@ -0,0 +1,579 @@ +""" +Copyright (c) 2020 Rahaf Aljundi, Lucas Caccia, Eugene Belilovsky, Massimo Caccia + +Modified from https://github.com/optimass/Maximally_Interfered_Retrieval/ +""" +import os +import torch +import numpy as np +from PIL import Image +import random +from scipy.ndimage.interpolation import rotate +from torchvision import datasets, transforms +import yaml + + +""" Template Dataset with Labels """ +class XYDataset(torch.utils.data.Dataset): + def __init__(self, x, y, **kwargs): + self.x, self.y = x, y + + # this was to store the inverse permutation in permuted_mnist + # so that we could 'unscramble' samples and plot them + for name, value in kwargs.items(): + setattr(self, name, value) + + def __len__(self): + return len(self.x) + + def __getitem__(self, idx): + x, y = self.x[idx], self.y[idx] + + if type(x) != torch.Tensor: + # mini_imagenet + # we assume it's a path --> load from file + x = self.transform(Image.open(x).convert('RGB')) + y = torch.Tensor(1).fill_(y).long().squeeze() + elif self.source.startswith('cifar'): + #cifar10_mean = (0.5, 0.5, 0.5) + #cifar10_std = (0.5, 0.5, 0.5) + x = x.float() / 255 + # transform = transforms.Compose([ + # transforms.ToPILImage(), + # transforms.RandomCrop(32, padding=4), + # transforms.RandomHorizontalFlip(), + # transforms.RandomRotation(10), + # transforms.ToTensor(), + # ]) + # x = transform(x) + y = y.long() + else: + x = x.float() / 255. + y = y.long() + + # for some reason mnist does better \in [0,1] than [-1, 1] + if self.source == 'mnist' or self.source.startswith('cifar'): + return x, y + else: + return (x - 0.5) * 2, y + + +""" Template Dataset for Continual Learning """ +class CLDataLoader(object): + def __init__(self, datasets_per_task, batch_size, train=True): + bs = batch_size if train else 64 + + self.datasets = datasets_per_task + self.loaders = [ + torch.utils.data.DataLoader(x, batch_size=bs, shuffle=True, drop_last=train, num_workers=0) + for x in self.datasets ] + + def __getitem__(self, idx): + return self.loaders[idx] + + def __len__(self): + return len(self.loaders) + + +class FuzzyCLDataLoader(object): + def __init__(self, datasets_per_task, batch_size, train=True): + bs = batch_size if train else 64 + self.raw_datasets = datasets_per_task + self.datasets = [_ for _ in datasets_per_task] + for i in range(len(self.datasets) - 1): + self.datasets[i], self.datasets[i + 1] = self.mix_two_datasets(self.datasets[i], self.datasets[i + 1]) + self.loaders = [ + torch.utils.data.DataLoader(x, batch_size=bs, shuffle=True, drop_last=train, num_workers=0) + for x in self.datasets ] + + def shuffle(self, x, y): + perm = np.random.permutation(len(x)) + x = x[perm] + y = y[perm] + return x, y + + def mix_two_datasets(self, a, b, start=0.5): + a.x, a.y = self.shuffle(a.x, a.y) + b.x, b.y = self.shuffle(b.x, b.y) + + def cmf_examples(i): + if i < start * len(a): + return 0 + else: + return (1 - start) * len(a) * 0.25 * ((i / len(a) - start) / (1 - start)) ** 2 + + s, swaps = 0, [] + for i in range(len(a)): + c = cmf_examples(i) + if s < c: + swaps.append(i) + s += 1 + + for idx in swaps: + a.x[idx], b.x[len(b) - idx], a.y[idx], b.y[len(b) - idx] = b.x[len(b) - idx], a.x[idx], b.y[len(b) - idx], a.y[idx] + return a, b + + def __getitem__(self, idx): + return self.loaders[idx] + + def __len__(self): + return len(self.loaders) + + + +class IIDDataset(torch.utils.data.Dataset): + def __init__(self, data_loaders, seed=0): + self.data_loader = data_loaders + self.idx = [] + for task_id in range(len(data_loaders)): + for i in range(len(data_loaders[task_id].dataset)): + self.idx.append((task_id, i)) + random.Random(seed).shuffle(self.idx) + + def __getitem__(self, idx): + task_id, instance_id = self.idx[idx] + return self.data_loader[task_id].dataset.__getitem__(instance_id) + + def __len__(self): + return len(self.idx) + +""" Permuted MNIST """ +def get_permuted_mnist(args): + #assert not args.use_conv + args.multiple_heads = False + args.n_classes = 10 + #if 'mem_size' in args: + # args.buffer_size = args.mem_size * args.n_classes + args.n_tasks = 10 #if args.n_tasks==-1 else args.n_tasks + args.use_conv = False + args.input_type = 'binary' + args.input_size = [784] + #if args.output_loss is None: + args.output_loss = 'bernouilli' + + # fetch MNIST + train = datasets.MNIST('data/', train=True, download=True) + test = datasets.MNIST('data/', train=False, download=True) + + try: + train_x, train_y = train.data, train.targets + test_x, test_y = test.data, test.targets + except: + train_x, train_y = train.train_data, train.train_labels + test_x, test_y = test.test_data, test.test_labels + + # only select 1000 of train x + permutation = np.random.RandomState(0).permutation(train_x.size(0))[:1000] + train_x = train_x[permutation] + train_y = train_y[permutation] + + train_x = train_x.view(train_x.size(0), -1) + test_x = test_x.view(test_x.size(0), -1) + + train_ds, test_ds, inv_perms = [], [], [] + for task in range(args.n_tasks): + perm = torch.arange(train_x.size(-1)) if task == 0 else torch.randperm(train_x.size(-1)) + + # build inverse permutations, so we can display samples + inv_perm = torch.zeros_like(perm) + for i in range(perm.size(0)): + inv_perm[perm[i]] = i + + inv_perms += [inv_perm] + train_ds += [(train_x[:, perm], train_y)] + test_ds += [(test_x[:, perm], test_y)] + + train_ds, val_ds = make_valid_from_train(train_ds) + + train_ds = map(lambda x, y : XYDataset(x[0], x[1], **{'inv_perm': y, 'source': 'mnist'}), train_ds, inv_perms) + val_ds = map(lambda x, y: XYDataset(x[0], x[1], **{'inv_perm': y, 'source': 'mnist'}), val_ds, inv_perms) + test_ds = map(lambda x, y : XYDataset(x[0], x[1], **{'inv_perm': y, 'source': 'mnist'}), test_ds, inv_perms) + + return train_ds, val_ds, test_ds + +""" Rotated MNIST """ +def get_rotated_mnist(args): + #assert not args.use_conv + args.multiple_heads = False + args.n_classes = 10 + #if 'mem_size' in args: + # args.buffer_size = args.mem_size * args.n_classes + args.n_tasks = 20 + args.use_conv = False + args.input_type = 'binary' + args.input_size = [784] + #if args.output_loss is None: + args.output_loss = 'bernouilli' + + args.min_rot = 0 + args.max_rot = 180 + train_ds, test_ds, inv_perms = [], [], [] + val_ds = [] + # fetch MNIST + to_tensor = transforms.ToTensor() + + def rotate_dataset(x, angle): + x_np = np.copy(x.cpu().numpy()) + x_np = rotate(x_np, angle=angle, axes=(2,1), reshape=False) + return torch.from_numpy(x_np).float() + + train = datasets.MNIST('data/', train=True, download=True) + test = datasets.MNIST('data/', train=False, download=True) + + + for task in range(args.n_tasks): + #angle = random.random() * (args.max_rot - args.min_rot) + args.min_rot + min_rot = 1.0 * task / args.n_tasks * (args.max_rot - args.min_rot) + \ + args.min_rot + max_rot = 1.0 * (task + 1) / args.n_tasks * \ + (args.max_rot - args.min_rot) + args.min_rot + angle = random.random() * (max_rot - min_rot) + min_rot + + rand_perm = np.random.permutation(len(train.data))[:1000] + rand_perm_test = np.random.permutation(len(test.data))[:1000] + + try: + train_x, train_y = train.data[rand_perm], train.targets[rand_perm] + test_x, test_y = test.data[rand_perm_test], test.targets[rand_perm_test] + except: + train_x, train_y = train.train_data[rand_perm], train.train_labels[rand_perm] + test_x, test_y = test.test_data[rand_perm_test], test.test_labels[rand_perm_test] + + train_x, train_y, val_x, val_y = train_x[:950], train_y[:950], train_x[950:], train_y[950:] + + train_x = rotate_dataset(train_x, angle) + test_x = rotate_dataset(test_x, angle) + val_x = rotate_dataset(val_x, angle) + #train_x = train_x.view(train_x.size(0), -1) + #test_x = test_x.view(test_x.size(0), -1) + + train_ds += [(train_x, train_y)] + test_ds += [(test_x, test_y)] + val_ds += [(val_x, val_y)] + #train_ds, _ = make_valid_from_train(train_ds, cut=0.99) + + train_ds = map(lambda x: XYDataset(x[0], x[1], **{'source': 'mnist'}), train_ds) + val_ds = map(lambda x: XYDataset(x[0], x[1], **{'source': 'mnist'}), val_ds) + test_ds = map(lambda x: XYDataset(x[0], x[1], **{'source': 'mnist'}), test_ds) + + return train_ds, val_ds, test_ds + +""" Split MNIST into 5 tasks {{0,1}, ... {8,9}} """ +def get_split_mnist(args, cfg): + args.multiple_heads = False + args.n_classes = 10 + args.n_tasks = 5 #if args.n_tasks==-1 else args.n_tasks + if 'mem_size' in args: + args.buffer_size = args.n_tasks * args.mem_size * 2 + args.use_conv = False + args.input_type = 'binary' + args.input_size = [1,28,28] + #if args.output_loss is None: + args.output_loss = 'bernouilli' + + assert args.n_tasks in [5, 10], 'SplitMnist only works with 5 or 10 tasks' + assert '1.' in str(torch.__version__)[:2], 'Use Pytorch 1.x!' + + # fetch MNIST + train = datasets.MNIST('Data/', train=True, download=True) + test = datasets.MNIST('Data/', train=False, download=True) + + try: + train_x, train_y = train.data, train.targets + test_x, test_y = test.data, test.targets + except: + train_x, train_y = train.train_data, train.train_labels + test_x, test_y = test.test_data, test.test_labels + + # sort according to the label + out_train = [ + (x,y) for (x,y) in sorted(zip(train_x, train_y), key=lambda v : v[1]) ] + + out_test = [ + (x,y) for (x,y) in sorted(zip(test_x, test_y), key=lambda v : v[1]) ] + + train_x, train_y = [ + torch.stack([elem[i] for elem in out_train]) for i in [0,1] ] + + test_x, test_y = [ + torch.stack([elem[i] for elem in out_test]) for i in [0,1] ] + + # cast in 3D: + train_x = train_x.view(train_x.size(0), 1, train_x.size(1), train_x.size(2)) + test_x = test_x.view(test_x.size(0), 1, test_x.size(1), test_x.size(2)) + + # get indices of class split + train_idx = [((train_y + i) % 10).argmax() for i in range(10)] + train_idx = [0] + sorted(train_idx) + + test_idx = [((test_y + i) % 10).argmax() for i in range(10)] + test_idx = [0] + sorted(test_idx) + + train_ds, test_ds = [], [] + skip = 10 // args.n_tasks + for i in range(0, 10, skip): + tr_s, tr_e = train_idx[i], train_idx[i + skip] + te_s, te_e = test_idx[i], test_idx[i + skip] + + train_ds += [(train_x[tr_s:tr_e], train_y[tr_s:tr_e])] + test_ds += [(test_x[te_s:te_e], test_y[te_s:te_e])] + + if hasattr(cfg, 'NOVAL') and cfg.NOVAL: + train_ds, val_ds = train_ds, test_ds + print('no validation set') + else: + train_ds, val_ds = make_valid_from_train(train_ds) + + train_ds = map(lambda x : XYDataset(x[0], x[1], **{'source': 'mnist'}), train_ds) + val_ds = map(lambda x : XYDataset(x[0], x[1], **{'source': 'mnist'}), val_ds) + test_ds = map(lambda x : XYDataset(x[0], x[1], **{'source': 'mnist'}), test_ds) + + return train_ds, val_ds, test_ds + + + +""" Split CIFAR10 into 5 tasks {{0,1}, ... {8,9}} """ +def get_split_cifar10(args, cfg): + # assert args.n_tasks in [5, 10], 'SplitCifar only works with 5 or 10 tasks' + assert '1.' in str(torch.__version__)[:2], 'Use Pytorch 1.x!' + args.n_tasks = 5 + args.n_classes = 10 + #args.buffer_size = args.n_tasks * args.mem_size * 2 + args.multiple_heads = False + args.use_conv = True + args.n_classes_per_task = 2 + args.input_size = [3, 32, 32] + args.input_type = 'continuous' + # because data is between [-1,1]: + # fetch MNIST + train = datasets.CIFAR10('Data/', train=True, download=True) + test = datasets.CIFAR10('Data/', train=False, download=True) + + try: + train_x, train_y = train.data, train.targets + test_x, test_y = test.data, test.targets + except: + train_x, train_y = train.train_data, train.train_labels + test_x, test_y = test.test_data, test.test_labels + + # sort according to the label + out_train = [ + (x,y) for (x,y) in sorted(zip(train_x, train_y), key=lambda v : v[1]) ] + + out_test = [ + (x,y) for (x,y) in sorted(zip(test_x, test_y), key=lambda v : v[1]) ] + + train_x, train_y = [ + np.stack([elem[i] for elem in out_train]) for i in [0,1] ] + + test_x, test_y = [ + np.stack([elem[i] for elem in out_test]) for i in [0,1] ] + + train_x = torch.Tensor(train_x).permute(0, 3, 1, 2).contiguous() + test_x = torch.Tensor(test_x).permute(0, 3, 1, 2).contiguous() + + train_y = torch.Tensor(train_y) + test_y = torch.Tensor(test_y) + + # get indices of class split + train_idx = [((train_y + i) % 10).argmax() for i in range(10)] + train_idx = [0] + [x + 1 for x in sorted(train_idx)] + + test_idx = [((test_y + i) % 10).argmax() for i in range(10)] + test_idx = [0] + [x + 1 for x in sorted(test_idx)] + + train_ds, test_ds = [], [] + skip = 10 // 5 #args.n_tasks + for i in range(0, 10, skip): + tr_s, tr_e = train_idx[i], train_idx[i + skip] + te_s, te_e = test_idx[i], test_idx[i + skip] + + train_ds += [(train_x[tr_s:tr_e], train_y[tr_s:tr_e])] + test_ds += [(test_x[te_s:te_e], test_y[te_s:te_e])] + if hasattr(cfg, 'NOVAL') and cfg.NOVAL: + train_ds, val_ds = train_ds, test_ds + print('no validation set') + else: + train_ds, val_ds = make_valid_from_train(train_ds) + #else: + # train_ds, val_ds = train_ds, test_ds + train_ds = map(lambda x : XYDataset(x[0], x[1], **{'source':'cifar10'}), train_ds) + val_ds = map(lambda x : XYDataset(x[0], x[1], **{'source':'cifar10'}), val_ds) + test_ds = map(lambda x : XYDataset(x[0], x[1], **{'source':'cifar10'}), test_ds) + + return train_ds, val_ds, test_ds + +""" Split CIFAR100 into 20 tasks {{0,1,2,3,4}, ... {95,96,97,98,99}} """ +def get_split_cifar100(args, cfg): + # assert args.n_tasks in [5, 10], 'SplitCifar only works with 5 or 10 tasks' + assert '1.' in str(torch.__version__)[:2], 'Use Pytorch 1.x!' + args.n_tasks = 20 + args.n_classes = 100 + #args.buffer_size = args.n_tasks * args.mem_size * 2 + args.multiple_heads = False + args.use_conv = True + args.n_classes_per_task = 5 + args.input_size = [3, 32, 32] + args.input_type = 'continuous' + # fetch MNIST + train = datasets.CIFAR100('Data/', train=True, download=True) + test = datasets.CIFAR100('Data/', train=False, download=True) + + try: + train_x, train_y = train.data, train.targets + test_x, test_y = test.data, test.targets + except: + train_x, train_y = train.train_data, train.train_labels + test_x, test_y = test.test_data, test.test_labels + + # sort according to the label + out_train = [ + (x,y) for (x,y) in sorted(zip(train_x, train_y), key=lambda v : v[1]) ] + + out_test = [ + (x,y) for (x,y) in sorted(zip(test_x, test_y), key=lambda v : v[1]) ] + + train_x, train_y = [ + np.stack([elem[i] for elem in out_train]) for i in [0,1] ] + + test_x, test_y = [ + np.stack([elem[i] for elem in out_test]) for i in [0,1] ] + + train_x = torch.Tensor(train_x).permute(0, 3, 1, 2).contiguous() + test_x = torch.Tensor(test_x).permute(0, 3, 1, 2).contiguous() + train_y = torch.Tensor(train_y) + test_y = torch.Tensor(test_y) + train_ds, test_ds = [], [] + + with open('data/cifar100-split-online.yaml') as f: + label_splits = yaml.load(f) + for label_split in label_splits: + train_indice = [] + task_labels = [_[1] for _ in label_split['subsets']] + for i in range(len(train_y)): + if train_y[i].item() in task_labels: + train_indice.append(i) + test_indice = [] + for i in range(len(test_y)): + if test_y[i].item() in task_labels: + test_indice.append(i) + train_ds.append((train_x[train_indice], train_y[train_indice])) + test_ds.append((test_x[test_indice], test_y[test_indice])) + + train_ds, val_ds = train_ds, test_ds + train_ds = map(lambda x : XYDataset(x[0], x[1], **{'source':'cifar100'}), train_ds) + val_ds = map(lambda x : XYDataset(x[0], x[1], **{'source':'cifar100'}), val_ds) + test_ds = map(lambda x : XYDataset(x[0], x[1], **{'source':'cifar100'}), test_ds) + + return train_ds, val_ds, test_ds + +def get_miniimagenet(args): + ROOT_PATH = 'datasets/miniImagenet/' + + args.use_conv = True + args.n_tasks = 20 + args.n_classes = 100 + args.multiple_heads = False + args.n_classes_per_task = 5 + args.input_size = (3, 84, 84) + label2id = {} + + def get_data(setname): + ds_dir = os.path.join(ROOT_PATH, setname) + label_dirs = os.listdir(ds_dir) + data, labels = [], [] + + for label in label_dirs: + label_dir = os.path.join(ds_dir, label) + for image_file in os.listdir(label_dir): + data.append(os.path.join(label_dir, image_file)) + if label not in label2id: + label_id = len(label2id) + label2id[label] = label_id + label_id = label2id[label] + labels.append(label_id) + return data, labels + + transform = transforms.Compose([ + transforms.Resize(84), + transforms.CenterCrop(84), + transforms.ToTensor(), + ]) + + train_data, train_label = get_data('train') + valid_data, valid_label = get_data('val') + test_data, test_label = get_data('test') + + # total of 60k examples for training, the rest for testing + all_data = np.array(train_data + valid_data + test_data) + all_label = np.array(train_label + valid_label + test_label) + + + train_ds, test_ds = [], [] + current_train, current_test = None, None + + cat = lambda x, y: np.concatenate((x, y), axis=0) + + for i in range(args.n_classes): + class_indices = np.argwhere(all_label == i).reshape(-1) + class_data = all_data[class_indices] + class_label = all_label[class_indices] + split = int(0.8 * class_data.shape[0]) + + data_train, data_test = class_data[:split], class_data[split:] + label_train, label_test = class_label[:split], class_label[split:] + + if current_train is None: + current_train, current_test = (data_train, label_train), (data_test, label_test) + else: + current_train = cat(current_train[0], data_train), cat(current_train[1], label_train) + current_test = cat(current_test[0], data_test), cat(current_test[1], label_test) + + if i % args.n_classes_per_task == (args.n_classes_per_task - 1): + train_ds += [current_train] + test_ds += [current_test] + current_train, current_test = None, None + + # TODO: remove this + ## Facebook actually does 17 tasks (3 to CV) + #train_ds = train_ds[:17] + #test_ds = test_ds[:17] + + # build masks + masks = [] + task_ids = [None for _ in range(20)] + for task, task_data in enumerate(train_ds): + labels = np.unique(task_data[1]) #task_data[1].unique().long() + assert labels.shape[0] == args.n_classes_per_task + mask = torch.zeros(args.n_classes).cuda() + mask[labels] = 1 + masks += [mask] + task_ids[task] = labels + + task_ids = torch.from_numpy(np.stack(task_ids)).cuda().long() + + train_ds, val_ds = make_valid_from_train(train_ds) + train_ds = map(lambda x, y : XYDataset(x[0], x[1], **{'source':'cifar100', 'mask':y, 'task_ids':task_ids, 'transform':transform}), train_ds, masks) + val_ds = map(lambda x, y: XYDataset(x[0], x[1], **{'source': 'cifar100', 'mask': y, 'task_ids': task_ids, 'transform': transform}), val_ds, masks) + test_ds = map(lambda x, y : XYDataset(x[0], x[1], **{'source':'cifar100', 'mask':y, 'task_ids':task_ids, 'transform':transform}), test_ds, masks) + + return train_ds, val_ds, test_ds + + +def make_valid_from_train(dataset, cut=0.95): + tr_ds, val_ds = [], [] + for task_ds in dataset: + x_t, y_t = task_ds + + # shuffle before splitting + perm = torch.randperm(len(x_t)) + x_t, y_t = x_t[perm], y_t[perm] + + split = int(len(x_t) * cut) + x_tr, y_tr = x_t[:split], y_t[:split] + x_val, y_val = x_t[split:], y_t[split:] + + tr_ds += [(x_tr, y_tr)] + val_ds += [(x_val, y_val)] + + return tr_ds, val_ds diff --git a/dataloader.py b/dataloader.py new file mode 100644 index 0000000..f200858 --- /dev/null +++ b/dataloader.py @@ -0,0 +1,144 @@ +import torch +import yaml +import numpy as np +import pickle +from PIL import Image, ImageDraw +from torch.utils.data import DataLoader +from torchvision.transforms import transforms as T +from tqdm import tqdm +from yacs.config import CfgNode +from maskrcnn_benchmark.config import cfg as maskrcnn_cfg +from utils.tupperware import tupperware +from utils.build_transforms import build_transforms +from data.benchmark_mir import CLDataLoader, get_permuted_mnist, get_split_mnist, get_miniimagenet, get_rotated_mnist, \ + get_split_cifar10, get_split_cifar100, IIDDataset, FuzzyCLDataLoader + +from utils.utils import DotDict, get_config_attr + +import random + +_dataset = { + +} + +_smnist_loaders = None +def get_split_mnist_dataloader(cfg, split='train', filter_obj=None, batch_size=128, *args, **kwargs): + fuzzy = get_config_attr(cfg,'EXTERNAL.OCL.FUZZY', default=0, mute=True) + d = DotDict() + global _smnist_loaders + if not _smnist_loaders: + data = get_split_mnist(d, cfg) + loader_cls = CLDataLoader if not fuzzy else FuzzyCLDataLoader + train_loader, val_loader, test_loader = [loader_cls(elem, batch_size, train=t) \ + for elem, t in zip(data, [True, False, False])] + _smnist_loaders = train_loader, val_loader, test_loader + else: + train_loader, val_loader, test_loader = _smnist_loaders + + if split == 'train': + return train_loader[filter_obj[0]] + elif split == 'val': + return val_loader[filter_obj[0]] + elif split == 'test': + return test_loader[filter_obj[0]] + + +_rmnist_loaders = None +def get_rotated_mnist_dataloader(cfg, split='train', filter_obj=None, batch_size=128, task_num=10, *args, **kwargs): + d = DotDict() + fuzzy = get_config_attr(cfg, 'EXTERNAL.OCL.FUZZY', default=0, mute=True) + global _rmnist_loaders + if not _rmnist_loaders: + data = get_rotated_mnist(d) + #train_loader, val_loader, test_loader = [CLDataLoader(elem, batch_size, train=t) \ + # for elem, t in zip(data, [True, False, False])] + loader_cls = CLDataLoader if not fuzzy else FuzzyCLDataLoader + train_loader, val_loader, test_loader = [loader_cls(elem, batch_size, train=t) \ + for elem, t in zip(data, [True, False, False])] + _rmnist_loaders = train_loader, val_loader, test_loader + else: + train_loader, val_loader, test_loader = _rmnist_loaders + if split == 'train': + return train_loader[filter_obj[0]] + elif split == 'val': + return val_loader[filter_obj[0]] + elif split == 'test': + return test_loader[filter_obj[0]] + +_pmnist_loaders = None +def get_permute_mnist_dataloader(cfg, split='train', filter_obj=None, batch_size=128, task_num=10, *args, **kwargs): + d = DotDict() + fuzzy = get_config_attr(cfg, 'EXTERNAL.OCL.FUZZY', default=0, mute=True) + global _pmnist_loaders + if not _pmnist_loaders: + data = get_permuted_mnist(d) + loader_cls = CLDataLoader if not fuzzy else FuzzyCLDataLoader + train_loader, val_loader, test_loader = [loader_cls(elem, batch_size, train=t) \ + for elem, t in zip(data, [True, False, False])] + _pmnist_loaders = train_loader, val_loader, test_loader + else: + train_loader, val_loader, test_loader = _pmnist_loaders + if split == 'train': + return train_loader[filter_obj[0]] + elif split == 'val': + return val_loader[filter_obj[0]] + elif split == 'test': + return test_loader[filter_obj[0]] + +_cache_cifar = None +def get_split_cifar_dataloader(cfg, split='train', filter_obj=None, batch_size=128, *args, **kwargs): + d = DotDict() + fuzzy = get_config_attr(cfg, 'EXTERNAL.OCL.FUZZY', default=0, mute=True) + global _cache_cifar + if not _cache_cifar: + data = get_split_cifar10(d,cfg) #ds_cifar10and100(batch_size=batch_size, num_workers=0, cfg=cfg, **kwargs) + loader_cls = CLDataLoader if not fuzzy else FuzzyCLDataLoader + train_loader, val_loader, test_loader = [loader_cls(elem, batch_size, train=t) \ + for elem, t in zip(data, [True, False, False])] + _cache_cifar = train_loader, val_loader, test_loader + train_loader, val_loader, test_loader = _cache_cifar + if split == 'train': + return train_loader[filter_obj[0]] + elif split == 'val': + return val_loader[filter_obj[0]] + elif split == 'test': + return test_loader[filter_obj[0]] + +_cache_cifar100 = None +def get_split_cifar100_dataloader(cfg, split='train', filter_obj=None, batch_size=128, *args, **kwargs): + d = DotDict() + fuzzy = get_config_attr(cfg, 'EXTERNAL.OCL.FUZZY', default=0, mute=True) + global _cache_cifar100 + if not _cache_cifar100: + data = get_split_cifar100(d,cfg) #ds_cifar10and100(batch_size=batch_size, num_workers=0, cfg=cfg, **kwargs) + loader_cls = CLDataLoader if not fuzzy else FuzzyCLDataLoader + train_loader, val_loader, test_loader = [loader_cls(elem, batch_size, train=t) \ + for elem, t in zip(data, [True, False, False])] + _cache_cifar100 = train_loader, val_loader, test_loader + train_loader, val_loader, test_loader = _cache_cifar100 + if split == 'train': + return train_loader[filter_obj[0]] + elif split == 'val': + return val_loader[filter_obj[0]] + elif split == 'test': + return test_loader[filter_obj[0]] + +_cache_mini_imagenet = None +def get_split_mini_imagenet_dataloader(cfg, split='train', filter_obj=None, batch_size=128, *args, **kwargs): + global _cache_mini_imagenet + d = DotDict() + fuzzy = get_config_attr(cfg, 'EXTERNAL.OCL.FUZZY', default=0, mute=True) + if not _cache_mini_imagenet: + data = get_miniimagenet(d) + loader_cls = CLDataLoader if not fuzzy else FuzzyCLDataLoader + train_loader, val_loader, test_loader = [loader_cls(elem, batch_size, train=t) \ + for elem, t in zip(data, [True, False, False])] + _cache_mini_imagenet = train_loader, val_loader, test_loader + train_loader, val_loader, test_loader = _cache_mini_imagenet + if split == 'train': + return train_loader[filter_obj[0]] + elif split == 'val': + return val_loader[filter_obj[0]] + elif split == 'test': + return test_loader[filter_obj[0]] + diff --git a/inference.py b/inference.py new file mode 100644 index 0000000..5243da1 --- /dev/null +++ b/inference.py @@ -0,0 +1,359 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import logging +import h5py + +import torch +import json +from tqdm import tqdm +from time import sleep + +from config import cfg +from utils.utils import Timer +from dataloader import get_dataloader +from sklearn.metrics import f1_score +from collections import defaultdict +from torch.nn import functional as F + + +def to_list(tensor): + return tensor.cpu().numpy().tolist() + +def compute_on_dataset(model, data_loader, local_rank, device, timer=None, output_file='debug.h5'): + model.eval() + results_dict = {} + gt_dict = {} + cpu_device = torch.device("cpu") + text = "GPU {}".format(local_rank) + pbar = tqdm( + total=len(data_loader), + position=local_rank, + desc=text, + ) + with h5py.File(output_file, 'w') as f: + box_feat_ds = f.create_dataset( + 'bbox_features', shape=(len(data_loader), 150, 1024) + ) + + box_ds = f.create_dataset( + 'bboxes', shape=(len(data_loader), 150, 4) + ) + + num_boxes_ds = f.create_dataset( + 'num_boxes', shape=(len(data_loader), 1) + ) + + img_size_ds = f.create_dataset( + 'image_size', shape=(len(data_loader), 2) + ) + + info_dict = {} + + for e, out_dict in enumerate(data_loader): + images = out_dict['images'] + targets = out_dict['gt_bboxes'] + image_ids = out_dict['image_ids'] + info = out_dict['info'] + + images = [image.to(device) for image in images] + targets = [target.to(device) for target in targets] + if timer: + timer.tic() + x, _, output = model(images, targets) + if timer: + if not cfg.MODEL.DEVICE == 'cpu': + torch.cuda.synchronize() + timer.toc() + assert (x.shape[0] == len(targets[0].bbox)) + output = [o.to(cpu_device) for o in output] + info_dict[image_ids[0]] = {} + info_dict[image_ids[0]]['idx'] = e + info_dict[image_ids[0]]['objects'] = {} + bboxes = targets[0].bbox + num_boxes = len(targets[0].bbox) + for i in range(num_boxes): + tmp = {} + tmp = info[0][i] + tmp['idx'] = i + info_dict[image_ids[0]]['objects'][info[0][i]['object_id']] = tmp + + ''' + for i in range(5): + fpn_feat_ds[i][e] = output[i].numpy() + ''' + box_feat_ds[e, :num_boxes, :] = x.cpu().numpy() + box_ds[e, :num_boxes, :] = bboxes.cpu().numpy() + num_boxes_ds[e] = num_boxes + img_size_ds[e] = targets[0].size + pbar.update(1) + sleep(0.001) + + with open(output_file.replace('.h5', '_map.json'), 'w') as fp: + json.dump(info_dict, fp) + + +def inference_step(model, out_dict, device=torch.device('cuda')): + images = torch.stack(out_dict['images']) + obj_labels = torch.cat(out_dict['object_labels'], -1) + attr_labels = torch.cat(out_dict['attribute_labels'], -1) + cropped_image = torch.stack(out_dict['cropped_image']) + images = images.to(device) + obj_labels = obj_labels.to(device) + attr_labels = attr_labels.to(device) + + cropped_image = cropped_image.to(device) + # loss_dict = model(images, targets) + ret_dict = model(bbox_images=cropped_image, spatial_feat=None, + attr_labels=attr_labels, obj_labels=obj_labels, + images=images) + attr_score, obj_score = ret_dict.get('attr_score', None), \ + ret_dict.get('obj_score', None) + if attr_score is not None: + attr_score_norm = F.softmax(attr_score, -1) + ret_dict['pred_attr_prob'], ret_dict['pred_attr'] = attr_score_norm.max(-1) + if obj_score is not None: + obj_score_norm = F.softmax(obj_score, -1) + ret_dict['pred_obj_prob'], ret_dict['pred_obj'] = obj_score_norm.max(-1) + ret_dict['obj_labels'], ret_dict['attr_labels'] = obj_labels, attr_labels + return ret_dict + + +def inference( + model, + current_epoch, + current_iter, + local_rank, + data_loader, + dataset_name, + device="cuda", + max_instance=3200, + mute=False, + verbose_return=False +): + model.train(False) + # convert to a torch.device for efficiency + device = torch.device(device) + if not mute: + logger = logging.getLogger("maskrcnn_benchmark.inference") + logger.info("Start evaluation") + total_timer = Timer() + total_timer.tic() + torch.cuda.empty_cache() + if not mute: + pbar = tqdm( + total=len(data_loader), + desc="Validation in progress" + ) + + def to_list(tensor): + return tensor.cpu().numpy().tolist() + with torch.no_grad(): + all_pred_obj, all_truth_obj, all_pred_attr, all_truth_attr = [], [], [], [] + all_image_ids, all_boxes = [], [] + all_pred_attr_prob = [] + all_raws = [] + obj_loss_all, attr_loss_all = 0, 0 + + cnt = 0 + for iteration, out_dict in enumerate(data_loader): + if type(max_instance) is int: + if iteration == max_instance // model.cfg.EXTERNAL.BATCH_SIZE: break + if type(max_instance) is float: + if iteration > max_instance * len(data_loader) // model.cfg.EXTERNAL.BATCH_SIZE: break + # print(iteration) + + if verbose_return: + all_image_ids.extend(out_dict['image_ids']) + all_boxes.extend(out_dict['gt_bboxes']) + all_raws.extend(out_dict['raw']) + + ret_dict = inference_step(model, out_dict, device) + loss_attr, loss_obj, attr_score, obj_score = ret_dict.get('attr_loss', None), \ + ret_dict.get('obj_loss', None), \ + ret_dict.get('attr_score', None), \ + ret_dict.get('obj_score', None) + + if loss_attr is not None: + attr_loss_all += loss_attr.item() + pred_attr_prob, pred_attr = ret_dict['pred_attr_prob'], ret_dict['pred_attr'] + all_pred_attr.extend(to_list(pred_attr)) + all_truth_attr.extend(to_list(ret_dict['attr_labels'])) + all_pred_attr_prob.extend(to_list(pred_attr_prob)) + if loss_obj is not None: + obj_loss_all += loss_obj.item() + _, pred_obj = obj_score.max(-1) + all_pred_obj.extend(to_list(pred_obj)) + all_truth_obj.extend(to_list(ret_dict['obj_labels'])) + cnt += 1 + if not mute: + pbar.update(1) + + obj_f1 = f1_score(all_truth_obj, all_pred_obj, average='micro') + attr_f1 = f1_score(all_truth_attr, all_pred_attr, average='micro') + obj_loss_all /= (cnt + 1e-10) + attr_loss_all /= (cnt + 1e-10) + if not mute: + logger.info('Epoch: {}\tIteration: {}\tObject f1: {}\tAttr f1:{}\tObject loss:{}\tAttr loss:{}'. + format(current_epoch, current_iter, obj_f1, attr_f1, obj_loss_all, attr_loss_all)) + #compute_on_dataset(model, data_loader, local_rank, device, inference_timer, output_file) + # wait for all processes to complete before measuring the time + total_time = total_timer.toc() + model.train(True) + if not verbose_return: + return obj_f1, attr_f1, len(all_truth_attr) + else: + return obj_f1, attr_f1, all_pred_attr, all_truth_attr, all_pred_obj, all_truth_obj, all_image_ids, all_boxes, \ + all_pred_attr_prob, all_raws + +def run_forget_metrics(metric_dict, finished_tasks, all_tasks, key, forget_dict): + for task in all_tasks: + if len(metric_dict[key][task]) <= 1 or task not in finished_tasks: + forget_dict[task].append(-1) + else: + forget_dict[task].append(max(metric_dict[key][task][:-1]) - metric_dict[key][task][-1]) + + +def run_forward_transfer_metrics(metric_dict, seen_tasks, all_tasks, key, ft_dict): + for task in all_tasks: + if task not in seen_tasks: + ft_dict[task].append(metric_dict[key][task][-1]) + else: + ft_dict[task].append(-1) + + +def numericalize_metric_scores(metric_dict): + result_dict = defaultdict(list) + for t in range(metric_dict['length']): + for key in ['attr_acc', 'forget_dict', 'obj_acc']: + total_inst = 0 + total_score = 0 + for attr in metric_dict[key]: + inst_num = metric_dict['inst_num'][attr][t] + score = metric_dict[key][attr][t] + if score != -1: + total_inst += inst_num + total_score += score * inst_num + avg_score = total_score / (total_inst + 1e-10) + result_dict[key].append(avg_score) + return result_dict + + +def inference_ocl_attr( + model, + current_epoch, + current_iter, + dataset_name, + prev_metric_dict, + seen_objects, + finished_objects, + all_objects, + max_instance +): + """ + + :param model: + :param current_epoch: + :param current_iter: + :param prev_metric_dict: {attr_acc: : [acc1, acc2]} + :param filter_objects: + :param filter_attrs: + :return: + """ + model.train(False) + device = torch.device('cuda') + + if not prev_metric_dict: + prev_metric_dict = { + 'attr_acc': defaultdict(list), + 'inst_num': defaultdict(list), + 'ft_dict': defaultdict(list), + 'forget_dict': defaultdict(list), + 'obj_acc': defaultdict(list), + 'length': 0 + } + + pbar = tqdm( + total=len(all_objects), + desc="Validation in progress" + ) + # only seen objects by this time + for obj in all_objects: + dataloader = get_dataloader(model.cfg, 'val',False,False,filter_obj=[obj]) + obj_acc, attr_acc, inst_num = inference(model, current_epoch, current_iter, 0, dataloader, dataset_name, + max_instance=max_instance, mute=True) + + prev_metric_dict['attr_acc'][obj].append(attr_acc) + prev_metric_dict['inst_num'][obj].append(inst_num) + prev_metric_dict['obj_acc'][obj].append(obj_acc) + pbar.update(1) + + metric_dict = prev_metric_dict + #run_forward_transfer_metrics(metric_dict, seen_objects, all_objects, 'attr_acc', metric_dict['ft_dict']) + run_forget_metrics(metric_dict, finished_objects, all_objects, 'attr_acc', metric_dict['forget_dict']) + metric_dict['length'] += 1 + + numerical_metric_dict = numericalize_metric_scores(metric_dict) + + return metric_dict, numerical_metric_dict + +def inference_mean_exemplar( + model, + current_epoch, + current_iter, + local_rank, + data_loader, + dataset_name, + device="cuda", + max_instance=3200, + mute=False, +): + model.train(False) + # convert to a torch.device for efficiency + device = torch.device(device) + if not mute: + logger = logging.getLogger("maskrcnn_benchmark.inference") + logger.info("Start evaluation") + total_timer = Timer() + inference_timer = Timer() + total_timer.tic() + torch.cuda.empty_cache() + if not mute: + pbar = tqdm( + total=len(data_loader), + desc="Validation in progress" + ) + with torch.no_grad(): + all_pred_obj, all_truth_obj, all_pred_attr, all_truth_attr = [], [], [], [] + obj_loss_all, attr_loss_all = 0, 0 + cnt = 0 + for iteration, out_dict in enumerate(data_loader): + if type(max_instance) is int: + if iteration == max_instance // model.cfg.EXTERNAL.BATCH_SIZE: break + if type(max_instance) is float: + if iteration > max_instance * len(data_loader) // model.cfg.EXTERNAL.BATCH_SIZE: break + # print(iteration) + images = torch.stack(out_dict['images']) + obj_labels = torch.cat(out_dict['object_labels'], -1) + attr_labels = torch.cat(out_dict['attribute_labels'], -1) + cropped_image = torch.stack(out_dict['cropped_image']) + + images = images.to(device) + obj_labels = obj_labels.to(device) + attr_labels = attr_labels.to(device) + + cropped_image = cropped_image.to(device) + # loss_dict = model(images, targets) + pred_obj = model.mean_of_exemplar_classify(cropped_image) + + all_pred_obj.extend(to_list(pred_obj)) + all_truth_obj.extend(to_list(obj_labels)) + cnt += 1 + if not mute: + pbar.update(1) + + obj_f1 = f1_score(all_truth_obj, all_pred_obj, average='micro') + #attr_f1 = f1_score(all_truth_attr, all_pred_attr, average='micro') + obj_loss_all /= (cnt + 1e-10) + # wait for all processes to complete before measuring the time + total_time = total_timer.toc() + model.train(True) + return obj_f1, 0, len(all_truth_obj) \ No newline at end of file diff --git a/nets/__init__.py b/nets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nets/classifier.py b/nets/classifier.py new file mode 100644 index 0000000..1411ba2 --- /dev/null +++ b/nets/classifier.py @@ -0,0 +1,303 @@ +from .initialization import get_glove_matrix +from torchvision.models import resnet50, resnet34 #resnet18 +import torch +from torch import nn +import torch.nn.functional as F +import logging + +logger = logging.getLogger(__name__) +logger.setLevel('INFO') +logger.addHandler(logging.StreamHandler()) + + +def get_resnet_features(x, net): + if isinstance(net, ResNet): + return net(x) + else: + x = net.conv1(x) + x = net.bn1(x) + x = net.relu(x) + x = net.maxpool(x) + + x = net.layer1(x) + x = net.layer2(x) + x = net.layer3(x) + x = net.layer4(x) + + x = net.avgpool(x) + x = torch.flatten(x, 1) + return x + +class ResNetClassifier(nn.Module): + def __init__(self, cfg, depth='34', mlp=3, init=True, ignore_index=0, num_of_datasets=1, + num_of_classes=100, task_incremental=False, goal=None, *args, **kwargs): + super().__init__() + + self.debug = hasattr(cfg, 'DEBUG') and cfg.DEBUG + + if depth == '34': + self.resnet = resnet34(pretrained=False) + resnet_feat_size = self.resnet.fc.weight.size(1) + elif depth == '18': + if goal == 'split_mini_imagenet': + self.resnet = ResNet18(input_size=(3,84,84)) + else: + self.resnet = ResNet18() + resnet_feat_size = self.resnet.last_hid_size + hidden_size = resnet_feat_size + mlps = [] + for _ in range(mlp-1): + mlps.append(nn.Linear(hidden_size, hidden_size)) + mlps.append(nn.ReLU()) + + self.mlp_attr = nn.Sequential( + *mlps + ) + + self.task_incremental = task_incremental + self.num_of_datasets = num_of_datasets + self.num_of_classes = num_of_classes + + if self.debug: + self.final_layer = nn.Linear(self.resnet.last_hid_size, num_of_classes * num_of_datasets) + else: + self.final_layer = nn.Linear(hidden_size, num_of_classes * num_of_datasets) + + self.mlp_obj = nn.Sequential( + nn.Linear(hidden_size, hidden_size // 2), + nn.ReLU(), + nn.Dropout(0.2), + nn.Linear(hidden_size // 2, 2000) + ) + + self.ignore_index = ignore_index + self.criterion = F.cross_entropy + + self.cfg = cfg + #self.task_specific_params = nn.ModuleList([self.mlp_attr[0], self.mlp_attr[2], self.mlp_attr[4]]) + + if init and cfg.MODE == 'train': + self.initialize() + + def forward(self, images, bbox_images, spatial_feat, attr_labels, obj_labels, weights=None, reduce=True, + task_ids=None): + if self.debug: + feat = self.resnet.return_hidden(bbox_images) + attr_score = self.final_layer(feat) + else: + feat = get_resnet_features(bbox_images, self.resnet) + attr_feat = self.mlp_attr(torch.cat([feat], -1)) + attr_score = self.final_layer(attr_feat) + + if self.task_incremental: + scores = [] # get logits of corresponding tasks + for b in range(attr_score.size(0)): + scores.append(attr_score[b, task_ids[b] * self.num_of_classes: (task_ids[b] + 1) * self.num_of_classes]) + attr_score = torch.stack(scores) + + loss_attr, loss_obj = None, None + if attr_labels is not None: + reduction = 'mean' if reduce else 'none' + loss_attr = self.criterion(attr_score, attr_labels, ignore_index=self.ignore_index, + reduction=reduction) + # if obj_labels is not None: + # reduction = 'mean' if reduce else 'none' + # loss_obj = self.criterion(obj_score, obj_labels, ignore_index=self.ignore_index, + # reduction=reduction) + + return { + 'attr_score': attr_score, + #'obj_score': obj_score, + 'attr_loss': loss_attr, + #'obj_loss': loss_obj, + 'loss': loss_attr, #loss_obj + loss_attr, + 'score': attr_score, + 'feat': feat + } + + def forward_from_feat(self, feat, attr_labels, reduce=True, **kwargs): + reduction = 'mean' if reduce else 'none' + attr_score = self.mlp_attr(feat) + loss_attr = self.criterion(attr_score, attr_labels, ignore_index=self.ignore_index, + reduction=reduction) + return { + 'attr_score': attr_score, + 'attr_loss': loss_attr, + 'loss': loss_attr, + 'score': attr_score, + 'feat': feat + } + + def get_obj_features(self, bbox_images): + feat = get_resnet_features(bbox_images, self.resnet) + for i, module in enumerate(self.mlp_obj._modules.values()): + if i == len(self.mlp_obj) - 1: break + feat = module(feat) + return feat + + def set_task_specific_weights(self, idx, weight): + self.task_specific_params[idx].weight = weight + + def _freeze_resnet_if_required(self, cfg): + if hasattr(cfg.EXTERNAL, 'FREEZE_RESNET') and cfg.EXTERNAL.FREEZE_RESNET: + for param in self.resnet.parameters(): + param.requires_grad = False + logger.info('Resnet frozen') + + def _load_resnet_weights_if_required(self, cfg): + if hasattr(cfg.EXTERNAL, 'LOAD_RESNET') and cfg.EXTERNAL.LOAD_RESNET: + checkpoint = torch.load(cfg.EXTERNAL.LOAD_RESNET, map_location=torch.device("cpu")) + model_state_dict = checkpoint['model'] + model_state_dict = dict(filter(lambda x: x[0].startswith('resnet'), model_state_dict.items())) + model = ResNetClassifier(self.cfg, init=False) + model.load_state_dict(model_state_dict, strict=False) + self.resnet = model.resnet.to(cfg.MODEL.DEVICE) + logger.info('loaded resnet from %s' % cfg.EXTERNAL.LOAD_RESNET) + + def initialize(self): + self._freeze_resnet_if_required(self.cfg) + self._load_resnet_weights_if_required(self.cfg) + + +class ResNetClassifierWObj(ResNetClassifier): + def __init__(self, cfg, init=True, *args, **kwargs): + # do not init + super().__init__(cfg, init=False, *args, **kwargs) + self.cfg = cfg + self.resnet = resnet34(pretrained=True) + resnet_feat_size = self.resnet.fc.weight.size(1) + hidden_size = resnet_feat_size + label_embed_size = self.cfg.WOBJ.LABEL_EMBED_DIM + + self.obj_emb = nn.Embedding(2000, label_embed_size) + self.mlp_attr = nn.Sequential( + nn.Linear(hidden_size + label_embed_size, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, 101) + ) + self.task_specific_params = nn.ModuleList([self.mlp_attr[0], self.mlp_attr[2], self.mlp_attr[4]]) + self.criterion = nn.CrossEntropyLoss(ignore_index=0) + if init and cfg.MODE == 'train': + self.initialize() + + def forward(self, images, bbox_images, spatial_feat, attr_labels, obj_labels, weights=None): + obj_emb = self.obj_emb(obj_labels) + feat = get_resnet_features(bbox_images, self.resnet) + feat = torch.cat([feat, obj_emb], -1) + + if weights is None: + attr_score = self.mlp_attr(feat) + else: # hypernetwork + hidden_1 = F.relu(F.linear(feat, weights[0], weights[1])) + hidden_2 = F.relu(F.linear(hidden_1, weights[2], weights[3])) + attr_score = F.linear(hidden_2, weights[4], weights[5]) + + loss_attr, loss_obj = None, None + if attr_labels is not None: + loss_attr = self.criterion(attr_score, attr_labels) + #if obj_labels is not None: + # loss_obj = self.criterion(obj_score, obj_labels) + + return { + 'attr_score': attr_score, + #'obj_score': obj_score, + 'attr_loss': loss_attr, + #'obj_loss': loss_obj, + 'loss': loss_attr + } + + def get_obj_features(self, bbox_images): + feat = get_resnet_features(bbox_images, self.resnet) + for i, module in enumerate(self.mlp_obj._modules.values()): + if i == len(self.mlp_obj) - 1: break + feat = module(feat) + return feat + + def set_task_specific_weights(self, idx, weight): + self.task_specific_params[idx].weight = weight + + def initialize_word_emb(self, vocab, w2v_file): + device = self.obj_emb.weight.data.device + mat = get_glove_matrix(vocab, w2v_file, self.obj_emb.weight.data.cpu().numpy()) + mat = torch.from_numpy(mat).float().to(device) + self.obj_emb.weight.data = mat + logger.info('loaded embedding from {}'.format(w2v_file)) + +def conv3x3(in_planes, out_planes, stride=1): + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(in_planes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, + stride=stride, bias=False), + nn.BatchNorm2d(self.expansion * planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + +class ResNet(nn.Module): + def __init__(self, block, num_blocks, nf, input_size): + super(ResNet, self).__init__() + self.in_planes = nf + self.input_size = input_size + + self.conv1 = conv3x3(input_size[0], nf * 1) + self.bn1 = nn.BatchNorm2d(nf * 1) + #self.bn1 = CategoricalConditionalBatchNorm(nf, 2) + self.layer1 = self._make_layer(block, nf * 1, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, nf * 2, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, nf * 4, num_blocks[2], stride=2) + self.layer4 = self._make_layer(block, nf * 8, num_blocks[3], stride=2) + + # hardcoded for now + self.last_hid_size = nf * 8 * block.expansion if input_size[1] in [8,16,21,32,42] else 640 + #self.linear = nn.Linear(last_hid, num_classes) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def return_hidden(self, x): + bsz = x.size(0) + #pre_bn = self.conv1(x.view(bsz, 3, 32, 32)) + #post_bn = self.bn1(pre_bn, 1 if is_real else 0) + #out = F.relu(post_bn) + out = F.relu(self.bn1(self.conv1(x.view(bsz, *self.input_size)))) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = F.avg_pool2d(out, 4) + out = out.view(out.size(0), -1) + return out + + def forward(self, x): + out = self.return_hidden(x) + #out = self.linear(out) + return out + +def ResNet18(nf=20, input_size=(3, 32, 32)): + return ResNet(BasicBlock, [2, 2, 2, 2], nf, input_size) \ No newline at end of file diff --git a/nets/initialization.py b/nets/initialization.py new file mode 100644 index 0000000..acee318 --- /dev/null +++ b/nets/initialization.py @@ -0,0 +1,36 @@ +import logging +import numpy as np + +logger = logging.getLogger(__name__) +logger.setLevel('INFO') + +def get_glove_matrix(vocab, glove_path, initial_embedding_np): + """ + return a glove embedding matrix + :param initial_embedding_np: + :return: np array of [V,E] + """ + ef = open(glove_path, 'r', encoding='utf-8') + cnt = 0 + vec_array = initial_embedding_np + old_avg = np.average(vec_array) + old_std = np.std(vec_array) + vec_array = vec_array.astype(np.float32) + new_avg, new_std = 0, 0 + + for line in ef.readlines(): + line = line.strip().split(' ') + word, vec = line[0].lower(), line[1:] + vec = np.array(vec, np.float32) + if word in vocab: + cnt += 1 + word_idx = vocab[word] + vec_array[word_idx] = vec + new_avg += np.average(vec) + new_std += np.std(vec) + new_avg /= cnt + new_std /= cnt + ef.close() + print('%d known embedding. old mean: %f new mean %f, old std %f new std %f' % (cnt, old_avg, + new_avg, old_std, new_std)) + return vec_array diff --git a/nets/simplenet.py b/nets/simplenet.py new file mode 100644 index 0000000..500c99d --- /dev/null +++ b/nets/simplenet.py @@ -0,0 +1,83 @@ +import torch.nn.functional as func +import torch.nn as nn +import torch + +class GatedDense(nn.Module): + def __init__(self, input_size, output_size, activation=None): + super(GatedDense, self).__init__() + + self.activation = activation + self.sigmoid = nn.Sigmoid() + self.h = nn.Linear(input_size, output_size) + self.g = nn.Linear(input_size, output_size) + + def forward(self, x): + h = self.h(x) + if self.activation is not None: + h = self.activation( self.h( x ) ) + + g = self.sigmoid( self.g( x ) ) + + return h * g + +class FC2Layers(nn.Module): + def __init__(self, **kwargs): + super(FC2Layers, self).__init__() + layer1_width = 400 + + self.ds_idx = 0 + self.num_of_datasets = kwargs.get("num_of_datasets", 1) + self.num_of_classes = kwargs.get("num_of_classes", 10) + self.input_size = kwargs.get("input_size", 784) + self.task_incremental = kwargs.get('task_incremental', False) + + act = nn.ReLU() + self.layer1 = nn.Sequential( + nn.Linear(self.input_size, layer1_width), + nn.ReLU(), + nn.Linear(layer1_width, layer1_width), + nn.ReLU() + ) + + self.last_layer = nn.Linear(layer1_width, self.num_of_classes * self.num_of_datasets) + self.criterion = func.cross_entropy + + + def forward(self, x, y, task_ids=None, reduce=True, from_weights=False, weights=None): + x = x.view(-1, self.input_size) + if from_weights: + out = func.relu(func.linear(x, weights[0], weights[1])) + feat = func.relu(func.linear(out, weights[2], weights[3])) + out = func.linear(feat, weights[4], weights[5]) + else: + out = self.layer1(x) + feat = func.relu(out) + out = self.last_layer(feat) + + if self.task_incremental: + scores = [] # get logits of corresponding tasks + for b in range(x.size(0)): + scores.append(feat[b, task_ids[b] * self.num_of_classes: (task_ids[b] + 1) * self.num_of_classes]) + out = torch.stack(scores) + + if reduce: + loss = self.criterion(out, y) + else: + loss = self.criterion(out, y, reduction='none') + return {'score': out, 'loss': loss, 'feat': feat} + + def forward_from_feat(self, feat, y, reduce=True, **kwargs): + out = self.last_layer[self.ds_idx](feat) + if reduce: + loss = self.criterion(out, y) + else: + loss = self.criterion(out, y, reduction='none') + return {'score': out, 'loss': loss, 'feat': feat} + + + +def mnist_simple_net_400width_classlearning_1024input_10cls_1ds(**kwargs): + return FC2Layers(input_size=28 * 28, layer1_width=400, layer2_width=400, **kwargs) + +def mnist_simple_net_400width_domainlearning_1024input_10cls_1ds(**kwargs): + return FC2Layers(input_size=28 * 28, layer1_width=400, layer2_width=400, **kwargs) \ No newline at end of file diff --git a/ocl/__init__.py b/ocl/__init__.py new file mode 100644 index 0000000..e32d7e5 --- /dev/null +++ b/ocl/__init__.py @@ -0,0 +1,5 @@ +from .agem import AGEM +from .er import ExperienceReplay +from .naive import NaiveWrapper +from .ver import FOExperienceEvolve +from .ver_approx import ExperienceEvolveApprox \ No newline at end of file diff --git a/ocl/agem.py b/ocl/agem.py new file mode 100644 index 0000000..2666414 --- /dev/null +++ b/ocl/agem.py @@ -0,0 +1,89 @@ +import torch +from torch import nn +from .er import ExperienceReplay, store_grad +from config import cfg +from utils.utils import get_config_attr + +def overwrite_grad(pp, newgrad, grad_dims): + """ + This is used to overwrite the gradients with a new gradient + vector, whenever violations occur. + pp: parameters + newgrad: corrected gradient + grad_dims: list storing number of parameters at each layer + """ + cnt = 0 + for param in pp(): + if param.grad is not None: + beg = 0 if cnt == 0 else sum(grad_dims[:cnt]) + en = sum(grad_dims[:cnt + 1]) + this_grad = newgrad[beg: en].contiguous().view( + param.grad.data.size()) + param.grad.data.copy_(this_grad) + cnt += 1 + + +class AGEM(ExperienceReplay): + def __init__(self, base, optimizer, input_size, cfg, goal): + super(AGEM, self).__init__(base, optimizer, input_size, cfg, goal) + + # self.grads = self.grads.cuda() + self.violation_count = 0 + self.agem_k = get_config_attr(cfg, 'EXTERNAL.OCL.AGEM_K', default=256) + + def compute_grad(self, mem_x, mem_y): + self.zero_grad() + ret_dict = self.forward_net(mem_x, mem_y) + # regularize both of the loss + loss = ret_dict['loss'] + loss.backward() + grads = torch.Tensor(sum(self.grad_dims)).to(mem_x.device) + store_grad(self.parameters, grads, self.grad_dims) + return grads + + def fix_grad(self, mem_grads): + # check whether the current gradient interfere with the average gradients + grads = torch.Tensor(sum(self.grad_dims)).to(mem_grads.device) + store_grad(self.parameters, grads, self.grad_dims) + dotp = torch.dot(grads, mem_grads) + if dotp < 0: + # project the grads back to the mem_grads + # g_new = g - g^Tg_{ref} / g_{ref}^Tg_{ref} * g_{ref} + new_grad = grads - (torch.dot(grads, mem_grads) / (torch.dot(mem_grads, mem_grads) + 1e-12)) * mem_grads + overwrite_grad(self.parameters, new_grad, self.grad_dims) + return 1 + else: + return 0 + + def observe(self, x, y, task_ids=None, extra=None, optimize=True): + # recover image, feat from x + self.optimizer.zero_grad() + mem_x, mem_y, _ = self.sample_mem_batch(x.device, k=self.agem_k) + + if mem_x is not None: + # calculate gradients on the memory batch + mem_grads = self.compute_grad(mem_x, mem_y) + + # backward on the current minibatch + self.optimizer.zero_grad() + batch_size = x.size(0) + ret_dict = \ + self.forward_net(x, y) + for b in range(batch_size): + #self.update_mem(x[b], y[b], task_ids[b] if task_ids is not None else None) + if type(y) is tuple: + y_ = [_[b] for _ in y] + else: + y_ = y[b] + self.update_mem(x[b], y_, task_ids[b] if task_ids is not None else None, + x_feat=None) + loss = ret_dict['loss'] + if optimize: + loss.backward() + if mem_x is not None: + violated = self.fix_grad(mem_grads) + self.violation_count += violated + + self.optimizer.step() + + return ret_dict diff --git a/ocl/er.py b/ocl/er.py new file mode 100644 index 0000000..1ef838b --- /dev/null +++ b/ocl/er.py @@ -0,0 +1,774 @@ +import torch +import numpy as np +import pickle +from utils.utils import get_config_attr +import copy +from .naive import NaiveWrapper +from torch.optim import SGD, Adam +from collections import defaultdict + + +import math +try: + from pytorch_transformers import AdamW +except ImportError: + AdamW = None + + +def y_to_np(y): + if type(y) is tuple: + return tuple(x.item() for x in y) + else: + return y.cpu().numpy() + + +def y_to_cpu(y): + if torch.is_tensor(y): + y = y.cpu() + else: + y = [_.cpu() for _ in y] + return y + + +def index_select(l, indices, device): + ret = [] + for i in indices: + if type(l[i]) is np.ndarray: + x = torch.from_numpy(l[i]).to(device) + ret.append(x.unsqueeze(0)) + else: + if type(l[i]) is list: + item = [] + for j in range(len(l[i])): + if type(l[i][j]) is np.ndarray: + item.append(torch.from_numpy(l[i][j])) + else: + item.append(l[i][j]) + ret.append(item) + else: + ret.append(l[i]) + return ret + + +def concat_with_padding(l): + if l is None or l[0] is None: return None + if type(l[0]) in [list, tuple]: + ret = [torch.cat(t, 0) for t in zip(*l)] + else: + if len(l[0].size()) == 2: # potentially requires padding + max_length = max([x.size(1) for x in l]) + ret = [] + for x in l: + pad = torch.zeros(x.size(0), max_length - x.size(1)).long().to(x.device) + x_pad = torch.cat([x, pad], -1) + ret.append(x_pad) + ret = torch.cat(ret, 0) + else: + ret = torch.cat(l, 0) + return ret + + +def store_grad(pp, grads, grad_dims): + """ + This stores parameter gradients of past tasks. + pp: parameters + grads: gradients + grad_dims: list with number of parameters per layers + tid: task id + """ + # store the gradients + grads.fill_(0.0) + cnt = 0 + for param in pp(): + if param.grad is not None: + beg = 0 if cnt == 0 else sum(grad_dims[:cnt]) + en = sum(grad_dims[:cnt + 1]) + grads[beg: en].copy_(param.grad.data.view(-1)) + cnt += 1 + +class ExperienceReplay(NaiveWrapper): + def __init__(self, base, optimizer, input_size, cfg, goal): + super().__init__(base, optimizer, input_size, cfg, goal) + self.net = base + self.optimizer = optimizer + self.mem_limit = cfg.EXTERNAL.REPLAY.MEM_LIMIT + self.mem_bs = cfg.EXTERNAL.REPLAY.MEM_BS + self.input_size = input_size + self.reservoir, self.example_seen = None, None + self.reset_mem() + self.mem_occupied = {} + + self.seen_tasks = [] + self.balanced = False + self.policy = get_config_attr(cfg, 'EXTERNAL.OCL.POLICY', default='reservoir', totype=str) + self.mir_k = get_config_attr(cfg, 'EXTERNAL.REPLAY.MIR_K', default=10, totype=int) + self.mir = get_config_attr(cfg, 'EXTERNAL.OCL.MIR', default=0, totype=int) + + self.mir_agg = get_config_attr(cfg, 'EXTERNAL.OCL.MIR_AGG', default='avg', totype=str) + + self.concat_replay = get_config_attr(cfg, 'EXTERNAL.OCL.CONCAT', default=0, totype=int) + self.separate_replay = get_config_attr(cfg, 'EXTERNAL.OCL.SEPARATE', default=0, totype=int) + self.mem_augment = get_config_attr(cfg, 'EXTERNAL.OCL.MEM_AUG', default=0, totype=int) + self.legacy_aug = get_config_attr(cfg, 'EXTERNAL.OCL.LEGACY_AUG', default=0, totype=int) + self.use_hflip_aug = get_config_attr(cfg,'EXTERNAL.OCL.USE_HFLIP_AUG',default=1,totype=int) + self.padding_aug = get_config_attr(cfg,'EXTERNAL.OCL.PADDING_AUG',default=-1,totype=int) + self.rot_aug = get_config_attr(cfg,'EXTERNAL.OCL.ROT_AUG',default=-1,totype=int) + + self.lb_reservoir = get_config_attr(cfg,'EXTERNAL.OCL.LB_RESERVOIR', default=0) + + self.cfg = cfg + self.grad_dims = [] + + for param in self.parameters(): + self.grad_dims.append(param.data.numel()) + + def get_config_padding(self, pad, rot): + if self.padding_aug != -1: + pad = self.padding_aug + if self.rot_aug != -1: + rot = self.rot_aug + return pad, rot + + def reset_mem(self): + self.reservoir = {'x': np.zeros((self.mem_limit, self.input_size)), + 'y': [None] * self.mem_limit, + 'y_extra': [None] * self.mem_limit, + 'replay_time': np.zeros(self.mem_limit), + 'loss_stats': np.zeros(self.mem_limit), + 'label_cnts': defaultdict(int), + } + self.example_seen = 0 + + def update_mem(self, *args, **kwargs): + if self.policy == 'balanced': + return self.update_mem_balanced(*args, **kwargs) + elif self.policy == 'reservoir': + return self.update_mem_reservoir(*args, **kwargs) + elif self.policy == 'clus': + return self.update_mem_kmeans(*args, **kwargs) + else: + raise ValueError + + def reinit_mem(self, xsize): + self.input_size = xsize + self.reset_mem() + + def forward(self, *args, **kwargs): + return self.net(*args, **kwargs) + + def update_mem_reservoir(self, x, y, y_extra=None, loss_x=None, *args, **kwargs): + if self.example_seen == 0 and self.reservoir['x'].shape[-1] != x.shape[-1]: + self.reinit_mem(x.shape[-1]) + + x = x.cpu().numpy() + if type(y) not in [list, tuple]: + y = y_to_np(y) + #if y.shape == (1,): + # y = y[0] + else: + y = y_to_cpu(y) + + if type(y_extra) not in [list, tuple]: + y_extra = y_to_np(y_extra) + elif y_extra is not None: + y_extra = y_to_cpu(y_extra) + + if self.example_seen < self.mem_limit: + self.reservoir['x'][self.example_seen] = x + self.reservoir['y'][self.example_seen] = y + self.reservoir['y_extra'][self.example_seen] = y_extra + #self.reservoir['replay_time'][self.example_seen] = 0 + self.reservoir['label_cnts'][y.item()] += 1 + j = self.example_seen + else: + + # + #else: + j = np.random.RandomState(self.example_seen + self.cfg.SEED).randint(0, self.example_seen) + if j < self.mem_limit: + if self.lb_reservoir: + j = self.get_loss_aware_balanced_reservoir_sampling_index() + self.reservoir['label_cnts'][self.reservoir['y'][j].item()] -= 1 + self.reservoir['x'][j] = x + self.reservoir['y'][j] = y + self.reservoir['y_extra'][j] = y_extra + #self.reservoir['replay_time'][j] = 0 + self.reservoir['label_cnts'][y.item()] += 1 + + #if loss_x is not None: + # self.reservoir['loss_stats'][j] = loss_x if not torch.is_tensor(loss_x) else loss_x.item() + #self.reservoir['loss_stat_steps'][j] = [self.example_seen] + #self.reservoir['forget'] = 0 + self.example_seen += 1 + + def update_loss_states(self, loss, indices): + for i in range(len(indices)): + self.reservoir['loss_stats'][indices[i]] = loss[i].item() + + def get_loss_aware_balanced_reservoir_sampling_index(self): + # assumes the mem is full + _random = np.random.RandomState(self.example_seen + self.cfg.SEED) + s_balance = np.array([self.reservoir['label_cnts'][mem_y.item()] for mem_y in self.reservoir['y']]) + s_loss = -self.reservoir['loss_stats'] + alpha = abs(s_balance.sum()) / s_loss.sum() + s = s_loss * alpha + s_balance + probs = s / s.sum() + idx = _random.choice(len(probs), p=probs) + return idx + + def update_mem_kmeans(self, x, y, y_extra=None, x_feat=None, **kwargs): + if 'x_clus' not in self.reservoir: + self.reservoir['x_clus'] = np.zeros((self.mem_limit, x_feat.shape[0])) + self.reservoir['x_cnt'] = [None] * self.mem_limit + self.reservoir['x_feat'] = np.zeros((self.mem_limit, x_feat.shape[0])) + x = x.cpu().numpy() + + if type(y) not in [list, tuple]: + y = y_to_np(y) + else: + y = y_to_cpu(y) + + if self.example_seen < self.mem_limit: + self.reservoir['x'][self.example_seen] = x + self.reservoir['y'][self.example_seen] = y + self.reservoir['y_extra'][self.example_seen] = y_extra + self.reservoir['x_clus'][self.example_seen] = x_feat + self.reservoir['x_cnt'][self.example_seen] = 0 + self.reservoir['x_feat'][self.example_seen] = x_feat + else: + # compute L2 distance + center_i, center_d = -1, 1e10 + for mb in range(self.mem_limit): + dist = np.sum(np.square(x_feat - self.reservoir['x_clus'][mb])) + if dist < center_d: + center_i = mb + center_d = dist + cnt = self.reservoir['x_cnt'][center_i] + self.reservoir['x_clus'][center_i] = (self.reservoir['x_clus'][center_i] * cnt + x_feat) / (cnt + 1) + self.reservoir['x_cnt'][center_i] += 1 + + d_mem = np.sum(np.square(x_feat - self.reservoir['x_clus'][center_i])) + d_mem_best = np.sum(np.square(self.reservoir['x_feat'][center_i] - self.reservoir['x_clus'][center_i])) + if d_mem < d_mem_best: + self.reservoir['x'][center_i] = x + self.reservoir['y'][center_i] = y + self.reservoir['y_extra'][center_i] = y_extra + self.reservoir['x_feat'][center_i] = x_feat + self.example_seen += 1 + + def compute_offset(self, task_id, n): + idx = self.seen_tasks.index(task_id) + return int(idx / n * self.mem_limit), int((idx + 1) / n * self.mem_limit) + + def reallocate_memory(self, old_mem, old_occ): + new_mem = {'x': np.zeros((self.mem_limit, self.input_size)), + 'y': [None] * self.mem_limit, + 'y_extra': [None] * self.mem_limit} + new_occ = {} + n_tasks = len(self.seen_tasks) + for task in self.seen_tasks: + old_offset_start, old_offset_stop = self.compute_offset(task, n_tasks) + new_offset_start, new_offset_stop = self.compute_offset(task, n_tasks + 1) + i = 0 + while new_offset_start + i < new_offset_stop: + new_offset = new_offset_start + i + old_offset = old_offset_start + i + for k in old_mem: + new_mem[k][new_offset] = old_mem[k][old_offset] + i += 1 + new_occ[task] = old_occ[task] # min(old_occ[task], new_offset_stop - new_offset_start) + return new_mem, new_occ + + def update_mem_balanced(self, x, y): + y_attr, y_obj = y_to_np(y) + y_obj = y_obj.item() + y_attr = y_attr.item() + x = x.cpu().numpy() + + if y_obj not in self.seen_tasks: + # reallocate memory by expanding seen task by 1 + new_mem, new_occ = self.reallocate_memory(self.reservoir, self.mem_occupied) + self.reservoir = new_mem + self.mem_occupied = new_occ + self.mem_occupied[y_obj] = 0 + self.seen_tasks.append(y_obj) + offset_start, offset_stop = self.compute_offset(y_obj, len(self.seen_tasks)) + + if self.mem_occupied[y_obj] < offset_stop - offset_start: + pos = self.mem_occupied[y_obj] + offset_start + self.reservoir['x'][pos] = x + self.reservoir['y'][pos] = y + # self.reservoir['y_attr'][pos] = y_attr + else: + j = np.random.RandomState(self.example_seen + self.cfg.SEED).randint(0, self.mem_occupied[y_obj]) + if j < offset_stop - offset_start: + pos = j + offset_start + self.reservoir['x'][pos] = x + self.reservoir['y'][pos] = y + # self.reservoir['y_attr'][pos] = y_attr + + self.mem_occupied[y_obj] += 1 + self.example_seen += 1 + + def get_available_index(self): + l = [] + for idx, t in enumerate(self.seen_tasks): + offset_start, offset_stop = self.compute_offset(t, len(self.seen_tasks)) + for i in range(offset_start, min(offset_start + self.mem_occupied[t], offset_stop)): + l.append(i) + return l + + def get_random(self, seed=1): + random_state = None + for i in range(seed): + if random_state is None: + random_state = np.random.RandomState(self.example_seen + self.cfg.SEED) + else: + random_state = np.random.RandomState(random_state.randint(0, int(1e5))) + return random_state + + def store_cache(self): + # for i, param in enumerate(self.net.parameters()): + # self.parameter_cache[i].copy_(param.data) + self.cache = copy.deepcopy(self.net.state_dict()) + + def load_cache(self): + self.net.load_state_dict(self.cache) + self.net.zero_grad() + + def get_loss_and_pseudo_update(self, x, y, task_ids): + ret_dict_d = self.forward_net(x, y, task_ids) + self.optimizer.zero_grad() + ret_dict_d['loss'].backward(retain_graph=False) + if isinstance(self.optimizer, torch.optim.SGD): + step_wo_state_update_sgd(self.optimizer, amp=1.) + elif isinstance(self.optimizer, torch.optim.Adam): + step_wo_state_update_adam(self.optimizer, amp=1.) + elif isinstance(self.optimizer, AdamW): + step_wo_state_update_adamw(self.optimizer) + else: + raise NotImplementedError + return ret_dict_d + + def decide_mir_mem(self, x, y, task_ids, mir_k, cand_x, cand_y, cand_task_ids, indices, mir_least): + if cand_x.size(0) < mir_k: + return cand_x, cand_y, cand_task_ids, indices + else: + self.store_cache() + if type(cand_y[0]) not in [list, tuple]: + cand_y = concat_with_padding(cand_y) + else: + cand_y = [torch.stack(_).to(x.device) for _ in zip(*cand_y)] + with torch.no_grad(): + ret_dict_mem_before = self.forward_net(cand_x, cand_y, reduce=False, task_ids=cand_task_ids) + ret_dict_d = self.get_loss_and_pseudo_update(x, y, task_ids) + with torch.no_grad(): + ret_dict_mem_after = self.forward_net(cand_x, cand_y, reduce=False, task_ids=cand_task_ids) + loss_increase = ret_dict_mem_after['loss'] - ret_dict_mem_before['loss'] + with torch.no_grad(): + if self.goal == 'captioning': + if self.mir_agg == 'avg': + loss_increase_by_ts = loss_increase.view(cand_x.size(0), -1).sum(1) + mask_num_by_ts = (cand_y[2] != -1).sum(1).float() + 1e-10 + loss_increase = loss_increase_by_ts / mask_num_by_ts + elif self.mir_agg == 'max': + loss_increase, _ = loss_increase.view(cand_x.size(0), -1).max(1) + + _, topi = loss_increase.topk(mir_k, largest=not mir_least) + if type(cand_y) is not list: + mem_x, mem_y, mem_task_ids = cand_x[topi], cand_y[topi], cand_task_ids[topi] + else: + mem_x, mem_task_ids = cand_x[topi], cand_task_ids[topi] + mem_y = [_[topi] for _ in cand_y] + + self.load_cache() + return mem_x, mem_y, mem_task_ids, indices[topi.cpu()] + + + def sample_mem_batch(self, device, return_indices=False, k=None, seed=1, + mir=False, input_x=None, input_y=None, input_task_ids=None, mir_k=0, + skip_task=None, mir_least=False): + random_state = self.get_random(seed) + if k is None: + k = self.mem_bs + + if not self.balanced: + # reservoir + n_max = min(self.mem_limit, self.example_seen) + available_indices = [_ for _ in range(n_max)] + if skip_task is not None and get_config_attr(self.cfg, 'EXTERNAL.REPLAY.FILTER_SELF', default=0, mute=True): + available_indices = list(filter(lambda x: self.reservoir['y_extra'][x] != skip_task, available_indices)) + if not available_indices: + if return_indices: + return None, None, None + else: + return None, None, None + elif len(available_indices) < k: + indices = np.arange(n_max) + else: + indices = random_state.choice(available_indices, k, replace=False) + else: + available_index = self.get_available_index() + if len(available_index) == 0: + if return_indices: + return None, None, None + else: + return None, None, None + elif len(available_index) < k: + indices = np.array(available_index) + else: + indices = random_state.choice(available_index, k, replace=False) + x = self.reservoir['x'][indices] + x = torch.from_numpy(x).to(device).float() + + y = index_select(self.reservoir['y'], indices, device) # [ [...], [...] ] + y_extra = index_select(self.reservoir['y_extra'], indices, device) + if type(y[0]) not in [list, tuple]: + y_pad = concat_with_padding(y) + else: + y_pad = [torch.stack(_).to(device) for _ in zip(*y)] + y_extra = concat_with_padding(y_extra) + + if mir: + x, y_pad, y_extra, indices = self.decide_mir_mem(input_x, input_y, input_task_ids, mir_k, + x, y, y_extra, indices, mir_least) + + if not return_indices: + return x, y_pad, y_extra + else: + return (x, indices), y_pad, y_extra + + def observe(self, x, y, task_ids=None, extra=None, optimize=True): + # recover image, feat from x + if task_ids is None: + task_ids = torch.zeros(x.size(0)).to(x.device).long() + + if self.mir: + mem_x, mem_y, mem_task_ids = self.sample_mem_batch(x.device, input_x=x, input_y=y, input_task_ids=task_ids, + mir_k=self.mir_k, mir=self.mir, + skip_task=task_ids[0].item()) + else: + mem_x, mem_y, mem_task_ids = self.sample_mem_batch(x.device, skip_task=task_ids[0].item()) + + batch_size = x.size(0) + if mem_x is not None and not self.separate_replay and not self.goal == 'captioning': # a dirty fix to prevent oom + if not self.mem_augment: + combined_x = torch.cat([x, mem_x], 0) + combined_y = concat_with_padding([y, mem_y]) + combined_task_ids = concat_with_padding([task_ids, mem_task_ids]) + else: + aug_mem_x = self.transform_image_batch(mem_x) + combined_x = torch.cat([x,mem_x,aug_mem_x],0) + combined_y = concat_with_padding([y,mem_y,mem_y]) + combined_task_ids = concat_with_padding([task_ids, mem_task_ids, mem_task_ids]) + else: + combined_x, combined_y, combined_task_ids = x, y, task_ids + + ret_dict = self.forward_net(combined_x, combined_y, combined_task_ids, + reduce=self.concat_replay or self.separate_replay) + + loss_tmp = ret_dict['loss'] + if optimize: + # loss = loss_tmp.mean() + # print(loss.item()) + if self.concat_replay or self.separate_replay: + loss = ret_dict['loss'] + else: + loss = loss_tmp[: x.size(0)].mean() + if mem_x is not None: + loss += loss_tmp[x.size(0):].mean() + + self.optimizer.zero_grad() + + if self.concat_replay and mem_x is not None: + loss = loss / 2 + + loss.backward() + + #if mem_x is None or (not self.separate_replay and not self.goal == 'captioning'): + if not self.concat_replay or mem_x is None: + self.optimizer.step() + + if (self.separate_replay or self.goal == 'captioning') and mem_x is not None: + ret_dict_mem = self.forward_net(mem_x, mem_y, mem_task_ids, reduce=True) + + if not self.concat_replay: + self.optimizer.zero_grad() + + if self.concat_replay: + ret_dict_mem['loss'] = ret_dict_mem['loss'] / 2 + + ret_dict_mem['loss'].backward() + self.optimizer.step() + ret_dict['loss'] = (ret_dict['loss'] + ret_dict_mem['loss']) / 2 + + for b in range(batch_size): # x.size(0) + if type(y) is tuple: + y_ = [_[b] for _ in y] + else: + y_ = y[b] + self.update_mem(x[b], y_, task_ids[b] if task_ids is not None else None, + x_feat=None) + return ret_dict + + def dump_reservoir(self, path, verbose=False): + f = open(path, 'wb') + pickle.dump({ + 'reservoir_x': self.reservoir['x'] if verbose else None, + 'reservoir_y': self.reservoir['y'], + 'reservoir_y_extra': self.reservoir['y_extra'], + 'mem_occupied': self.mem_occupied, + 'example_seen': self.example_seen, + 'seen_tasks': self.seen_tasks, + 'balanced': self.balanced + }, f) + f.close() + + def load_reservoir(self, path): + try: + f = open(path, 'rb') + dic = pickle.load(f) + for k in dic: + setattr(self, k, dic[k]) + f.close() + return dic + except FileNotFoundError: + print('no replay buffer dump file') + return {} + + def load_reservoir_from_dic(self, dic): + for k in dic: + setattr(self, k, dic[k]) + + def get_reservoir(self): + return {'reservoir': self.reservoir, 'mem_occupied': self.mem_occupied, + 'example_seen': self.example_seen, 'seen_tasks': self.seen_tasks, + 'balanced': self.balanced} + + def mean_of_exemplar_classify(self, cropped_image_inp): + if not hasattr(self, 'mean_exemplar_vec'): + mean_exemplar_vec = [] + for task in self.seen_tasks: + offset_start, offset_stop = self.compute_offset(task, len(self.seen_tasks)) + x = self.reservoir['x'][offset_start: min(offset_stop, offset_start + self.mem_occupied[task])] + x = torch.from_numpy(x).float().to(cropped_image_inp.device) + cropped_image = x[:, : 3 * self.cfg.EXTERNAL.IMAGE_SIZE * self.cfg.EXTERNAL.IMAGE_SIZE] \ + .view(-1, 3, self.cfg.EXTERNAL.IMAGE_SIZE, self.cfg.EXTERNAL.IMAGE_SIZE) + feat = self.net.get_obj_features(cropped_image) + mean_feat = feat.mean(0) + mean_exemplar_vec.append(mean_feat) + self.mean_exemplar_vec = torch.stack(mean_exemplar_vec) # [C, H] + + feat = self.net.get_obj_features(cropped_image_inp) # [B, H] + mean_exemplar_vec_expand = self.mean_exemplar_vec.unsqueeze(0).expand(feat.size(0), -1, -1) # [B,C,H] + feat_expand = feat.unsqueeze(1).expand(-1, self.mean_exemplar_vec.size(0), -1) # [B,C,H] + dist = torch.sum((mean_exemplar_vec_expand - feat_expand) ** 2, -1) # [B,C] + dist = torch.sqrt(dist) + _, pred = dist.min(-1) # [B] + + pred_index_fix = torch.zeros(pred.size()).to(pred.device) + for b in range(pred.size(0)): + pred_index_fix[b] = self.seen_tasks[pred[b]] + + return pred_index_fix + + +def step_wo_state_update_adam(adam, closure=None, amp=1.): + """Performs a single optimization step. Do not update optimizer states + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + if type(adam) is not Adam: + raise ValueError + loss = None + if closure is not None: + loss = closure() + + for group in adam.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') + amsgrad = group['amsgrad'] + + state = adam.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p.data) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + if amsgrad: + max_exp_avg_sq = state['max_exp_avg_sq'] + beta1, beta2 = group['betas'] + + # state['step'] += 1 + + if group['weight_decay'] != 0: + grad.add_(group['weight_decay'], p.data) + + # Decay the first and second moment running average coefficient + exp_avg = exp_avg.mul(beta1).add(1 - beta1, grad) + exp_avg_sq = exp_avg_sq.mul(beta2).addcmul(1 - beta2, grad, grad) + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max. for normalizing running avg. of gradient + denom = max_exp_avg_sq.sqrt().add_(group['eps']) + else: + denom = exp_avg_sq.sqrt().add_(group['eps']) + + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 * amp + + p.data.addcdiv_(-step_size, exp_avg, denom) + + return loss + + +def step_wo_state_update_sgd(sgd, closure=None, amp=1.): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in sgd.param_groups: + weight_decay = group['weight_decay'] + momentum = group['momentum'] + dampening = group['dampening'] + nesterov = group['nesterov'] + + for p in group['params']: + if p.grad is None: + continue + d_p = p.grad.data + if weight_decay != 0: + d_p.add_(weight_decay, p.data) + if momentum != 0: + param_state = sgd.state[p] + if 'momentum_buffer' not in param_state: + buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() + else: + buf = param_state['momentum_buffer'] + buf.mul_(momentum).add_(1 - dampening, d_p) + if nesterov: + d_p = d_p.add(momentum, buf) + else: + d_p = buf + + p.data.add_(-group['lr'] * amp, d_p) + + return loss + + +def step_wo_state_update_adamw(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') + + state = self.state[p] + + # State initialization + #if len(state) == 0: + # state['step'] = 0 + # # Exponential moving average of gradient values + # state['exp_avg'] = torch.zeros_like(p.data) + # # Exponential moving average of squared gradient values + # state['exp_avg_sq'] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + #state['step'] += 1 + + # Decay the first and second moment running average coefficient + # In-place operations to update the averages at the same time + exp_avg = exp_avg.mul(beta1).add(1.0 - beta1, grad) + exp_avg_sq = exp_avg_sq.mul(beta2).addcmul(1.0 - beta2, grad, grad) + denom = exp_avg_sq.sqrt().add(group['eps']) + + step_size = group['lr'] + if group['correct_bias']: # No bias correction for Bert + bias_correction1 = 1.0 - beta1 ** state['step'] + bias_correction2 = 1.0 - beta2 ** state['step'] + step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 + + p.data.addcdiv_(-step_size, exp_avg, denom) + + # Just adding the square of the weights to the loss function is *not* + # the correct way of using L2 regularization/weight decay with Adam, + # since that will interact with the m and v parameters in strange ways. + # + # Instead we want to decay the weights in a manner that doesn't interact + # with the m/v parameters. This is equivalent to adding the square + # of the weights to the loss with plain (non-momentum) SGD. + # Add weight decay at the end (fixed version) + if group['weight_decay'] > 0.0: + p.data.add_(-group['lr'] * group['weight_decay'], p.data) + + return loss + +def get_updated_weights_sgd(sgd, closure=None, amp=1.): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + weights = [] + for group in sgd.param_groups: + weight_decay = group['weight_decay'] + momentum = group['momentum'] + dampening = group['dampening'] + nesterov = group['nesterov'] + + for p in group['params']: + if p.grad is None: + weights.append(p.data) + d_p = p.grad + if weight_decay != 0: + d_p = d_p.add(weight_decay, p.data) + if momentum != 0: + param_state = sgd.state[p] + if 'momentum_buffer' not in param_state: + buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() + else: + buf = param_state['momentum_buffer'] + buf.mul_(momentum).add_(1 - dampening, d_p) + if nesterov: + d_p = d_p.add(momentum, buf) + else: + d_p = buf + weights.append(p.data - group['lr'] * amp * d_p) + return weights \ No newline at end of file diff --git a/ocl/naive.py b/ocl/naive.py new file mode 100644 index 0000000..f523206 --- /dev/null +++ b/ocl/naive.py @@ -0,0 +1,100 @@ +from torch import nn +from torch import optim +import torch +from utils.utils import get_config_attr +import numpy as np + +class NaiveWrapper(nn.Module): + def __init__(self, base, optimizer, input_size, cfg, goal, **kwargs): + super().__init__() + self.net = base + self.optimizer = optimizer + self.input_size = input_size + self.cfg = cfg + self.goal = goal + + if 'caption' in self.goal: + self.clip_grad = True + self.use_image_feat = get_config_attr(self.cfg, 'EXTERNAL.USE_IMAGE_FEAT', default=0) + self.spatial_feat_shape = (2048, 7, 7) + self.bbox_feat_shape = (100, 2048) + self.bbox_shape = (100,4) + self.rev_optim = None + if hasattr(self.net, 'rev_update_modules'): + self.rev_optim = optim.Adam(lr=cfg.SOLVER.BASE_LR, betas=(0.9, 0.999), + params=self.net.rev_update_modules.parameters()) + + def forward(self, *args, **kwargs): + return self.net(*args, **kwargs) + + def forward_net(self, x, y, task_ids=None, **kwargs): + if self.goal in ['classification']: + if type(y) is tuple: + attr_labels, obj_labels = y + else: + attr_labels, obj_labels = y, None + cropped_image = x[:, : 3 * self.cfg.EXTERNAL.IMAGE_SIZE * self.cfg.EXTERNAL.IMAGE_SIZE] \ + .view(-1, 3, self.cfg.EXTERNAL.IMAGE_SIZE, self.cfg.EXTERNAL.IMAGE_SIZE) + images = x[:, 3 * self.cfg.EXTERNAL.IMAGE_SIZE * self.cfg.EXTERNAL.IMAGE_SIZE:] \ + .view(-1, 3, self.cfg.EXTERNAL.IMAGE_SIZE, self.cfg.EXTERNAL.IMAGE_SIZE) + ret_dict = self.net( + bbox_images=cropped_image, spatial_feat=None, + images=images, + attr_labels=attr_labels, + obj_labels=obj_labels, + task_ids=task_ids + **kwargs + ) + elif self.goal == 'captioning': + # legacy + if self.use_image_feat: + images = x.view(-1, 3, self.cfg.EXTERNAL.IMAGE_SIZE, self.cfg.EXTERNAL.IMAGE_SIZE) + captions, caption_lens, labels = y + ret_dict = self.net(images=images, captions=captions, caption_lens=caption_lens, labels=labels, **kwargs) + else: + # rebuild spatial and packed inputs + batch_size = x.size(0) + bfeat_dim = np.prod(self.bbox_feat_shape) + # bbox_feats, spatial_feats, bboxes = x[:, :bfeat_dim].view(batch_size, *self.bbox_feat_shape), \ + # x[:,bfeat_dim:bfeat_dim + sfeat_dim].view(batch_size, *self.spatial_feat_shape), \ + # x[:,bfeat_dim + sfeat_dim:].view(batch_size, *self.bbox_shape) + bbox_feats, bboxes = x[:, :bfeat_dim].view(batch_size, *self.bbox_feat_shape), \ + x[:,bfeat_dim:].view(batch_size, *self.bbox_shape) + captions, caption_lens, labels = y + ret_dict = self.net(bbox_feats=bbox_feats, captions=captions, caption_lens=caption_lens, labels=labels, + bboxes=bboxes, **kwargs) + + elif self.goal == 'split_mnist' or self.goal == 'permute_mnist' or self.goal == 'rotated_mnist': + images = x.view(x.size(0), -1) + ret_dict = self.net(images, y, task_ids=task_ids, **kwargs) + elif self.goal == 'split_cifar': + images = x.view(-1, 3, 32, 32) + ret_dict = self.net(bbox_images=images, spatial_feat=None, images=None, attr_labels=y, obj_labels=None, + task_ids=task_ids, **kwargs) + elif self.goal == 'split_mini_imagenet': + images = x.view(-1,3,84,84) + ret_dict = self.net(bbox_images=images, spatial_feat=None, images=None, attr_labels=y, obj_labels=None, + task_ids=task_ids, **kwargs) + else: + raise ValueError + return ret_dict + + def observe(self, x, y, task_ids, optimize=True): + # if deprecated is not None: + # y = (y, deprecated) + # recover image, feat from x + self.optimizer.zero_grad() + ret_dict = \ + self.forward_net(x, y, task_ids) + + loss = ret_dict['loss'] + if optimize: + loss.backward() + #if self.clip_grad: + # torch.nn.utils.clip_grad_norm_(self.net.parameters(), 1.0) + self.optimizer.step() + + return ret_dict + + def initialize_word_emb(self, *args, **kwargs): + self.net.initialize_word_emb(*args, **kwargs) diff --git a/ocl/utils.py b/ocl/utils.py new file mode 100644 index 0000000..b59903e --- /dev/null +++ b/ocl/utils.py @@ -0,0 +1,39 @@ +import numpy as np + +def get_glove_matrix(vocab, glove_path, initial_embedding_np): + """ + return a glove embedding matrix + :param initial_embedding_np: + :return: np array of [V,E] + """ + ef = open(glove_path, 'r', encoding='utf-8') + cnt = 0 + vec_array = initial_embedding_np + old_avg = np.average(vec_array) + old_std = np.std(vec_array) + vec_array = vec_array.astype(np.float32) + new_avg, new_std = 0, 0 + + indices = [] + + for line in ef.readlines(): + line = line.strip().split(' ') + word, vec = line[0].lower(), line[1:] + vec = np.array(vec, np.float32) + if word in vocab: + cnt += 1 + word_idx = vocab[word] + if word_idx < vec_array.shape[0]: + vec_array[word_idx] = vec + new_avg += np.average(vec) + new_std += np.std(vec) + indices.append(word_idx) + new_avg /= cnt + new_std /= cnt + ef.close() + print('%d known embedding. old mean: %f new mean %f, old std %f new std %f' % (cnt, old_avg, + new_avg, old_std, new_std)) + # scale the added embeddings + vec_array[indices] *= old_avg / new_avg + + return vec_array \ No newline at end of file diff --git a/ocl/ver.py b/ocl/ver.py new file mode 100644 index 0000000..1f00c0a --- /dev/null +++ b/ocl/ver.py @@ -0,0 +1,141 @@ +from .er import * +import math +from torch.optim import Adam +import copy +from utils.utils import get_config_attr + +class FOExperienceEvolve(ExperienceReplay): + def __init__(self, base, optimizer, input_size, cfg, goal): + super().__init__(base, optimizer, input_size, cfg, goal) + # self.reservoir = {'x': np.zeros((self.mem_limit, input_size)), + # 'y': [None] * self.mem_limit, + # 'y_extra': [None] * self.mem_limit, + # 'x_origin': np.zeros((self.mem_limit, input_size)), + # 'x_edit_state': [None] * self.mem_limit, + # 'loss_stats': [None] * self.mem_limit, + # 'loss_stat_steps': [None] * self.mem_limit, + # 'forget': [None] * self.mem_limit, + # 'support': [None] * self.mem_limit + # } + self.itf_cnt = 0 + self.total_cnt = 0 + self.grad_iter = get_config_attr(cfg, 'EXTERNAL.OCL.GRAD_ITER', default=1) + self.grad_stride = get_config_attr(cfg, 'EXTERNAL.OCL.GRAD_STRIDE', default=10.) + self.edit_decay = get_config_attr(cfg, 'EXTERNAL.OCL.EDIT_DECAY', default=0.) + self.no_write_back = get_config_attr(cfg, 'EXTERNAL.OCL.NO_WRITE_BACK', default=0) + self.reservoir['age'] = np.zeros(self.mem_limit) + + def get_mem_ages(self, indices, astype): + ages = self.reservoir['age'][indices] + if torch.is_tensor(astype): + ages = torch.from_numpy(ages).float().to(astype.device) + return ages + + def observe(self, x, y, task_ids, extra=None, optimize=True, sequential=False): + sequential = True + global total_cnt, itf_cnt + self.optimizer.zero_grad() + mem_x_indices, mem_x_origin, mem_y, mem_task_ids = self.sample_mem_batch(x.device, return_indices=True) + + batch_size = x.size(0) + if mem_x_indices is None: + combined_x, combined_y, combined_task_ids = self.sample_mem_batch() + else: + mem_x, indices = mem_x_indices + self.store_cache() + for i in range(self.grad_iter): + # evaluate loss on mem_x, mem_y + mem_x.requires_grad = True + mem_x.grad = None + mem_x_origin.requires_grad = True + + # evaluate grad of l wrt mem + self.optimizer.zero_grad() + ret_dict_mem_before = self.forward_net(mem_x, mem_y, reduce=False, task_ids=task_ids) + # grad_l = -torch.autograd.grad(torch.sum(ret_dict_mem_origin_before['loss']), mem_x_origin, retain_graph=True)[0] + + # train the model on D + if not sequential: + self.get_loss_and_pseudo_update(x, y, task_ids) + else: + self.train(False) + for b in range(batch_size): + x_b = x[b].unsqueeze(0) + if type(y) in [tuple, list]: + y_b = [_[b].unsqueeze(0) for _ in y] + else: + y_b = y[b].unsqueeze(0) + ret_dict_db = self.forward_net(x_b, y_b) + self.optimizer.zero_grad() + ret_dict_db['loss'].backward() + if isinstance(self.optimizer, torch.optim.SGD): + step_wo_state_update_sgd(self.optimizer, amp=1.) + elif isinstance(self.optimizer, torch.optim.Adam): + step_wo_state_update_adam(self.optimizer, amp=1.) + else: + raise NotImplementedError + self.train(True) + + ret_dict_mem_after = self.forward_net(mem_x, mem_y, reduce=False) + if 'mask_cnts' not in ret_dict_mem_after: + loss_increase = (ret_dict_mem_after['loss'] - ret_dict_mem_before['loss']).mean() + else: + loss_increase = (ret_dict_mem_after['loss'] - ret_dict_mem_before['loss']).sum() / \ + (sum(ret_dict_mem_after['mask_cnts']) + 1e-10) + loss_increase.backward(retain_graph=False) + grad_delta = mem_x.grad + + self.load_cache() + + mem_ages = self.get_mem_ages(mem_x_indices, astype=mem_x) + stride_decayed = (1 - self.edit_decay) ** mem_ages + + proposed_mem_x = mem_x + self.grad_stride * stride_decayed.view(-1,1) * grad_delta + proposed_mem_x.detach_() + + mem_x = proposed_mem_x + mem_x = mem_x.detach() + self.evolve_mem(mem_x, indices) + + # load cached parameters back + self.load_cache() + combined_x = torch.cat([x, mem_x], 0) + combined_y = concat_with_padding([y, mem_y]) + + ret_dict = self.forward_net(combined_x, combined_y) + + for b in range(batch_size): + if type(y) is tuple: + self.update_mem(x[b], [_[b] for _ in y], extra[b] if extra is not None else None) + else: + self.update_mem(x[b], y[b], extra[b] if extra is not None else None) + + loss = ret_dict['loss'] + if optimize: + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + return ret_dict + + def evolve_mem(self, x, indices): + for i, idx in enumerate(indices): + #if self.reservoir['age'][idx] < 10: + self.reservoir['x'][idx] = x[i].cpu().numpy() + self.reservoir['age'][idx] += 1 + + + +total_cnt, itf_cnt = 0, 0 +def proj_grad(a, b, binary, always_proj): + # project b to the direction of a + + dotp = torch.dot(a,b) + if dotp >= 0: + if not always_proj: + return b + else: + return b - (torch.dot(-a,b) / torch.dot(a,a)) * a + else: + if binary: return torch.zeros_like(a) + else: return b - (torch.dot(a,b) / torch.dot(a,a)) * a diff --git a/ocl/ver_approx.py b/ocl/ver_approx.py new file mode 100644 index 0000000..d076403 --- /dev/null +++ b/ocl/ver_approx.py @@ -0,0 +1,365 @@ +from .ver import * +from torch.nn import functional as F +import random + + +class ExperienceEvolveApprox(FOExperienceEvolve): + def __init__(self, base, optimizer, input_size, cfg, goal): + super().__init__(base, optimizer, input_size, cfg, goal) + self.edit_least = get_config_attr(cfg, 'EXTERNAL.OCL.EDIT_LEAST', default=0) + self.edit_random = get_config_attr(cfg, 'EXTERNAL.OCL.EDIT_RANDOM', default=0) + + self.edit_interfere = get_config_attr(cfg, 'EXTERNAL.OCL.EDIT_INTERFERE', default=1) + self.edit_replace = get_config_attr(cfg, 'EXTERNAL.OCL.EDIT_REPLACE', default=0) + self.replace_reweight = get_config_attr(cfg, 'EXTERNAL.OCL.REPLACE_REWEIGHT', default=0) + self.use_relu = get_config_attr(cfg, 'EXTERNAL.OCL.USE_RELU', default=0) + self.reg_strength = get_config_attr(cfg, 'EXTERNAL.OCL.REG_STRENGTH', default=0.1) + self.always_proj = get_config_attr(cfg, 'EXTERNAL.OCL.ALWAYS_PROJ', default=0) + + self.edit_mir_k = get_config_attr(cfg, 'EXTERNAL.OCL.EDIT_K', default=-1) + + self.hal_mem = get_config_attr(cfg, 'EXTERNAL.OCL.HAL_MEM', default=0) + self.post_edit_mem_aug = get_config_attr(cfg,'EXTERNAL.OCL.POST_EDIT_MEM_AUG', default=1) + self.edit_aug_mem = get_config_attr(cfg,'EXTERNAL.OCL.EDIT_AUG_MEM', default=0) + self.double_weight = get_config_attr(cfg,'EXTERNAL.OCL.DOUBLE_WEIGHT', default=0) + + if self.edit_mir_k == -1: + self.edit_mir_k = self.mir_k + + def sample_mem_batch_same_task(self, device, task_id_or_label, return_indices=False, mem_k=None, seed=0, use_same_label=False): + if mem_k is None: + mem_k = self.mem_bs + if use_same_label: + label = task_id_or_label + else: + task_id = task_id_or_label + + n_max = min(self.mem_limit, self.example_seen) + indices = [] + for i in range(n_max): + if use_same_label: + if self.reservoir['y'][i] == label: + indices.append(i) + else: + if self.reservoir['y_extra'][i] == task_id: + indices.append(i) + + # reservoir + if not indices: + return None, None, None + elif len(indices) >= mem_k: + indices = np.random.RandomState(seed * self.example_seen + self.cfg.SEED).\ + choice(indices, mem_k, replace=False) + + x = self.reservoir['x'][indices] + #x_origin = self.reservoir['x_origin'][indices] + + x = torch.from_numpy(x).to(device).float() + #x_origin = torch.from_numpy(x_origin).to(device).float() + y = index_select(self.reservoir['y'], indices, device) # [ [...], [...] ] + y_extra = index_select(self.reservoir['y_extra'], indices, device) + y_extra = concat_with_padding(y_extra) + if type(y[0]) not in [list, tuple]: + y_pad = concat_with_padding(y) + else: + y_pad = [torch.stack(_).to(device) for _ in zip(*y)] + + if not return_indices: + return x, y_pad, y_extra + else: + return (x, indices), y_pad, y_extra + + def clear_mem_grad(self, mem_x): + mem_x.detach_() + mem_x.grad = None + mem_x.requires_grad = True + + + def observe(self, x, y, task_ids, extra=None, optimize=True, sequential=False): + n_iter = get_config_attr(self.cfg, 'EXTERNAL.OCL.N_ITER', default=1, mute=True) + batch_size = x.size(0) + self.store_cache() + for i_iter in range(n_iter): + if not self.mir: + mem_x_indices, mem_y, mem_task_ids = self.sample_mem_batch(x.device, return_indices=True, seed=i_iter + 1) + if self.edit_random: # select another batch for editing + edit_x_indices, edit_y, edit_task_ids = self.sample_mem_batch(x.device, return_indices=True, + k=self.mir_k, seed=i_iter + 2) + else: + edit_x_indices, edit_y, edit_task_ids = mem_x_indices, mem_y, mem_task_ids + else: + mem_x_indices, mem_y, mem_task_ids = self.sample_mem_batch(x.device, return_indices=True, + input_x=x, input_y=y, input_task_ids=task_ids, + mir_k=self.mir_k, mir=self.mir, + skip_task=task_ids[0].item(), + seed=i_iter + 1 + ) + if self.edit_least: + edit_x_indices, edit_y, edit_task_ids = self.sample_mem_batch(x.device, return_indices=True, + input_x=x, input_y=y, input_task_ids=task_ids, + mir_k=self.edit_mir_k, mir=self.mir, + skip_task=task_ids[0].item(), + mir_least=True, + seed=i_iter + 2 + ) + elif self.edit_random: + edit_x_indices, edit_y, edit_task_ids = self.sample_mem_batch(x.device, return_indices=True, + k=self.edit_mir_k, seed=i_iter + 2) + else: + edit_x_indices, edit_y, edit_task_ids = mem_x_indices, mem_y, mem_task_ids + self.optimizer.zero_grad() + + edit_x_val_indices, edit_y_val, _ = self.sample_mem_batch_same_task(x.device, task_ids.cpu().numpy()[0], + return_indices=True, seed=i_iter + 2, + mem_k=self.mir_k if self.mir else self.mem_bs) + edit_task_ids_val = task_ids + #mem_x_val_indices, _, mem_y_val, mem_task_ids_val = self.sample_mem_batch(x.device, return_indices=True, seed=1) + + if edit_x_indices is None: + combined_x, combined_y, combined_task_ids = x, y, task_ids + else: + mem_x, mem_indices = mem_x_indices + edit_x, edit_indices = edit_x_indices + if edit_x_val_indices is not None: + edit_x_val, indices_val = edit_x_val_indices + else: + edit_x_val, indices_val = None, None + if self.mem_augment: + aug_mem_x = self.transform_image_batch(mem_x) + if i_iter == 0 and self.edit_interfere: + if self.hal_mem: + train_x, train_y, _ = self.sample_mem_batch(x.device, return_indices=False, + seed=i_iter + 3) + else: + train_x, train_y = x, y + + edit_x, mem_x = self.edit_mem_interfere(train_x, train_y, task_ids, mem_x, mem_y, edit_x, edit_y, edit_task_ids, + edit_x_val, edit_y_val, edit_task_ids_val, edit_indices) + if not self.no_write_back: + self.evolve_mem(edit_x, edit_indices) + + # load cached parameters back + self.load_cache() + if not self.mem_augment: + combined_x = torch.cat([x, mem_x], 0) + combined_y = concat_with_padding([y, mem_y]) + combined_task_ids = concat_with_padding([task_ids, mem_task_ids]) + else: + if self.post_edit_mem_aug: + aug_mem_x = self.transform_image_batch(mem_x) + if self.edit_aug_mem: + aug_mem_x, _ = self.edit_mem_interfere(train_x, train_y, task_ids, aug_mem_x, mem_y, aug_mem_x, + mem_y, edit_task_ids, + edit_x_val, edit_y_val, edit_task_ids_val, edit_indices) + if self.double_weight: + combined_x = torch.cat([x, mem_x, aug_mem_x], 0) + combined_y = concat_with_padding([y, mem_y, mem_y]) + combined_task_ids = concat_with_padding([task_ids, mem_task_ids, mem_task_ids]) + else: + combined_x = torch.cat([x, mem_x, aug_mem_x], 0) + combined_y = concat_with_padding([y, mem_y, mem_y]) + combined_task_ids = concat_with_padding([task_ids, mem_task_ids, mem_task_ids]) + + ret_dict = self.forward_net(combined_x, combined_y, task_ids=combined_task_ids, reduce=False) + loss_tmp = ret_dict['loss'] + if optimize: + loss = loss_tmp[: x.size(0)].mean() + if mem_x_indices is not None: + loss += loss_tmp[x.size(0):].mean() #* (2 if self.double_weight else 1) + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + if self.lb_reservoir and mem_x_indices is not None: + self.update_loss_states(loss_tmp[x.size(0):], mem_indices) + for b in range(batch_size): + if type(y) is tuple: + self.update_mem(x[b], [_[b] for _ in y], task_ids[b]) + else: + self.update_mem(x[b], y[b], task_ids[b]) + + return ret_dict + + def edit_mem_interfere(self, x, y, task_ids, mem_x, mem_y, edit_x, edit_y, edit_task_ids, + edit_x_val, edit_y_val, edit_task_ids_val, edit_indices): + """ + Edit memory so that they are more inter + :param x: + :param y: + :param task_ids: + :param mem_x: + :param mem_y: + :param edit_x: + :param edit_y: + :param edit_task_ids: + :param edit_x_val: + :param edit_y_val: + :param edit_task_ids_val: + :return: + """ + device = x.device + # only edit at the first iter + + for i in range(self.grad_iter): + # evaluate loss on edit_x, edit_y + self.clear_mem_grad(edit_x) + # evaluate grad of l wrt edit + ret_dict_edit_before = self.forward_net(edit_x, edit_y, reduce=True, task_ids=edit_task_ids) + # train the model on D + grad_reg = -torch.autograd.grad(torch.sum(ret_dict_edit_before['loss']), + edit_x, retain_graph=True)[0] + ret_dict_edit_before['loss'].backward() + edit_x_grad1 = edit_x.grad + + self.clear_mem_grad(edit_x) + + for _ in range(1): + ret_dict_d = self.forward_net(x, y, task_ids=task_ids) + self.optimizer.zero_grad() + ret_dict_d['loss'].backward(retain_graph=False) + if isinstance(self.optimizer, torch.optim.SGD): + step_wo_state_update_sgd(self.optimizer, amp=1.) + elif isinstance(self.optimizer, torch.optim.Adam): + step_wo_state_update_adam(self.optimizer, amp=1.) + else: + raise NotImplementedError + + ret_dict_edit_after = self.forward_net(edit_x, edit_y, reduce=True, task_ids=edit_task_ids) + if 'mask_cnts' not in ret_dict_edit_after: + loss_increase = ret_dict_edit_after['loss'] - ret_dict_edit_before['loss'] + else: + loss_increase = (ret_dict_edit_after['loss'] - ret_dict_edit_before['loss']).sum() / \ + (sum(ret_dict_edit_after['mask_cnts']) + 1e-10) + + #if self.use_relu: + # loss_increase = F.relu(loss_increase) + ret_dict_edit_after['loss'].backward() + edit_x_grad2 = edit_x.grad + grad_delta = edit_x_grad2 - edit_x_grad1 + + grad_delta_2 = 0 + + self.clear_mem_grad(edit_x) + self.load_cache() + + total_grad = 0 + + if self.cfg.EXTERNAL.OCL.USE_LOSS_1: + if self.cfg.EXTERNAL.OCL.USE_LOSS_1 == 1: + total_grad += self.cfg.EXTERNAL.OCL.USE_LOSS_1 * grad_delta + elif self.cfg.EXTERNAL.OCL.USE_LOSS_1 == -2: # random tied direction uniform norm + random_vecs = self.get_random_grad(grad_delta, edit_indices) + total_grad += random_vecs + elif self.cfg.EXTERNAL.OCL.USE_LOSS_1 == -3: # random tied direction keep norm + random_vecs = self.get_direction_perturbed_grad(grad_delta, edit_indices) + total_grad += random_vecs + elif self.cfg.EXTERNAL.OCL.USE_LOSS_1 == -4: # random untied direction keep norm + random_vecs = self.get_direction_perturbed_grad_untied(grad_delta, edit_indices) + total_grad += random_vecs + elif self.cfg.EXTERNAL.OCL.USE_LOSS_1 == -5: # negated gradient + neg_vecs = self.get_neg_gradients(grad_delta, edit_indices) + total_grad += neg_vecs + elif self.cfg.EXTERNAL.OCL.USE_LOSS_1 == -6: # adversarial-continuous + total_grad += edit_x_grad1 + elif self.cfg.EXTERNAL.OCL.USE_LOSS_1 == -7: # untied, fixed + random_vecs = self.get_direction_perturbed_grad_untied_fixed(grad_delta, edit_indices) + total_grad += random_vecs + elif self.cfg.EXTERNAL.OCL.USE_LOSS_1 == -8: # PGD + total_grad += torch.sign(edit_x_grad1) + else: + raise ValueError + + if self.cfg.EXTERNAL.OCL.USE_LOSS_2: + total_grad += self.cfg.EXTERNAL.OCL.USE_LOSS_2 * grad_delta_2 + + if type(total_grad) is not int: # has grad update + #if self.cfg.EXTERNAL.OCL.PROJ_LOSS_REG: + if self.cfg.EXTERNAL.OCL.PROJ_LOSS_REG == 1: + for b in range(total_grad.size(0)): + total_grad[b] = proj_grad(-grad_reg[b], total_grad[b], binary=False, always_proj=self.always_proj) + elif self.cfg.EXTERNAL.OCL.PROJ_LOSS_REG == 2: + for b in range(total_grad.size(0)): + total_grad[b] = -grad_reg[b] * self.reg_strength + total_grad[b] + + mem_ages = self.get_mem_ages(edit_indices, astype=edit_x) + stride_decayed = (1 - self.edit_decay) ** mem_ages + + for b in range(total_grad.size(0)): + edit_x[b] = edit_x[b] + self.grad_stride * stride_decayed[b] * total_grad[b] + edit_x = edit_x.detach() + mem_x = mem_x.detach() + + return edit_x, mem_x + + # code 2: edit memory to a random direction tied to each input example + def get_random_grad(self, grad, indices): + if not hasattr(self, 'random_dirs'): + self.random_dirs = torch.zeros(self.mem_limit, *grad[0].size()).uniform_(-1,1) + random_vecs = [] + for i, indice in enumerate(indices): + random_vec = self.random_dirs[indice].to(grad.device) + random_vec = random_vec / random_vec.norm() + random_vecs.append(random_vec) + random_vecs = torch.stack(random_vecs) + return random_vecs + + # code 3: perturbate editing direction - keep norm + def get_direction_perturbed_grad(self, grad, indices): + if not hasattr(self, 'random_dirs'): + self.random_dirs = torch.zeros(self.mem_limit, *grad[0].size()).uniform_(-1, 1) + random_vecs = [] + for i, indice in enumerate(indices): + random_vec = self.random_dirs[indice].to(grad.device) + random_vec = random_vec / random_vec.norm() * grad[i].norm() + random_vecs.append(random_vec) + random_vecs = torch.stack(random_vecs) + return random_vecs + + # code 4: perturbate editing direction - not tied + def get_direction_perturbed_grad_untied(self, grad, indices): + self.random_dirs = torch.zeros(self.mem_limit, *grad[0].size()).to(grad.device).uniform_(-1, 1) + random_vecs = [] + for i, indice in enumerate(indices): + random_vec = self.random_dirs[indice] + random_vec = random_vec / random_vec.norm() * grad[i].norm() + random_vecs.append(random_vec) + random_vecs = torch.stack(random_vecs) + return random_vecs + + # code 5: flip the update direction + def get_neg_gradients(self, grad, indices): + return -grad + + # code 7: untied and fixed + def get_direction_perturbed_grad_untied_fixed(self, grad, indices): + self.random_dirs = torch.zeros(len(indices), *grad[0].size()).to(grad.device).normal_(0, 1) + random_vecs = [] + for i, indice in enumerate(indices): + random_vec = self.random_dirs[i] + random_vec = random_vec / random_vec.norm() + random_vecs.append(random_vec) + random_vecs = torch.stack(random_vecs) + return random_vecs + + def to_mem_type(self, x, y, y_extra): + x = x.cpu().numpy() + if type(y) not in [list, tuple]: + y = y_to_np(y) + else: + y = y_to_cpu(y) + if type(y_extra) not in [list, tuple]: + y_extra = y_to_np(y_extra) + else: + y_extra = y_to_cpu(y_extra) + return x, y, y_extra + + def indices_to_examples(self, indices, device): + cand_x = self.reservoir['x'][indices] + cand_x = torch.from_numpy(cand_x).to(device).float() + cand_y = index_select(self.reservoir['y'], indices, device) # [ [...], [...] ] + cand_y_extra = index_select(self.reservoir['y_extra'], indices, device) + if type(cand_y[0]) not in [list, tuple]: + cand_y_pad = concat_with_padding(cand_y) + else: + cand_y_pad = [torch.stack(_).to(device) for _ in zip(*cand_y)] + cand_y_extra = concat_with_padding(cand_y_extra) + return cand_x, cand_y_pad, cand_y_extra \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..a68cacd --- /dev/null +++ b/requirements.txt @@ -0,0 +1,125 @@ +absl-py==0.9.0 +apex==0.1 +astunparse==1.6.3 +attrs==19.3.0 +backcall==0.1.0 +bleach==3.1.0 +boto3==1.11.9 +botocore==1.14.9 +cachetools==4.1.0 +certifi==2019.11.28 +cffi==1.13.2 +chardet==3.0.4 +cityscapesScripts==1.1.0 +Click==7.0 +cycler==0.10.0 +Cython==0.29.14 +decorator==4.4.1 +defusedxml==0.6.0 +docutils==0.15.2 +easydict==1.9 +entrypoints==0.3 +gast==0.3.3 +google-auth==1.15.0 +google-auth-oauthlib==0.4.1 +google-pasta==0.2.0 +grpcio==1.29.0 +h5py==2.10.0 +idna==2.8 +imagecodecs==2020.2.18 +imageio==2.8.0 +importlib-metadata==1.5.0 +inflect==4.1.0 +ipykernel==5.1.4 +ipython==7.10.2 +ipython-genutils==0.2.0 +jaraco.itertools==5.0.0 +jedi==0.15.1 +Jinja2==2.11.1 +jmespath==0.9.4 +joblib==0.14.1 +jsonlines==1.2.0 +jsonschema==3.2.0 +jupyter-client==5.3.4 +jupyter-core==4.6.1 +Keras-Preprocessing==1.1.2 +kiwisolver==1.1.0 +Markdown==3.2.2 +MarkupSafe==1.1.1 +matplotlib==3.1.2 +mistune==0.8.4 +mkl-fft==1.0.15 +mkl-random==1.1.0 +mkl-service==2.3.0 +more-itertools==8.2.0 +nbconvert==5.6.1 +nbformat==5.0.4 +networkx==2.4 +ninja==1.9.0.post1 +nltk==3.4.5 +notebook==6.0.3 +numpy==1.18.0 +oauthlib==3.1.0 +olefile==0.46 +opencv-python==4.1.2.30 +opt-einsum==3.2.1 +pandas==1.0.3 +pandocfilters==1.4.2 +parso==0.5.2 +pexpect==4.7.0 +pickleshare==0.7.5 +Pillow==6.2.1 +prometheus-client==0.7.1 +prompt-toolkit==3.0.2 +protobuf==3.10.0 +ptyprocess==0.6.0 +pyasn1==0.4.8 +pyasn1-modules==0.2.8 +pycocotools==2.0 +pycparser==2.19 +Pygments==2.5.2 +pygtrie==2.3.3 +pyparsing==2.4.6 +pyrsistent==0.15.7 +python-dateutil==2.8.1 +pytorch-transformers==1.2.0 +pytz==2020.1 +PyWavelets==1.1.1 +PyYAML==5.2 +pyzmq==18.1.1 +quadprog==0.1.7 +recordclass==0.12.0.1 +regex==2019.8.19 +requests==2.22.0 +requests-oauthlib==1.3.0 +rsa==4.0 +s3transfer==0.3.2 +sacremoses==0.0.38 +scikit-image==0.17.2 +scikit-learn==0.22 +scipy==1.3.2 +Send2Trash==1.5.0 +sentencepiece==0.1.85 +six==1.13.0 +stanfordnlp==0.2.0 +tensorboard==2.2.1 +tensorboard-plugin-wit==1.6.0.post3 +tensorboardX==2.0 +tensorflow==2.2.0 +tensorflow-estimator==2.2.0 +termcolor==1.1.0 +terminado==0.8.3 +testpath==0.4.4 +tifffile==2020.5.11 +torch==1.1.0 +torchvision==0.2.2 +tornado==6.0.3 +tqdm==4.41.0 +traitlets==4.3.3 +urllib3==1.25.8 +wcwidth==0.1.7 +webencodings==0.5.1 +Werkzeug==1.0.1 +wrapt==1.12.1 +yacs==0.1.6 +zipp==2.1.0 diff --git a/scripts/gmed_mini_imagenet.sh b/scripts/gmed_mini_imagenet.sh new file mode 100644 index 0000000..337f56e --- /dev/null +++ b/scripts/gmed_mini_imagenet.sh @@ -0,0 +1,56 @@ +#!/bin/bash + +min_seed=5 +max_seed=9 +model_name=${1} +grad_iter=${2} +grad_stride=${3} +use_loss_1=${4} +use_loss_2=${5} +proj_loss_reg=${6} +mem_size=${7} +extra_opt=${8} +extra="" +extra_args="" + +if [[ -n "$9" ]]; then + min_seed=${9} +fi +if [[ -n "${10}" ]]; then + max_seed=${10} +fi +if [[ -n "${11}" ]]; then + extra_args=${11} +fi +#if [[ -n "$9" ]]; then +# mem_bs=${9} +#else +# mem_bs=25 +#fi +mem_bs=25 + +seed=$min_seed + + +if [[ ${extra_opt} == "mir" ]] +then + extra="EXTERNAL.REPLAY.MEM_BS=${mem_bs} EXTERNAL.REPLAY.MIR_K=10 EXTERNAL.OCL.MIR=1" +fi +if [[ ${extra_opt} == "mirl" ]] +then + extra="EXTERNAL.REPLAY.MEM_BS=${mem_bs} EXTERNAL.REPLAY.MIR_K=10 EXTERNAL.OCL.MIR=1 EXTERNAL.OCL.EDIT_LEAST=1" +fi +if [[ ${extra_opt} == "mirr" ]] +then + extra="EXTERNAL.REPLAY.MEM_BS=${mem_bs} EXTERNAL.REPLAY.MIR_K=10 EXTERNAL.OCL.MIR=1 EXTERNAL.OCL.EDIT_RANDOM=1" +fi +if [[ ${extra_opt} == 'relu' ]] +then + extra="EXTERNAL.OCL.USE_RELU=1" +fi + +while(( $seed<= $max_seed )) +do + python train.py --name ${1} --config configs/memevolve/verx_mini_imagenet.yaml --seed ${seed} --cfg EXTERNAL.OCL.GRAD_ITER=${grad_iter} EXTERNAL.OCL.GRAD_STRIDE=${grad_stride} EXTERNAL.OCL.USE_LOSS_1=${use_loss_1} EXTERNAL.OCL.USE_LOSS_2=${use_loss_2} EXTERNAL.OCL.PROJ_LOSS_REG=${proj_loss_reg} EXTERNAL.REPLAY.MEM_LIMIT=${mem_size} ${extra} ${extra_args} + let "seed++" +done \ No newline at end of file diff --git a/scripts/gmed_permuted_mnist.sh b/scripts/gmed_permuted_mnist.sh new file mode 100644 index 0000000..9eac682 --- /dev/null +++ b/scripts/gmed_permuted_mnist.sh @@ -0,0 +1,78 @@ +#!/bin/bash + +min_seed=0 +max_seed=4 +seed=$min_seed +model_name=${1} +grad_iter=${2} +grad_stride=${3} +use_loss_1=${4} +use_loss_2=${5} +proj_loss_reg=${6} +mem_size=${7} +extra_opt=${8} +extra="" +extra_args="" + +if [[ -n "$9" ]]; then + min_seed=${9} +fi +if [[ -n "${10}" ]]; then + max_seed=${10} +fi +if [[ -n "${11}" ]]; then + extra_args=${11} +fi +#if [[ -n "$9" ]]; then +# mem_bs=${9} +#else +# mem_bs=25 +#fi +mem_bs=50 + +seed=$min_seed + +if [[ ${extra_opt} == "mir" ]] +then + extra="EXTERNAL.REPLAY.MEM_BS=${mem_bs} EXTERNAL.REPLAY.MIR_K=10 EXTERNAL.OCL.MIR=1" +fi +if [[ ${extra_opt} == "mirl" ]] +then + extra="EXTERNAL.REPLAY.MEM_BS=${mem_bs} EXTERNAL.REPLAY.MIR_K=10 EXTERNAL.OCL.MIR=1 EXTERNAL.OCL.EDIT_LEAST=1" +fi +if [[ ${extra_opt} == "mirr" ]] +then + extra="EXTERNAL.REPLAY.MEM_BS=${mem_bs} EXTERNAL.REPLAY.MIR_K=10 EXTERNAL.OCL.MIR=1 EXTERNAL.OCL.EDIT_RANDOM=1" +fi +if [[ ${extra_opt} == "edit_replace" ]] +then + extra="EXTERNAL.OCL.EDIT_INTERFERE=0 EXTERNAL.OCL.EDIT_REPLACE=1" +fi +if [[ ${extra_opt} == "rrw" ]] +then + extra="EXTERNAL.OCL.REPLACE_REWEIGHT=1" +fi +if [[ ${extra_opt} == "rrwr" ]] +then + extra="EXTERNAL.OCL.REPLACE_REWEIGHT=2" +fi +if [[ ${extra_opt} == "supp_proj" ]] +then + extra="EXTERNAL.OCL.REG_SUPPORTIVE=1" +fi +if [[ ${extra_opt} == "supp_reg" ]] +then + extra="EXTERNAL.OCL.REG_SUPPORTIVE=2" +fi + +config_file="configs/memevolve/ver_permute_approx.yaml" +if [[ ${extra_opt} == "hal" ]] +then + config_file="configs/baselines/hal_permuted_mnist.yaml" +fi + +while(( $seed<= $max_seed )) +do + python train.py --name ${1} --config ${config_file} --seed ${seed} --cfg EXTERNAL.OCL.GRAD_ITER=${grad_iter} EXTERNAL.OCL.GRAD_STRIDE=${grad_stride} EXTERNAL.OCL.USE_LOSS_1=${use_loss_1} EXTERNAL.OCL.USE_LOSS_2=${use_loss_2} EXTERNAL.OCL.PROJ_LOSS_REG=${proj_loss_reg} EXTERNAL.REPLAY.MEM_LIMIT=${mem_size} ${extra} ${extra_args} + let "seed++" +done \ No newline at end of file diff --git a/scripts/gmed_rotate_mnist.sh b/scripts/gmed_rotate_mnist.sh new file mode 100644 index 0000000..66cf003 --- /dev/null +++ b/scripts/gmed_rotate_mnist.sh @@ -0,0 +1,77 @@ +#!/bin/bash + +min_seed=0 +max_seed=9 +model_name=${1} +grad_iter=${2} +grad_stride=${3} +use_loss_1=${4} +use_loss_2=${5} +proj_loss_reg=${6} +mem_size=${7} +extra_opt=${8} +extra="" +extra_arg="" + + +#if [[ -n "$9" ]]; then +# mem_bs=${9} +#else +# mem_bs=50 +#fi +mem_bs=50 + +if [[ -n "$9" ]]; then + min_seed=${9} +fi +if [[ -n "${10}" ]]; then + max_seed=${10} +fi + +seed=$min_seed + +if [[ -n "${11}" ]]; then + extra_args=${11} +fi + +if [[ ${extra_opt} == "mir" ]] +then + extra="EXTERNAL.REPLAY.MEM_BS=${mem_bs} EXTERNAL.REPLAY.MIR_K=10 EXTERNAL.OCL.MIR=1" +fi +if [[ ${extra_opt} == "mirl" ]] +then + extra="EXTERNAL.REPLAY.MEM_BS=${mem_bs} EXTERNAL.REPLAY.MIR_K=10 EXTERNAL.OCL.MIR=1 EXTERNAL.OCL.EDIT_LEAST=1" +fi +if [[ ${extra_opt} == "mirr" ]] +then + extra="EXTERNAL.REPLAY.MEM_BS=${mem_bs} EXTERNAL.REPLAY.MIR_K=10 EXTERNAL.OCL.MIR=1 EXTERNAL.OCL.EDIT_RANDOM=1" +fi +if [[ ${extra_opt} == "supp_proj" ]] +then + extra="EXTERNAL.OCL.REG_SUPPORTIVE=1" +fi +if [[ ${extra_opt} == "supp_reg" ]] +then + extra="EXTERNAL.OCL.REG_SUPPORTIVE=2" +fi +if [[ ${extra_opt} == "supp_proj_meta" ]] +then + extra="EXTERNAL.OCL.REG_SUPPORTIVE=3" +fi +if [[ ${extra_opt} == "supp_reg_meta" ]] +then + extra="EXTERNAL.OCL.REG_SUPPORTIVE=4" +fi + +config_file="configs/memevolve/ver_permute_approx.yaml" +if [[ ${extra_opt} == "hal" ]] +then + config_file="configs/baselines/hal_permuted_mnist.yaml" +fi + + +while(( $seed<= $max_seed )) +do + python train.py --name ${1} --config configs/memevolve/ver_rotate_approx.yaml --seed ${seed} --cfg EXTERNAL.OCL.GRAD_ITER=${grad_iter} EXTERNAL.OCL.GRAD_STRIDE=${grad_stride} EXTERNAL.OCL.USE_LOSS_1=${use_loss_1} EXTERNAL.OCL.USE_LOSS_2=${use_loss_2} EXTERNAL.OCL.PROJ_LOSS_REG=${proj_loss_reg} EXTERNAL.REPLAY.MEM_LIMIT=${mem_size} ${extra} ${extra_args} + let "seed++" +done \ No newline at end of file diff --git a/scripts/gmed_split_cifar10.sh b/scripts/gmed_split_cifar10.sh new file mode 100644 index 0000000..e328ca0 --- /dev/null +++ b/scripts/gmed_split_cifar10.sh @@ -0,0 +1,78 @@ +#!/bin/bash + +min_seed=0 +max_seed=4 +seed=$min_seed +model_name=${1} +grad_iter=${2} +grad_stride=${3} +use_loss_1=${4} +use_loss_2=${5} +proj_loss_reg=${6} +mem_size=${7} +extra_opt=${8} +extra="" +extra_args="" + +if [[ -n "$9" ]]; then + min_seed=${9} +fi +if [[ -n "${10}" ]]; then + max_seed=${10} +fi +if [[ -n "${11}" ]]; then + extra_args=${11} +fi +#if [[ -n "$9" ]]; then +# mem_bs=${9} +#else +# mem_bs=25 +#fi +mem_bs=50 + +seed=$min_seed + +if [[ ${extra_opt} == "mir" ]] +then + extra="EXTERNAL.REPLAY.MEM_BS=${mem_bs} EXTERNAL.REPLAY.MIR_K=10 EXTERNAL.OCL.MIR=1" +fi +if [[ ${extra_opt} == "mirl" ]] +then + extra="EXTERNAL.REPLAY.MEM_BS=${mem_bs} EXTERNAL.REPLAY.MIR_K=10 EXTERNAL.OCL.MIR=1 EXTERNAL.OCL.EDIT_LEAST=1" +fi +if [[ ${extra_opt} == "mirr" ]] +then + extra="EXTERNAL.REPLAY.MEM_BS=${mem_bs} EXTERNAL.REPLAY.MIR_K=10 EXTERNAL.OCL.MIR=1 EXTERNAL.OCL.EDIT_RANDOM=1" +fi +if [[ ${extra_opt} == "rrw" ]] +then + extra="EXTERNAL.OCL.REPLACE_REWEIGHT=1" +fi +if [[ ${extra_opt} == "rrwr" ]] +then + extra="EXTERNAL.OCL.REPLACE_REWEIGHT=2" +fi +if [[ ${extra_opt} == "supp_proj" ]] +then + extra="EXTERNAL.OCL.REG_SUPPORTIVE=1" +fi +if [[ ${extra_opt} == "supp_reg" ]] +then + extra="EXTERNAL.OCL.REG_SUPPORTIVE=2" +fi +if [[ ${extra_opt} == 'relu' ]] +then + extra="EXTERNAL.OCL.USE_RELU=1" +fi + +config_file="configs/memevolve/verx_cifar.yaml" +if [[ ${extra_opt} == "hal" ]] +then + config_file="configs/baselines/hal_split_cifar.yaml" +fi + +while(( $seed<= $max_seed )) +do + python train.py --name ${1} --config ${config_file} --seed ${seed} --cfg EXTERNAL.OCL.GRAD_ITER=${grad_iter} EXTERNAL.OCL.GRAD_STRIDE=${grad_stride} EXTERNAL.OCL.USE_LOSS_1=${use_loss_1} EXTERNAL.OCL.USE_LOSS_2=${use_loss_2} EXTERNAL.OCL.PROJ_LOSS_REG=${proj_loss_reg} EXTERNAL.REPLAY.MEM_LIMIT=${mem_size} ${extra} ${extra_args} + let "seed++" +done \ No newline at end of file diff --git a/scripts/gmed_split_cifar100.sh b/scripts/gmed_split_cifar100.sh new file mode 100644 index 0000000..b1f34b4 --- /dev/null +++ b/scripts/gmed_split_cifar100.sh @@ -0,0 +1,84 @@ +#!/bin/bash + +min_seed=0 +max_seed=4 +seed=$min_seed +model_name=${1} +grad_iter=${2} +grad_stride=${3} +use_loss_1=${4} +use_loss_2=${5} +proj_loss_reg=${6} +mem_size=${7} +extra_opt=${8} +extra="" +extra_args="" + +if [[ -n "$9" ]]; then + min_seed=${9} +fi +if [[ -n "${10}" ]]; then + max_seed=${10} +fi +if [[ -n "${11}" ]]; then + extra_args=${11} +fi +#if [[ -n "$9" ]]; then +# mem_bs=${9} +#else +# mem_bs=25 +#fi +mem_bs=50 + +seed=$min_seed + +if [[ ${extra_opt} == "mir" ]] +then + extra="EXTERNAL.REPLAY.MEM_BS=${mem_bs} EXTERNAL.REPLAY.MIR_K=10 EXTERNAL.OCL.MIR=1" +fi +if [[ ${extra_opt} == "mirl" ]] +then + extra="EXTERNAL.REPLAY.MEM_BS=${mem_bs} EXTERNAL.REPLAY.MIR_K=10 EXTERNAL.OCL.MIR=1 EXTERNAL.OCL.EDIT_LEAST=1" +fi +if [[ ${extra_opt} == "mirr" ]] +then + extra="EXTERNAL.REPLAY.MEM_BS=${mem_bs} EXTERNAL.REPLAY.MIR_K=10 EXTERNAL.OCL.MIR=1 EXTERNAL.OCL.EDIT_RANDOM=1" +fi +if [[ ${extra_opt} == "rrw" ]] +then + extra="EXTERNAL.OCL.REPLACE_REWEIGHT=1" +fi +if [[ ${extra_opt} == "rrwr" ]] +then + extra="EXTERNAL.OCL.REPLACE_REWEIGHT=2" +fi +if [[ ${extra_opt} == "supp_proj" ]] +then + extra="EXTERNAL.OCL.REG_SUPPORTIVE=1" +fi +if [[ ${extra_opt} == "supp_reg" ]] +then + extra="EXTERNAL.OCL.REG_SUPPORTIVE=2" +fi +if [[ ${extra_opt} == 'relu' ]] +then + extra="EXTERNAL.OCL.USE_RELU=1" +fi + +config_file="configs/memevolve/verx_cifar100.yaml" + +if [[ ${extra_opt} == "hal" ]] +then + config_file="configs/baselines/hal_split_cifar100.yaml" +fi + +if [[ ${extra_opt} == "cndpm" ]] +then + config_file="configs/baselines/cndpm_split_cifar100.yaml" +fi + +while(( $seed<= $max_seed )) +do + python train.py --name ${1} --config ${config_file} --seed ${seed} --cfg EXTERNAL.OCL.GRAD_ITER=${grad_iter} EXTERNAL.OCL.GRAD_STRIDE=${grad_stride} EXTERNAL.OCL.USE_LOSS_1=${use_loss_1} EXTERNAL.OCL.USE_LOSS_2=${use_loss_2} EXTERNAL.OCL.PROJ_LOSS_REG=${proj_loss_reg} EXTERNAL.REPLAY.MEM_LIMIT=${mem_size} ${extra} ${extra_args} + let "seed++" +done \ No newline at end of file diff --git a/scripts/gmed_split_mnist.sh b/scripts/gmed_split_mnist.sh new file mode 100755 index 0000000..0924f8b --- /dev/null +++ b/scripts/gmed_split_mnist.sh @@ -0,0 +1,82 @@ +#!/bin/bash + +min_seed=0 +max_seed=4 +seed=$min_seed +model_name=${1} +grad_iter=${2} +grad_stride=${3} +use_loss_1=${4} +use_loss_2=${5} +proj_loss_reg=${6} +mem_size=${7} +extra_opt=${8} +extra="" + +mem_bs=50 +if [[ -n "$9" ]]; then + min_seed=${9} +fi +if [[ -n "${10}" ]]; then + max_seed=${10} +fi + +seed=$min_seed + +if [[ -n "${11}" ]]; then + extra_args=${11} +fi + + +if [[ ${extra_opt} == "mir" ]] +then + extra="EXTERNAL.REPLAY.MEM_BS=${mem_bs} EXTERNAL.REPLAY.MIR_K=10 EXTERNAL.OCL.MIR=1" +fi +if [[ ${extra_opt} == "mirl" ]] +then + extra="EXTERNAL.REPLAY.MEM_BS=${mem_bs} EXTERNAL.REPLAY.MIR_K=10 EXTERNAL.OCL.MIR=1 EXTERNAL.OCL.EDIT_LEAST=1" +fi +if [[ ${extra_opt} == "mirr" ]] +then + extra="EXTERNAL.REPLAY.MEM_BS=${mem_bs} EXTERNAL.REPLAY.MIR_K=10 EXTERNAL.OCL.MIR=1 EXTERNAL.OCL.EDIT_RANDOM=1" +fi +if [[ ${extra_opt} == "edit_replace" ]] +then + extra="EXTERNAL.OCL.EDIT_INTERFERE=0 EXTERNAL.OCL.EDIT_REPLACE=1" +fi +if [[ ${extra_opt} == "rrw" ]] +then + extra="EXTERNAL.OCL.REPLACE_REWEIGHT=1" +fi +if [[ ${extra_opt} == "rrwr" ]] +then + extra="EXTERNAL.OCL.REPLACE_REWEIGHT=2" +fi +if [[ ${extra_opt} == "supp_proj" ]] +then + extra="EXTERNAL.OCL.REG_SUPPORTIVE=1" +fi +if [[ ${extra_opt} == "supp_reg" ]] +then + extra="EXTERNAL.OCL.REG_SUPPORTIVE=2" +fi +if [[ ${extra_opt} == "supp_proj_meta" ]] +then + extra="EXTERNAL.OCL.REG_SUPPORTIVE=3" +fi +if [[ ${extra_opt} == "supp_reg_meta" ]] +then + extra="EXTERNAL.OCL.REG_SUPPORTIVE=4" +fi + +config_file="configs/memevolve/ver_approx.yaml" + +if [[ ${extra_opt} == "hal" ]] +then + config_file="configs/baselines/hal_split_mnist.yaml" +fi +while(( $seed<= $max_seed )) +do + python train.py --name ${1} --config ${config_file} --seed ${seed} --cfg EXTERNAL.OCL.GRAD_ITER=${grad_iter} EXTERNAL.OCL.GRAD_STRIDE=${grad_stride} EXTERNAL.OCL.USE_LOSS_1=${use_loss_1} EXTERNAL.OCL.USE_LOSS_2=${use_loss_2} EXTERNAL.OCL.PROJ_LOSS_REG=${proj_loss_reg} EXTERNAL.REPLAY.MEM_LIMIT=${mem_size} ${extra} ${extra_args} + let "seed++" +done \ No newline at end of file diff --git a/scripts/tune_params.py b/scripts/tune_params.py new file mode 100644 index 0000000..d348c32 --- /dev/null +++ b/scripts/tune_params.py @@ -0,0 +1,96 @@ +import os +import logging +import argparse +import random + +logger = logging.getLogger(__name__) + +ds_to_name = {'split_cifar': 'cifar10_5tasks_2class_ci_verx', 'mini_imagenet': 'mini_imagenet_ci_verx', + 'rotated_mnist': 'rotated_mnist_verx', 'split_mnist':'split_mnist_verx', + 'permuted_mnist': 'pm_verx'} +ds_to_config = {'split_cifar': 'verx_cifar.yaml', 'mini_imagenet': 'verx_mini_imagenet.yaml', + 'rotated_mnist': 'ver_rotate_approx.yaml', 'split_mnist':'ver_approx.yaml', + 'permuted_mnist': 'ver_permute_approx.yaml'} + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--start_seed', type=int, default=0) + parser.add_argument('--stop_seed', type=int, default=5) + parser.add_argument('--dataset') + parser.add_argument('--strides', type=float, nargs='+', default=[0.1,0.05,0.5,1.0]) + parser.add_argument('--proj_flags', type=int, nargs='+', default=[2]) + parser.add_argument('--reg_strengths', type=float, nargs='+', default=[0.001,0.01,0.1]) + parser.add_argument('--lr', type=float, nargs='*', default=[0]) + parser.add_argument('--iters', type=int, nargs='+') + parser.add_argument('--mem_sizes', type=int, nargs='+') + parser.add_argument('--use_relu', type=int, default=0) + parser.add_argument('--epoch', type=int, default=3) + parser.add_argument('--batch_size', type=int, default=10) + parser.add_argument('--task_num', type=int, default=3) + parser.add_argument('--extra', default='') + parser.add_argument('--name', default='') + parser.add_argument('--overwrite', action='store_true') + parser.add_argument('--mir', action='store_true') + parser.add_argument('--rand', action='store_true') + + args = parser.parse_args() + + logger.addHandler(logging.FileHandler('logs/{}_{}.txt'.format(args.dataset, random.randint(0,10000000)))) + logger.setLevel(logging.INFO) + + strides = args.strides + mem_sizes = args.mem_sizes + iters = args.iters + + use_loss_1 = 1 if not args.rand else -2 + + extra_auto = '' + if args.mir: + extra_auto = 'EXTERNAL.OCL.MIR=1 EXTERNAL.REPLAY.MEM_BS=10 EXTERNAL.REPLAY.MIR_K=10 EXTERNAL.OCL.EDIT_RANDOM=1' + + for stride in strides: + for mem_size in mem_sizes: + for proj_flag in args.proj_flags: + for reg_strength in args.reg_strengths: + for lr in args.lr: + + for iter_n in iters: + for seed in range(args.start_seed, args.stop_seed): + name_mark = '' + if args.use_relu: + name_mark += 'relu' + if args.mir: + name_mark += 'mir' + + name = '{}_iter{}_stride_{}_{}_0_{}{}_m{}_{}_{}_{}'.format(ds_to_name[args.dataset], iter_n, stride, use_loss_1, + proj_flag, reg_strength, mem_size, seed, + name_mark, + args.name) + if lr != 0: + name += '_lr{}_epoch{}'.format(lr, args.epoch) + result_file = 'runs/{}_{}/result_tune_k{}.json'.format(name, seed, args.task_num) + if not args.overwrite and os.path.isfile(result_file): + print('result file exists, skiping') + else: + command = 'python train.py --name {name} --tune --config configs/memevolve/{config} --seed {seed} ' \ + '--cfg EXTERNAL.OCL.GRAD_ITER={grad_iter} EXTERNAL.OCL.GRAD_STRIDE={grad_stride} '\ + 'EXTERNAL.OCL.USE_LOSS_1={use_loss_1} EXTERNAL.OCL.USE_LOSS_2=0 EXTERNAL.OCL.USE_RELU=0 '\ + 'EXTERNAL.OCL.PROJ_LOSS_REG={proj_flag} EXTERNAL.REPLAY.MEM_LIMIT={mem_size} ' \ + 'EXTERNAL.OCL.REG_STRENGTH={reg_strength} ' \ + 'EXTERNAL.OCL.USE_RELU={use_relu} ' \ + 'EXTERNAL.EPOCH={epoch} ' \ + 'EXTERNAL.BATCH_SIZE={batch_size} ' \ + 'EXTERNAL.OCL.TASK_NUM={task_num} ' \ + '{extra} ' \ + '{extra_auto} '\ + .format(**{'name': name, 'config': ds_to_config[args.dataset], 'seed':seed, + 'grad_iter':iter_n, 'grad_stride': stride, 'mem_size': mem_size, + 'extra': args.extra, 'proj_flag': proj_flag, 'reg_strength': reg_strength, + 'use_relu': args.use_relu, 'epoch': args.epoch, 'extra_auto': extra_auto, + 'batch_size': args.batch_size, 'task_num': args.task_num, 'use_loss_1': use_loss_1}) + if lr != 0: + command += ' SOLVER.BASE_LR={} '.format(lr) + logger.info(command) + exit_code = os.system(command) + logger.info('Exit code {}'.format(exit_code)) \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000..effb24a --- /dev/null +++ b/train.py @@ -0,0 +1,232 @@ +import argparse +import os +import sys + +import torch +import numpy as np +import random +from trainer_benchmark import ocl_train_mnist, ocl_train_cifar +from utils.utils import get_exp_id, set_config_attr, get_config_attr +from yacs.config import CfgNode + +from nets.classifier import ResNetClassifier, ResNetClassifierWObj +from nets.simplenet import mnist_simple_net_400width_classlearning_1024input_10cls_1ds + + +from ocl import NaiveWrapper, ExperienceReplay, AGEM, ExperienceEvolveApprox + + +def train(cfg, local_rank, distributed, tune=False): + is_ocl = hasattr(cfg.EXTERNAL.OCL, 'ALGO') and cfg.EXTERNAL.OCL.ALGO != 'PLAIN' + task_incremental = get_config_attr(cfg, 'EXTERNAL.OCL.TASK_INCREMENTAL', default=False) + + cfg.TUNE = tune + + algo = cfg.EXTERNAL.OCL.ALGO + if hasattr(cfg,'MNIST'): + if cfg.MNIST.TASK == 'split': + goal = 'split_mnist' + elif cfg.MNIST.TASK == 'permute': + goal = 'permute_mnist' + elif cfg.MNIST.TASK == 'rotate': + goal = 'rotated_mnist' + if hasattr(cfg, 'CIFAR'): + goal = 'split_cifar' + if get_config_attr(cfg, 'CIFAR.DATASET', default="") == 'CIFAR100': + goal = 'split_cifar100' + if get_config_attr(cfg, 'CIFAR.MINI_IMAGENET', default=0): + goal = 'split_mini_imagenet' + + + + if hasattr(cfg,'MNIST'): + + num_of_datasets = 1 if not task_incremental else get_config_attr(cfg, 'EXTERNAL.OCL.TASK_NUM', totype=int) + num_of_classes = 10 if not task_incremental else get_config_attr(cfg, 'EXTERNAL.OCL.CLASS_NUM', totype=int) + base_model = mnist_simple_net_400width_classlearning_1024input_10cls_1ds(num_of_datasets=num_of_datasets, + num_of_classes=num_of_classes, + task_incremental=task_incremental) + + base_model.cfg = cfg + elif hasattr(cfg, 'CIFAR'): + if goal == 'split_cifar': + num_of_datasets = 1 if not task_incremental else get_config_attr(cfg, 'EXTERNAL.OCL.TASK_NUM', totype=int) + num_of_classes = 10 if not task_incremental else get_config_attr(cfg, 'EXTERNAL.OCL.CLASS_NUM', totype=int) + elif goal == 'split_cifar100': + num_of_datasets = 1 if not task_incremental else get_config_attr(cfg, 'EXTERNAL.OCL.TASK_NUM', totype=int) + num_of_classes = 100 if not task_incremental else get_config_attr(cfg, 'EXTERNAL.OCL.CLASS_NUM', totype=int) + elif goal == 'split_mini_imagenet': + num_of_datasets = 1 if not task_incremental else get_config_attr(cfg, 'EXTERNAL.OCL.TASK_NUM', totype=int) + num_of_classes = 100 if not task_incremental else get_config_attr(cfg, 'EXTERNAL.OCL.CLASS_NUM', totype=int) + + + base_model = ResNetClassifier(cfg, depth='18', mlp=1, ignore_index=-100, num_of_datasets=num_of_datasets, + num_of_classes=num_of_classes, task_incremental=task_incremental, goal=goal) + base_model.cfg = cfg + else: + base_model = ResNetClassifier(cfg) + + device = torch.device(cfg.MODEL.DEVICE) + base_model.to(device) + if cfg.EXTERNAL.OPTIMIZER.ADAM: + optimizer = torch.optim.Adam( + filter(lambda x: x.requires_grad, base_model.parameters()), + lr=cfg.SOLVER.BASE_LR, betas=(0.9, 0.999) + ) + else: + optimizer = torch.optim.SGD( + filter(lambda x: x.requires_grad, base_model.parameters()), + lr=cfg.SOLVER.BASE_LR + ) + + # algorithm specific model wrapper + x_size = 3 * 2 * base_model.cfg.EXTERNAL.IMAGE_SIZE ** 2 if goal == 'classification' else \ + 3 * base_model.cfg.EXTERNAL.IMAGE_SIZE ** 2 + if goal == 'split_mnist' or goal == 'permute_mnist' or goal == 'rotated_mnist': x_size = 28 * 28 + if goal == 'split_cifar' or goal == 'split_cifar100': x_size = 3 * 32 * 32 + if goal == 'split_mini_imagenet': x_size = 3 * 84 * 84 + + if algo == 'ER': + model = ExperienceReplay(base_model, optimizer, x_size, base_model.cfg, goal) + elif algo == 'VERX': + model = ExperienceEvolveApprox(base_model, optimizer, x_size, base_model.cfg, goal) + elif algo == 'AGEM': + model = AGEM(base_model, optimizer, x_size, base_model.cfg, goal) + elif algo == 'naive': + model = NaiveWrapper(base_model, optimizer, x_size, base_model.cfg, goal) + model.to(device) + + use_mixed_precision = cfg.DTYPE == "float16" + arguments = {"iteration": 0, "global_step": 0, "epoch": 0} + output_dir = cfg.OUTPUT_DIR + writer = None + epoch_num = 1 + for e in range(epoch_num): + print("epoch") + + arguments['iteration'] = 0 + epoch = arguments['epoch'] + if goal == 'split_mnist' or goal == 'permute_mnist' or goal == 'rotated_mnist': + ocl_train_mnist(model, optimizer, None, device, arguments, writer, epoch, goal, tune=tune) + elif goal == 'split_cifar' or goal == 'split_cifar100' or goal == 'split_mini_imagenet': + ocl_train_cifar(model, optimizer, None, device, arguments, writer, epoch, goal, tune=tune) + else: + raise NotImplementedError + arguments['epoch'] += 1 + + with open(os.path.join(output_dir, 'model.bin'),'wb') as wf: + torch.save(model.state_dict(), wf) + # else: + # break + if is_ocl and hasattr(model, 'dump_reservoir') and args.dump_reservoir: + model.dump_reservoir(os.path.join(cfg.OUTPUT_DIR, 'mem_dump.pkl'), verbose=args.dump_reservoir_verbose) + return model + +def set_cfg_from_args(args, cfg): + cfg_params = args.cfg + if cfg_params is None: return + for param in cfg_params: + k, v = param.split('=') + set_config_attr(cfg, k, v) + +def count_params(m: torch.nn.Module, only_trainable: bool = False): + """ + returns the total number of parameters used by `m` (only counting + shared parameters once); if `only_trainable` is True, then only + includes parameters with `requires_grad = True` + """ + parameters = m.parameters() + if only_trainable: + parameters = list(p for p in parameters if p.requires_grad) + unique = dict((p.data_ptr(), p) for p in parameters).values() + return sum(p.numel() for p in unique) + + +def main(args): + if '%id' in args.name: + exp_name = args.name.replace('%id', get_exp_id()) + else: + exp_name = args.name + + combined_cfg = CfgNode(new_allowed=True) + combined_cfg.merge_from_file(args.config) + cfg = combined_cfg + cfg.EXTERNAL.EXPERIMENT_NAME = exp_name + cfg.SEED = args.seed + cfg.DEBUG = args.debug + + set_cfg_from_args(args, cfg) + + output_dir = get_config_attr(cfg, 'OUTPUT_DIR', default='') + if output_dir == '.': output_dir = 'runs/' + cfg.OUTPUT_DIR = os.path.join(output_dir, + '{}_{}'.format(cfg.EXTERNAL.EXPERIMENT_NAME, cfg.SEED)) + cfg.MODE = 'train' + + # cfg.freeze() + + num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 + distributed = num_gpus > 1 + local_rank = int(os.environ.get('LOCAL_RANK', 0)) + + if distributed: + torch.cuda.set_device(local_rank) + torch.distributed.init_process_group( + backend="nccl", init_method="env://" + ) + + output_dir = cfg.OUTPUT_DIR + + # save overloaded model config in the output directory + model = train(cfg, local_rank, distributed, tune=args.tune) + + output_args_path = os.path.join(output_dir, 'args.txt') + wf = open(output_args_path, 'w') + wf.write(' '.join(sys.argv)) + wf.close() + +def seed_everything(seed): + ''' + :param seed: + :param device: + :return: + ''' + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + # some cudnn methods can be random even after fixing the seed + # unless you tell it to be deterministic + torch.backends.cudnn.deterministic = True + +def seed_everything_old(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--local_rank', type=int) + parser.add_argument('--name', type=str, default='%id') + parser.add_argument('--config', type=str, default='config.yaml') + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--seed_old', action='store_true') + parser.add_argument('--n_runs', type=int, default=1) + parser.add_argument('--tune', action='store_true') + parser.add_argument('--cfg', nargs='*') + parser.add_argument('--debug', action='store_true') + parser.add_argument('--dump_reservoir', action='store_true') + parser.add_argument('--dump_reservoir_verbose', action='store_true') + parser.add_argument('--single_word', action='store') + + args = parser.parse_args() + if not args.seed_old: + seed_everything(args.seed) + else: + seed_everything_old(args.seed) + n_runs = args.n_runs + for i in range(n_runs): + main(args) diff --git a/trainer_benchmark.py b/trainer_benchmark.py new file mode 100644 index 0000000..db24761 --- /dev/null +++ b/trainer_benchmark.py @@ -0,0 +1,242 @@ +#from trainer import * +import logging +import torch + +from tqdm import tqdm +import json, os + +from inference import to_list, f1_score +from dataloader import get_split_mnist_dataloader, get_permute_mnist_dataloader,\ + get_split_cifar_dataloader, get_split_cifar100_dataloader, get_split_mini_imagenet_dataloader,\ + get_rotated_mnist_dataloader, IIDDataset +from utils.utils import get_config_attr +from torch.utils.data import DataLoader + +def exp_decay_lr(optimizer, step, total_step, init_lr): + gamma = (1 / 6) ** (step / total_step) + for param_group in optimizer.param_groups: + param_group['lr'] = init_lr * gamma + +def ocl_train_mnist(model, optimizer, checkpointer, device, arguments, writer, epoch, + goal='split', tune=False): + logger = logging.getLogger("maskrcnn_benchmark.trainer") + logger.info("Start training @ epoch {:02d}".format(arguments['epoch'])) + model.train() + cfg = model.cfg + pbar = tqdm( + position=0, + desc='GPU: 0' + ) + + num_instances = cfg.MNIST.INSTANCE_NUM + + if goal == 'split_mnist': + task_num = 5 + loader_func = get_split_mnist_dataloader + elif goal == 'permute_mnist': + task_num = 10 + loader_func = get_permute_mnist_dataloader + elif goal == 'rotated_mnist': + task_num = 20 + loader_func = get_rotated_mnist_dataloader + else: + raise ValueError + + if tune: + task_num = get_config_attr(cfg, 'EXTERNAL.OCL.TASK_NUM', totype=int, default=3) + + num_epoch = get_config_attr(cfg, 'EXTERNAL.EPOCH', totype=int, default=1) + total_step = task_num * 1000 + base_lr = get_config_attr(cfg,'SOLVER.BASE_LR',totype=float) + # whether iid + iid = not get_config_attr(cfg, 'EXTERNAL.OCL.ACTIVATED', totype=bool) + do_exp_lr_decay = get_config_attr(cfg,'EXTERNAL.OCL.EXP_LR_DECAY',0) + + all_accs = [] + best_avg_accs = [] + step = 0 + for task_id in range(task_num): + if iid: + if task_id != 0: break + data_loaders = [loader_func(cfg, 'train', [task_id], batch_size=cfg.EXTERNAL.BATCH_SIZE, + max_instance=num_instances) for task_id in range(task_num)] + data_loader = DataLoader(IIDDataset(data_loaders), batch_size=cfg.EXTERNAL.BATCH_SIZE) + num_instances *= task_num + else: + data_loader = loader_func(cfg, 'train', [task_id], batch_size=cfg.EXTERNAL.BATCH_SIZE, + max_instance=num_instances) + + best_avg_acc = -1 + #model.net.set_task(task_id) # choose the classifier head if the model supports + for epoch in range(num_epoch): + seen = 0 + for i, data in enumerate(data_loader): + if seen >= num_instances: break + inputs, labels = data + inputs, labels = (inputs.to(device), labels.to(device)) + task_ids = torch.LongTensor([task_id] * labels.size(0)).to(inputs.device) + inputs = inputs.flatten(1) + model.observe(inputs, labels, task_ids=task_ids) + step += 1 + if do_exp_lr_decay: + exp_decay_lr(optimizer, step, total_step, base_lr) + + seen += labels.size(0) + # run evaluation + with torch.no_grad(): + if iid: + accs, _, avg_acc = inference_mnist(model, task_num, loader_func, device, tune=tune) + else: + accs, _, avg_acc = inference_mnist(model, task_id + 1, loader_func, device, tune=tune) + logger.info('Epoch {}\tTask {}\tAcc {}'.format(epoch, task_id, avg_acc)) + for i, acc in enumerate(accs): + logger.info('::Val Task {}\t Acc {}'.format(i, acc)) + all_accs.append(accs) + if avg_acc > best_avg_acc: + best_avg_acc = avg_acc + else: + break + best_avg_accs.append(best_avg_acc) + file_name = 'result.json' if not tune else 'result_tune_k{}.json'.format(task_num) + result_file = open(os.path.join(cfg.OUTPUT_DIR, file_name), 'w') + json.dump({'all_accs': all_accs, 'avg_acc': avg_acc}, result_file, indent=4) + result_file.close() + + +def inference_mnist(model, max_task, loader_func, device, tune=False): + model.train(False) + accs, instance_nums = [], [] + for val_task_id in range(0, max_task): + #task_id = 0 + all_pred, all_truth = [], [] + val_data_loader = loader_func(model.cfg, 'test' if not tune else 'val', [val_task_id], + batch_size=model.cfg.EXTERNAL.BATCH_SIZE) + print('-------len val data loader {}-------'.format(len(val_data_loader))) + for i, data in enumerate(val_data_loader): + + inputs, labels = data + inputs, labels = (inputs.to(device), labels.to(device)) + task_ids = torch.LongTensor([val_task_id] * labels.size(0)).to(inputs.device) + ret_dict = model(inputs, labels, task_ids=task_ids) + score = ret_dict['score'] + _, pred = torch.max(score, -1) + all_pred.extend(to_list(pred)) + all_truth.extend(to_list(labels)) + acc = f1_score(all_truth, all_pred, average='micro') + accs.append(acc) + instance_nums.append(len(all_pred)) + total_instance_num = sum(instance_nums) + model.train(True) + return accs, instance_nums, sum([x * y / total_instance_num for x,y in zip(accs, instance_nums)]) + + +def ocl_train_cifar(model, optimizer, checkpointer, device, arguments, writer, epoch, + goal='split_cifar', tune=False): + logger = logging.getLogger("maskrcnn_benchmark.trainer") + logger.info("Start training @ epoch {:02d}".format(arguments['epoch'])) + model.train() + cfg = model.cfg + + num_epoch = cfg.CIFAR.EPOCH + if goal == 'split_cifar': + loader_func = get_split_cifar_dataloader + total_step = 4750 + elif goal == 'split_cifar100': + loader_func = get_split_cifar100_dataloader + total_step = 25000 + else: + loader_func = get_split_mini_imagenet_dataloader + total_step = 22500 + max_instance = cfg.CIFAR.INSTANCE_NUM if hasattr(cfg.CIFAR, 'INSTANCE_NUM') else 1e10 + if not tune: + task_num = get_config_attr(cfg, 'EXTERNAL.OCL.TASK_NUM', totype=int) + else: + task_num = get_config_attr(cfg, 'EXTERNAL.OCL.TASK_NUM', totype=int) + + + do_exp_lr_decay = get_config_attr(cfg,'EXTERNAL.OCL.EXP_LR_DECAY',0) + base_lr = get_config_attr(cfg,'SOLVER.BASE_LR',totype=float) + step = 0 + + num_epoch = get_config_attr(cfg, 'EXTERNAL.EPOCH', totype=int, default=1) + all_accs = [] + best_avg_accs = [] + iid = not get_config_attr(cfg, 'EXTERNAL.OCL.ACTIVATED', totype=bool) + for task_id in range(task_num): + if iid: + if task_id != 0: break + data_loaders = [loader_func(cfg, 'train', [task_id], batch_size=cfg.EXTERNAL.BATCH_SIZE, + max_instance=max_instance) for task_id in range(task_num)] + data_loader = DataLoader(IIDDataset(data_loaders), batch_size=cfg.EXTERNAL.BATCH_SIZE) + max_instance *= task_num + else: + data_loader = loader_func(cfg, 'train', [task_id], batch_size=cfg.EXTERNAL.BATCH_SIZE, max_instance=max_instance) + pbar = tqdm( + position=0, + desc='GPU: 0', + total=len(data_loader) + ) + best_avg_acc = -1 + for epoch in range(num_epoch): + seen = 0 + for i, data in enumerate(data_loader): + if seen >= max_instance: break + pbar.update(1) + inputs, labels = data + inputs, labels = (inputs.to(device), labels.to(device)) + inputs = inputs.flatten(1) + task_ids = torch.LongTensor([task_id] * labels.size(0)).to(inputs.device) + model.observe(inputs, labels, task_ids) + seen += inputs.size(0) + if do_exp_lr_decay: + exp_decay_lr(optimizer, step, total_step, base_lr) + step += 1 + # # run evaluation + with torch.no_grad(): + if iid: + accs, _, avg_acc = inference_cifar(model, task_num, loader_func, device, goal, tune=tune) + else: + accs, _, avg_acc = inference_cifar(model, task_id + 1, loader_func, device, goal, tune=tune) + logger.info('Epoch {}\tTask {}\tAcc {}'.format(epoch, task_id, avg_acc)) + for i, acc in enumerate(accs): + logger.info('::Val Task {}\t Acc {}'.format(i, acc)) + all_accs.append(accs) + if avg_acc > best_avg_acc: + best_avg_acc = avg_acc + else: + break + best_avg_accs.append(best_avg_acc) + file_name = 'result.json' if not tune else 'result_tune_k{}.json'.format(task_num) + result_file = open(os.path.join(cfg.OUTPUT_DIR, file_name), 'w') + json.dump({'all_accs': all_accs, 'avg_acc': avg_acc, 'best_avg_accs': best_avg_accs}, result_file, indent=4) + result_file.close() + return best_avg_accs + +def inference_cifar(model, max_task, loader_func, device, goal, tune=False): + accs, instance_nums = [], [] + model.train(False) + for val_task_id in range(0, max_task): + all_pred, all_truth = [], [] + val_data_loader = loader_func(model.cfg, 'test' if not tune else 'val', [val_task_id], batch_size=model.cfg.EXTERNAL.BATCH_SIZE) + for i, data in enumerate(val_data_loader): + inputs, labels = data + inputs, labels = (inputs.to(device), labels.to(device)) + inputs = inputs.view(-1, 3, 32, 32) if goal == 'split_cifar' or goal == 'split_cifar100' else inputs.view(-1, 3, 84, 84) + task_ids = torch.LongTensor([val_task_id] * labels.size(0)).to(inputs.device) + if model.cfg.EXTERNAL.OCL.ALGO == 'CNDPM': + score = model(inputs) + else: + ret_dict = model(bbox_images=inputs, spatial_feat=None, attr_labels=labels, + obj_labels=None, images=None, task_ids=task_ids) + score = ret_dict['score'] + _, pred = torch.max(score, -1) + all_pred.extend(to_list(pred)) + all_truth.extend(to_list(labels)) + acc = f1_score(all_truth, all_pred, average='micro') + accs.append(acc) + instance_nums.append(len(all_pred)) + total_instance_num = sum(instance_nums) + model.train(True) + return accs, instance_nums, sum([x * y / total_instance_num for x,y in zip(accs, instance_nums)]) + + diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..b7be57f --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,2 @@ +import sys +sys.path.append('.') \ No newline at end of file diff --git a/utils/build_transforms.py b/utils/build_transforms.py new file mode 100644 index 0000000..3a02647 --- /dev/null +++ b/utils/build_transforms.py @@ -0,0 +1,45 @@ +# from maskrcnn_benchmark.data.transforms import transforms as T +from torchvision.transforms import transforms as T + +def build_transforms(cfg, split="train"): + if split=="train": + min_size = min(cfg.EXTERNAL.IMAGE.HEIGHT,cfg.EXTERNAL.IMAGE.WIDTH) + max_size = max(cfg.EXTERNAL.IMAGE.HEIGHT,cfg.EXTERNAL.IMAGE.WIDTH) + flip_horizontal_prob = 0.5 # cfg.INPUT.FLIP_PROB_TRAIN + flip_vertical_prob = cfg.INPUT.VERTICAL_FLIP_PROB_TRAIN + brightness = cfg.INPUT.BRIGHTNESS + contrast = cfg.INPUT.CONTRAST + saturation = cfg.INPUT.SATURATION + hue = cfg.INPUT.HUE + else: + min_size = min(cfg.EXTERNAL.IMAGE.HEIGHT, cfg.EXTERNAL.IMAGE.WIDTH) + max_size = max(cfg.EXTERNAL.IMAGE.HEIGHT, cfg.EXTERNAL.IMAGE.WIDTH) + flip_horizontal_prob = 0.0 + flip_vertical_prob = 0.0 + brightness = 0.0 + contrast = 0.0 + saturation = 0.0 + hue = 0.0 + + to_bgr255 = cfg.INPUT.TO_BGR255 + normalize_transform = T.Normalize( + mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD, + ) + color_jitter = T.ColorJitter( + brightness=brightness, + contrast=contrast, + saturation=saturation, + hue=hue, + ) + + transform = T.Compose( + [ + color_jitter, + T.Resize(min_size, max_size), + T.RandomHorizontalFlip(flip_horizontal_prob), + T.RandomVerticalFlip(flip_vertical_prob), + T.ToTensor(), + normalize_transform, + ] + ) + return transform \ No newline at end of file diff --git a/utils/tupperware.py b/utils/tupperware.py new file mode 100644 index 0000000..8961239 --- /dev/null +++ b/utils/tupperware.py @@ -0,0 +1,73 @@ +from collections import UserDict +import collections +from recordclass import recordclass + +def tupperware(mapping): + """ Convert mappings to 'tupperwares' recursively. + Lets you use dicts like they're JavaScript Object Literals (~=JSON)... + It recursively turns mappings (dictionaries) into namedtuples. + Thus, you can cheaply create an object whose attributes are accessible + by dotted notation (all the way down). + Use cases: + * Fake objects (useful for dependency injection when you're making + fakes/stubs that are simpler than proper mocks) + * Storing data (like fixtures) in a structured way, in Python code + (data whose initial definition reads nicely like JSON). You could do + this with dictionaries, but namedtuples are immutable, and their + dotted notation can be clearer in some contexts. + .. doctest:: + >>> t = tupperware({ + ... 'foo': 'bar', + ... 'baz': {'qux': 'quux'}, + ... 'tito': { + ... 'tata': 'tutu', + ... 'totoro': 'tots', + ... 'frobnicator': ['this', 'is', 'not', 'a', 'mapping'] + ... } + ... }) + >>> t # doctest: +ELLIPSIS + Tupperware(tito=Tupperware(...), foo='bar', baz=Tupperware(qux='quux')) + >>> t.tito # doctest: +ELLIPSIS + Tupperware(frobnicator=[...], tata='tutu', totoro='tots') + >>> t.tito.tata + 'tutu' + >>> t.tito.frobnicator + ['this', 'is', 'not', 'a', 'mapping'] + >>> t.foo + 'bar' + >>> t.baz.qux + 'quux' + Args: + mapping: An object that might be a mapping. If it's a mapping, convert + it (and all of its contents that are mappings) to namedtuples + (called 'Tupperwares'). + Returns: + A tupperware (a namedtuple (of namedtuples (of namedtuples (...)))). + If argument is not a mapping, it just returns it (this enables the + recursion). + """ + + if (isinstance(mapping, collections.Mapping) and + not isinstance(mapping, ProtectedDict)): + for key, value in mapping.items(): + mapping[key] = tupperware(value) + return namedtuple_from_mapping(mapping) + return mapping + + +def namedtuple_from_mapping(mapping, name="Tupperware"): + # this_namedtuple_maker = collections.namedtuple(name, mapping.keys()) + this_namedtuple_maker = recordclass(name, mapping.keys()) + return this_namedtuple_maker(**mapping) + + +class ProtectedDict(UserDict): + """ A class that exists just to tell `tupperware` not to eat it. + `tupperware` eats all dicts you give it, recursively; but what if you + actually want a dictionary in there? This will stop it. Just do + ProtectedDict({...}) or ProtectedDict(kwarg=foo). + """ + + +def tupperware_from_kwargs(**kwargs): + return tupperware(kwargs) diff --git a/utils/utils.py b/utils/utils.py new file mode 100644 index 0000000..5ba1d07 --- /dev/null +++ b/utils/utils.py @@ -0,0 +1,291 @@ +from collections import Counter +import uuid, time, datetime, os, torch, logging +from collections import OrderedDict +import numpy as np + +def get_top_k_by_frequency(vocab, top_k=None): + ''' + Get the top k elements from a list by frequency + :param vocab: the list + :parm top_k: top_k + :return: top k elements + ''' + counter = Counter(vocab) + + sorted_vocab = sorted( + [t for t in counter], + key=counter.get, + reverse=True + ) + + if top_k: + return sorted_vocab[:top_k] + + return sorted_vocab + +def get_exp_id(): + return uuid.uuid4().hex[:6] + + +class Timer(object): + def __init__(self): + self.reset() + + @property + def average_time(self): + return self.total_time / self.calls if self.calls > 0 else 0.0 + + def tic(self): + # using time.time instead of time.clock because time time.clock + # does not normalize for multithreading + self.start_time = time.time() + + def toc(self, average=True): + self.add(time.time() - self.start_time) + if average: + return self.average_time + else: + return self.diff + + def add(self, time_diff): + self.diff = time_diff + self.total_time += self.diff + self.calls += 1 + + def reset(self): + self.total_time = 0.0 + self.calls = 0 + self.start_time = 0.0 + self.diff = 0.0 + + def avg_time_str(self): + time_str = str(datetime.timedelta(seconds=self.average_time)) + return time_str + + class Checkpointer(object): + def __init__( + self, + model, + optimizer=None, + scheduler=None, + save_dir="", + save_to_disk=None, + logger=None, + ): + self.model = model + self.optimizer = optimizer + self.scheduler = scheduler + self.save_dir = save_dir + self.save_to_disk = save_to_disk + if logger is None: + logger = logging.getLogger(__name__) + self.logger = logger + + def save(self, name, **kwargs): + if not self.save_dir: + return + + if not self.save_to_disk: + return + + data = {} + data["model"] = self.model.state_dict() + if self.optimizer is not None: + data["optimizer"] = self.optimizer.state_dict() + if self.scheduler is not None: + data["scheduler"] = self.scheduler.state_dict() + data.update(kwargs) + + save_file = os.path.join(self.save_dir, "{}.pth".format(name)) + self.logger.info("Saving checkpoint to {}".format(save_file)) + torch.save(data, save_file) + self.tag_last_checkpoint(save_file) + + def load(self, f=None, use_latest=True): + if self.has_checkpoint() and use_latest: + # override argument with existing checkpoint + f = self.get_checkpoint_file() + if not f: + # no checkpoint could be found + self.logger.info("No checkpoint found. Initializing model from scratch") + return {} + self.logger.info("Loading checkpoint from {}".format(f)) + checkpoint = self._load_file(f) + self._load_model(checkpoint) + if "optimizer" in checkpoint and self.optimizer: + self.logger.info("Loading optimizer from {}".format(f)) + self.optimizer.load_state_dict(checkpoint.pop("optimizer")) + if "scheduler" in checkpoint and self.scheduler: + self.logger.info("Loading scheduler from {}".format(f)) + self.scheduler.load_state_dict(checkpoint.pop("scheduler")) + + # return any further checkpoint data + return checkpoint + + def has_checkpoint(self): + save_file = os.path.join(self.save_dir, "last_checkpoint") + return os.path.exists(save_file) + + def get_checkpoint_file(self): + save_file = os.path.join(self.save_dir, "last_checkpoint") + try: + with open(save_file, "r") as f: + last_saved = f.read() + last_saved = last_saved.strip() + except IOError: + # if file doesn't exist, maybe because it has just been + # deleted by a separate process + last_saved = "" + return last_saved + + def tag_last_checkpoint(self, last_filename): + save_file = os.path.join(self.save_dir, "last_checkpoint") + with open(save_file, "w") as f: + f.write(last_filename) + + def _load_file(self, f): + return torch.load(f, map_location=torch.device("cpu")) + + def _load_model(self, checkpoint): + load_state_dict(self.model, checkpoint.pop("model")) + +def load_state_dict(model, loaded_state_dict): + model_state_dict = model.state_dict() + # if the state_dict comes from a model that was wrapped in a + # DataParallel or DistributedDataParallel during serialization, + # remove the "module" prefix before performing the matching + loaded_state_dict = strip_prefix_if_present(loaded_state_dict, prefix="module.") + align_and_update_state_dicts(model_state_dict, loaded_state_dict) + + # use strict loading + model.load_state_dict(model_state_dict) + +def align_and_update_state_dicts(model_state_dict, loaded_state_dict): + """ + Strategy: suppose that the models that we will create will have prefixes appended + to each of its keys, for example due to an extra level of nesting that the original + pre-trained weights from ImageNet won't contain. For example, model.state_dict() + might return backbone[0].body.res2.conv1.weight, while the pre-trained model contains + res2.conv1.weight. We thus want to match both parameters together. + For that, we look for each model weight, look among all loaded keys if there is one + that is a suffix of the current weight name, and use it if that's the case. + If multiple matches exist, take the one with longest size + of the corresponding name. For example, for the same model as before, the pretrained + weight file can contain both res2.conv1.weight, as well as conv1.weight. In this case, + we want to match backbone[0].body.conv1.weight to conv1.weight, and + backbone[0].body.res2.conv1.weight to res2.conv1.weight. + """ + current_keys = sorted(list(model_state_dict.keys())) + loaded_keys = sorted(list(loaded_state_dict.keys())) + # get a matrix of string matches, where each (i, j) entry correspond to the size of the + # loaded_key string, if it matches + match_matrix = [ + len(j) if i.endswith(j) else 0 for i in current_keys for j in loaded_keys + ] + match_matrix = torch.as_tensor(match_matrix).view( + len(current_keys), len(loaded_keys) + ) + max_match_size, idxs = match_matrix.max(1) + # remove indices that correspond to no-match + idxs[max_match_size == 0] = -1 + + # used for logging + max_size = max([len(key) for key in current_keys]) if current_keys else 1 + max_size_loaded = max([len(key) for key in loaded_keys]) if loaded_keys else 1 + log_str_template = "{: <{}} loaded from {: <{}} of shape {}" + logger = logging.getLogger(__name__) + for idx_new, idx_old in enumerate(idxs.tolist()): + if idx_old == -1: + continue + key = current_keys[idx_new] + key_old = loaded_keys[idx_old] + model_state_dict[key] = loaded_state_dict[key_old] + logger.info( + log_str_template.format( + key, + max_size, + key_old, + max_size_loaded, + tuple(loaded_state_dict[key_old].shape), + ) + ) + + +def strip_prefix_if_present(state_dict, prefix): + keys = sorted(state_dict.keys()) + if not all(key.startswith(prefix) for key in keys): + return state_dict + stripped_state_dict = OrderedDict() + for key, value in state_dict.items(): + stripped_state_dict[key.replace(prefix, "")] = value + return stripped_state_dict + +def get_config_attr(cfg, attr_string, default=None, totype=None, mute=False): + try: + attrs = attr_string.split('.') + obj = cfg + for s in attrs: + obj = getattr(obj, s) + if totype is None: + return type(default)(obj) + else: + if totype is bool and obj not in ['True','False',True,False]: + raise ValueError('malformed boolean input: {}, {}'.format(obj,type(obj))) + if totype is bool and obj in ['False',False]: + return False + return totype(obj) + except AttributeError: + if not mute: + print('Warning: attribute {} not found. Default: {}'.format(attr_string, default)) + return default + +class DotDict(dict): + __getattr__ = dict.__getitem__ + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + + def __init__(self, **dct): + for key, value in dct.items(): + #if hasattr(value, 'keys'): + # value = DotDict(**value) + self[key] = value + + +def set_config_attr(cfg, attr_key, attr_value): + attrs = attr_key.split('.') + obj = cfg + for attr in attrs[:-1]: + if not hasattr(obj, attr): + setattr(obj, attr, DotDict()) + obj = getattr(obj, attr) + + + try: + attr_value = int(attr_value) + except ValueError: + try: + attr_value = float(attr_value) + except ValueError: + pass + pass + + if attr_value == 'True': attr_value = True + if attr_value == 'False': attr_value = False + + setattr(obj, attrs[-1], attr_value) + + +def filter_outliers(l): + q1 = np.quantile(l, 0.25) + q2 = np.quantile(l, 0.75) + iqr = q2 - q1 + lb, rb = q1 - 1.5 * iqr, q2 + 1.5 * iqr + l = [x for x in l if lb <= x <= rb] + return l + +def set_cfg_from_args(args, cfg): + cfg_params = args.cfg + if cfg_params is None: return + for param in cfg_params: + k, v = param.split('=') + set_config_attr(cfg, k, v)