diff options
author | Saurabh Saxena <srbs@google.com> | 2018-09-28 12:46:10 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-28 12:53:29 -0700 |
commit | 5e66d25666aad9fa76ed8cc0d2b162db76ea0cc8 (patch) | |
tree | a5b810c506ee8eb61b707e890ae41b71ac4cb8bd /tensorflow/python/ops | |
parent | e00954e8626c74b263b90527e0c020cfd64136b2 (diff) |
Add flag for enabling while_v2.
Add a single test flag for enabling v2 control flow in tests since we do not plan to support v2 ops with legacy control flow.
We have 2 test decorators now:
@with_control_flow_v2: Enables all tests in a class to run with v2 control flow.
@disable_control_flow_v2: Disables a test function from running in v2. I have removed the skiptests to avoid setup/teardown overheads.
Enable tests in control_flow_ops_py_test that run with control_flow_v2.
PiperOrigin-RevId: 214980108
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r-- | tensorflow/python/ops/control_flow_ops.py | 16 | ||||
-rw-r--r-- | tensorflow/python/ops/while_v2.py | 4 |
2 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 87f8bd85a5..9d7d31df22 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -60,8 +60,17 @@ from tensorflow.python.util import nest from tensorflow.python.util import tf_should_use from tensorflow.python.util.tf_export import tf_export +# The while_v2 module. +_while_v2 = None ENABLE_COND_V2 = os.getenv("TF_ENABLE_COND_V2", "0") != "0" +# Note: Setting this to True is not sufficient to switch to the v2 while_loop. +# Users must also import the while_v2 module to set the _while_v2 module +# variable above. We do this to avoid a circular dependency: +# control_flow_ops -> while_v2 -> gradients_impl -> control_flow_ops +# A ValueError is raised in tf.while_loop if this is set to True and the +# `_while_v2` module is not set. +ENABLE_WHILE_V2 = os.getenv("TF_ENABLE_WHILE_V2", "0") != "0" # We override the 'tuple' for a control flow op, so we keep python's @@ -3211,6 +3220,13 @@ def while_loop(cond, ``` """ + if ENABLE_WHILE_V2 and not context.executing_eagerly(): + if not _while_v2: + raise ValueError("The while_v2 module is not set. Did you forget to " + "import tensorflow.python.ops." + "while_v2?") + return _while_v2.while_loop(cond, body, loop_vars, name) + with ops.name_scope(name, "while", loop_vars): if not loop_vars: raise ValueError("No loop variables provided") diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py index 875be31602..6791e1cd61 100644 --- a/tensorflow/python/ops/while_v2.py +++ b/tensorflow/python/ops/while_v2.py @@ -24,6 +24,7 @@ from __future__ import division from __future__ import print_function import collections +import sys from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.eager import function @@ -33,6 +34,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import cond_v2_impl as cond_v2 +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import gen_functional_ops from tensorflow.python.ops import gradients_impl @@ -41,6 +43,8 @@ from tensorflow.python.util import nest # pylint: disable=protected-access +control_flow_ops._while_v2 = sys.modules[__name__] + # TODO(b/79881896): Handle external control dependencies. tf.while_loop allows # control dependencies on external nodes with at least 1 output. # Another idea is to create const nodes outside the loop and add control edges |