diff options
author | 2017-05-05 09:09:05 -0800 | |
---|---|---|
committer | 2017-05-05 10:26:00 -0700 | |
commit | 692fad20f913ffa2cb874a87578ecabb03cc4557 (patch) | |
tree | 172717f537c91b0d1ac0366731b4eb2093fb743b /tensorflow/contrib/grid_rnn | |
parent | b329dd821e29e64c93b1b9bf38e61871c6cb53da (diff) |
Merge changes from github.
Change: 155209832
Diffstat (limited to 'tensorflow/contrib/grid_rnn')
-rw-r--r-- | tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py | 631 | ||||
-rw-r--r-- | tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py | 497 |
2 files changed, 743 insertions, 385 deletions
diff --git a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py index 758e0bcc07..280271a42d 100644 --- a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py +++ b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py @@ -34,180 +34,228 @@ from tensorflow.python.platform import test class GridRNNCellTest(test.TestCase): def testGrid2BasicLSTMCell(self): - with self.test_session() as sess: + with self.test_session(use_gpu=False) as sess: with variable_scope.variable_scope( 'root', initializer=init_ops.constant_initializer(0.2)) as root_scope: x = array_ops.zeros([1, 3]) - m = array_ops.zeros([1, 8]) + m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])), + (array_ops.zeros([1, 2]), array_ops.zeros([1, 2]))) cell = grid_rnn_cell.Grid2BasicLSTMCell(2) - self.assertEqual(cell.state_size, 8) + self.assertEqual(cell.state_size, ((2, 2), (2, 2))) g, s = cell(x, m) - self.assertEqual(g.get_shape(), (1, 2)) - self.assertEqual(s.get_shape(), (1, 8)) + self.assertEqual(g[0].get_shape(), (1, 2)) + self.assertEqual(s[0].c.get_shape(), (1, 2)) + self.assertEqual(s[0].h.get_shape(), (1, 2)) + self.assertEqual(s[1].c.get_shape(), (1, 2)) + self.assertEqual(s[1].h.get_shape(), (1, 2)) sess.run([variables.global_variables_initializer()]) - res = sess.run([g, s], { - x: np.array([[1., 1., 1.]]), - m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]]) + res_g, res_s = sess.run([g, s], { + x: + np.array([[1., 1., 1.]]), + m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])), + (np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]]))) }) - self.assertEqual(res[0].shape, (1, 2)) - self.assertEqual(res[1].shape, (1, 8)) - self.assertAllClose(res[0], [[0.36617181, 0.36617181]]) - self.assertAllClose(res[1], [[0.71053141, 0.71053141, 0.36617181, - 0.36617181, 0.72320831, 0.80555487, - 0.39102408, 0.42150158]]) + self.assertEqual(res_g[0].shape, (1, 2)) + self.assertEqual(res_s[0].c.shape, (1, 2)) + self.assertEqual(res_s[0].h.shape, (1, 2)) + self.assertEqual(res_s[1].c.shape, (1, 2)) + self.assertEqual(res_s[1].h.shape, (1, 2)) + + self.assertAllClose(res_g, ([[0.36617181, 0.36617181]],)) + self.assertAllClose( + res_s, (([[0.71053141, 0.71053141]], [[0.36617181, 0.36617181]]), + ([[0.72320831, 0.80555487]], [[0.39102408, 0.42150158]]))) # emulate a loop through the input sequence, # where we call cell() multiple times root_scope.reuse_variables() g2, s2 = cell(x, m) - self.assertEqual(g2.get_shape(), (1, 2)) - self.assertEqual(s2.get_shape(), (1, 8)) - - res = sess.run([g2, s2], {x: np.array([[2., 2., 2.]]), m: res[1]}) - self.assertEqual(res[0].shape, (1, 2)) - self.assertEqual(res[1].shape, (1, 8)) - self.assertAllClose(res[0], [[0.58847463, 0.58847463]]) - self.assertAllClose(res[1], [[1.40469193, 1.40469193, 0.58847463, - 0.58847463, 0.97726452, 1.04626071, - 0.4927212, 0.51137757]]) + self.assertEqual(g2[0].get_shape(), (1, 2)) + self.assertEqual(s2[0].c.get_shape(), (1, 2)) + self.assertEqual(s2[0].h.get_shape(), (1, 2)) + self.assertEqual(s2[1].c.get_shape(), (1, 2)) + self.assertEqual(s2[1].h.get_shape(), (1, 2)) + + res_g2, res_s2 = sess.run([g2, s2], + {x: np.array([[2., 2., 2.]]), + m: res_s}) + self.assertEqual(res_g2[0].shape, (1, 2)) + self.assertEqual(res_s2[0].c.shape, (1, 2)) + self.assertEqual(res_s2[0].h.shape, (1, 2)) + self.assertEqual(res_s2[1].c.shape, (1, 2)) + self.assertEqual(res_s2[1].h.shape, (1, 2)) + self.assertAllClose(res_g2[0], [[0.58847463, 0.58847463]]) + self.assertAllClose( + res_s2, (([[1.40469193, 1.40469193]], [[0.58847463, 0.58847463]]), + ([[0.97726452, 1.04626071]], [[0.4927212, 0.51137757]]))) def testGrid2BasicLSTMCellTied(self): - with self.test_session() as sess: + with self.test_session(use_gpu=False) as sess: with variable_scope.variable_scope( 'root', initializer=init_ops.constant_initializer(0.2)): x = array_ops.zeros([1, 3]) - m = array_ops.zeros([1, 8]) + m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])), + (array_ops.zeros([1, 2]), array_ops.zeros([1, 2]))) cell = grid_rnn_cell.Grid2BasicLSTMCell(2, tied=True) - self.assertEqual(cell.state_size, 8) + self.assertEqual(cell.state_size, ((2, 2), (2, 2))) g, s = cell(x, m) - self.assertEqual(g.get_shape(), (1, 2)) - self.assertEqual(s.get_shape(), (1, 8)) + self.assertEqual(g[0].get_shape(), (1, 2)) + self.assertEqual(s[0].c.get_shape(), (1, 2)) + self.assertEqual(s[0].h.get_shape(), (1, 2)) + self.assertEqual(s[1].c.get_shape(), (1, 2)) + self.assertEqual(s[1].h.get_shape(), (1, 2)) sess.run([variables.global_variables_initializer()]) - res = sess.run([g, s], { - x: np.array([[1., 1., 1.]]), - m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]]) + res_g, res_s = sess.run([g, s], { + x: + np.array([[1., 1., 1.]]), + m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])), + (np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]]))) }) - self.assertEqual(res[0].shape, (1, 2)) - self.assertEqual(res[1].shape, (1, 8)) - self.assertAllClose(res[0], [[0.36617181, 0.36617181]]) - self.assertAllClose(res[1], [[0.71053141, 0.71053141, 0.36617181, - 0.36617181, 0.72320831, 0.80555487, - 0.39102408, 0.42150158]]) + self.assertEqual(res_g[0].shape, (1, 2)) + self.assertEqual(res_s[0].c.shape, (1, 2)) + self.assertEqual(res_s[0].h.shape, (1, 2)) + self.assertEqual(res_s[1].c.shape, (1, 2)) + self.assertEqual(res_s[1].h.shape, (1, 2)) - res = sess.run([g, s], {x: np.array([[1., 1., 1.]]), m: res[1]}) - self.assertEqual(res[0].shape, (1, 2)) - self.assertEqual(res[1].shape, (1, 8)) - self.assertAllClose(res[0], [[0.36703536, 0.36703536]]) - self.assertAllClose(res[1], [[0.71200621, 0.71200621, 0.36703536, - 0.36703536, 0.80941606, 0.87550586, - 0.40108523, 0.42199609]]) + self.assertAllClose(res_g[0], [[0.36617181, 0.36617181]]) + self.assertAllClose( + res_s, (([[0.71053141, 0.71053141]], [[0.36617181, 0.36617181]]), + ([[0.72320831, 0.80555487]], [[0.39102408, 0.42150158]]))) + + res_g, res_s = sess.run([g, s], {x: np.array([[1., 1., 1.]]), m: res_s}) + self.assertEqual(res_g[0].shape, (1, 2)) + + self.assertAllClose(res_g[0], [[0.36703536, 0.36703536]]) + self.assertAllClose( + res_s, (([[0.71200621, 0.71200621]], [[0.36703536, 0.36703536]]), + ([[0.80941606, 0.87550586]], [[0.40108523, 0.42199609]]))) def testGrid2BasicLSTMCellWithRelu(self): - with self.test_session() as sess: + with self.test_session(use_gpu=False) as sess: with variable_scope.variable_scope( 'root', initializer=init_ops.constant_initializer(0.2)): x = array_ops.zeros([1, 3]) - m = array_ops.zeros([1, 4]) + m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),) cell = grid_rnn_cell.Grid2BasicLSTMCell( 2, tied=False, non_recurrent_fn=nn_ops.relu) - self.assertEqual(cell.state_size, 4) + self.assertEqual(cell.state_size, ((2, 2),)) g, s = cell(x, m) - self.assertEqual(g.get_shape(), (1, 2)) - self.assertEqual(s.get_shape(), (1, 4)) + self.assertEqual(g[0].get_shape(), (1, 2)) + self.assertEqual(s[0].c.get_shape(), (1, 2)) + self.assertEqual(s[0].h.get_shape(), (1, 2)) sess.run([variables.global_variables_initializer()]) - res = sess.run( - [g, s], - {x: np.array([[1., 1., 1.]]), - m: np.array([[0.1, 0.2, 0.3, 0.4]])}) - self.assertEqual(res[0].shape, (1, 2)) - self.assertEqual(res[1].shape, (1, 4)) - self.assertAllClose(res[0], [[0.31667367, 0.31667367]]) - self.assertAllClose(res[1], [[0.29530135, 0.37520045, 0.17044567, - 0.21292259]]) + res_g, res_s = sess.run([g, s], { + x: np.array([[1., 1., 1.]]), + m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),) + }) + self.assertEqual(res_g[0].shape, (1, 2)) + self.assertAllClose(res_g[0], [[0.31667367, 0.31667367]]) + self.assertAllClose(res_s, (([[0.29530135, 0.37520045]], + [[0.17044567, 0.21292259]]),)) """LSTMCell """ def testGrid2LSTMCell(self): - with self.test_session() as sess: + with self.test_session(use_gpu=False) as sess: with variable_scope.variable_scope( 'root', initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 3]) - m = array_ops.zeros([1, 8]) + m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])), + (array_ops.zeros([1, 2]), array_ops.zeros([1, 2]))) cell = grid_rnn_cell.Grid2LSTMCell(2, use_peepholes=True) - self.assertEqual(cell.state_size, 8) + self.assertEqual(cell.state_size, ((2, 2), (2, 2))) g, s = cell(x, m) - self.assertEqual(g.get_shape(), (1, 2)) - self.assertEqual(s.get_shape(), (1, 8)) + self.assertEqual(g[0].get_shape(), (1, 2)) + self.assertEqual(s[0].c.get_shape(), (1, 2)) + self.assertEqual(s[0].h.get_shape(), (1, 2)) + self.assertEqual(s[1].c.get_shape(), (1, 2)) + self.assertEqual(s[1].h.get_shape(), (1, 2)) sess.run([variables.global_variables_initializer()]) - res = sess.run([g, s], { - x: np.array([[1., 1., 1.]]), - m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]]) + res_g, res_s = sess.run([g, s], { + x: + np.array([[1., 1., 1.]]), + m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])), + (np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]]))) }) - self.assertEqual(res[0].shape, (1, 2)) - self.assertEqual(res[1].shape, (1, 8)) - self.assertAllClose(res[0], [[0.95686918, 0.95686918]]) - self.assertAllClose(res[1], [[2.41515064, 2.41515064, 0.95686918, - 0.95686918, 1.38917875, 1.49043763, - 0.83884692, 0.86036491]]) + self.assertEqual(res_g[0].shape, (1, 2)) + self.assertEqual(res_s[0].c.shape, (1, 2)) + self.assertEqual(res_s[0].h.shape, (1, 2)) + self.assertEqual(res_s[1].c.shape, (1, 2)) + self.assertEqual(res_s[1].h.shape, (1, 2)) + + self.assertAllClose(res_g[0], [[0.95686918, 0.95686918]]) + self.assertAllClose( + res_s, (([[2.41515064, 2.41515064]], [[0.95686918, 0.95686918]]), + ([[1.38917875, 1.49043763]], [[0.83884692, 0.86036491]]))) def testGrid2LSTMCellTied(self): - with self.test_session() as sess: + with self.test_session(use_gpu=False) as sess: with variable_scope.variable_scope( 'root', initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 3]) - m = array_ops.zeros([1, 8]) + m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])), + (array_ops.zeros([1, 2]), array_ops.zeros([1, 2]))) cell = grid_rnn_cell.Grid2LSTMCell(2, tied=True, use_peepholes=True) - self.assertEqual(cell.state_size, 8) + self.assertEqual(cell.state_size, ((2, 2), (2, 2))) g, s = cell(x, m) - self.assertEqual(g.get_shape(), (1, 2)) - self.assertEqual(s.get_shape(), (1, 8)) + self.assertEqual(g[0].get_shape(), (1, 2)) + self.assertEqual(s[0].c.get_shape(), (1, 2)) + self.assertEqual(s[0].h.get_shape(), (1, 2)) + self.assertEqual(s[1].c.get_shape(), (1, 2)) + self.assertEqual(s[1].h.get_shape(), (1, 2)) sess.run([variables.global_variables_initializer()]) - res = sess.run([g, s], { - x: np.array([[1., 1., 1.]]), - m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]]) + res_g, res_s = sess.run([g, s], { + x: + np.array([[1., 1., 1.]]), + m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])), + (np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]]))) }) - self.assertEqual(res[0].shape, (1, 2)) - self.assertEqual(res[1].shape, (1, 8)) - self.assertAllClose(res[0], [[0.95686918, 0.95686918]]) - self.assertAllClose(res[1], [[2.41515064, 2.41515064, 0.95686918, - 0.95686918, 1.38917875, 1.49043763, - 0.83884692, 0.86036491]]) + self.assertEqual(res_g[0].shape, (1, 2)) + self.assertEqual(res_s[0].c.shape, (1, 2)) + self.assertEqual(res_s[0].h.shape, (1, 2)) + self.assertEqual(res_s[1].c.shape, (1, 2)) + self.assertEqual(res_s[1].h.shape, (1, 2)) + + self.assertAllClose(res_g[0], [[0.95686918, 0.95686918]]) + self.assertAllClose( + res_s, (([[2.41515064, 2.41515064]], [[0.95686918, 0.95686918]]), + ([[1.38917875, 1.49043763]], [[0.83884692, 0.86036491]]))) def testGrid2LSTMCellWithRelu(self): with self.test_session() as sess: with variable_scope.variable_scope( 'root', initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 3]) - m = array_ops.zeros([1, 4]) + m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),) cell = grid_rnn_cell.Grid2LSTMCell( 2, use_peepholes=True, non_recurrent_fn=nn_ops.relu) - self.assertEqual(cell.state_size, 4) + self.assertEqual(cell.state_size, ((2, 2),)) g, s = cell(x, m) - self.assertEqual(g.get_shape(), (1, 2)) - self.assertEqual(s.get_shape(), (1, 4)) + self.assertEqual(g[0].get_shape(), (1, 2)) + self.assertEqual(s[0].c.get_shape(), (1, 2)) + self.assertEqual(s[0].h.get_shape(), (1, 2)) sess.run([variables.global_variables_initializer()]) - res = sess.run( - [g, s], - {x: np.array([[1., 1., 1.]]), - m: np.array([[0.1, 0.2, 0.3, 0.4]])}) - self.assertEqual(res[0].shape, (1, 2)) - self.assertEqual(res[1].shape, (1, 4)) - self.assertAllClose(res[0], [[2.1831727, 2.1831727]]) - self.assertAllClose(res[1], [[0.92270052, 1.02325559, 0.66159075, - 0.70475441]]) + res_g, res_s = sess.run([g, s], { + x: np.array([[1., 1., 1.]]), + m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),) + }) + self.assertEqual(res_g[0].shape, (1, 2)) + self.assertAllClose(res_g[0], [[2.1831727, 2.1831727]]) + self.assertAllClose(res_s, (([[0.92270052, 1.02325559]], + [[0.66159075, 0.70475441]]),)) """RNNCell """ @@ -217,74 +265,84 @@ class GridRNNCellTest(test.TestCase): with variable_scope.variable_scope( 'root', initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([2, 2]) - m = array_ops.zeros([2, 4]) + m = (array_ops.zeros([2, 2]), array_ops.zeros([2, 2])) cell = grid_rnn_cell.Grid2BasicRNNCell(2) - self.assertEqual(cell.state_size, 4) + self.assertEqual(cell.state_size, (2, 2)) g, s = cell(x, m) - self.assertEqual(g.get_shape(), (2, 2)) - self.assertEqual(s.get_shape(), (2, 4)) + self.assertEqual(g[0].get_shape(), (2, 2)) + self.assertEqual(s[0].get_shape(), (2, 2)) + self.assertEqual(s[1].get_shape(), (2, 2)) sess.run([variables.global_variables_initializer()]) - res = sess.run([g, s], { - x: np.array([[1., 1.], [2., 2.]]), - m: np.array([[0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2]]) + res_g, res_s = sess.run([g, s], { + x: + np.array([[1., 1.], [2., 2.]]), + m: (np.array([[0.1, 0.1], [0.2, 0.2]]), np.array([[0.1, 0.1], + [0.2, 0.2]])) }) - self.assertEqual(res[0].shape, (2, 2)) - self.assertEqual(res[1].shape, (2, 4)) - self.assertAllClose(res[0], [[0.94685763, 0.94685763], - [0.99480951, 0.99480951]]) - self.assertAllClose(res[1], - [[0.94685763, 0.94685763, 0.80049908, 0.80049908], - [0.99480951, 0.99480951, 0.97574311, 0.97574311]]) + self.assertEqual(res_g[0].shape, (2, 2)) + self.assertEqual(res_s[0].shape, (2, 2)) + self.assertEqual(res_s[1].shape, (2, 2)) + + self.assertAllClose(res_g, ([[0.94685763, 0.94685763], + [0.99480951, 0.99480951]],)) + self.assertAllClose( + res_s, ([[0.94685763, 0.94685763], [0.99480951, 0.99480951]], + [[0.80049908, 0.80049908], [0.97574311, 0.97574311]])) def testGrid2BasicRNNCellTied(self): with self.test_session() as sess: with variable_scope.variable_scope( 'root', initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([2, 2]) - m = array_ops.zeros([2, 4]) + m = (array_ops.zeros([2, 2]), array_ops.zeros([2, 2])) cell = grid_rnn_cell.Grid2BasicRNNCell(2, tied=True) - self.assertEqual(cell.state_size, 4) + self.assertEqual(cell.state_size, (2, 2)) g, s = cell(x, m) - self.assertEqual(g.get_shape(), (2, 2)) - self.assertEqual(s.get_shape(), (2, 4)) + self.assertEqual(g[0].get_shape(), (2, 2)) + self.assertEqual(s[0].get_shape(), (2, 2)) + self.assertEqual(s[1].get_shape(), (2, 2)) sess.run([variables.global_variables_initializer()]) - res = sess.run([g, s], { - x: np.array([[1., 1.], [2., 2.]]), - m: np.array([[0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2]]) + res_g, res_s = sess.run([g, s], { + x: + np.array([[1., 1.], [2., 2.]]), + m: (np.array([[0.1, 0.1], [0.2, 0.2]]), np.array([[0.1, 0.1], + [0.2, 0.2]])) }) - self.assertEqual(res[0].shape, (2, 2)) - self.assertEqual(res[1].shape, (2, 4)) - self.assertAllClose(res[0], [[0.94685763, 0.94685763], - [0.99480951, 0.99480951]]) - self.assertAllClose(res[1], - [[0.94685763, 0.94685763, 0.80049908, 0.80049908], - [0.99480951, 0.99480951, 0.97574311, 0.97574311]]) + self.assertEqual(res_g[0].shape, (2, 2)) + self.assertEqual(res_s[0].shape, (2, 2)) + self.assertEqual(res_s[1].shape, (2, 2)) + + self.assertAllClose(res_g, ([[0.94685763, 0.94685763], + [0.99480951, 0.99480951]],)) + self.assertAllClose( + res_s, ([[0.94685763, 0.94685763], [0.99480951, 0.99480951]], + [[0.80049908, 0.80049908], [0.97574311, 0.97574311]])) def testGrid2BasicRNNCellWithRelu(self): with self.test_session() as sess: with variable_scope.variable_scope( 'root', initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) - m = array_ops.zeros([1, 2]) + m = (array_ops.zeros([1, 2]),) cell = grid_rnn_cell.Grid2BasicRNNCell(2, non_recurrent_fn=nn_ops.relu) - self.assertEqual(cell.state_size, 2) + self.assertEqual(cell.state_size, (2,)) g, s = cell(x, m) - self.assertEqual(g.get_shape(), (1, 2)) - self.assertEqual(s.get_shape(), (1, 2)) + self.assertEqual(g[0].get_shape(), (1, 2)) + self.assertEqual(s[0].get_shape(), (1, 2)) sess.run([variables.global_variables_initializer()]) - res = sess.run([g, s], - {x: np.array([[1., 1.]]), - m: np.array([[0.1, 0.1]])}) - self.assertEqual(res[0].shape, (1, 2)) - self.assertEqual(res[1].shape, (1, 2)) - self.assertAllClose(res[0], [[1.80049896, 1.80049896]]) - self.assertAllClose(res[1], [[0.80049896, 0.80049896]]) + res_g, res_s = sess.run( + [g, s], {x: np.array([[1., 1.]]), + m: np.array([[0.1, 0.1]])}) + self.assertEqual(res_g[0].shape, (1, 2)) + self.assertEqual(res_s[0].shape, (1, 2)) + self.assertAllClose(res_g, ([[1.80049896, 1.80049896]],)) + self.assertAllClose(res_s, ([[0.80049896, 0.80049896]],)) """1-LSTM """ @@ -294,51 +352,59 @@ class GridRNNCellTest(test.TestCase): with variable_scope.variable_scope( 'root', initializer=init_ops.constant_initializer(0.5)) as root_scope: x = array_ops.zeros([1, 3]) - m = array_ops.zeros([1, 4]) + m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),) cell = grid_rnn_cell.Grid1LSTMCell(2, use_peepholes=True) - self.assertEqual(cell.state_size, 4) + self.assertEqual(cell.state_size, ((2, 2),)) g, s = cell(x, m) - self.assertEqual(g.get_shape(), (1, 2)) - self.assertEqual(s.get_shape(), (1, 4)) + self.assertEqual(g[0].get_shape(), (1, 2)) + self.assertEqual(s[0].c.get_shape(), (1, 2)) + self.assertEqual(s[0].h.get_shape(), (1, 2)) sess.run([variables.global_variables_initializer()]) - res = sess.run( - [g, s], - {x: np.array([[1., 1., 1.]]), - m: np.array([[0.1, 0.2, 0.3, 0.4]])}) - self.assertEqual(res[0].shape, (1, 2)) - self.assertEqual(res[1].shape, (1, 4)) - self.assertAllClose(res[0], [[0.91287315, 0.91287315]]) - self.assertAllClose(res[1], - [[2.26285243, 2.26285243, 0.91287315, 0.91287315]]) + res_g, res_s = sess.run([g, s], { + x: np.array([[1., 1., 1.]]), + m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),) + }) + self.assertEqual(res_g[0].shape, (1, 2)) + self.assertEqual(res_s[0].c.shape, (1, 2)) + self.assertEqual(res_s[0].h.shape, (1, 2)) + + self.assertAllClose(res_g, ([[0.91287315, 0.91287315]],)) + self.assertAllClose(res_s, (([[2.26285243, 2.26285243]], + [[0.91287315, 0.91287315]]),)) root_scope.reuse_variables() x2 = array_ops.zeros([0, 0]) g2, s2 = cell(x2, m) - self.assertEqual(g2.get_shape(), (1, 2)) - self.assertEqual(s2.get_shape(), (1, 4)) + self.assertEqual(g2[0].get_shape(), (1, 2)) + self.assertEqual(s2[0].c.get_shape(), (1, 2)) + self.assertEqual(s2[0].h.get_shape(), (1, 2)) sess.run([variables.global_variables_initializer()]) - res = sess.run([g2, s2], {m: res[1]}) - self.assertEqual(res[0].shape, (1, 2)) - self.assertEqual(res[1].shape, (1, 4)) - self.assertAllClose(res[0], [[0.9032144, 0.9032144]]) - self.assertAllClose(res[1], - [[2.79966092, 2.79966092, 0.9032144, 0.9032144]]) + res_g2, res_s2 = sess.run([g2, s2], {m: res_s}) + self.assertEqual(res_g2[0].shape, (1, 2)) + self.assertEqual(res_s2[0].c.shape, (1, 2)) + self.assertEqual(res_s2[0].h.shape, (1, 2)) + + self.assertAllClose(res_g2, ([[0.9032144, 0.9032144]],)) + self.assertAllClose(res_s2, (([[2.79966092, 2.79966092]], + [[0.9032144, 0.9032144]]),)) g3, s3 = cell(x2, m) - self.assertEqual(g3.get_shape(), (1, 2)) - self.assertEqual(s3.get_shape(), (1, 4)) + self.assertEqual(g3[0].get_shape(), (1, 2)) + self.assertEqual(s3[0].c.get_shape(), (1, 2)) + self.assertEqual(s3[0].h.get_shape(), (1, 2)) sess.run([variables.global_variables_initializer()]) - res = sess.run([g3, s3], {m: res[1]}) - self.assertEqual(res[0].shape, (1, 2)) - self.assertEqual(res[1].shape, (1, 4)) - self.assertAllClose(res[0], [[0.92727238, 0.92727238]]) - self.assertAllClose(res[1], - [[3.3529923, 3.3529923, 0.92727238, 0.92727238]]) + res_g3, res_s3 = sess.run([g3, s3], {m: res_s2}) + self.assertEqual(res_g3[0].shape, (1, 2)) + self.assertEqual(res_s3[0].c.shape, (1, 2)) + self.assertEqual(res_s3[0].h.shape, (1, 2)) + self.assertAllClose(res_g3, ([[0.92727238, 0.92727238]],)) + self.assertAllClose(res_s3, (([[3.3529923, 3.3529923]], + [[0.92727238, 0.92727238]]),)) """3-LSTM """ @@ -348,32 +414,42 @@ class GridRNNCellTest(test.TestCase): with variable_scope.variable_scope( 'root', initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 3]) - m = array_ops.zeros([1, 12]) + m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])), + (array_ops.zeros([1, 2]), array_ops.zeros([1, 2])), + (array_ops.zeros([1, 2]), array_ops.zeros([1, 2]))) cell = grid_rnn_cell.Grid3LSTMCell(2, use_peepholes=True) - self.assertEqual(cell.state_size, 12) + self.assertEqual(cell.state_size, ((2, 2), (2, 2), (2, 2))) g, s = cell(x, m) - self.assertEqual(g.get_shape(), (1, 2)) - self.assertEqual(s.get_shape(), (1, 12)) + self.assertEqual(g[0].get_shape(), (1, 2)) + self.assertEqual(s[0].c.get_shape(), (1, 2)) + self.assertEqual(s[0].h.get_shape(), (1, 2)) + self.assertEqual(s[1].c.get_shape(), (1, 2)) + self.assertEqual(s[1].h.get_shape(), (1, 2)) + self.assertEqual(s[2].c.get_shape(), (1, 2)) + self.assertEqual(s[2].h.get_shape(), (1, 2)) sess.run([variables.global_variables_initializer()]) - res = sess.run([g, s], { + res_g, res_s = sess.run([g, s], { x: np.array([[1., 1., 1.]]), - m: - np.array([[ - 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, -0.1, -0.2, -0.3, - -0.4 - ]]) + m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])), + (np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])), (np.array( + [[-0.1, -0.2]]), np.array([[-0.3, -0.4]]))) }) - self.assertEqual(res[0].shape, (1, 2)) - self.assertEqual(res[1].shape, (1, 12)) - - self.assertAllClose(res[0], [[0.96892911, 0.96892911]]) - self.assertAllClose(res[1], [[2.45227885, 2.45227885, 0.96892911, - 0.96892911, 1.33592629, 1.4373529, - 0.80867189, 0.83247656, 0.7317788, - 0.63205892, 0.56548983, 0.50446129]]) + self.assertEqual(res_g[0].shape, (1, 2)) + self.assertEqual(res_s[0].c.shape, (1, 2)) + self.assertEqual(res_s[0].h.shape, (1, 2)) + self.assertEqual(res_s[1].c.shape, (1, 2)) + self.assertEqual(res_s[1].h.shape, (1, 2)) + self.assertEqual(res_s[2].c.shape, (1, 2)) + self.assertEqual(res_s[2].h.shape, (1, 2)) + + self.assertAllClose(res_g, ([[0.96892911, 0.96892911]],)) + self.assertAllClose( + res_s, (([[2.45227885, 2.45227885]], [[0.96892911, 0.96892911]]), + ([[1.33592629, 1.4373529]], [[0.80867189, 0.83247656]]), + ([[0.7317788, 0.63205892]], [[0.56548983, 0.50446129]]))) """Edge cases """ @@ -383,7 +459,7 @@ class GridRNNCellTest(test.TestCase): with variable_scope.variable_scope( 'root', initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([3, 2]) - m = array_ops.zeros([0, 0]) + m = () # this is equivalent to relu cell = grid_rnn_cell.GridRNNCell( @@ -394,21 +470,22 @@ class GridRNNCellTest(test.TestCase): non_recurrent_dims=0, non_recurrent_fn=nn_ops.relu) g, s = cell(x, m) - self.assertEqual(g.get_shape(), (3, 2)) - self.assertEqual(s.get_shape(), (0, 0)) + self.assertEqual(g[0].get_shape(), (3, 2)) + self.assertEqual(s, ()) sess.run([variables.global_variables_initializer()]) - res = sess.run([g, s], {x: np.array([[1., -1.], [-2, 1], [2, -1]])}) - self.assertEqual(res[0].shape, (3, 2)) - self.assertEqual(res[1].shape, (0, 0)) - self.assertAllClose(res[0], [[0, 0], [0, 0], [0.5, 0.5]]) + res_g, res_s = sess.run([g, s], + {x: np.array([[1., -1.], [-2, 1], [2, -1]])}) + self.assertEqual(res_g[0].shape, (3, 2)) + self.assertEqual(res_s, ()) + self.assertAllClose(res_g, ([[0, 0], [0, 0], [0.5, 0.5]],)) def testGridRNNEdgeCasesNoOutput(self): with self.test_session() as sess: with variable_scope.variable_scope( 'root', initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) - m = array_ops.zeros([1, 4]) + m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),) # This cell produces no output cell = grid_rnn_cell.GridRNNCell( @@ -419,16 +496,18 @@ class GridRNNCellTest(test.TestCase): non_recurrent_dims=0, non_recurrent_fn=nn_ops.relu) g, s = cell(x, m) - self.assertEqual(g.get_shape(), (0, 0)) - self.assertEqual(s.get_shape(), (1, 4)) + self.assertEqual(g, ()) + self.assertEqual(s[0].c.get_shape(), (1, 2)) + self.assertEqual(s[0].h.get_shape(), (1, 2)) sess.run([variables.global_variables_initializer()]) - res = sess.run( - [g, s], - {x: np.array([[1., 1.]]), - m: np.array([[0.1, 0.1, 0.1, 0.1]])}) - self.assertEqual(res[0].shape, (0, 0)) - self.assertEqual(res[1].shape, (1, 4)) + res_g, res_s = sess.run([g, s], { + x: np.array([[1., 1.]]), + m: ((np.array([[0.1, 0.1]]), np.array([[0.1, 0.1]])),) + }) + self.assertEqual(res_g, ()) + self.assertEqual(res_s[0].c.shape, (1, 2)) + self.assertEqual(res_s[0].h.shape, (1, 2)) """Test with tf.nn.rnn """ @@ -451,20 +530,29 @@ class GridRNNCellTest(test.TestCase): outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32) self.assertEqual(len(outputs), len(inputs)) - self.assertEqual(state.get_shape(), (batch_size, 8)) + self.assertEqual(state[0].c.get_shape(), (batch_size, 2)) + self.assertEqual(state[0].h.get_shape(), (batch_size, 2)) + self.assertEqual(state[1].c.get_shape(), (batch_size, 2)) + self.assertEqual(state[1].h.get_shape(), (batch_size, 2)) for out, inp in zip(outputs, inputs): - self.assertEqual(out.get_shape()[0], inp.get_shape()[0]) - self.assertEqual(out.get_shape()[1], num_units) - self.assertEqual(out.dtype, inp.dtype) + self.assertEqual(len(out), 1) + self.assertEqual(out[0].get_shape()[0], inp.get_shape()[0]) + self.assertEqual(out[0].get_shape()[1], num_units) + self.assertEqual(out[0].dtype, inp.dtype) with self.test_session() as sess: sess.run(variables.global_variables_initializer()) input_value = np.ones((batch_size, input_size)) values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value}) - for v in values: - self.assertTrue(np.all(np.isfinite(v))) + for tp in values[:-1]: + for v in tp: + self.assertTrue(np.all(np.isfinite(v))) + for tp in values[-1]: + for st in tp: + for v in st: + self.assertTrue(np.all(np.isfinite(v))) def testGrid2LSTMCellReLUWithRNN(self): batch_size = 3 @@ -478,27 +566,33 @@ class GridRNNCellTest(test.TestCase): num_units=num_units, non_recurrent_fn=nn_ops.relu) inputs = max_length * [ - array_ops.placeholder( - dtypes.float32, shape=(batch_size, input_size)) + array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size)) ] outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32) self.assertEqual(len(outputs), len(inputs)) - self.assertEqual(state.get_shape(), (batch_size, 4)) + self.assertEqual(state[0].c.get_shape(), (batch_size, 2)) + self.assertEqual(state[0].h.get_shape(), (batch_size, 2)) for out, inp in zip(outputs, inputs): - self.assertEqual(out.get_shape()[0], inp.get_shape()[0]) - self.assertEqual(out.get_shape()[1], num_units) - self.assertEqual(out.dtype, inp.dtype) + self.assertEqual(len(out), 1) + self.assertEqual(out[0].get_shape()[0], inp.get_shape()[0]) + self.assertEqual(out[0].get_shape()[1], num_units) + self.assertEqual(out[0].dtype, inp.dtype) with self.test_session() as sess: sess.run(variables.global_variables_initializer()) input_value = np.ones((batch_size, input_size)) values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value}) - for v in values: - self.assertTrue(np.all(np.isfinite(v))) + for tp in values[:-1]: + for v in tp: + self.assertTrue(np.all(np.isfinite(v))) + for tp in values[-1]: + for st in tp: + for v in st: + self.assertTrue(np.all(np.isfinite(v))) def testGrid3LSTMCellReLUWithRNN(self): batch_size = 3 @@ -512,27 +606,35 @@ class GridRNNCellTest(test.TestCase): num_units=num_units, non_recurrent_fn=nn_ops.relu) inputs = max_length * [ - array_ops.placeholder( - dtypes.float32, shape=(batch_size, input_size)) + array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size)) ] outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32) self.assertEqual(len(outputs), len(inputs)) - self.assertEqual(state.get_shape(), (batch_size, 8)) + self.assertEqual(state[0].c.get_shape(), (batch_size, 2)) + self.assertEqual(state[0].h.get_shape(), (batch_size, 2)) + self.assertEqual(state[1].c.get_shape(), (batch_size, 2)) + self.assertEqual(state[1].h.get_shape(), (batch_size, 2)) for out, inp in zip(outputs, inputs): - self.assertEqual(out.get_shape()[0], inp.get_shape()[0]) - self.assertEqual(out.get_shape()[1], num_units) - self.assertEqual(out.dtype, inp.dtype) + self.assertEqual(len(out), 1) + self.assertEqual(out[0].get_shape()[0], inp.get_shape()[0]) + self.assertEqual(out[0].get_shape()[1], num_units) + self.assertEqual(out[0].dtype, inp.dtype) with self.test_session() as sess: sess.run(variables.global_variables_initializer()) input_value = np.ones((batch_size, input_size)) values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value}) - for v in values: - self.assertTrue(np.all(np.isfinite(v))) + for tp in values[:-1]: + for v in tp: + self.assertTrue(np.all(np.isfinite(v))) + for tp in values[-1]: + for st in tp: + for v in st: + self.assertTrue(np.all(np.isfinite(v))) def testGrid1LSTMCellWithRNN(self): batch_size = 3 @@ -553,20 +655,91 @@ class GridRNNCellTest(test.TestCase): outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32) self.assertEqual(len(outputs), len(inputs)) - self.assertEqual(state.get_shape(), (batch_size, 4)) + self.assertEqual(state[0].c.get_shape(), (batch_size, 2)) + self.assertEqual(state[0].h.get_shape(), (batch_size, 2)) for out, inp in zip(outputs, inputs): - self.assertEqual(out.get_shape(), (3, num_units)) - self.assertEqual(out.dtype, inp.dtype) + self.assertEqual(len(out), 1) + self.assertEqual(out[0].get_shape(), (3, num_units)) + self.assertEqual(out[0].dtype, inp.dtype) with self.test_session() as sess: sess.run(variables.global_variables_initializer()) input_value = np.ones((batch_size, input_size)) values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value}) - for v in values: - self.assertTrue(np.all(np.isfinite(v))) + for tp in values[:-1]: + for v in tp: + self.assertTrue(np.all(np.isfinite(v))) + for tp in values[-1]: + for st in tp: + for v in st: + self.assertTrue(np.all(np.isfinite(v))) + + def testGrid2LSTMCellWithRNNAndDynamicBatchSize(self): + """Test for #4296.""" + input_size = 5 + max_length = 6 # unrolled up to this length + num_units = 2 + + with variable_scope.variable_scope( + 'root', initializer=init_ops.constant_initializer(0.5)): + cell = grid_rnn_cell.Grid2LSTMCell(num_units=num_units) + inputs = max_length * [ + array_ops.placeholder(dtypes.float32, shape=(None, input_size)) + ] + + outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32) + + self.assertEqual(len(outputs), len(inputs)) + + for out, inp in zip(outputs, inputs): + self.assertEqual(len(out), 1) + self.assertTrue(out[0].get_shape()[0].value is None) + self.assertEqual(out[0].get_shape()[1], num_units) + self.assertEqual(out[0].dtype, inp.dtype) + + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + + input_value = np.ones((3, input_size)) + values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value}) + for tp in values[:-1]: + for v in tp: + self.assertTrue(np.all(np.isfinite(v))) + for tp in values[-1]: + for st in tp: + for v in st: + self.assertTrue(np.all(np.isfinite(v))) + + def testGrid2LSTMCellLegacy(self): + """Test for legacy case (when state_is_tuple=False).""" + with self.test_session() as sess: + with variable_scope.variable_scope( + 'root', initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 3]) + m = array_ops.zeros([1, 8]) + cell = grid_rnn_cell.Grid2LSTMCell( + 2, use_peepholes=True, state_is_tuple=False, output_is_tuple=False) + self.assertEqual(cell.state_size, 8) + + g, s = cell(x, m) + self.assertEqual(g.get_shape(), (1, 2)) + self.assertEqual(s.get_shape(), (1, 8)) + + sess.run([variables.global_variables_initializer()]) + res = sess.run([g, s], { + x: np.array([[1., 1., 1.]]), + m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]]) + }) + self.assertEqual(res[0].shape, (1, 2)) + self.assertEqual(res[1].shape, (1, 8)) + self.assertAllClose(res[0], [[0.95686918, 0.95686918]]) + self.assertAllClose(res[1], [[ + 2.41515064, 2.41515064, 0.95686918, 0.95686918, 1.38917875, + 1.49043763, 0.83884692, 0.86036491 + ]]) if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py b/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py index 269b224581..252788140f 100644 --- a/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py +++ b/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py @@ -25,6 +25,8 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import variable_scope as vs + +from tensorflow.python.platform import tf_logging as logging from tensorflow.contrib import layers from tensorflow.contrib import rnn @@ -53,7 +55,9 @@ class GridRNNCell(rnn.RNNCell): non_recurrent_dims=None, tied=False, cell_fn=None, - non_recurrent_fn=None): + non_recurrent_fn=None, + state_is_tuple=True, + output_is_tuple=True): """Initialize the parameters of a Grid RNN cell Args: @@ -68,26 +72,47 @@ class GridRNNCell(rnn.RNNCell): non_recurrent_dims: int or list, List of dimensions that are not recurrent. The transfer function for non-recurrent dimensions is specified - via `non_recurrent_fn`, - which is default to be `tensorflow.nn.relu`. + via `non_recurrent_fn`, which is + default to be `tensorflow.nn.relu`. tied: bool, Whether to share the weights among the dimensions of this GridRNN cell. If there are non-recurrent dimensions in the grid, weights are - shared between each - group of recurrent and non-recurrent dimensions. - cell_fn: function, a function which returns the recurrent cell object. Has - to be in the following signature: - def cell_func(num_units, input_size): + shared between each group of recurrent and non-recurrent + dimensions. + cell_fn: function, a function which returns the recurrent cell object. + Has to be in the following signature: + ``` + def cell_func(num_units): # ... - + ``` and returns an object of type `RNNCell`. If None, LSTMCell with default parameters will be used. + Note that if you use a custom RNNCell (with `cell_fn`), it is your + responsibility to make sure the inner cell use `state_is_tuple=True`. + non_recurrent_fn: a tensorflow Op that will be the transfer function of the non-recurrent dimensions + state_is_tuple: If True, accepted and returned states are tuples of the + states of the recurrent dimensions. If False, they are concatenated + along the column axis. The latter behavior will soon be deprecated. + + Note that if you use a custom RNNCell (with `cell_fn`), it is your + responsibility to make sure the inner cell use `state_is_tuple=True`. + + output_is_tuple: If True, the output is a tuple of the outputs of the + recurrent dimensions. If False, they are concatenated along the + column axis. The later behavior will soon be deprecated. Raises: TypeError: if cell_fn does not return an RNNCell instance. """ + if not state_is_tuple: + logging.warning('%s: Using a concatenated state is slower and will ' + 'soon be deprecated. Use state_is_tuple=True.', self) + if not output_is_tuple: + logging.warning('%s: Using a concatenated output is slower and will' + 'soon be deprecated. Use output_is_tuple=True.', self) + if num_dims < 1: raise ValueError('dims must be >= 1: {}'.format(num_dims)) @@ -96,37 +121,41 @@ class GridRNNCell(rnn.RNNCell): non_recurrent_fn or nn.relu, tied, num_units) - cell_input_size = (self._config.num_dims - 1) * num_units + self._state_is_tuple = state_is_tuple + self._output_is_tuple = output_is_tuple + if cell_fn is None: my_cell_fn = functools.partial( - rnn.LSTMCell, - num_units=num_units, input_size=cell_input_size, - state_is_tuple=False) + rnn.LSTMCell, num_units=num_units, state_is_tuple=state_is_tuple) else: - my_cell_fn = lambda: cell_fn(num_units, cell_input_size) + my_cell_fn = lambda: cell_fn(num_units) if tied: self._cells = [my_cell_fn()] * num_dims else: self._cells = [my_cell_fn() for _ in range(num_dims)] if not isinstance(self._cells[0], rnn.RNNCell): - raise TypeError( - 'cell_fn must return an RNNCell instance, saw: %s' - % type(self._cells[0])) + raise TypeError('cell_fn must return an RNNCell instance, saw: %s' % + type(self._cells[0])) - @property - def input_size(self): - # temporarily using num_units as the input_size of each dimension. - # The actual input size only determined when this cell get invoked, - # so this information can be considered unreliable. - return self._config.num_units * len(self._config.inputs) + if self._output_is_tuple: + self._output_size = tuple(self._cells[0].output_size + for _ in self._config.outputs) + else: + self._output_size = self._cells[0].output_size * len(self._config.outputs) + + if self._state_is_tuple: + self._state_size = tuple(self._cells[0].state_size + for _ in self._config.recurrents) + else: + self._state_size = self._cell_state_size() * len(self._config.recurrents) @property def output_size(self): - return self._cells[0].output_size * len(self._config.outputs) + return self._output_size @property def state_size(self): - return self._cells[0].state_size * len(self._config.recurrents) + return self._state_size def __call__(self, inputs, state, scope=None): """Run one step of GridRNN. @@ -145,76 +174,148 @@ class GridRNNCell(rnn.RNNCell): - A 2D, batch x state_size, Tensor representing the new state of the cell after reading "inputs" when previous state was "state". """ - state_sz = state.get_shape().as_list()[1] - if self.state_size != state_sz: - raise ValueError( - 'Actual state size not same as specified: {} vs {}.'.format( - state_sz, self.state_size)) - conf = self._config - dtype = inputs.dtype if inputs is not None else state.dtype + dtype = inputs.dtype - # c_prev is `m`, and m_prev is `h` in the paper. - # Keep c and m here for consistency with the codebase - c_prev = [None] * self._config.num_dims - m_prev = [None] * self._config.num_dims - cell_output_size = self._cells[0].state_size - conf.num_units - - # for LSTM : state = memory cell + output, hence cell_output_size > 0 - # for GRU/RNN: state = output (whose size is equal to _num_units), - # hence cell_output_size = 0 - for recurrent_dim, start_idx in zip(self._config.recurrents, range( - 0, self.state_size, self._cells[0].state_size)): - if cell_output_size > 0: - c_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx], - [-1, conf.num_units]) - m_prev[recurrent_dim] = array_ops.slice( - state, [0, start_idx + conf.num_units], [-1, cell_output_size]) - else: - m_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx], - [-1, conf.num_units]) + c_prev, m_prev, cell_output_size = self._extract_states(state) new_output = [None] * conf.num_dims new_state = [None] * conf.num_dims with vs.variable_scope(scope or type(self).__name__): # GridRNNCell + # project input, populate c_prev and m_prev + self._project_input(inputs, c_prev, m_prev, cell_output_size > 0) - # project input - if inputs is not None and sum(inputs.get_shape().as_list()) > 0 and len( - conf.inputs) > 0: - input_splits = array_ops.split( - value=inputs, num_or_size_splits=len(conf.inputs), axis=1) - input_sz = input_splits[0].get_shape().as_list()[1] - - for i, j in enumerate(conf.inputs): - input_project_m = vs.get_variable( - 'project_m_{}'.format(j), [input_sz, conf.num_units], dtype=dtype) - m_prev[j] = math_ops.matmul(input_splits[i], input_project_m) - - if cell_output_size > 0: - input_project_c = vs.get_variable( - 'project_c_{}'.format(j), [input_sz, conf.num_units], - dtype=dtype) - c_prev[j] = math_ops.matmul(input_splits[i], input_project_c) - + # propagate along dimensions, first for non-priority dimensions + # then priority dimensions _propagate(conf.non_priority, conf, self._cells, c_prev, m_prev, new_output, new_state, True) _propagate(conf.priority, conf, self._cells, c_prev, m_prev, new_output, new_state, False) + # collect outputs and states output_tensors = [new_output[i] for i in self._config.outputs] - output = array_ops.zeros( - [0, 0], dtype) if len(output_tensors) == 0 else array_ops.concat( - output_tensors, 1) + if self._output_is_tuple: + output = tuple(output_tensors) + else: + if output_tensors: + output = array_ops.concat(output_tensors, 1) + else: + output = array_ops.zeros([0, 0], dtype) - state_tensors = [new_state[i] for i in self._config.recurrents] - states = array_ops.zeros( - [0, 0], - dtype) if len(state_tensors) == 0 else array_ops.concat(state_tensors, - 1) + if self._state_is_tuple: + states = tuple(new_state[i] for i in self._config.recurrents) + else: + # concat each state first, then flatten the whole thing + state_tensors = [ + x for i in self._config.recurrents for x in new_state[i] + ] + if state_tensors: + states = array_ops.concat(state_tensors, 1) + else: + states = array_ops.zeros([0, 0], dtype) return output, states + def _extract_states(self, state): + """Extract the cell and previous output tensors from the given state. + + Args: + state: The RNN state. + + Returns: + Tuple of the cell value, previous output, and cell_output_size. + + Raises: + ValueError: If len(self._config.recurrents) != len(state). + """ + conf = self._config + + # c_prev is `m` (cell value), and + # m_prev is `h` (previous output) in the paper. + # Keeping c and m here for consistency with the codebase + c_prev = [None] * conf.num_dims + m_prev = [None] * conf.num_dims + + # for LSTM : state = memory cell + output, hence cell_output_size > 0 + # for GRU/RNN: state = output (whose size is equal to _num_units), + # hence cell_output_size = 0 + total_cell_state_size = self._cell_state_size() + cell_output_size = total_cell_state_size - conf.num_units + + if self._state_is_tuple: + if len(conf.recurrents) != len(state): + raise ValueError('Expected state as a tuple of {} ' + 'element'.format(len(conf.recurrents))) + + for recurrent_dim, recurrent_state in zip(conf.recurrents, state): + if cell_output_size > 0: + c_prev[recurrent_dim], m_prev[recurrent_dim] = recurrent_state + else: + m_prev[recurrent_dim] = recurrent_state + else: + for recurrent_dim, start_idx in zip(conf.recurrents, + range(0, self.state_size, + total_cell_state_size)): + if cell_output_size > 0: + c_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx], + [-1, conf.num_units]) + m_prev[recurrent_dim] = array_ops.slice( + state, [0, start_idx + conf.num_units], [-1, cell_output_size]) + else: + m_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx], + [-1, conf.num_units]) + return c_prev, m_prev, cell_output_size + + def _project_input(self, inputs, c_prev, m_prev, with_c): + """Fills in c_prev and m_prev with projected input, for input dimensions. + + Args: + inputs: inputs tensor + c_prev: cell value + m_prev: previous output + with_c: boolean; whether to include project_c. + + Raises: + ValueError: if len(self._config.input) != len(inputs) + """ + conf = self._config + + if (inputs is not None and inputs.get_shape().with_rank(2)[1].value > 0 and + conf.inputs): + if isinstance(inputs, tuple): + if len(conf.inputs) != len(inputs): + raise ValueError('Expect inputs as a tuple of {} ' + 'tensors'.format(len(conf.inputs))) + input_splits = inputs + else: + input_splits = array_ops.split( + value=inputs, num_or_size_splits=len(conf.inputs), axis=1) + input_sz = input_splits[0].get_shape().with_rank(2)[1].value + + for i, j in enumerate(conf.inputs): + input_project_m = vs.get_variable( + 'project_m_{}'.format(j), [input_sz, conf.num_units], + dtype=inputs.dtype) + m_prev[j] = math_ops.matmul(input_splits[i], input_project_m) + + if with_c: + input_project_c = vs.get_variable( + 'project_c_{}'.format(j), [input_sz, conf.num_units], + dtype=inputs.dtype) + c_prev[j] = math_ops.matmul(input_splits[i], input_project_c) + + def _cell_state_size(self): + """Total size of the state of the inner cell used in this grid. + + Returns: + Total size of the state of the inner cell. + """ + state_sizes = self._cells[0].state_size + if isinstance(state_sizes, tuple): + return sum(state_sizes) + return state_sizes + """Specialized cells, for convenience """ @@ -223,11 +324,17 @@ class GridRNNCell(rnn.RNNCell): class Grid1BasicRNNCell(GridRNNCell): """1D BasicRNN cell""" - def __init__(self, num_units): + def __init__(self, num_units, state_is_tuple=True, output_is_tuple=True): super(Grid1BasicRNNCell, self).__init__( - num_units=num_units, num_dims=1, - input_dims=0, output_dims=0, priority_dims=0, tied=False, - cell_fn=lambda n, i: rnn.BasicRNNCell(num_units=n, input_size=i)) + num_units=num_units, + num_dims=1, + input_dims=0, + output_dims=0, + priority_dims=0, + tied=False, + cell_fn=lambda n: rnn.BasicRNNCell(num_units=n), + state_is_tuple=state_is_tuple, + output_is_tuple=output_is_tuple) class Grid2BasicRNNCell(GridRNNCell): @@ -240,71 +347,112 @@ class Grid2BasicRNNCell(GridRNNCell): specified. """ - def __init__(self, num_units, tied=False, non_recurrent_fn=None): + def __init__(self, + num_units, + tied=False, + non_recurrent_fn=None, + state_is_tuple=True, + output_is_tuple=True): super(Grid2BasicRNNCell, self).__init__( - num_units=num_units, num_dims=2, - input_dims=0, output_dims=0, priority_dims=0, tied=tied, + num_units=num_units, + num_dims=2, + input_dims=0, + output_dims=0, + priority_dims=0, + tied=tied, non_recurrent_dims=None if non_recurrent_fn is None else 0, - cell_fn=lambda n, i: rnn.BasicRNNCell(num_units=n, input_size=i), - non_recurrent_fn=non_recurrent_fn) + cell_fn=lambda n: rnn.BasicRNNCell(num_units=n), + non_recurrent_fn=non_recurrent_fn, + state_is_tuple=state_is_tuple, + output_is_tuple=output_is_tuple) class Grid1BasicLSTMCell(GridRNNCell): - """1D BasicLSTM cell""" + """1D BasicLSTM cell.""" - def __init__(self, num_units, forget_bias=1): + def __init__(self, + num_units, + forget_bias=1, + state_is_tuple=True, + output_is_tuple=True): + def cell_fn(n): + return rnn.BasicLSTMCell(num_units=n, forget_bias=forget_bias) super(Grid1BasicLSTMCell, self).__init__( - num_units=num_units, num_dims=1, - input_dims=0, output_dims=0, priority_dims=0, tied=False, - cell_fn=lambda n, i: rnn.BasicLSTMCell( - num_units=n, - forget_bias=forget_bias, input_size=i, - state_is_tuple=False)) + num_units=num_units, + num_dims=1, + input_dims=0, + output_dims=0, + priority_dims=0, + tied=False, + cell_fn=cell_fn, + state_is_tuple=state_is_tuple, + output_is_tuple=output_is_tuple) class Grid2BasicLSTMCell(GridRNNCell): - """2D BasicLSTM cell + """2D BasicLSTM cell. - This creates a 2D cell which receives input and gives output in the first - dimension. + This creates a 2D cell which receives input and gives output in the first + dimension. - The first dimension can optionally be non-recurrent if `non_recurrent_fn` is - specified. + The first dimension can optionally be non-recurrent if `non_recurrent_fn` is + specified. """ def __init__(self, num_units, tied=False, non_recurrent_fn=None, - forget_bias=1): + forget_bias=1, + state_is_tuple=True, + output_is_tuple=True): + def cell_fn(n): + return rnn.BasicLSTMCell(num_units=n, forget_bias=forget_bias) super(Grid2BasicLSTMCell, self).__init__( - num_units=num_units, num_dims=2, - input_dims=0, output_dims=0, priority_dims=0, tied=tied, + num_units=num_units, + num_dims=2, + input_dims=0, + output_dims=0, + priority_dims=0, + tied=tied, non_recurrent_dims=None if non_recurrent_fn is None else 0, - cell_fn=lambda n, i: rnn.BasicLSTMCell( - num_units=n, forget_bias=forget_bias, input_size=i, - state_is_tuple=False), - non_recurrent_fn=non_recurrent_fn) + cell_fn=cell_fn, + non_recurrent_fn=non_recurrent_fn, + state_is_tuple=state_is_tuple, + output_is_tuple=output_is_tuple) class Grid1LSTMCell(GridRNNCell): - """1D LSTM cell + """1D LSTM cell. - This is different from Grid1BasicLSTMCell because it gives options to - specify the forget bias and enabling peepholes + This is different from Grid1BasicLSTMCell because it gives options to + specify the forget bias and enabling peepholes. """ - def __init__(self, num_units, use_peepholes=False, forget_bias=1.0): + def __init__(self, + num_units, + use_peepholes=False, + forget_bias=1.0, + state_is_tuple=True, + output_is_tuple=True): + + def cell_fn(n): + return rnn.LSTMCell( + num_units=n, forget_bias=forget_bias, use_peepholes=use_peepholes) + super(Grid1LSTMCell, self).__init__( - num_units=num_units, num_dims=1, - input_dims=0, output_dims=0, priority_dims=0, - cell_fn=lambda n, i: rnn.LSTMCell( - num_units=n, input_size=i, use_peepholes=use_peepholes, - forget_bias=forget_bias, state_is_tuple=False)) + num_units=num_units, + num_dims=1, + input_dims=0, + output_dims=0, + priority_dims=0, + cell_fn=cell_fn, + state_is_tuple=state_is_tuple, + output_is_tuple=output_is_tuple) class Grid2LSTMCell(GridRNNCell): - """2D LSTM cell + """2D LSTM cell. This creates a 2D cell which receives input and gives output in the first dimension. @@ -317,19 +465,30 @@ class Grid2LSTMCell(GridRNNCell): tied=False, non_recurrent_fn=None, use_peepholes=False, - forget_bias=1.0): + forget_bias=1.0, + state_is_tuple=True, + output_is_tuple=True): + + def cell_fn(n): + return rnn.LSTMCell( + num_units=n, forget_bias=forget_bias, use_peepholes=use_peepholes) + super(Grid2LSTMCell, self).__init__( - num_units=num_units, num_dims=2, - input_dims=0, output_dims=0, priority_dims=0, tied=tied, + num_units=num_units, + num_dims=2, + input_dims=0, + output_dims=0, + priority_dims=0, + tied=tied, non_recurrent_dims=None if non_recurrent_fn is None else 0, - cell_fn=lambda n, i: rnn.LSTMCell( - num_units=n, input_size=i, forget_bias=forget_bias, - use_peepholes=use_peepholes, state_is_tuple=False), - non_recurrent_fn=non_recurrent_fn) + cell_fn=cell_fn, + non_recurrent_fn=non_recurrent_fn, + state_is_tuple=state_is_tuple, + output_is_tuple=output_is_tuple) class Grid3LSTMCell(GridRNNCell): - """3D BasicLSTM cell + """3D BasicLSTM cell. This creates a 2D cell which receives input and gives output in the first dimension. @@ -343,19 +502,30 @@ class Grid3LSTMCell(GridRNNCell): tied=False, non_recurrent_fn=None, use_peepholes=False, - forget_bias=1.0): + forget_bias=1.0, + state_is_tuple=True, + output_is_tuple=True): + + def cell_fn(n): + return rnn.LSTMCell( + num_units=n, forget_bias=forget_bias, use_peepholes=use_peepholes) + super(Grid3LSTMCell, self).__init__( - num_units=num_units, num_dims=3, - input_dims=0, output_dims=0, priority_dims=0, tied=tied, + num_units=num_units, + num_dims=3, + input_dims=0, + output_dims=0, + priority_dims=0, + tied=tied, non_recurrent_dims=None if non_recurrent_fn is None else 0, - cell_fn=lambda n, i: rnn.LSTMCell( - num_units=n, input_size=i, forget_bias=forget_bias, - use_peepholes=use_peepholes, state_is_tuple=False), - non_recurrent_fn=non_recurrent_fn) + cell_fn=cell_fn, + non_recurrent_fn=non_recurrent_fn, + state_is_tuple=state_is_tuple, + output_is_tuple=output_is_tuple) class Grid2GRUCell(GridRNNCell): - """2D LSTM cell + """2D LSTM cell. This creates a 2D cell which receives input and gives output in the first dimension. @@ -363,21 +533,31 @@ class Grid2GRUCell(GridRNNCell): specified. """ - def __init__(self, num_units, tied=False, non_recurrent_fn=None): + def __init__(self, + num_units, + tied=False, + non_recurrent_fn=None, + state_is_tuple=True, + output_is_tuple=True): super(Grid2GRUCell, self).__init__( - num_units=num_units, num_dims=2, - input_dims=0, output_dims=0, priority_dims=0, tied=tied, + num_units=num_units, + num_dims=2, + input_dims=0, + output_dims=0, + priority_dims=0, + tied=tied, non_recurrent_dims=None if non_recurrent_fn is None else 0, - cell_fn=lambda n, i: rnn.GRUCell(num_units=n, input_size=i), - non_recurrent_fn=non_recurrent_fn) + cell_fn=lambda n: rnn.GRUCell(num_units=n), + non_recurrent_fn=non_recurrent_fn, + state_is_tuple=state_is_tuple, + output_is_tuple=output_is_tuple) -"""Helpers -""" +# Helpers -_GridRNNDimension = namedtuple( - '_GridRNNDimension', - ['idx', 'is_input', 'is_output', 'is_priority', 'non_recurrent_fn']) +_GridRNNDimension = namedtuple('_GridRNNDimension', [ + 'idx', 'is_input', 'is_output', 'is_priority', 'non_recurrent_fn' +]) _GridRNNConfig = namedtuple('_GridRNNConfig', ['num_dims', 'dims', 'inputs', 'outputs', @@ -387,7 +567,6 @@ _GridRNNConfig = namedtuple('_GridRNNConfig', def _parse_rnn_config(num_dims, ls_input_dims, ls_output_dims, ls_priority_dims, ls_non_recurrent_dims, non_recurrent_fn, tied, num_units): - def check_dim_list(ls): if ls is None: ls = [] @@ -412,8 +591,8 @@ def _parse_rnn_config(num_dims, ls_input_dims, ls_output_dims, ls_priority_dims, is_input=(i in input_dims), is_output=(i in output_dims), is_priority=(i in priority_dims), - non_recurrent_fn=non_recurrent_fn if i in non_recurrent_dims else - None)) + non_recurrent_fn=non_recurrent_fn + if i in non_recurrent_dims else None)) return _GridRNNConfig( num_dims=num_dims, dims=rnn_dims, @@ -440,34 +619,40 @@ def _propagate(dim_indices, conf, cells, c_prev, m_prev, new_output, new_state, if conf.num_dims > 1: ls_cell_inputs = [None] * (conf.num_dims - 1) for d in conf.dims[:-1]: - ls_cell_inputs[d.idx] = new_output[d.idx] if new_output[ - d.idx] is not None else m_prev[d.idx] + if new_output[d.idx] is None: + ls_cell_inputs[d.idx] = m_prev[d.idx] + else: + ls_cell_inputs[d.idx] = new_output[d.idx] cell_inputs = array_ops.concat(ls_cell_inputs, 1) else: cell_inputs = array_ops.zeros([m_prev[0].get_shape().as_list()[0], 0], m_prev[0].dtype) - last_dim_output = new_output[-1] if new_output[-1] is not None else m_prev[-1] + last_dim_output = (new_output[-1] + if new_output[-1] is not None else m_prev[-1]) for i in dim_indices: d = conf.dims[i] if d.non_recurrent_fn: - linear_args = array_ops.concat( - [cell_inputs, last_dim_output], - 1) if conf.num_dims > 1 else last_dim_output + if conf.num_dims > 1: + linear_args = array_ops.concat([cell_inputs, last_dim_output], 1) + else: + linear_args = last_dim_output with vs.variable_scope('non_recurrent' if conf.tied else 'non_recurrent/cell_{}'.format(i)): if conf.tied and not (first_call and i == dim_indices[0]): vs.get_variable_scope().reuse_variables() - new_output[d.idx] = layers.legacy_fully_connected( + + new_output[d.idx] = layers.fully_connected( linear_args, - num_output_units=conf.num_units, + num_outputs=conf.num_units, activation_fn=d.non_recurrent_fn, - weight_init=vs.get_variable_scope().initializer or - layers.initializers.xavier_initializer) + weights_initializer=(vs.get_variable_scope().initializer or + layers.initializers.xavier_initializer), + weights_regularizer=vs.get_variable_scope().regularizer) else: if c_prev[i] is not None: - cell_state = array_ops.concat([c_prev[i], last_dim_output], 1) + cell_state = (c_prev[i], last_dim_output) else: # for GRU/RNN, the state is just the previous output cell_state = last_dim_output |