diff options
Diffstat (limited to 'tensorflow/contrib/seq2seq/python/ops/decoder.py')
-rw-r--r-- | tensorflow/contrib/seq2seq/python/ops/decoder.py | 29 |
1 files changed, 14 insertions, 15 deletions
diff --git a/tensorflow/contrib/seq2seq/python/ops/decoder.py b/tensorflow/contrib/seq2seq/python/ops/decoder.py index e69725ff8a..f58268eff5 100644 --- a/tensorflow/contrib/seq2seq/python/ops/decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/decoder.py @@ -21,6 +21,7 @@ from __future__ import print_function import abc import six +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -182,19 +183,20 @@ def dynamic_decode(decoder, raise TypeError("Expected decoder to be type Decoder, but saw: %s" % type(decoder)) - def _is_xla_tensor(tensor): - try: - op = tensor.op - except AttributeError: - return False - if control_flow_util.IsInXLAContext(op): - return True - return False - with variable_scope.variable_scope(scope, "decoder") as varscope: - # Properly cache variable values inside the while_loop - if varscope.caching_device is None: - varscope.set_caching_device(lambda op: op.device) + # Determine context types. + ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access + is_xla = control_flow_util.GetContainingXLAContext(ctxt) is not None + in_while_loop = ( + control_flow_util.GetContainingWhileContext(ctxt) is not None) + # Properly cache variable values inside the while_loop. + # Don't set a caching device when running in a loop, since it is possible + # that train steps could be wrapped in a tf.while_loop. In that scenario + # caching prevents forward computations in loop iterations from re-reading + # the updated weights. + if not context.executing_eagerly() and not in_while_loop: + if varscope.caching_device is None: + varscope.set_caching_device(lambda op: op.device) if maximum_iterations is not None: maximum_iterations = ops.convert_to_tensor( @@ -208,9 +210,6 @@ def dynamic_decode(decoder, decoder.output_dtype, decoder.batch_size) - is_xla = False - if any([_is_xla_tensor(i) for i in nest.flatten(initial_inputs)]): - is_xla = True if is_xla and maximum_iterations is None: raise ValueError("maximum_iterations is required for XLA compilation.") if maximum_iterations is not None: |