aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r--tensorflow/python/ops/control_flow_ops.py16
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py2
-rw-r--r--tensorflow/python/ops/while_v2.py4
3 files changed, 21 insertions, 1 deletions
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 87f8bd85a5..9d7d31df22 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -60,8 +60,17 @@ from tensorflow.python.util import nest
from tensorflow.python.util import tf_should_use
from tensorflow.python.util.tf_export import tf_export
+# The while_v2 module.
+_while_v2 = None
ENABLE_COND_V2 = os.getenv("TF_ENABLE_COND_V2", "0") != "0"
+# Note: Setting this to True is not sufficient to switch to the v2 while_loop.
+# Users must also import the while_v2 module to set the _while_v2 module
+# variable above. We do this to avoid a circular dependency:
+# control_flow_ops -> while_v2 -> gradients_impl -> control_flow_ops
+# A ValueError is raised in tf.while_loop if this is set to True and the
+# `_while_v2` module is not set.
+ENABLE_WHILE_V2 = os.getenv("TF_ENABLE_WHILE_V2", "0") != "0"
# We override the 'tuple' for a control flow op, so we keep python's
@@ -3211,6 +3220,13 @@ def while_loop(cond,
```
"""
+ if ENABLE_WHILE_V2 and not context.executing_eagerly():
+ if not _while_v2:
+ raise ValueError("The while_v2 module is not set. Did you forget to "
+ "import tensorflow.python.ops."
+ "while_v2?")
+ return _while_v2.while_loop(cond, body, loop_vars, name)
+
with ops.name_scope(name, "while", loop_vars):
if not loop_vars:
raise ValueError("No loop variables provided")
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index 43cca1a498..c2751e529a 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -611,7 +611,7 @@ class LSTMStateTuple(_LSTMStateTuple):
# TODO(scottzhu): Stop exporting this class in TF 2.0.
@tf_export("nn.rnn_cell.BasicLSTMCell")
class BasicLSTMCell(LayerRNNCell):
- """DEPRECATED: Please use @{tf.nn.rnn_cell.LSTMCell} instead.
+ """DEPRECATED: Please use `tf.nn.rnn_cell.LSTMCell` instead.
Basic LSTM recurrent network cell.
diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py
index 875be31602..6791e1cd61 100644
--- a/tensorflow/python/ops/while_v2.py
+++ b/tensorflow/python/ops/while_v2.py
@@ -24,6 +24,7 @@ from __future__ import division
from __future__ import print_function
import collections
+import sys
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.eager import function
@@ -33,6 +34,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import cond_v2_impl as cond_v2
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import gen_functional_ops
from tensorflow.python.ops import gradients_impl
@@ -41,6 +43,8 @@ from tensorflow.python.util import nest
# pylint: disable=protected-access
+control_flow_ops._while_v2 = sys.modules[__name__]
+
# TODO(b/79881896): Handle external control dependencies. tf.while_loop allows
# control dependencies on external nodes with at least 1 output.
# Another idea is to create const nodes outside the loop and add control edges