aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/rnn.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/rnn.py')
-rw-r--r--tensorflow/python/ops/rnn.py37
1 files changed, 25 insertions, 12 deletions
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index 215140e987..7096e0dd84 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -26,6 +26,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import tensor_array_ops
@@ -131,6 +132,18 @@ def _maybe_tensor_shape_from_tensor(shape):
return shape
+def _should_cache():
+ """Returns True if a default caching device should be set, otherwise False."""
+ if context.executing_eagerly():
+ return False
+ # Don't set a caching device when running in a loop, since it is possible that
+ # train steps could be wrapped in a tf.while_loop. In that scenario caching
+ # prevents forward computations in loop iterations from re-reading the
+ # updated weights.
+ ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access
+ return control_flow_util.GetContainingWhileContext(ctxt) is None
+
+
# pylint: disable=unused-argument
def _rnn_step(
time, sequence_length, min_sequence_length, max_sequence_length,
@@ -404,24 +417,24 @@ def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None,
# Backward direction
if not time_major:
- time_dim = 1
- batch_dim = 0
+ time_axis = 1
+ batch_axis = 0
else:
- time_dim = 0
- batch_dim = 1
+ time_axis = 0
+ batch_axis = 1
- def _reverse(input_, seq_lengths, seq_dim, batch_dim):
+ def _reverse(input_, seq_lengths, seq_axis, batch_axis):
if seq_lengths is not None:
return array_ops.reverse_sequence(
input=input_, seq_lengths=seq_lengths,
- seq_dim=seq_dim, batch_dim=batch_dim)
+ seq_axis=seq_axis, batch_axis=batch_axis)
else:
- return array_ops.reverse(input_, axis=[seq_dim])
+ return array_ops.reverse(input_, axis=[seq_axis])
with vs.variable_scope("bw") as bw_scope:
inputs_reverse = _reverse(
inputs, seq_lengths=sequence_length,
- seq_dim=time_dim, batch_dim=batch_dim)
+ seq_axis=time_axis, batch_axis=batch_axis)
tmp, output_state_bw = dynamic_rnn(
cell=cell_bw, inputs=inputs_reverse, sequence_length=sequence_length,
initial_state=initial_state_bw, dtype=dtype,
@@ -430,7 +443,7 @@ def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None,
output_bw = _reverse(
tmp, seq_lengths=sequence_length,
- seq_dim=time_dim, batch_dim=batch_dim)
+ seq_axis=time_axis, batch_axis=batch_axis)
outputs = (output_fw, output_bw)
output_states = (output_state_fw, output_state_bw)
@@ -558,7 +571,7 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
# Create a new scope in which the caching device is either
# determined by the parent scope, or is set to place the cached
# Variable using the same placement as for the rest of the RNN.
- if not context.executing_eagerly():
+ if _should_cache():
if varscope.caching_device is None:
varscope.set_caching_device(lambda op: op.device)
@@ -1015,7 +1028,7 @@ def raw_rnn(cell, loop_fn,
# determined by the parent scope, or is set to place the cached
# Variable using the same placement as for the rest of the RNN.
with vs.variable_scope(scope or "rnn") as varscope:
- if not context.executing_eagerly():
+ if _should_cache():
if varscope.caching_device is None:
varscope.set_caching_device(lambda op: op.device)
@@ -1228,7 +1241,7 @@ def static_rnn(cell,
# determined by the parent scope, or is set to place the cached
# Variable using the same placement as for the rest of the RNN.
with vs.variable_scope(scope or "rnn") as varscope:
- if not context.executing_eagerly():
+ if _should_cache():
if varscope.caching_device is None:
varscope.set_caching_device(lambda op: op.device)