On the Versatile Uses of Partial Distance Correlation in Deep Learning - Official PyTorch Implementation

December 23, 2022 · View on GitHub

On the Versatile Uses of Partial Distance Correlation in Deep Learning
Xingjian Zhen, Zihang Meng, Rudrasis Chakraborty, Vikas Singh European Conference on Computer Vision (ECCV), 2022.

Update:

  • Fixed typo in Partial_Distance_Correlation.ipynb from Peasor_Correlation() to Pearson_Correlation()
  • Reimplementation in TensorFlow in TF_Partial_Distance_Correlation.ipynb
  • Reimplementation in TF_Diverge_Training module

Abstract: Comparing the functional behavior of neural network models, whether it is a single network over time or two (or more networks) during or post-training, is an essential step in understanding what they are learning (and what they are not), and for identifying strategies for regularization or efficiency improvements. Despite recent progress, e.g., comparing vision transformers to CNNs, systematic comparison of function, especially across different networks, remains difficult and is often carried out layer by layer. Approaches such as canonical correlation analysis (CCA) are applicable in principle, but have been sparingly used so far. In this paper, we revisit a (less widely known) from statistics, called distance correlation (and its partial variant), designed to evaluate correlation between feature spaces of different dimensions. We describe the steps necessary to carry out its deployment for large scale models -- this opens the door to a surprising array of applications ranging from conditioning one deep model w.r.t. another, learning disentangled representations as well as optimizing diverse models that would directly be more robust to adversarial attacks. Our experiments suggest a versatile regularizer (or constraint) with many advantages, which avoids some of the common difficulties one faces in such analyses.

YouTube Introduction

Please also go check the project webpage

Results

Independent Features Help Robustness (Diverge Training)

Table 1: The test accuracy (%) of a model f2f_2 on the adversarial examples generated using f1f_1 with the same architecture. "Baseline": train without constraint. "Ours": f2f_2 is independent to f1f_1. "Clean": test accuracy without adversarial examples.

DatasetNetworkMethodCleanFGM ϵ=0.03\epsilon=0.03PGD ϵ=0.03\epsilon=0.03FGM ϵ=0.05\epsilon=0.05PGD ϵ=0.05\epsilon=0.05FGM ϵ=0.10\epsilon=0.10PGD ϵ=0.10\epsilon=0.10
CIFAR10Resnet 18Baseline89.1472.1066.3462.0049.4248.2327.41
CIFAR10Resnet 18Ours87.6174.7672.8565.5659.3350.2436.11
ImageNetMobilenet-v3-smallBaseline47.1629.6430.0023.5224.8113.9017.15
ImageNetMobilenet-v3-smallOurs42.3434.4736.9829.5333.7719.5328.04
ImageNetEfficientnet-B0Baseline57.8526.7228.2218.9619.4512.0411.17
ImageNetEfficientnet-B0Ours55.8230.4235.9922.0527.5614.1617.62
ImageNetResnet 34Baseline64.0152.6256.6145.4551.1133.7541.70
ImageNetResnet 34Ours63.7753.1957.1846.5052.2835.0043.35
ImageNetResnet 152Baseline66.8856.5659.1950.6153.4940.5044.49
ImageNetResnet 152Ours68.0458.3461.3352.5956.0542.6147.17

Diverge Training

Informative Comparisons between Networks (Partial Distance Correlation)

Table 2: Partial DC between the network ΘX\Theta_X conditioned on the network ΘY\Theta_Y , and the ImageNet class name embedding. The higher value indicates the more information.

Network ΘX\Theta_XNetwork ΘY\Theta_YR2(X,GT)\mathcal{R}^2(X, GT)R2(Y,GT)\mathcal{R}^2(Y, GT)R2(XY,GT)\mathcal{R}^2(X\|Y, GT)R2((YX),GT)\mathcal{R}^2((Y\|X), GT)
ViT1^1Resnet 182^20.0420.0250.0350.007
ViTResnet 503^30.0430.0360.0280.017
ViTResnet 1524^40.0440.0200.0400.009
ViTVGG 19 BN5^50.0420.0370.0260.015
ViTDensenet1216^60.0430.0260.0350.007
ViT large7^7Resnet 180.0460.0270.0380.007
ViT largeResnet 500.0460.0370.0310.016
ViT largeResnet 1520.0460.0210.0420.010
ViT largeViT0.0450.0430.0190.013
ViT+Resnet 508^8Resnet 180.0440.0240.0370.005
Resnet 152Resnet 180.0190.0250.0130.020
Resnet 152Resnet 500.0210.0370.0030.030
Resnet 50Resnet 180.0360.0250.0270.008
Resnet 50VGG 19 BN0.0360.0360.0200.019

Note Accuracy: 1. 84.40%; 2. 69.76%; 3. 79.02%; 4. 82.54%; 5. 74.22%; 6. 75.57%; 7. 85.68%; 8. 84.13%

Grad Cam Heat Map

Disentanglement

Visualization Distance Correlation in Disentanglement

Quantitive measurement (distance correlation between residual and attribute of interest)

Table 3: DC between residual attributes (R) and attributes of interest, if we use the ground truth CLIP labeled data to measure the attribute of interest. Range from 0 to 1, and smaller is better.

age vs R.gender vs R.ethnicity vs R.hair color vs R.beard vs R.glasses vs R
0.03290.01800.02220.02420.02190.0255

Table 4: DC between residual attributes (R) and attributes of interest, if we use in-model classifier to classify the attribute of interest. Range from 0 to 1, and smaller is better.

age vs R.gender vs R.ethnicity vs R.hair color vs R.beard vs R.glasses vs R
0.04300.01240.03760.02590.04900.0188

Requirements

python 3.8 pytorch 1.8 cuda 10.2

Training

Please refer to each different directory for detailed training steps for each experiment.

Citation

@inproceedings{zhen2022versatile,
  author    = {Zhen, Xingjian and Meng, Zihang and Chakraborty, Rudrasis and Singh, Vikas},
  title     = {On the Versatile Uses of Partial Distance Correlation in Deep Learning},
  booktitle = {Proceedings of the European conference on computer vision (ECCV)},
  year      = {2022}
}

If you use distance correlation for disentanglement, please give credit to the following paper: Measuring the Biases and Effectiveness of Content-Style Disentanglement which discusses a nice demonstration of distance correlation helps content style disentanglement. We were not aware of this paper when we wrote the paper last year and thank Sotirios Tsaftaris for communicating his findings with us.