aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-06 20:38:27 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-06 20:42:32 -0800
commit17a0b492b1548830b87a048b931522b59bd7466a (patch)
treed80d8959d98d5de09f6cf402c624f16a17a8ddf5
parent99e29f79576a8a1fc4c32beae4c44f7af5ee53a7 (diff)
Makes GLSTMCell accept input of any compatible dimension.
Currently, GLSTMCell requires that the input dimension is is the same as the output dimension. After this change, the input can be any compatible dimension---i.e., anything divisible by the number of groups. The input size is still assumed to be the output size in the case where the innermost dimension of the input is not statically-defined. PiperOrigin-RevId: 188123536
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py107
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn_cell.py34
2 files changed, 99 insertions, 42 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
index eef1ae25e9..7de55a0bb3 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
@@ -1031,57 +1031,92 @@ class RNNCellTest(test.TestCase):
num_units = 4
number_of_groups = 1
- with self.test_session() as sess:
- with variable_scope.variable_scope(
- "root1", initializer=init_ops.constant_initializer(0.5)):
- x = array_ops.ones([batch_size, num_units])
- # When number_of_groups = 1, G-LSTM is equivalent to regular LSTM
- gcell = contrib_rnn_cell.GLSTMCell(
- num_units=num_units, number_of_groups=number_of_groups)
- cell = rnn_cell.LSTMCell(num_units=num_units)
- self.assertTrue(isinstance(gcell.state_size, tuple))
- zero_state = gcell.zero_state(
- batch_size=batch_size, dtype=dtypes.float32)
- gh, gs = gcell(x, zero_state)
- h, g = cell(x, zero_state)
+ # Try with input dimension equal to num_units or not.
+ for num_inputs in [num_units, num_units + number_of_groups]:
+ with self.test_session() as sess:
+ with variable_scope.variable_scope(
+ "root1_%d" % num_inputs,
+ initializer=init_ops.constant_initializer(0.5)):
+ x = array_ops.ones([batch_size, num_inputs])
+ # When number_of_groups = 1, G-LSTM is equivalent to regular LSTM
+ gcell = contrib_rnn_cell.GLSTMCell(
+ num_units=num_units, number_of_groups=number_of_groups)
+ cell = rnn_cell.LSTMCell(num_units=num_units)
+ self.assertTrue(isinstance(gcell.state_size, tuple))
+ zero_state = gcell.zero_state(
+ batch_size=batch_size, dtype=dtypes.float32)
+ gh, gs = gcell(x, zero_state)
+ h, g = cell(x, zero_state)
- sess.run([variables.global_variables_initializer()])
- glstm_result = sess.run([gh, gs])
- lstm_result = sess.run([h, g])
+ sess.run([variables.global_variables_initializer()])
+ glstm_result = sess.run([gh, gs])
+ lstm_result = sess.run([h, g])
- self.assertAllClose(glstm_result[0], lstm_result[0], 1e-5)
- self.assertAllClose(glstm_result[1], lstm_result[1], 1e-5)
+ self.assertAllClose(glstm_result[0], lstm_result[0], 1e-5)
+ self.assertAllClose(glstm_result[1], lstm_result[1], 1e-5)
# Test that G-LSTM subgroup act like corresponding sub-LSTMs
batch_size = 2
num_units = 4
number_of_groups = 2
- with self.test_session() as sess:
+ # Try with num_inputs equal to or not equal to num_units.
+ for num_inputs in [num_units, num_units + number_of_groups]:
+ with self.test_session() as sess:
+ with variable_scope.variable_scope(
+ "root2_%d" % num_inputs,
+ initializer=init_ops.constant_initializer(0.5)):
+ # input for G-LSTM with 2 groups
+ glstm_input = array_ops.ones([batch_size, num_inputs])
+ gcell = contrib_rnn_cell.GLSTMCell(
+ num_units=num_units, number_of_groups=number_of_groups)
+ gcell_zero_state = gcell.zero_state(
+ batch_size=batch_size, dtype=dtypes.float32)
+ gh, gs = gcell(glstm_input, gcell_zero_state)
+
+ # input for LSTM cell simulating single G-LSTM group
+ lstm_input = array_ops.ones(
+ [batch_size, num_inputs / number_of_groups])
+ # note division by number_of_groups. This cell one simulates G-LSTM
+ # group
+ cell = rnn_cell.LSTMCell(num_units=int(num_units / number_of_groups))
+ cell_zero_state = cell.zero_state(
+ batch_size=batch_size, dtype=dtypes.float32)
+ h, g = cell(lstm_input, cell_zero_state)
+
+ sess.run([variables.global_variables_initializer()])
+ [gh_res, h_res] = sess.run([gh, h])
+ self.assertAllClose(gh_res[:, 0:int(num_units / number_of_groups)],
+ h_res, 1e-5)
+ self.assertAllClose(gh_res[:, int(num_units / number_of_groups):],
+ h_res, 1e-5)
+
+ def testGLSTMCellFailure(self):
+ batch_size = 2
+ num_units = 4
+ number_of_groups = 2
+ with self.test_session():
with variable_scope.variable_scope(
- "root2", initializer=init_ops.constant_initializer(0.5)):
- # input for G-LSTM with 2 groups
- glstm_input = array_ops.ones([batch_size, num_units])
+ "glstm_failure", initializer=init_ops.constant_initializer(0.5)):
gcell = contrib_rnn_cell.GLSTMCell(
num_units=num_units, number_of_groups=number_of_groups)
gcell_zero_state = gcell.zero_state(
batch_size=batch_size, dtype=dtypes.float32)
- gh, gs = gcell(glstm_input, gcell_zero_state)
- # input for LSTM cell simulating single G-LSTM group
- lstm_input = array_ops.ones([batch_size, num_units / number_of_groups])
- # note division by number_of_groups. This cell one simulates G-LSTM group
- cell = rnn_cell.LSTMCell(num_units=int(num_units / number_of_groups))
- cell_zero_state = cell.zero_state(
- batch_size=batch_size, dtype=dtypes.float32)
- h, g = cell(lstm_input, cell_zero_state)
+ # Try an input with statically-unknown innermost dimension.
+ glstm_input = array_ops.placeholder(
+ dtypes.float32, shape=[batch_size, None])
+ with self.assertRaisesRegexp(ValueError,
+ "input size must be statically known"):
+ gcell(glstm_input, gcell_zero_state)
- sess.run([variables.global_variables_initializer()])
- [gh_res, h_res] = sess.run([gh, h])
- self.assertAllClose(gh_res[:, 0:int(num_units / number_of_groups)],
- h_res, 1e-5)
- self.assertAllClose(gh_res[:, int(num_units / number_of_groups):],
- h_res, 1e-5)
+ # Try an input whose innermost dimension isn't divisible into groups.
+ glstm_input = array_ops.placeholder(
+ dtypes.float32, shape=[batch_size, 3])
+ with self.assertRaisesRegexp(
+ ValueError,
+ r"input size \(3\) must be divisible by number_of_groups \(2\)"):
+ gcell(glstm_input, gcell_zero_state)
class LayerNormBasicLSTMCellTest(test.TestCase):
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index a6c2d9cdbb..6bea8d4a21 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -2225,6 +2225,13 @@ class GLSTMCell(rnn_cell_impl.RNNCell):
O. Kuchaiev and B. Ginsburg
"Factorization Tricks for LSTM Networks", ICLR 2017 workshop.
+
+ In brief, a G-LSTM cell consists of one LSTM sub-cell per group, where each
+ sub-cell operates on an evenly-sized sub-vector of the input and produces an
+ evenly-sized sub-vector of the output. For example, a G-LSTM cell with 128
+ units and 4 groups consists of 4 LSTMs sub-cells with 32 units each. If that
+ G-LSTM cell is fed a 200-dim input, then each sub-cell receives a 50-dim part
+ of the input and produces a 32-dim part of the output.
"""
def __init__(self,
@@ -2320,9 +2327,12 @@ class GLSTMCell(rnn_cell_impl.RNNCell):
"""Run one step of G-LSTM.
Args:
- inputs: input Tensor, 2D, [batch x num_units].
- state: this must be a tuple of state Tensors, both `2-D`,
- with column sizes `c_state` and `m_state`.
+ inputs: input Tensor, 2D, [batch x num_inputs]. num_inputs must be
+ statically-known and evenly divisible into groups. The innermost
+ vectors of the inputs are split into evenly-sized sub-vectors and fed
+ into the per-group LSTM sub-cells.
+ state: this must be a tuple of state Tensors, both `2-D`, with column
+ sizes `c_state` and `m_state`.
Returns:
A tuple containing:
@@ -2337,11 +2347,24 @@ class GLSTMCell(rnn_cell_impl.RNNCell):
Raises:
ValueError: If input size cannot be inferred from inputs via
- static shape inference.
+ static shape inference, or if the input shape is incompatible
+ with the number of groups.
"""
(c_prev, m_prev) = state
self._batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0]
+
+ # If the input size is statically-known, calculate and validate its group
+ # size. Otherwise, use the output group size.
+ input_size = inputs.shape[1].value
+ if input_size is None:
+ raise ValueError("input size must be statically known")
+ if input_size % self._number_of_groups != 0:
+ raise ValueError(
+ "input size (%d) must be divisible by number_of_groups (%d)" %
+ (input_size, self._number_of_groups))
+ input_group_size = int(input_size / self._number_of_groups)
+
dtype = inputs.dtype
scope = vs.get_variable_scope()
with vs.variable_scope(scope, initializer=self._initializer):
@@ -2354,8 +2377,7 @@ class GLSTMCell(rnn_cell_impl.RNNCell):
with vs.variable_scope("group%d" % group_id):
x_g_id = array_ops.concat(
[
- self._get_input_for_group(inputs, group_id,
- self._group_shape[0]),
+ self._get_input_for_group(inputs, group_id, input_group_size),
self._get_input_for_group(m_prev, group_id,
self._group_shape[0])
],