aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq/python/ops/decoder.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/seq2seq/python/ops/decoder.py')
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/decoder.py29
1 files changed, 14 insertions, 15 deletions
diff --git a/tensorflow/contrib/seq2seq/python/ops/decoder.py b/tensorflow/contrib/seq2seq/python/ops/decoder.py
index e69725ff8a..f58268eff5 100644
--- a/tensorflow/contrib/seq2seq/python/ops/decoder.py
+++ b/tensorflow/contrib/seq2seq/python/ops/decoder.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import abc
import six
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -182,19 +183,20 @@ def dynamic_decode(decoder,
raise TypeError("Expected decoder to be type Decoder, but saw: %s" %
type(decoder))
- def _is_xla_tensor(tensor):
- try:
- op = tensor.op
- except AttributeError:
- return False
- if control_flow_util.IsInXLAContext(op):
- return True
- return False
-
with variable_scope.variable_scope(scope, "decoder") as varscope:
- # Properly cache variable values inside the while_loop
- if varscope.caching_device is None:
- varscope.set_caching_device(lambda op: op.device)
+ # Determine context types.
+ ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access
+ is_xla = control_flow_util.GetContainingXLAContext(ctxt) is not None
+ in_while_loop = (
+ control_flow_util.GetContainingWhileContext(ctxt) is not None)
+ # Properly cache variable values inside the while_loop.
+ # Don't set a caching device when running in a loop, since it is possible
+ # that train steps could be wrapped in a tf.while_loop. In that scenario
+ # caching prevents forward computations in loop iterations from re-reading
+ # the updated weights.
+ if not context.executing_eagerly() and not in_while_loop:
+ if varscope.caching_device is None:
+ varscope.set_caching_device(lambda op: op.device)
if maximum_iterations is not None:
maximum_iterations = ops.convert_to_tensor(
@@ -208,9 +210,6 @@ def dynamic_decode(decoder,
decoder.output_dtype,
decoder.batch_size)
- is_xla = False
- if any([_is_xla_tensor(i) for i in nest.flatten(initial_inputs)]):
- is_xla = True
if is_xla and maximum_iterations is None:
raise ValueError("maximum_iterations is required for XLA compilation.")
if maximum_iterations is not None: