diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2017-02-09 19:21:29 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-02-09 19:46:05 -0800 |
commit | 5f600c2b1daa004d45b4d63df112f85be1ee5e4b (patch) | |
tree | 8f4b3f3ffa642388f847fb761223d1c5fb1bd95a | |
parent | 3b7b39ac5dd2dceebe4b80b5e0b12316720a924b (diff) |
Refactor some code in the new seq2seq api:
1. Rename samplers to helpers and move them to helper.py.
2. Remove the redundant name "Basic" from all helpers' names.
3. Rename SamplingDecoder to BasicDecoder.
Change: 147112283
-rw-r--r-- | tensorflow/contrib/seq2seq/BUILD | 4 | ||||
-rw-r--r-- | tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py (renamed from tensorflow/contrib/seq2seq/python/kernel_tests/sampling_decoder_test.py) | 63 | ||||
-rw-r--r-- | tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py | 29 | ||||
-rw-r--r-- | tensorflow/contrib/seq2seq/python/ops/basic_decoder.py | 121 | ||||
-rw-r--r-- | tensorflow/contrib/seq2seq/python/ops/helper.py (renamed from tensorflow/contrib/seq2seq/python/ops/sampling_decoder.py) | 150 |
5 files changed, 202 insertions, 165 deletions
diff --git a/tensorflow/contrib/seq2seq/BUILD b/tensorflow/contrib/seq2seq/BUILD index c10d77f7ef..d5a9dcdce1 100644 --- a/tensorflow/contrib/seq2seq/BUILD +++ b/tensorflow/contrib/seq2seq/BUILD @@ -80,9 +80,9 @@ cuda_py_test( ) cuda_py_test( - name = "sampling_decoder_test", + name = "basic_decoder_test", size = "medium", - srcs = ["python/kernel_tests/sampling_decoder_test.py"], + srcs = ["python/kernel_tests/basic_decoder_test.py"], additional_deps = [ ":seq2seq_py", "//tensorflow/contrib/layers:layers_py", diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/sampling_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py index 3f8b4c077d..7ef0095b2e 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/sampling_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for contrib.seq2seq.python.seq2seq.sampling_decoder.""" +"""Tests for contrib.seq2seq.python.seq2seq.basic_decoder.""" # pylint: disable=unused-import,g-bad-import-order from __future__ import absolute_import from __future__ import division @@ -30,7 +30,8 @@ if hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags"): import numpy as np from tensorflow.contrib.rnn import core_rnn_cell -from tensorflow.contrib.seq2seq.python.ops import sampling_decoder +from tensorflow.contrib.seq2seq.python.ops import helper as helper_py +from tensorflow.contrib.seq2seq.python.ops import basic_decoder from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape @@ -39,9 +40,9 @@ from tensorflow.python.platform import test # pylint: enable=g-import-not-at-top -class BasicSamplingDecoderTest(test.TestCase): +class BasicDecoderTest(test.TestCase): - def testStepWithBasicTrainingSampler(self): + def testStepWithTrainingHelper(self): sequence_length = [3, 4, 3, 1, 0] batch_size = 5 max_time = 8 @@ -52,21 +53,21 @@ class BasicSamplingDecoderTest(test.TestCase): inputs = np.random.randn(batch_size, max_time, input_depth).astype(np.float32) cell = core_rnn_cell.LSTMCell(cell_depth) - sampler = sampling_decoder.BasicTrainingSampler( + helper = helper_py.TrainingHelper( inputs, sequence_length, time_major=False) - my_decoder = sampling_decoder.BasicSamplingDecoder( + my_decoder = basic_decoder.BasicDecoder( cell=cell, - sampler=sampler, + helper=helper, initial_state=cell.zero_state( dtype=dtypes.float32, batch_size=batch_size)) output_size = my_decoder.output_size output_dtype = my_decoder.output_dtype self.assertEqual( - sampling_decoder.SamplingDecoderOutput(cell_depth, - tensor_shape.TensorShape([])), + basic_decoder.BasicDecoderOutput(cell_depth, + tensor_shape.TensorShape([])), output_size) self.assertEqual( - sampling_decoder.SamplingDecoderOutput(dtypes.float32, dtypes.int32), + basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32), output_dtype) (first_finished, first_inputs, first_state) = my_decoder.initialize() @@ -78,7 +79,7 @@ class BasicSamplingDecoderTest(test.TestCase): self.assertTrue(isinstance(first_state, core_rnn_cell.LSTMStateTuple)) self.assertTrue(isinstance(step_state, core_rnn_cell.LSTMStateTuple)) self.assertTrue( - isinstance(step_outputs, sampling_decoder.SamplingDecoderOutput)) + isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape()) self.assertEqual((batch_size,), step_outputs[1].get_shape()) self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) @@ -106,7 +107,7 @@ class BasicSamplingDecoderTest(test.TestCase): np.argmax(sess_results["step_outputs"].rnn_output, -1), sess_results["step_outputs"].sample_id) - def testStepWithGreedyEmbeddingSampler(self): + def testStepWithGreedyEmbeddingHelper(self): batch_size = 5 vocabulary_size = 7 cell_depth = vocabulary_size # cell's logits must match vocabulary size @@ -118,21 +119,21 @@ class BasicSamplingDecoderTest(test.TestCase): embeddings = np.random.randn(vocabulary_size, input_depth).astype(np.float32) cell = core_rnn_cell.LSTMCell(vocabulary_size) - sampler = sampling_decoder.GreedyEmbeddingSampler( - embeddings, start_tokens, end_token) - my_decoder = sampling_decoder.BasicSamplingDecoder( + helper = helper_py.GreedyEmbeddingHelper(embeddings, start_tokens, + end_token) + my_decoder = basic_decoder.BasicDecoder( cell=cell, - sampler=sampler, + helper=helper, initial_state=cell.zero_state( dtype=dtypes.float32, batch_size=batch_size)) output_size = my_decoder.output_size output_dtype = my_decoder.output_dtype self.assertEqual( - sampling_decoder.SamplingDecoderOutput(cell_depth, - tensor_shape.TensorShape([])), + basic_decoder.BasicDecoderOutput(cell_depth, + tensor_shape.TensorShape([])), output_size) self.assertEqual( - sampling_decoder.SamplingDecoderOutput(dtypes.float32, dtypes.int32), + basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32), output_dtype) (first_finished, first_inputs, first_state) = my_decoder.initialize() @@ -144,7 +145,7 @@ class BasicSamplingDecoderTest(test.TestCase): self.assertTrue(isinstance(first_state, core_rnn_cell.LSTMStateTuple)) self.assertTrue(isinstance(step_state, core_rnn_cell.LSTMStateTuple)) self.assertTrue( - isinstance(step_outputs, sampling_decoder.SamplingDecoderOutput)) + isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape()) self.assertEqual((batch_size,), step_outputs[1].get_shape()) self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) @@ -176,7 +177,7 @@ class BasicSamplingDecoderTest(test.TestCase): self.assertAllEqual(expected_step_next_inputs, sess_results["step_next_inputs"]) - def testStepWithScheduledEmbeddingTrainingSampler(self): + def testStepWithScheduledEmbeddingTrainingHelper(self): sequence_length = [3, 4, 3, 1, 0] batch_size = 5 max_time = 8 @@ -190,23 +191,25 @@ class BasicSamplingDecoderTest(test.TestCase): vocabulary_size, input_depth).astype(np.float32) half = constant_op.constant(0.5) cell = core_rnn_cell.LSTMCell(vocabulary_size) - sampler = sampling_decoder.ScheduledEmbeddingTrainingSampler( - inputs=inputs, sequence_length=sequence_length, - embedding=embeddings, sampling_probability=half, + helper = helper_py.ScheduledEmbeddingTrainingHelper( + inputs=inputs, + sequence_length=sequence_length, + embedding=embeddings, + sampling_probability=half, time_major=False) - my_decoder = sampling_decoder.BasicSamplingDecoder( + my_decoder = basic_decoder.BasicDecoder( cell=cell, - sampler=sampler, + helper=helper, initial_state=cell.zero_state( dtype=dtypes.float32, batch_size=batch_size)) output_size = my_decoder.output_size output_dtype = my_decoder.output_dtype self.assertEqual( - sampling_decoder.SamplingDecoderOutput(vocabulary_size, - tensor_shape.TensorShape([])), + basic_decoder.BasicDecoderOutput(vocabulary_size, + tensor_shape.TensorShape([])), output_size) self.assertEqual( - sampling_decoder.SamplingDecoderOutput(dtypes.float32, dtypes.int32), + basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32), output_dtype) (first_finished, first_inputs, first_state) = my_decoder.initialize() @@ -218,7 +221,7 @@ class BasicSamplingDecoderTest(test.TestCase): self.assertTrue(isinstance(first_state, core_rnn_cell.LSTMStateTuple)) self.assertTrue(isinstance(step_state, core_rnn_cell.LSTMStateTuple)) self.assertTrue( - isinstance(step_outputs, sampling_decoder.SamplingDecoderOutput)) + isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) self.assertEqual((batch_size, vocabulary_size), step_outputs[0].get_shape()) self.assertEqual((batch_size,), step_outputs[1].get_shape()) diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py index 851eb6ef92..76de154b48 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py @@ -31,7 +31,8 @@ import numpy as np from tensorflow.contrib.rnn import core_rnn_cell from tensorflow.contrib.seq2seq.python.ops import decoder -from tensorflow.contrib.seq2seq.python.ops import sampling_decoder +from tensorflow.contrib.seq2seq.python.ops import helper as helper_py +from tensorflow.contrib.seq2seq.python.ops import basic_decoder from tensorflow.python.framework import dtypes from tensorflow.python.ops import rnn from tensorflow.python.ops import variables @@ -59,11 +60,11 @@ class DynamicDecodeRNNTest(test.TestCase): inputs = np.random.randn(batch_size, max_time, input_depth).astype(np.float32) cell = core_rnn_cell.LSTMCell(cell_depth) - sampler = sampling_decoder.BasicTrainingSampler( + helper = helper_py.TrainingHelper( inputs, sequence_length, time_major=time_major) - my_decoder = sampling_decoder.BasicSamplingDecoder( + my_decoder = basic_decoder.BasicDecoder( cell=cell, - sampler=sampler, + helper=helper, initial_state=cell.zero_state( dtype=dtypes.float32, batch_size=batch_size)) @@ -77,7 +78,7 @@ class DynamicDecodeRNNTest(test.TestCase): return shape self.assertTrue( - isinstance(final_outputs, sampling_decoder.SamplingDecoderOutput)) + isinstance(final_outputs, basic_decoder.BasicDecoderOutput)) self.assertTrue(isinstance(final_state, core_rnn_cell.LSTMStateTuple)) self.assertEqual( @@ -116,7 +117,7 @@ class DynamicDecodeRNNTest(test.TestCase): def testDynamicDecodeRNNOneMaxIter(self): self._testDynamicDecodeRNN(time_major=True, maximum_iterations=1) - def _testDynamicDecodeRNNWithBasicTrainingSamplerMatchesDynamicRNN( + def _testDynamicDecodeRNNWithTrainingHelperMatchesDynamicRNN( self, use_sequence_length): sequence_length = [3, 4, 3, 1, 0] batch_size = 5 @@ -131,9 +132,9 @@ class DynamicDecodeRNNTest(test.TestCase): cell = core_rnn_cell.LSTMCell(cell_depth) zero_state = cell.zero_state(dtype=dtypes.float32, batch_size=batch_size) - sampler = sampling_decoder.BasicTrainingSampler(inputs, sequence_length) - my_decoder = sampling_decoder.BasicSamplingDecoder( - cell=cell, sampler=sampler, initial_state=zero_state) + helper = helper_py.TrainingHelper(inputs, sequence_length) + my_decoder = basic_decoder.BasicDecoder( + cell=cell, helper=helper, initial_state=zero_state) # Match the variable scope of dynamic_rnn below so we end up # using the same variables @@ -169,14 +170,12 @@ class DynamicDecodeRNNTest(test.TestCase): self.assertAllClose(sess_results["final_decoder_state"], sess_results["final_rnn_state"]) - def testDynamicDecodeRNNWithBasicTrainingSamplerMatchesDynamicRNNWithSeqLen( - self): - self._testDynamicDecodeRNNWithBasicTrainingSamplerMatchesDynamicRNN( + def testDynamicDecodeRNNWithTrainingHelperMatchesDynamicRNNWithSeqLen(self): + self._testDynamicDecodeRNNWithTrainingHelperMatchesDynamicRNN( use_sequence_length=True) - def testDynamicDecodeRNNWithBasicTrainingSamplerMatchesDynamicRNNNoSeqLen( - self): - self._testDynamicDecodeRNNWithBasicTrainingSamplerMatchesDynamicRNN( + def testDynamicDecodeRNNWithTrainingHelperMatchesDynamicRNNNoSeqLen(self): + self._testDynamicDecodeRNNWithTrainingHelperMatchesDynamicRNN( use_sequence_length=False) if __name__ == "__main__": diff --git a/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py b/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py new file mode 100644 index 0000000000..eac2c179b2 --- /dev/null +++ b/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py @@ -0,0 +1,121 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A class of Decoders that may sample to generate the next input. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from tensorflow.contrib.rnn import core_rnn_cell +from tensorflow.contrib.seq2seq.python.ops import decoder +from tensorflow.contrib.seq2seq.python.ops import helper as helper_py +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.util import nest + + +__all__ = [ + "BasicDecoderOutput", + "BasicDecoder", +] + + +class BasicDecoderOutput( + collections.namedtuple("BasicDecoderOutput", ("rnn_output", "sample_id"))): + pass + + +class BasicDecoder(decoder.Decoder): + """Basic sampling decoder.""" + + def __init__(self, cell, helper, initial_state): + """Initialize BasicDecoder. + + Args: + cell: An `RNNCell` instance. + helper: A `Helper` instance. + initial_state: A (possibly nested tuple of...) tensors and TensorArrays. + + Raises: + TypeError: if `cell` is not an instance of `RNNCell` or `helper` + is not an instance of `Helper`. + """ + if not isinstance(cell, core_rnn_cell.RNNCell): + raise TypeError("cell must be an RNNCell, received: %s" % type(cell)) + if not isinstance(helper, helper_py.Helper): + raise TypeError("helper must be a Helper, received: %s" % type(helper)) + self._cell = cell + self._helper = helper + self._initial_state = initial_state + + @property + def batch_size(self): + return self._helper.batch_size + + @property + def output_size(self): + # Return the cell output and the id + return BasicDecoderOutput( + rnn_output=self._cell.output_size, + sample_id=tensor_shape.TensorShape([])) + + @property + def output_dtype(self): + # Assume the dtype of the cell is the output_size structure + # containing the input_state's first component's dtype. + # Return that structure and int32 (the id) + dtype = nest.flatten(self._initial_state)[0].dtype + return BasicDecoderOutput( + nest.map_structure(lambda _: dtype, self._cell.output_size), + dtypes.int32) + + def initialize(self, name=None): + """Initialize the decoder. + + Args: + name: Name scope for any created operations. + + Returns: + `(finished, first_inputs, initial_state)`. + """ + return self._helper.initialize() + (self._initial_state,) + + def step(self, time, inputs, state, name=None): + """Perform a decoding step. + + Args: + time: scalar `int32` tensor. + inputs: A (structure of) input tensors. + state: A (structure of) state tensors and TensorArrays. + name: Name scope for any created operations. + + Returns: + `(outputs, next_state, next_inputs, finished)`. + """ + with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state)): + cell_outputs, cell_state = self._cell(inputs, state) + sample_ids = self._helper.sample( + time=time, outputs=cell_outputs, state=cell_state) + (finished, next_inputs, next_state) = self._helper.next_inputs( + time=time, + outputs=cell_outputs, + state=cell_state, + sample_ids=sample_ids) + outputs = BasicDecoderOutput(cell_outputs, sample_ids) + return (outputs, next_state, next_inputs, finished) diff --git a/tensorflow/contrib/seq2seq/python/ops/sampling_decoder.py b/tensorflow/contrib/seq2seq/python/ops/helper.py index 3cd986cb04..46d0563fe0 100644 --- a/tensorflow/contrib/seq2seq/python/ops/sampling_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/helper.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""A class of Decoders that may sample to generate the next input. +"""A library of helpers for use with SamplingDecoders. """ from __future__ import absolute_import @@ -20,16 +20,13 @@ from __future__ import division from __future__ import print_function import abc -import collections import six from tensorflow.contrib.distributions.python.ops import categorical -from tensorflow.contrib.rnn import core_rnn_cell from tensorflow.contrib.seq2seq.python.ops import decoder from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import embedding_ops @@ -39,17 +36,19 @@ from tensorflow.python.ops import tensor_array_ops from tensorflow.python.util import nest __all__ = [ - "Sampler", "SamplingDecoderOutput", "BasicSamplingDecoder", - "BasicTrainingSampler", "GreedyEmbeddingSampler", "CustomSampler", - "ScheduledEmbeddingTrainingSampler", + "Helper", + "TrainingHelper", + "GreedyEmbeddingHelper", + "CustomHelper", + "ScheduledEmbeddingTrainingHelper", ] _transpose_batch_time = decoder._transpose_batch_time # pylint: disable=protected-access @six.add_metaclass(abc.ABCMeta) -class Sampler(object): - """Sampler interface. Sampler instances are used by BasicSamplingDecoder.""" +class Helper(object): + """Helper interface. Helper instances are used by SamplingDecoder.""" @abc.abstractproperty def batch_size(self): @@ -72,93 +71,7 @@ class Sampler(object): pass -class SamplingDecoderOutput( - collections.namedtuple("SamplingDecoderOutput", - ("rnn_output", "sample_id"))): - pass - - -class BasicSamplingDecoder(decoder.Decoder): - """Basic sampling decoder.""" - - def __init__(self, cell, sampler, initial_state): - """Initialize BasicSamplingDecoder. - - Args: - cell: An `RNNCell` instance. - sampler: A `Sampler` instance. - initial_state: A (possibly nested tuple of...) tensors and TensorArrays. - - Raises: - TypeError: if `cell` is not an instance of `RNNCell` or `sampler` - is not an instance of `Sampler`. - """ - if not isinstance(cell, core_rnn_cell.RNNCell): - raise TypeError("cell must be an RNNCell, received: %s" % type(cell)) - if not isinstance(sampler, Sampler): - raise TypeError("sampler must be a Sampler, received: %s" % - type(sampler)) - self._cell = cell - self._sampler = sampler - self._initial_state = initial_state - - @property - def batch_size(self): - return self._sampler.batch_size - - @property - def output_size(self): - # Return the cell output and the id - return SamplingDecoderOutput( - rnn_output=self._cell.output_size, - sample_id=tensor_shape.TensorShape([])) - - @property - def output_dtype(self): - # Assume the dtype of the cell is the output_size structure - # containing the input_state's first component's dtype. - # Return that structure and int32 (the id) - dtype = nest.flatten(self._initial_state)[0].dtype - return SamplingDecoderOutput( - nest.map_structure(lambda _: dtype, self._cell.output_size), - dtypes.int32) - - def initialize(self, name=None): - """Initialize the decoder. - - Args: - name: Name scope for any created operations. - - Returns: - `(finished, first_inputs, initial_state)`. - """ - return self._sampler.initialize() + (self._initial_state,) - - def step(self, time, inputs, state, name=None): - """Perform a decoding step. - - Args: - time: scalar `int32` tensor. - inputs: A (structure of) input tensors. - state: A (structure of) state tensors and TensorArrays. - name: Name scope for any created operations. - - Returns: - `(outputs, next_state, next_inputs, finished)`. - """ - with ops.name_scope( - name, "BasicSamplingDecoderStep", (time, inputs, state)): - cell_outputs, cell_state = self._cell(inputs, state) - sample_ids = self._sampler.sample( - time=time, outputs=cell_outputs, state=cell_state) - (finished, next_inputs, next_state) = self._sampler.next_inputs( - time=time, outputs=cell_outputs, state=cell_state, - sample_ids=sample_ids) - outputs = SamplingDecoderOutput(cell_outputs, sample_ids) - return (outputs, next_state, next_inputs, finished) - - -class CustomSampler(Sampler): +class CustomHelper(Helper): """Base abstract class that allows the user to customize sampling.""" def __init__(self, initialize_fn, sample_fn, next_inputs_fn): @@ -202,8 +115,8 @@ class CustomSampler(Sampler): time=time, outputs=outputs, state=state, sample_ids=sample_ids) -class BasicTrainingSampler(Sampler): - """A (non-)sampler for use during training. Only reads inputs. +class TrainingHelper(Helper): + """A helper for use during training. Only reads inputs. Returned sample_ids are the argmax of the RNN output logits. """ @@ -221,8 +134,7 @@ class BasicTrainingSampler(Sampler): Raises: ValueError: if `sequence_length` is not a 1D tensor. """ - with ops.name_scope( - name, "BasicTrainingSampler", [inputs, sequence_length]): + with ops.name_scope(name, "TrainingHelper", [inputs, sequence_length]): inputs = ops.convert_to_tensor(inputs, name="inputs") if not time_major: inputs = nest.map_structure(_transpose_batch_time, inputs) @@ -250,7 +162,7 @@ class BasicTrainingSampler(Sampler): return self._batch_size def initialize(self, name=None): - with ops.name_scope(name, "BasicTrainingSamplerInitialize"): + with ops.name_scope(name, "TrainingHelperInitialize"): finished = math_ops.equal(0, self._sequence_length) all_finished = math_ops.reduce_all(finished) next_inputs = control_flow_ops.cond( @@ -259,15 +171,15 @@ class BasicTrainingSampler(Sampler): return (finished, next_inputs) def sample(self, time, outputs, name=None, **unused_kwargs): - with ops.name_scope(name, "BasicTrainingSamplerSample", [time, outputs]): + with ops.name_scope(name, "TrainingHelperSample", [time, outputs]): sample_ids = math_ops.cast( math_ops.argmax(outputs, axis=-1), dtypes.int32) return sample_ids def next_inputs(self, time, outputs, state, name=None, **unused_kwargs): - """next_inputs_fn for BasicTrainingSampler.""" - with ops.name_scope( - name, "BasicTrainingSamplerNextInputs", [time, outputs, state]): + """next_inputs_fn for TrainingHelper.""" + with ops.name_scope(name, "TrainingHelperNextInputs", + [time, outputs, state]): next_time = time + 1 finished = (next_time >= self._sequence_length) all_finished = math_ops.reduce_all(finished) @@ -279,8 +191,8 @@ class BasicTrainingSampler(Sampler): return (finished, next_inputs, state) -class ScheduledEmbeddingTrainingSampler(BasicTrainingSampler): - """A training sampler that adds scheduled sampling. +class ScheduledEmbeddingTrainingHelper(TrainingHelper): + """A training helper that adds scheduled sampling. Returns -1s for sample_ids where no sampling took place; valid sample id values elsewhere. @@ -322,18 +234,17 @@ class ScheduledEmbeddingTrainingSampler(BasicTrainingSampler): "saw shape: %s" % (self._sampling_probability.get_shape())) self._seed = seed self._scheduling_seed = scheduling_seed - super(ScheduledEmbeddingTrainingSampler, self).__init__( + super(ScheduledEmbeddingTrainingHelper, self).__init__( inputs=inputs, sequence_length=sequence_length, time_major=time_major, name=name) def initialize(self, name=None): - return super(ScheduledEmbeddingTrainingSampler, self).initialize( - name=name) + return super(ScheduledEmbeddingTrainingHelper, self).initialize(name=name) def sample(self, time, outputs, state, name=None): - with ops.name_scope(name, "ScheduledEmbeddingTrainingSamplerSample", + with ops.name_scope(name, "ScheduledEmbeddingTrainingHelperSample", [time, outputs, state]): # Return -1s where we did not sample, and sample_ids elsewhere select_sample_noise = random_ops.random_uniform( @@ -346,11 +257,14 @@ class ScheduledEmbeddingTrainingSampler(BasicTrainingSampler): array_ops.tile([-1], [self.batch_size])) def next_inputs(self, time, outputs, state, sample_ids, name=None): - with ops.name_scope(name, "ScheduledEmbeddingTrainingSamplerSample", + with ops.name_scope(name, "ScheduledEmbeddingTrainingHelperSample", [time, outputs, state, sample_ids]): (finished, base_next_inputs, state) = ( - super(ScheduledEmbeddingTrainingSampler, self).next_inputs( - time=time, outputs=outputs, state=state, sample_ids=sample_ids, + super(ScheduledEmbeddingTrainingHelper, self).next_inputs( + time=time, + outputs=outputs, + state=state, + sample_ids=sample_ids, name=name)) def maybe_sample(): @@ -379,8 +293,8 @@ class ScheduledEmbeddingTrainingSampler(BasicTrainingSampler): return (finished, next_inputs, state) -class GreedyEmbeddingSampler(Sampler): - """A (non-)sampler for use during inference. +class GreedyEmbeddingHelper(Helper): + """A helper for use during inference. Uses the argmax of the output (treated as logits) and passes the result through an embedding layer to get the next input. @@ -424,7 +338,7 @@ class GreedyEmbeddingSampler(Sampler): return (finished, self._start_inputs) def sample(self, time, outputs, state, name=None): - """sample for GreedyEmbeddingSampler.""" + """sample for GreedyEmbeddingHelper.""" del time, state # unused by sample_fn # Outputs are logits, use argmax to get the most probable id if not isinstance(outputs, ops.Tensor): @@ -435,7 +349,7 @@ class GreedyEmbeddingSampler(Sampler): return sample_ids def next_inputs(self, time, outputs, state, sample_ids, name=None): - """next_inputs_fn for GreedyEmbeddingSampler.""" + """next_inputs_fn for GreedyEmbeddingHelper.""" del time, outputs # unused by next_inputs_fn finished = math_ops.equal(sample_ids, self._end_token) all_finished = math_ops.reduce_all(finished) |