diff options
author | 2016-01-14 10:46:16 -0800 | |
---|---|---|
committer | 2016-01-14 18:56:23 -0800 | |
commit | 11a8da42cc9aa30983be5c5e0569b09625febf42 (patch) | |
tree | e736814d543b1e0d8f9e1966210142924f3a773c | |
parent | 90644c9db115ee99a2a14ba462d70b822ec0804d (diff) |
Add tf.cond to public API
There was some concern that it used weird stuff like switch in the
implementation, but it's widely used and the weird stuff can remain *not* part
of the public API even if this is.
Change: 112166819
-rw-r--r-- | tensorflow/g3doc/api_docs/python/control_flow_ops.md | 41 | ||||
-rw-r--r-- | tensorflow/g3doc/api_docs/python/index.md | 1 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/control_flow_ops_py_test.py | 61 | ||||
-rw-r--r-- | tensorflow/python/ops/control_flow_ops.py | 1 | ||||
-rw-r--r-- | tensorflow/python/ops/standard_ops.py | 1 |
5 files changed, 72 insertions, 33 deletions
diff --git a/tensorflow/g3doc/api_docs/python/control_flow_ops.md b/tensorflow/g3doc/api_docs/python/control_flow_ops.md index 45cb06b81b..111a426dd4 100644 --- a/tensorflow/g3doc/api_docs/python/control_flow_ops.md +++ b/tensorflow/g3doc/api_docs/python/control_flow_ops.md @@ -139,6 +139,47 @@ easier to chain operations that need to use the updated value. input, the values produced will all be distinct. +- - - + +### `tf.cond(pred, fn1, fn2, name=None)` {#cond} + +Return either 'fn1()' or 'fn2()' based on the boolean predicate 'pred'. + +`fn1` and `fn2` both return lists of output tensors. `fn1` and `fn2` must have +the same number and type of outputs. + +##### Args: + + +* <b>`pred`</b>: A scalar determining whether to return the result of `fn1` or `fn2`. +* <b>`fn1`</b>: The function to be performed if pred is true. +* <b>`fn2`</b>: The function to be performed if pref is false. +* <b>`name`</b>: Optional name prefix for the returned tensors. + +##### Returns: + + Tensors returned by the call to either `fn1` or `fn2`. If the functions + return a singleton list, the element is extracted from the list. + +##### Raises: + + +* <b>`TypeError`</b>: if `fn1` or `fn2` is not callable. +* <b>`ValueError`</b>: if `fn1` and `fn2` do not return the same number of tensors, or + return tensors of different types. + + +* <b>`Example`</b>: +```python + x = constant(2) + y = constant(5) + def f1(): return constant(17) + def f2(): return constant(23) + r = cond(math_ops.less(x, y), f1, f2) + # r is set to f1() +``` + + ## Logical Operators diff --git a/tensorflow/g3doc/api_docs/python/index.md b/tensorflow/g3doc/api_docs/python/index.md index 388f8d8c0d..c0d6dda076 100644 --- a/tensorflow/g3doc/api_docs/python/index.md +++ b/tensorflow/g3doc/api_docs/python/index.md @@ -184,6 +184,7 @@ * [`add_check_numerics_ops`](../../api_docs/python/control_flow_ops.md#add_check_numerics_ops) * [`Assert`](../../api_docs/python/control_flow_ops.md#Assert) * [`check_numerics`](../../api_docs/python/control_flow_ops.md#check_numerics) + * [`cond`](../../api_docs/python/control_flow_ops.md#cond) * [`count_up_to`](../../api_docs/python/control_flow_ops.md#count_up_to) * [`equal`](../../api_docs/python/control_flow_ops.md#equal) * [`greater`](../../api_docs/python/control_flow_ops.md#greater) diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 80b0d0ef4f..bbe3bbd8b4 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -335,7 +335,7 @@ class ControlFlowTest(tf.test.TestCase): fn1 = lambda: tf.add(values, 1) fn2 = lambda: tf.sub(values, 1) with self.assertRaisesRegexp(TypeError, "must not be a Python bool"): - _ = control_flow_ops.cond(False, fn1, fn2) + _ = tf.cond(False, fn1, fn2) def testCondIndexedSlices(self): with self.test_session(): @@ -345,7 +345,7 @@ class ControlFlowTest(tf.test.TestCase): pred = tf.less(1, 2) fn1 = lambda: tf.IndexedSlices(tf.add(x.values, 1), indices) fn2 = lambda: tf.IndexedSlices(tf.sub(x.values, 1), indices) - r = control_flow_ops.cond(pred, fn1, fn2) + r = tf.cond(pred, fn1, fn2) val = r.values.eval() ind = r.indices.eval() @@ -362,7 +362,7 @@ class ControlFlowTest(tf.test.TestCase): pred = tf.less(1, 2) fn1 = lambda: tf.IndexedSlices(tf.add(x.values, 1), i_32) fn2 = lambda: tf.IndexedSlices(tf.sub(x.values, 1), i_64) - r = control_flow_ops.cond(pred, fn1, fn2) + r = tf.cond(pred, fn1, fn2) val = r.values.eval() ind = r.indices.eval() @@ -380,7 +380,7 @@ class ControlFlowTest(tf.test.TestCase): pred = tf.less(1.0, 2.0) fn1 = lambda: tf.add(v, 1.0) fn2 = lambda: tf.sub(x, 1.0) - r = control_flow_ops.cond(pred, fn1, fn2) + r = tf.cond(pred, fn1, fn2) for op in x.graph.get_operations(): if op.name == "cond/Add/Switch": @@ -392,7 +392,7 @@ class ControlFlowTest(tf.test.TestCase): pred = tf.less(1, 2) fn1 = lambda: tf.add(x, 1) fn2 = lambda: tf.sub(x, 1) - r = control_flow_ops.cond(pred, fn1, fn2) + r = tf.cond(pred, fn1, fn2) result = r.eval() self.assertTrue(check_op_order(x.graph)) @@ -405,8 +405,7 @@ class ControlFlowTest(tf.test.TestCase): def testCond_2(self): with self.test_session(): x = tf.constant(10) - r = control_flow_ops.cond(tf.less(1, 0), lambda: tf.add(x, 1), - lambda: tf.sub(x, 1)) + r = tf.cond(tf.less(1, 0), lambda: tf.add(x, 1), lambda: tf.sub(x, 1)) result = r.eval() self.assertTrue(check_op_order(x.graph)) self.assertAllEqual(9, result) @@ -417,8 +416,8 @@ class ControlFlowTest(tf.test.TestCase): pred = tf.less(1, 2) fn1 = lambda: tf.add(x, 1) fn2 = lambda: tf.sub(x, 1) - fn3 = lambda: tf.add(control_flow_ops.cond(pred, fn1, fn2), 1) - r = control_flow_ops.cond(pred, fn3, fn2) + fn3 = lambda: tf.add(tf.cond(pred, fn1, fn2), 1) + r = tf.cond(pred, fn3, fn2) result = r.eval() self.assertTrue(check_op_order(x.graph)) @@ -435,7 +434,7 @@ class ControlFlowTest(tf.test.TestCase): pred = tf.greater(age, max_age) fn1 = lambda: [tf.assign(v1, 1).op, tf.assign(v2, 2).op] fn2 = lambda: [tf.assign(v3, 3).op, tf.constant(10).op] - r = control_flow_ops.cond(pred, fn1, fn2) + r = tf.cond(pred, fn1, fn2) tf.initialize_all_variables().run() self.assertEqual(len(r), 2) @@ -452,7 +451,7 @@ class ControlFlowTest(tf.test.TestCase): count = tf.constant(0, name="count") def body(i): - return control_flow_ops.cond( + return tf.cond( alive, lambda: [tf.less(i, 3), tf.add(count, 1)], lambda: [alive, count]) @@ -468,7 +467,7 @@ class ControlFlowTest(tf.test.TestCase): pred = tf.greater(age, 4) fn1 = lambda: age fn2 = lambda: v1 - r = control_flow_ops.cond(pred, fn1, fn2) + r = tf.cond(pred, fn1, fn2) tf.initialize_all_variables().run() result = r.eval() @@ -481,7 +480,7 @@ class ControlFlowTest(tf.test.TestCase): pred = tf.less(1, 2) fn1 = lambda: [tf.add(x, 1), tf.add(x, 2)] fn2 = lambda: [y, y] - r = control_flow_ops.cond(pred, fn1, fn2) + r = tf.cond(pred, fn1, fn2) self.assertAllEqual([11, 12], sess.run(r)) @@ -491,7 +490,7 @@ class ControlFlowTest(tf.test.TestCase): pred = tf.less(1, 2) fn1 = lambda: tf.identity(x) fn2 = lambda: tf.identity(x) - r = control_flow_ops.cond(pred, fn1, fn2) + r = tf.cond(pred, fn1, fn2) grad = tf.gradients(r, [x])[0] result = grad.eval() @@ -504,7 +503,7 @@ class ControlFlowTest(tf.test.TestCase): pred = tf.less(c, 2) fn1 = lambda: tf.mul(x, 42.0) fn2 = lambda: tf.mul(x, 3.0) - r = control_flow_ops.cond(pred, fn1, fn2) + r = tf.cond(pred, fn1, fn2) grad = tf.gradients(r, [x])[0] self.assertAllEqual(42.0, grad.eval(feed_dict={c: 1})) @@ -517,7 +516,7 @@ class ControlFlowTest(tf.test.TestCase): pred = tf.less(c, 2) fn1 = lambda: tf.identity(v1) fn2 = lambda: tf.gather(v1, [1, 1]) - r = control_flow_ops.cond(pred, fn1, fn2) + r = tf.cond(pred, fn1, fn2) grad = tf.gradients(r, [v1])[0] tf.initialize_all_variables().run() # Should just be [1, 1], but possibly a sparse representation @@ -806,9 +805,9 @@ class ControlFlowTest(tf.test.TestCase): n = tf.convert_to_tensor(0, name="n") c = lambda x: tf.less(x, 10) b = lambda x: tf.add(x, 1) - r = control_flow_ops.cond(tf.less(0, 1), - lambda: control_flow_ops.While(c, b, [n]), - lambda: n) + r = tf.cond(tf.less(0, 1), + lambda: control_flow_ops.While(c, b, [n]), + lambda: n) result = r.eval() self.assertTrue(check_op_order(n.graph)) @@ -819,8 +818,8 @@ class ControlFlowTest(tf.test.TestCase): n = tf.convert_to_tensor(0) c = lambda x: tf.less(x, 10) b = lambda x: tf.add(x, 1) - r = control_flow_ops.cond(tf.less(1, 0), lambda: tf.add(n, 1), - lambda: control_flow_ops.While(c, b, [n])) + r = tf.cond(tf.less(1, 0), lambda: tf.add(n, 1), + lambda: control_flow_ops.While(c, b, [n])) result = r.eval() self.assertTrue(check_op_order(n.graph)) @@ -832,9 +831,8 @@ class ControlFlowTest(tf.test.TestCase): n = tf.convert_to_tensor(10, name="n") one = tf.convert_to_tensor(1, name="one") c = lambda x: tf.less(x, n) - b = lambda x: control_flow_ops.cond(tf.constant(True), - lambda: tf.add(x, one), - lambda: tf.sub(x, one)) + b = lambda x: tf.cond(tf.constant(True), lambda: tf.add(x, one), + lambda: tf.sub(x, one)) r = control_flow_ops.While(c, b, [i]) result = r.eval() @@ -845,9 +843,7 @@ class ControlFlowTest(tf.test.TestCase): with self.test_session(): n = tf.convert_to_tensor(0, name="n") c = lambda x: tf.less(x, 10) - b = lambda x: control_flow_ops.cond(tf.constant(True), - lambda: tf.add(x, 1), - lambda: n) + b = lambda x: tf.cond(tf.constant(True), lambda: tf.add(x, 1), lambda: n) r = control_flow_ops.While(c, b, [n]) result = r.eval() @@ -858,9 +854,8 @@ class ControlFlowTest(tf.test.TestCase): with self.test_session(): n = tf.convert_to_tensor(0) c = lambda x: tf.less(x, 10) - b = lambda x: control_flow_ops.cond(tf.less(0, 1), - lambda: tf.add(x, 1), - lambda: tf.sub(x, 1)) + b = lambda x: tf.cond(tf.less(0, 1), lambda: tf.add(x, 1), + lambda: tf.sub(x, 1)) r = control_flow_ops.While(c, b, [n]) result = r.eval() @@ -1097,7 +1092,7 @@ class ControlFlowTest(tf.test.TestCase): one = tf.convert_to_tensor(1, name="one") two = tf.convert_to_tensor(2, name="two") p = tf.greater_equal(c, 1) - i = control_flow_ops.cond(p, lambda: one, lambda: two) + i = tf.cond(p, lambda: one, lambda: two) self.assertTrue(isinstance(i, tf.Tensor)) # True case: c = 2 is >= 1 @@ -1117,7 +1112,7 @@ class ControlFlowTest(tf.test.TestCase): def l1(): return tf.reduce_sum(tf.abs(x)) - i = control_flow_ops.cond(tf.equal(d, 2), l2, l1) + i = tf.cond(tf.equal(d, 2), l2, l1) self.assertEqual(4.0, i.eval(feed_dict={d: 1})) self.assertAllClose(2.0 * math.sqrt(2), i.eval(feed_dict={d: 2})) @@ -1135,7 +1130,7 @@ class ControlFlowTest(tf.test.TestCase): def b(): return tf.assign(v, two) - i = control_flow_ops.cond(p, a, b) + i = tf.cond(p, a, b) self.assertTrue(isinstance(i, tf.Tensor)) tf.initialize_all_variables().run() diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 3c51ccb357..7b5cc3e73f 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -23,6 +23,7 @@ the execution of operations and add conditional dependencies to your graph. @@group @@no_op @@count_up_to +@@cond ## Logical Operators diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py index e2180737df..9af536b971 100644 --- a/tensorflow/python/ops/standard_ops.py +++ b/tensorflow/python/ops/standard_ops.py @@ -32,6 +32,7 @@ from tensorflow.python.ops.clip_ops import * from tensorflow.python.ops.control_flow_ops import group from tensorflow.python.ops.control_flow_ops import no_op from tensorflow.python.ops.control_flow_ops import tuple +from tensorflow.python.ops.control_flow_ops import cond from tensorflow.python.ops.data_flow_ops import * from tensorflow.python.ops.gradients import * from tensorflow.python.ops.init_ops import * |