diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2016-08-11 18:06:42 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-08-11 19:17:28 -0700 |
commit | bc156957812b64d86d263bee940cacf7fb6fb9a3 (patch) | |
tree | 2147c502149a4ad616628317903a00742b46031e /tensorflow/contrib/grid_rnn | |
parent | fc218ef4d53e697c1a97d90a57a7437255575f89 (diff) |
** BREAKING CHANGE **
Core RNNCell implementations now use state_is_tuple=True by default
This is part of the deprecation process for non-tuple LSTM and MultiRNNCell
states.
Change: 130059769
Diffstat (limited to 'tensorflow/contrib/grid_rnn')
-rw-r--r-- | tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py | 210 | ||||
-rw-r--r-- | tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py | 356 |
2 files changed, 357 insertions, 209 deletions
diff --git a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py index 11fb208f49..abd06ffa26 100644 --- a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py +++ b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Tests for GridRNN cells.""" from __future__ import absolute_import @@ -27,7 +26,8 @@ class GridRNNCellTest(tf.test.TestCase): def testGrid2BasicLSTMCell(self): with self.test_session() as sess: - with tf.variable_scope('root', initializer=tf.constant_initializer(0.2)) as root_scope: + with tf.variable_scope( + 'root', initializer=tf.constant_initializer(0.2)) as root_scope: x = tf.zeros([1, 3]) m = tf.zeros([1, 8]) cell = tf.contrib.grid_rnn.Grid2BasicLSTMCell(2) @@ -38,15 +38,18 @@ class GridRNNCellTest(tf.test.TestCase): self.assertEqual(s.get_shape(), (1, 8)) sess.run([tf.initialize_all_variables()]) - res = sess.run([g, s], {x: np.array([[1., 1., 1.]]), - m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])}) + res = sess.run( + [g, s], {x: np.array([[1., 1., 1.]]), + m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])}) self.assertEqual(res[0].shape, (1, 2)) self.assertEqual(res[1].shape, (1, 8)) - self.assertAllClose(res[0], [[ 0.36617181, 0.36617181]]) - self.assertAllClose(res[1], [[ 0.71053141, 0.71053141, 0.36617181, 0.36617181, - 0.72320831, 0.80555487, 0.39102408, 0.42150158]]) + self.assertAllClose(res[0], [[0.36617181, 0.36617181]]) + self.assertAllClose(res[1], [[0.71053141, 0.71053141, 0.36617181, + 0.36617181, 0.72320831, 0.80555487, + 0.39102408, 0.42150158]]) - # emulate a loop through the input sequence, where we call cell() multiple times + # emulate a loop through the input sequence, + # where we call cell() multiple times root_scope.reuse_variables() g2, s2 = cell(x, m) self.assertEqual(g2.get_shape(), (1, 2)) @@ -56,8 +59,9 @@ class GridRNNCellTest(tf.test.TestCase): self.assertEqual(res[0].shape, (1, 2)) self.assertEqual(res[1].shape, (1, 8)) self.assertAllClose(res[0], [[0.58847463, 0.58847463]]) - self.assertAllClose(res[1], [[1.40469193, 1.40469193, 0.58847463, 0.58847463, - 0.97726452, 1.04626071, 0.4927212, 0.51137757]]) + self.assertAllClose(res[1], [[1.40469193, 1.40469193, 0.58847463, + 0.58847463, 0.97726452, 1.04626071, + 0.4927212, 0.51137757]]) def testGrid2BasicLSTMCellTied(self): with self.test_session() as sess: @@ -72,27 +76,31 @@ class GridRNNCellTest(tf.test.TestCase): self.assertEqual(s.get_shape(), (1, 8)) sess.run([tf.initialize_all_variables()]) - res = sess.run([g, s], {x: np.array([[1., 1., 1.]]), - m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])}) + res = sess.run( + [g, s], {x: np.array([[1., 1., 1.]]), + m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])}) self.assertEqual(res[0].shape, (1, 2)) self.assertEqual(res[1].shape, (1, 8)) - self.assertAllClose(res[0], [[ 0.36617181, 0.36617181]]) - self.assertAllClose(res[1], [[ 0.71053141, 0.71053141, 0.36617181, 0.36617181, - 0.72320831, 0.80555487, 0.39102408, 0.42150158]]) + self.assertAllClose(res[0], [[0.36617181, 0.36617181]]) + self.assertAllClose(res[1], [[0.71053141, 0.71053141, 0.36617181, + 0.36617181, 0.72320831, 0.80555487, + 0.39102408, 0.42150158]]) res = sess.run([g, s], {x: np.array([[1., 1., 1.]]), m: res[1]}) self.assertEqual(res[0].shape, (1, 2)) self.assertEqual(res[1].shape, (1, 8)) self.assertAllClose(res[0], [[0.36703536, 0.36703536]]) - self.assertAllClose(res[1], [[0.71200621, 0.71200621, 0.36703536, 0.36703536, - 0.80941606, 0.87550586, 0.40108523, 0.42199609]]) + self.assertAllClose(res[1], [[0.71200621, 0.71200621, 0.36703536, + 0.36703536, 0.80941606, 0.87550586, + 0.40108523, 0.42199609]]) def testGrid2BasicLSTMCellWithRelu(self): with self.test_session() as sess: with tf.variable_scope('root', initializer=tf.constant_initializer(0.2)): x = tf.zeros([1, 3]) m = tf.zeros([1, 4]) - cell = tf.contrib.grid_rnn.Grid2BasicLSTMCell(2, tied=False, non_recurrent_fn=tf.nn.relu) + cell = tf.contrib.grid_rnn.Grid2BasicLSTMCell( + 2, tied=False, non_recurrent_fn=tf.nn.relu) self.assertEqual(cell.state_size, 4) g, s = cell(x, m) @@ -104,11 +112,11 @@ class GridRNNCellTest(tf.test.TestCase): m: np.array([[0.1, 0.2, 0.3, 0.4]])}) self.assertEqual(res[0].shape, (1, 2)) self.assertEqual(res[1].shape, (1, 4)) - self.assertAllClose(res[0], [[ 0.31667367, 0.31667367]]) - self.assertAllClose(res[1], [[ 0.29530135, 0.37520045, 0.17044567, 0.21292259]]) + self.assertAllClose(res[0], [[0.31667367, 0.31667367]]) + self.assertAllClose(res[1], + [[0.29530135, 0.37520045, 0.17044567, 0.21292259]]) - """ - LSTMCell + """LSTMCell """ def testGrid2LSTMCell(self): @@ -124,20 +132,23 @@ class GridRNNCellTest(tf.test.TestCase): self.assertEqual(s.get_shape(), (1, 8)) sess.run([tf.initialize_all_variables()]) - res = sess.run([g, s], {x: np.array([[1., 1., 1.]]), - m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])}) + res = sess.run( + [g, s], {x: np.array([[1., 1., 1.]]), + m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])}) self.assertEqual(res[0].shape, (1, 2)) self.assertEqual(res[1].shape, (1, 8)) - self.assertAllClose(res[0], [[ 0.95686918, 0.95686918]]) - self.assertAllClose(res[1], [[ 2.41515064, 2.41515064, 0.95686918, 0.95686918, - 1.38917875, 1.49043763, 0.83884692, 0.86036491]]) + self.assertAllClose(res[0], [[0.95686918, 0.95686918]]) + self.assertAllClose(res[1], [[2.41515064, 2.41515064, 0.95686918, + 0.95686918, 1.38917875, 1.49043763, + 0.83884692, 0.86036491]]) def testGrid2LSTMCellTied(self): with self.test_session() as sess: with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)): x = tf.zeros([1, 3]) m = tf.zeros([1, 8]) - cell = tf.contrib.grid_rnn.Grid2LSTMCell(2, tied=True, use_peepholes=True) + cell = tf.contrib.grid_rnn.Grid2LSTMCell( + 2, tied=True, use_peepholes=True) self.assertEqual(cell.state_size, 8) g, s = cell(x, m) @@ -145,20 +156,23 @@ class GridRNNCellTest(tf.test.TestCase): self.assertEqual(s.get_shape(), (1, 8)) sess.run([tf.initialize_all_variables()]) - res = sess.run([g, s], {x: np.array([[1., 1., 1.]]), - m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])}) + res = sess.run( + [g, s], {x: np.array([[1., 1., 1.]]), + m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])}) self.assertEqual(res[0].shape, (1, 2)) self.assertEqual(res[1].shape, (1, 8)) - self.assertAllClose(res[0], [[ 0.95686918, 0.95686918]]) - self.assertAllClose(res[1], [[ 2.41515064, 2.41515064, 0.95686918, 0.95686918, - 1.38917875, 1.49043763, 0.83884692, 0.86036491]]) + self.assertAllClose(res[0], [[0.95686918, 0.95686918]]) + self.assertAllClose(res[1], [[2.41515064, 2.41515064, 0.95686918, + 0.95686918, 1.38917875, 1.49043763, + 0.83884692, 0.86036491]]) def testGrid2LSTMCellWithRelu(self): with self.test_session() as sess: with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)): x = tf.zeros([1, 3]) m = tf.zeros([1, 4]) - cell = tf.contrib.grid_rnn.Grid2LSTMCell(2, use_peepholes=True, non_recurrent_fn=tf.nn.relu) + cell = tf.contrib.grid_rnn.Grid2LSTMCell( + 2, use_peepholes=True, non_recurrent_fn=tf.nn.relu) self.assertEqual(cell.state_size, 4) g, s = cell(x, m) @@ -170,11 +184,11 @@ class GridRNNCellTest(tf.test.TestCase): m: np.array([[0.1, 0.2, 0.3, 0.4]])}) self.assertEqual(res[0].shape, (1, 2)) self.assertEqual(res[1].shape, (1, 4)) - self.assertAllClose(res[0], [[ 2.1831727, 2.1831727]]) - self.assertAllClose(res[1], [[ 0.92270052, 1.02325559, 0.66159075, 0.70475441]]) + self.assertAllClose(res[0], [[2.1831727, 2.1831727]]) + self.assertAllClose(res[1], + [[0.92270052, 1.02325559, 0.66159075, 0.70475441]]) - """ - RNNCell + """RNNCell """ def testGrid2BasicRNNCell(self): @@ -190,14 +204,16 @@ class GridRNNCellTest(tf.test.TestCase): self.assertEqual(s.get_shape(), (2, 4)) sess.run([tf.initialize_all_variables()]) - res = sess.run([g, s], {x: np.array([[1., 1.], [2., 2.]]), - m: np.array([[0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2]])}) + res = sess.run( + [g, s], {x: np.array([[1., 1.], [2., 2.]]), + m: np.array([[0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2]])}) self.assertEqual(res[0].shape, (2, 2)) self.assertEqual(res[1].shape, (2, 4)) self.assertAllClose(res[0], [[0.94685763, 0.94685763], [0.99480951, 0.99480951]]) - self.assertAllClose(res[1], [[0.94685763, 0.94685763, 0.80049908, 0.80049908], - [0.99480951, 0.99480951, 0.97574311, 0.97574311]]) + self.assertAllClose(res[1], + [[0.94685763, 0.94685763, 0.80049908, 0.80049908], + [0.99480951, 0.99480951, 0.97574311, 0.97574311]]) def testGrid2BasicRNNCellTied(self): with self.test_session() as sess: @@ -212,21 +228,24 @@ class GridRNNCellTest(tf.test.TestCase): self.assertEqual(s.get_shape(), (2, 4)) sess.run([tf.initialize_all_variables()]) - res = sess.run([g, s], {x: np.array([[1., 1.], [2., 2.]]), - m: np.array([[0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2]])}) + res = sess.run( + [g, s], {x: np.array([[1., 1.], [2., 2.]]), + m: np.array([[0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2]])}) self.assertEqual(res[0].shape, (2, 2)) self.assertEqual(res[1].shape, (2, 4)) self.assertAllClose(res[0], [[0.94685763, 0.94685763], [0.99480951, 0.99480951]]) - self.assertAllClose(res[1], [[0.94685763, 0.94685763, 0.80049908, 0.80049908], - [0.99480951, 0.99480951, 0.97574311, 0.97574311]]) + self.assertAllClose(res[1], + [[0.94685763, 0.94685763, 0.80049908, 0.80049908], + [0.99480951, 0.99480951, 0.97574311, 0.97574311]]) def testGrid2BasicRNNCellWithRelu(self): with self.test_session() as sess: with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)): x = tf.zeros([1, 2]) m = tf.zeros([1, 2]) - cell = tf.contrib.grid_rnn.Grid2BasicRNNCell(2, non_recurrent_fn=tf.nn.relu) + cell = tf.contrib.grid_rnn.Grid2BasicRNNCell( + 2, non_recurrent_fn=tf.nn.relu) self.assertEqual(cell.state_size, 2) g, s = cell(x, m) @@ -234,19 +253,20 @@ class GridRNNCellTest(tf.test.TestCase): self.assertEqual(s.get_shape(), (1, 2)) sess.run([tf.initialize_all_variables()]) - res = sess.run([g, s], {x: np.array([[1., 1.]]), m: np.array([[0.1, 0.1]])}) + res = sess.run([g, s], {x: np.array([[1., 1.]]), + m: np.array([[0.1, 0.1]])}) self.assertEqual(res[0].shape, (1, 2)) self.assertEqual(res[1].shape, (1, 2)) self.assertAllClose(res[0], [[1.80049896, 1.80049896]]) self.assertAllClose(res[1], [[0.80049896, 0.80049896]]) - """ - 1-LSTM + """1-LSTM """ def testGrid1LSTMCell(self): with self.test_session() as sess: - with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)) as root_scope: + with tf.variable_scope( + 'root', initializer=tf.constant_initializer(0.5)) as root_scope: x = tf.zeros([1, 3]) m = tf.zeros([1, 4]) cell = tf.contrib.grid_rnn.Grid1LSTMCell(2, use_peepholes=True) @@ -262,7 +282,8 @@ class GridRNNCellTest(tf.test.TestCase): self.assertEqual(res[0].shape, (1, 2)) self.assertEqual(res[1].shape, (1, 4)) self.assertAllClose(res[0], [[0.91287315, 0.91287315]]) - self.assertAllClose(res[1], [[2.26285243, 2.26285243, 0.91287315, 0.91287315]]) + self.assertAllClose(res[1], + [[2.26285243, 2.26285243, 0.91287315, 0.91287315]]) root_scope.reuse_variables() @@ -276,7 +297,8 @@ class GridRNNCellTest(tf.test.TestCase): self.assertEqual(res[0].shape, (1, 2)) self.assertEqual(res[1].shape, (1, 4)) self.assertAllClose(res[0], [[0.9032144, 0.9032144]]) - self.assertAllClose(res[1], [[2.79966092, 2.79966092, 0.9032144, 0.9032144]]) + self.assertAllClose(res[1], + [[2.79966092, 2.79966092, 0.9032144, 0.9032144]]) g3, s3 = cell(x2, m) self.assertEqual(g3.get_shape(), (1, 2)) @@ -287,11 +309,12 @@ class GridRNNCellTest(tf.test.TestCase): self.assertEqual(res[0].shape, (1, 2)) self.assertEqual(res[1].shape, (1, 4)) self.assertAllClose(res[0], [[0.92727238, 0.92727238]]) - self.assertAllClose(res[1], [[3.3529923, 3.3529923, 0.92727238, 0.92727238]]) + self.assertAllClose(res[1], + [[3.3529923, 3.3529923, 0.92727238, 0.92727238]]) + """3-LSTM """ - 3-LSTM - """ + def testGrid3LSTMCell(self): with self.test_session() as sess: with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)): @@ -306,18 +329,20 @@ class GridRNNCellTest(tf.test.TestCase): sess.run([tf.initialize_all_variables()]) res = sess.run([g, s], {x: np.array([[1., 1., 1.]]), - m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, -0.1, -0.2, -0.3, -0.4]])}) + m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, + 0.8, -0.1, -0.2, -0.3, -0.4]])}) self.assertEqual(res[0].shape, (1, 2)) self.assertEqual(res[1].shape, (1, 12)) self.assertAllClose(res[0], [[0.96892911, 0.96892911]]) - self.assertAllClose(res[1], [[2.45227885, 2.45227885, 0.96892911, 0.96892911, - 1.33592629, 1.4373529, 0.80867189, 0.83247656, - 0.7317788, 0.63205892, 0.56548983, 0.50446129]]) + self.assertAllClose(res[1], [[2.45227885, 2.45227885, 0.96892911, + 0.96892911, 1.33592629, 1.4373529, + 0.80867189, 0.83247656, 0.7317788, + 0.63205892, 0.56548983, 0.50446129]]) + """Edge cases """ - Edge cases - """ + def testGridRNNEdgeCasesLikeRelu(self): with self.test_session() as sess: with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)): @@ -325,8 +350,13 @@ class GridRNNCellTest(tf.test.TestCase): m = tf.zeros([0, 0]) # this is equivalent to relu - cell = tf.contrib.grid_rnn.GridRNNCell(num_units=2, num_dims=1, input_dims=0, output_dims=0, - non_recurrent_dims=0, non_recurrent_fn=tf.nn.relu) + cell = tf.contrib.grid_rnn.GridRNNCell( + num_units=2, + num_dims=1, + input_dims=0, + output_dims=0, + non_recurrent_dims=0, + non_recurrent_fn=tf.nn.relu) g, s = cell(x, m) self.assertEqual(g.get_shape(), (3, 2)) self.assertEqual(s.get_shape(), (0, 0)) @@ -340,12 +370,17 @@ class GridRNNCellTest(tf.test.TestCase): def testGridRNNEdgeCasesNoOutput(self): with self.test_session() as sess: with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)): - x = tf.zeros([1, 2]) + x = tf.zeros([1, 2]) m = tf.zeros([1, 4]) # This cell produces no output - cell = tf.contrib.grid_rnn.GridRNNCell(num_units=2, num_dims=2, input_dims=0, output_dims=None, - non_recurrent_dims=0, non_recurrent_fn=tf.nn.relu) + cell = tf.contrib.grid_rnn.GridRNNCell( + num_units=2, + num_dims=2, + input_dims=0, + output_dims=None, + non_recurrent_dims=0, + non_recurrent_fn=tf.nn.relu) g, s = cell(x, m) self.assertEqual(g.get_shape(), (0, 0)) self.assertEqual(s.get_shape(), (1, 4)) @@ -356,8 +391,7 @@ class GridRNNCellTest(tf.test.TestCase): self.assertEqual(res[0].shape, (0, 0)) self.assertEqual(res[1].shape, (1, 4)) - """ - Test with tf.nn.rnn + """Test with tf.nn.rnn """ def testGrid2LSTMCellWithRNN(self): @@ -370,7 +404,9 @@ class GridRNNCellTest(tf.test.TestCase): cell = tf.contrib.grid_rnn.Grid2LSTMCell(num_units=num_units) inputs = max_length * [ - tf.placeholder(tf.float32, shape=(batch_size, input_size))] + tf.placeholder( + tf.float32, shape=(batch_size, input_size)) + ] outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32) @@ -386,8 +422,7 @@ class GridRNNCellTest(tf.test.TestCase): sess.run(tf.initialize_all_variables()) input_value = np.ones((batch_size, input_size)) - values = sess.run(outputs + [state], - feed_dict={inputs[0]: input_value}) + values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value}) for v in values: self.assertTrue(np.all(np.isfinite(v))) @@ -398,10 +433,13 @@ class GridRNNCellTest(tf.test.TestCase): num_units = 2 with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)): - cell = tf.contrib.grid_rnn.Grid2LSTMCell(num_units=num_units, non_recurrent_fn=tf.nn.relu) + cell = tf.contrib.grid_rnn.Grid2LSTMCell( + num_units=num_units, non_recurrent_fn=tf.nn.relu) inputs = max_length * [ - tf.placeholder(tf.float32, shape=(batch_size, input_size))] + tf.placeholder( + tf.float32, shape=(batch_size, input_size)) + ] outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32) @@ -417,8 +455,7 @@ class GridRNNCellTest(tf.test.TestCase): sess.run(tf.initialize_all_variables()) input_value = np.ones((batch_size, input_size)) - values = sess.run(outputs + [state], - feed_dict={inputs[0]: input_value}) + values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value}) for v in values: self.assertTrue(np.all(np.isfinite(v))) @@ -429,10 +466,13 @@ class GridRNNCellTest(tf.test.TestCase): num_units = 2 with tf.variable_scope('root', initializer=tf.constant_initializer(0.5)): - cell = tf.contrib.grid_rnn.Grid3LSTMCell(num_units=num_units, non_recurrent_fn=tf.nn.relu) + cell = tf.contrib.grid_rnn.Grid3LSTMCell( + num_units=num_units, non_recurrent_fn=tf.nn.relu) inputs = max_length * [ - tf.placeholder(tf.float32, shape=(batch_size, input_size))] + tf.placeholder( + tf.float32, shape=(batch_size, input_size)) + ] outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32) @@ -448,12 +488,10 @@ class GridRNNCellTest(tf.test.TestCase): sess.run(tf.initialize_all_variables()) input_value = np.ones((batch_size, input_size)) - values = sess.run(outputs + [state], - feed_dict={inputs[0]: input_value}) + values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value}) for v in values: self.assertTrue(np.all(np.isfinite(v))) - def testGrid1LSTMCellWithRNN(self): batch_size = 3 input_size = 5 @@ -464,8 +502,8 @@ class GridRNNCellTest(tf.test.TestCase): cell = tf.contrib.grid_rnn.Grid1LSTMCell(num_units=num_units) # for 1-LSTM, we only feed the first step - inputs = [tf.placeholder(tf.float32, shape=(batch_size, input_size))] \ - + (max_length - 1) * [tf.zeros([batch_size, input_size])] + inputs = ([tf.placeholder(tf.float32, shape=(batch_size, input_size))] + + (max_length - 1) * [tf.zeros([batch_size, input_size])]) outputs, state = tf.nn.rnn(cell, inputs, dtype=tf.float32) @@ -480,10 +518,10 @@ class GridRNNCellTest(tf.test.TestCase): sess.run(tf.initialize_all_variables()) input_value = np.ones((batch_size, input_size)) - values = sess.run(outputs + [state], - feed_dict={inputs[0]: input_value}) + values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value}) for v in values: self.assertTrue(np.all(np.isfinite(v))) -if __name__ == "__main__": + +if __name__ == '__main__': tf.test.main() diff --git a/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py b/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py index 8a2d433141..45862f43d7 100644 --- a/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py +++ b/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -# TODO(b/28879898) Fix all lint issues and clean the code. """Module for constructing GridRNN cells""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -35,46 +35,67 @@ class GridRNNCell(rnn_cell.RNNCell): http://arxiv.org/pdf/1507.01526v3.pdf - This is the generic implementation of GridRNN. Users can specify arbitrary number of dimensions, + This is the generic implementation of GridRNN. Users can specify arbitrary + number of dimensions, set some of them to be priority (section 3.2), non-recurrent (section 3.3) and input/output dimensions (section 3.4). Weight sharing can also be specified using the `tied` parameter. Type of recurrent units can be specified via `cell_fn`. """ - def __init__(self, num_units, num_dims=1, input_dims=None, output_dims=None, priority_dims=None, - non_recurrent_dims=None, tied=False, cell_fn=None, non_recurrent_fn=None): + def __init__(self, + num_units, + num_dims=1, + input_dims=None, + output_dims=None, + priority_dims=None, + non_recurrent_dims=None, + tied=False, + cell_fn=None, + non_recurrent_fn=None): """Initialize the parameters of a Grid RNN cell Args: num_units: int, The number of units in all dimensions of this GridRNN cell num_dims: int, Number of dimensions of this grid. input_dims: int or list, List of dimensions which will receive input data. - output_dims: int or list, List of dimensions from which the output will be recorded. - priority_dims: int or list, List of dimensions to be considered as priority dimensions. + output_dims: int or list, List of dimensions from which the output will be + recorded. + priority_dims: int or list, List of dimensions to be considered as + priority dimensions. If None, no dimension is prioritized. - non_recurrent_dims: int or list, List of dimensions that are not recurrent. - The transfer function for non-recurrent dimensions is specified via `non_recurrent_fn`, + non_recurrent_dims: int or list, List of dimensions that are not + recurrent. + The transfer function for non-recurrent dimensions is specified + via `non_recurrent_fn`, which is default to be `tensorflow.nn.relu`. - tied: bool, Whether to share the weights among the dimensions of this GridRNN cell. - If there are non-recurrent dimensions in the grid, weights are shared between each + tied: bool, Whether to share the weights among the dimensions of this + GridRNN cell. + If there are non-recurrent dimensions in the grid, weights are + shared between each group of recurrent and non-recurrent dimensions. - cell_fn: function, a function which returns the recurrent cell object. Has to be in the following signature: + cell_fn: function, a function which returns the recurrent cell object. Has + to be in the following signature: def cell_func(num_units, input_size): # ... - and returns an object of type `RNNCell`. If None, LSTMCell with default parameters will be used. - non_recurrent_fn: a tensorflow Op that will be the transfer function of the non-recurrent dimensions + and returns an object of type `RNNCell`. If None, LSTMCell with + default parameters will be used. + non_recurrent_fn: a tensorflow Op that will be the transfer function of + the non-recurrent dimensions """ if num_dims < 1: raise ValueError('dims must be >= 1: {}'.format(num_dims)) - self._config = _parse_rnn_config(num_dims, input_dims, output_dims, priority_dims, - non_recurrent_dims, non_recurrent_fn or nn.relu, tied, num_units) + self._config = _parse_rnn_config(num_dims, input_dims, output_dims, + priority_dims, non_recurrent_dims, + non_recurrent_fn or nn.relu, tied, + num_units) cell_input_size = (self._config.num_dims - 1) * num_units if cell_fn is None: - self._cell = rnn_cell.LSTMCell(num_units=num_units, input_size=cell_input_size) + self._cell = rnn_cell.LSTMCell( + num_units=num_units, input_size=cell_input_size, state_is_tuple=False) else: self._cell = cell_fn(num_units, cell_input_size) if not isinstance(self._cell, rnn_cell.RNNCell): @@ -100,7 +121,8 @@ class GridRNNCell(rnn_cell.RNNCell): Args: inputs: input Tensor, 2D, batch x input_size. Or None - state: state Tensor, 2D, batch x state_size. Note that state_size = cell_state_size * recurrent_dims + state: state Tensor, 2D, batch x state_size. Note that state_size = + cell_state_size * recurrent_dims scope: VariableScope for the created subgraph; defaults to "GridRNNCell". Returns: @@ -112,24 +134,32 @@ class GridRNNCell(rnn_cell.RNNCell): """ state_sz = state.get_shape().as_list()[1] if self.state_size != state_sz: - raise ValueError('Actual state size not same as specified: {} vs {}.'.format(state_sz, self.state_size)) + raise ValueError( + 'Actual state size not same as specified: {} vs {}.'.format( + state_sz, self.state_size)) conf = self._config dtype = inputs.dtype if inputs is not None else state.dtype - # c_prev is `m`, and m_prev is `h` in the paper. Keep c and m here for consistency with the codebase + # c_prev is `m`, and m_prev is `h` in the paper. + # Keep c and m here for consistency with the codebase c_prev = [None] * self._config.num_dims m_prev = [None] * self._config.num_dims cell_output_size = self._cell.state_size - conf.num_units # for LSTM : state = memory cell + output, hence cell_output_size > 0 - # for GRU/RNN: state = output (whose size is equal to _num_units), hence cell_output_size = 0 - for recurrent_dim, start_idx in zip(self._config.recurrents, range(0, self.state_size, self._cell.state_size)): + # for GRU/RNN: state = output (whose size is equal to _num_units), + # hence cell_output_size = 0 + for recurrent_dim, start_idx in zip(self._config.recurrents, range( + 0, self.state_size, self._cell.state_size)): if cell_output_size > 0: - c_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx], [-1, conf.num_units]) - m_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx + conf.num_units], [-1, cell_output_size]) + c_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx], + [-1, conf.num_units]) + m_prev[recurrent_dim] = array_ops.slice( + state, [0, start_idx + conf.num_units], [-1, cell_output_size]) else: - m_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx], [-1, conf.num_units]) + m_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx], + [-1, conf.num_units]) new_output = [None] * conf.num_dims new_state = [None] * conf.num_dims @@ -137,150 +167,212 @@ class GridRNNCell(rnn_cell.RNNCell): with vs.variable_scope(scope or type(self).__name__): # GridRNNCell # project input - if inputs is not None and sum(inputs.get_shape().as_list()) > 0 and len(conf.inputs) > 0: + if inputs is not None and sum(inputs.get_shape().as_list()) > 0 and len( + conf.inputs) > 0: input_splits = array_ops.split(1, len(conf.inputs), inputs) input_sz = input_splits[0].get_shape().as_list()[1] for i, j in enumerate(conf.inputs): - input_project_m = vs.get_variable('project_m_{}'.format(j), [input_sz, conf.num_units], dtype=dtype) + input_project_m = vs.get_variable( + 'project_m_{}'.format(j), [input_sz, conf.num_units], dtype=dtype) m_prev[j] = math_ops.matmul(input_splits[i], input_project_m) if cell_output_size > 0: - input_project_c = vs.get_variable('project_c_{}'.format(j), [input_sz, conf.num_units], dtype=dtype) + input_project_c = vs.get_variable( + 'project_c_{}'.format(j), [input_sz, conf.num_units], + dtype=dtype) c_prev[j] = math_ops.matmul(input_splits[i], input_project_c) - - _propagate(conf.non_priority, conf, self._cell, c_prev, m_prev, new_output, new_state, True) - _propagate(conf.priority, conf, self._cell, c_prev, m_prev, new_output, new_state, False) + _propagate(conf.non_priority, conf, self._cell, c_prev, m_prev, + new_output, new_state, True) + _propagate(conf.priority, conf, self._cell, c_prev, m_prev, new_output, + new_state, False) output_tensors = [new_output[i] for i in self._config.outputs] - output = array_ops.zeros([0, 0], dtype) if len(output_tensors) == 0 else array_ops.concat(1, - output_tensors) + output = array_ops.zeros( + [0, 0], dtype) if len(output_tensors) == 0 else array_ops.concat( + 1, output_tensors) state_tensors = [new_state[i] for i in self._config.recurrents] - states = array_ops.zeros([0, 0], dtype) if len(state_tensors) == 0 else array_ops.concat(1, state_tensors) + states = array_ops.zeros( + [0, 0], dtype) if len(state_tensors) == 0 else array_ops.concat( + 1, state_tensors) return output, states +"""Specialized cells, for convenience """ -Specialized cells, for convenience -""" + class Grid1BasicRNNCell(GridRNNCell): """1D BasicRNN cell""" + def __init__(self, num_units): - super(Grid1BasicRNNCell, self).__init__(num_units=num_units, num_dims=1, - input_dims=0, output_dims=0, priority_dims=0, tied=False, - cell_fn=lambda n, i: rnn_cell.BasicRNNCell(num_units=n, input_size=i)) + super(Grid1BasicRNNCell, self).__init__( + num_units=num_units, num_dims=1, + input_dims=0, output_dims=0, priority_dims=0, tied=False, + cell_fn=lambda n, i: rnn_cell.BasicRNNCell(num_units=n, input_size=i)) class Grid2BasicRNNCell(GridRNNCell): """2D BasicRNN cell - This creates a 2D cell which receives input and gives output in the first dimension. - The first dimension can optionally be non-recurrent if `non_recurrent_fn` is specified. + + This creates a 2D cell which receives input and gives output in the first + dimension. + + The first dimension can optionally be non-recurrent if `non_recurrent_fn` is + specified. """ + def __init__(self, num_units, tied=False, non_recurrent_fn=None): - super(Grid2BasicRNNCell, self).__init__(num_units=num_units, num_dims=2, - input_dims=0, output_dims=0, priority_dims=0, tied=tied, - non_recurrent_dims=None if non_recurrent_fn is None else 0, - cell_fn=lambda n, i: rnn_cell.BasicRNNCell(num_units=n, input_size=i), - non_recurrent_fn=non_recurrent_fn) + super(Grid2BasicRNNCell, self).__init__( + num_units=num_units, num_dims=2, + input_dims=0, output_dims=0, priority_dims=0, tied=tied, + non_recurrent_dims=None if non_recurrent_fn is None else 0, + cell_fn=lambda n, i: rnn_cell.BasicRNNCell(num_units=n, input_size=i), + non_recurrent_fn=non_recurrent_fn) class Grid1BasicLSTMCell(GridRNNCell): """1D BasicLSTM cell""" + def __init__(self, num_units, forget_bias=1): - super(Grid1BasicLSTMCell, self).__init__(num_units=num_units, num_dims=1, - input_dims=0, output_dims=0, priority_dims=0, tied=False, - cell_fn=lambda n, i: rnn_cell.BasicLSTMCell(num_units=n, - forget_bias=forget_bias, input_size=i)) + super(Grid1BasicLSTMCell, self).__init__( + num_units=num_units, num_dims=1, + input_dims=0, output_dims=0, priority_dims=0, tied=False, + cell_fn=lambda n, i: rnn_cell.BasicLSTMCell( + num_units=n, + forget_bias=forget_bias, input_size=i, + state_is_tuple=False)) class Grid2BasicLSTMCell(GridRNNCell): """2D BasicLSTM cell - This creates a 2D cell which receives input and gives output in the first dimension. - The first dimension can optionally be non-recurrent if `non_recurrent_fn` is specified. + + This creates a 2D cell which receives input and gives output in the first + dimension. + + The first dimension can optionally be non-recurrent if `non_recurrent_fn` is + specified. """ - def __init__(self, num_units, tied=False, non_recurrent_fn=None, forget_bias=1): - super(Grid2BasicLSTMCell, self).__init__(num_units=num_units, num_dims=2, - input_dims=0, output_dims=0, priority_dims=0, tied=tied, - non_recurrent_dims=None if non_recurrent_fn is None else 0, - cell_fn=lambda n, i: rnn_cell.BasicLSTMCell( - num_units=n, forget_bias=forget_bias, input_size=i), - non_recurrent_fn=non_recurrent_fn) + + def __init__(self, + num_units, + tied=False, + non_recurrent_fn=None, + forget_bias=1): + super(Grid2BasicLSTMCell, self).__init__( + num_units=num_units, num_dims=2, + input_dims=0, output_dims=0, priority_dims=0, tied=tied, + non_recurrent_dims=None if non_recurrent_fn is None else 0, + cell_fn=lambda n, i: rnn_cell.BasicLSTMCell( + num_units=n, forget_bias=forget_bias, input_size=i, + state_is_tuple=False), + non_recurrent_fn=non_recurrent_fn) class Grid1LSTMCell(GridRNNCell): """1D LSTM cell - This is different from Grid1BasicLSTMCell because it gives options to specify the forget bias and enabling peepholes + + This is different from Grid1BasicLSTMCell because it gives options to + specify the forget bias and enabling peepholes """ + def __init__(self, num_units, use_peepholes=False, forget_bias=1.0): - super(Grid1LSTMCell, self).__init__(num_units=num_units, num_dims=1, - input_dims=0, output_dims=0, priority_dims=0, - cell_fn=lambda n, i: rnn_cell.LSTMCell( - num_units=n, input_size=i, use_peepholes=use_peepholes, - forget_bias=forget_bias)) + super(Grid1LSTMCell, self).__init__( + num_units=num_units, num_dims=1, + input_dims=0, output_dims=0, priority_dims=0, + cell_fn=lambda n, i: rnn_cell.LSTMCell( + num_units=n, input_size=i, use_peepholes=use_peepholes, + forget_bias=forget_bias, state_is_tuple=False)) class Grid2LSTMCell(GridRNNCell): """2D LSTM cell - This creates a 2D cell which receives input and gives output in the first dimension. - The first dimension can optionally be non-recurrent if `non_recurrent_fn` is specified. + + This creates a 2D cell which receives input and gives output in the first + dimension. + The first dimension can optionally be non-recurrent if `non_recurrent_fn` is + specified. """ - def __init__(self, num_units, tied=False, non_recurrent_fn=None, - use_peepholes=False, forget_bias=1.0): - super(Grid2LSTMCell, self).__init__(num_units=num_units, num_dims=2, - input_dims=0, output_dims=0, priority_dims=0, tied=tied, - non_recurrent_dims=None if non_recurrent_fn is None else 0, - cell_fn=lambda n, i: rnn_cell.LSTMCell( - num_units=n, input_size=i, forget_bias=forget_bias, - use_peepholes=use_peepholes), - non_recurrent_fn=non_recurrent_fn) + + def __init__(self, + num_units, + tied=False, + non_recurrent_fn=None, + use_peepholes=False, + forget_bias=1.0): + super(Grid2LSTMCell, self).__init__( + num_units=num_units, num_dims=2, + input_dims=0, output_dims=0, priority_dims=0, tied=tied, + non_recurrent_dims=None if non_recurrent_fn is None else 0, + cell_fn=lambda n, i: rnn_cell.LSTMCell( + num_units=n, input_size=i, forget_bias=forget_bias, + use_peepholes=use_peepholes, state_is_tuple=False), + non_recurrent_fn=non_recurrent_fn) class Grid3LSTMCell(GridRNNCell): """3D BasicLSTM cell - This creates a 2D cell which receives input and gives output in the first dimension. - The first dimension can optionally be non-recurrent if `non_recurrent_fn` is specified. + + This creates a 2D cell which receives input and gives output in the first + dimension. + The first dimension can optionally be non-recurrent if `non_recurrent_fn` is + specified. The second and third dimensions are LSTM. """ - def __init__(self, num_units, tied=False, non_recurrent_fn=None, - use_peepholes=False, forget_bias=1.0): - super(Grid3LSTMCell, self).__init__(num_units=num_units, num_dims=3, - input_dims=0, output_dims=0, priority_dims=0, tied=tied, - non_recurrent_dims=None if non_recurrent_fn is None else 0, - cell_fn=lambda n, i: rnn_cell.LSTMCell( - num_units=n, input_size=i, forget_bias=forget_bias, - use_peepholes=use_peepholes), - non_recurrent_fn=non_recurrent_fn) + + def __init__(self, + num_units, + tied=False, + non_recurrent_fn=None, + use_peepholes=False, + forget_bias=1.0): + super(Grid3LSTMCell, self).__init__( + num_units=num_units, num_dims=3, + input_dims=0, output_dims=0, priority_dims=0, tied=tied, + non_recurrent_dims=None if non_recurrent_fn is None else 0, + cell_fn=lambda n, i: rnn_cell.LSTMCell( + num_units=n, input_size=i, forget_bias=forget_bias, + use_peepholes=use_peepholes, state_is_tuple=False), + non_recurrent_fn=non_recurrent_fn) + class Grid2GRUCell(GridRNNCell): """2D LSTM cell - This creates a 2D cell which receives input and gives output in the first dimension. - The first dimension can optionally be non-recurrent if `non_recurrent_fn` is specified. + + This creates a 2D cell which receives input and gives output in the first + dimension. + The first dimension can optionally be non-recurrent if `non_recurrent_fn` is + specified. """ def __init__(self, num_units, tied=False, non_recurrent_fn=None): - super(Grid2GRUCell, self).__init__(num_units=num_units, num_dims=2, - input_dims=0, output_dims=0, priority_dims=0, tied=tied, - non_recurrent_dims=None if non_recurrent_fn is None else 0, - cell_fn=lambda n, i: rnn_cell.GRUCell(num_units=n, input_size=i), - non_recurrent_fn=non_recurrent_fn) + super(Grid2GRUCell, self).__init__( + num_units=num_units, num_dims=2, + input_dims=0, output_dims=0, priority_dims=0, tied=tied, + non_recurrent_dims=None if non_recurrent_fn is None else 0, + cell_fn=lambda n, i: rnn_cell.GRUCell(num_units=n, input_size=i), + non_recurrent_fn=non_recurrent_fn) + +"""Helpers """ -Helpers -""" -_GridRNNDimension = namedtuple('_GridRNNDimension', ['idx', 'is_input', 'is_output', 'is_priority', 'non_recurrent_fn']) +_GridRNNDimension = namedtuple( + '_GridRNNDimension', + ['idx', 'is_input', 'is_output', 'is_priority', 'non_recurrent_fn']) + +_GridRNNConfig = namedtuple('_GridRNNConfig', + ['num_dims', 'dims', 'inputs', 'outputs', + 'recurrents', 'priority', 'non_priority', 'tied', + 'num_units']) -_GridRNNConfig = namedtuple('_GridRNNConfig', ['num_dims', 'dims', - 'inputs', 'outputs', 'recurrents', - 'priority', 'non_priority', 'tied', 'num_units']) +def _parse_rnn_config(num_dims, ls_input_dims, ls_output_dims, ls_priority_dims, + ls_non_recurrent_dims, non_recurrent_fn, tied, num_units): -def _parse_rnn_config(num_dims, ls_input_dims, ls_output_dims, ls_priority_dims, ls_non_recurrent_dims, - non_recurrent_fn, tied, num_units): def check_dim_list(ls): if ls is None: ls = [] @@ -288,7 +380,8 @@ def _parse_rnn_config(num_dims, ls_input_dims, ls_output_dims, ls_priority_dims, ls = [ls] ls = sorted(set(ls)) if any(_ < 0 or _ >= num_dims for _ in ls): - raise ValueError('Invalid dims: {}. Must be in [0, {})'.format(ls, num_dims)) + raise ValueError('Invalid dims: {}. Must be in [0, {})'.format(ls, + num_dims)) return ls input_dims = check_dim_list(ls_input_dims) @@ -298,42 +391,58 @@ def _parse_rnn_config(num_dims, ls_input_dims, ls_output_dims, ls_priority_dims, rnn_dims = [] for i in range(num_dims): - rnn_dims.append(_GridRNNDimension(idx=i, is_input=(i in input_dims), is_output=(i in output_dims), - is_priority=(i in priority_dims), - non_recurrent_fn=non_recurrent_fn if i in non_recurrent_dims else None)) - return _GridRNNConfig(num_dims=num_dims, dims=rnn_dims, inputs=input_dims, outputs=output_dims, - recurrents=[x for x in range(num_dims) if x not in non_recurrent_dims], - priority=priority_dims, - non_priority=[x for x in range(num_dims) if x not in priority_dims], - tied=tied, num_units=num_units) - - -def _propagate(dim_indices, conf, cell, c_prev, m_prev, new_output, new_state, first_call): - """ - Propagates through all the cells in dim_indices dimensions. + rnn_dims.append( + _GridRNNDimension( + idx=i, + is_input=(i in input_dims), + is_output=(i in output_dims), + is_priority=(i in priority_dims), + non_recurrent_fn=non_recurrent_fn if i in non_recurrent_dims else + None)) + return _GridRNNConfig( + num_dims=num_dims, + dims=rnn_dims, + inputs=input_dims, + outputs=output_dims, + recurrents=[x for x in range(num_dims) if x not in non_recurrent_dims], + priority=priority_dims, + non_priority=[x for x in range(num_dims) if x not in priority_dims], + tied=tied, + num_units=num_units) + + +def _propagate(dim_indices, conf, cell, c_prev, m_prev, new_output, new_state, + first_call): + """Propagates through all the cells in dim_indices dimensions. """ if len(dim_indices) == 0: return - # Because of the way RNNCells are implemented, we take the last dimension (H_{N-1}) out - # and feed it as the state of the RNN cell (in `last_dim_output`) + # Because of the way RNNCells are implemented, we take the last dimension + # (H_{N-1}) out and feed it as the state of the RNN cell + # (in `last_dim_output`). # The input of the cell (H_0 to H_{N-2}) are concatenated into `cell_inputs` if conf.num_dims > 1: ls_cell_inputs = [None] * (conf.num_dims - 1) for d in conf.dims[:-1]: - ls_cell_inputs[d.idx] = new_output[d.idx] if new_output[d.idx] is not None else m_prev[d.idx] + ls_cell_inputs[d.idx] = new_output[d.idx] if new_output[ + d.idx] is not None else m_prev[d.idx] cell_inputs = array_ops.concat(1, ls_cell_inputs) else: - cell_inputs = array_ops.zeros([m_prev[0].get_shape().as_list()[0], 0], m_prev[0].dtype) + cell_inputs = array_ops.zeros([m_prev[0].get_shape().as_list()[0], 0], + m_prev[0].dtype) last_dim_output = new_output[-1] if new_output[-1] is not None else m_prev[-1] for i in dim_indices: d = conf.dims[i] if d.non_recurrent_fn: - linear_args = array_ops.concat(1, [cell_inputs, last_dim_output]) if conf.num_dims > 1 else last_dim_output - with vs.variable_scope('non_recurrent' if conf.tied else 'non_recurrent/cell_{}'.format(i)): - if conf.tied and not(first_call and i == dim_indices[0]): + linear_args = array_ops.concat( + 1, [cell_inputs, last_dim_output + ]) if conf.num_dims > 1 else last_dim_output + with vs.variable_scope('non_recurrent' if conf.tied else + 'non_recurrent/cell_{}'.format(i)): + if conf.tied and not (first_call and i == dim_indices[0]): vs.get_variable_scope().reuse_variables() new_output[d.idx] = layers.legacy_fully_connected( linear_args, @@ -348,7 +457,8 @@ def _propagate(dim_indices, conf, cell, c_prev, m_prev, new_output, new_state, f # for GRU/RNN, the state is just the previous output cell_state = last_dim_output - with vs.variable_scope('recurrent' if conf.tied else 'recurrent/cell_{}'.format(i)): + with vs.variable_scope('recurrent' if conf.tied else + 'recurrent/cell_{}'.format(i)): if conf.tied and not (first_call and i == dim_indices[0]): vs.get_variable_scope().reuse_variables() new_output[d.idx], new_state[d.idx] = cell(cell_inputs, cell_state) |