diff options
Diffstat (limited to 'tensorflow/python/layers/utils.py')
-rw-r--r-- | tensorflow/python/layers/utils.py | 83 |
1 files changed, 33 insertions, 50 deletions
diff --git a/tensorflow/python/layers/utils.py b/tensorflow/python/layers/utils.py index 1195284024..484c6fc466 100644 --- a/tensorflow/python/layers/utils.py +++ b/tensorflow/python/layers/utils.py @@ -179,73 +179,56 @@ def deconv_output_length(input_length, filter_size, padding, stride): return input_length -def smart_cond(pred, fn1, fn2, name=None): - """Return either `fn1()` or `fn2()` based on the boolean predicate `pred`. +def smart_cond(pred, true_fn=None, false_fn=None, name=None): + """Return either `true_fn()` if predicate `pred` is true else `false_fn()`. - If `pred` is a bool or has a constant value, we return either `fn1()` - or `fn2()`, otherwise we use `tf.cond` to dynamically route to both. + If `pred` is a bool or has a constant value, we return either `true_fn()` + or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both. Arguments: - pred: A scalar determining whether to return the result of `fn1` or `fn2`. - fn1: The callable to be performed if pred is true. - fn2: The callable to be performed if pred is false. + pred: A scalar determining whether to return the result of `true_fn` or + `false_fn`. + true_fn: The callable to be performed if pred is true. + false_fn: The callable to be performed if pred is false. name: Optional name prefix when using `tf.cond`. Returns: - Tensors returned by the call to either `fn1` or `fn2`. + Tensors returned by the call to either `true_fn` or `false_fn`. Raises: - TypeError: If `fn1` or `fn2` is not callable. + TypeError: If `true_fn` or `false_fn` is not callable. """ - if not callable(fn1): - raise TypeError('`fn1` must be callable.') - if not callable(fn2): - raise TypeError('`fn2` must be callable.') - - if context.in_eager_mode(): - if pred: - return fn1() - else: - return fn2() - - pred_value = constant_value(pred) - if pred_value is not None: - if pred_value: - return fn1() - else: - return fn2() - else: - return control_flow_ops.cond(pred, true_fn=fn1, false_fn=fn2, name=name) + if isinstance(pred, variables.Variable): + return control_flow_ops.cond( + pred, true_fn=true_fn, false_fn=false_fn, name=name) + return control_flow_ops.smart_cond( + pred, true_fn=true_fn, false_fn=false_fn, name=name) def constant_value(pred): """Return the bool value for `pred`, or None if `pred` had a dynamic value. - Arguments: - pred: A scalar, either a Python bool or a TensorFlow boolean variable - or tensor, or the Python integer 1 or 0. + Arguments: + pred: A scalar, either a Python bool or a TensorFlow boolean variable + or tensor, or the Python integer 1 or 0. - Returns: - True or False if `pred` has a constant boolean value, None otherwise. + Returns: + True or False if `pred` has a constant boolean value, None otherwise. - Raises: - TypeError: If `pred` is not a Variable, Tensor or bool. - """ + Raises: + TypeError: If `pred` is not a Variable, Tensor or bool, or Python + interger 1 or 0. + """ # Allow integer booleans. - if pred == 0: - pred = False - elif pred == 1: - pred = True - - if isinstance(pred, bool): - pred_value = pred - elif isinstance(pred, variables.Variable): - pred_value = None - elif isinstance(pred, ops.Tensor): - pred_value = tensor_util.constant_value(pred) - else: - raise TypeError('`pred` must be a Tensor, a Variable, or a Python bool.') - return pred_value + if isinstance(pred, int): + if pred == 1: + pred = True + elif pred == 0: + pred = False + + if isinstance(pred, variables.Variable): + return None + return control_flow_ops.smart_constant_value(pred) def object_list_uid(object_list): |