diff options
-rw-r--r-- | tensorflow/python/kernel_tests/check_ops_test.py | 49 | ||||
-rw-r--r-- | tensorflow/python/ops/check_ops.py | 49 |
2 files changed, 96 insertions, 2 deletions
diff --git a/tensorflow/python/kernel_tests/check_ops_test.py b/tensorflow/python/kernel_tests/check_ops_test.py index a2df4cb2a7..935519d074 100644 --- a/tensorflow/python/kernel_tests/check_ops_test.py +++ b/tensorflow/python/kernel_tests/check_ops_test.py @@ -124,6 +124,55 @@ class AssertEqualTest(test.TestCase): out.eval() +class AssertNoneEqualTest(test.TestCase): + + def test_doesnt_raise_when_not_equal(self): + with self.test_session(): + small = constant_op.constant([1, 2], name="small") + big = constant_op.constant([10, 20], name="small") + with ops.control_dependencies( + [check_ops.assert_none_equal(big, small)]): + out = array_ops.identity(small) + out.eval() + + def test_raises_when_equal(self): + with self.test_session(): + small = constant_op.constant([3, 1], name="small") + with ops.control_dependencies( + [check_ops.assert_none_equal(small, small)]): + out = array_ops.identity(small) + with self.assertRaisesOpError("x != y did not hold"): + out.eval() + + def test_doesnt_raise_when_not_equal_and_broadcastable_shapes(self): + with self.test_session(): + small = constant_op.constant([1, 2], name="small") + big = constant_op.constant([3], name="big") + with ops.control_dependencies( + [check_ops.assert_none_equal(small, big)]): + out = array_ops.identity(small) + out.eval() + + def test_raises_when_not_equal_but_non_broadcastable_shapes(self): + with self.test_session(): + small = constant_op.constant([1, 1, 1], name="small") + big = constant_op.constant([10, 10], name="big") + with self.assertRaisesRegexp(ValueError, "must be"): + with ops.control_dependencies( + [check_ops.assert_none_equal(small, big)]): + out = array_ops.identity(small) + out.eval() + + def test_doesnt_raise_when_both_empty(self): + with self.test_session(): + larry = constant_op.constant([]) + curly = constant_op.constant([]) + with ops.control_dependencies( + [check_ops.assert_none_equal(larry, curly)]): + out = array_ops.identity(larry) + out.eval() + + class AssertLessTest(test.TestCase): def test_raises_when_equal(self): diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py index 77664f19c4..0439e7b860 100644 --- a/tensorflow/python/ops/check_ops.py +++ b/tensorflow/python/ops/check_ops.py @@ -19,11 +19,10 @@ See the @{$python/check_ops} guide. @@assert_negative @@assert_positive -@@assert_proper_iterable @@assert_non_negative @@assert_non_positive @@assert_equal -@@assert_integer +@@assert_none_equal @@assert_less @@assert_less_equal @@assert_greater @@ -31,6 +30,8 @@ See the @{$python/check_ops} guide. @@assert_rank @@assert_rank_at_least @@assert_type +@@assert_integer +@@assert_proper_iterable @@is_non_decreasing @@is_numeric_tensor @@is_strictly_increasing @@ -63,6 +64,7 @@ __all__ = [ 'assert_non_negative', 'assert_non_positive', 'assert_equal', + 'assert_none_equal', 'assert_integer', 'assert_less', 'assert_less_equal', @@ -285,6 +287,49 @@ def assert_equal(x, y, data=None, summarize=None, message=None, name=None): return control_flow_ops.Assert(condition, data, summarize=summarize) +def assert_none_equal( + x, y, data=None, summarize=None, message=None, name=None): + """Assert the condition `x != y` holds for all elements. + + Example of adding a dependency to an operation: + + ```python + with tf.control_dependencies([tf.assert_none_equal(x, y)]): + output = tf.reduce_sum(x) + ``` + + This condition holds if for every pair of (possibly broadcast) elements + `x[i]`, `y[i]`, we have `x[i] != y[i]`. + If both `x` and `y` are empty, this is trivially satisfied. + + Args: + x: Numeric `Tensor`. + y: Numeric `Tensor`, same dtype as and broadcastable to `x`. + data: The tensors to print out if the condition is False. Defaults to + error message and first few entries of `x`, `y`. + summarize: Print this many entries of each tensor. + message: A string to prefix to the default message. + name: A name for this operation (optional). + Defaults to "assert_none_equal". + + Returns: + Op that raises `InvalidArgumentError` if `x != y` is ever False. + """ + message = message or '' + with ops.name_scope(name, 'assert_none_equal', [x, y, data]): + x = ops.convert_to_tensor(x, name='x') + y = ops.convert_to_tensor(y, name='y') + if data is None: + data = [ + message, + 'Condition x != y did not hold for every single element: x = ', + x.name, x, + 'y = ', y.name, y + ] + condition = math_ops.reduce_all(math_ops.not_equal(x, y)) + return control_flow_ops.Assert(condition, data, summarize=summarize) + + def assert_less(x, y, data=None, summarize=None, message=None, name=None): """Assert the condition `x < y` holds element-wise. |