aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Alexander Gorban <gorban@google.com>2017-12-12 20:09:45 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-12 20:12:52 -0800
commit216878ea3dafdc5fbe6a15d389edb003ad2fd4b4 (patch)
tree9607fd086c774caeeb96cd5d98219bc2b86e8cc5
parentb7308e3bd69349e9023497948a6bf55d3b0895d9 (diff)
Simplify tf.case implementation.
PiperOrigin-RevId: 178853258
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py4
-rw-r--r--tensorflow/python/ops/control_flow_ops.py268
-rw-r--r--tensorflow/python/ops/control_flow_ops_test.py14
3 files changed, 118 insertions, 168 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 35ae89ed33..5b0abaa2eb 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -2279,8 +2279,7 @@ class ControlFlowTest(test.TestCase):
# Duplicate events cause an error if exclusive = True
r4 = control_flow_ops.case(
[(x < y, f1), (x < y, f2)], default=f3, exclusive=True)
- with self.assertRaisesOpError(
- "More than one condition evaluated as True but exclusive=True."):
+ with self.assertRaisesOpError("Input error:"):
r4.eval()
# Check that the default is called if none of the others are
@@ -3045,5 +3044,6 @@ class EagerTest(test.TestCase):
default=f3, exclusive=True)
self.assertAllEqual(r1.numpy(), 17)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 6e97fe00bd..3418f33717 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -52,6 +52,7 @@ from __future__ import division
from __future__ import print_function
import collections
+import functools
import six
from six.moves import xrange # pylint: disable=redefined-builtin
@@ -3148,23 +3149,105 @@ def tuple(tensors, name=None, control_inputs=None):
return tpl
-def _assert_exclusive(preds):
- """Returns an Assert op that checks that the predicates are exclusive."""
- preds_c = array_ops.stack(preds, name="preds_c")
+def _assert_at_most_n_true(predicates, n, msg):
+ """Returns an Assert op that checks that at most n predicates are True.
+
+ Args:
+ predicates: list of bool scalar tensors.
+ n: maximum number of true predicates allowed.
+ msg: Error message.
+ """
+ preds_c = array_ops.stack(predicates, name="preds_c")
num_true_conditions = math_ops.reduce_sum(
math_ops.cast(preds_c, dtypes.int32), name="num_true_conds")
- at_most_one_true_condition = math_ops.less(
- num_true_conditions, constant_op.constant(2, name="two_true_conds"))
+ condition = math_ops.less_equal(num_true_conditions,
+ constant_op.constant(n, name="n_true_conds"))
+ preds_names = ", ".join(getattr(p, "name", "?") for p in predicates)
+ error_msg = [
+ "%s: more than %d conditions (%s) evaluated as True:" %
+ (msg, n, preds_names), preds_c
+ ]
+ return Assert(condition, data=error_msg, summarize=len(predicates))
+
- error_msg = [("More than one condition evaluated as True but "
- "exclusive=True. Conditions: (%s), Values:"
- % ", ".join([p.name for p in preds])),
- preds_c]
- return Assert(condition=at_most_one_true_condition, data=error_msg,
- summarize=len(preds))
+def _case_create_default_action(predicates, actions):
+ """Creates default action for a list of actions and their predicates.
+ It uses the input actions to select an arbitrary as default and makes sure
+ that corresponding predicates have valid values.
-def case(pred_fn_pairs, default=None, exclusive=False, strict=False,
+ Args:
+ predicates: a list of bool scalar tensors
+ actions: a list of callable objects which return tensors.
+
+ Returns:
+ a callable
+ """
+ k = len(predicates) - 1 # could pick any
+ predicate, action = predicates[k], actions[k]
+ other_predicates, other_actions = predicates[:k], actions[:k]
+
+ def default_action():
+ others_msg = ("Implementation error: "
+ "selected default action #%d was called, but some of other "
+ "predicates are True: " % k)
+ default_msg = ("Input error: "
+ "None of conditions evaluated as True:",
+ array_ops.stack(predicates, name="preds_c"))
+ with ops.control_dependencies([
+ _assert_at_most_n_true(other_predicates, n=0, msg=others_msg),
+ Assert(predicate, data=default_msg)
+ ]):
+ return action()
+
+ return default_action, other_predicates, other_actions
+
+
+def _case_verify_and_canonicalize_args(pred_fn_pairs, exclusive, name):
+ """Verifies input arguments for the case function.
+
+ Args:
+ pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor and a
+ callable which returns a list of tensors.
+ exclusive: True iff at most one predicate is allowed to evaluate to `True`.
+ name: A name for the case operation.
+
+ Raises:
+ TypeError: If `pred_fn_pairs` is not a list/dictionary.
+ TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
+ TypeError: If `fns[i]` is not callable for any i, or `default` is not
+ callable.
+
+ Returns:
+ a tuple <list of scalar bool tensors, list of callables>.
+ """
+ if not isinstance(pred_fn_pairs, (list, _basetuple, dict)):
+ raise TypeError("fns must be a list, tuple, or dict")
+
+ if isinstance(pred_fn_pairs, collections.OrderedDict):
+ pred_fn_pairs = pred_fn_pairs.items()
+ elif isinstance(pred_fn_pairs, dict):
+ pred_fn_pairs = sorted(pred_fn_pairs.items(), key=lambda item: item[0].name)
+ if not exclusive:
+ logging.warn("%s: An unordered dictionary of predicate/fn pairs was "
+ "provided, but exclusive=False. The order of conditional "
+ "tests is deterministic but not guaranteed.", name)
+ for pred_fn_pair in pred_fn_pairs:
+ if not isinstance(pred_fn_pair, _basetuple) or len(pred_fn_pair) != 2:
+ raise TypeError("Each entry in pred_fn_pairs must be a 2-tuple")
+ pred, fn = pred_fn_pair
+ if pred.dtype != dtypes.bool:
+ raise TypeError("pred must be of type bool: %s", pred.name)
+ if not callable(fn):
+ raise TypeError("fn for pred %s must be callable." % pred.name)
+ predicates, actions = zip(*pred_fn_pairs)
+ return predicates, actions
+
+
+def case(pred_fn_pairs,
+ default=None,
+ exclusive=False,
+ strict=False,
name="case"):
"""Create a case operation.
@@ -3249,152 +3332,27 @@ def case(pred_fn_pairs, default=None, exclusive=False, strict=False,
TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
TypeError: If `fns[i]` is not callable for any i, or `default` is not
callable.
- ValueError: If in eager mode and all predicates are false and no
- default is provided.
- ValueError: If in eager mode and is passed a dictionary.
"""
- pfp = pred_fn_pairs # For readability
- if not (isinstance(pfp, list) or isinstance(pfp, _basetuple)
- or isinstance(pfp, dict)):
- raise TypeError("fns must be a list, tuple, or dict")
- if isinstance(pfp, dict):
- if context.in_eager_mode():
- raise ValueError(
- "In eager mode the predicates must be a list, not a dictionary.")
- if isinstance(pfp, collections.OrderedDict):
- pfp = pfp.items()
- else:
- pfp = sorted(pfp.items(), key=lambda item: item[0].name)
- if not exclusive:
- logging.warn("%s: An unordered dictionary of predicate/fn pairs was "
- "provided, but exclusive=False. The order of conditional "
- "tests is deterministic but not guaranteed.", name)
- for tup in pfp:
- if not isinstance(tup, _basetuple) or len(tup) != 2:
- raise TypeError("Each entry in pred_fn_pairs must be a 2-tuple")
- pred, fn = tup
- if pred.dtype != dtypes.bool:
- raise TypeError("pred must be of type bool: %s", pred.name)
- if not callable(fn):
- raise TypeError("fn for pred %s must be callable." % pred.name)
-
- if default is not None and not callable(default):
- raise TypeError("default must be callable.")
-
- if context.in_eager_mode():
- for pred, fn in pfp:
- if pred:
- return fn()
- if default is None:
- raise ValueError("tf.case received all false predicates and no default.")
- return default()
-
- preds, fns = map(list, zip(*pfp))
- del pfp # From now on, preds and fns form the source of truth.
-
- with ops.name_scope(name, "case", [preds]):
- exclusivity_assert = _assert_exclusive(preds) if exclusive else None
- # If no default is provided, then we remove one of the (predicate, function)
- # pairs and define the default to be the removed function with an additional
- # control dependency that asserts that the removed predicate holds.
+ predicates, actions = _case_verify_and_canonicalize_args(
+ pred_fn_pairs, exclusive, name)
+ with ops.name_scope(name, "case", [predicates]):
if default is None:
- all_preds = _basetuple(preds) # For the error message.
- last_pred, last_fn = preds.pop(), fns.pop()
- def new_default():
- preds_c = array_ops.stack(all_preds, name="preds_c")
- error_msg = [
- ("None of the conditions evaluated as True. Conditions: (%s), "
- "Values:" % ", ".join([p.name for p in all_preds])),
- preds_c]
- assertion = Assert(condition=last_pred,
- data=error_msg, summarize=len(all_preds))
- with ops.control_dependencies([assertion]):
- return last_fn()
- default = new_default
-
- if not preds:
- return default()
- not_preds = []
- for i, p in enumerate(preds):
- with ops.name_scope("not_%d" % i):
- not_preds.append(math_ops.logical_not(p))
- and_not_preds = [constant_op.constant(True, name="always_true")]
- for i, notp in enumerate(not_preds):
- with ops.name_scope("and_not_%d" % i):
- and_not_preds.append(math_ops.logical_and(and_not_preds[-1], notp))
-
- # preds = [p1, p2, p3]
- # fns = [f1, f2, f3]
- # not_preds = [~p1, ~p2, ~p3]
- # and_not_preds = [True, ~p1, ~p1 & ~p2, ~p1 & ~p2 & ~p3]
- # case_preds = [p1,
- # p2 & ~p1,
- # p3 & ~p2 & ~p1,
- # ~p3 & ~p2 & ~p1]
-
- case_preds = []
- for i, (p, and_not_p_prev) in enumerate(zip(preds, and_not_preds[:-1])):
- with ops.name_scope("case_%d" % i):
- case_preds.append(math_ops.logical_and(p, and_not_p_prev))
- with ops.name_scope("case_none_are_true"):
- case_preds.append(and_not_preds[-1])
-
- # Create an empty tensor, or list, with the right type and shape
- with ops.name_scope("case_create_empty"):
- def _create_empty_constant(dtype, shape):
- value = ("" if dtype == dtypes.string else dtype.as_numpy_dtype())
- if shape.ndims is None:
- return array_ops.constant(value, dtype=dtype)
- else:
- temp_shape = [1 if x.value is None else x.value for x in shape]
- result = array_ops.constant(value, shape=temp_shape, dtype=dtype)
- result._shape = shape # pylint: disable=protected-access
- return result
-
- def _correct_empty(v):
- if isinstance(v, ops.Operation):
- return no_op()
- elif isinstance(v, tensor_array_ops.TensorArray):
- return v
- elif not hasattr(v, "dtype"):
- return ops.convert_to_tensor(v)
- elif isinstance(v, sparse_tensor.SparseTensor):
- return sparse_tensor.SparseTensor(indices=[[0] * len(v.get_shape())],
- values=[v.dtype.as_numpy_dtype()],
- dense_shape=v.get_shape())
- else:
- return _create_empty_constant(v.dtype, v.get_shape())
-
- empty = lambda: nest.map_structure(_correct_empty, default())
-
- # case_sequence = [
- # cond(~p3 & ~p2 & ~p1, default, empty),
- # cond(p3 & ~p2 & ~p1, f3, lambda: case_sequence[0]),
- # cond(p2 & ~p1, f2, lambda: case_sequence[1]),
- # cond(p1, f1, lambda: case_sequence[2])
- # ]
- #
- # And the return value will be case_sequence[-1]
- def _build_case():
- all_fns = [fn for fn in fns]
- all_fns.append(default)
- prev_case = None
- for i, (cp, fn) in enumerate(list(zip(case_preds, all_fns))[::-1]):
- prev_case = cond(
- cp, fn,
- empty if i == 0 else lambda: prev_case,
- strict=strict, name="If_%d" % i)
- return prev_case
-
- if exclusivity_assert is not None:
- with ops.control_dependencies([exclusivity_assert]):
- case_seq = _build_case()
+ default, predicates, actions = _case_create_default_action(
+ predicates, actions)
+ fn = default
+ # To eval conditions in direct order we create nested conditions in reverse:
+ # cond(c[0], true_fn=.., false_fn=cond(c[1], ...))
+ for predicate, action in reversed(list(zip(predicates, actions))):
+ fn = functools.partial(
+ cond, predicate, true_fn=action, false_fn=fn, strict=strict)
+ if exclusive:
+ with ops.control_dependencies([
+ _assert_at_most_n_true(
+ predicates, n=1, msg="Input error: exclusive=True")
+ ]):
+ return fn()
else:
- case_seq = _build_case()
-
- if not strict:
- case_seq = _UnpackIfSingleton(case_seq)
- return case_seq
+ return fn()
ops.register_proto_function(ops.GraphKeys.COND_CONTEXT,
diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py
index a666fd33a2..cc5a42bf3d 100644
--- a/tensorflow/python/ops/control_flow_ops_test.py
+++ b/tensorflow/python/ops/control_flow_ops_test.py
@@ -883,8 +883,7 @@ class CaseTest(test_util.TensorFlowTestCase):
with self.test_session() as sess:
self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
self.assertEqual(sess.run(output, feed_dict={x: 3}), 8)
- with self.assertRaisesRegexp(errors.InvalidArgumentError,
- "More than one condition evaluated as True"):
+ with self.assertRaisesRegexp(errors.InvalidArgumentError, "Input error:"):
sess.run(output, feed_dict={x: 2})
def testCase_multiple_matches_non_exclusive(self):
@@ -909,11 +908,7 @@ class CaseTest(test_util.TensorFlowTestCase):
self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
self.assertEqual(sess.run(output, feed_dict={x: 2}), 4)
self.assertEqual(sess.run(output, feed_dict={x: 3}), 6)
- with self.assertRaisesRegexp(
- errors.InvalidArgumentError,
- r"\[None of the conditions evaluated as True. "
- r"Conditions: \(Equal:0, Equal_1:0, Equal_2:0\), Values:\] "
- r"\[0 0 0\]"):
+ with self.assertRaisesRegexp(errors.InvalidArgumentError, "Input error:"):
sess.run(output, feed_dict={x: 4})
def testCase_withoutDefault_oneCondition(self):
@@ -922,10 +917,7 @@ class CaseTest(test_util.TensorFlowTestCase):
output = control_flow_ops.case(conditions, exclusive=True)
with self.test_session() as sess:
self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
- with self.assertRaisesRegexp(
- errors.InvalidArgumentError,
- r"\[None of the conditions evaluated as True. "
- r"Conditions: \(Equal:0\), Values:\] \[0\]"):
+ with self.assertRaisesRegexp(errors.InvalidArgumentError, "Input error:"):
sess.run(output, feed_dict={x: 4})