aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/control_flow_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/control_flow_ops.py')
-rw-r--r--tensorflow/python/ops/control_flow_ops.py93
1 files changed, 93 insertions, 0 deletions
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index a2d605532a..b4bfc0fe47 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -23,6 +23,7 @@ See the @{$python/control_flow_ops} guide.
@@no_op
@@count_up_to
@@cond
+@@smart_cond
@@case
@@while_loop
@@logical_and
@@ -2129,6 +2130,61 @@ def cond(pred,
# pylint: enable=redefined-outer-name
+def smart_cond(pred, true_fn=None, false_fn=None, name=None):
+ """Return either `true_fn()` if predicate `pred` is true else `false_fn()`.
+
+ If `pred` is a bool or has a constant value, we return either `true_fn()`
+ or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both.
+
+ Arguments:
+ pred: A scalar determining whether to return the result of `true_fn` or
+ `false_fn`.
+ true_fn: The callable to be performed if pred is true.
+ false_fn: The callable to be performed if pred is false.
+ name: Optional name prefix when using `tf.cond`.
+
+ Returns:
+ Tensors returned by the call to either `true_fn` or `false_fn`.
+
+ Raises:
+ TypeError: If `true_fn` or `false_fn` is not callable.
+ """
+ if not callable(true_fn):
+ raise TypeError("`true_fn` must be callable.")
+ if not callable(false_fn):
+ raise TypeError("`false_fn` must be callable.")
+
+ pred_value = smart_constant_value(pred)
+ if pred_value is not None:
+ if pred_value:
+ return true_fn()
+ else:
+ return false_fn()
+ else:
+ return cond(pred, true_fn=true_fn, false_fn=false_fn, name=name)
+
+
+def smart_constant_value(pred):
+ """Return the bool value for `pred`, or None if `pred` had a dynamic value.
+
+ Arguments:
+ pred: A scalar, either a Python bool or tensor.
+
+ Returns:
+ True or False if `pred` has a constant boolean value, None otherwise.
+
+ Raises:
+ TypeError: If `pred` is not a Tensor or bool.
+ """
+ if isinstance(pred, bool):
+ pred_value = pred
+ elif isinstance(pred, ops.Tensor):
+ pred_value = tensor_util.constant_value(pred)
+ else:
+ raise TypeError("`pred` must be a Tensor or a Python bool.")
+ return pred_value
+
+
def _resource_safe_shape(t):
"""Returns the shape of t or the variable it points to."""
if t.dtype == dtypes.resource:
@@ -3126,6 +3182,43 @@ def while_loop(cond,
shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
```
+ Example which demonstrates non-strict semantics: In the following
+ example, the final value of the counter `i` does not depend on `x`. So
+ the `while_loop` can increment the counter parallel to updates of `x`.
+ However, because the loop counter at one loop iteration depends
+ on the value at the previous iteration, the loop counter itself cannot
+ be incremented in parallel. Hence if we just want the final value of the
+ counter (which we print on the line `print(sess.run(i))`), then
+ `x` will never be incremented, but the counter will be updated on a
+ single thread. Conversely, if we want the value of the output (which we
+ print on the line `print(sess.run(out).shape)`), then the counter may be
+ incremented on its own thread, while `x` can be incremented in
+ parallel on a separate thread. In the extreme case, it is conceivable
+ that the thread incrementing the counter runs until completion before
+ `x` is incremented even a single time. The only thing that can never
+ happen is that the thread updating `x` can never get ahead of the
+ counter thread because the thread incrementing `x` depends on the value
+ of the counter.
+ ```python
+ import tensorflow as tf
+
+ n = 10000
+ x = tf.constant(list(range(n)))
+ c = lambda i, x: i < n
+ b = lambda i, x: (tf.Print(i + 1, [i]), tf.Print(x + 1, [i], "x:"))
+ i, out = tf.while_loop(c, b, (0, x))
+ with tf.Session() as sess:
+ print(sess.run(i)) # prints [0] ... [9999]
+
+ # The following line may increment the counter and x in parallel.
+ # The counter thread may get ahead of the other thread, but not the
+ # other way around. So you may see things like
+ # [9996] x:[9987]
+ # meaning that the counter thread is on iteration 9996,
+ # while the other thread is on iteration 9987
+ print(sess.run(out).shape)
+ ```
+
"""
with ops.name_scope(name, "while", loop_vars):
if not loop_vars: