Efficient LSTM cell in Torch
May 5, 2015 ยท View on GitHub
--[[ Efficient LSTM in Torch using nngraph library. This code was optimized by Justin Johnson (@jcjohnson) based on the trick of batching up the LSTM GEMMs, as also seen in my efficient Python LSTM gist. --]]
function LSTM.fast_lstm(input_size, rnn_size) local x = nn.Identity()() local prev_c = nn.Identity()() local prev_h = nn.Identity()()
local i2h = nn.Linear(input_size, 4 * rnn_size)(x) local h2h = nn.Linear(rnn_size, 4 * rnn_size)(prev_h) local all_input_sums = nn.CAddTable()({i2h, h2h})
local sigmoid_chunk = nn.Narrow(2, 1, 3 * rnn_size)(all_input_sums) sigmoid_chunk = nn.Sigmoid()(sigmoid_chunk) local in_gate = nn.Narrow(2, 1, rnn_size)(sigmoid_chunk) local forget_gate = nn.Narrow(2, rnn_size + 1, rnn_size)(sigmoid_chunk) local out_gate = nn.Narrow(2, 2 * rnn_size + 1, rnn_size)(sigmoid_chunk)
local in_transform = nn.Narrow(2, 3 * rnn_size + 1, rnn_size)(all_input_sums) in_transform = nn.Tanh()(in_transform)
local next_c = nn.CAddTable()({ nn.CMulTable()({forget_gate, prev_c}), nn.CMulTable()({in_gate, in_transform}) }) local next_h = nn.CMulTable()({out_gate, nn.Tanh()(next_c)})
return nn.gModule({x, prev_c, prev_h}, {next_c, next_h}) end