diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-03-26 11:05:37 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-26 11:08:11 -0700 |
commit | 005b8aa42c273a0152642279d0c57aa9e08ccbe0 (patch) | |
tree | 3231b4bfed877b52beeb91c8d66d74fc270fa8f1 /tensorflow/contrib/seq2seq | |
parent | 9d8779eebdf0e813748fa1b81b975f443f84f73a (diff) |
Fixes an issue with calling tf.contrib.seq2seq.dynamic_decode with an extended BasicDecoder which for example returns a tf.contrib.seq2seq.AttentionWrapperState.
In this case the internal while-loop fails when trying to store an instance tf.contrib.seq2seq.AttentionWrapperState in the internal TensorArray.
PiperOrigin-RevId: 190491787
Diffstat (limited to 'tensorflow/contrib/seq2seq')
-rw-r--r-- | tensorflow/contrib/seq2seq/python/ops/decoder.py | 15 |
1 files changed, 5 insertions, 10 deletions
diff --git a/tensorflow/contrib/seq2seq/python/ops/decoder.py b/tensorflow/contrib/seq2seq/python/ops/decoder.py index f14974b9d5..898493662d 100644 --- a/tensorflow/contrib/seq2seq/python/ops/decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/decoder.py @@ -30,6 +30,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import rnn +from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variable_scope from tensorflow.python.util import nest @@ -39,6 +40,7 @@ __all__ = ["Decoder", "dynamic_decode"] _transpose_batch_time = rnn._transpose_batch_time # pylint: disable=protected-access +_zero_state_tensors = rnn_cell_impl._zero_state_tensors # pylint: disable=protected-access @six.add_metaclass(abc.ABCMeta) @@ -133,16 +135,8 @@ class Decoder(object): def _create_zero_outputs(size, dtype, batch_size): """Create a zero outputs Tensor structure.""" - def _t(s): - return (s if isinstance(s, ops.Tensor) else constant_op.constant( - tensor_shape.TensorShape(s).as_list(), - dtype=dtypes.int32, - name="zero_suffix_shape")) - def _create(s, d): - return array_ops.zeros( - array_ops.concat( - ([batch_size], _t(s)), axis=0), dtype=d) + return _zero_state_tensors(s, batch_size, d) return nest.map_structure(_create, size, dtype) @@ -212,7 +206,8 @@ def dynamic_decode(decoder, initial_time = constant_op.constant(0, dtype=dtypes.int32) def _shape(batch_size, from_shape): - if not isinstance(from_shape, tensor_shape.TensorShape): + if (not isinstance(from_shape, tensor_shape.TensorShape) or + from_shape.ndims == 0): return tensor_shape.TensorShape(None) else: batch_size = tensor_util.constant_value( |