diff options
author | 2018-06-15 16:18:18 -0700 | |
---|---|---|
committer | 2018-06-15 16:21:04 -0700 | |
commit | ed3adf62db3a4371e01d6b7ac8f69a40f5914f1a (patch) | |
tree | 944dd500a23188b097f861513c5e83271110ac96 | |
parent | 0e85bc7b36d05f585d76d21e55dd09b40c94145a (diff) |
Fixes Eager mode of dynamic_rnn for RNNCells with unbalanced output
PiperOrigin-RevId: 200791012
-rw-r--r-- | tensorflow/python/kernel_tests/rnn_test.py | 41 | ||||
-rw-r--r-- | tensorflow/python/ops/rnn.py | 3 |
2 files changed, 43 insertions, 1 deletions
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py index fe5ad84c10..e9ae105c28 100644 --- a/tensorflow/python/kernel_tests/rnn_test.py +++ b/tensorflow/python/kernel_tests/rnn_test.py @@ -81,6 +81,25 @@ class ScalarStateRNNCell(rnn_cell_impl.RNNCell): return (input_, state + 1) +class UnbalancedOutputRNNCell(rnn_cell_impl.RNNCell): + """RNN Cell generating (output, new_state) = (input + 1, state + 1).""" + + @property + def output_size(self): + return tensor_shape.TensorShape(1), tensor_shape.TensorShape((2)) + + @property + def state_size(self): + return tensor_shape.TensorShape([]) + + def zero_state(self, batch_size, dtype): + return array_ops.zeros([], dtype=dtypes.int32) + + def call(self, input_, state, scope=None): + concatenated = array_ops.concat((input_, input_), axis=-1) + return (input_, concatenated), state + 1 + + class TensorArrayStateRNNCell(rnn_cell_impl.RNNCell): """RNN Cell its state as a TensorArray.""" @@ -183,6 +202,28 @@ class RNNTest(test.TestCase): self.assertAllEqual(4, state) @test_util.run_in_graph_and_eager_modes() + def testUnbalancedOutputIsAccepted(self): + cell = UnbalancedOutputRNNCell() + in_eager_mode = context.executing_eagerly() + + if in_eager_mode: + inputs = np.array([[[1], [2], [3], [4]]], dtype=np.float32) + else: + inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1)) + + with self.test_session() as sess: + outputs, state = rnn.dynamic_rnn( + cell, inputs, dtype=dtypes.float32, sequence_length=[4]) + if not in_eager_mode: + outputs, state = sess.run( + [outputs, state], feed_dict={inputs: [[[1], [2], [3], [4]]]}) + + self.assertIsInstance(outputs, tuple) + self.assertAllEqual([[[1], [2], [3], [4]]], outputs[0]) + self.assertAllEqual([[[1, 1], [2, 2], [3, 3], [4, 4]]], outputs[1]) + self.assertAllEqual(4, state) + + @test_util.run_in_graph_and_eager_modes() def testTensorArrayStateIsAccepted(self): cell = TensorArrayStateRNNCell() in_eager_mode = context.executing_eagerly() diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index 10d576c95b..215140e987 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -828,7 +828,8 @@ def _dynamic_rnn_loop(cell, final_outputs = nest.pack_sequence_as( structure=cell.output_size, flat_sequence=final_outputs) if not in_graph_mode: - final_outputs = array_ops.stack(final_outputs, axis=0) + final_outputs = nest.map_structure_up_to( + cell.output_size, lambda x: array_ops.stack(x, axis=0), final_outputs) return (final_outputs, final_state) |