PyTorch implementation of TWP

December 31, 2020 · View on GitHub

Overcoming Catastrophic Forgetting in Graph Neural Networks, AAAI2021

Method Overview

image

Cite

If you find this code useful in your research, please consider citing:

@inproceedings{liu2021overcoming,
Title = {Overcoming Catastrophic Forgetting in Graph Neural Networks},
Author = {Huihui Liu, Yiding Yang, and Xinchao Wang},
Booktitle  = {AAAI Conference on Artificial Intelligence},
Year = {2021}
}

Dependencies

See the file requirements.txt for more information about how to install the dependencies.

Datasets

Node classification

We conduct experiments on four public datasets (Corafull, Amazon Computer, PPI, Reddit) based on DGL.

Graph classification

We conduct experiments on one public dataset (Tox21) based on DGLlife.

Models

We use DGL to implement all the GNN models.

Overview

Here we provide the implementation of our method based on the Corafull dataset. The folder < corafull_amazon/ > is organised as follows:

  • < LifeModel/ > contains the implementation of all the continual learning method for GNNs, including the baseline methods and our method;
  • < dataset/ > contains code to load the dataset;
  • < models/ > contains the implementation of the GNN backbone;
  • < training/ > contains code to set seed;
  • The file < train.py > is used to execute the whole training process on the Corafull dataset;
  • The file < run.sh > contains an example to run the code.

Results

Here we show the performance comparison on different datasets with GATs as the backbone. For the task performance, we use classification accuracy on Corafull and Amazon Computers datasets, and micro-averaged F1 score for PPI and Reddit datasets. The symbol↑(↓) indicates higher (lower) is better.

Dataset: Corafull

MethodAP (↑)AF(↓)
Fine-tune51.6±6.4%46.1±7.0%
LWF57.3±2.3%39.5±3.1%
GEM84.4±1.1%4.2±1.0%
EWC86.9±1.7%6.4±1.8%
MAS84.1±1.8%8.6±2.2%
Ours89.0±0.8%3.3±0.3%
Joint train91.9±0.8%0.1±0.2%

Dataset: Amazon Computers

MethodAP (↑)AF(↓)
Fine-tune86.5±8.0%12.3±12.3%
LWF90.3±6.4%9.9±7.0%
GEM97.1±0.9%0.7±0.5%
EWC94.5±3.3%4.6±4.5%
MAS94.0±5.5%5.0±6.9%
Ours97.3±0.6%0.6±0.2%
Joint train98.2±0.6%0.02±0.1%

Dataset: PPI

MethodAP (↑)AF(↓)
Fine-tune0.365±0.024%0.178±0.019%
LWF0.382±0.024%0.185±0.060%
GEM0.741±0.016%0.112±0.030%
EWC0.826±0.027%0.142±0.028%
MAS0.749±0.007%0.092±0.008%
Ours0.853±0.004%0.086±0.005%
Joint train0.931±0.40%0.035±0.026%

Dataset: Reddit

MethodAP (↑)AF(↓)
Fine-tune0.474±0.006%0.580±0.007%
LWF0.500±0.033%0.550±0.034%
GEM0.947±0.001%0.030±0.008%
EWC0.944±0.019%0.032±0.021%
MAS0.865±0.031%0.085±0.024%
Ours0.954±0.014%0.014±0.015%
Joint train0.978±0.001%0.001±0.001%