aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-15 16:18:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-15 16:21:04 -0700
commited3adf62db3a4371e01d6b7ac8f69a40f5914f1a (patch)
tree944dd500a23188b097f861513c5e83271110ac96
parent0e85bc7b36d05f585d76d21e55dd09b40c94145a (diff)
Fixes Eager mode of dynamic_rnn for RNNCells with unbalanced output
PiperOrigin-RevId: 200791012
-rw-r--r--tensorflow/python/kernel_tests/rnn_test.py41
-rw-r--r--tensorflow/python/ops/rnn.py3
2 files changed, 43 insertions, 1 deletions
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py
index fe5ad84c10..e9ae105c28 100644
--- a/tensorflow/python/kernel_tests/rnn_test.py
+++ b/tensorflow/python/kernel_tests/rnn_test.py
@@ -81,6 +81,25 @@ class ScalarStateRNNCell(rnn_cell_impl.RNNCell):
return (input_, state + 1)
+class UnbalancedOutputRNNCell(rnn_cell_impl.RNNCell):
+ """RNN Cell generating (output, new_state) = (input + 1, state + 1)."""
+
+ @property
+ def output_size(self):
+ return tensor_shape.TensorShape(1), tensor_shape.TensorShape((2))
+
+ @property
+ def state_size(self):
+ return tensor_shape.TensorShape([])
+
+ def zero_state(self, batch_size, dtype):
+ return array_ops.zeros([], dtype=dtypes.int32)
+
+ def call(self, input_, state, scope=None):
+ concatenated = array_ops.concat((input_, input_), axis=-1)
+ return (input_, concatenated), state + 1
+
+
class TensorArrayStateRNNCell(rnn_cell_impl.RNNCell):
"""RNN Cell its state as a TensorArray."""
@@ -183,6 +202,28 @@ class RNNTest(test.TestCase):
self.assertAllEqual(4, state)
@test_util.run_in_graph_and_eager_modes()
+ def testUnbalancedOutputIsAccepted(self):
+ cell = UnbalancedOutputRNNCell()
+ in_eager_mode = context.executing_eagerly()
+
+ if in_eager_mode:
+ inputs = np.array([[[1], [2], [3], [4]]], dtype=np.float32)
+ else:
+ inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1))
+
+ with self.test_session() as sess:
+ outputs, state = rnn.dynamic_rnn(
+ cell, inputs, dtype=dtypes.float32, sequence_length=[4])
+ if not in_eager_mode:
+ outputs, state = sess.run(
+ [outputs, state], feed_dict={inputs: [[[1], [2], [3], [4]]]})
+
+ self.assertIsInstance(outputs, tuple)
+ self.assertAllEqual([[[1], [2], [3], [4]]], outputs[0])
+ self.assertAllEqual([[[1, 1], [2, 2], [3, 3], [4, 4]]], outputs[1])
+ self.assertAllEqual(4, state)
+
+ @test_util.run_in_graph_and_eager_modes()
def testTensorArrayStateIsAccepted(self):
cell = TensorArrayStateRNNCell()
in_eager_mode = context.executing_eagerly()
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index 10d576c95b..215140e987 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -828,7 +828,8 @@ def _dynamic_rnn_loop(cell,
final_outputs = nest.pack_sequence_as(
structure=cell.output_size, flat_sequence=final_outputs)
if not in_graph_mode:
- final_outputs = array_ops.stack(final_outputs, axis=0)
+ final_outputs = nest.map_structure_up_to(
+ cell.output_size, lambda x: array_ops.stack(x, axis=0), final_outputs)
return (final_outputs, final_state)