aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops
diff options
context:
space:
mode:
authorGravatar Saurabh Saxena <srbs@google.com>2018-09-28 12:46:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-28 12:53:29 -0700
commit5e66d25666aad9fa76ed8cc0d2b162db76ea0cc8 (patch)
treea5b810c506ee8eb61b707e890ae41b71ac4cb8bd /tensorflow/python/ops
parente00954e8626c74b263b90527e0c020cfd64136b2 (diff)
Add flag for enabling while_v2.
Add a single test flag for enabling v2 control flow in tests since we do not plan to support v2 ops with legacy control flow. We have 2 test decorators now: @with_control_flow_v2: Enables all tests in a class to run with v2 control flow. @disable_control_flow_v2: Disables a test function from running in v2. I have removed the skiptests to avoid setup/teardown overheads. Enable tests in control_flow_ops_py_test that run with control_flow_v2. PiperOrigin-RevId: 214980108
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r--tensorflow/python/ops/control_flow_ops.py16
-rw-r--r--tensorflow/python/ops/while_v2.py4
2 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 87f8bd85a5..9d7d31df22 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -60,8 +60,17 @@ from tensorflow.python.util import nest
from tensorflow.python.util import tf_should_use
from tensorflow.python.util.tf_export import tf_export
+# The while_v2 module.
+_while_v2 = None
ENABLE_COND_V2 = os.getenv("TF_ENABLE_COND_V2", "0") != "0"
+# Note: Setting this to True is not sufficient to switch to the v2 while_loop.
+# Users must also import the while_v2 module to set the _while_v2 module
+# variable above. We do this to avoid a circular dependency:
+# control_flow_ops -> while_v2 -> gradients_impl -> control_flow_ops
+# A ValueError is raised in tf.while_loop if this is set to True and the
+# `_while_v2` module is not set.
+ENABLE_WHILE_V2 = os.getenv("TF_ENABLE_WHILE_V2", "0") != "0"
# We override the 'tuple' for a control flow op, so we keep python's
@@ -3211,6 +3220,13 @@ def while_loop(cond,
```
"""
+ if ENABLE_WHILE_V2 and not context.executing_eagerly():
+ if not _while_v2:
+ raise ValueError("The while_v2 module is not set. Did you forget to "
+ "import tensorflow.python.ops."
+ "while_v2?")
+ return _while_v2.while_loop(cond, body, loop_vars, name)
+
with ops.name_scope(name, "while", loop_vars):
if not loop_vars:
raise ValueError("No loop variables provided")
diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py
index 875be31602..6791e1cd61 100644
--- a/tensorflow/python/ops/while_v2.py
+++ b/tensorflow/python/ops/while_v2.py
@@ -24,6 +24,7 @@ from __future__ import division
from __future__ import print_function
import collections
+import sys
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.eager import function
@@ -33,6 +34,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import cond_v2_impl as cond_v2
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import gen_functional_ops
from tensorflow.python.ops import gradients_impl
@@ -41,6 +43,8 @@ from tensorflow.python.util import nest
# pylint: disable=protected-access
+control_flow_ops._while_v2 = sys.modules[__name__]
+
# TODO(b/79881896): Handle external control dependencies. tf.while_loop allows
# control dependencies on external nodes with at least 1 output.
# Another idea is to create const nodes outside the loop and add control edges