README.md
July 9, 2024 ยท View on GitHub
rlbase

This is a codebase that implements simple reinforcement learning algorithms in JAX. It also has support for several environments. The idea is to have solid single-file implementations of various RL algorithms for research use. This codebase contains both online and offline methods.
Online Algorithms Implemented:
- Proximal Policy Optimization (PPO):
algs_online/ppo.py - Soft Actor-Critic (SAC):
algs_online/sac.py - Twin Delayed DDPG (TD3):
algs_online/td3.py
Offline Algorithms Implemented:
- Behavior Cloning (BC):
algs_offline/bc.py - Implicit Q-Learning (IQL):
algs_offline/iql.py
Environments Supported:
- (Online) Gym Mujoco Locomotion:
HalfCheetah-v2, CartPole-v1, etc - (Online) Deepmind Control:
cheetah_run, pendulum_swingup, etc - (Offline) D4RL Mujoco Locomotion:
halfcheetah-medium-expert-v2, etc - (Offline) D4RL AntMaze + Goal Conditioned:
antmaze-large-diverse-v2, gc-antmaze-large-diverse-v2 - (Offline) ExORL:
exorl_cheetah_walk, etc - See
envs/env_helper.pyfor full list
Instllation
For the cleanest installation, create a conda environment:
conda env create -f deps/environment.yml
You can also refer to the singularity script in deps/base_container.def for full reproducability.
Reproduction
We've provided a set of stable results comparing each algorithm to a reference implementation. For full training curves, see the wandb reports for online results and the wandb reports for offline results.
You can reproduce these results using the commands available at run_baselines.py.
The basic starting point is to run the individual file, e.g.
python algs_online/ppo.py --env_name walker_walk --agent.gamma 0.99
Offline Results
| Env | Best Performance (ours) | Best Original Performance (reference paper) |
|---|---|---|
| exorl_cheetah_run | 257.5 (IQL-DDPG) | ~250 (TD3) source (exorl) |
| exorl_walker_run | 471.9 (IQL-DDPG) | ~200 (TD3) source (exorl) |
| halfcheetah-medium-expert-v2 | 83.8 (IQL) | 90.7 (TD3+BC) source (iql) |
| walker2d-medium-expert-v2 | 106.8 (BC) | 110.1 (TD3+BC) source (iql) |
| hopper-medium-expert-v2 | 98.9 (IQL) | 98.7 (CQL) source (iql) |
| gc-antmaze-large-diverse-v2 | 52.5 (IQL) | 50.7 (IQL) source (hiql) |
| gc-maze2d-large-v1 | 97.5 (IQL) | N/A |
Online Results
| Env | Best Performance (ours) | Best Original Performance (reference paper) |
|---|---|---|
| HalfCheetah-v2 | 11029 (SAC) | 12138.8 (SAC) source (tianshou) |
| Walker2d-v2 | 5101.8 (SAC-Tianshoulike) | 5007 (SAC)source (tianshou) |
| Hopper-v2 | 2714.4 (REDQ) | 3542.2 (SAC)source (tianshou) |
| cheetah_run | 918.9 (REDQ) | 800 (SAC) source (pytorch-sac) |
| walker_run | 835.7 (TD3) | 700 (SAC) source (pytorch-sac) |
| hopper_hop | 474.9 (TD3) | 210 (SAC) source (pytorch-sac) |
| quadruped_run | 920.8 (TD3) | 700 (SAC) source (pytorch-sac) |
| humanoid_run | 211.8 (REDQ) | 90 (SAC) source (pytorch-sac) |
| pendulum_swingup | 790.2 (SAC) | 920 (SAC) source (pytorch-sac) |
History
This code is based largely off the jaxrl_m repo, and takes inspiration also from jaxrl and cleanrl.