TRFL: Reinforcement Learning Building Blocks
April 16, 2020 ยท View on GitHub
TRFL (pronounced "truffle") is a library built on top of TensorFlow that exposes several useful building blocks for implementing Reinforcement Learning agents.
Background
Common RL algorithms describe a particular update to either a Policy, a Value function, or an Action-Value (Q) function. In Deep-RL, a policy, value- or Q- function is typically represented by a neural network (the model, not to be confused with an environment model, which is used in model-based RL). We formulate common RL update rules for these neural networks as differentiable loss functions, as is common in (un-)supervised learning. Under automatic differentiation, the original update rule is recovered. We find that loss functions are more modular and composable than traditional RL updates, and more natural when combining with supervised or unsupervised objectives.
The loss functions and other operations provided here are implemented in pure TensorFlow. They are not complete algorithms, but implementations of RL-specific mathematical operations needed when building fully-functional RL agents. In particular, the updates are only valid if the input data are sampled in the correct manner. For example, the sequence-advantage-actor-critic loss (i.e. A2C) is only valid if the input trajectory is an unbiased sample from the current policy; i.e. the data are on-policy. This library cannot check or enforce such constraints.
Installation
TRFL can be installed from pip directly from github, with the following command:
pip install git+git://github.com/deepmind/trfl.git
TRFL will work with both the CPU and GPU version of tensorflow, but to allow for that it does not list Tensorflow as a requirement, so you need to install Tensorflow and Tensorflow-probability separately if you haven't already done so.
Example usage
Import TensorFlow and TRFL.
import tensorflow as tf
import trfl
Define the relevant data associated to a transition in the environment from
state s_tm1 to state s_t. This typically includes action values (or other
characterization of the agent's policy) in both the source and destination
states. The action a_tm1 is the one selected after observing s_tm1, and
resulted in observing the immediate reward r_t and the subsequent state s_t.
pcont_t represents a time dependent discount factor, or (equivalently) the
continuation probability from state s_t. In most applications, its value will
be equal to a constant factor (e.g., 0.99), except if s_t is the final state
in an episode, in which case it is set to zero.
# Q-values for the previous and next timesteps, shape [batch_size, num_actions].
q_tm1 = tf.get_variable(
"q_tm1", initializer=[[1., 1., 0.], [1., 2., 0.]], dtype=tf.float32)
q_t = tf.get_variable(
"q_t", initializer=[[0., 1., 0.], [1., 2., 0.]], dtype=tf.float32)
# Action indices, discounts and rewards, shape [batch_size].
a_tm1 = tf.constant([0, 1], dtype=tf.int32)
r_t = tf.constant([1, 1], dtype=tf.float32)
pcont_t = tf.constant([0, 1], dtype=tf.float32) # the discount factor
# Q-learning loss, and auxiliary data.
loss, q_learning = trfl.qlearning(q_tm1, a_tm1, r_t, pcont_t, q_t)
loss is the tensor representing the loss. For Q-learning, it is half the
squared difference between the predicted Q-values and the TD targets, shape
[batch_size]. Extra information is in the q_learning namedtuple, including
q_learning.td_error and q_learning.target.
Most of the time, you may only be interested in the loss, in which case you can use any of the following expressions:
loss, _ = trfl.qlearning(q_tm1, a_tm1, r_t, pcont_t, q_t)
loss = trfl.qlearning(q_tm1, a_tm1, r_t, pcont_t, q_t).loss
The loss tensor can be differentiated to derive the corresponding RL update.
Note that in Q-learning, as in other bootstrapped losses, the TD targets
are wrapped in a tf.stop_gradient. Differentiating loss therefore
results in gradients with respect to q_tm1 but not with respect to q_t.
reduced_loss = tf.reduce_mean(loss)
optimizer = tf.train.AdamOptimizer(learning_rate=0.1)
train_op = optimizer.minimize(reduced_loss)
All loss functions in the package return both a loss tensor and a namedtuple
with extra information, using the above convention, but different functions
may have different extra fields. Check the documentation of each function
below for more information.
Naming Conventions and Developer Guidelines
Throughout the package, we use the following conventions:
-
Time indices and variable names:
q_tm1: the action value in thesourcestate of a transition.a_tm1: the action that was selected in thesourcestate.r_t: the resulting rewards collected in thedestinationstate.pcont_t: the continuation probability /discountfor a transition.q_t: the action values in thedestinationstate.
-
Tensor shapes:
- All ops support minibatches only. We use
Bto denote the batch size. - Batches of rewards, continuation probabilities / discounts have shape [B].
- Batches of state-values have shape
[B]. - Batches of action-values / q-values have shape
[B, num_actions]. - All losses have shape [B], i.e. the loss is not reduced over the batch dimension. This allows the user to easily weight the loss for different elements of the batch (e.g., by their importance sampling weights).
- For ops that take batches of sequences of data,
Tdenotes the sequence length. Tensors are time-major, and have shape[T, B, ...]. Index0of the time dimension is assumed to be the start of the sequence.
- All ops support minibatches only. We use
Learning updates
-
State Value learning:
-
Discrete-action Value learning:
-
Distributional Value learning:
-
Continuous-action Policy Gradient:
-
Deterministic Policy Gradient:
-
Discrete-action Policy Gradient:
- discrete_policy_entropy_loss
- sequence_advantage_actor_critic_loss: this is the commonly-used A2C/A3C loss function.
- discrete_policy_gradient
- discrete_policy_gradient_loss
-
Pixel control:
-
Retrace:
-
Target Network Updating:
-
V-trace:
Other
-
Clipping ops
-
Distributions
-
Indexing ops
-
Periodic execution ops
-
Policy ops
-
Sequence ops