aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-02-09 19:21:29 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-09 19:46:05 -0800
commit5f600c2b1daa004d45b4d63df112f85be1ee5e4b (patch)
tree8f4b3f3ffa642388f847fb761223d1c5fb1bd95a
parent3b7b39ac5dd2dceebe4b80b5e0b12316720a924b (diff)
Refactor some code in the new seq2seq api:
1. Rename samplers to helpers and move them to helper.py. 2. Remove the redundant name "Basic" from all helpers' names. 3. Rename SamplingDecoder to BasicDecoder. Change: 147112283
-rw-r--r--tensorflow/contrib/seq2seq/BUILD4
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py (renamed from tensorflow/contrib/seq2seq/python/kernel_tests/sampling_decoder_test.py)63
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py29
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/basic_decoder.py121
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/helper.py (renamed from tensorflow/contrib/seq2seq/python/ops/sampling_decoder.py)150
5 files changed, 202 insertions, 165 deletions
diff --git a/tensorflow/contrib/seq2seq/BUILD b/tensorflow/contrib/seq2seq/BUILD
index c10d77f7ef..d5a9dcdce1 100644
--- a/tensorflow/contrib/seq2seq/BUILD
+++ b/tensorflow/contrib/seq2seq/BUILD
@@ -80,9 +80,9 @@ cuda_py_test(
)
cuda_py_test(
- name = "sampling_decoder_test",
+ name = "basic_decoder_test",
size = "medium",
- srcs = ["python/kernel_tests/sampling_decoder_test.py"],
+ srcs = ["python/kernel_tests/basic_decoder_test.py"],
additional_deps = [
":seq2seq_py",
"//tensorflow/contrib/layers:layers_py",
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/sampling_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py
index 3f8b4c077d..7ef0095b2e 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/sampling_decoder_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for contrib.seq2seq.python.seq2seq.sampling_decoder."""
+"""Tests for contrib.seq2seq.python.seq2seq.basic_decoder."""
# pylint: disable=unused-import,g-bad-import-order
from __future__ import absolute_import
from __future__ import division
@@ -30,7 +30,8 @@ if hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags"):
import numpy as np
from tensorflow.contrib.rnn import core_rnn_cell
-from tensorflow.contrib.seq2seq.python.ops import sampling_decoder
+from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
+from tensorflow.contrib.seq2seq.python.ops import basic_decoder
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
@@ -39,9 +40,9 @@ from tensorflow.python.platform import test
# pylint: enable=g-import-not-at-top
-class BasicSamplingDecoderTest(test.TestCase):
+class BasicDecoderTest(test.TestCase):
- def testStepWithBasicTrainingSampler(self):
+ def testStepWithTrainingHelper(self):
sequence_length = [3, 4, 3, 1, 0]
batch_size = 5
max_time = 8
@@ -52,21 +53,21 @@ class BasicSamplingDecoderTest(test.TestCase):
inputs = np.random.randn(batch_size, max_time,
input_depth).astype(np.float32)
cell = core_rnn_cell.LSTMCell(cell_depth)
- sampler = sampling_decoder.BasicTrainingSampler(
+ helper = helper_py.TrainingHelper(
inputs, sequence_length, time_major=False)
- my_decoder = sampling_decoder.BasicSamplingDecoder(
+ my_decoder = basic_decoder.BasicDecoder(
cell=cell,
- sampler=sampler,
+ helper=helper,
initial_state=cell.zero_state(
dtype=dtypes.float32, batch_size=batch_size))
output_size = my_decoder.output_size
output_dtype = my_decoder.output_dtype
self.assertEqual(
- sampling_decoder.SamplingDecoderOutput(cell_depth,
- tensor_shape.TensorShape([])),
+ basic_decoder.BasicDecoderOutput(cell_depth,
+ tensor_shape.TensorShape([])),
output_size)
self.assertEqual(
- sampling_decoder.SamplingDecoderOutput(dtypes.float32, dtypes.int32),
+ basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32),
output_dtype)
(first_finished, first_inputs, first_state) = my_decoder.initialize()
@@ -78,7 +79,7 @@ class BasicSamplingDecoderTest(test.TestCase):
self.assertTrue(isinstance(first_state, core_rnn_cell.LSTMStateTuple))
self.assertTrue(isinstance(step_state, core_rnn_cell.LSTMStateTuple))
self.assertTrue(
- isinstance(step_outputs, sampling_decoder.SamplingDecoderOutput))
+ isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape())
self.assertEqual((batch_size,), step_outputs[1].get_shape())
self.assertEqual((batch_size, cell_depth), first_state[0].get_shape())
@@ -106,7 +107,7 @@ class BasicSamplingDecoderTest(test.TestCase):
np.argmax(sess_results["step_outputs"].rnn_output, -1),
sess_results["step_outputs"].sample_id)
- def testStepWithGreedyEmbeddingSampler(self):
+ def testStepWithGreedyEmbeddingHelper(self):
batch_size = 5
vocabulary_size = 7
cell_depth = vocabulary_size # cell's logits must match vocabulary size
@@ -118,21 +119,21 @@ class BasicSamplingDecoderTest(test.TestCase):
embeddings = np.random.randn(vocabulary_size,
input_depth).astype(np.float32)
cell = core_rnn_cell.LSTMCell(vocabulary_size)
- sampler = sampling_decoder.GreedyEmbeddingSampler(
- embeddings, start_tokens, end_token)
- my_decoder = sampling_decoder.BasicSamplingDecoder(
+ helper = helper_py.GreedyEmbeddingHelper(embeddings, start_tokens,
+ end_token)
+ my_decoder = basic_decoder.BasicDecoder(
cell=cell,
- sampler=sampler,
+ helper=helper,
initial_state=cell.zero_state(
dtype=dtypes.float32, batch_size=batch_size))
output_size = my_decoder.output_size
output_dtype = my_decoder.output_dtype
self.assertEqual(
- sampling_decoder.SamplingDecoderOutput(cell_depth,
- tensor_shape.TensorShape([])),
+ basic_decoder.BasicDecoderOutput(cell_depth,
+ tensor_shape.TensorShape([])),
output_size)
self.assertEqual(
- sampling_decoder.SamplingDecoderOutput(dtypes.float32, dtypes.int32),
+ basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32),
output_dtype)
(first_finished, first_inputs, first_state) = my_decoder.initialize()
@@ -144,7 +145,7 @@ class BasicSamplingDecoderTest(test.TestCase):
self.assertTrue(isinstance(first_state, core_rnn_cell.LSTMStateTuple))
self.assertTrue(isinstance(step_state, core_rnn_cell.LSTMStateTuple))
self.assertTrue(
- isinstance(step_outputs, sampling_decoder.SamplingDecoderOutput))
+ isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape())
self.assertEqual((batch_size,), step_outputs[1].get_shape())
self.assertEqual((batch_size, cell_depth), first_state[0].get_shape())
@@ -176,7 +177,7 @@ class BasicSamplingDecoderTest(test.TestCase):
self.assertAllEqual(expected_step_next_inputs,
sess_results["step_next_inputs"])
- def testStepWithScheduledEmbeddingTrainingSampler(self):
+ def testStepWithScheduledEmbeddingTrainingHelper(self):
sequence_length = [3, 4, 3, 1, 0]
batch_size = 5
max_time = 8
@@ -190,23 +191,25 @@ class BasicSamplingDecoderTest(test.TestCase):
vocabulary_size, input_depth).astype(np.float32)
half = constant_op.constant(0.5)
cell = core_rnn_cell.LSTMCell(vocabulary_size)
- sampler = sampling_decoder.ScheduledEmbeddingTrainingSampler(
- inputs=inputs, sequence_length=sequence_length,
- embedding=embeddings, sampling_probability=half,
+ helper = helper_py.ScheduledEmbeddingTrainingHelper(
+ inputs=inputs,
+ sequence_length=sequence_length,
+ embedding=embeddings,
+ sampling_probability=half,
time_major=False)
- my_decoder = sampling_decoder.BasicSamplingDecoder(
+ my_decoder = basic_decoder.BasicDecoder(
cell=cell,
- sampler=sampler,
+ helper=helper,
initial_state=cell.zero_state(
dtype=dtypes.float32, batch_size=batch_size))
output_size = my_decoder.output_size
output_dtype = my_decoder.output_dtype
self.assertEqual(
- sampling_decoder.SamplingDecoderOutput(vocabulary_size,
- tensor_shape.TensorShape([])),
+ basic_decoder.BasicDecoderOutput(vocabulary_size,
+ tensor_shape.TensorShape([])),
output_size)
self.assertEqual(
- sampling_decoder.SamplingDecoderOutput(dtypes.float32, dtypes.int32),
+ basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32),
output_dtype)
(first_finished, first_inputs, first_state) = my_decoder.initialize()
@@ -218,7 +221,7 @@ class BasicSamplingDecoderTest(test.TestCase):
self.assertTrue(isinstance(first_state, core_rnn_cell.LSTMStateTuple))
self.assertTrue(isinstance(step_state, core_rnn_cell.LSTMStateTuple))
self.assertTrue(
- isinstance(step_outputs, sampling_decoder.SamplingDecoderOutput))
+ isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
self.assertEqual((batch_size, vocabulary_size),
step_outputs[0].get_shape())
self.assertEqual((batch_size,), step_outputs[1].get_shape())
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py
index 851eb6ef92..76de154b48 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py
@@ -31,7 +31,8 @@ import numpy as np
from tensorflow.contrib.rnn import core_rnn_cell
from tensorflow.contrib.seq2seq.python.ops import decoder
-from tensorflow.contrib.seq2seq.python.ops import sampling_decoder
+from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
+from tensorflow.contrib.seq2seq.python.ops import basic_decoder
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import rnn
from tensorflow.python.ops import variables
@@ -59,11 +60,11 @@ class DynamicDecodeRNNTest(test.TestCase):
inputs = np.random.randn(batch_size, max_time,
input_depth).astype(np.float32)
cell = core_rnn_cell.LSTMCell(cell_depth)
- sampler = sampling_decoder.BasicTrainingSampler(
+ helper = helper_py.TrainingHelper(
inputs, sequence_length, time_major=time_major)
- my_decoder = sampling_decoder.BasicSamplingDecoder(
+ my_decoder = basic_decoder.BasicDecoder(
cell=cell,
- sampler=sampler,
+ helper=helper,
initial_state=cell.zero_state(
dtype=dtypes.float32, batch_size=batch_size))
@@ -77,7 +78,7 @@ class DynamicDecodeRNNTest(test.TestCase):
return shape
self.assertTrue(
- isinstance(final_outputs, sampling_decoder.SamplingDecoderOutput))
+ isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
self.assertTrue(isinstance(final_state, core_rnn_cell.LSTMStateTuple))
self.assertEqual(
@@ -116,7 +117,7 @@ class DynamicDecodeRNNTest(test.TestCase):
def testDynamicDecodeRNNOneMaxIter(self):
self._testDynamicDecodeRNN(time_major=True, maximum_iterations=1)
- def _testDynamicDecodeRNNWithBasicTrainingSamplerMatchesDynamicRNN(
+ def _testDynamicDecodeRNNWithTrainingHelperMatchesDynamicRNN(
self, use_sequence_length):
sequence_length = [3, 4, 3, 1, 0]
batch_size = 5
@@ -131,9 +132,9 @@ class DynamicDecodeRNNTest(test.TestCase):
cell = core_rnn_cell.LSTMCell(cell_depth)
zero_state = cell.zero_state(dtype=dtypes.float32, batch_size=batch_size)
- sampler = sampling_decoder.BasicTrainingSampler(inputs, sequence_length)
- my_decoder = sampling_decoder.BasicSamplingDecoder(
- cell=cell, sampler=sampler, initial_state=zero_state)
+ helper = helper_py.TrainingHelper(inputs, sequence_length)
+ my_decoder = basic_decoder.BasicDecoder(
+ cell=cell, helper=helper, initial_state=zero_state)
# Match the variable scope of dynamic_rnn below so we end up
# using the same variables
@@ -169,14 +170,12 @@ class DynamicDecodeRNNTest(test.TestCase):
self.assertAllClose(sess_results["final_decoder_state"],
sess_results["final_rnn_state"])
- def testDynamicDecodeRNNWithBasicTrainingSamplerMatchesDynamicRNNWithSeqLen(
- self):
- self._testDynamicDecodeRNNWithBasicTrainingSamplerMatchesDynamicRNN(
+ def testDynamicDecodeRNNWithTrainingHelperMatchesDynamicRNNWithSeqLen(self):
+ self._testDynamicDecodeRNNWithTrainingHelperMatchesDynamicRNN(
use_sequence_length=True)
- def testDynamicDecodeRNNWithBasicTrainingSamplerMatchesDynamicRNNNoSeqLen(
- self):
- self._testDynamicDecodeRNNWithBasicTrainingSamplerMatchesDynamicRNN(
+ def testDynamicDecodeRNNWithTrainingHelperMatchesDynamicRNNNoSeqLen(self):
+ self._testDynamicDecodeRNNWithTrainingHelperMatchesDynamicRNN(
use_sequence_length=False)
if __name__ == "__main__":
diff --git a/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py b/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py
new file mode 100644
index 0000000000..eac2c179b2
--- /dev/null
+++ b/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py
@@ -0,0 +1,121 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A class of Decoders that may sample to generate the next input.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from tensorflow.contrib.rnn import core_rnn_cell
+from tensorflow.contrib.seq2seq.python.ops import decoder
+from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.util import nest
+
+
+__all__ = [
+ "BasicDecoderOutput",
+ "BasicDecoder",
+]
+
+
+class BasicDecoderOutput(
+ collections.namedtuple("BasicDecoderOutput", ("rnn_output", "sample_id"))):
+ pass
+
+
+class BasicDecoder(decoder.Decoder):
+ """Basic sampling decoder."""
+
+ def __init__(self, cell, helper, initial_state):
+ """Initialize BasicDecoder.
+
+ Args:
+ cell: An `RNNCell` instance.
+ helper: A `Helper` instance.
+ initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
+
+ Raises:
+ TypeError: if `cell` is not an instance of `RNNCell` or `helper`
+ is not an instance of `Helper`.
+ """
+ if not isinstance(cell, core_rnn_cell.RNNCell):
+ raise TypeError("cell must be an RNNCell, received: %s" % type(cell))
+ if not isinstance(helper, helper_py.Helper):
+ raise TypeError("helper must be a Helper, received: %s" % type(helper))
+ self._cell = cell
+ self._helper = helper
+ self._initial_state = initial_state
+
+ @property
+ def batch_size(self):
+ return self._helper.batch_size
+
+ @property
+ def output_size(self):
+ # Return the cell output and the id
+ return BasicDecoderOutput(
+ rnn_output=self._cell.output_size,
+ sample_id=tensor_shape.TensorShape([]))
+
+ @property
+ def output_dtype(self):
+ # Assume the dtype of the cell is the output_size structure
+ # containing the input_state's first component's dtype.
+ # Return that structure and int32 (the id)
+ dtype = nest.flatten(self._initial_state)[0].dtype
+ return BasicDecoderOutput(
+ nest.map_structure(lambda _: dtype, self._cell.output_size),
+ dtypes.int32)
+
+ def initialize(self, name=None):
+ """Initialize the decoder.
+
+ Args:
+ name: Name scope for any created operations.
+
+ Returns:
+ `(finished, first_inputs, initial_state)`.
+ """
+ return self._helper.initialize() + (self._initial_state,)
+
+ def step(self, time, inputs, state, name=None):
+ """Perform a decoding step.
+
+ Args:
+ time: scalar `int32` tensor.
+ inputs: A (structure of) input tensors.
+ state: A (structure of) state tensors and TensorArrays.
+ name: Name scope for any created operations.
+
+ Returns:
+ `(outputs, next_state, next_inputs, finished)`.
+ """
+ with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state)):
+ cell_outputs, cell_state = self._cell(inputs, state)
+ sample_ids = self._helper.sample(
+ time=time, outputs=cell_outputs, state=cell_state)
+ (finished, next_inputs, next_state) = self._helper.next_inputs(
+ time=time,
+ outputs=cell_outputs,
+ state=cell_state,
+ sample_ids=sample_ids)
+ outputs = BasicDecoderOutput(cell_outputs, sample_ids)
+ return (outputs, next_state, next_inputs, finished)
diff --git a/tensorflow/contrib/seq2seq/python/ops/sampling_decoder.py b/tensorflow/contrib/seq2seq/python/ops/helper.py
index 3cd986cb04..46d0563fe0 100644
--- a/tensorflow/contrib/seq2seq/python/ops/sampling_decoder.py
+++ b/tensorflow/contrib/seq2seq/python/ops/helper.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""A class of Decoders that may sample to generate the next input.
+"""A library of helpers for use with SamplingDecoders.
"""
from __future__ import absolute_import
@@ -20,16 +20,13 @@ from __future__ import division
from __future__ import print_function
import abc
-import collections
import six
from tensorflow.contrib.distributions.python.ops import categorical
-from tensorflow.contrib.rnn import core_rnn_cell
from tensorflow.contrib.seq2seq.python.ops import decoder
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import embedding_ops
@@ -39,17 +36,19 @@ from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.util import nest
__all__ = [
- "Sampler", "SamplingDecoderOutput", "BasicSamplingDecoder",
- "BasicTrainingSampler", "GreedyEmbeddingSampler", "CustomSampler",
- "ScheduledEmbeddingTrainingSampler",
+ "Helper",
+ "TrainingHelper",
+ "GreedyEmbeddingHelper",
+ "CustomHelper",
+ "ScheduledEmbeddingTrainingHelper",
]
_transpose_batch_time = decoder._transpose_batch_time # pylint: disable=protected-access
@six.add_metaclass(abc.ABCMeta)
-class Sampler(object):
- """Sampler interface. Sampler instances are used by BasicSamplingDecoder."""
+class Helper(object):
+ """Helper interface. Helper instances are used by SamplingDecoder."""
@abc.abstractproperty
def batch_size(self):
@@ -72,93 +71,7 @@ class Sampler(object):
pass
-class SamplingDecoderOutput(
- collections.namedtuple("SamplingDecoderOutput",
- ("rnn_output", "sample_id"))):
- pass
-
-
-class BasicSamplingDecoder(decoder.Decoder):
- """Basic sampling decoder."""
-
- def __init__(self, cell, sampler, initial_state):
- """Initialize BasicSamplingDecoder.
-
- Args:
- cell: An `RNNCell` instance.
- sampler: A `Sampler` instance.
- initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
-
- Raises:
- TypeError: if `cell` is not an instance of `RNNCell` or `sampler`
- is not an instance of `Sampler`.
- """
- if not isinstance(cell, core_rnn_cell.RNNCell):
- raise TypeError("cell must be an RNNCell, received: %s" % type(cell))
- if not isinstance(sampler, Sampler):
- raise TypeError("sampler must be a Sampler, received: %s" %
- type(sampler))
- self._cell = cell
- self._sampler = sampler
- self._initial_state = initial_state
-
- @property
- def batch_size(self):
- return self._sampler.batch_size
-
- @property
- def output_size(self):
- # Return the cell output and the id
- return SamplingDecoderOutput(
- rnn_output=self._cell.output_size,
- sample_id=tensor_shape.TensorShape([]))
-
- @property
- def output_dtype(self):
- # Assume the dtype of the cell is the output_size structure
- # containing the input_state's first component's dtype.
- # Return that structure and int32 (the id)
- dtype = nest.flatten(self._initial_state)[0].dtype
- return SamplingDecoderOutput(
- nest.map_structure(lambda _: dtype, self._cell.output_size),
- dtypes.int32)
-
- def initialize(self, name=None):
- """Initialize the decoder.
-
- Args:
- name: Name scope for any created operations.
-
- Returns:
- `(finished, first_inputs, initial_state)`.
- """
- return self._sampler.initialize() + (self._initial_state,)
-
- def step(self, time, inputs, state, name=None):
- """Perform a decoding step.
-
- Args:
- time: scalar `int32` tensor.
- inputs: A (structure of) input tensors.
- state: A (structure of) state tensors and TensorArrays.
- name: Name scope for any created operations.
-
- Returns:
- `(outputs, next_state, next_inputs, finished)`.
- """
- with ops.name_scope(
- name, "BasicSamplingDecoderStep", (time, inputs, state)):
- cell_outputs, cell_state = self._cell(inputs, state)
- sample_ids = self._sampler.sample(
- time=time, outputs=cell_outputs, state=cell_state)
- (finished, next_inputs, next_state) = self._sampler.next_inputs(
- time=time, outputs=cell_outputs, state=cell_state,
- sample_ids=sample_ids)
- outputs = SamplingDecoderOutput(cell_outputs, sample_ids)
- return (outputs, next_state, next_inputs, finished)
-
-
-class CustomSampler(Sampler):
+class CustomHelper(Helper):
"""Base abstract class that allows the user to customize sampling."""
def __init__(self, initialize_fn, sample_fn, next_inputs_fn):
@@ -202,8 +115,8 @@ class CustomSampler(Sampler):
time=time, outputs=outputs, state=state, sample_ids=sample_ids)
-class BasicTrainingSampler(Sampler):
- """A (non-)sampler for use during training. Only reads inputs.
+class TrainingHelper(Helper):
+ """A helper for use during training. Only reads inputs.
Returned sample_ids are the argmax of the RNN output logits.
"""
@@ -221,8 +134,7 @@ class BasicTrainingSampler(Sampler):
Raises:
ValueError: if `sequence_length` is not a 1D tensor.
"""
- with ops.name_scope(
- name, "BasicTrainingSampler", [inputs, sequence_length]):
+ with ops.name_scope(name, "TrainingHelper", [inputs, sequence_length]):
inputs = ops.convert_to_tensor(inputs, name="inputs")
if not time_major:
inputs = nest.map_structure(_transpose_batch_time, inputs)
@@ -250,7 +162,7 @@ class BasicTrainingSampler(Sampler):
return self._batch_size
def initialize(self, name=None):
- with ops.name_scope(name, "BasicTrainingSamplerInitialize"):
+ with ops.name_scope(name, "TrainingHelperInitialize"):
finished = math_ops.equal(0, self._sequence_length)
all_finished = math_ops.reduce_all(finished)
next_inputs = control_flow_ops.cond(
@@ -259,15 +171,15 @@ class BasicTrainingSampler(Sampler):
return (finished, next_inputs)
def sample(self, time, outputs, name=None, **unused_kwargs):
- with ops.name_scope(name, "BasicTrainingSamplerSample", [time, outputs]):
+ with ops.name_scope(name, "TrainingHelperSample", [time, outputs]):
sample_ids = math_ops.cast(
math_ops.argmax(outputs, axis=-1), dtypes.int32)
return sample_ids
def next_inputs(self, time, outputs, state, name=None, **unused_kwargs):
- """next_inputs_fn for BasicTrainingSampler."""
- with ops.name_scope(
- name, "BasicTrainingSamplerNextInputs", [time, outputs, state]):
+ """next_inputs_fn for TrainingHelper."""
+ with ops.name_scope(name, "TrainingHelperNextInputs",
+ [time, outputs, state]):
next_time = time + 1
finished = (next_time >= self._sequence_length)
all_finished = math_ops.reduce_all(finished)
@@ -279,8 +191,8 @@ class BasicTrainingSampler(Sampler):
return (finished, next_inputs, state)
-class ScheduledEmbeddingTrainingSampler(BasicTrainingSampler):
- """A training sampler that adds scheduled sampling.
+class ScheduledEmbeddingTrainingHelper(TrainingHelper):
+ """A training helper that adds scheduled sampling.
Returns -1s for sample_ids where no sampling took place; valid sample id
values elsewhere.
@@ -322,18 +234,17 @@ class ScheduledEmbeddingTrainingSampler(BasicTrainingSampler):
"saw shape: %s" % (self._sampling_probability.get_shape()))
self._seed = seed
self._scheduling_seed = scheduling_seed
- super(ScheduledEmbeddingTrainingSampler, self).__init__(
+ super(ScheduledEmbeddingTrainingHelper, self).__init__(
inputs=inputs,
sequence_length=sequence_length,
time_major=time_major,
name=name)
def initialize(self, name=None):
- return super(ScheduledEmbeddingTrainingSampler, self).initialize(
- name=name)
+ return super(ScheduledEmbeddingTrainingHelper, self).initialize(name=name)
def sample(self, time, outputs, state, name=None):
- with ops.name_scope(name, "ScheduledEmbeddingTrainingSamplerSample",
+ with ops.name_scope(name, "ScheduledEmbeddingTrainingHelperSample",
[time, outputs, state]):
# Return -1s where we did not sample, and sample_ids elsewhere
select_sample_noise = random_ops.random_uniform(
@@ -346,11 +257,14 @@ class ScheduledEmbeddingTrainingSampler(BasicTrainingSampler):
array_ops.tile([-1], [self.batch_size]))
def next_inputs(self, time, outputs, state, sample_ids, name=None):
- with ops.name_scope(name, "ScheduledEmbeddingTrainingSamplerSample",
+ with ops.name_scope(name, "ScheduledEmbeddingTrainingHelperSample",
[time, outputs, state, sample_ids]):
(finished, base_next_inputs, state) = (
- super(ScheduledEmbeddingTrainingSampler, self).next_inputs(
- time=time, outputs=outputs, state=state, sample_ids=sample_ids,
+ super(ScheduledEmbeddingTrainingHelper, self).next_inputs(
+ time=time,
+ outputs=outputs,
+ state=state,
+ sample_ids=sample_ids,
name=name))
def maybe_sample():
@@ -379,8 +293,8 @@ class ScheduledEmbeddingTrainingSampler(BasicTrainingSampler):
return (finished, next_inputs, state)
-class GreedyEmbeddingSampler(Sampler):
- """A (non-)sampler for use during inference.
+class GreedyEmbeddingHelper(Helper):
+ """A helper for use during inference.
Uses the argmax of the output (treated as logits) and passes the
result through an embedding layer to get the next input.
@@ -424,7 +338,7 @@ class GreedyEmbeddingSampler(Sampler):
return (finished, self._start_inputs)
def sample(self, time, outputs, state, name=None):
- """sample for GreedyEmbeddingSampler."""
+ """sample for GreedyEmbeddingHelper."""
del time, state # unused by sample_fn
# Outputs are logits, use argmax to get the most probable id
if not isinstance(outputs, ops.Tensor):
@@ -435,7 +349,7 @@ class GreedyEmbeddingSampler(Sampler):
return sample_ids
def next_inputs(self, time, outputs, state, sample_ids, name=None):
- """next_inputs_fn for GreedyEmbeddingSampler."""
+ """next_inputs_fn for GreedyEmbeddingHelper."""
del time, outputs # unused by next_inputs_fn
finished = math_ops.equal(sample_ids, self._end_token)
all_finished = math_ops.reduce_all(finished)