aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/grid_rnn
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2016-08-11 18:06:42 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-11 19:17:28 -0700
commitbc156957812b64d86d263bee940cacf7fb6fb9a3 (patch)
tree2147c502149a4ad616628317903a00742b46031e /tensorflow/contrib/grid_rnn
parentfc218ef4d53e697c1a97d90a57a7437255575f89 (diff)
** BREAKING CHANGE **
Core RNNCell implementations now use state_is_tuple=True by default This is part of the deprecation process for non-tuple LSTM and MultiRNNCell states. Change: 130059769
Diffstat (limited to 'tensorflow/contrib/grid_rnn')
-rw-r--r--tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py210
-rw-r--r--tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py356
2 files changed, 357 insertions, 209 deletions
diff --git a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
index 11fb208f49..abd06ffa26 100644
--- a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
+++ b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Tests for GridRNN cells."""
from __future__ import absolute_import
@@ -27,7 +26,8 @@ class GridRNNCellTest(tf.test.TestCase):
def testGrid2BasicLSTMCell(self):
with self.test_session() as sess:
- with tf.variable_scope('root', initializer=tf.constant_initializer(0.2)) as root_scope:
+ with tf.variable_scope(
+ 'root', initializer=tf.constant_initializer(0.2)) as root_scope:
x = tf.zeros([1, 3])
m = tf.zeros([1, 8])
cell = tf.contrib.grid_rnn.Grid2BasicLSTMCell(2)
@@ -38,15 +38,18 @@ class GridRNNCellTest(tf.test.TestCase):
self.assertEqual(s.get_shape(), (1, 8))
sess.run([tf.initialize_all_variables()])
- res = sess.run([g, s], {x: np.array([[1., 1., 1.]]),
- m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])})
+ res = sess.run(
+ [g, s], {x: np.array([[1., 1., 1.]]),
+ m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])})
self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 8))
- self.assertAllClose(res[0], [[ 0.36617181, 0.36617181]])
- self.assertAllClose(res[1], [[ 0.71053141, 0.71053141, 0.36617181, 0.36617181,
- 0.72320831, 0.80555487, 0.39102408, 0.42150158]])
+ self.assertAllClose(res[0], [[0.36617181, 0.36617181]])
+ self.assertAllClose(res[1], [[0.71053141, 0.71053141, 0.36617181,
+ 0.36617181, 0.72320831, 0.80555487,
+ 0.39102408, 0.42150158]])
- # emulate a loop through the input sequence, where we call cell() multiple times
+ # emulate a loop through the input sequence,
+ # where we call cell() multiple times
root_scope.reuse_variables()
g2, s2 = cell(x, m)
self.assertEqual(g2.get_shape(), (1, 2))
@@ -56,8 +59,9 @@ class GridRNNCellTest(tf.test.TestCase):
self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 8))
self.assertAllClose(res[0], [[0.58847463, 0.58847463]])
- self.assertAllClose(res[1], [[1.40469193, 1.40469193, 0.58847463, 0.58847463,
- 0.97726452, 1.04626071, 0.4927212, 0.51137757]])
+ self.assertAllClose(res[1], [[1.40469193, 1.40469193, 0.58847463,
+ 0.58847463, 0.97726452, 1.04626071,
+ 0.4927212, 0.51137757]])
def testGrid2BasicLSTMCellTied(self):
with self.test_session() as sess:
@@ -72,27 +76,31 @@ class GridRNNCellTest(tf.test.TestCase):
self.assertEqual(s.get_shape(), (1, 8))
sess.run([tf.initialize_all_variables()])
- res = sess.run([g, s], {x: np.array([[1., 1., 1.]]),
- m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])})
+ res = sess.run(
+ [g, s], {x: np.array([[1., 1., 1.]]),
+ m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])})
self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 8))
- self.assertAllClose(res[0], [[ 0.36617181, 0.36617181]])
- self.assertAllClose(res[1], [[ 0.71053141, 0.71053141, 0.36617181, 0.36617181,
- 0.72320831, 0.80555487, 0.39102408, 0.42150158]])
+ self.assertAllClose(res[0], [[0.36617181, 0.36617181]])
+ self.assertAllClose(res[1], [[0.71053141, 0.71053141, 0.36617181,
+ 0.36617181, 0.72320831, 0.80555487,
+ 0.39102408, 0.42150158]])
res = sess.run([g, s], {x: np.array([[1., 1., 1.]]), m: res[1]})
self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 8))
self.assertAllClose(res[0], [[0.36703536, 0.36703536]])
- self.assertAllClose(res[1], [[0.71200621, 0.71200621, 0.36703536, 0.36703536,
- 0.80941606, 0.87550586, 0.40108523, 0.42199609]])
+ self.assertAllClose(res[1], [[0.71200621, 0.71200621, 0.36703536,
+ 0.36703536, 0.80941606, 0.87550586,
+ 0.40108523, 0.42199609]])
def testGrid2BasicLSTMCellWithRelu(self):
with self.test_session() as sess:
with tf.variable_scope('root', initializer=tf.constant_initializer(0.2)):
x = tf.zeros([1, 3])
m = tf.zeros([1, 4])
- cell = tf.contrib.grid_rnn.Grid2BasicLSTMCell(2, tied=False, non_recurrent_fn=tf.nn.relu)
+ cell = tf.contrib.grid_rnn.Grid2BasicLSTMCell(
+ 2, tied=False, non_recurrent_fn=tf.nn.relu)
self.assertEqual(cell.state_size, 4)
g, s = cell(x, m)
@@ -104,11 +112,11 @@ class GridRNNCellTest(tf.test.TestCase):
m: np.array([[0.1, 0.2, 0.3, 0.4]])})
self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 4))
- self.assertAllClose(res[0], [[ 0.31667367, 0.31667367]])
- self.assertAllClose(res[1], [[ 0.29530135, 0.37520045, 0.17044567, 0.21292259]])
+ self.assertAllClose(res[0], [[0.31667367, 0.31667367]])
+ self.assertAllClose(res[1],
+ [[0.29530135, 0.37520045, 0.17044567, 0.21292259]])
- """
- LSTMCell
+ """LSTMCell
"""
def testGrid2LSTMCell(self):
@@ -124,20 +132,23 @@ class GridRNNCellTest(tf.test.TestCase):
self.assertEqual(s.get_shape(), (1, 8))
sess.run([tf.initialize_all_variables()])
- res = sess.run([g, s], {x: np.array([[1., 1., 1.]]),
- m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])})
+ res = sess.run(
+ [g, s], {x: np.array([[1., 1., 1.]]),
+ m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])})
self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 8))
- self.assertAllClose(res[0], [[ 0.95686918, 0.95686918]])
- self.assertAllClose(res[1], [[ 2.41515064, 2.41515064, 0.95686918, 0.95686918,
- 1.38917875, 1.49043763, 0.83884692, 0.86036491]])
+ self.assertAllClose(res[0], [[0.95686918, 0.95686918]])
+ self.assertAllClose(res[1], [[2.41515064, 2.41515064, 0.95686918,
+ 0.95686918, 1.38917875, 1.49043763,
+ 0.83884692, 0.86036491]])
def testGrid2LSTMCellTied(self):
with self.test_session() as sess:
with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
x = tf.zeros([1, 3])
m = tf.zeros([1, 8])
- cell = tf.contrib.grid_rnn.Grid2LSTMCell(2, tied=True, use_peepholes=True)
+ cell = tf.contrib.grid_rnn.Grid2LSTMCell(
+ 2, tied=True, use_peepholes=True)
self.assertEqual(cell.state_size, 8)
g, s = cell(x, m)
@@ -145,20 +156,23 @@ class GridRNNCellTest(tf.test.TestCase):
self.assertEqual(s.get_shape(), (1, 8))
sess.run([tf.initialize_all_variables()])
- res = sess.run([g, s], {x: np.array([[1., 1., 1.]]),
- m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])})
+ res = sess.run(
+ [g, s], {x: np.array([[1., 1., 1.]]),
+ m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])})
self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 8))
- self.assertAllClose(res[0], [[ 0.95686918, 0.95686918]])
- self.assertAllClose(res[1], [[ 2.41515064, 2.41515064, 0.95686918, 0.95686918,
- 1.38917875, 1.49043763, 0.83884692, 0.86036491]])
+ self.assertAllClose(res[0], [[0.95686918, 0.95686918]])
+ self.assertAllClose(res[1], [[2.41515064, 2.41515064, 0.95686918,
+ 0.95686918, 1.38917875, 1.49043763,
+ 0.83884692, 0.86036491]])
def testGrid2LSTMCellWithRelu(self):
with self.test_session() as sess:
with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
x = tf.zeros([1, 3])
m = tf.zeros([1, 4])
- cell = tf.contrib.grid_rnn.Grid2LSTMCell(2, use_peepholes=True, non_recurrent_fn=tf.nn.relu)
+ cell = tf.contrib.grid_rnn.Grid2LSTMCell(
+ 2, use_peepholes=True, non_recurrent_fn=tf.nn.relu)
self.assertEqual(cell.state_size, 4)
g, s = cell(x, m)
@@ -170,11 +184,11 @@ class GridRNNCellTest(tf.test.TestCase):
m: np.array([[0.1, 0.2, 0.3, 0.4]])})
self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 4))
- self.assertAllClose(res[0], [[ 2.1831727, 2.1831727]])
- self.assertAllClose(res[1], [[ 0.92270052, 1.02325559, 0.66159075, 0.70475441]])
+ self.assertAllClose(res[0], [[2.1831727, 2.1831727]])
+ self.assertAllClose(res[1],
+ [[0.92270052, 1.02325559, 0.66159075, 0.70475441]])
- """
- RNNCell
+ """RNNCell
"""
def testGrid2BasicRNNCell(self):
@@ -190,14 +204,16 @@ class GridRNNCellTest(tf.test.TestCase):
self.assertEqual(s.get_shape(), (2, 4))
sess.run([tf.initialize_all_variables()])
- res = sess.run([g, s], {x: np.array([[1., 1.], [2., 2.]]),
- m: np.array([[0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2]])})
+ res = sess.run(
+ [g, s], {x: np.array([[1., 1.], [2., 2.]]),
+ m: np.array([[0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2]])})
self.assertEqual(res[0].shape, (2, 2))
self.assertEqual(res[1].shape, (2, 4))
self.assertAllClose(res[0], [[0.94685763, 0.94685763],
[0.99480951, 0.99480951]])
- self.assertAllClose(res[1], [[0.94685763, 0.94685763, 0.80049908, 0.80049908],
- [0.99480951, 0.99480951, 0.97574311, 0.97574311]])
+ self.assertAllClose(res[1],
+ [[0.94685763, 0.94685763, 0.80049908, 0.80049908],
+ [0.99480951, 0.99480951, 0.97574311, 0.97574311]])
def testGrid2BasicRNNCellTied(self):
with self.test_session() as sess:
@@ -212,21 +228,24 @@ class GridRNNCellTest(tf.test.TestCase):
self.assertEqual(s.get_shape(), (2, 4))
sess.run([tf.initialize_all_variables()])
- res = sess.run([g, s], {x: np.array([[1., 1.], [2., 2.]]),
- m: np.array([[0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2]])})
+ res = sess.run(
+ [g, s], {x: np.array([[1., 1.], [2., 2.]]),
+ m: np.array([[0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2]])})
self.assertEqual(res[0].shape, (2, 2))
self.assertEqual(res[1].shape, (2, 4))
self.assertAllClose(res[0], [[0.94685763, 0.94685763],
[0.99480951, 0.99480951]])
- self.assertAllClose(res[1], [[0.94685763, 0.94685763, 0.80049908, 0.80049908],
- [0.99480951, 0.99480951, 0.97574311, 0.97574311]])
+ self.assertAllClose(res[1],
+ [[0.94685763, 0.94685763, 0.80049908, 0.80049908],
+ [0.99480951, 0.99480951, 0.97574311, 0.97574311]])
def testGrid2BasicRNNCellWithRelu(self):
with self.test_session() as sess:
with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
x = tf.zeros([1, 2])
m = tf.zeros([1, 2])
- cell = tf.contrib.grid_rnn.Grid2BasicRNNCell(2, non_recurrent_fn=tf.nn.relu)
+ cell = tf.contrib.grid_rnn.Grid2BasicRNNCell(
+ 2, non_recurrent_fn=tf.nn.relu)
self.assertEqual(cell.state_size, 2)
g, s = cell(x, m)
@@ -234,19 +253,20 @@ class GridRNNCellTest(tf.test.TestCase):
self.assertEqual(s.get_shape(), (1, 2))
sess.run([tf.initialize_all_variables()])
- res = sess.run([g, s], {x: np.array([[1., 1.]]), m: np.array([[0.1, 0.1]])})
+ res = sess.run([g, s], {x: np.array([[1., 1.]]),
+ m: np.array([[0.1, 0.1]])})
self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 2))
self.assertAllClose(res[0], [[1.80049896, 1.80049896]])
self.assertAllClose(res[1], [[0.80049896, 0.80049896]])
- """
- 1-LSTM
+ """1-LSTM
"""
def testGrid1LSTMCell(self):
with self.test_session() as sess:
- with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)) as root_scope:
+ with tf.variable_scope(
+ 'root', initializer=tf.constant_initializer(0.5)) as root_scope:
x = tf.zeros([1, 3])
m = tf.zeros([1, 4])
cell = tf.contrib.grid_rnn.Grid1LSTMCell(2, use_peepholes=True)
@@ -262,7 +282,8 @@ class GridRNNCellTest(tf.test.TestCase):
self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 4))
self.assertAllClose(res[0], [[0.91287315, 0.91287315]])
- self.assertAllClose(res[1], [[2.26285243, 2.26285243, 0.91287315, 0.91287315]])
+ self.assertAllClose(res[1],
+ [[2.26285243, 2.26285243, 0.91287315, 0.91287315]])
root_scope.reuse_variables()
@@ -276,7 +297,8 @@ class GridRNNCellTest(tf.test.TestCase):
self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 4))
self.assertAllClose(res[0], [[0.9032144, 0.9032144]])
- self.assertAllClose(res[1], [[2.79966092, 2.79966092, 0.9032144, 0.9032144]])
+ self.assertAllClose(res[1],
+ [[2.79966092, 2.79966092, 0.9032144, 0.9032144]])
g3, s3 = cell(x2, m)
self.assertEqual(g3.get_shape(), (1, 2))
@@ -287,11 +309,12 @@ class GridRNNCellTest(tf.test.TestCase):
self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 4))
self.assertAllClose(res[0], [[0.92727238, 0.92727238]])
- self.assertAllClose(res[1], [[3.3529923, 3.3529923, 0.92727238, 0.92727238]])
+ self.assertAllClose(res[1],
+ [[3.3529923, 3.3529923, 0.92727238, 0.92727238]])
+ """3-LSTM
"""
- 3-LSTM
- """
+
def testGrid3LSTMCell(self):
with self.test_session() as sess:
with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
@@ -306,18 +329,20 @@ class GridRNNCellTest(tf.test.TestCase):
sess.run([tf.initialize_all_variables()])
res = sess.run([g, s], {x: np.array([[1., 1., 1.]]),
- m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, -0.1, -0.2, -0.3, -0.4]])})
+ m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7,
+ 0.8, -0.1, -0.2, -0.3, -0.4]])})
self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 12))
self.assertAllClose(res[0], [[0.96892911, 0.96892911]])
- self.assertAllClose(res[1], [[2.45227885, 2.45227885, 0.96892911, 0.96892911,
- 1.33592629, 1.4373529, 0.80867189, 0.83247656,
- 0.7317788, 0.63205892, 0.56548983, 0.50446129]])
+ self.assertAllClose(res[1], [[2.45227885, 2.45227885, 0.96892911,
+ 0.96892911, 1.33592629, 1.4373529,
+ 0.80867189, 0.83247656, 0.7317788,
+ 0.63205892, 0.56548983, 0.50446129]])
+ """Edge cases
"""
- Edge cases
- """
+
def testGridRNNEdgeCasesLikeRelu(self):
with self.test_session() as sess:
with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
@@ -325,8 +350,13 @@ class GridRNNCellTest(tf.test.TestCase):
m = tf.zeros([0, 0])
# this is equivalent to relu
- cell = tf.contrib.grid_rnn.GridRNNCell(num_units=2, num_dims=1, input_dims=0, output_dims=0,
- non_recurrent_dims=0, non_recurrent_fn=tf.nn.relu)
+ cell = tf.contrib.grid_rnn.GridRNNCell(
+ num_units=2,
+ num_dims=1,
+ input_dims=0,
+ output_dims=0,
+ non_recurrent_dims=0,
+ non_recurrent_fn=tf.nn.relu)
g, s = cell(x, m)
self.assertEqual(g.get_shape(), (3, 2))
self.assertEqual(s.get_shape(), (0, 0))
@@ -340,12 +370,17 @@ class GridRNNCellTest(tf.test.TestCase):
def testGridRNNEdgeCasesNoOutput(self):
with self.test_session() as sess:
with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
- x = tf.zeros([1, 2])
+ x = tf.zeros([1, 2])
m = tf.zeros([1, 4])
# This cell produces no output
- cell = tf.contrib.grid_rnn.GridRNNCell(num_units=2, num_dims=2, input_dims=0, output_dims=None,
- non_recurrent_dims=0, non_recurrent_fn=tf.nn.relu)
+ cell = tf.contrib.grid_rnn.GridRNNCell(
+ num_units=2,
+ num_dims=2,
+ input_dims=0,
+ output_dims=None,
+ non_recurrent_dims=0,
+ non_recurrent_fn=tf.nn.relu)
g, s = cell(x, m)
self.assertEqual(g.get_shape(), (0, 0))
self.assertEqual(s.get_shape(), (1, 4))
@@ -356,8 +391,7 @@ class GridRNNCellTest(tf.test.TestCase):
self.assertEqual(res[0].shape, (0, 0))
self.assertEqual(res[1].shape, (1, 4))
- """
- Test with tf.nn.rnn
+ """Test with tf.nn.rnn
"""
def testGrid2LSTMCellWithRNN(self):
@@ -370,7 +404,9 @@ class GridRNNCellTest(tf.test.TestCase):
cell = tf.contrib.grid_rnn.Grid2LSTMCell(num_units=num_units)
inputs = max_length * [
- tf.placeholder(tf.float32, shape=(batch_size, input_size))]
+ tf.placeholder(
+ tf.float32, shape=(batch_size, input_size))
+ ]
outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32)
@@ -386,8 +422,7 @@ class GridRNNCellTest(tf.test.TestCase):
sess.run(tf.initialize_all_variables())
input_value = np.ones((batch_size, input_size))
- values = sess.run(outputs + [state],
- feed_dict={inputs[0]: input_value})
+ values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value})
for v in values:
self.assertTrue(np.all(np.isfinite(v)))
@@ -398,10 +433,13 @@ class GridRNNCellTest(tf.test.TestCase):
num_units = 2
with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
- cell = tf.contrib.grid_rnn.Grid2LSTMCell(num_units=num_units, non_recurrent_fn=tf.nn.relu)
+ cell = tf.contrib.grid_rnn.Grid2LSTMCell(
+ num_units=num_units, non_recurrent_fn=tf.nn.relu)
inputs = max_length * [
- tf.placeholder(tf.float32, shape=(batch_size, input_size))]
+ tf.placeholder(
+ tf.float32, shape=(batch_size, input_size))
+ ]
outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32)
@@ -417,8 +455,7 @@ class GridRNNCellTest(tf.test.TestCase):
sess.run(tf.initialize_all_variables())
input_value = np.ones((batch_size, input_size))
- values = sess.run(outputs + [state],
- feed_dict={inputs[0]: input_value})
+ values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value})
for v in values:
self.assertTrue(np.all(np.isfinite(v)))
@@ -429,10 +466,13 @@ class GridRNNCellTest(tf.test.TestCase):
num_units = 2
with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)):
- cell = tf.contrib.grid_rnn.Grid3LSTMCell(num_units=num_units, non_recurrent_fn=tf.nn.relu)
+ cell = tf.contrib.grid_rnn.Grid3LSTMCell(
+ num_units=num_units, non_recurrent_fn=tf.nn.relu)
inputs = max_length * [
- tf.placeholder(tf.float32, shape=(batch_size, input_size))]
+ tf.placeholder(
+ tf.float32, shape=(batch_size, input_size))
+ ]
outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32)
@@ -448,12 +488,10 @@ class GridRNNCellTest(tf.test.TestCase):
sess.run(tf.initialize_all_variables())
input_value = np.ones((batch_size, input_size))
- values = sess.run(outputs + [state],
- feed_dict={inputs[0]: input_value})
+ values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value})
for v in values:
self.assertTrue(np.all(np.isfinite(v)))
-
def testGrid1LSTMCellWithRNN(self):
batch_size = 3
input_size = 5
@@ -464,8 +502,8 @@ class GridRNNCellTest(tf.test.TestCase):
cell = tf.contrib.grid_rnn.Grid1LSTMCell(num_units=num_units)
# for 1-LSTM, we only feed the first step
- inputs = [tf.placeholder(tf.float32, shape=(batch_size, input_size))] \
- + (max_length - 1) * [tf.zeros([batch_size, input_size])]
+ inputs = ([tf.placeholder(tf.float32, shape=(batch_size, input_size))]
+ + (max_length - 1) * [tf.zeros([batch_size, input_size])])
outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32)
@@ -480,10 +518,10 @@ class GridRNNCellTest(tf.test.TestCase):
sess.run(tf.initialize_all_variables())
input_value = np.ones((batch_size, input_size))
- values = sess.run(outputs + [state],
- feed_dict={inputs[0]: input_value})
+ values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value})
for v in values:
self.assertTrue(np.all(np.isfinite(v)))
-if __name__ == "__main__":
+
+if __name__ == '__main__':
tf.test.main()
diff --git a/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py b/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py
index 8a2d433141..45862f43d7 100644
--- a/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py
+++ b/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py
@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-# TODO(b/28879898) Fix all lint issues and clean the code.
"""Module for constructing GridRNN cells"""
+
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -35,46 +35,67 @@ class GridRNNCell(rnn_cell.RNNCell):
http://arxiv.org/pdf/1507.01526v3.pdf
- This is the generic implementation of GridRNN. Users can specify arbitrary number of dimensions,
+ This is the generic implementation of GridRNN. Users can specify arbitrary
+ number of dimensions,
set some of them to be priority (section 3.2), non-recurrent (section 3.3)
and input/output dimensions (section 3.4).
Weight sharing can also be specified using the `tied` parameter.
Type of recurrent units can be specified via `cell_fn`.
"""
- def __init__(self, num_units, num_dims=1, input_dims=None, output_dims=None, priority_dims=None,
- non_recurrent_dims=None, tied=False, cell_fn=None, non_recurrent_fn=None):
+ def __init__(self,
+ num_units,
+ num_dims=1,
+ input_dims=None,
+ output_dims=None,
+ priority_dims=None,
+ non_recurrent_dims=None,
+ tied=False,
+ cell_fn=None,
+ non_recurrent_fn=None):
"""Initialize the parameters of a Grid RNN cell
Args:
num_units: int, The number of units in all dimensions of this GridRNN cell
num_dims: int, Number of dimensions of this grid.
input_dims: int or list, List of dimensions which will receive input data.
- output_dims: int or list, List of dimensions from which the output will be recorded.
- priority_dims: int or list, List of dimensions to be considered as priority dimensions.
+ output_dims: int or list, List of dimensions from which the output will be
+ recorded.
+ priority_dims: int or list, List of dimensions to be considered as
+ priority dimensions.
If None, no dimension is prioritized.
- non_recurrent_dims: int or list, List of dimensions that are not recurrent.
- The transfer function for non-recurrent dimensions is specified via `non_recurrent_fn`,
+ non_recurrent_dims: int or list, List of dimensions that are not
+ recurrent.
+ The transfer function for non-recurrent dimensions is specified
+ via `non_recurrent_fn`,
which is default to be `tensorflow.nn.relu`.
- tied: bool, Whether to share the weights among the dimensions of this GridRNN cell.
- If there are non-recurrent dimensions in the grid, weights are shared between each
+ tied: bool, Whether to share the weights among the dimensions of this
+ GridRNN cell.
+ If there are non-recurrent dimensions in the grid, weights are
+ shared between each
group of recurrent and non-recurrent dimensions.
- cell_fn: function, a function which returns the recurrent cell object. Has to be in the following signature:
+ cell_fn: function, a function which returns the recurrent cell object. Has
+ to be in the following signature:
def cell_func(num_units, input_size):
# ...
- and returns an object of type `RNNCell`. If None, LSTMCell with default parameters will be used.
- non_recurrent_fn: a tensorflow Op that will be the transfer function of the non-recurrent dimensions
+ and returns an object of type `RNNCell`. If None, LSTMCell with
+ default parameters will be used.
+ non_recurrent_fn: a tensorflow Op that will be the transfer function of
+ the non-recurrent dimensions
"""
if num_dims < 1:
raise ValueError('dims must be >= 1: {}'.format(num_dims))
- self._config = _parse_rnn_config(num_dims, input_dims, output_dims, priority_dims,
- non_recurrent_dims, non_recurrent_fn or nn.relu, tied, num_units)
+ self._config = _parse_rnn_config(num_dims, input_dims, output_dims,
+ priority_dims, non_recurrent_dims,
+ non_recurrent_fn or nn.relu, tied,
+ num_units)
cell_input_size = (self._config.num_dims - 1) * num_units
if cell_fn is None:
- self._cell = rnn_cell.LSTMCell(num_units=num_units, input_size=cell_input_size)
+ self._cell = rnn_cell.LSTMCell(
+ num_units=num_units, input_size=cell_input_size, state_is_tuple=False)
else:
self._cell = cell_fn(num_units, cell_input_size)
if not isinstance(self._cell, rnn_cell.RNNCell):
@@ -100,7 +121,8 @@ class GridRNNCell(rnn_cell.RNNCell):
Args:
inputs: input Tensor, 2D, batch x input_size. Or None
- state: state Tensor, 2D, batch x state_size. Note that state_size = cell_state_size * recurrent_dims
+ state: state Tensor, 2D, batch x state_size. Note that state_size =
+ cell_state_size * recurrent_dims
scope: VariableScope for the created subgraph; defaults to "GridRNNCell".
Returns:
@@ -112,24 +134,32 @@ class GridRNNCell(rnn_cell.RNNCell):
"""
state_sz = state.get_shape().as_list()[1]
if self.state_size != state_sz:
- raise ValueError('Actual state size not same as specified: {} vs {}.'.format(state_sz, self.state_size))
+ raise ValueError(
+ 'Actual state size not same as specified: {} vs {}.'.format(
+ state_sz, self.state_size))
conf = self._config
dtype = inputs.dtype if inputs is not None else state.dtype
- # c_prev is `m`, and m_prev is `h` in the paper. Keep c and m here for consistency with the codebase
+ # c_prev is `m`, and m_prev is `h` in the paper.
+ # Keep c and m here for consistency with the codebase
c_prev = [None] * self._config.num_dims
m_prev = [None] * self._config.num_dims
cell_output_size = self._cell.state_size - conf.num_units
# for LSTM : state = memory cell + output, hence cell_output_size > 0
- # for GRU/RNN: state = output (whose size is equal to _num_units), hence cell_output_size = 0
- for recurrent_dim, start_idx in zip(self._config.recurrents, range(0, self.state_size, self._cell.state_size)):
+ # for GRU/RNN: state = output (whose size is equal to _num_units),
+ # hence cell_output_size = 0
+ for recurrent_dim, start_idx in zip(self._config.recurrents, range(
+ 0, self.state_size, self._cell.state_size)):
if cell_output_size > 0:
- c_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx], [-1, conf.num_units])
- m_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx + conf.num_units], [-1, cell_output_size])
+ c_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx],
+ [-1, conf.num_units])
+ m_prev[recurrent_dim] = array_ops.slice(
+ state, [0, start_idx + conf.num_units], [-1, cell_output_size])
else:
- m_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx], [-1, conf.num_units])
+ m_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx],
+ [-1, conf.num_units])
new_output = [None] * conf.num_dims
new_state = [None] * conf.num_dims
@@ -137,150 +167,212 @@ class GridRNNCell(rnn_cell.RNNCell):
with vs.variable_scope(scope or type(self).__name__): # GridRNNCell
# project input
- if inputs is not None and sum(inputs.get_shape().as_list()) > 0 and len(conf.inputs) > 0:
+ if inputs is not None and sum(inputs.get_shape().as_list()) > 0 and len(
+ conf.inputs) > 0:
input_splits = array_ops.split(1, len(conf.inputs), inputs)
input_sz = input_splits[0].get_shape().as_list()[1]
for i, j in enumerate(conf.inputs):
- input_project_m = vs.get_variable('project_m_{}'.format(j), [input_sz, conf.num_units], dtype=dtype)
+ input_project_m = vs.get_variable(
+ 'project_m_{}'.format(j), [input_sz, conf.num_units], dtype=dtype)
m_prev[j] = math_ops.matmul(input_splits[i], input_project_m)
if cell_output_size > 0:
- input_project_c = vs.get_variable('project_c_{}'.format(j), [input_sz, conf.num_units], dtype=dtype)
+ input_project_c = vs.get_variable(
+ 'project_c_{}'.format(j), [input_sz, conf.num_units],
+ dtype=dtype)
c_prev[j] = math_ops.matmul(input_splits[i], input_project_c)
-
- _propagate(conf.non_priority, conf, self._cell, c_prev, m_prev, new_output, new_state, True)
- _propagate(conf.priority, conf, self._cell, c_prev, m_prev, new_output, new_state, False)
+ _propagate(conf.non_priority, conf, self._cell, c_prev, m_prev,
+ new_output, new_state, True)
+ _propagate(conf.priority, conf, self._cell, c_prev, m_prev, new_output,
+ new_state, False)
output_tensors = [new_output[i] for i in self._config.outputs]
- output = array_ops.zeros([0, 0], dtype) if len(output_tensors) == 0 else array_ops.concat(1,
- output_tensors)
+ output = array_ops.zeros(
+ [0, 0], dtype) if len(output_tensors) == 0 else array_ops.concat(
+ 1, output_tensors)
state_tensors = [new_state[i] for i in self._config.recurrents]
- states = array_ops.zeros([0, 0], dtype) if len(state_tensors) == 0 else array_ops.concat(1, state_tensors)
+ states = array_ops.zeros(
+ [0, 0], dtype) if len(state_tensors) == 0 else array_ops.concat(
+ 1, state_tensors)
return output, states
+"""Specialized cells, for convenience
"""
-Specialized cells, for convenience
-"""
+
class Grid1BasicRNNCell(GridRNNCell):
"""1D BasicRNN cell"""
+
def __init__(self, num_units):
- super(Grid1BasicRNNCell, self).__init__(num_units=num_units, num_dims=1,
- input_dims=0, output_dims=0, priority_dims=0, tied=False,
- cell_fn=lambda n, i: rnn_cell.BasicRNNCell(num_units=n, input_size=i))
+ super(Grid1BasicRNNCell, self).__init__(
+ num_units=num_units, num_dims=1,
+ input_dims=0, output_dims=0, priority_dims=0, tied=False,
+ cell_fn=lambda n, i: rnn_cell.BasicRNNCell(num_units=n, input_size=i))
class Grid2BasicRNNCell(GridRNNCell):
"""2D BasicRNN cell
- This creates a 2D cell which receives input and gives output in the first dimension.
- The first dimension can optionally be non-recurrent if `non_recurrent_fn` is specified.
+
+ This creates a 2D cell which receives input and gives output in the first
+ dimension.
+
+ The first dimension can optionally be non-recurrent if `non_recurrent_fn` is
+ specified.
"""
+
def __init__(self, num_units, tied=False, non_recurrent_fn=None):
- super(Grid2BasicRNNCell, self).__init__(num_units=num_units, num_dims=2,
- input_dims=0, output_dims=0, priority_dims=0, tied=tied,
- non_recurrent_dims=None if non_recurrent_fn is None else 0,
- cell_fn=lambda n, i: rnn_cell.BasicRNNCell(num_units=n, input_size=i),
- non_recurrent_fn=non_recurrent_fn)
+ super(Grid2BasicRNNCell, self).__init__(
+ num_units=num_units, num_dims=2,
+ input_dims=0, output_dims=0, priority_dims=0, tied=tied,
+ non_recurrent_dims=None if non_recurrent_fn is None else 0,
+ cell_fn=lambda n, i: rnn_cell.BasicRNNCell(num_units=n, input_size=i),
+ non_recurrent_fn=non_recurrent_fn)
class Grid1BasicLSTMCell(GridRNNCell):
"""1D BasicLSTM cell"""
+
def __init__(self, num_units, forget_bias=1):
- super(Grid1BasicLSTMCell, self).__init__(num_units=num_units, num_dims=1,
- input_dims=0, output_dims=0, priority_dims=0, tied=False,
- cell_fn=lambda n, i: rnn_cell.BasicLSTMCell(num_units=n,
- forget_bias=forget_bias, input_size=i))
+ super(Grid1BasicLSTMCell, self).__init__(
+ num_units=num_units, num_dims=1,
+ input_dims=0, output_dims=0, priority_dims=0, tied=False,
+ cell_fn=lambda n, i: rnn_cell.BasicLSTMCell(
+ num_units=n,
+ forget_bias=forget_bias, input_size=i,
+ state_is_tuple=False))
class Grid2BasicLSTMCell(GridRNNCell):
"""2D BasicLSTM cell
- This creates a 2D cell which receives input and gives output in the first dimension.
- The first dimension can optionally be non-recurrent if `non_recurrent_fn` is specified.
+
+ This creates a 2D cell which receives input and gives output in the first
+ dimension.
+
+ The first dimension can optionally be non-recurrent if `non_recurrent_fn` is
+ specified.
"""
- def __init__(self, num_units, tied=False, non_recurrent_fn=None, forget_bias=1):
- super(Grid2BasicLSTMCell, self).__init__(num_units=num_units, num_dims=2,
- input_dims=0, output_dims=0, priority_dims=0, tied=tied,
- non_recurrent_dims=None if non_recurrent_fn is None else 0,
- cell_fn=lambda n, i: rnn_cell.BasicLSTMCell(
- num_units=n, forget_bias=forget_bias, input_size=i),
- non_recurrent_fn=non_recurrent_fn)
+
+ def __init__(self,
+ num_units,
+ tied=False,
+ non_recurrent_fn=None,
+ forget_bias=1):
+ super(Grid2BasicLSTMCell, self).__init__(
+ num_units=num_units, num_dims=2,
+ input_dims=0, output_dims=0, priority_dims=0, tied=tied,
+ non_recurrent_dims=None if non_recurrent_fn is None else 0,
+ cell_fn=lambda n, i: rnn_cell.BasicLSTMCell(
+ num_units=n, forget_bias=forget_bias, input_size=i,
+ state_is_tuple=False),
+ non_recurrent_fn=non_recurrent_fn)
class Grid1LSTMCell(GridRNNCell):
"""1D LSTM cell
- This is different from Grid1BasicLSTMCell because it gives options to specify the forget bias and enabling peepholes
+
+ This is different from Grid1BasicLSTMCell because it gives options to
+ specify the forget bias and enabling peepholes
"""
+
def __init__(self, num_units, use_peepholes=False, forget_bias=1.0):
- super(Grid1LSTMCell, self).__init__(num_units=num_units, num_dims=1,
- input_dims=0, output_dims=0, priority_dims=0,
- cell_fn=lambda n, i: rnn_cell.LSTMCell(
- num_units=n, input_size=i, use_peepholes=use_peepholes,
- forget_bias=forget_bias))
+ super(Grid1LSTMCell, self).__init__(
+ num_units=num_units, num_dims=1,
+ input_dims=0, output_dims=0, priority_dims=0,
+ cell_fn=lambda n, i: rnn_cell.LSTMCell(
+ num_units=n, input_size=i, use_peepholes=use_peepholes,
+ forget_bias=forget_bias, state_is_tuple=False))
class Grid2LSTMCell(GridRNNCell):
"""2D LSTM cell
- This creates a 2D cell which receives input and gives output in the first dimension.
- The first dimension can optionally be non-recurrent if `non_recurrent_fn` is specified.
+
+ This creates a 2D cell which receives input and gives output in the first
+ dimension.
+ The first dimension can optionally be non-recurrent if `non_recurrent_fn` is
+ specified.
"""
- def __init__(self, num_units, tied=False, non_recurrent_fn=None,
- use_peepholes=False, forget_bias=1.0):
- super(Grid2LSTMCell, self).__init__(num_units=num_units, num_dims=2,
- input_dims=0, output_dims=0, priority_dims=0, tied=tied,
- non_recurrent_dims=None if non_recurrent_fn is None else 0,
- cell_fn=lambda n, i: rnn_cell.LSTMCell(
- num_units=n, input_size=i, forget_bias=forget_bias,
- use_peepholes=use_peepholes),
- non_recurrent_fn=non_recurrent_fn)
+
+ def __init__(self,
+ num_units,
+ tied=False,
+ non_recurrent_fn=None,
+ use_peepholes=False,
+ forget_bias=1.0):
+ super(Grid2LSTMCell, self).__init__(
+ num_units=num_units, num_dims=2,
+ input_dims=0, output_dims=0, priority_dims=0, tied=tied,
+ non_recurrent_dims=None if non_recurrent_fn is None else 0,
+ cell_fn=lambda n, i: rnn_cell.LSTMCell(
+ num_units=n, input_size=i, forget_bias=forget_bias,
+ use_peepholes=use_peepholes, state_is_tuple=False),
+ non_recurrent_fn=non_recurrent_fn)
class Grid3LSTMCell(GridRNNCell):
"""3D BasicLSTM cell
- This creates a 2D cell which receives input and gives output in the first dimension.
- The first dimension can optionally be non-recurrent if `non_recurrent_fn` is specified.
+
+ This creates a 2D cell which receives input and gives output in the first
+ dimension.
+ The first dimension can optionally be non-recurrent if `non_recurrent_fn` is
+ specified.
The second and third dimensions are LSTM.
"""
- def __init__(self, num_units, tied=False, non_recurrent_fn=None,
- use_peepholes=False, forget_bias=1.0):
- super(Grid3LSTMCell, self).__init__(num_units=num_units, num_dims=3,
- input_dims=0, output_dims=0, priority_dims=0, tied=tied,
- non_recurrent_dims=None if non_recurrent_fn is None else 0,
- cell_fn=lambda n, i: rnn_cell.LSTMCell(
- num_units=n, input_size=i, forget_bias=forget_bias,
- use_peepholes=use_peepholes),
- non_recurrent_fn=non_recurrent_fn)
+
+ def __init__(self,
+ num_units,
+ tied=False,
+ non_recurrent_fn=None,
+ use_peepholes=False,
+ forget_bias=1.0):
+ super(Grid3LSTMCell, self).__init__(
+ num_units=num_units, num_dims=3,
+ input_dims=0, output_dims=0, priority_dims=0, tied=tied,
+ non_recurrent_dims=None if non_recurrent_fn is None else 0,
+ cell_fn=lambda n, i: rnn_cell.LSTMCell(
+ num_units=n, input_size=i, forget_bias=forget_bias,
+ use_peepholes=use_peepholes, state_is_tuple=False),
+ non_recurrent_fn=non_recurrent_fn)
+
class Grid2GRUCell(GridRNNCell):
"""2D LSTM cell
- This creates a 2D cell which receives input and gives output in the first dimension.
- The first dimension can optionally be non-recurrent if `non_recurrent_fn` is specified.
+
+ This creates a 2D cell which receives input and gives output in the first
+ dimension.
+ The first dimension can optionally be non-recurrent if `non_recurrent_fn` is
+ specified.
"""
def __init__(self, num_units, tied=False, non_recurrent_fn=None):
- super(Grid2GRUCell, self).__init__(num_units=num_units, num_dims=2,
- input_dims=0, output_dims=0, priority_dims=0, tied=tied,
- non_recurrent_dims=None if non_recurrent_fn is None else 0,
- cell_fn=lambda n, i: rnn_cell.GRUCell(num_units=n, input_size=i),
- non_recurrent_fn=non_recurrent_fn)
+ super(Grid2GRUCell, self).__init__(
+ num_units=num_units, num_dims=2,
+ input_dims=0, output_dims=0, priority_dims=0, tied=tied,
+ non_recurrent_dims=None if non_recurrent_fn is None else 0,
+ cell_fn=lambda n, i: rnn_cell.GRUCell(num_units=n, input_size=i),
+ non_recurrent_fn=non_recurrent_fn)
+
+"""Helpers
"""
-Helpers
-"""
-_GridRNNDimension = namedtuple('_GridRNNDimension', ['idx', 'is_input', 'is_output', 'is_priority', 'non_recurrent_fn'])
+_GridRNNDimension = namedtuple(
+ '_GridRNNDimension',
+ ['idx', 'is_input', 'is_output', 'is_priority', 'non_recurrent_fn'])
+
+_GridRNNConfig = namedtuple('_GridRNNConfig',
+ ['num_dims', 'dims', 'inputs', 'outputs',
+ 'recurrents', 'priority', 'non_priority', 'tied',
+ 'num_units'])
-_GridRNNConfig = namedtuple('_GridRNNConfig', ['num_dims', 'dims',
- 'inputs', 'outputs', 'recurrents',
- 'priority', 'non_priority', 'tied', 'num_units'])
+def _parse_rnn_config(num_dims, ls_input_dims, ls_output_dims, ls_priority_dims,
+ ls_non_recurrent_dims, non_recurrent_fn, tied, num_units):
-def _parse_rnn_config(num_dims, ls_input_dims, ls_output_dims, ls_priority_dims, ls_non_recurrent_dims,
- non_recurrent_fn, tied, num_units):
def check_dim_list(ls):
if ls is None:
ls = []
@@ -288,7 +380,8 @@ def _parse_rnn_config(num_dims, ls_input_dims, ls_output_dims, ls_priority_dims,
ls = [ls]
ls = sorted(set(ls))
if any(_ < 0 or _ >= num_dims for _ in ls):
- raise ValueError('Invalid dims: {}. Must be in [0, {})'.format(ls, num_dims))
+ raise ValueError('Invalid dims: {}. Must be in [0, {})'.format(ls,
+ num_dims))
return ls
input_dims = check_dim_list(ls_input_dims)
@@ -298,42 +391,58 @@ def _parse_rnn_config(num_dims, ls_input_dims, ls_output_dims, ls_priority_dims,
rnn_dims = []
for i in range(num_dims):
- rnn_dims.append(_GridRNNDimension(idx=i, is_input=(i in input_dims), is_output=(i in output_dims),
- is_priority=(i in priority_dims),
- non_recurrent_fn=non_recurrent_fn if i in non_recurrent_dims else None))
- return _GridRNNConfig(num_dims=num_dims, dims=rnn_dims, inputs=input_dims, outputs=output_dims,
- recurrents=[x for x in range(num_dims) if x not in non_recurrent_dims],
- priority=priority_dims,
- non_priority=[x for x in range(num_dims) if x not in priority_dims],
- tied=tied, num_units=num_units)
-
-
-def _propagate(dim_indices, conf, cell, c_prev, m_prev, new_output, new_state, first_call):
- """
- Propagates through all the cells in dim_indices dimensions.
+ rnn_dims.append(
+ _GridRNNDimension(
+ idx=i,
+ is_input=(i in input_dims),
+ is_output=(i in output_dims),
+ is_priority=(i in priority_dims),
+ non_recurrent_fn=non_recurrent_fn if i in non_recurrent_dims else
+ None))
+ return _GridRNNConfig(
+ num_dims=num_dims,
+ dims=rnn_dims,
+ inputs=input_dims,
+ outputs=output_dims,
+ recurrents=[x for x in range(num_dims) if x not in non_recurrent_dims],
+ priority=priority_dims,
+ non_priority=[x for x in range(num_dims) if x not in priority_dims],
+ tied=tied,
+ num_units=num_units)
+
+
+def _propagate(dim_indices, conf, cell, c_prev, m_prev, new_output, new_state,
+ first_call):
+ """Propagates through all the cells in dim_indices dimensions.
"""
if len(dim_indices) == 0:
return
- # Because of the way RNNCells are implemented, we take the last dimension (H_{N-1}) out
- # and feed it as the state of the RNN cell (in `last_dim_output`)
+ # Because of the way RNNCells are implemented, we take the last dimension
+ # (H_{N-1}) out and feed it as the state of the RNN cell
+ # (in `last_dim_output`).
# The input of the cell (H_0 to H_{N-2}) are concatenated into `cell_inputs`
if conf.num_dims > 1:
ls_cell_inputs = [None] * (conf.num_dims - 1)
for d in conf.dims[:-1]:
- ls_cell_inputs[d.idx] = new_output[d.idx] if new_output[d.idx] is not None else m_prev[d.idx]
+ ls_cell_inputs[d.idx] = new_output[d.idx] if new_output[
+ d.idx] is not None else m_prev[d.idx]
cell_inputs = array_ops.concat(1, ls_cell_inputs)
else:
- cell_inputs = array_ops.zeros([m_prev[0].get_shape().as_list()[0], 0], m_prev[0].dtype)
+ cell_inputs = array_ops.zeros([m_prev[0].get_shape().as_list()[0], 0],
+ m_prev[0].dtype)
last_dim_output = new_output[-1] if new_output[-1] is not None else m_prev[-1]
for i in dim_indices:
d = conf.dims[i]
if d.non_recurrent_fn:
- linear_args = array_ops.concat(1, [cell_inputs, last_dim_output]) if conf.num_dims > 1 else last_dim_output
- with vs.variable_scope('non_recurrent' if conf.tied else 'non_recurrent/cell_{}'.format(i)):
- if conf.tied and not(first_call and i == dim_indices[0]):
+ linear_args = array_ops.concat(
+ 1, [cell_inputs, last_dim_output
+ ]) if conf.num_dims > 1 else last_dim_output
+ with vs.variable_scope('non_recurrent' if conf.tied else
+ 'non_recurrent/cell_{}'.format(i)):
+ if conf.tied and not (first_call and i == dim_indices[0]):
vs.get_variable_scope().reuse_variables()
new_output[d.idx] = layers.legacy_fully_connected(
linear_args,
@@ -348,7 +457,8 @@ def _propagate(dim_indices, conf, cell, c_prev, m_prev, new_output, new_state, f
# for GRU/RNN, the state is just the previous output
cell_state = last_dim_output
- with vs.variable_scope('recurrent' if conf.tied else 'recurrent/cell_{}'.format(i)):
+ with vs.variable_scope('recurrent' if conf.tied else
+ 'recurrent/cell_{}'.format(i)):
if conf.tied and not (first_call and i == dim_indices[0]):
vs.get_variable_scope().reuse_variables()
new_output[d.idx], new_state[d.idx] = cell(cell_inputs, cell_state)