aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-04-26 14:41:33 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-26 15:53:22 -0700
commitb6c105d23a12b844784a7de37abbadf2bb183ff1 (patch)
tree33bc79522105390149c24068cae825cc34840207 /tensorflow
parent7888d8c318e6b0d54d3bfdb44dde47643256a728 (diff)
Check if memory_sequence_length is not None before converting tensor
Change: 154355223
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
index 04b38159bb..2e2b2ebe60 100644
--- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
+++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
@@ -73,8 +73,9 @@ def _prepare_memory(memory, memory_sequence_length, check_inner_dims_defined):
"""
memory = nest.map_structure(
lambda m: ops.convert_to_tensor(m, name="memory"), memory)
- memory_sequence_length = ops.convert_to_tensor(
- memory_sequence_length, name="memory_sequence_length")
+ if memory_sequence_length is not None:
+ memory_sequence_length = ops.convert_to_tensor(
+ memory_sequence_length, name="memory_sequence_length")
if check_inner_dims_defined:
def _check_dims(m):
if not m.get_shape()[2:].is_fully_defined():