pytorch-randaugment

December 30, 2019 ยท View on GitHub

Unofficial PyTorch Reimplementation of RandAugment. Most of codes are from Fast AutoAugment.

Introduction

Models can be trained with RandAugment for the dataset of interest with no need for a separate proxy task. By only tuning two hyperparameters(N, M), you can achieve competitive performances as AutoAugments.

Install

$ pip install git+https://github.com/ildoonet/pytorch-randaugment

Usage

from torchvision.transforms import transforms
from RandAugment import RandAugment

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD),
])

# Add RandAugment with N, M(hyperparameter)
transform_train.transforms.insert(0, RandAugment(N, M))

Experiment

We use same hyperparameters as the paper mentioned. We observed similar results as reported.

You can run an experiment with,

$ python RandAugment/train.py -c confs/wresnet28x10_cifar10_b256.yaml --save cifar10_wres28x10.pth

CIFAR-10 Classification

ModelPaper's ResultOurs
Wide-ResNet 28x1097.397.4
Shake26 2x96d98.098.1
Pyramid27298.5

CIFAR-100 Classification

ModelPaper's ResultOurs
Wide-ResNet 28x1083.383.3

SVHN Classification

ModelPaper's ResultOurs
Wide-ResNet 28x1098.998.8

ImageNet Classification

I have experienced some difficulties while reproducing paper's result.

Issue : https://github.com/ildoonet/pytorch-randaugment/issues/9

ModelPaper's ResultOurs
ResNet-5077.6 / 92.8TODO
EfficientNet-B583.2 / 96.7TODO
EfficientNet-B784.4 / 97.1TODO

References