diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/check_ops_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/check_ops_test.py | 311 |
1 files changed, 110 insertions, 201 deletions
diff --git a/tensorflow/python/kernel_tests/check_ops_test.py b/tensorflow/python/kernel_tests/check_ops_test.py index 43785adcee..ed859e3774 100644 --- a/tensorflow/python/kernel_tests/check_ops_test.py +++ b/tensorflow/python/kernel_tests/check_ops_test.py @@ -20,13 +20,10 @@ from __future__ import print_function import numpy as np -from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor -from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.platform import test @@ -74,178 +71,110 @@ class AssertProperIterableTest(test.TestCase): class AssertEqualTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() def test_doesnt_raise_when_equal(self): - small = constant_op.constant([1, 2], name="small") - with ops.control_dependencies([check_ops.assert_equal(small, small)]): - out = array_ops.identity(small) - self.evaluate(out) - - def test_returns_none_with_eager(self): - with context.eager_mode(): + with self.test_session(): small = constant_op.constant([1, 2], name="small") - x = check_ops.assert_equal(small, small) - assert x is None + with ops.control_dependencies([check_ops.assert_equal(small, small)]): + out = array_ops.identity(small) + out.eval() - @test_util.run_in_graph_and_eager_modes() def test_raises_when_greater(self): - # Static check - static_small = constant_op.constant([1, 2], name="small") - static_big = constant_op.constant([3, 4], name="big") - with self.assertRaisesRegexp(errors.InvalidArgumentError, "fail"): - check_ops.assert_equal(static_big, static_small, message="fail") - - # Dynamic check - if context.in_graph_mode(): - with self.test_session(): - small = array_ops.placeholder(dtypes.int32, name="small") - big = array_ops.placeholder(dtypes.int32, name="big") - with ops.control_dependencies( - [check_ops.assert_equal( - big, small, message="fail")]): - out = array_ops.identity(small) - with self.assertRaisesOpError("fail.*big.*small"): - out.eval(feed_dict={small: [1, 2], big: [3, 4]}) - - def test_error_message_eager(self): - expected_error_msg_full = r"""big does not equal small -Condition x == y did not hold. -Indices of first 6 different values: -\[\[0 0\] - \[1 1\] - \[2 0\]\] -Corresponding x values: -\[2 3 6\] -Corresponding y values: -\[20 30 60\] -First 6 elements of x: -\[2 2 3 3 6 6\] -First 6 elements of y: -\[20 2 3 30 60 6\] -""" - expected_error_msg_short = r"""big does not equal small -Condition x == y did not hold. -Indices of first 2 different values: -\[\[0 0\] - \[1 1\]\] -Corresponding x values: -\[2 3\] -Corresponding y values: -\[20 30\] -First 2 elements of x: -\[2 2\] -First 2 elements of y: -\[20 2\] -""" - with context.eager_mode(): - big = constant_op.constant([[2, 2], [3, 3], [6, 6]]) - small = constant_op.constant([[20, 2], [3, 30], [60, 6]]) - with self.assertRaisesRegexp(errors.InvalidArgumentError, - expected_error_msg_full): - check_ops.assert_equal(big, small, message="big does not equal small", - summarize=10) - with self.assertRaisesRegexp(errors.InvalidArgumentError, - expected_error_msg_short): - check_ops.assert_equal(big, small, message="big does not equal small", - summarize=2) - - @test_util.run_in_graph_and_eager_modes() + with self.test_session(): + # Static check + static_small = constant_op.constant([1, 2], name="small") + static_big = constant_op.constant([3, 4], name="big") + with self.assertRaisesRegexp(ValueError, "fail"): + check_ops.assert_equal(static_big, static_small, message="fail") + # Dynamic check + small = array_ops.placeholder(dtypes.int32, name="small") + big = array_ops.placeholder(dtypes.int32, name="big") + with ops.control_dependencies( + [check_ops.assert_equal( + big, small, message="fail")]): + out = array_ops.identity(small) + with self.assertRaisesOpError("fail.*big.*small"): + out.eval(feed_dict={small: [1, 2], big: [3, 4]}) + def test_raises_when_less(self): - # Static check - static_small = constant_op.constant([3, 1], name="small") - static_big = constant_op.constant([4, 2], name="big") - with self.assertRaisesRegexp(errors.InvalidArgumentError, "fail"): - check_ops.assert_equal(static_big, static_small, message="fail") - - # Dynamic check - if context.in_graph_mode(): - with self.test_session(): - small = array_ops.placeholder(dtypes.int32, name="small") - big = array_ops.placeholder(dtypes.int32, name="big") - with ops.control_dependencies([check_ops.assert_equal(small, big)]): - out = array_ops.identity(small) - with self.assertRaisesOpError("small.*big"): - out.eval(feed_dict={small: [3, 1], big: [4, 2]}) + with self.test_session(): + # Static check + static_small = constant_op.constant([3, 1], name="small") + static_big = constant_op.constant([4, 2], name="big") + with self.assertRaisesRegexp(ValueError, "fail"): + check_ops.assert_equal(static_big, static_small, message="fail") + # Dynamic check + small = array_ops.placeholder(dtypes.int32, name="small") + big = array_ops.placeholder(dtypes.int32, name="big") + with ops.control_dependencies([check_ops.assert_equal(small, big)]): + out = array_ops.identity(small) + with self.assertRaisesOpError("small.*big"): + out.eval(feed_dict={small: [3, 1], big: [4, 2]}) - @test_util.run_in_graph_and_eager_modes() def test_doesnt_raise_when_equal_and_broadcastable_shapes(self): - small = constant_op.constant([[1, 2], [1, 2]], name="small") - small_2 = constant_op.constant([1, 2], name="small_2") - with ops.control_dependencies([check_ops.assert_equal(small, small_2)]): - out = array_ops.identity(small) - self.evaluate(out) - - @test_util.run_in_graph_and_eager_modes() - def test_raises_when_equal_but_non_broadcastable_shapes(self): - small = constant_op.constant([1, 1, 1], name="small") - small_2 = constant_op.constant([1, 1], name="small_2") - # The exception in eager and non-eager mode is different because - # eager mode relies on shape check done as part of the C++ op, while - # graph mode does shape checks when creating the `Operation` instance. - with self.assertRaisesRegexp( - (errors.InvalidArgumentError, ValueError), - (r"Incompatible shapes: \[3\] vs. \[2\]|" - r"Dimensions must be equal, but are 3 and 2")): + with self.test_session(): + small = constant_op.constant([1, 2], name="small") + small_2 = constant_op.constant([1, 2], name="small_2") with ops.control_dependencies([check_ops.assert_equal(small, small_2)]): out = array_ops.identity(small) - self.evaluate(out) + out.eval() + + def test_raises_when_equal_but_non_broadcastable_shapes(self): + with self.test_session(): + small = constant_op.constant([1, 1, 1], name="small") + small_2 = constant_op.constant([1, 1], name="small_2") + with self.assertRaisesRegexp(ValueError, "must be"): + with ops.control_dependencies([check_ops.assert_equal(small, small_2)]): + out = array_ops.identity(small) + out.eval() - @test_util.run_in_graph_and_eager_modes() def test_doesnt_raise_when_both_empty(self): - larry = constant_op.constant([]) - curly = constant_op.constant([]) - with ops.control_dependencies([check_ops.assert_equal(larry, curly)]): - out = array_ops.identity(larry) - self.evaluate(out) + with self.test_session(): + larry = constant_op.constant([]) + curly = constant_op.constant([]) + with ops.control_dependencies([check_ops.assert_equal(larry, curly)]): + out = array_ops.identity(larry) + out.eval() class AssertNoneEqualTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() def test_doesnt_raise_when_not_equal(self): - small = constant_op.constant([1, 2], name="small") - big = constant_op.constant([10, 20], name="small") - with ops.control_dependencies( - [check_ops.assert_none_equal(big, small)]): - out = array_ops.identity(small) - self.evaluate(out) - - @test_util.run_in_graph_and_eager_modes() + with self.test_session(): + small = constant_op.constant([1, 2], name="small") + big = constant_op.constant([10, 20], name="small") + with ops.control_dependencies( + [check_ops.assert_none_equal(big, small)]): + out = array_ops.identity(small) + out.eval() + def test_raises_when_equal(self): - small = constant_op.constant([3, 1], name="small") - with self.assertRaisesOpError("x != y did not hold"): + with self.test_session(): + small = constant_op.constant([3, 1], name="small") with ops.control_dependencies( [check_ops.assert_none_equal(small, small)]): out = array_ops.identity(small) - self.evaluate(out) + with self.assertRaisesOpError("x != y did not hold"): + out.eval() - @test_util.run_in_graph_and_eager_modes() def test_doesnt_raise_when_not_equal_and_broadcastable_shapes(self): - small = constant_op.constant([1, 2], name="small") - big = constant_op.constant([3], name="big") - with ops.control_dependencies( - [check_ops.assert_none_equal(small, big)]): - out = array_ops.identity(small) - self.evaluate(out) - - @test_util.run_in_graph_and_eager_modes() + with self.test_session(): + small = constant_op.constant([1, 2], name="small") + big = constant_op.constant([3], name="big") + with ops.control_dependencies( + [check_ops.assert_none_equal(small, big)]): + out = array_ops.identity(small) + out.eval() + def test_raises_when_not_equal_but_non_broadcastable_shapes(self): with self.test_session(): small = constant_op.constant([1, 1, 1], name="small") big = constant_op.constant([10, 10], name="big") - # The exception in eager and non-eager mode is different because - # eager mode relies on shape check done as part of the C++ op, while - # graph mode does shape checks when creating the `Operation` instance. - with self.assertRaisesRegexp( - (ValueError, errors.InvalidArgumentError), - (r"Incompatible shapes: \[3\] vs. \[2\]|" - r"Dimensions must be equal, but are 3 and 2")): + with self.assertRaisesRegexp(ValueError, "must be"): with ops.control_dependencies( [check_ops.assert_none_equal(small, big)]): out = array_ops.identity(small) - self.evaluate(out) + out.eval() - @test_util.run_in_graph_and_eager_modes() def test_doesnt_raise_when_both_empty(self): with self.test_session(): larry = constant_op.constant([]) @@ -253,82 +182,62 @@ class AssertNoneEqualTest(test.TestCase): with ops.control_dependencies( [check_ops.assert_none_equal(larry, curly)]): out = array_ops.identity(larry) - self.evaluate(out) - - def test_returns_none_with_eager(self): - with context.eager_mode(): - t1 = constant_op.constant([1, 2]) - t2 = constant_op.constant([3, 4]) - x = check_ops.assert_none_equal(t1, t2) - assert x is None + out.eval() class AssertLessTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() def test_raises_when_equal(self): - small = constant_op.constant([1, 2], name="small") - with self.assertRaisesOpError("failure message.*\n*.* x < y did not hold"): + with self.test_session(): + small = constant_op.constant([1, 2], name="small") with ops.control_dependencies( [check_ops.assert_less( - small, small, message="failure message")]): + small, small, message="fail")]): out = array_ops.identity(small) - self.evaluate(out) + with self.assertRaisesOpError("fail.*small.*small"): + out.eval() - @test_util.run_in_graph_and_eager_modes() def test_raises_when_greater(self): - small = constant_op.constant([1, 2], name="small") - big = constant_op.constant([3, 4], name="big") - with self.assertRaisesOpError("x < y did not hold"): + with self.test_session(): + small = constant_op.constant([1, 2], name="small") + big = constant_op.constant([3, 4], name="big") with ops.control_dependencies([check_ops.assert_less(big, small)]): out = array_ops.identity(small) - self.evaluate(out) + with self.assertRaisesOpError("big.*small"): + out.eval() - @test_util.run_in_graph_and_eager_modes() def test_doesnt_raise_when_less(self): - small = constant_op.constant([3, 1], name="small") - big = constant_op.constant([4, 2], name="big") - with ops.control_dependencies([check_ops.assert_less(small, big)]): - out = array_ops.identity(small) - self.evaluate(out) + with self.test_session(): + small = constant_op.constant([3, 1], name="small") + big = constant_op.constant([4, 2], name="big") + with ops.control_dependencies([check_ops.assert_less(small, big)]): + out = array_ops.identity(small) + out.eval() - @test_util.run_in_graph_and_eager_modes() def test_doesnt_raise_when_less_and_broadcastable_shapes(self): - small = constant_op.constant([1], name="small") - big = constant_op.constant([3, 2], name="big") - with ops.control_dependencies([check_ops.assert_less(small, big)]): - out = array_ops.identity(small) - self.evaluate(out) - - @test_util.run_in_graph_and_eager_modes() - def test_raises_when_less_but_non_broadcastable_shapes(self): - small = constant_op.constant([1, 1, 1], name="small") - big = constant_op.constant([3, 2], name="big") - # The exception in eager and non-eager mode is different because - # eager mode relies on shape check done as part of the C++ op, while - # graph mode does shape checks when creating the `Operation` instance. - with self.assertRaisesRegexp( - (ValueError, errors.InvalidArgumentError), - (r"Incompatible shapes: \[3\] vs. \[2\]|" - "Dimensions must be equal, but are 3 and 2")): + with self.test_session(): + small = constant_op.constant([1], name="small") + big = constant_op.constant([3, 2], name="big") with ops.control_dependencies([check_ops.assert_less(small, big)]): out = array_ops.identity(small) - self.evaluate(out) + out.eval() + + def test_raises_when_less_but_non_broadcastable_shapes(self): + with self.test_session(): + small = constant_op.constant([1, 1, 1], name="small") + big = constant_op.constant([3, 2], name="big") + with self.assertRaisesRegexp(ValueError, "must be"): + with ops.control_dependencies([check_ops.assert_less(small, big)]): + out = array_ops.identity(small) + out.eval() - @test_util.run_in_graph_and_eager_modes() def test_doesnt_raise_when_both_empty(self): - larry = constant_op.constant([]) - curly = constant_op.constant([]) - with ops.control_dependencies([check_ops.assert_less(larry, curly)]): - out = array_ops.identity(larry) - self.evaluate(out) - - def test_returns_none_with_eager(self): - with context.eager_mode(): - t1 = constant_op.constant([1, 2]) - t2 = constant_op.constant([3, 4]) - x = check_ops.assert_less(t1, t2) - assert x is None + with self.test_session(): + larry = constant_op.constant([]) + curly = constant_op.constant([]) + with ops.control_dependencies([check_ops.assert_less(larry, curly)]): + out = array_ops.identity(larry) + out.eval() class AssertLessEqualTest(test.TestCase): |