aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-26 11:05:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-26 11:08:11 -0700
commit005b8aa42c273a0152642279d0c57aa9e08ccbe0 (patch)
tree3231b4bfed877b52beeb91c8d66d74fc270fa8f1 /tensorflow/contrib/seq2seq
parent9d8779eebdf0e813748fa1b81b975f443f84f73a (diff)
Fixes an issue with calling tf.contrib.seq2seq.dynamic_decode with an extended BasicDecoder which for example returns a tf.contrib.seq2seq.AttentionWrapperState.
In this case the internal while-loop fails when trying to store an instance tf.contrib.seq2seq.AttentionWrapperState in the internal TensorArray. PiperOrigin-RevId: 190491787
Diffstat (limited to 'tensorflow/contrib/seq2seq')
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/decoder.py15
1 files changed, 5 insertions, 10 deletions
diff --git a/tensorflow/contrib/seq2seq/python/ops/decoder.py b/tensorflow/contrib/seq2seq/python/ops/decoder.py
index f14974b9d5..898493662d 100644
--- a/tensorflow/contrib/seq2seq/python/ops/decoder.py
+++ b/tensorflow/contrib/seq2seq/python/ops/decoder.py
@@ -30,6 +30,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
+from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.util import nest
@@ -39,6 +40,7 @@ __all__ = ["Decoder", "dynamic_decode"]
_transpose_batch_time = rnn._transpose_batch_time # pylint: disable=protected-access
+_zero_state_tensors = rnn_cell_impl._zero_state_tensors # pylint: disable=protected-access
@six.add_metaclass(abc.ABCMeta)
@@ -133,16 +135,8 @@ class Decoder(object):
def _create_zero_outputs(size, dtype, batch_size):
"""Create a zero outputs Tensor structure."""
- def _t(s):
- return (s if isinstance(s, ops.Tensor) else constant_op.constant(
- tensor_shape.TensorShape(s).as_list(),
- dtype=dtypes.int32,
- name="zero_suffix_shape"))
-
def _create(s, d):
- return array_ops.zeros(
- array_ops.concat(
- ([batch_size], _t(s)), axis=0), dtype=d)
+ return _zero_state_tensors(s, batch_size, d)
return nest.map_structure(_create, size, dtype)
@@ -212,7 +206,8 @@ def dynamic_decode(decoder,
initial_time = constant_op.constant(0, dtype=dtypes.int32)
def _shape(batch_size, from_shape):
- if not isinstance(from_shape, tensor_shape.TensorShape):
+ if (not isinstance(from_shape, tensor_shape.TensorShape) or
+ from_shape.ndims == 0):
return tensor_shape.TensorShape(None)
else:
batch_size = tensor_util.constant_value(