diff options
Diffstat (limited to 'tensorflow/contrib/seq2seq/python/ops/sampling_decoder.py')
-rw-r--r-- | tensorflow/contrib/seq2seq/python/ops/sampling_decoder.py | 14 |
1 files changed, 6 insertions, 8 deletions
diff --git a/tensorflow/contrib/seq2seq/python/ops/sampling_decoder.py b/tensorflow/contrib/seq2seq/python/ops/sampling_decoder.py index 2e754b7f22..cfb964d885 100644 --- a/tensorflow/contrib/seq2seq/python/ops/sampling_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/sampling_decoder.py @@ -113,8 +113,7 @@ class BasicSamplingDecoder(decoder.Decoder): dtypes.int32) def initialize(self, name=None): - with ops.name_scope("basic_sampling_decoder_initialize"): - return self._sampler.initialize() + (self._initial_state,) + return self._sampler.initialize() + (self._initial_state,) def step(self, time, inputs, state): """Perform a decoding step. @@ -127,12 +126,11 @@ class BasicSamplingDecoder(decoder.Decoder): Returns: `(outputs, next_state, next_inputs, finished)`. """ - with ops.name_scope("basic_sampling_decoder_step"): - cell_outputs, next_state = self._cell(inputs, state) - (sample_id, finished, next_inputs) = self._sampler.sample( - time=time, outputs=cell_outputs, state=next_state) - outputs = SamplingDecoderOutput(cell_outputs, sample_id) - return (outputs, next_state, next_inputs, finished) + cell_outputs, next_state = self._cell(inputs, state) + (sample_id, finished, next_inputs) = self._sampler.sample( + time=time, outputs=cell_outputs, state=next_state) + outputs = SamplingDecoderOutput(cell_outputs, sample_id) + return (outputs, next_state, next_inputs, finished) class BasicTrainingSampler(Sampler): |