aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/test_util_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework/test_util_test.py')
-rw-r--r--tensorflow/python/framework/test_util_test.py128
1 files changed, 128 insertions, 0 deletions
diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py
new file mode 100644
index 0000000000..e0618cfea4
--- /dev/null
+++ b/tensorflow/python/framework/test_util_test.py
@@ -0,0 +1,128 @@
+"""Tests for tensorflow.ops.test_util."""
+import threading
+
+import tensorflow.python.platform
+import numpy as np
+
+from google.protobuf import text_format
+
+from tensorflow.core.framework import graph_pb2
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.framework import types
+from tensorflow.python.platform import googletest
+from tensorflow.python.ops import logging_ops
+
+class TestUtilTest(test_util.TensorFlowTestCase):
+
+ def testIsGoogleCudaEnabled(self):
+ # The test doesn't assert anything. It ensures the py wrapper
+ # function is generated correctly.
+ if test_util.IsGoogleCudaEnabled():
+ print "GoogleCuda is enabled"
+ else:
+ print "GoogleCuda is disabled"
+
+ def testAssertProtoEqualsStr(self):
+
+ graph_str = "node { name: 'w1' op: 'params' }"
+ graph_def = graph_pb2.GraphDef()
+ text_format.Merge(graph_str, graph_def)
+
+ # test string based comparison
+ self.assertProtoEquals(graph_str, graph_def)
+
+ # test original comparison
+ self.assertProtoEquals(graph_def, graph_def)
+
+ def testNDArrayNear(self):
+ a1 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
+ a2 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
+ a3 = np.array([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]])
+ self.assertTrue(self._NDArrayNear(a1, a2, 1e-5))
+ self.assertFalse(self._NDArrayNear(a1, a3, 1e-5))
+
+ def testCheckedThreadSucceeds(self):
+ def noop(ev):
+ ev.set()
+
+ event_arg = threading.Event()
+
+ self.assertFalse(event_arg.is_set())
+ t = self.checkedThread(target=noop, args=(event_arg,))
+ t.start()
+ t.join()
+ self.assertTrue(event_arg.is_set())
+
+ def testCheckedThreadFails(self):
+ def err_func():
+ return 1 / 0
+
+ t = self.checkedThread(target=err_func)
+ t.start()
+ with self.assertRaises(self.failureException) as fe:
+ t.join()
+ self.assertTrue("integer division or modulo by zero"
+ in fe.exception.message)
+
+ def testCheckedThreadWithWrongAssertionFails(self):
+ x = 37
+
+ def err_func():
+ self.assertTrue(x < 10)
+
+ t = self.checkedThread(target=err_func)
+ t.start()
+ with self.assertRaises(self.failureException) as fe:
+ t.join()
+ self.assertTrue("False is not true" in fe.exception.message)
+
+ def testMultipleThreadsWithOneFailure(self):
+ def err_func(i):
+ self.assertTrue(i != 7)
+
+ threads = [self.checkedThread(target=err_func, args=(i,))
+ for i in range(10)]
+ for t in threads:
+ t.start()
+ for i, t in enumerate(threads):
+ if i == 7:
+ with self.assertRaises(self.failureException):
+ t.join()
+ else:
+ t.join()
+
+ def _WeMustGoDeeper(self, msg):
+ with self.assertRaisesOpError(msg):
+ node_def = ops._NodeDef("op_type", "name")
+ node_def_orig = ops._NodeDef("op_type_orig", "orig")
+ op_orig = ops.Operation(node_def_orig, ops.get_default_graph())
+ op = ops.Operation(node_def, ops.get_default_graph(), original_op=op_orig)
+ raise errors.UnauthenticatedError(node_def, op, "true_err")
+
+ def testAssertRaisesOpErrorDoesNotPassMessageDueToLeakedStack(self):
+ with self.assertRaises(AssertionError):
+ self._WeMustGoDeeper("this_is_not_the_error_you_are_looking_for")
+
+ self._WeMustGoDeeper("true_err")
+ self._WeMustGoDeeper("name")
+ self._WeMustGoDeeper("orig")
+
+ def testAllCloseScalars(self):
+ self.assertAllClose(7, 7 + 1e-8)
+ with self.assertRaisesRegexp(AssertionError, r"Not equal to tolerance"):
+ self.assertAllClose(7, 8)
+
+ def testForceGPU(self):
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ "Cannot assign a device to node"):
+ with self.test_session(force_gpu=True):
+ # this relies on us not having a GPU implementation for assert, which
+ # seems sensible
+ x = [True]
+ y = [15]
+ logging_ops.Assert(x, y).run()
+
+if __name__ == "__main__":
+ googletest.main()