aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Geoffrey Irving <geoffreyi@google.com>2016-01-14 10:46:16 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2016-01-14 18:56:23 -0800
commit11a8da42cc9aa30983be5c5e0569b09625febf42 (patch)
treee736814d543b1e0d8f9e1966210142924f3a773c
parent90644c9db115ee99a2a14ba462d70b822ec0804d (diff)
Add tf.cond to public API
There was some concern that it used weird stuff like switch in the implementation, but it's widely used and the weird stuff can remain *not* part of the public API even if this is. Change: 112166819
-rw-r--r--tensorflow/g3doc/api_docs/python/control_flow_ops.md41
-rw-r--r--tensorflow/g3doc/api_docs/python/index.md1
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py61
-rw-r--r--tensorflow/python/ops/control_flow_ops.py1
-rw-r--r--tensorflow/python/ops/standard_ops.py1
5 files changed, 72 insertions, 33 deletions
diff --git a/tensorflow/g3doc/api_docs/python/control_flow_ops.md b/tensorflow/g3doc/api_docs/python/control_flow_ops.md
index 45cb06b81b..111a426dd4 100644
--- a/tensorflow/g3doc/api_docs/python/control_flow_ops.md
+++ b/tensorflow/g3doc/api_docs/python/control_flow_ops.md
@@ -139,6 +139,47 @@ easier to chain operations that need to use the updated value.
input, the values produced will all be distinct.
+- - -
+
+### `tf.cond(pred, fn1, fn2, name=None)` {#cond}
+
+Return either 'fn1()' or 'fn2()' based on the boolean predicate 'pred'.
+
+`fn1` and `fn2` both return lists of output tensors. `fn1` and `fn2` must have
+the same number and type of outputs.
+
+##### Args:
+
+
+* <b>`pred`</b>: A scalar determining whether to return the result of `fn1` or `fn2`.
+* <b>`fn1`</b>: The function to be performed if pred is true.
+* <b>`fn2`</b>: The function to be performed if pref is false.
+* <b>`name`</b>: Optional name prefix for the returned tensors.
+
+##### Returns:
+
+ Tensors returned by the call to either `fn1` or `fn2`. If the functions
+ return a singleton list, the element is extracted from the list.
+
+##### Raises:
+
+
+* <b>`TypeError`</b>: if `fn1` or `fn2` is not callable.
+* <b>`ValueError`</b>: if `fn1` and `fn2` do not return the same number of tensors, or
+ return tensors of different types.
+
+
+* <b>`Example`</b>:
+```python
+ x = constant(2)
+ y = constant(5)
+ def f1(): return constant(17)
+ def f2(): return constant(23)
+ r = cond(math_ops.less(x, y), f1, f2)
+ # r is set to f1()
+```
+
+
## Logical Operators
diff --git a/tensorflow/g3doc/api_docs/python/index.md b/tensorflow/g3doc/api_docs/python/index.md
index 388f8d8c0d..c0d6dda076 100644
--- a/tensorflow/g3doc/api_docs/python/index.md
+++ b/tensorflow/g3doc/api_docs/python/index.md
@@ -184,6 +184,7 @@
* [`add_check_numerics_ops`](../../api_docs/python/control_flow_ops.md#add_check_numerics_ops)
* [`Assert`](../../api_docs/python/control_flow_ops.md#Assert)
* [`check_numerics`](../../api_docs/python/control_flow_ops.md#check_numerics)
+ * [`cond`](../../api_docs/python/control_flow_ops.md#cond)
* [`count_up_to`](../../api_docs/python/control_flow_ops.md#count_up_to)
* [`equal`](../../api_docs/python/control_flow_ops.md#equal)
* [`greater`](../../api_docs/python/control_flow_ops.md#greater)
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 80b0d0ef4f..bbe3bbd8b4 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -335,7 +335,7 @@ class ControlFlowTest(tf.test.TestCase):
fn1 = lambda: tf.add(values, 1)
fn2 = lambda: tf.sub(values, 1)
with self.assertRaisesRegexp(TypeError, "must not be a Python bool"):
- _ = control_flow_ops.cond(False, fn1, fn2)
+ _ = tf.cond(False, fn1, fn2)
def testCondIndexedSlices(self):
with self.test_session():
@@ -345,7 +345,7 @@ class ControlFlowTest(tf.test.TestCase):
pred = tf.less(1, 2)
fn1 = lambda: tf.IndexedSlices(tf.add(x.values, 1), indices)
fn2 = lambda: tf.IndexedSlices(tf.sub(x.values, 1), indices)
- r = control_flow_ops.cond(pred, fn1, fn2)
+ r = tf.cond(pred, fn1, fn2)
val = r.values.eval()
ind = r.indices.eval()
@@ -362,7 +362,7 @@ class ControlFlowTest(tf.test.TestCase):
pred = tf.less(1, 2)
fn1 = lambda: tf.IndexedSlices(tf.add(x.values, 1), i_32)
fn2 = lambda: tf.IndexedSlices(tf.sub(x.values, 1), i_64)
- r = control_flow_ops.cond(pred, fn1, fn2)
+ r = tf.cond(pred, fn1, fn2)
val = r.values.eval()
ind = r.indices.eval()
@@ -380,7 +380,7 @@ class ControlFlowTest(tf.test.TestCase):
pred = tf.less(1.0, 2.0)
fn1 = lambda: tf.add(v, 1.0)
fn2 = lambda: tf.sub(x, 1.0)
- r = control_flow_ops.cond(pred, fn1, fn2)
+ r = tf.cond(pred, fn1, fn2)
for op in x.graph.get_operations():
if op.name == "cond/Add/Switch":
@@ -392,7 +392,7 @@ class ControlFlowTest(tf.test.TestCase):
pred = tf.less(1, 2)
fn1 = lambda: tf.add(x, 1)
fn2 = lambda: tf.sub(x, 1)
- r = control_flow_ops.cond(pred, fn1, fn2)
+ r = tf.cond(pred, fn1, fn2)
result = r.eval()
self.assertTrue(check_op_order(x.graph))
@@ -405,8 +405,7 @@ class ControlFlowTest(tf.test.TestCase):
def testCond_2(self):
with self.test_session():
x = tf.constant(10)
- r = control_flow_ops.cond(tf.less(1, 0), lambda: tf.add(x, 1),
- lambda: tf.sub(x, 1))
+ r = tf.cond(tf.less(1, 0), lambda: tf.add(x, 1), lambda: tf.sub(x, 1))
result = r.eval()
self.assertTrue(check_op_order(x.graph))
self.assertAllEqual(9, result)
@@ -417,8 +416,8 @@ class ControlFlowTest(tf.test.TestCase):
pred = tf.less(1, 2)
fn1 = lambda: tf.add(x, 1)
fn2 = lambda: tf.sub(x, 1)
- fn3 = lambda: tf.add(control_flow_ops.cond(pred, fn1, fn2), 1)
- r = control_flow_ops.cond(pred, fn3, fn2)
+ fn3 = lambda: tf.add(tf.cond(pred, fn1, fn2), 1)
+ r = tf.cond(pred, fn3, fn2)
result = r.eval()
self.assertTrue(check_op_order(x.graph))
@@ -435,7 +434,7 @@ class ControlFlowTest(tf.test.TestCase):
pred = tf.greater(age, max_age)
fn1 = lambda: [tf.assign(v1, 1).op, tf.assign(v2, 2).op]
fn2 = lambda: [tf.assign(v3, 3).op, tf.constant(10).op]
- r = control_flow_ops.cond(pred, fn1, fn2)
+ r = tf.cond(pred, fn1, fn2)
tf.initialize_all_variables().run()
self.assertEqual(len(r), 2)
@@ -452,7 +451,7 @@ class ControlFlowTest(tf.test.TestCase):
count = tf.constant(0, name="count")
def body(i):
- return control_flow_ops.cond(
+ return tf.cond(
alive, lambda: [tf.less(i, 3), tf.add(count, 1)],
lambda: [alive, count])
@@ -468,7 +467,7 @@ class ControlFlowTest(tf.test.TestCase):
pred = tf.greater(age, 4)
fn1 = lambda: age
fn2 = lambda: v1
- r = control_flow_ops.cond(pred, fn1, fn2)
+ r = tf.cond(pred, fn1, fn2)
tf.initialize_all_variables().run()
result = r.eval()
@@ -481,7 +480,7 @@ class ControlFlowTest(tf.test.TestCase):
pred = tf.less(1, 2)
fn1 = lambda: [tf.add(x, 1), tf.add(x, 2)]
fn2 = lambda: [y, y]
- r = control_flow_ops.cond(pred, fn1, fn2)
+ r = tf.cond(pred, fn1, fn2)
self.assertAllEqual([11, 12], sess.run(r))
@@ -491,7 +490,7 @@ class ControlFlowTest(tf.test.TestCase):
pred = tf.less(1, 2)
fn1 = lambda: tf.identity(x)
fn2 = lambda: tf.identity(x)
- r = control_flow_ops.cond(pred, fn1, fn2)
+ r = tf.cond(pred, fn1, fn2)
grad = tf.gradients(r, [x])[0]
result = grad.eval()
@@ -504,7 +503,7 @@ class ControlFlowTest(tf.test.TestCase):
pred = tf.less(c, 2)
fn1 = lambda: tf.mul(x, 42.0)
fn2 = lambda: tf.mul(x, 3.0)
- r = control_flow_ops.cond(pred, fn1, fn2)
+ r = tf.cond(pred, fn1, fn2)
grad = tf.gradients(r, [x])[0]
self.assertAllEqual(42.0, grad.eval(feed_dict={c: 1}))
@@ -517,7 +516,7 @@ class ControlFlowTest(tf.test.TestCase):
pred = tf.less(c, 2)
fn1 = lambda: tf.identity(v1)
fn2 = lambda: tf.gather(v1, [1, 1])
- r = control_flow_ops.cond(pred, fn1, fn2)
+ r = tf.cond(pred, fn1, fn2)
grad = tf.gradients(r, [v1])[0]
tf.initialize_all_variables().run()
# Should just be [1, 1], but possibly a sparse representation
@@ -806,9 +805,9 @@ class ControlFlowTest(tf.test.TestCase):
n = tf.convert_to_tensor(0, name="n")
c = lambda x: tf.less(x, 10)
b = lambda x: tf.add(x, 1)
- r = control_flow_ops.cond(tf.less(0, 1),
- lambda: control_flow_ops.While(c, b, [n]),
- lambda: n)
+ r = tf.cond(tf.less(0, 1),
+ lambda: control_flow_ops.While(c, b, [n]),
+ lambda: n)
result = r.eval()
self.assertTrue(check_op_order(n.graph))
@@ -819,8 +818,8 @@ class ControlFlowTest(tf.test.TestCase):
n = tf.convert_to_tensor(0)
c = lambda x: tf.less(x, 10)
b = lambda x: tf.add(x, 1)
- r = control_flow_ops.cond(tf.less(1, 0), lambda: tf.add(n, 1),
- lambda: control_flow_ops.While(c, b, [n]))
+ r = tf.cond(tf.less(1, 0), lambda: tf.add(n, 1),
+ lambda: control_flow_ops.While(c, b, [n]))
result = r.eval()
self.assertTrue(check_op_order(n.graph))
@@ -832,9 +831,8 @@ class ControlFlowTest(tf.test.TestCase):
n = tf.convert_to_tensor(10, name="n")
one = tf.convert_to_tensor(1, name="one")
c = lambda x: tf.less(x, n)
- b = lambda x: control_flow_ops.cond(tf.constant(True),
- lambda: tf.add(x, one),
- lambda: tf.sub(x, one))
+ b = lambda x: tf.cond(tf.constant(True), lambda: tf.add(x, one),
+ lambda: tf.sub(x, one))
r = control_flow_ops.While(c, b, [i])
result = r.eval()
@@ -845,9 +843,7 @@ class ControlFlowTest(tf.test.TestCase):
with self.test_session():
n = tf.convert_to_tensor(0, name="n")
c = lambda x: tf.less(x, 10)
- b = lambda x: control_flow_ops.cond(tf.constant(True),
- lambda: tf.add(x, 1),
- lambda: n)
+ b = lambda x: tf.cond(tf.constant(True), lambda: tf.add(x, 1), lambda: n)
r = control_flow_ops.While(c, b, [n])
result = r.eval()
@@ -858,9 +854,8 @@ class ControlFlowTest(tf.test.TestCase):
with self.test_session():
n = tf.convert_to_tensor(0)
c = lambda x: tf.less(x, 10)
- b = lambda x: control_flow_ops.cond(tf.less(0, 1),
- lambda: tf.add(x, 1),
- lambda: tf.sub(x, 1))
+ b = lambda x: tf.cond(tf.less(0, 1), lambda: tf.add(x, 1),
+ lambda: tf.sub(x, 1))
r = control_flow_ops.While(c, b, [n])
result = r.eval()
@@ -1097,7 +1092,7 @@ class ControlFlowTest(tf.test.TestCase):
one = tf.convert_to_tensor(1, name="one")
two = tf.convert_to_tensor(2, name="two")
p = tf.greater_equal(c, 1)
- i = control_flow_ops.cond(p, lambda: one, lambda: two)
+ i = tf.cond(p, lambda: one, lambda: two)
self.assertTrue(isinstance(i, tf.Tensor))
# True case: c = 2 is >= 1
@@ -1117,7 +1112,7 @@ class ControlFlowTest(tf.test.TestCase):
def l1():
return tf.reduce_sum(tf.abs(x))
- i = control_flow_ops.cond(tf.equal(d, 2), l2, l1)
+ i = tf.cond(tf.equal(d, 2), l2, l1)
self.assertEqual(4.0, i.eval(feed_dict={d: 1}))
self.assertAllClose(2.0 * math.sqrt(2), i.eval(feed_dict={d: 2}))
@@ -1135,7 +1130,7 @@ class ControlFlowTest(tf.test.TestCase):
def b():
return tf.assign(v, two)
- i = control_flow_ops.cond(p, a, b)
+ i = tf.cond(p, a, b)
self.assertTrue(isinstance(i, tf.Tensor))
tf.initialize_all_variables().run()
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 3c51ccb357..7b5cc3e73f 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -23,6 +23,7 @@ the execution of operations and add conditional dependencies to your graph.
@@group
@@no_op
@@count_up_to
+@@cond
## Logical Operators
diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py
index e2180737df..9af536b971 100644
--- a/tensorflow/python/ops/standard_ops.py
+++ b/tensorflow/python/ops/standard_ops.py
@@ -32,6 +32,7 @@ from tensorflow.python.ops.clip_ops import *
from tensorflow.python.ops.control_flow_ops import group
from tensorflow.python.ops.control_flow_ops import no_op
from tensorflow.python.ops.control_flow_ops import tuple
+from tensorflow.python.ops.control_flow_ops import cond
from tensorflow.python.ops.data_flow_ops import *
from tensorflow.python.ops.gradients import *
from tensorflow.python.ops.init_ops import *