From 5f600c2b1daa004d45b4d63df112f85be1ee5e4b Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Thu, 9 Feb 2017 19:21:29 -0800 Subject: Refactor some code in the new seq2seq api: 1. Rename samplers to helpers and move them to helper.py. 2. Remove the redundant name "Basic" from all helpers' names. 3. Rename SamplingDecoder to BasicDecoder. Change: 147112283 --- tensorflow/contrib/seq2seq/BUILD | 4 +- .../python/kernel_tests/basic_decoder_test.py | 267 ++++++++++++ .../seq2seq/python/kernel_tests/decoder_test.py | 29 +- .../python/kernel_tests/sampling_decoder_test.py | 264 ------------ .../contrib/seq2seq/python/ops/basic_decoder.py | 121 ++++++ tensorflow/contrib/seq2seq/python/ops/helper.py | 361 +++++++++++++++++ .../contrib/seq2seq/python/ops/sampling_decoder.py | 447 --------------------- 7 files changed, 765 insertions(+), 728 deletions(-) create mode 100644 tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py delete mode 100644 tensorflow/contrib/seq2seq/python/kernel_tests/sampling_decoder_test.py create mode 100644 tensorflow/contrib/seq2seq/python/ops/basic_decoder.py create mode 100644 tensorflow/contrib/seq2seq/python/ops/helper.py delete mode 100644 tensorflow/contrib/seq2seq/python/ops/sampling_decoder.py diff --git a/tensorflow/contrib/seq2seq/BUILD b/tensorflow/contrib/seq2seq/BUILD index c10d77f7ef..d5a9dcdce1 100644 --- a/tensorflow/contrib/seq2seq/BUILD +++ b/tensorflow/contrib/seq2seq/BUILD @@ -80,9 +80,9 @@ cuda_py_test( ) cuda_py_test( - name = "sampling_decoder_test", + name = "basic_decoder_test", size = "medium", - srcs = ["python/kernel_tests/sampling_decoder_test.py"], + srcs = ["python/kernel_tests/basic_decoder_test.py"], additional_deps = [ ":seq2seq_py", "//tensorflow/contrib/layers:layers_py", diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py new file mode 100644 index 0000000000..7ef0095b2e --- /dev/null +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py @@ -0,0 +1,267 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for contrib.seq2seq.python.seq2seq.basic_decoder.""" +# pylint: disable=unused-import,g-bad-import-order +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +# pylint: enable=unused-import + +import sys + +# TODO(jart): #6568 Remove this hack that makes dlopen() not crash. +if hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags"): + import ctypes # pylint: disable=g-import-not-at-top + sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL) + +# pylint: disable=g-import-not-at-top +import numpy as np + +from tensorflow.contrib.rnn import core_rnn_cell +from tensorflow.contrib.seq2seq.python.ops import helper as helper_py +from tensorflow.contrib.seq2seq.python.ops import basic_decoder +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +# pylint: enable=g-import-not-at-top + + +class BasicDecoderTest(test.TestCase): + + def testStepWithTrainingHelper(self): + sequence_length = [3, 4, 3, 1, 0] + batch_size = 5 + max_time = 8 + input_depth = 7 + cell_depth = 10 + + with self.test_session() as sess: + inputs = np.random.randn(batch_size, max_time, + input_depth).astype(np.float32) + cell = core_rnn_cell.LSTMCell(cell_depth) + helper = helper_py.TrainingHelper( + inputs, sequence_length, time_major=False) + my_decoder = basic_decoder.BasicDecoder( + cell=cell, + helper=helper, + initial_state=cell.zero_state( + dtype=dtypes.float32, batch_size=batch_size)) + output_size = my_decoder.output_size + output_dtype = my_decoder.output_dtype + self.assertEqual( + basic_decoder.BasicDecoderOutput(cell_depth, + tensor_shape.TensorShape([])), + output_size) + self.assertEqual( + basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32), + output_dtype) + + (first_finished, first_inputs, first_state) = my_decoder.initialize() + (step_outputs, step_state, step_next_inputs, + step_finished) = my_decoder.step( + constant_op.constant(0), first_inputs, first_state) + batch_size_t = my_decoder.batch_size + + self.assertTrue(isinstance(first_state, core_rnn_cell.LSTMStateTuple)) + self.assertTrue(isinstance(step_state, core_rnn_cell.LSTMStateTuple)) + self.assertTrue( + isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) + self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape()) + self.assertEqual((batch_size,), step_outputs[1].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) + + sess.run(variables.global_variables_initializer()) + sess_results = sess.run({ + "batch_size": batch_size_t, + "first_finished": first_finished, + "first_inputs": first_inputs, + "first_state": first_state, + "step_outputs": step_outputs, + "step_state": step_state, + "step_next_inputs": step_next_inputs, + "step_finished": step_finished + }) + + self.assertAllEqual([False, False, False, False, True], + sess_results["first_finished"]) + self.assertAllEqual([False, False, False, True, True], + sess_results["step_finished"]) + self.assertAllEqual( + np.argmax(sess_results["step_outputs"].rnn_output, -1), + sess_results["step_outputs"].sample_id) + + def testStepWithGreedyEmbeddingHelper(self): + batch_size = 5 + vocabulary_size = 7 + cell_depth = vocabulary_size # cell's logits must match vocabulary size + input_depth = 10 + start_tokens = [0] * batch_size + end_token = 1 + + with self.test_session() as sess: + embeddings = np.random.randn(vocabulary_size, + input_depth).astype(np.float32) + cell = core_rnn_cell.LSTMCell(vocabulary_size) + helper = helper_py.GreedyEmbeddingHelper(embeddings, start_tokens, + end_token) + my_decoder = basic_decoder.BasicDecoder( + cell=cell, + helper=helper, + initial_state=cell.zero_state( + dtype=dtypes.float32, batch_size=batch_size)) + output_size = my_decoder.output_size + output_dtype = my_decoder.output_dtype + self.assertEqual( + basic_decoder.BasicDecoderOutput(cell_depth, + tensor_shape.TensorShape([])), + output_size) + self.assertEqual( + basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32), + output_dtype) + + (first_finished, first_inputs, first_state) = my_decoder.initialize() + (step_outputs, step_state, step_next_inputs, + step_finished) = my_decoder.step( + constant_op.constant(0), first_inputs, first_state) + batch_size_t = my_decoder.batch_size + + self.assertTrue(isinstance(first_state, core_rnn_cell.LSTMStateTuple)) + self.assertTrue(isinstance(step_state, core_rnn_cell.LSTMStateTuple)) + self.assertTrue( + isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) + self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape()) + self.assertEqual((batch_size,), step_outputs[1].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) + + sess.run(variables.global_variables_initializer()) + sess_results = sess.run({ + "batch_size": batch_size_t, + "first_finished": first_finished, + "first_inputs": first_inputs, + "first_state": first_state, + "step_outputs": step_outputs, + "step_state": step_state, + "step_next_inputs": step_next_inputs, + "step_finished": step_finished + }) + + expected_sample_ids = np.argmax( + sess_results["step_outputs"].rnn_output, -1) + expected_step_finished = (expected_sample_ids == end_token) + expected_step_next_inputs = embeddings[expected_sample_ids] + self.assertAllEqual([False, False, False, False, False], + sess_results["first_finished"]) + self.assertAllEqual(expected_step_finished, sess_results["step_finished"]) + self.assertAllEqual(expected_sample_ids, + sess_results["step_outputs"].sample_id) + self.assertAllEqual(expected_step_next_inputs, + sess_results["step_next_inputs"]) + + def testStepWithScheduledEmbeddingTrainingHelper(self): + sequence_length = [3, 4, 3, 1, 0] + batch_size = 5 + max_time = 8 + input_depth = 7 + vocabulary_size = 10 + + with self.test_session() as sess: + inputs = np.random.randn( + batch_size, max_time, input_depth).astype(np.float32) + embeddings = np.random.randn( + vocabulary_size, input_depth).astype(np.float32) + half = constant_op.constant(0.5) + cell = core_rnn_cell.LSTMCell(vocabulary_size) + helper = helper_py.ScheduledEmbeddingTrainingHelper( + inputs=inputs, + sequence_length=sequence_length, + embedding=embeddings, + sampling_probability=half, + time_major=False) + my_decoder = basic_decoder.BasicDecoder( + cell=cell, + helper=helper, + initial_state=cell.zero_state( + dtype=dtypes.float32, batch_size=batch_size)) + output_size = my_decoder.output_size + output_dtype = my_decoder.output_dtype + self.assertEqual( + basic_decoder.BasicDecoderOutput(vocabulary_size, + tensor_shape.TensorShape([])), + output_size) + self.assertEqual( + basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32), + output_dtype) + + (first_finished, first_inputs, first_state) = my_decoder.initialize() + (step_outputs, step_state, step_next_inputs, + step_finished) = my_decoder.step( + constant_op.constant(0), first_inputs, first_state) + batch_size_t = my_decoder.batch_size + + self.assertTrue(isinstance(first_state, core_rnn_cell.LSTMStateTuple)) + self.assertTrue(isinstance(step_state, core_rnn_cell.LSTMStateTuple)) + self.assertTrue( + isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) + self.assertEqual((batch_size, vocabulary_size), + step_outputs[0].get_shape()) + self.assertEqual((batch_size,), step_outputs[1].get_shape()) + self.assertEqual((batch_size, vocabulary_size), + first_state[0].get_shape()) + self.assertEqual((batch_size, vocabulary_size), + first_state[1].get_shape()) + self.assertEqual((batch_size, vocabulary_size), + step_state[0].get_shape()) + self.assertEqual((batch_size, vocabulary_size), + step_state[1].get_shape()) + self.assertEqual((batch_size, input_depth), + step_next_inputs.get_shape()) + + sess.run(variables.global_variables_initializer()) + sess_results = sess.run({ + "batch_size": batch_size_t, + "first_finished": first_finished, + "first_inputs": first_inputs, + "first_state": first_state, + "step_outputs": step_outputs, + "step_state": step_state, + "step_next_inputs": step_next_inputs, + "step_finished": step_finished + }) + + self.assertAllEqual([False, False, False, False, True], + sess_results["first_finished"]) + self.assertAllEqual([False, False, False, True, True], + sess_results["step_finished"]) + sample_ids = sess_results["step_outputs"].sample_id + batch_where_not_sampling = np.where(sample_ids == -1) + batch_where_sampling = np.where(sample_ids > -1) + self.assertAllClose( + sess_results["step_next_inputs"][batch_where_sampling], + embeddings[sample_ids[batch_where_sampling]]) + self.assertAllClose( + sess_results["step_next_inputs"][batch_where_not_sampling], + np.squeeze(inputs[batch_where_not_sampling, 1])) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py index 851eb6ef92..76de154b48 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py @@ -31,7 +31,8 @@ import numpy as np from tensorflow.contrib.rnn import core_rnn_cell from tensorflow.contrib.seq2seq.python.ops import decoder -from tensorflow.contrib.seq2seq.python.ops import sampling_decoder +from tensorflow.contrib.seq2seq.python.ops import helper as helper_py +from tensorflow.contrib.seq2seq.python.ops import basic_decoder from tensorflow.python.framework import dtypes from tensorflow.python.ops import rnn from tensorflow.python.ops import variables @@ -59,11 +60,11 @@ class DynamicDecodeRNNTest(test.TestCase): inputs = np.random.randn(batch_size, max_time, input_depth).astype(np.float32) cell = core_rnn_cell.LSTMCell(cell_depth) - sampler = sampling_decoder.BasicTrainingSampler( + helper = helper_py.TrainingHelper( inputs, sequence_length, time_major=time_major) - my_decoder = sampling_decoder.BasicSamplingDecoder( + my_decoder = basic_decoder.BasicDecoder( cell=cell, - sampler=sampler, + helper=helper, initial_state=cell.zero_state( dtype=dtypes.float32, batch_size=batch_size)) @@ -77,7 +78,7 @@ class DynamicDecodeRNNTest(test.TestCase): return shape self.assertTrue( - isinstance(final_outputs, sampling_decoder.SamplingDecoderOutput)) + isinstance(final_outputs, basic_decoder.BasicDecoderOutput)) self.assertTrue(isinstance(final_state, core_rnn_cell.LSTMStateTuple)) self.assertEqual( @@ -116,7 +117,7 @@ class DynamicDecodeRNNTest(test.TestCase): def testDynamicDecodeRNNOneMaxIter(self): self._testDynamicDecodeRNN(time_major=True, maximum_iterations=1) - def _testDynamicDecodeRNNWithBasicTrainingSamplerMatchesDynamicRNN( + def _testDynamicDecodeRNNWithTrainingHelperMatchesDynamicRNN( self, use_sequence_length): sequence_length = [3, 4, 3, 1, 0] batch_size = 5 @@ -131,9 +132,9 @@ class DynamicDecodeRNNTest(test.TestCase): cell = core_rnn_cell.LSTMCell(cell_depth) zero_state = cell.zero_state(dtype=dtypes.float32, batch_size=batch_size) - sampler = sampling_decoder.BasicTrainingSampler(inputs, sequence_length) - my_decoder = sampling_decoder.BasicSamplingDecoder( - cell=cell, sampler=sampler, initial_state=zero_state) + helper = helper_py.TrainingHelper(inputs, sequence_length) + my_decoder = basic_decoder.BasicDecoder( + cell=cell, helper=helper, initial_state=zero_state) # Match the variable scope of dynamic_rnn below so we end up # using the same variables @@ -169,14 +170,12 @@ class DynamicDecodeRNNTest(test.TestCase): self.assertAllClose(sess_results["final_decoder_state"], sess_results["final_rnn_state"]) - def testDynamicDecodeRNNWithBasicTrainingSamplerMatchesDynamicRNNWithSeqLen( - self): - self._testDynamicDecodeRNNWithBasicTrainingSamplerMatchesDynamicRNN( + def testDynamicDecodeRNNWithTrainingHelperMatchesDynamicRNNWithSeqLen(self): + self._testDynamicDecodeRNNWithTrainingHelperMatchesDynamicRNN( use_sequence_length=True) - def testDynamicDecodeRNNWithBasicTrainingSamplerMatchesDynamicRNNNoSeqLen( - self): - self._testDynamicDecodeRNNWithBasicTrainingSamplerMatchesDynamicRNN( + def testDynamicDecodeRNNWithTrainingHelperMatchesDynamicRNNNoSeqLen(self): + self._testDynamicDecodeRNNWithTrainingHelperMatchesDynamicRNN( use_sequence_length=False) if __name__ == "__main__": diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/sampling_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/sampling_decoder_test.py deleted file mode 100644 index 3f8b4c077d..0000000000 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/sampling_decoder_test.py +++ /dev/null @@ -1,264 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for contrib.seq2seq.python.seq2seq.sampling_decoder.""" -# pylint: disable=unused-import,g-bad-import-order -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -# pylint: enable=unused-import - -import sys - -# TODO(jart): #6568 Remove this hack that makes dlopen() not crash. -if hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags"): - import ctypes # pylint: disable=g-import-not-at-top - sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL) - -# pylint: disable=g-import-not-at-top -import numpy as np - -from tensorflow.contrib.rnn import core_rnn_cell -from tensorflow.contrib.seq2seq.python.ops import sampling_decoder -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import variables -from tensorflow.python.platform import test -# pylint: enable=g-import-not-at-top - - -class BasicSamplingDecoderTest(test.TestCase): - - def testStepWithBasicTrainingSampler(self): - sequence_length = [3, 4, 3, 1, 0] - batch_size = 5 - max_time = 8 - input_depth = 7 - cell_depth = 10 - - with self.test_session() as sess: - inputs = np.random.randn(batch_size, max_time, - input_depth).astype(np.float32) - cell = core_rnn_cell.LSTMCell(cell_depth) - sampler = sampling_decoder.BasicTrainingSampler( - inputs, sequence_length, time_major=False) - my_decoder = sampling_decoder.BasicSamplingDecoder( - cell=cell, - sampler=sampler, - initial_state=cell.zero_state( - dtype=dtypes.float32, batch_size=batch_size)) - output_size = my_decoder.output_size - output_dtype = my_decoder.output_dtype - self.assertEqual( - sampling_decoder.SamplingDecoderOutput(cell_depth, - tensor_shape.TensorShape([])), - output_size) - self.assertEqual( - sampling_decoder.SamplingDecoderOutput(dtypes.float32, dtypes.int32), - output_dtype) - - (first_finished, first_inputs, first_state) = my_decoder.initialize() - (step_outputs, step_state, step_next_inputs, - step_finished) = my_decoder.step( - constant_op.constant(0), first_inputs, first_state) - batch_size_t = my_decoder.batch_size - - self.assertTrue(isinstance(first_state, core_rnn_cell.LSTMStateTuple)) - self.assertTrue(isinstance(step_state, core_rnn_cell.LSTMStateTuple)) - self.assertTrue( - isinstance(step_outputs, sampling_decoder.SamplingDecoderOutput)) - self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape()) - self.assertEqual((batch_size,), step_outputs[1].get_shape()) - self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) - self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) - self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) - self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) - - sess.run(variables.global_variables_initializer()) - sess_results = sess.run({ - "batch_size": batch_size_t, - "first_finished": first_finished, - "first_inputs": first_inputs, - "first_state": first_state, - "step_outputs": step_outputs, - "step_state": step_state, - "step_next_inputs": step_next_inputs, - "step_finished": step_finished - }) - - self.assertAllEqual([False, False, False, False, True], - sess_results["first_finished"]) - self.assertAllEqual([False, False, False, True, True], - sess_results["step_finished"]) - self.assertAllEqual( - np.argmax(sess_results["step_outputs"].rnn_output, -1), - sess_results["step_outputs"].sample_id) - - def testStepWithGreedyEmbeddingSampler(self): - batch_size = 5 - vocabulary_size = 7 - cell_depth = vocabulary_size # cell's logits must match vocabulary size - input_depth = 10 - start_tokens = [0] * batch_size - end_token = 1 - - with self.test_session() as sess: - embeddings = np.random.randn(vocabulary_size, - input_depth).astype(np.float32) - cell = core_rnn_cell.LSTMCell(vocabulary_size) - sampler = sampling_decoder.GreedyEmbeddingSampler( - embeddings, start_tokens, end_token) - my_decoder = sampling_decoder.BasicSamplingDecoder( - cell=cell, - sampler=sampler, - initial_state=cell.zero_state( - dtype=dtypes.float32, batch_size=batch_size)) - output_size = my_decoder.output_size - output_dtype = my_decoder.output_dtype - self.assertEqual( - sampling_decoder.SamplingDecoderOutput(cell_depth, - tensor_shape.TensorShape([])), - output_size) - self.assertEqual( - sampling_decoder.SamplingDecoderOutput(dtypes.float32, dtypes.int32), - output_dtype) - - (first_finished, first_inputs, first_state) = my_decoder.initialize() - (step_outputs, step_state, step_next_inputs, - step_finished) = my_decoder.step( - constant_op.constant(0), first_inputs, first_state) - batch_size_t = my_decoder.batch_size - - self.assertTrue(isinstance(first_state, core_rnn_cell.LSTMStateTuple)) - self.assertTrue(isinstance(step_state, core_rnn_cell.LSTMStateTuple)) - self.assertTrue( - isinstance(step_outputs, sampling_decoder.SamplingDecoderOutput)) - self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape()) - self.assertEqual((batch_size,), step_outputs[1].get_shape()) - self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) - self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) - self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) - self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) - - sess.run(variables.global_variables_initializer()) - sess_results = sess.run({ - "batch_size": batch_size_t, - "first_finished": first_finished, - "first_inputs": first_inputs, - "first_state": first_state, - "step_outputs": step_outputs, - "step_state": step_state, - "step_next_inputs": step_next_inputs, - "step_finished": step_finished - }) - - expected_sample_ids = np.argmax( - sess_results["step_outputs"].rnn_output, -1) - expected_step_finished = (expected_sample_ids == end_token) - expected_step_next_inputs = embeddings[expected_sample_ids] - self.assertAllEqual([False, False, False, False, False], - sess_results["first_finished"]) - self.assertAllEqual(expected_step_finished, sess_results["step_finished"]) - self.assertAllEqual(expected_sample_ids, - sess_results["step_outputs"].sample_id) - self.assertAllEqual(expected_step_next_inputs, - sess_results["step_next_inputs"]) - - def testStepWithScheduledEmbeddingTrainingSampler(self): - sequence_length = [3, 4, 3, 1, 0] - batch_size = 5 - max_time = 8 - input_depth = 7 - vocabulary_size = 10 - - with self.test_session() as sess: - inputs = np.random.randn( - batch_size, max_time, input_depth).astype(np.float32) - embeddings = np.random.randn( - vocabulary_size, input_depth).astype(np.float32) - half = constant_op.constant(0.5) - cell = core_rnn_cell.LSTMCell(vocabulary_size) - sampler = sampling_decoder.ScheduledEmbeddingTrainingSampler( - inputs=inputs, sequence_length=sequence_length, - embedding=embeddings, sampling_probability=half, - time_major=False) - my_decoder = sampling_decoder.BasicSamplingDecoder( - cell=cell, - sampler=sampler, - initial_state=cell.zero_state( - dtype=dtypes.float32, batch_size=batch_size)) - output_size = my_decoder.output_size - output_dtype = my_decoder.output_dtype - self.assertEqual( - sampling_decoder.SamplingDecoderOutput(vocabulary_size, - tensor_shape.TensorShape([])), - output_size) - self.assertEqual( - sampling_decoder.SamplingDecoderOutput(dtypes.float32, dtypes.int32), - output_dtype) - - (first_finished, first_inputs, first_state) = my_decoder.initialize() - (step_outputs, step_state, step_next_inputs, - step_finished) = my_decoder.step( - constant_op.constant(0), first_inputs, first_state) - batch_size_t = my_decoder.batch_size - - self.assertTrue(isinstance(first_state, core_rnn_cell.LSTMStateTuple)) - self.assertTrue(isinstance(step_state, core_rnn_cell.LSTMStateTuple)) - self.assertTrue( - isinstance(step_outputs, sampling_decoder.SamplingDecoderOutput)) - self.assertEqual((batch_size, vocabulary_size), - step_outputs[0].get_shape()) - self.assertEqual((batch_size,), step_outputs[1].get_shape()) - self.assertEqual((batch_size, vocabulary_size), - first_state[0].get_shape()) - self.assertEqual((batch_size, vocabulary_size), - first_state[1].get_shape()) - self.assertEqual((batch_size, vocabulary_size), - step_state[0].get_shape()) - self.assertEqual((batch_size, vocabulary_size), - step_state[1].get_shape()) - self.assertEqual((batch_size, input_depth), - step_next_inputs.get_shape()) - - sess.run(variables.global_variables_initializer()) - sess_results = sess.run({ - "batch_size": batch_size_t, - "first_finished": first_finished, - "first_inputs": first_inputs, - "first_state": first_state, - "step_outputs": step_outputs, - "step_state": step_state, - "step_next_inputs": step_next_inputs, - "step_finished": step_finished - }) - - self.assertAllEqual([False, False, False, False, True], - sess_results["first_finished"]) - self.assertAllEqual([False, False, False, True, True], - sess_results["step_finished"]) - sample_ids = sess_results["step_outputs"].sample_id - batch_where_not_sampling = np.where(sample_ids == -1) - batch_where_sampling = np.where(sample_ids > -1) - self.assertAllClose( - sess_results["step_next_inputs"][batch_where_sampling], - embeddings[sample_ids[batch_where_sampling]]) - self.assertAllClose( - sess_results["step_next_inputs"][batch_where_not_sampling], - np.squeeze(inputs[batch_where_not_sampling, 1])) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py b/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py new file mode 100644 index 0000000000..eac2c179b2 --- /dev/null +++ b/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py @@ -0,0 +1,121 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A class of Decoders that may sample to generate the next input. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from tensorflow.contrib.rnn import core_rnn_cell +from tensorflow.contrib.seq2seq.python.ops import decoder +from tensorflow.contrib.seq2seq.python.ops import helper as helper_py +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.util import nest + + +__all__ = [ + "BasicDecoderOutput", + "BasicDecoder", +] + + +class BasicDecoderOutput( + collections.namedtuple("BasicDecoderOutput", ("rnn_output", "sample_id"))): + pass + + +class BasicDecoder(decoder.Decoder): + """Basic sampling decoder.""" + + def __init__(self, cell, helper, initial_state): + """Initialize BasicDecoder. + + Args: + cell: An `RNNCell` instance. + helper: A `Helper` instance. + initial_state: A (possibly nested tuple of...) tensors and TensorArrays. + + Raises: + TypeError: if `cell` is not an instance of `RNNCell` or `helper` + is not an instance of `Helper`. + """ + if not isinstance(cell, core_rnn_cell.RNNCell): + raise TypeError("cell must be an RNNCell, received: %s" % type(cell)) + if not isinstance(helper, helper_py.Helper): + raise TypeError("helper must be a Helper, received: %s" % type(helper)) + self._cell = cell + self._helper = helper + self._initial_state = initial_state + + @property + def batch_size(self): + return self._helper.batch_size + + @property + def output_size(self): + # Return the cell output and the id + return BasicDecoderOutput( + rnn_output=self._cell.output_size, + sample_id=tensor_shape.TensorShape([])) + + @property + def output_dtype(self): + # Assume the dtype of the cell is the output_size structure + # containing the input_state's first component's dtype. + # Return that structure and int32 (the id) + dtype = nest.flatten(self._initial_state)[0].dtype + return BasicDecoderOutput( + nest.map_structure(lambda _: dtype, self._cell.output_size), + dtypes.int32) + + def initialize(self, name=None): + """Initialize the decoder. + + Args: + name: Name scope for any created operations. + + Returns: + `(finished, first_inputs, initial_state)`. + """ + return self._helper.initialize() + (self._initial_state,) + + def step(self, time, inputs, state, name=None): + """Perform a decoding step. + + Args: + time: scalar `int32` tensor. + inputs: A (structure of) input tensors. + state: A (structure of) state tensors and TensorArrays. + name: Name scope for any created operations. + + Returns: + `(outputs, next_state, next_inputs, finished)`. + """ + with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state)): + cell_outputs, cell_state = self._cell(inputs, state) + sample_ids = self._helper.sample( + time=time, outputs=cell_outputs, state=cell_state) + (finished, next_inputs, next_state) = self._helper.next_inputs( + time=time, + outputs=cell_outputs, + state=cell_state, + sample_ids=sample_ids) + outputs = BasicDecoderOutput(cell_outputs, sample_ids) + return (outputs, next_state, next_inputs, finished) diff --git a/tensorflow/contrib/seq2seq/python/ops/helper.py b/tensorflow/contrib/seq2seq/python/ops/helper.py new file mode 100644 index 0000000000..46d0563fe0 --- /dev/null +++ b/tensorflow/contrib/seq2seq/python/ops/helper.py @@ -0,0 +1,361 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A library of helpers for use with SamplingDecoders. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc + +import six + +from tensorflow.contrib.distributions.python.ops import categorical +from tensorflow.contrib.seq2seq.python.ops import decoder +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import embedding_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.util import nest + +__all__ = [ + "Helper", + "TrainingHelper", + "GreedyEmbeddingHelper", + "CustomHelper", + "ScheduledEmbeddingTrainingHelper", +] + +_transpose_batch_time = decoder._transpose_batch_time # pylint: disable=protected-access + + +@six.add_metaclass(abc.ABCMeta) +class Helper(object): + """Helper interface. Helper instances are used by SamplingDecoder.""" + + @abc.abstractproperty + def batch_size(self): + """Returns a scalar int32 tensor.""" + raise NotImplementedError("batch_size has not been implemented") + + @abc.abstractmethod + def initialize(self, name=None): + """Returns `(initial_finished, initial_inputs)`.""" + pass + + @abc.abstractmethod + def sample(self, time, outputs, state, name=None): + """Returns `sample_ids`.""" + pass + + @abc.abstractmethod + def next_inputs(self, time, outputs, state, sample_ids, name=None): + """Returns `(finished, next_inputs, next_state)`.""" + pass + + +class CustomHelper(Helper): + """Base abstract class that allows the user to customize sampling.""" + + def __init__(self, initialize_fn, sample_fn, next_inputs_fn): + """Initializer. + + Args: + initialize_fn: callable that returns `(finished, next_inputs)` + for the first iteration. + sample_fn: callable that takes `(time, outputs, state)` + and emits tensor `sample_ids`. + next_inputs_fn: callable that takes `(time, outputs, state, sample_ids)` + and emits `(finished, next_inputs, next_state)`. + """ + self._initialize_fn = initialize_fn + self._sample_fn = sample_fn + self._next_inputs_fn = next_inputs_fn + self._batch_size = None + + @property + def batch_size(self): + if self._batch_size is None: + raise ValueError("batch_size accessed before initialize was called") + return self._batch_size + + def initialize(self, name=None): + with ops.name_scope(name, "%sInitialize" % type(self).__name__): + (finished, next_inputs) = self._initialize_fn() + if self._batch_size is None: + self._batch_size = array_ops.size(finished) + return (finished, next_inputs) + + def sample(self, time, outputs, state, name=None): + with ops.name_scope( + name, "%sSample" % type(self).__name__, (time, outputs, state)): + return self._sample_fn(time=time, outputs=outputs, state=state) + + def next_inputs(self, time, outputs, state, sample_ids, name=None): + with ops.name_scope( + name, "%sNextInputs" % type(self).__name__, (time, outputs, state)): + return self._next_inputs_fn( + time=time, outputs=outputs, state=state, sample_ids=sample_ids) + + +class TrainingHelper(Helper): + """A helper for use during training. Only reads inputs. + + Returned sample_ids are the argmax of the RNN output logits. + """ + + def __init__(self, inputs, sequence_length, time_major=False, name=None): + """Initializer. + + Args: + inputs: A (structure of) input tensors. + sequence_length: An int32 vector tensor. + time_major: Python bool. Whether the tensors in `inputs` are time major. + If `False` (default), they are assumed to be batch major. + name: Name scope for any created operations. + + Raises: + ValueError: if `sequence_length` is not a 1D tensor. + """ + with ops.name_scope(name, "TrainingHelper", [inputs, sequence_length]): + inputs = ops.convert_to_tensor(inputs, name="inputs") + if not time_major: + inputs = nest.map_structure(_transpose_batch_time, inputs) + + def _unstack_ta(inp): + return tensor_array_ops.TensorArray( + dtype=inp.dtype, size=array_ops.shape(inp)[0], + element_shape=inp.get_shape()[1:]).unstack(inp) + + self._input_tas = nest.map_structure(_unstack_ta, inputs) + self._sequence_length = ops.convert_to_tensor( + sequence_length, name="sequence_length") + if self._sequence_length.get_shape().ndims != 1: + raise ValueError( + "Expected sequence_length to be a vector, but received shape: %s" % + self._sequence_length.get_shape()) + + self._zero_inputs = nest.map_structure( + lambda inp: array_ops.zeros_like(inp[0, :]), inputs) + + self._batch_size = array_ops.size(sequence_length) + + @property + def batch_size(self): + return self._batch_size + + def initialize(self, name=None): + with ops.name_scope(name, "TrainingHelperInitialize"): + finished = math_ops.equal(0, self._sequence_length) + all_finished = math_ops.reduce_all(finished) + next_inputs = control_flow_ops.cond( + all_finished, lambda: self._zero_inputs, + lambda: nest.map_structure(lambda inp: inp.read(0), self._input_tas)) + return (finished, next_inputs) + + def sample(self, time, outputs, name=None, **unused_kwargs): + with ops.name_scope(name, "TrainingHelperSample", [time, outputs]): + sample_ids = math_ops.cast( + math_ops.argmax(outputs, axis=-1), dtypes.int32) + return sample_ids + + def next_inputs(self, time, outputs, state, name=None, **unused_kwargs): + """next_inputs_fn for TrainingHelper.""" + with ops.name_scope(name, "TrainingHelperNextInputs", + [time, outputs, state]): + next_time = time + 1 + finished = (next_time >= self._sequence_length) + all_finished = math_ops.reduce_all(finished) + def read_from_ta(inp): + return inp.read(next_time) + next_inputs = control_flow_ops.cond( + all_finished, lambda: self._zero_inputs, + lambda: nest.map_structure(read_from_ta, self._input_tas)) + return (finished, next_inputs, state) + + +class ScheduledEmbeddingTrainingHelper(TrainingHelper): + """A training helper that adds scheduled sampling. + + Returns -1s for sample_ids where no sampling took place; valid sample id + values elsewhere. + """ + + def __init__(self, inputs, sequence_length, embedding, sampling_probability, + time_major=False, seed=None, scheduling_seed=None, name=None): + """Initializer. + + Args: + inputs: A (structure of) input tensors. + sequence_length: An int32 vector tensor. + embedding: A callable that takes a vector tensor of `ids` (argmax ids), + or the `params` argument for `embedding_lookup`. + sampling_probability: A 0D `float32` tensor: the probability of sampling + categorically from the output ids instead of reading directly from the + inputs. + time_major: Python bool. Whether the tensors in `inputs` are time major. + If `False` (default), they are assumed to be batch major. + seed: The sampling seed. + scheduling_seed: The schedule decision rule sampling seed. + name: Name scope for any created operations. + + Raises: + ValueError: if `sampling_probability` is not a scalar or vector. + """ + with ops.name_scope(name, "ScheduledEmbeddingSamplingWrapper", + [embedding, sampling_probability]): + if callable(embedding): + self._embedding_fn = embedding + else: + self._embedding_fn = ( + lambda ids: embedding_ops.embedding_lookup(embedding, ids)) + self._sampling_probability = ops.convert_to_tensor( + sampling_probability, name="sampling_probability") + if self._sampling_probability.get_shape().ndims not in (0, 1): + raise ValueError( + "sampling_probability must be either a scalar or a vector. " + "saw shape: %s" % (self._sampling_probability.get_shape())) + self._seed = seed + self._scheduling_seed = scheduling_seed + super(ScheduledEmbeddingTrainingHelper, self).__init__( + inputs=inputs, + sequence_length=sequence_length, + time_major=time_major, + name=name) + + def initialize(self, name=None): + return super(ScheduledEmbeddingTrainingHelper, self).initialize(name=name) + + def sample(self, time, outputs, state, name=None): + with ops.name_scope(name, "ScheduledEmbeddingTrainingHelperSample", + [time, outputs, state]): + # Return -1s where we did not sample, and sample_ids elsewhere + select_sample_noise = random_ops.random_uniform( + [self.batch_size], seed=self._scheduling_seed) + select_sample = (self._sampling_probability > select_sample_noise) + sample_id_sampler = categorical.Categorical(logits=outputs) + return array_ops.where( + select_sample, + sample_id_sampler.sample(seed=self._seed), + array_ops.tile([-1], [self.batch_size])) + + def next_inputs(self, time, outputs, state, sample_ids, name=None): + with ops.name_scope(name, "ScheduledEmbeddingTrainingHelperSample", + [time, outputs, state, sample_ids]): + (finished, base_next_inputs, state) = ( + super(ScheduledEmbeddingTrainingHelper, self).next_inputs( + time=time, + outputs=outputs, + state=state, + sample_ids=sample_ids, + name=name)) + + def maybe_sample(): + """Perform scheduled sampling.""" + where_sampling = math_ops.cast( + array_ops.where(sample_ids > -1), dtypes.int32) + where_not_sampling = math_ops.cast( + array_ops.where(sample_ids <= -1), dtypes.int32) + where_sampling_flat = array_ops.reshape(where_sampling, [-1]) + where_not_sampling_flat = array_ops.reshape(where_not_sampling, [-1]) + sample_ids_sampling = array_ops.gather(sample_ids, where_sampling_flat) + inputs_not_sampling = array_ops.gather( + base_next_inputs, where_not_sampling_flat) + sampled_next_inputs = self._embedding_fn(sample_ids_sampling) + base_shape = array_ops.shape(base_next_inputs) + return (array_ops.scatter_nd(indices=where_sampling, + updates=sampled_next_inputs, + shape=base_shape) + + array_ops.scatter_nd(indices=where_not_sampling, + updates=inputs_not_sampling, + shape=base_shape)) + + all_finished = math_ops.reduce_all(finished) + next_inputs = control_flow_ops.cond( + all_finished, lambda: base_next_inputs, maybe_sample) + return (finished, next_inputs, state) + + +class GreedyEmbeddingHelper(Helper): + """A helper for use during inference. + + Uses the argmax of the output (treated as logits) and passes the + result through an embedding layer to get the next input. + """ + + def __init__(self, embedding, start_tokens, end_token): + """Initializer. + + Args: + embedding: A callable that takes a vector tensor of `ids` (argmax ids), + or the `params` argument for `embedding_lookup`. + start_tokens: `int32` vector shaped `[batch_size]`, the start tokens. + end_token: `int32` scalar, the token that marks end of decoding. + + Raises: + ValueError: if `sequence_length` is not a 1D tensor. + """ + if callable(embedding): + self._embedding_fn = embedding + else: + self._embedding_fn = ( + lambda ids: embedding_ops.embedding_lookup(embedding, ids)) + + self._start_tokens = ops.convert_to_tensor( + start_tokens, dtype=dtypes.int32, name="start_tokens") + self._end_token = ops.convert_to_tensor( + end_token, dtype=dtypes.int32, name="end_token") + if self._start_tokens.get_shape().ndims != 1: + raise ValueError("start_tokens must be a vector") + self._batch_size = array_ops.size(start_tokens) + if self._end_token.get_shape().ndims != 0: + raise ValueError("end_token must be a scalar") + self._start_inputs = self._embedding_fn(self._start_tokens) + + @property + def batch_size(self): + return self._batch_size + + def initialize(self, name=None): + finished = array_ops.tile([False], [self._batch_size]) + return (finished, self._start_inputs) + + def sample(self, time, outputs, state, name=None): + """sample for GreedyEmbeddingHelper.""" + del time, state # unused by sample_fn + # Outputs are logits, use argmax to get the most probable id + if not isinstance(outputs, ops.Tensor): + raise TypeError("Expected outputs to be a single Tensor, got: %s" % + outputs) + sample_ids = math_ops.cast( + math_ops.argmax(outputs, axis=-1), dtypes.int32) + return sample_ids + + def next_inputs(self, time, outputs, state, sample_ids, name=None): + """next_inputs_fn for GreedyEmbeddingHelper.""" + del time, outputs # unused by next_inputs_fn + finished = math_ops.equal(sample_ids, self._end_token) + all_finished = math_ops.reduce_all(finished) + next_inputs = control_flow_ops.cond( + all_finished, + # If we're finished, the next_inputs value doesn't matter + lambda: self._start_inputs, + lambda: self._embedding_fn(sample_ids)) + return (finished, next_inputs, state) diff --git a/tensorflow/contrib/seq2seq/python/ops/sampling_decoder.py b/tensorflow/contrib/seq2seq/python/ops/sampling_decoder.py deleted file mode 100644 index 3cd986cb04..0000000000 --- a/tensorflow/contrib/seq2seq/python/ops/sampling_decoder.py +++ /dev/null @@ -1,447 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""A class of Decoders that may sample to generate the next input. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import abc -import collections - -import six - -from tensorflow.contrib.distributions.python.ops import categorical -from tensorflow.contrib.rnn import core_rnn_cell -from tensorflow.contrib.seq2seq.python.ops import decoder -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import embedding_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import tensor_array_ops -from tensorflow.python.util import nest - -__all__ = [ - "Sampler", "SamplingDecoderOutput", "BasicSamplingDecoder", - "BasicTrainingSampler", "GreedyEmbeddingSampler", "CustomSampler", - "ScheduledEmbeddingTrainingSampler", -] - -_transpose_batch_time = decoder._transpose_batch_time # pylint: disable=protected-access - - -@six.add_metaclass(abc.ABCMeta) -class Sampler(object): - """Sampler interface. Sampler instances are used by BasicSamplingDecoder.""" - - @abc.abstractproperty - def batch_size(self): - """Returns a scalar int32 tensor.""" - raise NotImplementedError("batch_size has not been implemented") - - @abc.abstractmethod - def initialize(self, name=None): - """Returns `(initial_finished, initial_inputs)`.""" - pass - - @abc.abstractmethod - def sample(self, time, outputs, state, name=None): - """Returns `sample_ids`.""" - pass - - @abc.abstractmethod - def next_inputs(self, time, outputs, state, sample_ids, name=None): - """Returns `(finished, next_inputs, next_state)`.""" - pass - - -class SamplingDecoderOutput( - collections.namedtuple("SamplingDecoderOutput", - ("rnn_output", "sample_id"))): - pass - - -class BasicSamplingDecoder(decoder.Decoder): - """Basic sampling decoder.""" - - def __init__(self, cell, sampler, initial_state): - """Initialize BasicSamplingDecoder. - - Args: - cell: An `RNNCell` instance. - sampler: A `Sampler` instance. - initial_state: A (possibly nested tuple of...) tensors and TensorArrays. - - Raises: - TypeError: if `cell` is not an instance of `RNNCell` or `sampler` - is not an instance of `Sampler`. - """ - if not isinstance(cell, core_rnn_cell.RNNCell): - raise TypeError("cell must be an RNNCell, received: %s" % type(cell)) - if not isinstance(sampler, Sampler): - raise TypeError("sampler must be a Sampler, received: %s" % - type(sampler)) - self._cell = cell - self._sampler = sampler - self._initial_state = initial_state - - @property - def batch_size(self): - return self._sampler.batch_size - - @property - def output_size(self): - # Return the cell output and the id - return SamplingDecoderOutput( - rnn_output=self._cell.output_size, - sample_id=tensor_shape.TensorShape([])) - - @property - def output_dtype(self): - # Assume the dtype of the cell is the output_size structure - # containing the input_state's first component's dtype. - # Return that structure and int32 (the id) - dtype = nest.flatten(self._initial_state)[0].dtype - return SamplingDecoderOutput( - nest.map_structure(lambda _: dtype, self._cell.output_size), - dtypes.int32) - - def initialize(self, name=None): - """Initialize the decoder. - - Args: - name: Name scope for any created operations. - - Returns: - `(finished, first_inputs, initial_state)`. - """ - return self._sampler.initialize() + (self._initial_state,) - - def step(self, time, inputs, state, name=None): - """Perform a decoding step. - - Args: - time: scalar `int32` tensor. - inputs: A (structure of) input tensors. - state: A (structure of) state tensors and TensorArrays. - name: Name scope for any created operations. - - Returns: - `(outputs, next_state, next_inputs, finished)`. - """ - with ops.name_scope( - name, "BasicSamplingDecoderStep", (time, inputs, state)): - cell_outputs, cell_state = self._cell(inputs, state) - sample_ids = self._sampler.sample( - time=time, outputs=cell_outputs, state=cell_state) - (finished, next_inputs, next_state) = self._sampler.next_inputs( - time=time, outputs=cell_outputs, state=cell_state, - sample_ids=sample_ids) - outputs = SamplingDecoderOutput(cell_outputs, sample_ids) - return (outputs, next_state, next_inputs, finished) - - -class CustomSampler(Sampler): - """Base abstract class that allows the user to customize sampling.""" - - def __init__(self, initialize_fn, sample_fn, next_inputs_fn): - """Initializer. - - Args: - initialize_fn: callable that returns `(finished, next_inputs)` - for the first iteration. - sample_fn: callable that takes `(time, outputs, state)` - and emits tensor `sample_ids`. - next_inputs_fn: callable that takes `(time, outputs, state, sample_ids)` - and emits `(finished, next_inputs, next_state)`. - """ - self._initialize_fn = initialize_fn - self._sample_fn = sample_fn - self._next_inputs_fn = next_inputs_fn - self._batch_size = None - - @property - def batch_size(self): - if self._batch_size is None: - raise ValueError("batch_size accessed before initialize was called") - return self._batch_size - - def initialize(self, name=None): - with ops.name_scope(name, "%sInitialize" % type(self).__name__): - (finished, next_inputs) = self._initialize_fn() - if self._batch_size is None: - self._batch_size = array_ops.size(finished) - return (finished, next_inputs) - - def sample(self, time, outputs, state, name=None): - with ops.name_scope( - name, "%sSample" % type(self).__name__, (time, outputs, state)): - return self._sample_fn(time=time, outputs=outputs, state=state) - - def next_inputs(self, time, outputs, state, sample_ids, name=None): - with ops.name_scope( - name, "%sNextInputs" % type(self).__name__, (time, outputs, state)): - return self._next_inputs_fn( - time=time, outputs=outputs, state=state, sample_ids=sample_ids) - - -class BasicTrainingSampler(Sampler): - """A (non-)sampler for use during training. Only reads inputs. - - Returned sample_ids are the argmax of the RNN output logits. - """ - - def __init__(self, inputs, sequence_length, time_major=False, name=None): - """Initializer. - - Args: - inputs: A (structure of) input tensors. - sequence_length: An int32 vector tensor. - time_major: Python bool. Whether the tensors in `inputs` are time major. - If `False` (default), they are assumed to be batch major. - name: Name scope for any created operations. - - Raises: - ValueError: if `sequence_length` is not a 1D tensor. - """ - with ops.name_scope( - name, "BasicTrainingSampler", [inputs, sequence_length]): - inputs = ops.convert_to_tensor(inputs, name="inputs") - if not time_major: - inputs = nest.map_structure(_transpose_batch_time, inputs) - - def _unstack_ta(inp): - return tensor_array_ops.TensorArray( - dtype=inp.dtype, size=array_ops.shape(inp)[0], - element_shape=inp.get_shape()[1:]).unstack(inp) - - self._input_tas = nest.map_structure(_unstack_ta, inputs) - self._sequence_length = ops.convert_to_tensor( - sequence_length, name="sequence_length") - if self._sequence_length.get_shape().ndims != 1: - raise ValueError( - "Expected sequence_length to be a vector, but received shape: %s" % - self._sequence_length.get_shape()) - - self._zero_inputs = nest.map_structure( - lambda inp: array_ops.zeros_like(inp[0, :]), inputs) - - self._batch_size = array_ops.size(sequence_length) - - @property - def batch_size(self): - return self._batch_size - - def initialize(self, name=None): - with ops.name_scope(name, "BasicTrainingSamplerInitialize"): - finished = math_ops.equal(0, self._sequence_length) - all_finished = math_ops.reduce_all(finished) - next_inputs = control_flow_ops.cond( - all_finished, lambda: self._zero_inputs, - lambda: nest.map_structure(lambda inp: inp.read(0), self._input_tas)) - return (finished, next_inputs) - - def sample(self, time, outputs, name=None, **unused_kwargs): - with ops.name_scope(name, "BasicTrainingSamplerSample", [time, outputs]): - sample_ids = math_ops.cast( - math_ops.argmax(outputs, axis=-1), dtypes.int32) - return sample_ids - - def next_inputs(self, time, outputs, state, name=None, **unused_kwargs): - """next_inputs_fn for BasicTrainingSampler.""" - with ops.name_scope( - name, "BasicTrainingSamplerNextInputs", [time, outputs, state]): - next_time = time + 1 - finished = (next_time >= self._sequence_length) - all_finished = math_ops.reduce_all(finished) - def read_from_ta(inp): - return inp.read(next_time) - next_inputs = control_flow_ops.cond( - all_finished, lambda: self._zero_inputs, - lambda: nest.map_structure(read_from_ta, self._input_tas)) - return (finished, next_inputs, state) - - -class ScheduledEmbeddingTrainingSampler(BasicTrainingSampler): - """A training sampler that adds scheduled sampling. - - Returns -1s for sample_ids where no sampling took place; valid sample id - values elsewhere. - """ - - def __init__(self, inputs, sequence_length, embedding, sampling_probability, - time_major=False, seed=None, scheduling_seed=None, name=None): - """Initializer. - - Args: - inputs: A (structure of) input tensors. - sequence_length: An int32 vector tensor. - embedding: A callable that takes a vector tensor of `ids` (argmax ids), - or the `params` argument for `embedding_lookup`. - sampling_probability: A 0D `float32` tensor: the probability of sampling - categorically from the output ids instead of reading directly from the - inputs. - time_major: Python bool. Whether the tensors in `inputs` are time major. - If `False` (default), they are assumed to be batch major. - seed: The sampling seed. - scheduling_seed: The schedule decision rule sampling seed. - name: Name scope for any created operations. - - Raises: - ValueError: if `sampling_probability` is not a scalar or vector. - """ - with ops.name_scope(name, "ScheduledEmbeddingSamplingWrapper", - [embedding, sampling_probability]): - if callable(embedding): - self._embedding_fn = embedding - else: - self._embedding_fn = ( - lambda ids: embedding_ops.embedding_lookup(embedding, ids)) - self._sampling_probability = ops.convert_to_tensor( - sampling_probability, name="sampling_probability") - if self._sampling_probability.get_shape().ndims not in (0, 1): - raise ValueError( - "sampling_probability must be either a scalar or a vector. " - "saw shape: %s" % (self._sampling_probability.get_shape())) - self._seed = seed - self._scheduling_seed = scheduling_seed - super(ScheduledEmbeddingTrainingSampler, self).__init__( - inputs=inputs, - sequence_length=sequence_length, - time_major=time_major, - name=name) - - def initialize(self, name=None): - return super(ScheduledEmbeddingTrainingSampler, self).initialize( - name=name) - - def sample(self, time, outputs, state, name=None): - with ops.name_scope(name, "ScheduledEmbeddingTrainingSamplerSample", - [time, outputs, state]): - # Return -1s where we did not sample, and sample_ids elsewhere - select_sample_noise = random_ops.random_uniform( - [self.batch_size], seed=self._scheduling_seed) - select_sample = (self._sampling_probability > select_sample_noise) - sample_id_sampler = categorical.Categorical(logits=outputs) - return array_ops.where( - select_sample, - sample_id_sampler.sample(seed=self._seed), - array_ops.tile([-1], [self.batch_size])) - - def next_inputs(self, time, outputs, state, sample_ids, name=None): - with ops.name_scope(name, "ScheduledEmbeddingTrainingSamplerSample", - [time, outputs, state, sample_ids]): - (finished, base_next_inputs, state) = ( - super(ScheduledEmbeddingTrainingSampler, self).next_inputs( - time=time, outputs=outputs, state=state, sample_ids=sample_ids, - name=name)) - - def maybe_sample(): - """Perform scheduled sampling.""" - where_sampling = math_ops.cast( - array_ops.where(sample_ids > -1), dtypes.int32) - where_not_sampling = math_ops.cast( - array_ops.where(sample_ids <= -1), dtypes.int32) - where_sampling_flat = array_ops.reshape(where_sampling, [-1]) - where_not_sampling_flat = array_ops.reshape(where_not_sampling, [-1]) - sample_ids_sampling = array_ops.gather(sample_ids, where_sampling_flat) - inputs_not_sampling = array_ops.gather( - base_next_inputs, where_not_sampling_flat) - sampled_next_inputs = self._embedding_fn(sample_ids_sampling) - base_shape = array_ops.shape(base_next_inputs) - return (array_ops.scatter_nd(indices=where_sampling, - updates=sampled_next_inputs, - shape=base_shape) - + array_ops.scatter_nd(indices=where_not_sampling, - updates=inputs_not_sampling, - shape=base_shape)) - - all_finished = math_ops.reduce_all(finished) - next_inputs = control_flow_ops.cond( - all_finished, lambda: base_next_inputs, maybe_sample) - return (finished, next_inputs, state) - - -class GreedyEmbeddingSampler(Sampler): - """A (non-)sampler for use during inference. - - Uses the argmax of the output (treated as logits) and passes the - result through an embedding layer to get the next input. - """ - - def __init__(self, embedding, start_tokens, end_token): - """Initializer. - - Args: - embedding: A callable that takes a vector tensor of `ids` (argmax ids), - or the `params` argument for `embedding_lookup`. - start_tokens: `int32` vector shaped `[batch_size]`, the start tokens. - end_token: `int32` scalar, the token that marks end of decoding. - - Raises: - ValueError: if `sequence_length` is not a 1D tensor. - """ - if callable(embedding): - self._embedding_fn = embedding - else: - self._embedding_fn = ( - lambda ids: embedding_ops.embedding_lookup(embedding, ids)) - - self._start_tokens = ops.convert_to_tensor( - start_tokens, dtype=dtypes.int32, name="start_tokens") - self._end_token = ops.convert_to_tensor( - end_token, dtype=dtypes.int32, name="end_token") - if self._start_tokens.get_shape().ndims != 1: - raise ValueError("start_tokens must be a vector") - self._batch_size = array_ops.size(start_tokens) - if self._end_token.get_shape().ndims != 0: - raise ValueError("end_token must be a scalar") - self._start_inputs = self._embedding_fn(self._start_tokens) - - @property - def batch_size(self): - return self._batch_size - - def initialize(self, name=None): - finished = array_ops.tile([False], [self._batch_size]) - return (finished, self._start_inputs) - - def sample(self, time, outputs, state, name=None): - """sample for GreedyEmbeddingSampler.""" - del time, state # unused by sample_fn - # Outputs are logits, use argmax to get the most probable id - if not isinstance(outputs, ops.Tensor): - raise TypeError("Expected outputs to be a single Tensor, got: %s" % - outputs) - sample_ids = math_ops.cast( - math_ops.argmax(outputs, axis=-1), dtypes.int32) - return sample_ids - - def next_inputs(self, time, outputs, state, sample_ids, name=None): - """next_inputs_fn for GreedyEmbeddingSampler.""" - del time, outputs # unused by next_inputs_fn - finished = math_ops.equal(sample_ids, self._end_token) - all_finished = math_ops.reduce_all(finished) - next_inputs = control_flow_ops.cond( - all_finished, - # If we're finished, the next_inputs value doesn't matter - lambda: self._start_inputs, - lambda: self._embedding_fn(sample_ids)) - return (finished, next_inputs, state) -- cgit v1.2.3