DP-KIP

May 24, 2024 ยท View on GitHub

Installation JAX environment

  conda create --name dpkip python=3.9
  conda activate dpkip
  pip install -e /path/to/the/code/folder/of/this/repo/on/your/local/drive

Note: This will install JAX for CPU only. If you want to install JAX for GPU please go to: https://github.com/google/jax#installation

Run KIP and DP-KIP code

Image data KRR downstream classifier for infinite-width FC-NTK

python dpkip_inf_ntk.py --dpsgd=True --l2_norm_clip=1e-6 --epochs=10 --learning_rate=1e-2 --batch_size=50 --epsilon=1 --architecture='FC' --width=1024 --dataset='mnist' --support_size=10

Image data KRR downstream classifier for ScatterNet features

python dp_kip_other_features.py --dpsgd=True --learning_rate=1e-1 --batch_size=2000 --kip_loss_reg=1e-3 --feature_type="wavelet" --dataset='mnist' --epochs=10 --rand_init=True --support_size=10 --l2_norm_clip=1e-2 --epsilon=10

Image data KRR downstream classifier for ScatterNet features (non-dp)

python dp_kip_other_features.py --dpsgd=False --learning_rate=1e-1 --batch_size=2000 --kip_loss_reg=1e-3 --feature_type="wavelet" --dataset='mnist' --epochs=10 --support_size=10

Image data KRR downstream classifier for PFs (non-dp)

python dp_kip_other_features.py --dpsgd=False --learning_rate=1e-1 --batch_size=2000 --kip_loss_reg=1e-3 --feature_type="resnet" --pretrained_encoder=False --normalize_features=True --dataset='svhn_cropped' --epochs=10 --support_size=10

Image data KRR downstream classifier for e-NTK (non-dp)

python KIP_lenet_ntk.py --disable-dp --batch-size 2000 --epochs 10 --lr 1e-1 --reg 1e-3 --sup_size 10 --dataset 'fashion_mnist'

Tabular data

python dpkip_tab_data.py --dpsgd=True --reg=1e-6 --learning_rate=1e-1 --l2_norm_clip=1e-1 --batch_rate=0.01 --epochs=10 --dataset='credit' --undersampled_rate=0.01 --architecture='FC' --support_size=2 --width=1024