aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Adam Roberts <adarob@google.com>2017-07-11 14:00:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-11 14:04:05 -0700
commita80035d761d5563c00624a93a65b4accd7df22d1 (patch)
tree6db0da18ad77d2addf31a1daa241122153fedbc3
parent1de8cce6643bb75edd1441fa01ffb6a0e5258c5f (diff)
Make ScheduledEmbeddingTrainingHelper more readable and consistent with output ScheduledOutputTrainingHelper helper. Also results in minor speedup.
PiperOrigin-RevId: 161577148
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/helper.py19
1 files changed, 9 insertions, 10 deletions
diff --git a/tensorflow/contrib/seq2seq/python/ops/helper.py b/tensorflow/contrib/seq2seq/python/ops/helper.py
index 6b8cad7fd7..9d3f8ad441 100644
--- a/tensorflow/contrib/seq2seq/python/ops/helper.py
+++ b/tensorflow/contrib/seq2seq/python/ops/helper.py
@@ -30,8 +30,8 @@ from tensorflow.python.layers import base as layers_base
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import embedding_ops
+from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops.distributions import bernoulli
from tensorflow.python.ops.distributions import categorical
@@ -258,14 +258,15 @@ class ScheduledEmbeddingTrainingHelper(TrainingHelper):
with ops.name_scope(name, "ScheduledEmbeddingTrainingHelperSample",
[time, outputs, state]):
# Return -1s where we did not sample, and sample_ids elsewhere
- select_sample_noise = random_ops.random_uniform(
- [self.batch_size], seed=self._scheduling_seed)
- select_sample = (self._sampling_probability > select_sample_noise)
+ select_sampler = bernoulli.Bernoulli(
+ probs=self._sampling_probability, dtype=dtypes.bool)
+ select_sample = select_sampler.sample(
+ sample_shape=self.batch_size, seed=self._scheduling_seed)
sample_id_sampler = categorical.Categorical(logits=outputs)
return array_ops.where(
select_sample,
sample_id_sampler.sample(seed=self._seed),
- array_ops.tile([-1], [self.batch_size]))
+ gen_array_ops.fill([self.batch_size], -1))
def next_inputs(self, time, outputs, state, sample_ids, name=None):
with ops.name_scope(name, "ScheduledEmbeddingTrainingHelperSample",
@@ -284,11 +285,9 @@ class ScheduledEmbeddingTrainingHelper(TrainingHelper):
array_ops.where(sample_ids > -1), dtypes.int32)
where_not_sampling = math_ops.cast(
array_ops.where(sample_ids <= -1), dtypes.int32)
- where_sampling_flat = array_ops.reshape(where_sampling, [-1])
- where_not_sampling_flat = array_ops.reshape(where_not_sampling, [-1])
- sample_ids_sampling = array_ops.gather(sample_ids, where_sampling_flat)
- inputs_not_sampling = array_ops.gather(
- base_next_inputs, where_not_sampling_flat)
+ sample_ids_sampling = array_ops.gather_nd(sample_ids, where_sampling)
+ inputs_not_sampling = array_ops.gather_nd(
+ base_next_inputs, where_not_sampling)
sampled_next_inputs = self._embedding_fn(sample_ids_sampling)
base_shape = array_ops.shape(base_next_inputs)
return (array_ops.scatter_nd(indices=where_sampling,