aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2018-02-26 11:43:14 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-26 11:50:25 -0800
commit0f8ee19ef830fc7d28ae611194bcd66f4383b038 (patch)
tree9ec24a53e29ccc8b9d4be225a73323112f2f1c83 /tensorflow/python
parente5b73fc9a8df0d87cb964ed49e946d2477c73e19 (diff)
Actually expose smart_cond and smart_constant_value in tf.contrib.framework
Also moves these methods into their own file in python/framework. This avoids further bloating control_flow_ops.py and makes the BUILD deps easier for a future change I'm working on. PiperOrigin-RevId: 187055501
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/BUILD26
-rw-r--r--tensorflow/python/framework/smart_cond.py79
-rw-r--r--tensorflow/python/framework/smart_cond_test.py66
-rw-r--r--tensorflow/python/layers/utils.py5
-rw-r--r--tensorflow/python/ops/control_flow_ops.py56
-rw-r--r--tensorflow/python/ops/control_flow_ops_test.py36
6 files changed, 174 insertions, 94 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 4c8c73548c..b0cb48c80c 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -766,6 +766,31 @@ py_library(
)
py_library(
+ name = "smart_cond",
+ srcs = ["framework/smart_cond.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":control_flow_ops",
+ ":tensor_util",
+ ],
+)
+
+py_test(
+ name = "smart_cond_test",
+ size = "small",
+ srcs = ["framework/smart_cond_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":client_testlib",
+ ":constant_op",
+ ":framework_ops",
+ ":math_ops",
+ ":session",
+ ":smart_cond",
+ ],
+)
+
+py_library(
name = "sparse_tensor",
srcs = ["framework/sparse_tensor.py"],
srcs_version = "PY2AND3",
@@ -4091,6 +4116,7 @@ py_library(
":control_flow_ops",
":framework_for_generated_wrappers",
":platform",
+ ":smart_cond",
":tensor_util",
":util",
":variable_scope",
diff --git a/tensorflow/python/framework/smart_cond.py b/tensorflow/python/framework/smart_cond.py
new file mode 100644
index 0000000000..f97bb01f54
--- /dev/null
+++ b/tensorflow/python/framework/smart_cond.py
@@ -0,0 +1,79 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""smart_cond and related utilties."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import control_flow_ops
+
+
+def smart_cond(pred, true_fn=None, false_fn=None, name=None):
+ """Return either `true_fn()` if predicate `pred` is true else `false_fn()`.
+
+ If `pred` is a bool or has a constant value, we return either `true_fn()`
+ or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both.
+
+ Arguments:
+ pred: A scalar determining whether to return the result of `true_fn` or
+ `false_fn`.
+ true_fn: The callable to be performed if pred is true.
+ false_fn: The callable to be performed if pred is false.
+ name: Optional name prefix when using `tf.cond`.
+
+ Returns:
+ Tensors returned by the call to either `true_fn` or `false_fn`.
+
+ Raises:
+ TypeError: If `true_fn` or `false_fn` is not callable.
+ """
+ if not callable(true_fn):
+ raise TypeError("`true_fn` must be callable.")
+ if not callable(false_fn):
+ raise TypeError("`false_fn` must be callable.")
+
+ pred_value = smart_constant_value(pred)
+ if pred_value is not None:
+ if pred_value:
+ return true_fn()
+ else:
+ return false_fn()
+ else:
+ return control_flow_ops.cond(pred, true_fn=true_fn, false_fn=false_fn,
+ name=name)
+
+
+def smart_constant_value(pred):
+ """Return the bool value for `pred`, or None if `pred` had a dynamic value.
+
+ Arguments:
+ pred: A scalar, either a Python bool or tensor.
+
+ Returns:
+ True or False if `pred` has a constant boolean value, None otherwise.
+
+ Raises:
+ TypeError: If `pred` is not a Tensor or bool.
+ """
+ if isinstance(pred, bool):
+ pred_value = pred
+ elif isinstance(pred, ops.Tensor):
+ pred_value = tensor_util.constant_value(pred)
+ else:
+ raise TypeError("`pred` must be a Tensor or a Python bool.")
+ return pred_value
diff --git a/tensorflow/python/framework/smart_cond_test.py b/tensorflow/python/framework/smart_cond_test.py
new file mode 100644
index 0000000000..b682506da0
--- /dev/null
+++ b/tensorflow/python/framework/smart_cond_test.py
@@ -0,0 +1,66 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.client import session
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import smart_cond
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import googletest
+
+
+@test_util.with_c_api
+class SmartCondTest(test_util.TensorFlowTestCase):
+
+ def testSmartCondTrue(self):
+ with ops.Graph().as_default():
+ with session.Session():
+ x = constant_op.constant(2)
+ y = constant_op.constant(5)
+ z = smart_cond.smart_cond(True, lambda: math_ops.multiply(x, 16),
+ lambda: math_ops.multiply(y, 5))
+ self.assertEqual(z.eval(), 32)
+
+ def testSmartCondFalse(self):
+ with ops.Graph().as_default():
+ with session.Session():
+ x = constant_op.constant(4)
+ y = constant_op.constant(3)
+ z = smart_cond.smart_cond(False, lambda: math_ops.multiply(x, 16),
+ lambda: math_ops.multiply(y, 3))
+ self.assertEqual(z.eval(), 9)
+
+ def testSmartCondMissingArg1(self):
+ with ops.Graph().as_default():
+ with session.Session():
+ x = constant_op.constant(1)
+ with self.assertRaises(TypeError):
+ smart_cond.smart_cond(True, false_fn=lambda: x)
+
+ def testSmartCondMissingArg2(self):
+ with ops.Graph().as_default():
+ with session.Session():
+ x = constant_op.constant(1)
+ with self.assertRaises(TypeError):
+ smart_cond.smart_cond(True, lambda: x)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/layers/utils.py b/tensorflow/python/layers/utils.py
index 484c6fc466..3b156c36a2 100644
--- a/tensorflow/python/layers/utils.py
+++ b/tensorflow/python/layers/utils.py
@@ -24,6 +24,7 @@ from tensorflow.python.eager import context
from tensorflow.python.ops import variables
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.framework import ops
+from tensorflow.python.framework import smart_cond as smart_module
from tensorflow.python.framework import tensor_util
from tensorflow.python.util import nest
@@ -201,7 +202,7 @@ def smart_cond(pred, true_fn=None, false_fn=None, name=None):
if isinstance(pred, variables.Variable):
return control_flow_ops.cond(
pred, true_fn=true_fn, false_fn=false_fn, name=name)
- return control_flow_ops.smart_cond(
+ return smart_module.smart_cond(
pred, true_fn=true_fn, false_fn=false_fn, name=name)
@@ -228,7 +229,7 @@ def constant_value(pred):
if isinstance(pred, variables.Variable):
return None
- return control_flow_ops.smart_constant_value(pred)
+ return smart_module.smart_constant_value(pred)
def object_list_uid(object_list):
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index c78a5aa8c2..8d5ab72670 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -23,7 +23,6 @@ See the @{$python/control_flow_ops} guide.
@@no_op
@@count_up_to
@@cond
-@@smart_cond
@@case
@@while_loop
@@logical_and
@@ -2130,61 +2129,6 @@ def cond(pred,
# pylint: enable=redefined-outer-name
-def smart_cond(pred, true_fn=None, false_fn=None, name=None):
- """Return either `true_fn()` if predicate `pred` is true else `false_fn()`.
-
- If `pred` is a bool or has a constant value, we return either `true_fn()`
- or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both.
-
- Arguments:
- pred: A scalar determining whether to return the result of `true_fn` or
- `false_fn`.
- true_fn: The callable to be performed if pred is true.
- false_fn: The callable to be performed if pred is false.
- name: Optional name prefix when using `tf.cond`.
-
- Returns:
- Tensors returned by the call to either `true_fn` or `false_fn`.
-
- Raises:
- TypeError: If `true_fn` or `false_fn` is not callable.
- """
- if not callable(true_fn):
- raise TypeError("`true_fn` must be callable.")
- if not callable(false_fn):
- raise TypeError("`false_fn` must be callable.")
-
- pred_value = smart_constant_value(pred)
- if pred_value is not None:
- if pred_value:
- return true_fn()
- else:
- return false_fn()
- else:
- return cond(pred, true_fn=true_fn, false_fn=false_fn, name=name)
-
-
-def smart_constant_value(pred):
- """Return the bool value for `pred`, or None if `pred` had a dynamic value.
-
- Arguments:
- pred: A scalar, either a Python bool or tensor.
-
- Returns:
- True or False if `pred` has a constant boolean value, None otherwise.
-
- Raises:
- TypeError: If `pred` is not a Tensor or bool.
- """
- if isinstance(pred, bool):
- pred_value = pred
- elif isinstance(pred, ops.Tensor):
- pred_value = tensor_util.constant_value(pred)
- else:
- raise TypeError("`pred` must be a Tensor or a Python bool.")
- return pred_value
-
-
def _resource_safe_shape(t):
"""Returns the shape of t or the variable it points to."""
if t.dtype == dtypes.resource:
diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py
index adc8c51e11..f22f3059d1 100644
--- a/tensorflow/python/ops/control_flow_ops_test.py
+++ b/tensorflow/python/ops/control_flow_ops_test.py
@@ -350,42 +350,6 @@ class SwitchTestCase(test_util.TensorFlowTestCase):
@test_util.with_c_api
-class SmartCondTest(test_util.TensorFlowTestCase):
-
- def testSmartCondTrue(self):
- with ops.Graph().as_default():
- with session.Session():
- x = constant_op.constant(2)
- y = constant_op.constant(5)
- z = control_flow_ops.smart_cond(True, lambda: math_ops.multiply(x, 16),
- lambda: math_ops.multiply(y, 5))
- self.assertEqual(z.eval(), 32)
-
- def testSmartCondFalse(self):
- with ops.Graph().as_default():
- with session.Session():
- x = constant_op.constant(4)
- y = constant_op.constant(3)
- z = control_flow_ops.smart_cond(False, lambda: math_ops.multiply(x, 16),
- lambda: math_ops.multiply(y, 3))
- self.assertEqual(z.eval(), 9)
-
- def testSmartCondMissingArg1(self):
- with ops.Graph().as_default():
- with session.Session():
- x = constant_op.constant(1)
- with self.assertRaises(TypeError):
- control_flow_ops.smart_cond(True, false_fn=lambda: x)
-
- def testSmartCondMissingArg2(self):
- with ops.Graph().as_default():
- with session.Session():
- x = constant_op.constant(1)
- with self.assertRaises(TypeError):
- control_flow_ops.smart_cond(True, lambda: x)
-
-
-@test_util.with_c_api
class CondTest(test_util.TensorFlowTestCase):
def testCondTrue(self):