aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-28 16:27:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-28 16:30:55 -0700
commitc81830af5d488de600a4f62392c63059e310c017 (patch)
tree555b5e6d2b8520234fad9c2fe034bd93d96bf2fc /tensorflow/contrib/seq2seq
parent2273cda3e0209d17fc4f2f055a28d27377b65988 (diff)
Don't cache RNN weights in while loops since it's possible that train steps could be updating the weights. This is specifically true on TPUS (tpu.repeat). Also, fix the `testDynamicRnnTrainLoop` unit test.
PiperOrigin-RevId: 202565323
Diffstat (limited to 'tensorflow/contrib/seq2seq')
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/decoder.py29
1 files changed, 14 insertions, 15 deletions
diff --git a/tensorflow/contrib/seq2seq/python/ops/decoder.py b/tensorflow/contrib/seq2seq/python/ops/decoder.py
index e69725ff8a..f58268eff5 100644
--- a/tensorflow/contrib/seq2seq/python/ops/decoder.py
+++ b/tensorflow/contrib/seq2seq/python/ops/decoder.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import abc
import six
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -182,19 +183,20 @@ def dynamic_decode(decoder,
raise TypeError("Expected decoder to be type Decoder, but saw: %s" %
type(decoder))
- def _is_xla_tensor(tensor):
- try:
- op = tensor.op
- except AttributeError:
- return False
- if control_flow_util.IsInXLAContext(op):
- return True
- return False
-
with variable_scope.variable_scope(scope, "decoder") as varscope:
- # Properly cache variable values inside the while_loop
- if varscope.caching_device is None:
- varscope.set_caching_device(lambda op: op.device)
+ # Determine context types.
+ ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access
+ is_xla = control_flow_util.GetContainingXLAContext(ctxt) is not None
+ in_while_loop = (
+ control_flow_util.GetContainingWhileContext(ctxt) is not None)
+ # Properly cache variable values inside the while_loop.
+ # Don't set a caching device when running in a loop, since it is possible
+ # that train steps could be wrapped in a tf.while_loop. In that scenario
+ # caching prevents forward computations in loop iterations from re-reading
+ # the updated weights.
+ if not context.executing_eagerly() and not in_while_loop:
+ if varscope.caching_device is None:
+ varscope.set_caching_device(lambda op: op.device)
if maximum_iterations is not None:
maximum_iterations = ops.convert_to_tensor(
@@ -208,9 +210,6 @@ def dynamic_decode(decoder,
decoder.output_dtype,
decoder.batch_size)
- is_xla = False
- if any([_is_xla_tensor(i) for i in nest.flatten(initial_inputs)]):
- is_xla = True
if is_xla and maximum_iterations is None:
raise ValueError("maximum_iterations is required for XLA compilation.")
if maximum_iterations is not None: