aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/layers/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/layers/utils.py')
-rw-r--r--tensorflow/python/layers/utils.py83
1 files changed, 33 insertions, 50 deletions
diff --git a/tensorflow/python/layers/utils.py b/tensorflow/python/layers/utils.py
index 1195284024..484c6fc466 100644
--- a/tensorflow/python/layers/utils.py
+++ b/tensorflow/python/layers/utils.py
@@ -179,73 +179,56 @@ def deconv_output_length(input_length, filter_size, padding, stride):
return input_length
-def smart_cond(pred, fn1, fn2, name=None):
- """Return either `fn1()` or `fn2()` based on the boolean predicate `pred`.
+def smart_cond(pred, true_fn=None, false_fn=None, name=None):
+ """Return either `true_fn()` if predicate `pred` is true else `false_fn()`.
- If `pred` is a bool or has a constant value, we return either `fn1()`
- or `fn2()`, otherwise we use `tf.cond` to dynamically route to both.
+ If `pred` is a bool or has a constant value, we return either `true_fn()`
+ or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both.
Arguments:
- pred: A scalar determining whether to return the result of `fn1` or `fn2`.
- fn1: The callable to be performed if pred is true.
- fn2: The callable to be performed if pred is false.
+ pred: A scalar determining whether to return the result of `true_fn` or
+ `false_fn`.
+ true_fn: The callable to be performed if pred is true.
+ false_fn: The callable to be performed if pred is false.
name: Optional name prefix when using `tf.cond`.
Returns:
- Tensors returned by the call to either `fn1` or `fn2`.
+ Tensors returned by the call to either `true_fn` or `false_fn`.
Raises:
- TypeError: If `fn1` or `fn2` is not callable.
+ TypeError: If `true_fn` or `false_fn` is not callable.
"""
- if not callable(fn1):
- raise TypeError('`fn1` must be callable.')
- if not callable(fn2):
- raise TypeError('`fn2` must be callable.')
-
- if context.in_eager_mode():
- if pred:
- return fn1()
- else:
- return fn2()
-
- pred_value = constant_value(pred)
- if pred_value is not None:
- if pred_value:
- return fn1()
- else:
- return fn2()
- else:
- return control_flow_ops.cond(pred, true_fn=fn1, false_fn=fn2, name=name)
+ if isinstance(pred, variables.Variable):
+ return control_flow_ops.cond(
+ pred, true_fn=true_fn, false_fn=false_fn, name=name)
+ return control_flow_ops.smart_cond(
+ pred, true_fn=true_fn, false_fn=false_fn, name=name)
def constant_value(pred):
"""Return the bool value for `pred`, or None if `pred` had a dynamic value.
- Arguments:
- pred: A scalar, either a Python bool or a TensorFlow boolean variable
- or tensor, or the Python integer 1 or 0.
+ Arguments:
+ pred: A scalar, either a Python bool or a TensorFlow boolean variable
+ or tensor, or the Python integer 1 or 0.
- Returns:
- True or False if `pred` has a constant boolean value, None otherwise.
+ Returns:
+ True or False if `pred` has a constant boolean value, None otherwise.
- Raises:
- TypeError: If `pred` is not a Variable, Tensor or bool.
- """
+ Raises:
+ TypeError: If `pred` is not a Variable, Tensor or bool, or Python
+ interger 1 or 0.
+ """
# Allow integer booleans.
- if pred == 0:
- pred = False
- elif pred == 1:
- pred = True
-
- if isinstance(pred, bool):
- pred_value = pred
- elif isinstance(pred, variables.Variable):
- pred_value = None
- elif isinstance(pred, ops.Tensor):
- pred_value = tensor_util.constant_value(pred)
- else:
- raise TypeError('`pred` must be a Tensor, a Variable, or a Python bool.')
- return pred_value
+ if isinstance(pred, int):
+ if pred == 1:
+ pred = True
+ elif pred == 0:
+ pred = False
+
+ if isinstance(pred, variables.Variable):
+ return None
+ return control_flow_ops.smart_constant_value(pred)
def object_list_uid(object_list):