STViT-R

March 29, 2023 · View on GitHub

This folder contains the implementation of the STViT-R for image classification.

Install

  • Clone this repo:
git clone https://github.com/changsn/STViT-R.git
cd STViT-R
  • Create a conda virtual environment and activate it:
conda create -n stvit-r python=3.7 -y
conda activate stvit-r
conda install pytorch==1.7.1 torchvision==0.8.2 cudatoolkit=10.1 -c pytorch
  • Install timm==0.3.2:
pip install timm==0.3.2
  • Install Apex:
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
  • Install other requirements:
pip install opencv-python==4.4.0.46 termcolor==1.1.0 yacs==0.1.8

Data preparation

We use standard ImageNet dataset, you can download it from http://image-net.org/. We provide the following two ways to load data:

The file structure should look like:

$ tree data
imagenet
├── train
   ├── class1
   ├── img1.jpeg
   ├── img2.jpeg
   └── ...
   ├── class2
   ├── img3.jpeg
   └── ...
   └── ...
└── val
    ├── class1
   ├── img4.jpeg
   ├── img5.jpeg
   └── ...
    ├── class2
   ├── img6.jpeg
   └── ...
    └── ...