Adversarial Discriminative Domain Adaptation
March 29, 2017 ยท View on GitHub
Getting started
This code requires Python 3, and is implemented in Tensorflow.
Hopefully things should be fairly easy to run out of the box:
pip install -r requirements.txt
mkdir data snapshot
export PYTHONPATH="$PWD:$PYTHONPATH"
scripts/svhn-mnist.sh
The provided script does the following things:
- Train a base LeNet model on SVHN (downloading SVHN under
data/svhnin the process) - Use ADDA to adapt the SVHN model to MNIST (downloading MNIST under
data/mnistin the process) - Run an evaluation on MNIST using the source-only model (stored at
snapshot/lenet_svhn) - Run an evaluation on MNIST using the ADDA model (stored at
snapshot/adda_lenet_svhn_mnist)
Areas of interest
- Check
scripts/svhn-mnist.shfor hyperparameters. - The LeNet model definition is in
adda/models/lenet.py. - The model is annotated with data preprocessing info, which is used in the
preprocessingfunction inadda/models/model.py. - The main ADDA logic happens in
tools/train_adda.py. - The adversarial discriminator model definition is in
adda/adversary.py.