aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/grid_rnn
diff options
context:
space:
mode:
authorGravatar Dan Ringwalt <ringwalt@google.com>2017-05-05 09:09:05 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-05 10:26:00 -0700
commit692fad20f913ffa2cb874a87578ecabb03cc4557 (patch)
tree172717f537c91b0d1ac0366731b4eb2093fb743b /tensorflow/contrib/grid_rnn
parentb329dd821e29e64c93b1b9bf38e61871c6cb53da (diff)
Merge changes from github.
Change: 155209832
Diffstat (limited to 'tensorflow/contrib/grid_rnn')
-rw-r--r--tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py631
-rw-r--r--tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py497
2 files changed, 743 insertions, 385 deletions
diff --git a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
index 758e0bcc07..280271a42d 100644
--- a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
+++ b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
@@ -34,180 +34,228 @@ from tensorflow.python.platform import test
class GridRNNCellTest(test.TestCase):
def testGrid2BasicLSTMCell(self):
- with self.test_session() as sess:
+ with self.test_session(use_gpu=False) as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.2)) as root_scope:
x = array_ops.zeros([1, 3])
- m = array_ops.zeros([1, 8])
+ m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),
+ (array_ops.zeros([1, 2]), array_ops.zeros([1, 2])))
cell = grid_rnn_cell.Grid2BasicLSTMCell(2)
- self.assertEqual(cell.state_size, 8)
+ self.assertEqual(cell.state_size, ((2, 2), (2, 2)))
g, s = cell(x, m)
- self.assertEqual(g.get_shape(), (1, 2))
- self.assertEqual(s.get_shape(), (1, 8))
+ self.assertEqual(g[0].get_shape(), (1, 2))
+ self.assertEqual(s[0].c.get_shape(), (1, 2))
+ self.assertEqual(s[0].h.get_shape(), (1, 2))
+ self.assertEqual(s[1].c.get_shape(), (1, 2))
+ self.assertEqual(s[1].h.get_shape(), (1, 2))
sess.run([variables.global_variables_initializer()])
- res = sess.run([g, s], {
- x: np.array([[1., 1., 1.]]),
- m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])
+ res_g, res_s = sess.run([g, s], {
+ x:
+ np.array([[1., 1., 1.]]),
+ m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),
+ (np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))
})
- self.assertEqual(res[0].shape, (1, 2))
- self.assertEqual(res[1].shape, (1, 8))
- self.assertAllClose(res[0], [[0.36617181, 0.36617181]])
- self.assertAllClose(res[1], [[0.71053141, 0.71053141, 0.36617181,
- 0.36617181, 0.72320831, 0.80555487,
- 0.39102408, 0.42150158]])
+ self.assertEqual(res_g[0].shape, (1, 2))
+ self.assertEqual(res_s[0].c.shape, (1, 2))
+ self.assertEqual(res_s[0].h.shape, (1, 2))
+ self.assertEqual(res_s[1].c.shape, (1, 2))
+ self.assertEqual(res_s[1].h.shape, (1, 2))
+
+ self.assertAllClose(res_g, ([[0.36617181, 0.36617181]],))
+ self.assertAllClose(
+ res_s, (([[0.71053141, 0.71053141]], [[0.36617181, 0.36617181]]),
+ ([[0.72320831, 0.80555487]], [[0.39102408, 0.42150158]])))
# emulate a loop through the input sequence,
# where we call cell() multiple times
root_scope.reuse_variables()
g2, s2 = cell(x, m)
- self.assertEqual(g2.get_shape(), (1, 2))
- self.assertEqual(s2.get_shape(), (1, 8))
-
- res = sess.run([g2, s2], {x: np.array([[2., 2., 2.]]), m: res[1]})
- self.assertEqual(res[0].shape, (1, 2))
- self.assertEqual(res[1].shape, (1, 8))
- self.assertAllClose(res[0], [[0.58847463, 0.58847463]])
- self.assertAllClose(res[1], [[1.40469193, 1.40469193, 0.58847463,
- 0.58847463, 0.97726452, 1.04626071,
- 0.4927212, 0.51137757]])
+ self.assertEqual(g2[0].get_shape(), (1, 2))
+ self.assertEqual(s2[0].c.get_shape(), (1, 2))
+ self.assertEqual(s2[0].h.get_shape(), (1, 2))
+ self.assertEqual(s2[1].c.get_shape(), (1, 2))
+ self.assertEqual(s2[1].h.get_shape(), (1, 2))
+
+ res_g2, res_s2 = sess.run([g2, s2],
+ {x: np.array([[2., 2., 2.]]),
+ m: res_s})
+ self.assertEqual(res_g2[0].shape, (1, 2))
+ self.assertEqual(res_s2[0].c.shape, (1, 2))
+ self.assertEqual(res_s2[0].h.shape, (1, 2))
+ self.assertEqual(res_s2[1].c.shape, (1, 2))
+ self.assertEqual(res_s2[1].h.shape, (1, 2))
+ self.assertAllClose(res_g2[0], [[0.58847463, 0.58847463]])
+ self.assertAllClose(
+ res_s2, (([[1.40469193, 1.40469193]], [[0.58847463, 0.58847463]]),
+ ([[0.97726452, 1.04626071]], [[0.4927212, 0.51137757]])))
def testGrid2BasicLSTMCellTied(self):
- with self.test_session() as sess:
+ with self.test_session(use_gpu=False) as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.2)):
x = array_ops.zeros([1, 3])
- m = array_ops.zeros([1, 8])
+ m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),
+ (array_ops.zeros([1, 2]), array_ops.zeros([1, 2])))
cell = grid_rnn_cell.Grid2BasicLSTMCell(2, tied=True)
- self.assertEqual(cell.state_size, 8)
+ self.assertEqual(cell.state_size, ((2, 2), (2, 2)))
g, s = cell(x, m)
- self.assertEqual(g.get_shape(), (1, 2))
- self.assertEqual(s.get_shape(), (1, 8))
+ self.assertEqual(g[0].get_shape(), (1, 2))
+ self.assertEqual(s[0].c.get_shape(), (1, 2))
+ self.assertEqual(s[0].h.get_shape(), (1, 2))
+ self.assertEqual(s[1].c.get_shape(), (1, 2))
+ self.assertEqual(s[1].h.get_shape(), (1, 2))
sess.run([variables.global_variables_initializer()])
- res = sess.run([g, s], {
- x: np.array([[1., 1., 1.]]),
- m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])
+ res_g, res_s = sess.run([g, s], {
+ x:
+ np.array([[1., 1., 1.]]),
+ m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),
+ (np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))
})
- self.assertEqual(res[0].shape, (1, 2))
- self.assertEqual(res[1].shape, (1, 8))
- self.assertAllClose(res[0], [[0.36617181, 0.36617181]])
- self.assertAllClose(res[1], [[0.71053141, 0.71053141, 0.36617181,
- 0.36617181, 0.72320831, 0.80555487,
- 0.39102408, 0.42150158]])
+ self.assertEqual(res_g[0].shape, (1, 2))
+ self.assertEqual(res_s[0].c.shape, (1, 2))
+ self.assertEqual(res_s[0].h.shape, (1, 2))
+ self.assertEqual(res_s[1].c.shape, (1, 2))
+ self.assertEqual(res_s[1].h.shape, (1, 2))
- res = sess.run([g, s], {x: np.array([[1., 1., 1.]]), m: res[1]})
- self.assertEqual(res[0].shape, (1, 2))
- self.assertEqual(res[1].shape, (1, 8))
- self.assertAllClose(res[0], [[0.36703536, 0.36703536]])
- self.assertAllClose(res[1], [[0.71200621, 0.71200621, 0.36703536,
- 0.36703536, 0.80941606, 0.87550586,
- 0.40108523, 0.42199609]])
+ self.assertAllClose(res_g[0], [[0.36617181, 0.36617181]])
+ self.assertAllClose(
+ res_s, (([[0.71053141, 0.71053141]], [[0.36617181, 0.36617181]]),
+ ([[0.72320831, 0.80555487]], [[0.39102408, 0.42150158]])))
+
+ res_g, res_s = sess.run([g, s], {x: np.array([[1., 1., 1.]]), m: res_s})
+ self.assertEqual(res_g[0].shape, (1, 2))
+
+ self.assertAllClose(res_g[0], [[0.36703536, 0.36703536]])
+ self.assertAllClose(
+ res_s, (([[0.71200621, 0.71200621]], [[0.36703536, 0.36703536]]),
+ ([[0.80941606, 0.87550586]], [[0.40108523, 0.42199609]])))
def testGrid2BasicLSTMCellWithRelu(self):
- with self.test_session() as sess:
+ with self.test_session(use_gpu=False) as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.2)):
x = array_ops.zeros([1, 3])
- m = array_ops.zeros([1, 4])
+ m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),)
cell = grid_rnn_cell.Grid2BasicLSTMCell(
2, tied=False, non_recurrent_fn=nn_ops.relu)
- self.assertEqual(cell.state_size, 4)
+ self.assertEqual(cell.state_size, ((2, 2),))
g, s = cell(x, m)
- self.assertEqual(g.get_shape(), (1, 2))
- self.assertEqual(s.get_shape(), (1, 4))
+ self.assertEqual(g[0].get_shape(), (1, 2))
+ self.assertEqual(s[0].c.get_shape(), (1, 2))
+ self.assertEqual(s[0].h.get_shape(), (1, 2))
sess.run([variables.global_variables_initializer()])
- res = sess.run(
- [g, s],
- {x: np.array([[1., 1., 1.]]),
- m: np.array([[0.1, 0.2, 0.3, 0.4]])})
- self.assertEqual(res[0].shape, (1, 2))
- self.assertEqual(res[1].shape, (1, 4))
- self.assertAllClose(res[0], [[0.31667367, 0.31667367]])
- self.assertAllClose(res[1], [[0.29530135, 0.37520045, 0.17044567,
- 0.21292259]])
+ res_g, res_s = sess.run([g, s], {
+ x: np.array([[1., 1., 1.]]),
+ m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),)
+ })
+ self.assertEqual(res_g[0].shape, (1, 2))
+ self.assertAllClose(res_g[0], [[0.31667367, 0.31667367]])
+ self.assertAllClose(res_s, (([[0.29530135, 0.37520045]],
+ [[0.17044567, 0.21292259]]),))
"""LSTMCell
"""
def testGrid2LSTMCell(self):
- with self.test_session() as sess:
+ with self.test_session(use_gpu=False) as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
- m = array_ops.zeros([1, 8])
+ m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),
+ (array_ops.zeros([1, 2]), array_ops.zeros([1, 2])))
cell = grid_rnn_cell.Grid2LSTMCell(2, use_peepholes=True)
- self.assertEqual(cell.state_size, 8)
+ self.assertEqual(cell.state_size, ((2, 2), (2, 2)))
g, s = cell(x, m)
- self.assertEqual(g.get_shape(), (1, 2))
- self.assertEqual(s.get_shape(), (1, 8))
+ self.assertEqual(g[0].get_shape(), (1, 2))
+ self.assertEqual(s[0].c.get_shape(), (1, 2))
+ self.assertEqual(s[0].h.get_shape(), (1, 2))
+ self.assertEqual(s[1].c.get_shape(), (1, 2))
+ self.assertEqual(s[1].h.get_shape(), (1, 2))
sess.run([variables.global_variables_initializer()])
- res = sess.run([g, s], {
- x: np.array([[1., 1., 1.]]),
- m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])
+ res_g, res_s = sess.run([g, s], {
+ x:
+ np.array([[1., 1., 1.]]),
+ m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),
+ (np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))
})
- self.assertEqual(res[0].shape, (1, 2))
- self.assertEqual(res[1].shape, (1, 8))
- self.assertAllClose(res[0], [[0.95686918, 0.95686918]])
- self.assertAllClose(res[1], [[2.41515064, 2.41515064, 0.95686918,
- 0.95686918, 1.38917875, 1.49043763,
- 0.83884692, 0.86036491]])
+ self.assertEqual(res_g[0].shape, (1, 2))
+ self.assertEqual(res_s[0].c.shape, (1, 2))
+ self.assertEqual(res_s[0].h.shape, (1, 2))
+ self.assertEqual(res_s[1].c.shape, (1, 2))
+ self.assertEqual(res_s[1].h.shape, (1, 2))
+
+ self.assertAllClose(res_g[0], [[0.95686918, 0.95686918]])
+ self.assertAllClose(
+ res_s, (([[2.41515064, 2.41515064]], [[0.95686918, 0.95686918]]),
+ ([[1.38917875, 1.49043763]], [[0.83884692, 0.86036491]])))
def testGrid2LSTMCellTied(self):
- with self.test_session() as sess:
+ with self.test_session(use_gpu=False) as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
- m = array_ops.zeros([1, 8])
+ m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),
+ (array_ops.zeros([1, 2]), array_ops.zeros([1, 2])))
cell = grid_rnn_cell.Grid2LSTMCell(2, tied=True, use_peepholes=True)
- self.assertEqual(cell.state_size, 8)
+ self.assertEqual(cell.state_size, ((2, 2), (2, 2)))
g, s = cell(x, m)
- self.assertEqual(g.get_shape(), (1, 2))
- self.assertEqual(s.get_shape(), (1, 8))
+ self.assertEqual(g[0].get_shape(), (1, 2))
+ self.assertEqual(s[0].c.get_shape(), (1, 2))
+ self.assertEqual(s[0].h.get_shape(), (1, 2))
+ self.assertEqual(s[1].c.get_shape(), (1, 2))
+ self.assertEqual(s[1].h.get_shape(), (1, 2))
sess.run([variables.global_variables_initializer()])
- res = sess.run([g, s], {
- x: np.array([[1., 1., 1.]]),
- m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])
+ res_g, res_s = sess.run([g, s], {
+ x:
+ np.array([[1., 1., 1.]]),
+ m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),
+ (np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))
})
- self.assertEqual(res[0].shape, (1, 2))
- self.assertEqual(res[1].shape, (1, 8))
- self.assertAllClose(res[0], [[0.95686918, 0.95686918]])
- self.assertAllClose(res[1], [[2.41515064, 2.41515064, 0.95686918,
- 0.95686918, 1.38917875, 1.49043763,
- 0.83884692, 0.86036491]])
+ self.assertEqual(res_g[0].shape, (1, 2))
+ self.assertEqual(res_s[0].c.shape, (1, 2))
+ self.assertEqual(res_s[0].h.shape, (1, 2))
+ self.assertEqual(res_s[1].c.shape, (1, 2))
+ self.assertEqual(res_s[1].h.shape, (1, 2))
+
+ self.assertAllClose(res_g[0], [[0.95686918, 0.95686918]])
+ self.assertAllClose(
+ res_s, (([[2.41515064, 2.41515064]], [[0.95686918, 0.95686918]]),
+ ([[1.38917875, 1.49043763]], [[0.83884692, 0.86036491]])))
def testGrid2LSTMCellWithRelu(self):
with self.test_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
- m = array_ops.zeros([1, 4])
+ m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),)
cell = grid_rnn_cell.Grid2LSTMCell(
2, use_peepholes=True, non_recurrent_fn=nn_ops.relu)
- self.assertEqual(cell.state_size, 4)
+ self.assertEqual(cell.state_size, ((2, 2),))
g, s = cell(x, m)
- self.assertEqual(g.get_shape(), (1, 2))
- self.assertEqual(s.get_shape(), (1, 4))
+ self.assertEqual(g[0].get_shape(), (1, 2))
+ self.assertEqual(s[0].c.get_shape(), (1, 2))
+ self.assertEqual(s[0].h.get_shape(), (1, 2))
sess.run([variables.global_variables_initializer()])
- res = sess.run(
- [g, s],
- {x: np.array([[1., 1., 1.]]),
- m: np.array([[0.1, 0.2, 0.3, 0.4]])})
- self.assertEqual(res[0].shape, (1, 2))
- self.assertEqual(res[1].shape, (1, 4))
- self.assertAllClose(res[0], [[2.1831727, 2.1831727]])
- self.assertAllClose(res[1], [[0.92270052, 1.02325559, 0.66159075,
- 0.70475441]])
+ res_g, res_s = sess.run([g, s], {
+ x: np.array([[1., 1., 1.]]),
+ m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),)
+ })
+ self.assertEqual(res_g[0].shape, (1, 2))
+ self.assertAllClose(res_g[0], [[2.1831727, 2.1831727]])
+ self.assertAllClose(res_s, (([[0.92270052, 1.02325559]],
+ [[0.66159075, 0.70475441]]),))
"""RNNCell
"""
@@ -217,74 +265,84 @@ class GridRNNCellTest(test.TestCase):
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([2, 2])
- m = array_ops.zeros([2, 4])
+ m = (array_ops.zeros([2, 2]), array_ops.zeros([2, 2]))
cell = grid_rnn_cell.Grid2BasicRNNCell(2)
- self.assertEqual(cell.state_size, 4)
+ self.assertEqual(cell.state_size, (2, 2))
g, s = cell(x, m)
- self.assertEqual(g.get_shape(), (2, 2))
- self.assertEqual(s.get_shape(), (2, 4))
+ self.assertEqual(g[0].get_shape(), (2, 2))
+ self.assertEqual(s[0].get_shape(), (2, 2))
+ self.assertEqual(s[1].get_shape(), (2, 2))
sess.run([variables.global_variables_initializer()])
- res = sess.run([g, s], {
- x: np.array([[1., 1.], [2., 2.]]),
- m: np.array([[0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2]])
+ res_g, res_s = sess.run([g, s], {
+ x:
+ np.array([[1., 1.], [2., 2.]]),
+ m: (np.array([[0.1, 0.1], [0.2, 0.2]]), np.array([[0.1, 0.1],
+ [0.2, 0.2]]))
})
- self.assertEqual(res[0].shape, (2, 2))
- self.assertEqual(res[1].shape, (2, 4))
- self.assertAllClose(res[0], [[0.94685763, 0.94685763],
- [0.99480951, 0.99480951]])
- self.assertAllClose(res[1],
- [[0.94685763, 0.94685763, 0.80049908, 0.80049908],
- [0.99480951, 0.99480951, 0.97574311, 0.97574311]])
+ self.assertEqual(res_g[0].shape, (2, 2))
+ self.assertEqual(res_s[0].shape, (2, 2))
+ self.assertEqual(res_s[1].shape, (2, 2))
+
+ self.assertAllClose(res_g, ([[0.94685763, 0.94685763],
+ [0.99480951, 0.99480951]],))
+ self.assertAllClose(
+ res_s, ([[0.94685763, 0.94685763], [0.99480951, 0.99480951]],
+ [[0.80049908, 0.80049908], [0.97574311, 0.97574311]]))
def testGrid2BasicRNNCellTied(self):
with self.test_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([2, 2])
- m = array_ops.zeros([2, 4])
+ m = (array_ops.zeros([2, 2]), array_ops.zeros([2, 2]))
cell = grid_rnn_cell.Grid2BasicRNNCell(2, tied=True)
- self.assertEqual(cell.state_size, 4)
+ self.assertEqual(cell.state_size, (2, 2))
g, s = cell(x, m)
- self.assertEqual(g.get_shape(), (2, 2))
- self.assertEqual(s.get_shape(), (2, 4))
+ self.assertEqual(g[0].get_shape(), (2, 2))
+ self.assertEqual(s[0].get_shape(), (2, 2))
+ self.assertEqual(s[1].get_shape(), (2, 2))
sess.run([variables.global_variables_initializer()])
- res = sess.run([g, s], {
- x: np.array([[1., 1.], [2., 2.]]),
- m: np.array([[0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2]])
+ res_g, res_s = sess.run([g, s], {
+ x:
+ np.array([[1., 1.], [2., 2.]]),
+ m: (np.array([[0.1, 0.1], [0.2, 0.2]]), np.array([[0.1, 0.1],
+ [0.2, 0.2]]))
})
- self.assertEqual(res[0].shape, (2, 2))
- self.assertEqual(res[1].shape, (2, 4))
- self.assertAllClose(res[0], [[0.94685763, 0.94685763],
- [0.99480951, 0.99480951]])
- self.assertAllClose(res[1],
- [[0.94685763, 0.94685763, 0.80049908, 0.80049908],
- [0.99480951, 0.99480951, 0.97574311, 0.97574311]])
+ self.assertEqual(res_g[0].shape, (2, 2))
+ self.assertEqual(res_s[0].shape, (2, 2))
+ self.assertEqual(res_s[1].shape, (2, 2))
+
+ self.assertAllClose(res_g, ([[0.94685763, 0.94685763],
+ [0.99480951, 0.99480951]],))
+ self.assertAllClose(
+ res_s, ([[0.94685763, 0.94685763], [0.99480951, 0.99480951]],
+ [[0.80049908, 0.80049908], [0.97574311, 0.97574311]]))
def testGrid2BasicRNNCellWithRelu(self):
with self.test_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
- m = array_ops.zeros([1, 2])
+ m = (array_ops.zeros([1, 2]),)
cell = grid_rnn_cell.Grid2BasicRNNCell(2, non_recurrent_fn=nn_ops.relu)
- self.assertEqual(cell.state_size, 2)
+ self.assertEqual(cell.state_size, (2,))
g, s = cell(x, m)
- self.assertEqual(g.get_shape(), (1, 2))
- self.assertEqual(s.get_shape(), (1, 2))
+ self.assertEqual(g[0].get_shape(), (1, 2))
+ self.assertEqual(s[0].get_shape(), (1, 2))
sess.run([variables.global_variables_initializer()])
- res = sess.run([g, s],
- {x: np.array([[1., 1.]]),
- m: np.array([[0.1, 0.1]])})
- self.assertEqual(res[0].shape, (1, 2))
- self.assertEqual(res[1].shape, (1, 2))
- self.assertAllClose(res[0], [[1.80049896, 1.80049896]])
- self.assertAllClose(res[1], [[0.80049896, 0.80049896]])
+ res_g, res_s = sess.run(
+ [g, s], {x: np.array([[1., 1.]]),
+ m: np.array([[0.1, 0.1]])})
+ self.assertEqual(res_g[0].shape, (1, 2))
+ self.assertEqual(res_s[0].shape, (1, 2))
+ self.assertAllClose(res_g, ([[1.80049896, 1.80049896]],))
+ self.assertAllClose(res_s, ([[0.80049896, 0.80049896]],))
"""1-LSTM
"""
@@ -294,51 +352,59 @@ class GridRNNCellTest(test.TestCase):
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)) as root_scope:
x = array_ops.zeros([1, 3])
- m = array_ops.zeros([1, 4])
+ m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),)
cell = grid_rnn_cell.Grid1LSTMCell(2, use_peepholes=True)
- self.assertEqual(cell.state_size, 4)
+ self.assertEqual(cell.state_size, ((2, 2),))
g, s = cell(x, m)
- self.assertEqual(g.get_shape(), (1, 2))
- self.assertEqual(s.get_shape(), (1, 4))
+ self.assertEqual(g[0].get_shape(), (1, 2))
+ self.assertEqual(s[0].c.get_shape(), (1, 2))
+ self.assertEqual(s[0].h.get_shape(), (1, 2))
sess.run([variables.global_variables_initializer()])
- res = sess.run(
- [g, s],
- {x: np.array([[1., 1., 1.]]),
- m: np.array([[0.1, 0.2, 0.3, 0.4]])})
- self.assertEqual(res[0].shape, (1, 2))
- self.assertEqual(res[1].shape, (1, 4))
- self.assertAllClose(res[0], [[0.91287315, 0.91287315]])
- self.assertAllClose(res[1],
- [[2.26285243, 2.26285243, 0.91287315, 0.91287315]])
+ res_g, res_s = sess.run([g, s], {
+ x: np.array([[1., 1., 1.]]),
+ m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),)
+ })
+ self.assertEqual(res_g[0].shape, (1, 2))
+ self.assertEqual(res_s[0].c.shape, (1, 2))
+ self.assertEqual(res_s[0].h.shape, (1, 2))
+
+ self.assertAllClose(res_g, ([[0.91287315, 0.91287315]],))
+ self.assertAllClose(res_s, (([[2.26285243, 2.26285243]],
+ [[0.91287315, 0.91287315]]),))
root_scope.reuse_variables()
x2 = array_ops.zeros([0, 0])
g2, s2 = cell(x2, m)
- self.assertEqual(g2.get_shape(), (1, 2))
- self.assertEqual(s2.get_shape(), (1, 4))
+ self.assertEqual(g2[0].get_shape(), (1, 2))
+ self.assertEqual(s2[0].c.get_shape(), (1, 2))
+ self.assertEqual(s2[0].h.get_shape(), (1, 2))
sess.run([variables.global_variables_initializer()])
- res = sess.run([g2, s2], {m: res[1]})
- self.assertEqual(res[0].shape, (1, 2))
- self.assertEqual(res[1].shape, (1, 4))
- self.assertAllClose(res[0], [[0.9032144, 0.9032144]])
- self.assertAllClose(res[1],
- [[2.79966092, 2.79966092, 0.9032144, 0.9032144]])
+ res_g2, res_s2 = sess.run([g2, s2], {m: res_s})
+ self.assertEqual(res_g2[0].shape, (1, 2))
+ self.assertEqual(res_s2[0].c.shape, (1, 2))
+ self.assertEqual(res_s2[0].h.shape, (1, 2))
+
+ self.assertAllClose(res_g2, ([[0.9032144, 0.9032144]],))
+ self.assertAllClose(res_s2, (([[2.79966092, 2.79966092]],
+ [[0.9032144, 0.9032144]]),))
g3, s3 = cell(x2, m)
- self.assertEqual(g3.get_shape(), (1, 2))
- self.assertEqual(s3.get_shape(), (1, 4))
+ self.assertEqual(g3[0].get_shape(), (1, 2))
+ self.assertEqual(s3[0].c.get_shape(), (1, 2))
+ self.assertEqual(s3[0].h.get_shape(), (1, 2))
sess.run([variables.global_variables_initializer()])
- res = sess.run([g3, s3], {m: res[1]})
- self.assertEqual(res[0].shape, (1, 2))
- self.assertEqual(res[1].shape, (1, 4))
- self.assertAllClose(res[0], [[0.92727238, 0.92727238]])
- self.assertAllClose(res[1],
- [[3.3529923, 3.3529923, 0.92727238, 0.92727238]])
+ res_g3, res_s3 = sess.run([g3, s3], {m: res_s2})
+ self.assertEqual(res_g3[0].shape, (1, 2))
+ self.assertEqual(res_s3[0].c.shape, (1, 2))
+ self.assertEqual(res_s3[0].h.shape, (1, 2))
+ self.assertAllClose(res_g3, ([[0.92727238, 0.92727238]],))
+ self.assertAllClose(res_s3, (([[3.3529923, 3.3529923]],
+ [[0.92727238, 0.92727238]]),))
"""3-LSTM
"""
@@ -348,32 +414,42 @@ class GridRNNCellTest(test.TestCase):
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
- m = array_ops.zeros([1, 12])
+ m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),
+ (array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),
+ (array_ops.zeros([1, 2]), array_ops.zeros([1, 2])))
cell = grid_rnn_cell.Grid3LSTMCell(2, use_peepholes=True)
- self.assertEqual(cell.state_size, 12)
+ self.assertEqual(cell.state_size, ((2, 2), (2, 2), (2, 2)))
g, s = cell(x, m)
- self.assertEqual(g.get_shape(), (1, 2))
- self.assertEqual(s.get_shape(), (1, 12))
+ self.assertEqual(g[0].get_shape(), (1, 2))
+ self.assertEqual(s[0].c.get_shape(), (1, 2))
+ self.assertEqual(s[0].h.get_shape(), (1, 2))
+ self.assertEqual(s[1].c.get_shape(), (1, 2))
+ self.assertEqual(s[1].h.get_shape(), (1, 2))
+ self.assertEqual(s[2].c.get_shape(), (1, 2))
+ self.assertEqual(s[2].h.get_shape(), (1, 2))
sess.run([variables.global_variables_initializer()])
- res = sess.run([g, s], {
+ res_g, res_s = sess.run([g, s], {
x:
np.array([[1., 1., 1.]]),
- m:
- np.array([[
- 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, -0.1, -0.2, -0.3,
- -0.4
- ]])
+ m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),
+ (np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])), (np.array(
+ [[-0.1, -0.2]]), np.array([[-0.3, -0.4]])))
})
- self.assertEqual(res[0].shape, (1, 2))
- self.assertEqual(res[1].shape, (1, 12))
-
- self.assertAllClose(res[0], [[0.96892911, 0.96892911]])
- self.assertAllClose(res[1], [[2.45227885, 2.45227885, 0.96892911,
- 0.96892911, 1.33592629, 1.4373529,
- 0.80867189, 0.83247656, 0.7317788,
- 0.63205892, 0.56548983, 0.50446129]])
+ self.assertEqual(res_g[0].shape, (1, 2))
+ self.assertEqual(res_s[0].c.shape, (1, 2))
+ self.assertEqual(res_s[0].h.shape, (1, 2))
+ self.assertEqual(res_s[1].c.shape, (1, 2))
+ self.assertEqual(res_s[1].h.shape, (1, 2))
+ self.assertEqual(res_s[2].c.shape, (1, 2))
+ self.assertEqual(res_s[2].h.shape, (1, 2))
+
+ self.assertAllClose(res_g, ([[0.96892911, 0.96892911]],))
+ self.assertAllClose(
+ res_s, (([[2.45227885, 2.45227885]], [[0.96892911, 0.96892911]]),
+ ([[1.33592629, 1.4373529]], [[0.80867189, 0.83247656]]),
+ ([[0.7317788, 0.63205892]], [[0.56548983, 0.50446129]])))
"""Edge cases
"""
@@ -383,7 +459,7 @@ class GridRNNCellTest(test.TestCase):
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([3, 2])
- m = array_ops.zeros([0, 0])
+ m = ()
# this is equivalent to relu
cell = grid_rnn_cell.GridRNNCell(
@@ -394,21 +470,22 @@ class GridRNNCellTest(test.TestCase):
non_recurrent_dims=0,
non_recurrent_fn=nn_ops.relu)
g, s = cell(x, m)
- self.assertEqual(g.get_shape(), (3, 2))
- self.assertEqual(s.get_shape(), (0, 0))
+ self.assertEqual(g[0].get_shape(), (3, 2))
+ self.assertEqual(s, ())
sess.run([variables.global_variables_initializer()])
- res = sess.run([g, s], {x: np.array([[1., -1.], [-2, 1], [2, -1]])})
- self.assertEqual(res[0].shape, (3, 2))
- self.assertEqual(res[1].shape, (0, 0))
- self.assertAllClose(res[0], [[0, 0], [0, 0], [0.5, 0.5]])
+ res_g, res_s = sess.run([g, s],
+ {x: np.array([[1., -1.], [-2, 1], [2, -1]])})
+ self.assertEqual(res_g[0].shape, (3, 2))
+ self.assertEqual(res_s, ())
+ self.assertAllClose(res_g, ([[0, 0], [0, 0], [0.5, 0.5]],))
def testGridRNNEdgeCasesNoOutput(self):
with self.test_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
- m = array_ops.zeros([1, 4])
+ m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),)
# This cell produces no output
cell = grid_rnn_cell.GridRNNCell(
@@ -419,16 +496,18 @@ class GridRNNCellTest(test.TestCase):
non_recurrent_dims=0,
non_recurrent_fn=nn_ops.relu)
g, s = cell(x, m)
- self.assertEqual(g.get_shape(), (0, 0))
- self.assertEqual(s.get_shape(), (1, 4))
+ self.assertEqual(g, ())
+ self.assertEqual(s[0].c.get_shape(), (1, 2))
+ self.assertEqual(s[0].h.get_shape(), (1, 2))
sess.run([variables.global_variables_initializer()])
- res = sess.run(
- [g, s],
- {x: np.array([[1., 1.]]),
- m: np.array([[0.1, 0.1, 0.1, 0.1]])})
- self.assertEqual(res[0].shape, (0, 0))
- self.assertEqual(res[1].shape, (1, 4))
+ res_g, res_s = sess.run([g, s], {
+ x: np.array([[1., 1.]]),
+ m: ((np.array([[0.1, 0.1]]), np.array([[0.1, 0.1]])),)
+ })
+ self.assertEqual(res_g, ())
+ self.assertEqual(res_s[0].c.shape, (1, 2))
+ self.assertEqual(res_s[0].h.shape, (1, 2))
"""Test with tf.nn.rnn
"""
@@ -451,20 +530,29 @@ class GridRNNCellTest(test.TestCase):
outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
self.assertEqual(len(outputs), len(inputs))
- self.assertEqual(state.get_shape(), (batch_size, 8))
+ self.assertEqual(state[0].c.get_shape(), (batch_size, 2))
+ self.assertEqual(state[0].h.get_shape(), (batch_size, 2))
+ self.assertEqual(state[1].c.get_shape(), (batch_size, 2))
+ self.assertEqual(state[1].h.get_shape(), (batch_size, 2))
for out, inp in zip(outputs, inputs):
- self.assertEqual(out.get_shape()[0], inp.get_shape()[0])
- self.assertEqual(out.get_shape()[1], num_units)
- self.assertEqual(out.dtype, inp.dtype)
+ self.assertEqual(len(out), 1)
+ self.assertEqual(out[0].get_shape()[0], inp.get_shape()[0])
+ self.assertEqual(out[0].get_shape()[1], num_units)
+ self.assertEqual(out[0].dtype, inp.dtype)
with self.test_session() as sess:
sess.run(variables.global_variables_initializer())
input_value = np.ones((batch_size, input_size))
values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value})
- for v in values:
- self.assertTrue(np.all(np.isfinite(v)))
+ for tp in values[:-1]:
+ for v in tp:
+ self.assertTrue(np.all(np.isfinite(v)))
+ for tp in values[-1]:
+ for st in tp:
+ for v in st:
+ self.assertTrue(np.all(np.isfinite(v)))
def testGrid2LSTMCellReLUWithRNN(self):
batch_size = 3
@@ -478,27 +566,33 @@ class GridRNNCellTest(test.TestCase):
num_units=num_units, non_recurrent_fn=nn_ops.relu)
inputs = max_length * [
- array_ops.placeholder(
- dtypes.float32, shape=(batch_size, input_size))
+ array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size))
]
outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
self.assertEqual(len(outputs), len(inputs))
- self.assertEqual(state.get_shape(), (batch_size, 4))
+ self.assertEqual(state[0].c.get_shape(), (batch_size, 2))
+ self.assertEqual(state[0].h.get_shape(), (batch_size, 2))
for out, inp in zip(outputs, inputs):
- self.assertEqual(out.get_shape()[0], inp.get_shape()[0])
- self.assertEqual(out.get_shape()[1], num_units)
- self.assertEqual(out.dtype, inp.dtype)
+ self.assertEqual(len(out), 1)
+ self.assertEqual(out[0].get_shape()[0], inp.get_shape()[0])
+ self.assertEqual(out[0].get_shape()[1], num_units)
+ self.assertEqual(out[0].dtype, inp.dtype)
with self.test_session() as sess:
sess.run(variables.global_variables_initializer())
input_value = np.ones((batch_size, input_size))
values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value})
- for v in values:
- self.assertTrue(np.all(np.isfinite(v)))
+ for tp in values[:-1]:
+ for v in tp:
+ self.assertTrue(np.all(np.isfinite(v)))
+ for tp in values[-1]:
+ for st in tp:
+ for v in st:
+ self.assertTrue(np.all(np.isfinite(v)))
def testGrid3LSTMCellReLUWithRNN(self):
batch_size = 3
@@ -512,27 +606,35 @@ class GridRNNCellTest(test.TestCase):
num_units=num_units, non_recurrent_fn=nn_ops.relu)
inputs = max_length * [
- array_ops.placeholder(
- dtypes.float32, shape=(batch_size, input_size))
+ array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size))
]
outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
self.assertEqual(len(outputs), len(inputs))
- self.assertEqual(state.get_shape(), (batch_size, 8))
+ self.assertEqual(state[0].c.get_shape(), (batch_size, 2))
+ self.assertEqual(state[0].h.get_shape(), (batch_size, 2))
+ self.assertEqual(state[1].c.get_shape(), (batch_size, 2))
+ self.assertEqual(state[1].h.get_shape(), (batch_size, 2))
for out, inp in zip(outputs, inputs):
- self.assertEqual(out.get_shape()[0], inp.get_shape()[0])
- self.assertEqual(out.get_shape()[1], num_units)
- self.assertEqual(out.dtype, inp.dtype)
+ self.assertEqual(len(out), 1)
+ self.assertEqual(out[0].get_shape()[0], inp.get_shape()[0])
+ self.assertEqual(out[0].get_shape()[1], num_units)
+ self.assertEqual(out[0].dtype, inp.dtype)
with self.test_session() as sess:
sess.run(variables.global_variables_initializer())
input_value = np.ones((batch_size, input_size))
values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value})
- for v in values:
- self.assertTrue(np.all(np.isfinite(v)))
+ for tp in values[:-1]:
+ for v in tp:
+ self.assertTrue(np.all(np.isfinite(v)))
+ for tp in values[-1]:
+ for st in tp:
+ for v in st:
+ self.assertTrue(np.all(np.isfinite(v)))
def testGrid1LSTMCellWithRNN(self):
batch_size = 3
@@ -553,20 +655,91 @@ class GridRNNCellTest(test.TestCase):
outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
self.assertEqual(len(outputs), len(inputs))
- self.assertEqual(state.get_shape(), (batch_size, 4))
+ self.assertEqual(state[0].c.get_shape(), (batch_size, 2))
+ self.assertEqual(state[0].h.get_shape(), (batch_size, 2))
for out, inp in zip(outputs, inputs):
- self.assertEqual(out.get_shape(), (3, num_units))
- self.assertEqual(out.dtype, inp.dtype)
+ self.assertEqual(len(out), 1)
+ self.assertEqual(out[0].get_shape(), (3, num_units))
+ self.assertEqual(out[0].dtype, inp.dtype)
with self.test_session() as sess:
sess.run(variables.global_variables_initializer())
input_value = np.ones((batch_size, input_size))
values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value})
- for v in values:
- self.assertTrue(np.all(np.isfinite(v)))
+ for tp in values[:-1]:
+ for v in tp:
+ self.assertTrue(np.all(np.isfinite(v)))
+ for tp in values[-1]:
+ for st in tp:
+ for v in st:
+ self.assertTrue(np.all(np.isfinite(v)))
+
+ def testGrid2LSTMCellWithRNNAndDynamicBatchSize(self):
+ """Test for #4296."""
+ input_size = 5
+ max_length = 6 # unrolled up to this length
+ num_units = 2
+
+ with variable_scope.variable_scope(
+ 'root', initializer=init_ops.constant_initializer(0.5)):
+ cell = grid_rnn_cell.Grid2LSTMCell(num_units=num_units)
+ inputs = max_length * [
+ array_ops.placeholder(dtypes.float32, shape=(None, input_size))
+ ]
+
+ outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
+
+ self.assertEqual(len(outputs), len(inputs))
+
+ for out, inp in zip(outputs, inputs):
+ self.assertEqual(len(out), 1)
+ self.assertTrue(out[0].get_shape()[0].value is None)
+ self.assertEqual(out[0].get_shape()[1], num_units)
+ self.assertEqual(out[0].dtype, inp.dtype)
+
+ with self.test_session() as sess:
+ sess.run(variables.global_variables_initializer())
+
+ input_value = np.ones((3, input_size))
+ values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value})
+ for tp in values[:-1]:
+ for v in tp:
+ self.assertTrue(np.all(np.isfinite(v)))
+ for tp in values[-1]:
+ for st in tp:
+ for v in st:
+ self.assertTrue(np.all(np.isfinite(v)))
+
+ def testGrid2LSTMCellLegacy(self):
+ """Test for legacy case (when state_is_tuple=False)."""
+ with self.test_session() as sess:
+ with variable_scope.variable_scope(
+ 'root', initializer=init_ops.constant_initializer(0.5)):
+ x = array_ops.zeros([1, 3])
+ m = array_ops.zeros([1, 8])
+ cell = grid_rnn_cell.Grid2LSTMCell(
+ 2, use_peepholes=True, state_is_tuple=False, output_is_tuple=False)
+ self.assertEqual(cell.state_size, 8)
+
+ g, s = cell(x, m)
+ self.assertEqual(g.get_shape(), (1, 2))
+ self.assertEqual(s.get_shape(), (1, 8))
+
+ sess.run([variables.global_variables_initializer()])
+ res = sess.run([g, s], {
+ x: np.array([[1., 1., 1.]]),
+ m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])
+ })
+ self.assertEqual(res[0].shape, (1, 2))
+ self.assertEqual(res[1].shape, (1, 8))
+ self.assertAllClose(res[0], [[0.95686918, 0.95686918]])
+ self.assertAllClose(res[1], [[
+ 2.41515064, 2.41515064, 0.95686918, 0.95686918, 1.38917875,
+ 1.49043763, 0.83884692, 0.86036491
+ ]])
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py b/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py
index 269b224581..252788140f 100644
--- a/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py
+++ b/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py
@@ -25,6 +25,8 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import variable_scope as vs
+
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.contrib import layers
from tensorflow.contrib import rnn
@@ -53,7 +55,9 @@ class GridRNNCell(rnn.RNNCell):
non_recurrent_dims=None,
tied=False,
cell_fn=None,
- non_recurrent_fn=None):
+ non_recurrent_fn=None,
+ state_is_tuple=True,
+ output_is_tuple=True):
"""Initialize the parameters of a Grid RNN cell
Args:
@@ -68,26 +72,47 @@ class GridRNNCell(rnn.RNNCell):
non_recurrent_dims: int or list, List of dimensions that are not
recurrent.
The transfer function for non-recurrent dimensions is specified
- via `non_recurrent_fn`,
- which is default to be `tensorflow.nn.relu`.
+ via `non_recurrent_fn`, which is
+ default to be `tensorflow.nn.relu`.
tied: bool, Whether to share the weights among the dimensions of this
GridRNN cell.
If there are non-recurrent dimensions in the grid, weights are
- shared between each
- group of recurrent and non-recurrent dimensions.
- cell_fn: function, a function which returns the recurrent cell object. Has
- to be in the following signature:
- def cell_func(num_units, input_size):
+ shared between each group of recurrent and non-recurrent
+ dimensions.
+ cell_fn: function, a function which returns the recurrent cell object.
+ Has to be in the following signature:
+ ```
+ def cell_func(num_units):
# ...
-
+ ```
and returns an object of type `RNNCell`. If None, LSTMCell with
default parameters will be used.
+ Note that if you use a custom RNNCell (with `cell_fn`), it is your
+ responsibility to make sure the inner cell use `state_is_tuple=True`.
+
non_recurrent_fn: a tensorflow Op that will be the transfer function of
the non-recurrent dimensions
+ state_is_tuple: If True, accepted and returned states are tuples of the
+ states of the recurrent dimensions. If False, they are concatenated
+ along the column axis. The latter behavior will soon be deprecated.
+
+ Note that if you use a custom RNNCell (with `cell_fn`), it is your
+ responsibility to make sure the inner cell use `state_is_tuple=True`.
+
+ output_is_tuple: If True, the output is a tuple of the outputs of the
+ recurrent dimensions. If False, they are concatenated along the
+ column axis. The later behavior will soon be deprecated.
Raises:
TypeError: if cell_fn does not return an RNNCell instance.
"""
+ if not state_is_tuple:
+ logging.warning('%s: Using a concatenated state is slower and will '
+ 'soon be deprecated. Use state_is_tuple=True.', self)
+ if not output_is_tuple:
+ logging.warning('%s: Using a concatenated output is slower and will'
+ 'soon be deprecated. Use output_is_tuple=True.', self)
+
if num_dims < 1:
raise ValueError('dims must be >= 1: {}'.format(num_dims))
@@ -96,37 +121,41 @@ class GridRNNCell(rnn.RNNCell):
non_recurrent_fn or nn.relu, tied,
num_units)
- cell_input_size = (self._config.num_dims - 1) * num_units
+ self._state_is_tuple = state_is_tuple
+ self._output_is_tuple = output_is_tuple
+
if cell_fn is None:
my_cell_fn = functools.partial(
- rnn.LSTMCell,
- num_units=num_units, input_size=cell_input_size,
- state_is_tuple=False)
+ rnn.LSTMCell, num_units=num_units, state_is_tuple=state_is_tuple)
else:
- my_cell_fn = lambda: cell_fn(num_units, cell_input_size)
+ my_cell_fn = lambda: cell_fn(num_units)
if tied:
self._cells = [my_cell_fn()] * num_dims
else:
self._cells = [my_cell_fn() for _ in range(num_dims)]
if not isinstance(self._cells[0], rnn.RNNCell):
- raise TypeError(
- 'cell_fn must return an RNNCell instance, saw: %s'
- % type(self._cells[0]))
+ raise TypeError('cell_fn must return an RNNCell instance, saw: %s' %
+ type(self._cells[0]))
- @property
- def input_size(self):
- # temporarily using num_units as the input_size of each dimension.
- # The actual input size only determined when this cell get invoked,
- # so this information can be considered unreliable.
- return self._config.num_units * len(self._config.inputs)
+ if self._output_is_tuple:
+ self._output_size = tuple(self._cells[0].output_size
+ for _ in self._config.outputs)
+ else:
+ self._output_size = self._cells[0].output_size * len(self._config.outputs)
+
+ if self._state_is_tuple:
+ self._state_size = tuple(self._cells[0].state_size
+ for _ in self._config.recurrents)
+ else:
+ self._state_size = self._cell_state_size() * len(self._config.recurrents)
@property
def output_size(self):
- return self._cells[0].output_size * len(self._config.outputs)
+ return self._output_size
@property
def state_size(self):
- return self._cells[0].state_size * len(self._config.recurrents)
+ return self._state_size
def __call__(self, inputs, state, scope=None):
"""Run one step of GridRNN.
@@ -145,76 +174,148 @@ class GridRNNCell(rnn.RNNCell):
- A 2D, batch x state_size, Tensor representing the new state of the cell
after reading "inputs" when previous state was "state".
"""
- state_sz = state.get_shape().as_list()[1]
- if self.state_size != state_sz:
- raise ValueError(
- 'Actual state size not same as specified: {} vs {}.'.format(
- state_sz, self.state_size))
-
conf = self._config
- dtype = inputs.dtype if inputs is not None else state.dtype
+ dtype = inputs.dtype
- # c_prev is `m`, and m_prev is `h` in the paper.
- # Keep c and m here for consistency with the codebase
- c_prev = [None] * self._config.num_dims
- m_prev = [None] * self._config.num_dims
- cell_output_size = self._cells[0].state_size - conf.num_units
-
- # for LSTM : state = memory cell + output, hence cell_output_size > 0
- # for GRU/RNN: state = output (whose size is equal to _num_units),
- # hence cell_output_size = 0
- for recurrent_dim, start_idx in zip(self._config.recurrents, range(
- 0, self.state_size, self._cells[0].state_size)):
- if cell_output_size > 0:
- c_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx],
- [-1, conf.num_units])
- m_prev[recurrent_dim] = array_ops.slice(
- state, [0, start_idx + conf.num_units], [-1, cell_output_size])
- else:
- m_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx],
- [-1, conf.num_units])
+ c_prev, m_prev, cell_output_size = self._extract_states(state)
new_output = [None] * conf.num_dims
new_state = [None] * conf.num_dims
with vs.variable_scope(scope or type(self).__name__): # GridRNNCell
+ # project input, populate c_prev and m_prev
+ self._project_input(inputs, c_prev, m_prev, cell_output_size > 0)
- # project input
- if inputs is not None and sum(inputs.get_shape().as_list()) > 0 and len(
- conf.inputs) > 0:
- input_splits = array_ops.split(
- value=inputs, num_or_size_splits=len(conf.inputs), axis=1)
- input_sz = input_splits[0].get_shape().as_list()[1]
-
- for i, j in enumerate(conf.inputs):
- input_project_m = vs.get_variable(
- 'project_m_{}'.format(j), [input_sz, conf.num_units], dtype=dtype)
- m_prev[j] = math_ops.matmul(input_splits[i], input_project_m)
-
- if cell_output_size > 0:
- input_project_c = vs.get_variable(
- 'project_c_{}'.format(j), [input_sz, conf.num_units],
- dtype=dtype)
- c_prev[j] = math_ops.matmul(input_splits[i], input_project_c)
-
+ # propagate along dimensions, first for non-priority dimensions
+ # then priority dimensions
_propagate(conf.non_priority, conf, self._cells, c_prev, m_prev,
new_output, new_state, True)
_propagate(conf.priority, conf, self._cells,
c_prev, m_prev, new_output, new_state, False)
+ # collect outputs and states
output_tensors = [new_output[i] for i in self._config.outputs]
- output = array_ops.zeros(
- [0, 0], dtype) if len(output_tensors) == 0 else array_ops.concat(
- output_tensors, 1)
+ if self._output_is_tuple:
+ output = tuple(output_tensors)
+ else:
+ if output_tensors:
+ output = array_ops.concat(output_tensors, 1)
+ else:
+ output = array_ops.zeros([0, 0], dtype)
- state_tensors = [new_state[i] for i in self._config.recurrents]
- states = array_ops.zeros(
- [0, 0],
- dtype) if len(state_tensors) == 0 else array_ops.concat(state_tensors,
- 1)
+ if self._state_is_tuple:
+ states = tuple(new_state[i] for i in self._config.recurrents)
+ else:
+ # concat each state first, then flatten the whole thing
+ state_tensors = [
+ x for i in self._config.recurrents for x in new_state[i]
+ ]
+ if state_tensors:
+ states = array_ops.concat(state_tensors, 1)
+ else:
+ states = array_ops.zeros([0, 0], dtype)
return output, states
+ def _extract_states(self, state):
+ """Extract the cell and previous output tensors from the given state.
+
+ Args:
+ state: The RNN state.
+
+ Returns:
+ Tuple of the cell value, previous output, and cell_output_size.
+
+ Raises:
+ ValueError: If len(self._config.recurrents) != len(state).
+ """
+ conf = self._config
+
+ # c_prev is `m` (cell value), and
+ # m_prev is `h` (previous output) in the paper.
+ # Keeping c and m here for consistency with the codebase
+ c_prev = [None] * conf.num_dims
+ m_prev = [None] * conf.num_dims
+
+ # for LSTM : state = memory cell + output, hence cell_output_size > 0
+ # for GRU/RNN: state = output (whose size is equal to _num_units),
+ # hence cell_output_size = 0
+ total_cell_state_size = self._cell_state_size()
+ cell_output_size = total_cell_state_size - conf.num_units
+
+ if self._state_is_tuple:
+ if len(conf.recurrents) != len(state):
+ raise ValueError('Expected state as a tuple of {} '
+ 'element'.format(len(conf.recurrents)))
+
+ for recurrent_dim, recurrent_state in zip(conf.recurrents, state):
+ if cell_output_size > 0:
+ c_prev[recurrent_dim], m_prev[recurrent_dim] = recurrent_state
+ else:
+ m_prev[recurrent_dim] = recurrent_state
+ else:
+ for recurrent_dim, start_idx in zip(conf.recurrents,
+ range(0, self.state_size,
+ total_cell_state_size)):
+ if cell_output_size > 0:
+ c_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx],
+ [-1, conf.num_units])
+ m_prev[recurrent_dim] = array_ops.slice(
+ state, [0, start_idx + conf.num_units], [-1, cell_output_size])
+ else:
+ m_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx],
+ [-1, conf.num_units])
+ return c_prev, m_prev, cell_output_size
+
+ def _project_input(self, inputs, c_prev, m_prev, with_c):
+ """Fills in c_prev and m_prev with projected input, for input dimensions.
+
+ Args:
+ inputs: inputs tensor
+ c_prev: cell value
+ m_prev: previous output
+ with_c: boolean; whether to include project_c.
+
+ Raises:
+ ValueError: if len(self._config.input) != len(inputs)
+ """
+ conf = self._config
+
+ if (inputs is not None and inputs.get_shape().with_rank(2)[1].value > 0 and
+ conf.inputs):
+ if isinstance(inputs, tuple):
+ if len(conf.inputs) != len(inputs):
+ raise ValueError('Expect inputs as a tuple of {} '
+ 'tensors'.format(len(conf.inputs)))
+ input_splits = inputs
+ else:
+ input_splits = array_ops.split(
+ value=inputs, num_or_size_splits=len(conf.inputs), axis=1)
+ input_sz = input_splits[0].get_shape().with_rank(2)[1].value
+
+ for i, j in enumerate(conf.inputs):
+ input_project_m = vs.get_variable(
+ 'project_m_{}'.format(j), [input_sz, conf.num_units],
+ dtype=inputs.dtype)
+ m_prev[j] = math_ops.matmul(input_splits[i], input_project_m)
+
+ if with_c:
+ input_project_c = vs.get_variable(
+ 'project_c_{}'.format(j), [input_sz, conf.num_units],
+ dtype=inputs.dtype)
+ c_prev[j] = math_ops.matmul(input_splits[i], input_project_c)
+
+ def _cell_state_size(self):
+ """Total size of the state of the inner cell used in this grid.
+
+ Returns:
+ Total size of the state of the inner cell.
+ """
+ state_sizes = self._cells[0].state_size
+ if isinstance(state_sizes, tuple):
+ return sum(state_sizes)
+ return state_sizes
+
"""Specialized cells, for convenience
"""
@@ -223,11 +324,17 @@ class GridRNNCell(rnn.RNNCell):
class Grid1BasicRNNCell(GridRNNCell):
"""1D BasicRNN cell"""
- def __init__(self, num_units):
+ def __init__(self, num_units, state_is_tuple=True, output_is_tuple=True):
super(Grid1BasicRNNCell, self).__init__(
- num_units=num_units, num_dims=1,
- input_dims=0, output_dims=0, priority_dims=0, tied=False,
- cell_fn=lambda n, i: rnn.BasicRNNCell(num_units=n, input_size=i))
+ num_units=num_units,
+ num_dims=1,
+ input_dims=0,
+ output_dims=0,
+ priority_dims=0,
+ tied=False,
+ cell_fn=lambda n: rnn.BasicRNNCell(num_units=n),
+ state_is_tuple=state_is_tuple,
+ output_is_tuple=output_is_tuple)
class Grid2BasicRNNCell(GridRNNCell):
@@ -240,71 +347,112 @@ class Grid2BasicRNNCell(GridRNNCell):
specified.
"""
- def __init__(self, num_units, tied=False, non_recurrent_fn=None):
+ def __init__(self,
+ num_units,
+ tied=False,
+ non_recurrent_fn=None,
+ state_is_tuple=True,
+ output_is_tuple=True):
super(Grid2BasicRNNCell, self).__init__(
- num_units=num_units, num_dims=2,
- input_dims=0, output_dims=0, priority_dims=0, tied=tied,
+ num_units=num_units,
+ num_dims=2,
+ input_dims=0,
+ output_dims=0,
+ priority_dims=0,
+ tied=tied,
non_recurrent_dims=None if non_recurrent_fn is None else 0,
- cell_fn=lambda n, i: rnn.BasicRNNCell(num_units=n, input_size=i),
- non_recurrent_fn=non_recurrent_fn)
+ cell_fn=lambda n: rnn.BasicRNNCell(num_units=n),
+ non_recurrent_fn=non_recurrent_fn,
+ state_is_tuple=state_is_tuple,
+ output_is_tuple=output_is_tuple)
class Grid1BasicLSTMCell(GridRNNCell):
- """1D BasicLSTM cell"""
+ """1D BasicLSTM cell."""
- def __init__(self, num_units, forget_bias=1):
+ def __init__(self,
+ num_units,
+ forget_bias=1,
+ state_is_tuple=True,
+ output_is_tuple=True):
+ def cell_fn(n):
+ return rnn.BasicLSTMCell(num_units=n, forget_bias=forget_bias)
super(Grid1BasicLSTMCell, self).__init__(
- num_units=num_units, num_dims=1,
- input_dims=0, output_dims=0, priority_dims=0, tied=False,
- cell_fn=lambda n, i: rnn.BasicLSTMCell(
- num_units=n,
- forget_bias=forget_bias, input_size=i,
- state_is_tuple=False))
+ num_units=num_units,
+ num_dims=1,
+ input_dims=0,
+ output_dims=0,
+ priority_dims=0,
+ tied=False,
+ cell_fn=cell_fn,
+ state_is_tuple=state_is_tuple,
+ output_is_tuple=output_is_tuple)
class Grid2BasicLSTMCell(GridRNNCell):
- """2D BasicLSTM cell
+ """2D BasicLSTM cell.
- This creates a 2D cell which receives input and gives output in the first
- dimension.
+ This creates a 2D cell which receives input and gives output in the first
+ dimension.
- The first dimension can optionally be non-recurrent if `non_recurrent_fn` is
- specified.
+ The first dimension can optionally be non-recurrent if `non_recurrent_fn` is
+ specified.
"""
def __init__(self,
num_units,
tied=False,
non_recurrent_fn=None,
- forget_bias=1):
+ forget_bias=1,
+ state_is_tuple=True,
+ output_is_tuple=True):
+ def cell_fn(n):
+ return rnn.BasicLSTMCell(num_units=n, forget_bias=forget_bias)
super(Grid2BasicLSTMCell, self).__init__(
- num_units=num_units, num_dims=2,
- input_dims=0, output_dims=0, priority_dims=0, tied=tied,
+ num_units=num_units,
+ num_dims=2,
+ input_dims=0,
+ output_dims=0,
+ priority_dims=0,
+ tied=tied,
non_recurrent_dims=None if non_recurrent_fn is None else 0,
- cell_fn=lambda n, i: rnn.BasicLSTMCell(
- num_units=n, forget_bias=forget_bias, input_size=i,
- state_is_tuple=False),
- non_recurrent_fn=non_recurrent_fn)
+ cell_fn=cell_fn,
+ non_recurrent_fn=non_recurrent_fn,
+ state_is_tuple=state_is_tuple,
+ output_is_tuple=output_is_tuple)
class Grid1LSTMCell(GridRNNCell):
- """1D LSTM cell
+ """1D LSTM cell.
- This is different from Grid1BasicLSTMCell because it gives options to
- specify the forget bias and enabling peepholes
+ This is different from Grid1BasicLSTMCell because it gives options to
+ specify the forget bias and enabling peepholes.
"""
- def __init__(self, num_units, use_peepholes=False, forget_bias=1.0):
+ def __init__(self,
+ num_units,
+ use_peepholes=False,
+ forget_bias=1.0,
+ state_is_tuple=True,
+ output_is_tuple=True):
+
+ def cell_fn(n):
+ return rnn.LSTMCell(
+ num_units=n, forget_bias=forget_bias, use_peepholes=use_peepholes)
+
super(Grid1LSTMCell, self).__init__(
- num_units=num_units, num_dims=1,
- input_dims=0, output_dims=0, priority_dims=0,
- cell_fn=lambda n, i: rnn.LSTMCell(
- num_units=n, input_size=i, use_peepholes=use_peepholes,
- forget_bias=forget_bias, state_is_tuple=False))
+ num_units=num_units,
+ num_dims=1,
+ input_dims=0,
+ output_dims=0,
+ priority_dims=0,
+ cell_fn=cell_fn,
+ state_is_tuple=state_is_tuple,
+ output_is_tuple=output_is_tuple)
class Grid2LSTMCell(GridRNNCell):
- """2D LSTM cell
+ """2D LSTM cell.
This creates a 2D cell which receives input and gives output in the first
dimension.
@@ -317,19 +465,30 @@ class Grid2LSTMCell(GridRNNCell):
tied=False,
non_recurrent_fn=None,
use_peepholes=False,
- forget_bias=1.0):
+ forget_bias=1.0,
+ state_is_tuple=True,
+ output_is_tuple=True):
+
+ def cell_fn(n):
+ return rnn.LSTMCell(
+ num_units=n, forget_bias=forget_bias, use_peepholes=use_peepholes)
+
super(Grid2LSTMCell, self).__init__(
- num_units=num_units, num_dims=2,
- input_dims=0, output_dims=0, priority_dims=0, tied=tied,
+ num_units=num_units,
+ num_dims=2,
+ input_dims=0,
+ output_dims=0,
+ priority_dims=0,
+ tied=tied,
non_recurrent_dims=None if non_recurrent_fn is None else 0,
- cell_fn=lambda n, i: rnn.LSTMCell(
- num_units=n, input_size=i, forget_bias=forget_bias,
- use_peepholes=use_peepholes, state_is_tuple=False),
- non_recurrent_fn=non_recurrent_fn)
+ cell_fn=cell_fn,
+ non_recurrent_fn=non_recurrent_fn,
+ state_is_tuple=state_is_tuple,
+ output_is_tuple=output_is_tuple)
class Grid3LSTMCell(GridRNNCell):
- """3D BasicLSTM cell
+ """3D BasicLSTM cell.
This creates a 2D cell which receives input and gives output in the first
dimension.
@@ -343,19 +502,30 @@ class Grid3LSTMCell(GridRNNCell):
tied=False,
non_recurrent_fn=None,
use_peepholes=False,
- forget_bias=1.0):
+ forget_bias=1.0,
+ state_is_tuple=True,
+ output_is_tuple=True):
+
+ def cell_fn(n):
+ return rnn.LSTMCell(
+ num_units=n, forget_bias=forget_bias, use_peepholes=use_peepholes)
+
super(Grid3LSTMCell, self).__init__(
- num_units=num_units, num_dims=3,
- input_dims=0, output_dims=0, priority_dims=0, tied=tied,
+ num_units=num_units,
+ num_dims=3,
+ input_dims=0,
+ output_dims=0,
+ priority_dims=0,
+ tied=tied,
non_recurrent_dims=None if non_recurrent_fn is None else 0,
- cell_fn=lambda n, i: rnn.LSTMCell(
- num_units=n, input_size=i, forget_bias=forget_bias,
- use_peepholes=use_peepholes, state_is_tuple=False),
- non_recurrent_fn=non_recurrent_fn)
+ cell_fn=cell_fn,
+ non_recurrent_fn=non_recurrent_fn,
+ state_is_tuple=state_is_tuple,
+ output_is_tuple=output_is_tuple)
class Grid2GRUCell(GridRNNCell):
- """2D LSTM cell
+ """2D LSTM cell.
This creates a 2D cell which receives input and gives output in the first
dimension.
@@ -363,21 +533,31 @@ class Grid2GRUCell(GridRNNCell):
specified.
"""
- def __init__(self, num_units, tied=False, non_recurrent_fn=None):
+ def __init__(self,
+ num_units,
+ tied=False,
+ non_recurrent_fn=None,
+ state_is_tuple=True,
+ output_is_tuple=True):
super(Grid2GRUCell, self).__init__(
- num_units=num_units, num_dims=2,
- input_dims=0, output_dims=0, priority_dims=0, tied=tied,
+ num_units=num_units,
+ num_dims=2,
+ input_dims=0,
+ output_dims=0,
+ priority_dims=0,
+ tied=tied,
non_recurrent_dims=None if non_recurrent_fn is None else 0,
- cell_fn=lambda n, i: rnn.GRUCell(num_units=n, input_size=i),
- non_recurrent_fn=non_recurrent_fn)
+ cell_fn=lambda n: rnn.GRUCell(num_units=n),
+ non_recurrent_fn=non_recurrent_fn,
+ state_is_tuple=state_is_tuple,
+ output_is_tuple=output_is_tuple)
-"""Helpers
-"""
+# Helpers
-_GridRNNDimension = namedtuple(
- '_GridRNNDimension',
- ['idx', 'is_input', 'is_output', 'is_priority', 'non_recurrent_fn'])
+_GridRNNDimension = namedtuple('_GridRNNDimension', [
+ 'idx', 'is_input', 'is_output', 'is_priority', 'non_recurrent_fn'
+])
_GridRNNConfig = namedtuple('_GridRNNConfig',
['num_dims', 'dims', 'inputs', 'outputs',
@@ -387,7 +567,6 @@ _GridRNNConfig = namedtuple('_GridRNNConfig',
def _parse_rnn_config(num_dims, ls_input_dims, ls_output_dims, ls_priority_dims,
ls_non_recurrent_dims, non_recurrent_fn, tied, num_units):
-
def check_dim_list(ls):
if ls is None:
ls = []
@@ -412,8 +591,8 @@ def _parse_rnn_config(num_dims, ls_input_dims, ls_output_dims, ls_priority_dims,
is_input=(i in input_dims),
is_output=(i in output_dims),
is_priority=(i in priority_dims),
- non_recurrent_fn=non_recurrent_fn if i in non_recurrent_dims else
- None))
+ non_recurrent_fn=non_recurrent_fn
+ if i in non_recurrent_dims else None))
return _GridRNNConfig(
num_dims=num_dims,
dims=rnn_dims,
@@ -440,34 +619,40 @@ def _propagate(dim_indices, conf, cells, c_prev, m_prev, new_output, new_state,
if conf.num_dims > 1:
ls_cell_inputs = [None] * (conf.num_dims - 1)
for d in conf.dims[:-1]:
- ls_cell_inputs[d.idx] = new_output[d.idx] if new_output[
- d.idx] is not None else m_prev[d.idx]
+ if new_output[d.idx] is None:
+ ls_cell_inputs[d.idx] = m_prev[d.idx]
+ else:
+ ls_cell_inputs[d.idx] = new_output[d.idx]
cell_inputs = array_ops.concat(ls_cell_inputs, 1)
else:
cell_inputs = array_ops.zeros([m_prev[0].get_shape().as_list()[0], 0],
m_prev[0].dtype)
- last_dim_output = new_output[-1] if new_output[-1] is not None else m_prev[-1]
+ last_dim_output = (new_output[-1]
+ if new_output[-1] is not None else m_prev[-1])
for i in dim_indices:
d = conf.dims[i]
if d.non_recurrent_fn:
- linear_args = array_ops.concat(
- [cell_inputs, last_dim_output],
- 1) if conf.num_dims > 1 else last_dim_output
+ if conf.num_dims > 1:
+ linear_args = array_ops.concat([cell_inputs, last_dim_output], 1)
+ else:
+ linear_args = last_dim_output
with vs.variable_scope('non_recurrent' if conf.tied else
'non_recurrent/cell_{}'.format(i)):
if conf.tied and not (first_call and i == dim_indices[0]):
vs.get_variable_scope().reuse_variables()
- new_output[d.idx] = layers.legacy_fully_connected(
+
+ new_output[d.idx] = layers.fully_connected(
linear_args,
- num_output_units=conf.num_units,
+ num_outputs=conf.num_units,
activation_fn=d.non_recurrent_fn,
- weight_init=vs.get_variable_scope().initializer or
- layers.initializers.xavier_initializer)
+ weights_initializer=(vs.get_variable_scope().initializer or
+ layers.initializers.xavier_initializer),
+ weights_regularizer=vs.get_variable_scope().regularizer)
else:
if c_prev[i] is not None:
- cell_state = array_ops.concat([c_prev[i], last_dim_output], 1)
+ cell_state = (c_prev[i], last_dim_output)
else:
# for GRU/RNN, the state is just the previous output
cell_state = last_dim_output