aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/kernel_tests/check_ops_test.py49
-rw-r--r--tensorflow/python/ops/check_ops.py49
2 files changed, 96 insertions, 2 deletions
diff --git a/tensorflow/python/kernel_tests/check_ops_test.py b/tensorflow/python/kernel_tests/check_ops_test.py
index a2df4cb2a7..935519d074 100644
--- a/tensorflow/python/kernel_tests/check_ops_test.py
+++ b/tensorflow/python/kernel_tests/check_ops_test.py
@@ -124,6 +124,55 @@ class AssertEqualTest(test.TestCase):
out.eval()
+class AssertNoneEqualTest(test.TestCase):
+
+ def test_doesnt_raise_when_not_equal(self):
+ with self.test_session():
+ small = constant_op.constant([1, 2], name="small")
+ big = constant_op.constant([10, 20], name="small")
+ with ops.control_dependencies(
+ [check_ops.assert_none_equal(big, small)]):
+ out = array_ops.identity(small)
+ out.eval()
+
+ def test_raises_when_equal(self):
+ with self.test_session():
+ small = constant_op.constant([3, 1], name="small")
+ with ops.control_dependencies(
+ [check_ops.assert_none_equal(small, small)]):
+ out = array_ops.identity(small)
+ with self.assertRaisesOpError("x != y did not hold"):
+ out.eval()
+
+ def test_doesnt_raise_when_not_equal_and_broadcastable_shapes(self):
+ with self.test_session():
+ small = constant_op.constant([1, 2], name="small")
+ big = constant_op.constant([3], name="big")
+ with ops.control_dependencies(
+ [check_ops.assert_none_equal(small, big)]):
+ out = array_ops.identity(small)
+ out.eval()
+
+ def test_raises_when_not_equal_but_non_broadcastable_shapes(self):
+ with self.test_session():
+ small = constant_op.constant([1, 1, 1], name="small")
+ big = constant_op.constant([10, 10], name="big")
+ with self.assertRaisesRegexp(ValueError, "must be"):
+ with ops.control_dependencies(
+ [check_ops.assert_none_equal(small, big)]):
+ out = array_ops.identity(small)
+ out.eval()
+
+ def test_doesnt_raise_when_both_empty(self):
+ with self.test_session():
+ larry = constant_op.constant([])
+ curly = constant_op.constant([])
+ with ops.control_dependencies(
+ [check_ops.assert_none_equal(larry, curly)]):
+ out = array_ops.identity(larry)
+ out.eval()
+
+
class AssertLessTest(test.TestCase):
def test_raises_when_equal(self):
diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py
index 77664f19c4..0439e7b860 100644
--- a/tensorflow/python/ops/check_ops.py
+++ b/tensorflow/python/ops/check_ops.py
@@ -19,11 +19,10 @@ See the @{$python/check_ops} guide.
@@assert_negative
@@assert_positive
-@@assert_proper_iterable
@@assert_non_negative
@@assert_non_positive
@@assert_equal
-@@assert_integer
+@@assert_none_equal
@@assert_less
@@assert_less_equal
@@assert_greater
@@ -31,6 +30,8 @@ See the @{$python/check_ops} guide.
@@assert_rank
@@assert_rank_at_least
@@assert_type
+@@assert_integer
+@@assert_proper_iterable
@@is_non_decreasing
@@is_numeric_tensor
@@is_strictly_increasing
@@ -63,6 +64,7 @@ __all__ = [
'assert_non_negative',
'assert_non_positive',
'assert_equal',
+ 'assert_none_equal',
'assert_integer',
'assert_less',
'assert_less_equal',
@@ -285,6 +287,49 @@ def assert_equal(x, y, data=None, summarize=None, message=None, name=None):
return control_flow_ops.Assert(condition, data, summarize=summarize)
+def assert_none_equal(
+ x, y, data=None, summarize=None, message=None, name=None):
+ """Assert the condition `x != y` holds for all elements.
+
+ Example of adding a dependency to an operation:
+
+ ```python
+ with tf.control_dependencies([tf.assert_none_equal(x, y)]):
+ output = tf.reduce_sum(x)
+ ```
+
+ This condition holds if for every pair of (possibly broadcast) elements
+ `x[i]`, `y[i]`, we have `x[i] != y[i]`.
+ If both `x` and `y` are empty, this is trivially satisfied.
+
+ Args:
+ x: Numeric `Tensor`.
+ y: Numeric `Tensor`, same dtype as and broadcastable to `x`.
+ data: The tensors to print out if the condition is False. Defaults to
+ error message and first few entries of `x`, `y`.
+ summarize: Print this many entries of each tensor.
+ message: A string to prefix to the default message.
+ name: A name for this operation (optional).
+ Defaults to "assert_none_equal".
+
+ Returns:
+ Op that raises `InvalidArgumentError` if `x != y` is ever False.
+ """
+ message = message or ''
+ with ops.name_scope(name, 'assert_none_equal', [x, y, data]):
+ x = ops.convert_to_tensor(x, name='x')
+ y = ops.convert_to_tensor(y, name='y')
+ if data is None:
+ data = [
+ message,
+ 'Condition x != y did not hold for every single element: x = ',
+ x.name, x,
+ 'y = ', y.name, y
+ ]
+ condition = math_ops.reduce_all(math_ops.not_equal(x, y))
+ return control_flow_ops.Assert(condition, data, summarize=summarize)
+
+
def assert_less(x, y, data=None, summarize=None, message=None, name=None):
"""Assert the condition `x < y` holds element-wise.