diff options
author | 2018-02-26 11:43:14 -0800 | |
---|---|---|
committer | 2018-02-26 11:50:25 -0800 | |
commit | 0f8ee19ef830fc7d28ae611194bcd66f4383b038 (patch) | |
tree | 9ec24a53e29ccc8b9d4be225a73323112f2f1c83 /tensorflow/python | |
parent | e5b73fc9a8df0d87cb964ed49e946d2477c73e19 (diff) |
Actually expose smart_cond and smart_constant_value in tf.contrib.framework
Also moves these methods into their own file in python/framework. This avoids further bloating control_flow_ops.py and makes the BUILD deps easier for a future change I'm working on.
PiperOrigin-RevId: 187055501
Diffstat (limited to 'tensorflow/python')
-rw-r--r-- | tensorflow/python/BUILD | 26 | ||||
-rw-r--r-- | tensorflow/python/framework/smart_cond.py | 79 | ||||
-rw-r--r-- | tensorflow/python/framework/smart_cond_test.py | 66 | ||||
-rw-r--r-- | tensorflow/python/layers/utils.py | 5 | ||||
-rw-r--r-- | tensorflow/python/ops/control_flow_ops.py | 56 | ||||
-rw-r--r-- | tensorflow/python/ops/control_flow_ops_test.py | 36 |
6 files changed, 174 insertions, 94 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 4c8c73548c..b0cb48c80c 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -766,6 +766,31 @@ py_library( ) py_library( + name = "smart_cond", + srcs = ["framework/smart_cond.py"], + srcs_version = "PY2AND3", + deps = [ + ":control_flow_ops", + ":tensor_util", + ], +) + +py_test( + name = "smart_cond_test", + size = "small", + srcs = ["framework/smart_cond_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":client_testlib", + ":constant_op", + ":framework_ops", + ":math_ops", + ":session", + ":smart_cond", + ], +) + +py_library( name = "sparse_tensor", srcs = ["framework/sparse_tensor.py"], srcs_version = "PY2AND3", @@ -4091,6 +4116,7 @@ py_library( ":control_flow_ops", ":framework_for_generated_wrappers", ":platform", + ":smart_cond", ":tensor_util", ":util", ":variable_scope", diff --git a/tensorflow/python/framework/smart_cond.py b/tensorflow/python/framework/smart_cond.py new file mode 100644 index 0000000000..f97bb01f54 --- /dev/null +++ b/tensorflow/python/framework/smart_cond.py @@ -0,0 +1,79 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""smart_cond and related utilties.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import control_flow_ops + + +def smart_cond(pred, true_fn=None, false_fn=None, name=None): + """Return either `true_fn()` if predicate `pred` is true else `false_fn()`. + + If `pred` is a bool or has a constant value, we return either `true_fn()` + or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both. + + Arguments: + pred: A scalar determining whether to return the result of `true_fn` or + `false_fn`. + true_fn: The callable to be performed if pred is true. + false_fn: The callable to be performed if pred is false. + name: Optional name prefix when using `tf.cond`. + + Returns: + Tensors returned by the call to either `true_fn` or `false_fn`. + + Raises: + TypeError: If `true_fn` or `false_fn` is not callable. + """ + if not callable(true_fn): + raise TypeError("`true_fn` must be callable.") + if not callable(false_fn): + raise TypeError("`false_fn` must be callable.") + + pred_value = smart_constant_value(pred) + if pred_value is not None: + if pred_value: + return true_fn() + else: + return false_fn() + else: + return control_flow_ops.cond(pred, true_fn=true_fn, false_fn=false_fn, + name=name) + + +def smart_constant_value(pred): + """Return the bool value for `pred`, or None if `pred` had a dynamic value. + + Arguments: + pred: A scalar, either a Python bool or tensor. + + Returns: + True or False if `pred` has a constant boolean value, None otherwise. + + Raises: + TypeError: If `pred` is not a Tensor or bool. + """ + if isinstance(pred, bool): + pred_value = pred + elif isinstance(pred, ops.Tensor): + pred_value = tensor_util.constant_value(pred) + else: + raise TypeError("`pred` must be a Tensor or a Python bool.") + return pred_value diff --git a/tensorflow/python/framework/smart_cond_test.py b/tensorflow/python/framework/smart_cond_test.py new file mode 100644 index 0000000000..b682506da0 --- /dev/null +++ b/tensorflow/python/framework/smart_cond_test.py @@ -0,0 +1,66 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.client import session +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import smart_cond +from tensorflow.python.framework import test_util +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import googletest + + +@test_util.with_c_api +class SmartCondTest(test_util.TensorFlowTestCase): + + def testSmartCondTrue(self): + with ops.Graph().as_default(): + with session.Session(): + x = constant_op.constant(2) + y = constant_op.constant(5) + z = smart_cond.smart_cond(True, lambda: math_ops.multiply(x, 16), + lambda: math_ops.multiply(y, 5)) + self.assertEqual(z.eval(), 32) + + def testSmartCondFalse(self): + with ops.Graph().as_default(): + with session.Session(): + x = constant_op.constant(4) + y = constant_op.constant(3) + z = smart_cond.smart_cond(False, lambda: math_ops.multiply(x, 16), + lambda: math_ops.multiply(y, 3)) + self.assertEqual(z.eval(), 9) + + def testSmartCondMissingArg1(self): + with ops.Graph().as_default(): + with session.Session(): + x = constant_op.constant(1) + with self.assertRaises(TypeError): + smart_cond.smart_cond(True, false_fn=lambda: x) + + def testSmartCondMissingArg2(self): + with ops.Graph().as_default(): + with session.Session(): + x = constant_op.constant(1) + with self.assertRaises(TypeError): + smart_cond.smart_cond(True, lambda: x) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/python/layers/utils.py b/tensorflow/python/layers/utils.py index 484c6fc466..3b156c36a2 100644 --- a/tensorflow/python/layers/utils.py +++ b/tensorflow/python/layers/utils.py @@ -24,6 +24,7 @@ from tensorflow.python.eager import context from tensorflow.python.ops import variables from tensorflow.python.ops import control_flow_ops from tensorflow.python.framework import ops +from tensorflow.python.framework import smart_cond as smart_module from tensorflow.python.framework import tensor_util from tensorflow.python.util import nest @@ -201,7 +202,7 @@ def smart_cond(pred, true_fn=None, false_fn=None, name=None): if isinstance(pred, variables.Variable): return control_flow_ops.cond( pred, true_fn=true_fn, false_fn=false_fn, name=name) - return control_flow_ops.smart_cond( + return smart_module.smart_cond( pred, true_fn=true_fn, false_fn=false_fn, name=name) @@ -228,7 +229,7 @@ def constant_value(pred): if isinstance(pred, variables.Variable): return None - return control_flow_ops.smart_constant_value(pred) + return smart_module.smart_constant_value(pred) def object_list_uid(object_list): diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index c78a5aa8c2..8d5ab72670 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -23,7 +23,6 @@ See the @{$python/control_flow_ops} guide. @@no_op @@count_up_to @@cond -@@smart_cond @@case @@while_loop @@logical_and @@ -2130,61 +2129,6 @@ def cond(pred, # pylint: enable=redefined-outer-name -def smart_cond(pred, true_fn=None, false_fn=None, name=None): - """Return either `true_fn()` if predicate `pred` is true else `false_fn()`. - - If `pred` is a bool or has a constant value, we return either `true_fn()` - or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both. - - Arguments: - pred: A scalar determining whether to return the result of `true_fn` or - `false_fn`. - true_fn: The callable to be performed if pred is true. - false_fn: The callable to be performed if pred is false. - name: Optional name prefix when using `tf.cond`. - - Returns: - Tensors returned by the call to either `true_fn` or `false_fn`. - - Raises: - TypeError: If `true_fn` or `false_fn` is not callable. - """ - if not callable(true_fn): - raise TypeError("`true_fn` must be callable.") - if not callable(false_fn): - raise TypeError("`false_fn` must be callable.") - - pred_value = smart_constant_value(pred) - if pred_value is not None: - if pred_value: - return true_fn() - else: - return false_fn() - else: - return cond(pred, true_fn=true_fn, false_fn=false_fn, name=name) - - -def smart_constant_value(pred): - """Return the bool value for `pred`, or None if `pred` had a dynamic value. - - Arguments: - pred: A scalar, either a Python bool or tensor. - - Returns: - True or False if `pred` has a constant boolean value, None otherwise. - - Raises: - TypeError: If `pred` is not a Tensor or bool. - """ - if isinstance(pred, bool): - pred_value = pred - elif isinstance(pred, ops.Tensor): - pred_value = tensor_util.constant_value(pred) - else: - raise TypeError("`pred` must be a Tensor or a Python bool.") - return pred_value - - def _resource_safe_shape(t): """Returns the shape of t or the variable it points to.""" if t.dtype == dtypes.resource: diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py index adc8c51e11..f22f3059d1 100644 --- a/tensorflow/python/ops/control_flow_ops_test.py +++ b/tensorflow/python/ops/control_flow_ops_test.py @@ -350,42 +350,6 @@ class SwitchTestCase(test_util.TensorFlowTestCase): @test_util.with_c_api -class SmartCondTest(test_util.TensorFlowTestCase): - - def testSmartCondTrue(self): - with ops.Graph().as_default(): - with session.Session(): - x = constant_op.constant(2) - y = constant_op.constant(5) - z = control_flow_ops.smart_cond(True, lambda: math_ops.multiply(x, 16), - lambda: math_ops.multiply(y, 5)) - self.assertEqual(z.eval(), 32) - - def testSmartCondFalse(self): - with ops.Graph().as_default(): - with session.Session(): - x = constant_op.constant(4) - y = constant_op.constant(3) - z = control_flow_ops.smart_cond(False, lambda: math_ops.multiply(x, 16), - lambda: math_ops.multiply(y, 3)) - self.assertEqual(z.eval(), 9) - - def testSmartCondMissingArg1(self): - with ops.Graph().as_default(): - with session.Session(): - x = constant_op.constant(1) - with self.assertRaises(TypeError): - control_flow_ops.smart_cond(True, false_fn=lambda: x) - - def testSmartCondMissingArg2(self): - with ops.Graph().as_default(): - with session.Session(): - x = constant_op.constant(1) - with self.assertRaises(TypeError): - control_flow_ops.smart_cond(True, lambda: x) - - -@test_util.with_c_api class CondTest(test_util.TensorFlowTestCase): def testCondTrue(self): |