Guided Cluster Aggregation
October 26, 2023 ยท View on GitHub
This repository contains the code for our paper Guided Cluster Aggregation: A Hierarchical Approach to Generalized Category Discovery.
Installation
The dependencies can be installed via conda:
conda env create -f environment.yaml
To activate the environment run
conda activate gca
Data
We use the follwoing datasets in our paper:
- CIFAR10
- CIFAR100
- ImageNet-100 (subset of regular ImageNet)
- CUB200-2011
- Stanford Cars
- FGVC Aircraft
Base Models
We use two pretrainings in our paper: GCD and PromptCAL. While PromptCAL offers pretrained models for download, GCD does not, so you will have to train the GCD models yourself.
Usage
We provide experiment configs in the ./config/experiment folder.
The naming scheme is ./config/experiment/{pretraining}_{dataset}.yaml.
We'll use the ./config/experiment/promptcal_cub.yaml config as an example.
To reproduce the results from our paper follow these steps:
1. Obtain the correct pretrained model
In this case just use the link from the PromptCAL repository.
Place it in the directory specified in the config.
The paths given in the config file are relative to the ./src folder, so in
our case the full path of the pretrained model would be
./pretrained/GCD/promptcal_cub.ckpt.
2. Compute the nearest neighbors.
To do so, open a shell in the ./src folder
and run
python evaluate.py trainer=debug experiment=promptcal_cub datamodule.data_dir=[data location] lightning_module.knn_file=null
The nearest neighbor file will be saved in the
./src/evaluation_results/promptcal_cub folder.
Place it in the directory specified in the config, in this case
./neighbors/cub_promptcal_finetuned.npy.
3. Run the experiment
To do so, open a shell in the ./src folder
and run
python train.py trainer=train experiment=promptcal_cub datamodule.data_dir=[data location]
This will run the model using tensorboard logging.
You can also use wandb by adding
trainer/logger=wandb trainer.logger.entity=[entity] trainer.logger.project=[project]
to the command.
You can also run the training on a slurm cluster by adapting
./config/cluster/example.yaml and ./config/slurm_config/gca_gcd_example.yaml.
Then run
python train.py cluster=example trainer=train experiment=promptcal_cub datamodule.data_dir=[data location]
Acknowledgements
Most of the dataset loading code is based on GCD. The code for the VPT-Vit is based on PromptCAL, as well as some code regarding the experiments with less labeled data (e.g. CIFAR100 c10l50). Both repositories are licensed under the MIT license.