Skip to content
/ CIGA Public

[NeurIPS 2022] Learning Causally Invariant Representations for Out-of-Distribution Generalization on Graphs

License

Notifications You must be signed in to change notification settings

LFhase/CIGA

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

30 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

CIGA: Causality Inspired Invariant Graph LeArning

Paper Github License License Video Slides

This repo contains the sample code for reproducing the results of our NeurIPS 2022 paper: Learning Causally Invariant Representations for Out-of-Distribution Generalization on Graphs, which was also presented at ICML SCIS Workshop, known as Invariance Principle Meets Out-of-Distribution Generalization on Graphs. 😆😆😆

TODO items:

  • Camera ready version of the paper is released on Oct. 12 (link)!
  • Full code and instructions will be released soon!
  • Released a introductory blog in Chinese. Check it out!
  • 2023 Jan. 4: Done the benchmarking for 14 datasets under GOODMotif, GOODCMNIST, GOODSST2, GOODHIV (overview, full results, details)! CIGA is the SOTA Graph OOD algorithm under all benchmarked datasets! 🔥🔥🔥

Introduction

Despite recent success in using the invariance principle for out-of-distribution (OOD) generalization on Euclidean data (e.g., images), studies on graph data are still limited. Different from images, the complex nature of graphs poses unique challenges to adopting the invariance principle:

  1. Distribution shifts on graphs can appear in a variety of forms:

    • Node attributes;
    • Graph structure;
    • A mixure of both;
  2. Each distribution shift can spuriously correlate with the label in different modes. We divide the modes into FIIF and PIIF, according to whether the latent causal feature $C$ fully determines the label $Y$, i.e., or $(S,E)\perp\mkern-9.5mu\perp Y|C$:

    • Fully Informative Invariant Features (FIIF): $Y\leftarrow C\rightarrow S\leftarrow E$;
    • Partially Informative Invariant Features (PIIF): $C\rightarrow Y\leftarrow S \leftarrow E$;
    • Mixed Informative Invariant Features (MIIF): mixed with both FIIF and PIIF;
  3. Domain or environment partitions, which are often required by OOD methods on Euclidean data, can be highly expensive to obtain for graphs.

Figure 1. The architecture of CIGA.

This work addresses the above challenges by generalizing the causal invariance principle to graphs, and instantiating it as CIGA. Shown as in Figure 1, CIGA is powered by an information-theoretic objective that extracts the subgraphs which maximally preserve the invariant intra-class information. With certain assumptions, CIGA provably identifies the underlying invariant subgraphs (shown as the orange subgraphs). Learning with these subgraphs is immune to distribution shifts.

We implement CIGA using the interpretable GNN architecture, where the featurizer $g$ is designed to extract the invariant subgraph, and a classifier $f_c$ is designed to classify the extracted subgraph. The objective is imposed as an additional contrastive penalty to enforce the invariance of the extracted subgraphs at a latent sphere space (CIGAv1).

  1. When the size of underlying invariant subgraph $G_c$ is known and fixed across different graphs and environments, CIGAv1 is able to identify $G_c$.
  2. While it is often the case that the underlying $G_c$ varies, we further incorporate an additional penalty that maximizes $I(G_s;Y)$ to absorb potential spurious parts in the estimated $G_c$ (CIGAv2).

Extensive experiments on $16$ synthetic or real-world datasets, including a challenging setting -- DrugOOD, from AI-aided drug discovery, validate the superior OOD generalization ability of CIGA.

Use CIGA in Your Code

CIGA is consist of two key regularization terms: one is the contrastive loss that maximizes $I(\widehat{G}_c;\widetilde{G}_c|Y)$; the other is the hinge loss that maximizes $I(\widehat{G}_s;Y)$.

The contrastive loss is implemented via a simple call (line 480 in main.py):

get_contrast_loss(causal_rep, label)

which requires two key inputs:

  • causal_rep: the representations of the invariant subgraph representations;
  • label: the labels corresponding to the original graphs.

The hinge loss is implemented in line 430 to line 445 in main.py:

# a simple implementation of hinge loss
spu_loss_weight = torch.zeros(spu_pred_loss.size()).to(device)
spu_loss_weight[spu_pred_loss > pred_loss] = 1.0
spu_pred_loss = spu_pred_loss.dot(spu_loss_weight) / (sum(spu_pred_loss > pred_loss) + 1e-6)

which requires two key inputs:

  • spu_pred_loss: sample-wise loss values of predictions based on the spurious subgraph $\widehat{G}_s$.
  • pred_loss: sample-wise loss values of predictions based on the invariant subgraph $\widehat{G}_c$.

Then we can calculate the weights spu_loss_weight in the hinge loss for each sample based on the sample-wise loss values, and apply the weights to spu_pred_loss.

Instructions

Installation and data preparation

Our code is based on the following libraries:

torch==1.9.0
torch-geometric==1.7.2
scikit-image==0.19.1 

plus the DrugOOD benchmark repo.

The data used in the paper can be obtained following these instructions.

Reproduce results

We provide the hyperparamter tuning and evaluation details in the paper and appendix. In the below we give a brief introduction of the commands and their usage in our code. We provide the corresponding running scripts in the script folder.

To obtain results of ERM, simply run

python main.py --erm

with corresponding datasets and model specifications.

Runing with CIGA:

  • --ginv_opt specifies the interpretable GNN architectures, which can be asap or gib to test with ASAP or GIB respectively.
  • --r is also needed for interpretable GNN architectures that specify the interpretable ratio, i.e., the size of $G_c$.
  • --c_rep controls the inputs of the contrastive learning, e.g., the graph representations from the featurizer or from the classifier
  • --c_in controls the inputs to the classifier, e.g., the original graph features or the features from the featurizer
  • To test with CIGAv1, simply specify --ginv_opt as default, and --contrast a value larger than 0.
  • While for CIGAv2, additionally specify --spu_coe to include the other objective.
  • --s_rep controls the inputs for maximizing $I(\hat{G_s};Y)$, e.g., the graph representation of $\hat{G_s}$ from the classifier or the featurizer.

Running with the baselines:

  • To test with DIR, simply specify --ginv_opt as default and --dir a value larger than 0.
  • To test with invariant learning baselines, specify --num_envs=2 and use --irm_opt to be irm, vrex, eiil or ib-irm to specify the methods, and --irm_p to specify the penalty weights.

Due to the additional dependence of an ERM reference model in CNC, we need to train an ERM model and save it first, and then load the model to generate ERM predictions for positive/negative pairs sampling in CNC. Here is a simplistic example:

python main.py --erm --contrast 0 --save_model
python main.py --erm --contrast 1  -c_sam 'cnc'

Misc

As discussed in the paper that the current code is merely a prototypical implementation based on an interpretable GNN architecture, i.e., GAE, in fact there could be more implementation choices:

  • For the architectures: CIGA can also be implemented via GIB and GSAT.
  • For the hyperparameter tunning: You may find plentiful literature from multi-task learning, or try out PAIR.
  • Besides, CIGA is also compatible with state-of-the-art contrastive augmentations for graph learning, which you may find useful information from PyGCL.

You can also find more discussions on the limitations and future works in Appendix B of our paper.

That being said, CIGA is definitely not the ultimate solution and it intrinsically has many limitations. Nevertheless, we hope the causal analysis and the inspired solution in CIGA could serve as an initial step towards more reliable graph learning algorithms that are able to generalize various OOD graphs from the real world.

If you find our paper and repo useful, please cite our paper:

@InProceedings{chen2022ciga,
  title       = {Learning Causally Invariant Representations for Out-of-Distribution Generalization on Graphs},
  author      = {Yongqiang Chen and Yonggang Zhang and Yatao Bian and Han Yang and Kaili Ma and Binghui Xie and Tongliang Liu and Bo Han and James Cheng},
  booktitle   = {Advances in Neural Information Processing Systems},
  year        = {2022}
}

Ack: The readme is inspired by GSAT. 😄