diff options
Diffstat (limited to 'tensorflow/python/framework/ops_test.py')
-rw-r--r-- | tensorflow/python/framework/ops_test.py | 67 |
1 files changed, 61 insertions, 6 deletions
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index f848b69782..48328a7f58 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import gc +import os import threading import weakref @@ -2542,6 +2543,56 @@ class StatisticsTest(test_util.TensorFlowTestCase): self.assertEqual(3, flops_total.value) +class DeviceStackTest(test_util.TensorFlowTestCase): + + def testBasicDeviceAssignmentMetadata(self): + + def device_func(unused_op): + return "/cpu:*" + + const_zero = constant_op.constant([0.0], name="zero") + with ops.device("/cpu"): + const_one = constant_op.constant([1.0], name="one") + with ops.device("/cpu:0"): + const_two = constant_op.constant([2.0], name="two") + with ops.device(device_func): + const_three = constant_op.constant(3.0, name="three") + + self.assertEqual(0, len(const_zero.op._device_assignments)) + + one_list = const_one.op._device_assignments + self.assertEqual(1, len(one_list)) + self.assertEqual("/cpu", one_list[0].obj) + self.assertEqual("ops_test.py", os.path.basename(one_list[0].filename)) + + two_list = const_two.op._device_assignments + self.assertEqual(2, len(two_list)) + devices = [t.obj for t in two_list] + self.assertEqual(set(["/cpu", "/cpu:0"]), set(devices)) + + three_list = const_three.op._device_assignments + self.assertEqual(1, len(three_list)) + func_description = three_list[0].obj + expected_regex = r"device_func<.*ops_test.py, [0-9]+" + self.assertRegexpMatches(func_description, expected_regex) + + def testDeviceAssignmentMetadataForGraphDeviceAndTfDeviceFunctions(self): + + with ops.device("/cpu"): + const_one = constant_op.constant([1.0], name="one") + with ops.get_default_graph().device("/cpu"): + const_two = constant_op.constant([2.0], name="two") + + one_metadata = const_one.op._device_assignments[0] + two_metadata = const_two.op._device_assignments[0] + + # Verify both types of device assignment return the right stack info. + self.assertRegexpMatches("ops_test.py", + os.path.basename(one_metadata.filename)) + self.assertEqual(one_metadata.filename, two_metadata.filename) + self.assertEqual(one_metadata.lineno + 2, two_metadata.lineno) + + class ColocationGroupTest(test_util.TensorFlowTestCase): def testBasic(self): @@ -2554,13 +2605,17 @@ class ColocationGroupTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError): c.op.get_attr("_class") - # Roughly test that stack information is being saved correctly for the op. - locations_dict = b.op._colocation_dict - self.assertIn("a", locations_dict) - metadata = locations_dict["a"] + def testBasicColocationMetadata(self): + const_two = constant_op.constant([2.0], name="two") + with ops.colocate_with(const_two.op): + const_three = constant_op.constant(3.0, name="three") + locations_dict = const_three.op._colocation_dict + self.assertIn("two", locations_dict) + metadata = locations_dict["two"] self.assertIsNone(metadata.obj) - basename = metadata.filename.split("/")[-1] - self.assertEqual("ops_test.py", basename) + # Check that this test's filename is recorded as the file containing the + # colocation statement. + self.assertEqual("ops_test.py", os.path.basename(metadata.filename)) def testColocationDeviceInteraction(self): with ops.device("/cpu:0"): |