aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework/ops_test.py')
-rw-r--r--tensorflow/python/framework/ops_test.py67
1 files changed, 61 insertions, 6 deletions
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index f848b69782..48328a7f58 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import gc
+import os
import threading
import weakref
@@ -2542,6 +2543,56 @@ class StatisticsTest(test_util.TensorFlowTestCase):
self.assertEqual(3, flops_total.value)
+class DeviceStackTest(test_util.TensorFlowTestCase):
+
+ def testBasicDeviceAssignmentMetadata(self):
+
+ def device_func(unused_op):
+ return "/cpu:*"
+
+ const_zero = constant_op.constant([0.0], name="zero")
+ with ops.device("/cpu"):
+ const_one = constant_op.constant([1.0], name="one")
+ with ops.device("/cpu:0"):
+ const_two = constant_op.constant([2.0], name="two")
+ with ops.device(device_func):
+ const_three = constant_op.constant(3.0, name="three")
+
+ self.assertEqual(0, len(const_zero.op._device_assignments))
+
+ one_list = const_one.op._device_assignments
+ self.assertEqual(1, len(one_list))
+ self.assertEqual("/cpu", one_list[0].obj)
+ self.assertEqual("ops_test.py", os.path.basename(one_list[0].filename))
+
+ two_list = const_two.op._device_assignments
+ self.assertEqual(2, len(two_list))
+ devices = [t.obj for t in two_list]
+ self.assertEqual(set(["/cpu", "/cpu:0"]), set(devices))
+
+ three_list = const_three.op._device_assignments
+ self.assertEqual(1, len(three_list))
+ func_description = three_list[0].obj
+ expected_regex = r"device_func<.*ops_test.py, [0-9]+"
+ self.assertRegexpMatches(func_description, expected_regex)
+
+ def testDeviceAssignmentMetadataForGraphDeviceAndTfDeviceFunctions(self):
+
+ with ops.device("/cpu"):
+ const_one = constant_op.constant([1.0], name="one")
+ with ops.get_default_graph().device("/cpu"):
+ const_two = constant_op.constant([2.0], name="two")
+
+ one_metadata = const_one.op._device_assignments[0]
+ two_metadata = const_two.op._device_assignments[0]
+
+ # Verify both types of device assignment return the right stack info.
+ self.assertRegexpMatches("ops_test.py",
+ os.path.basename(one_metadata.filename))
+ self.assertEqual(one_metadata.filename, two_metadata.filename)
+ self.assertEqual(one_metadata.lineno + 2, two_metadata.lineno)
+
+
class ColocationGroupTest(test_util.TensorFlowTestCase):
def testBasic(self):
@@ -2554,13 +2605,17 @@ class ColocationGroupTest(test_util.TensorFlowTestCase):
with self.assertRaises(ValueError):
c.op.get_attr("_class")
- # Roughly test that stack information is being saved correctly for the op.
- locations_dict = b.op._colocation_dict
- self.assertIn("a", locations_dict)
- metadata = locations_dict["a"]
+ def testBasicColocationMetadata(self):
+ const_two = constant_op.constant([2.0], name="two")
+ with ops.colocate_with(const_two.op):
+ const_three = constant_op.constant(3.0, name="three")
+ locations_dict = const_three.op._colocation_dict
+ self.assertIn("two", locations_dict)
+ metadata = locations_dict["two"]
self.assertIsNone(metadata.obj)
- basename = metadata.filename.split("/")[-1]
- self.assertEqual("ops_test.py", basename)
+ # Check that this test's filename is recorded as the file containing the
+ # colocation statement.
+ self.assertEqual("ops_test.py", os.path.basename(metadata.filename))
def testColocationDeviceInteraction(self):
with ops.device("/cpu:0"):