diff options
author | Alexander Gorban <gorban@google.com> | 2017-12-12 20:09:45 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-12 20:12:52 -0800 |
commit | 216878ea3dafdc5fbe6a15d389edb003ad2fd4b4 (patch) | |
tree | 9607fd086c774caeeb96cd5d98219bc2b86e8cc5 | |
parent | b7308e3bd69349e9023497948a6bf55d3b0895d9 (diff) |
Simplify tf.case implementation.
PiperOrigin-RevId: 178853258
-rw-r--r-- | tensorflow/python/kernel_tests/control_flow_ops_py_test.py | 4 | ||||
-rw-r--r-- | tensorflow/python/ops/control_flow_ops.py | 268 | ||||
-rw-r--r-- | tensorflow/python/ops/control_flow_ops_test.py | 14 |
3 files changed, 118 insertions, 168 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 35ae89ed33..5b0abaa2eb 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -2279,8 +2279,7 @@ class ControlFlowTest(test.TestCase): # Duplicate events cause an error if exclusive = True r4 = control_flow_ops.case( [(x < y, f1), (x < y, f2)], default=f3, exclusive=True) - with self.assertRaisesOpError( - "More than one condition evaluated as True but exclusive=True."): + with self.assertRaisesOpError("Input error:"): r4.eval() # Check that the default is called if none of the others are @@ -3045,5 +3044,6 @@ class EagerTest(test.TestCase): default=f3, exclusive=True) self.assertAllEqual(r1.numpy(), 17) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 6e97fe00bd..3418f33717 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -52,6 +52,7 @@ from __future__ import division from __future__ import print_function import collections +import functools import six from six.moves import xrange # pylint: disable=redefined-builtin @@ -3148,23 +3149,105 @@ def tuple(tensors, name=None, control_inputs=None): return tpl -def _assert_exclusive(preds): - """Returns an Assert op that checks that the predicates are exclusive.""" - preds_c = array_ops.stack(preds, name="preds_c") +def _assert_at_most_n_true(predicates, n, msg): + """Returns an Assert op that checks that at most n predicates are True. + + Args: + predicates: list of bool scalar tensors. + n: maximum number of true predicates allowed. + msg: Error message. + """ + preds_c = array_ops.stack(predicates, name="preds_c") num_true_conditions = math_ops.reduce_sum( math_ops.cast(preds_c, dtypes.int32), name="num_true_conds") - at_most_one_true_condition = math_ops.less( - num_true_conditions, constant_op.constant(2, name="two_true_conds")) + condition = math_ops.less_equal(num_true_conditions, + constant_op.constant(n, name="n_true_conds")) + preds_names = ", ".join(getattr(p, "name", "?") for p in predicates) + error_msg = [ + "%s: more than %d conditions (%s) evaluated as True:" % + (msg, n, preds_names), preds_c + ] + return Assert(condition, data=error_msg, summarize=len(predicates)) + - error_msg = [("More than one condition evaluated as True but " - "exclusive=True. Conditions: (%s), Values:" - % ", ".join([p.name for p in preds])), - preds_c] - return Assert(condition=at_most_one_true_condition, data=error_msg, - summarize=len(preds)) +def _case_create_default_action(predicates, actions): + """Creates default action for a list of actions and their predicates. + It uses the input actions to select an arbitrary as default and makes sure + that corresponding predicates have valid values. -def case(pred_fn_pairs, default=None, exclusive=False, strict=False, + Args: + predicates: a list of bool scalar tensors + actions: a list of callable objects which return tensors. + + Returns: + a callable + """ + k = len(predicates) - 1 # could pick any + predicate, action = predicates[k], actions[k] + other_predicates, other_actions = predicates[:k], actions[:k] + + def default_action(): + others_msg = ("Implementation error: " + "selected default action #%d was called, but some of other " + "predicates are True: " % k) + default_msg = ("Input error: " + "None of conditions evaluated as True:", + array_ops.stack(predicates, name="preds_c")) + with ops.control_dependencies([ + _assert_at_most_n_true(other_predicates, n=0, msg=others_msg), + Assert(predicate, data=default_msg) + ]): + return action() + + return default_action, other_predicates, other_actions + + +def _case_verify_and_canonicalize_args(pred_fn_pairs, exclusive, name): + """Verifies input arguments for the case function. + + Args: + pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor and a + callable which returns a list of tensors. + exclusive: True iff at most one predicate is allowed to evaluate to `True`. + name: A name for the case operation. + + Raises: + TypeError: If `pred_fn_pairs` is not a list/dictionary. + TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples. + TypeError: If `fns[i]` is not callable for any i, or `default` is not + callable. + + Returns: + a tuple <list of scalar bool tensors, list of callables>. + """ + if not isinstance(pred_fn_pairs, (list, _basetuple, dict)): + raise TypeError("fns must be a list, tuple, or dict") + + if isinstance(pred_fn_pairs, collections.OrderedDict): + pred_fn_pairs = pred_fn_pairs.items() + elif isinstance(pred_fn_pairs, dict): + pred_fn_pairs = sorted(pred_fn_pairs.items(), key=lambda item: item[0].name) + if not exclusive: + logging.warn("%s: An unordered dictionary of predicate/fn pairs was " + "provided, but exclusive=False. The order of conditional " + "tests is deterministic but not guaranteed.", name) + for pred_fn_pair in pred_fn_pairs: + if not isinstance(pred_fn_pair, _basetuple) or len(pred_fn_pair) != 2: + raise TypeError("Each entry in pred_fn_pairs must be a 2-tuple") + pred, fn = pred_fn_pair + if pred.dtype != dtypes.bool: + raise TypeError("pred must be of type bool: %s", pred.name) + if not callable(fn): + raise TypeError("fn for pred %s must be callable." % pred.name) + predicates, actions = zip(*pred_fn_pairs) + return predicates, actions + + +def case(pred_fn_pairs, + default=None, + exclusive=False, + strict=False, name="case"): """Create a case operation. @@ -3249,152 +3332,27 @@ def case(pred_fn_pairs, default=None, exclusive=False, strict=False, TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples. TypeError: If `fns[i]` is not callable for any i, or `default` is not callable. - ValueError: If in eager mode and all predicates are false and no - default is provided. - ValueError: If in eager mode and is passed a dictionary. """ - pfp = pred_fn_pairs # For readability - if not (isinstance(pfp, list) or isinstance(pfp, _basetuple) - or isinstance(pfp, dict)): - raise TypeError("fns must be a list, tuple, or dict") - if isinstance(pfp, dict): - if context.in_eager_mode(): - raise ValueError( - "In eager mode the predicates must be a list, not a dictionary.") - if isinstance(pfp, collections.OrderedDict): - pfp = pfp.items() - else: - pfp = sorted(pfp.items(), key=lambda item: item[0].name) - if not exclusive: - logging.warn("%s: An unordered dictionary of predicate/fn pairs was " - "provided, but exclusive=False. The order of conditional " - "tests is deterministic but not guaranteed.", name) - for tup in pfp: - if not isinstance(tup, _basetuple) or len(tup) != 2: - raise TypeError("Each entry in pred_fn_pairs must be a 2-tuple") - pred, fn = tup - if pred.dtype != dtypes.bool: - raise TypeError("pred must be of type bool: %s", pred.name) - if not callable(fn): - raise TypeError("fn for pred %s must be callable." % pred.name) - - if default is not None and not callable(default): - raise TypeError("default must be callable.") - - if context.in_eager_mode(): - for pred, fn in pfp: - if pred: - return fn() - if default is None: - raise ValueError("tf.case received all false predicates and no default.") - return default() - - preds, fns = map(list, zip(*pfp)) - del pfp # From now on, preds and fns form the source of truth. - - with ops.name_scope(name, "case", [preds]): - exclusivity_assert = _assert_exclusive(preds) if exclusive else None - # If no default is provided, then we remove one of the (predicate, function) - # pairs and define the default to be the removed function with an additional - # control dependency that asserts that the removed predicate holds. + predicates, actions = _case_verify_and_canonicalize_args( + pred_fn_pairs, exclusive, name) + with ops.name_scope(name, "case", [predicates]): if default is None: - all_preds = _basetuple(preds) # For the error message. - last_pred, last_fn = preds.pop(), fns.pop() - def new_default(): - preds_c = array_ops.stack(all_preds, name="preds_c") - error_msg = [ - ("None of the conditions evaluated as True. Conditions: (%s), " - "Values:" % ", ".join([p.name for p in all_preds])), - preds_c] - assertion = Assert(condition=last_pred, - data=error_msg, summarize=len(all_preds)) - with ops.control_dependencies([assertion]): - return last_fn() - default = new_default - - if not preds: - return default() - not_preds = [] - for i, p in enumerate(preds): - with ops.name_scope("not_%d" % i): - not_preds.append(math_ops.logical_not(p)) - and_not_preds = [constant_op.constant(True, name="always_true")] - for i, notp in enumerate(not_preds): - with ops.name_scope("and_not_%d" % i): - and_not_preds.append(math_ops.logical_and(and_not_preds[-1], notp)) - - # preds = [p1, p2, p3] - # fns = [f1, f2, f3] - # not_preds = [~p1, ~p2, ~p3] - # and_not_preds = [True, ~p1, ~p1 & ~p2, ~p1 & ~p2 & ~p3] - # case_preds = [p1, - # p2 & ~p1, - # p3 & ~p2 & ~p1, - # ~p3 & ~p2 & ~p1] - - case_preds = [] - for i, (p, and_not_p_prev) in enumerate(zip(preds, and_not_preds[:-1])): - with ops.name_scope("case_%d" % i): - case_preds.append(math_ops.logical_and(p, and_not_p_prev)) - with ops.name_scope("case_none_are_true"): - case_preds.append(and_not_preds[-1]) - - # Create an empty tensor, or list, with the right type and shape - with ops.name_scope("case_create_empty"): - def _create_empty_constant(dtype, shape): - value = ("" if dtype == dtypes.string else dtype.as_numpy_dtype()) - if shape.ndims is None: - return array_ops.constant(value, dtype=dtype) - else: - temp_shape = [1 if x.value is None else x.value for x in shape] - result = array_ops.constant(value, shape=temp_shape, dtype=dtype) - result._shape = shape # pylint: disable=protected-access - return result - - def _correct_empty(v): - if isinstance(v, ops.Operation): - return no_op() - elif isinstance(v, tensor_array_ops.TensorArray): - return v - elif not hasattr(v, "dtype"): - return ops.convert_to_tensor(v) - elif isinstance(v, sparse_tensor.SparseTensor): - return sparse_tensor.SparseTensor(indices=[[0] * len(v.get_shape())], - values=[v.dtype.as_numpy_dtype()], - dense_shape=v.get_shape()) - else: - return _create_empty_constant(v.dtype, v.get_shape()) - - empty = lambda: nest.map_structure(_correct_empty, default()) - - # case_sequence = [ - # cond(~p3 & ~p2 & ~p1, default, empty), - # cond(p3 & ~p2 & ~p1, f3, lambda: case_sequence[0]), - # cond(p2 & ~p1, f2, lambda: case_sequence[1]), - # cond(p1, f1, lambda: case_sequence[2]) - # ] - # - # And the return value will be case_sequence[-1] - def _build_case(): - all_fns = [fn for fn in fns] - all_fns.append(default) - prev_case = None - for i, (cp, fn) in enumerate(list(zip(case_preds, all_fns))[::-1]): - prev_case = cond( - cp, fn, - empty if i == 0 else lambda: prev_case, - strict=strict, name="If_%d" % i) - return prev_case - - if exclusivity_assert is not None: - with ops.control_dependencies([exclusivity_assert]): - case_seq = _build_case() + default, predicates, actions = _case_create_default_action( + predicates, actions) + fn = default + # To eval conditions in direct order we create nested conditions in reverse: + # cond(c[0], true_fn=.., false_fn=cond(c[1], ...)) + for predicate, action in reversed(list(zip(predicates, actions))): + fn = functools.partial( + cond, predicate, true_fn=action, false_fn=fn, strict=strict) + if exclusive: + with ops.control_dependencies([ + _assert_at_most_n_true( + predicates, n=1, msg="Input error: exclusive=True") + ]): + return fn() else: - case_seq = _build_case() - - if not strict: - case_seq = _UnpackIfSingleton(case_seq) - return case_seq + return fn() ops.register_proto_function(ops.GraphKeys.COND_CONTEXT, diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py index a666fd33a2..cc5a42bf3d 100644 --- a/tensorflow/python/ops/control_flow_ops_test.py +++ b/tensorflow/python/ops/control_flow_ops_test.py @@ -883,8 +883,7 @@ class CaseTest(test_util.TensorFlowTestCase): with self.test_session() as sess: self.assertEqual(sess.run(output, feed_dict={x: 1}), 2) self.assertEqual(sess.run(output, feed_dict={x: 3}), 8) - with self.assertRaisesRegexp(errors.InvalidArgumentError, - "More than one condition evaluated as True"): + with self.assertRaisesRegexp(errors.InvalidArgumentError, "Input error:"): sess.run(output, feed_dict={x: 2}) def testCase_multiple_matches_non_exclusive(self): @@ -909,11 +908,7 @@ class CaseTest(test_util.TensorFlowTestCase): self.assertEqual(sess.run(output, feed_dict={x: 1}), 2) self.assertEqual(sess.run(output, feed_dict={x: 2}), 4) self.assertEqual(sess.run(output, feed_dict={x: 3}), 6) - with self.assertRaisesRegexp( - errors.InvalidArgumentError, - r"\[None of the conditions evaluated as True. " - r"Conditions: \(Equal:0, Equal_1:0, Equal_2:0\), Values:\] " - r"\[0 0 0\]"): + with self.assertRaisesRegexp(errors.InvalidArgumentError, "Input error:"): sess.run(output, feed_dict={x: 4}) def testCase_withoutDefault_oneCondition(self): @@ -922,10 +917,7 @@ class CaseTest(test_util.TensorFlowTestCase): output = control_flow_ops.case(conditions, exclusive=True) with self.test_session() as sess: self.assertEqual(sess.run(output, feed_dict={x: 1}), 2) - with self.assertRaisesRegexp( - errors.InvalidArgumentError, - r"\[None of the conditions evaluated as True. " - r"Conditions: \(Equal:0\), Values:\] \[0\]"): + with self.assertRaisesRegexp(errors.InvalidArgumentError, "Input error:"): sess.run(output, feed_dict={x: 4}) |