Skip to content

[NeurIPS 2022] "A Win-win Deal: Towards Sparse and Robust Pre-trained Language Models", Yuanxin Liu, Fandong Meng, Zheng Lin, Jiangnan Li, Peng Fu, Yanan Cao, Weiping Wang, Jie Zhou

Notifications You must be signed in to change notification settings

llyx97/sparse-and-robust-PLM

Repository files navigation

A Win-win Deal: Towards Sparse and Robust Pre-trained Language Models

This repository contains implementation of the paper "A Win-win Deal: Towards Sparse and Robust Pre-trained Language Models" (accepted by NeruIPS 2022).

The codes for debiasing methods are modified from chrisc36/debias and UKPLab/emnlp2020-debiasing-unknown. The codes for mask training are modified from maskbert. Our training and inference pipeline is based on huggingface/transformers.

Content

1. Overview

The main topic of this paper is to investigate whether there exist PLM subnetworks that are both sparse and robust against dataset bias?

We call such subnetworks SRNets and explore their existence under different pruning and fine-tuning paradigms, which are illustrated in Figure1.

2. Setup

  conda create -n srnet python=3.6
  conda activate srnet
  conda install pytorch==1.6.0 cudatoolkit=10.1 -c pytorch
  pip install -r requirements.txt

pytorch>=1.4.0 are also okay.

3. Prepare Data and Pre-trained Language Models

MNLI and QQP are datasets from the GLUE benchmark. For FEVER, we use the processed training and evaluation data provided by the authors of FEVER-Symmetric. The OOD datasets can be accessed from: HANS, PAWS and FEVER-Symmetric. Download the datasets and place them to the data/ folder, the structure of which is like follows:

data
├── MNLI
│   ├── train.tsv
│   ├── dev_matched.tsv
│   ├── dev_mismatched.tsv
|   └── hans
|       ├── heuristics_train_set.txt
|       └── heuristics_evaluation_set.txt
├── QQP
│   ├── train.tsv
│   ├── dev.tsv
|   ├── paws_qqp
|   |   ├── train.tsv
|   |   ├── dev.tsv
|   |   └── test.tsv
|   └── paws_wiki
|       ├── train.tsv
|       ├── dev.tsv
|       └── test.tsv
└── fever
    ├── fever.train.jsonl
    ├── fever.dev.jsonl
    ├── sym1
    |   └── test.jsonl
    └── sym2
        ├── dev.jsonl
        └── test.jsonl

By specifying the argument --model_name_or_path as bert-base-uncased, bert-large-uncased or roberta-base, the code will automatically download the PLMs. You can also manually download the models from huggingface models and set --model_name_or_path as the path to the model checkpoints.

4. Fine-tuning Full BERT

Fine-tuning with Standard Cross-Entropy (CE) Loss

To fine-tune full BERT with standard cross-entropy loss, use the scripts in scripts/full_bert/std_train. Taking MNLI as an example, run

  bash scripts/full_bert/std_train/mnli.sh

Fine-tuning with Debiasing Loss

The debiasing methods requires the bias models, which are trained using the codes provided by chrisc36/debias. The predictions of the bias models are placed in the folder bias_model_preds.

To fine-tune full BERT with Product-of-Experts(PoE) on MNLI, run

  bash scripts/full_bert/robust_train/poe/mnli.sh

Changing poe to reweighting or conf_reg to switch to Example Reweighting or Confidence Regularization.

Note that to perform conf_reg, we need to first fine-tune BERT with standard CE loss (the teacher model) and obtain the predictions.

5. Subnetworks from Fine-tuned BERT

Subnetworks from Standard Fine-tuned BERT

IMP

To perform IMP using the CE loss on a standard fine-tuned BERT (again, taking MNLI as example), run

  bash scripts/imp/prune_after_ft/std/mnli.sh

Note that IMP will produce subnetworks with varying sparsity levels (10%~90%).

Similarly, when PoE is used in the process of IMP, run

  bash scripts/imp/prune_after_ft/poe/mnli.sh

Mask Training

To perform mask training using the CE loss, with a target sparsity of 50%, run

  bash scripts/mask_train/mask_on_plm_ft/plm_std_ft/std/mnli/0.5.sh

Similarly, changing std in the path to poe, reweighting or conf_reg to use the dibiasing methods.

Subnetworks from PoE Fine-tuned BERT

IMP

To perform IMP using the PoE loss on a PoE fine-tuned BERT, run

  bash scripts/imp/prune_after_robust_ft/poe/mnli.sh

Mask Training

To perform mask training using the PoE loss, with a target sparsity of 50%, run

  bash scripts/mask_train/mask_on_plm_ft/plm_poe_ft/poe/mnli/0.5.sh

6. BERT Subnetworks Fine-tuned in Isolation

IMP

To obtain the subnetworks using IMP and PoE objective, run

  bash scripts/imp/lt/pruning/mnli.sh

This will produce subnetworks with varying sparsity levels (10%~90%). Then, fine-tune the obtained subnetwork (taking 50% sparsity as an example) by running:

  bash scripts/imp/lt/retrain/std/mnli/0.5.sh

Change the shell script to scripts/imp/lt/retrain/poe/mnli/0.5.sh to enable PoE fine-tuning.

Mask Training

We use the pruning masks of the Subnetworks from Standard Fine-tuned BERT. Then, fine-tune the obtained subnetwork by running:

  bash scripts/mask_train/mask_on_plm_ft/plm_std_ft/poe/mnli/retrain/std/0.5.sh

Change the shell script to scripts/mask_train/mask_on_plm_ft/plm_std_ft/poe/mnli/retrain/poe/0.5.sh to enable PoE fine-tuning.

7. BERT Subnetworks Without Fine-tuning

To obtain the subnetworks, directly performing mask training on the pre-trained BERT:

  bash scripts/mask_train/mask_on_plm_pt/std/mnli/0.5.sh

To enable mask training with PoE, change the shell script to scripts/mask_train/mask_on_plm_pt/poe/mnli/0.5.sh.

8. Sparse and Unbiased BERT Subnetworks

We utilize the OOD training data to explore the upper bound of SRNets. The above three setups are considered and we give examples as follows

Subnetworks from Fine-tuned BERT :

  bash scripts/mask_train/mask_on_plm_ft/plm_std_ft/ood/mnli/0.5.sh

BERT Subnetworks Fine-tuned in Isolation :

  bash scripts/mask_train/mask_on_plm_ft/plm_std_ft/ood/mnli/retrain/0.5.sh

BERT Subnetworks Without Fine-tuning :

  bash scripts/mask_train/mask_on_plm_pt/ood/mnli/0.5.sh

9. Refining the SRNets Searching Process

The Timing to Start Searching SRNets

To start mask training from a standard full BERT fine-tuned for 5000 steps, run

  bash scripts/mask_train/mask_on_plm_ft/plm_std_ft/poe/mnli/mask_on_checkpoints/5000.sh

The sparsity of subnetworks are set to 70% by default.

Gradual Sparsity Increase

To perform mask training with gradual sparsity increase, run the following command:

  bash scripts/mask_train/mask_on_plm_ft/plm_std_ft/poe/mnli/gradual_sparsity_increase/0.9.sh

The initial and final sparsity levels can be set in the script, corresponding to the arguments --init_sparsity and --zero_rate, respectively. Note that we adopt the soft version of magnitude initialization when using gradual sparsity increase (by setting --controlled_init as magnitude_soft)

Todo

  • Summarizing actual values of the results.

Citation

If you use this repository in a published research, please cite our paper:

@article{Liu2022SRNets,
author = {Yuanxin Liu, Fandong Meng, Zheng Lin, Jiangnan Li, Peng Fu, Yanan Cao, Weiping Wang, Jie Zhou},
title = {A Win-win Deal: Towards Sparse and Robust Pre-trained Language Models},
year = {2022},
eprint={2210.05211},
archivePrefix={arXiv},
primaryClass={cs.CL}
}

About

[NeurIPS 2022] "A Win-win Deal: Towards Sparse and Robust Pre-trained Language Models", Yuanxin Liu, Fandong Meng, Zheng Lin, Jiangnan Li, Peng Fu, Yanan Cao, Weiping Wang, Jie Zhou

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published