GreenAug: Green Screen Augmentation Enables Scene Generalisation in Robotic Manipulation
July 15, 2024 ยท View on GitHub
Eugene Teoh, Sumit Patidar, Xiao Ma, Stephen James
This repo contains the following augmentation methods:
-
greenaug.greenaug_random.GreenAugRandom: This applies random textures to the chroma-keyed background. In our paper, we used mil_data. -
greenaug.greenaug_generative.GreenAugGenerative: This uses the chroma-keyed mask to inpaint realistic or imagined backgrounds using Stable Diffusion. -
greenaug.greenaug_mask.GreenAugMask: This uses a masking network to isolate backgrounds as dark pixels during inference. One first needs to train a masking network (see instructions below). -
greenaug.generative_augmentation.GenerativeAugmentation: This is an implementation of generative augmentation (e.g. CACTI, GenAug, ROSIE). The implementation is close to ROSIE but with open-source models (Grounding DINO, Segment Anything, Stable Diffusion).
These augmentation methods can be integrated during policy learning (imitation or reinforcement). In our experiments, we used ACT and Coarse-to-fine Q-Network.
Installation
Install GreenAug as a Python package:
pip install greenaug @ git+https://github.com/eugeneteoh/greenaug.git
To use the generative variants (GreenAugGenerative and GenerativeAugmentation), set the CUDA_HOME environment variable and install cuda-toolkit:
conda create -n greenaug python=3.10 -y
conda activate greenaug
conda env config vars set CUDA_HOME=$CONDA_PREFIX
conda activate greenaug
# Install PyTorch
# Follow instructions at https://pytorch.org/get-started/locally/
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia -y
conda install cuda-toolkit -c nvidia/label/cuda-12.1.1 -y
pip install greenaug[generative] @ git+https://github.com/eugeneteoh/greenaug.git
To use GreenAugMask:
pip install greenaug[mask] @ git+https://github.com/eugeneteoh/greenaug.git
Then see the example below.
Example Usage
Check examples under examples/.
import torch
from greenaug import GreenAugRandom
augmenter = GreenAugRandom() # This is a torch.nn.Module
out = augmenter(image, ...)
Training GreenAugMask masking network:
# Download data
huggingface-cli download --repo-type dataset eugeneteoh/greenaug --include "GreenScreenDemoCollection/open_drawer_green_screen.mp4" --local-dir "assets/mask/raw/"
huggingface-cli download --repo-type dataset eugeneteoh/mil_data --include "*.png" --local-dir "assets/mask/background/"
# Preprocess data
python scripts/preprocess_masking_data.py
# Train Masking Network
python scripts/train_masking_network.py
# Run example
python examples/greenaug_mask.py --checkpoint /path/to/checkpoint