aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-26 01:27:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-26 01:31:05 -0700
commite5cc33df74ec4f761da26c87bb785edfa3fb8280 (patch)
tree67dedd2463f703845bd0444fb6fac1b7c12a80aa /tensorflow/python/framework
parentb1dc68e816e2bf6b8acd3651077c890f2f2f3b7b (diff)
Convert device function stack into TraceableStack for use in error message interpolation.
PiperOrigin-RevId: 206120307
Diffstat (limited to 'tensorflow/python/framework')
-rw-r--r--tensorflow/python/framework/error_interpolation.py53
-rw-r--r--tensorflow/python/framework/error_interpolation_test.py105
-rw-r--r--tensorflow/python/framework/function.py26
-rw-r--r--tensorflow/python/framework/ops.py137
-rw-r--r--tensorflow/python/framework/ops_test.py67
5 files changed, 319 insertions, 69 deletions
diff --git a/tensorflow/python/framework/error_interpolation.py b/tensorflow/python/framework/error_interpolation.py
index a79073b748..7719d03019 100644
--- a/tensorflow/python/framework/error_interpolation.py
+++ b/tensorflow/python/framework/error_interpolation.py
@@ -87,6 +87,53 @@ def _parse_message(message):
return seps, tags
+def _compute_device_summary_from_list(device_assignment_list, prefix=""):
+ """Return a summary of an op's device function stack.
+
+ Args:
+ device_assignment_list: The op._device_assignments list.
+ prefix: An optional string prefix used before each line of the multi-
+ line string returned by this function.
+
+ Returns:
+ A multi-line string similar to:
+ Device assignments active during op creation:
+ with tf.device(/cpu:0): <test_1.py:27>
+ with tf.device(some_func<foo.py, 123>): <test_2.py:38>
+ The first line will have no padding to its left by default. Subsequent
+ lines will have two spaces of left-padding. Use the prefix argument
+ to increase indentation.
+ """
+ if not device_assignment_list:
+ message = "No device assignments were active during op creation."
+ return prefix + message
+
+ str_list = []
+ str_list.append("%sDevice assignments active during op creation:" % prefix)
+
+ for traceable_obj in device_assignment_list:
+ location_summary = "<{file}:{line}>".format(file=traceable_obj.filename,
+ line=traceable_obj.lineno)
+ subs = {
+ "prefix": prefix,
+ "indent": " ",
+ "dev_name": traceable_obj.obj,
+ "loc": location_summary,
+ }
+ str_list.append(
+ "{prefix}{indent}with tf.device({dev_name}): {loc}".format(**subs))
+
+ return "\n".join(str_list)
+
+
+def _compute_device_assignment_summary_from_op(op, prefix=""):
+ if not op:
+ return ""
+ # pylint: disable=protected-access
+ return _compute_device_summary_from_list(op._device_assignments, prefix)
+ # pylint: enable=protected-access
+
+
def _compute_colocation_summary_from_dict(colocation_dict, prefix=""):
"""Return a summary of an op's colocation stack.
@@ -203,6 +250,7 @@ def _compute_field_dict(op):
"file": default_value,
"line": default_value,
"colocations": default_value,
+ "devices": default_value,
}
frame = _get_defining_frame_from_op(op)
if frame:
@@ -211,6 +259,9 @@ def _compute_field_dict(op):
colocation_summary = _compute_colocation_summary_from_op(op)
if colocation_summary:
field_dict["colocations"] = colocation_summary
+ device_summary = _compute_device_assignment_summary_from_op(op)
+ if device_summary:
+ field_dict["devices"] = device_summary
return field_dict
@@ -233,6 +284,8 @@ def interpolate(error_message, graph):
node_name_to_substitution_dict = {}
for name in [t.name for t in tags]:
+ if name in node_name_to_substitution_dict:
+ continue
try:
op = graph.get_operation_by_name(name)
except KeyError:
diff --git a/tensorflow/python/framework/error_interpolation_test.py b/tensorflow/python/framework/error_interpolation_test.py
index 1e5cb73854..fbf182879b 100644
--- a/tensorflow/python/framework/error_interpolation_test.py
+++ b/tensorflow/python/framework/error_interpolation_test.py
@@ -57,13 +57,32 @@ def _modify_op_stack_with_filenames(op, num_user_frames, user_filename,
op._traceback = stack
-def assert_node_in_colocation_summary(test_obj, colocation_summary_string,
- name, filename="", lineno=""):
- lineno = str(lineno)
- name_phrase = "colocate_with(%s)" % name
- for term in [name_phrase, filename, lineno]:
- test_obj.assertIn(term, colocation_summary_string)
- test_obj.assertNotIn("loc:@", colocation_summary_string)
+class ComputeDeviceSummaryFromOpTest(test.TestCase):
+
+ def testCorrectFormatWithActiveDeviceAssignments(self):
+ assignments = []
+ assignments.append(
+ traceable_stack.TraceableObject("/cpu:0",
+ filename="hope.py",
+ lineno=24))
+ assignments.append(
+ traceable_stack.TraceableObject("/gpu:2",
+ filename="please.py",
+ lineno=42))
+
+ summary = error_interpolation._compute_device_summary_from_list(
+ assignments, prefix=" ")
+
+ self.assertIn("tf.device(/cpu:0)", summary)
+ self.assertIn("<hope.py:24>", summary)
+ self.assertIn("tf.device(/gpu:2)", summary)
+ self.assertIn("<please.py:42>", summary)
+
+ def testCorrectFormatWhenNoColocationsWereActive(self):
+ device_assignment_list = []
+ summary = error_interpolation._compute_device_summary_from_list(
+ device_assignment_list, prefix=" ")
+ self.assertIn("No device assignments", summary)
class ComputeColocationSummaryFromOpTest(test.TestCase):
@@ -81,15 +100,10 @@ class ComputeColocationSummaryFromOpTest(test.TestCase):
}
summary = error_interpolation._compute_colocation_summary_from_dict(
colocation_dict, prefix=" ")
- assert_node_in_colocation_summary(self,
- summary,
- name="test_node_1",
- filename="test_1.py",
- lineno=27)
- assert_node_in_colocation_summary(self, summary,
- name="test_node_2",
- filename="test_2.py",
- lineno=38)
+ self.assertIn("colocate_with(test_node_1)", summary)
+ self.assertIn("<test_1.py:27>", summary)
+ self.assertIn("colocate_with(test_node_2)", summary)
+ self.assertIn("<test_2.py:38>", summary)
def testCorrectFormatWhenNoColocationsWereActive(self):
colocation_dict = {}
@@ -98,9 +112,10 @@ class ComputeColocationSummaryFromOpTest(test.TestCase):
self.assertIn("No node-device colocations", summary)
-class InterpolateTest(test.TestCase):
+class InterpolateFilenamesAndLineNumbersTest(test.TestCase):
def setUp(self):
+ ops.reset_default_graph()
# Add nodes to the graph for retrieval by name later.
constant_op.constant(1, name="One")
constant_op.constant(2, name="Two")
@@ -177,9 +192,57 @@ class InterpolateTest(test.TestCase):
self.assertRegexpMatches(interpolated_string, expected_regex)
+class InterpolateDeviceSummaryTest(test.TestCase):
+
+ def _fancy_device_function(self, unused_op):
+ return "/cpu:*"
+
+ def setUp(self):
+ ops.reset_default_graph()
+ self.zero = constant_op.constant([0.0], name="zero")
+ with ops.device("/cpu"):
+ self.one = constant_op.constant([1.0], name="one")
+ with ops.device("/cpu:0"):
+ self.two = constant_op.constant([2.0], name="two")
+ with ops.device(self._fancy_device_function):
+ self.three = constant_op.constant(3.0, name="three")
+
+ self.graph = self.three.graph
+
+ def testNodeZeroHasNoDeviceSummaryInfo(self):
+ message = "^^node:zero:${devices}^^"
+ result = error_interpolation.interpolate(message, self.graph)
+ self.assertIn("No device assignments were active", result)
+
+ def testNodeOneHasExactlyOneInterpolatedDevice(self):
+ message = "^^node:one:${devices}^^"
+ result = error_interpolation.interpolate(message, self.graph)
+ num_devices = result.count("tf.device")
+ self.assertEqual(1, num_devices)
+ self.assertIn("tf.device(/cpu)", result)
+
+ def testNodeTwoHasTwoInterpolatedDevice(self):
+ message = "^^node:two:${devices}^^"
+ result = error_interpolation.interpolate(message, self.graph)
+ num_devices = result.count("tf.device")
+ self.assertEqual(2, num_devices)
+ self.assertIn("tf.device(/cpu)", result)
+ self.assertIn("tf.device(/cpu:0)", result)
+
+ def testNodeThreeHasFancyFunctionDisplayNameForInterpolatedDevice(self):
+ message = "^^node:three:${devices}^^"
+ result = error_interpolation.interpolate(message, self.graph)
+ num_devices = result.count("tf.device")
+ self.assertEqual(1, num_devices)
+ name_re = r"_fancy_device_function<.*error_interpolation_test.py, [0-9]+>"
+ expected_re = r"with tf.device\(.*%s\)" % name_re
+ self.assertRegexpMatches(result, expected_re)
+
+
class InterpolateColocationSummaryTest(test.TestCase):
def setUp(self):
+ ops.reset_default_graph()
# Add nodes to the graph for retrieval by name later.
node_one = constant_op.constant(1, name="One")
node_two = constant_op.constant(2, name="Two")
@@ -203,12 +266,12 @@ class InterpolateColocationSummaryTest(test.TestCase):
def testNodeThreeHasColocationInterpolation(self):
message = "^^node:Three_with_one:${colocations}^^"
result = error_interpolation.interpolate(message, self.graph)
- assert_node_in_colocation_summary(self, result, name="One")
+ self.assertIn("colocate_with(One)", result)
def testNodeFourHasColocationInterpolationForNodeThreeOnly(self):
message = "^^node:Four_with_three:${colocations}^^"
result = error_interpolation.interpolate(message, self.graph)
- assert_node_in_colocation_summary(self, result, name="Three_with_one")
+ self.assertIn("colocate_with(Three_with_one)", result)
self.assertNotIn(
"One", result,
"Node One should not appear in Four_with_three's summary:\n%s"
@@ -217,8 +280,8 @@ class InterpolateColocationSummaryTest(test.TestCase):
def testNodeFiveHasColocationInterpolationForNodeOneAndTwo(self):
message = "^^node:Five_with_one_with_two:${colocations}^^"
result = error_interpolation.interpolate(message, self.graph)
- assert_node_in_colocation_summary(self, result, name="One")
- assert_node_in_colocation_summary(self, result, name="Two")
+ self.assertIn("colocate_with(One)", result)
+ self.assertIn("colocate_with(Two)", result)
def testColocationInterpolationForNodeLackingColocation(self):
message = "^^node:One:${colocations}^^"
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index 6525607fae..c76743d2c6 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -38,8 +38,8 @@ from tensorflow.python.ops import cond_v2_impl
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.util import compat
+from tensorflow.python.util import function_utils
from tensorflow.python.util import tf_contextlib
-from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
# This is to avoid a circular dependency with cond_v2_impl.
@@ -255,9 +255,12 @@ class _DefinedFunction(object):
# Constructed only when C API is enabled, lazily
self._c_func = None
self._sub_functions = dict() # Constructed with _definition or _c_func
- device_stack = ops.get_default_graph()._device_function_stack # pylint: disable=protected-access
+ # pylint: disable=protected-access
+ device_funcs = ops.get_default_graph()._device_functions_outer_to_inner
+ # pylint: enable=protected-access
+
# Get the innermost device if possbile.
- self._caller_device = device_stack[-1] if device_stack else None
+ self._caller_device = device_funcs[-1] if device_funcs else None
# Cached OpDef for this function. When C API is enabled, this is
# the only part of FunctionDef that we cache in Python. When C API
@@ -354,7 +357,7 @@ class _DefinedFunction(object):
if self._func_name:
base_func_name = self._func_name
else:
- base_func_name = _get_func_name(self._func)
+ base_func_name = function_utils.get_func_name(self._func)
if self._grad_func:
base_func_name += ("_%s" % self._grad_func.name)
kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs)
@@ -841,7 +844,7 @@ def func_graph_from_py_func(func, arg_names, arg_types, name=None,
ValueError: if func returns None.
"""
if not name:
- name = _get_func_name(func)
+ name = function_utils.get_func_name(func)
func_graph = _FuncGraph(name, capture_by_value)
with func_graph.as_default(), ops.device(device):
@@ -1139,19 +1142,6 @@ def _parse_kwargs_as_attrs(func_name, **kwargs):
return attrs
-def _get_func_name(func):
- _, func = tf_decorator.unwrap(func)
- if callable(func):
- if tf_inspect.isfunction(func):
- return func.__name__
- elif tf_inspect.ismethod(func):
- return "%s.%s" % (func.__self__.__name__, func.__name__)
- else: # Probably a class instance with __call__
- return type(func)
- else:
- raise ValueError("Argument must be callable")
-
-
def get_extra_vars():
"""Returns the captured variables by the function.
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 0fd028ebf0..197317cad9 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -50,14 +50,15 @@ from tensorflow.python.framework import registry
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import traceable_stack
from tensorflow.python.framework import versions
-from tensorflow.python.util import tf_stack
from tensorflow.python.ops import control_flow_util
from tensorflow.python.platform import app
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
from tensorflow.python.util import decorator_utils
+from tensorflow.python.util import function_utils
from tensorflow.python.util import lock_util
from tensorflow.python.util import tf_contextlib
+from tensorflow.python.util import tf_stack
from tensorflow.python.util.deprecation import deprecated_args
from tensorflow.python.util.tf_export import tf_export
@@ -73,6 +74,27 @@ def tensor_id(tensor):
return tensor._id # pylint: disable=protected-access
+class _UserDeviceSpec(object):
+ """Store user-specified device and provide computation of merged device."""
+
+ def __init__(self, device_name_or_function):
+ self._device_name_or_function = device_name_or_function
+
+ self.display_name = str(self._device_name_or_function)
+ if callable(self._device_name_or_function):
+ dev_func = self._device_name_or_function
+ func_name = function_utils.get_func_name(dev_func)
+ func_code = function_utils.get_func_code(dev_func)
+ self.display_name = "%s<%s, %d>" % (func_name,
+ func_code.co_filename,
+ func_code.co_firstlineno)
+
+ self.function = self._device_name_or_function
+ if not (self._device_name_or_function is None or
+ callable(self._device_name_or_function)):
+ self.function = pydev.merge_device(self._device_name_or_function)
+
+
class _NullContextmanager(object):
def __enter__(self):
@@ -1719,7 +1741,12 @@ class Operation(object):
self._id_value = self._graph._next_id()
self._original_op = original_op
self._traceback = tf_stack.extract_stack()
- # List of traceable_stack.TraceableObjects for colocation context managers.
+
+ # List of _UserDevSpecs holding code location of device context manager
+ # invocations and the users original argument to them.
+ self._device_code_locations = None
+ # Dict mapping op name to file and line information for op colocation
+ # context managers.
self._colocation_code_locations = None
self._control_flow_context = self.graph._get_control_flow_context()
# pylint: enable=protected-access
@@ -1861,6 +1888,37 @@ class Operation(object):
return c_api.TF_OperationDevice(self._c_op)
@property
+ def _device_assignments(self):
+ """Code locations for device context managers active at op creation.
+
+ This property will return a list of traceable_stack.TraceableObject
+ instances where .obj is a string representing the assigned device
+ (or information about the function that would be applied to this op
+ to compute the desired device) and the filename and lineno members
+ record the location of the relevant device context manager.
+
+ For example, suppose file_a contained these lines:
+
+ file_a.py:
+ 15: with tf.device('/gpu:0'):
+ 16: node_b = tf.constant(4, name='NODE_B')
+
+ Then a TraceableObject t_obj representing the device context manager
+ would have these member values:
+
+ t_obj.obj -> '/gpu:0'
+ t_obj.filename = 'file_a.py'
+ t_obj.lineno = 15
+
+ and node_b.op._device_assignments would return the list [t_obj].
+
+ Returns:
+ [str: traceable_stack.TraceableObject, ...] as per this method's
+ description, above.
+ """
+ return self._device_code_locations or []
+
+ @property
def _colocation_dict(self):
"""Code locations for colocation context managers active at op creation.
@@ -1881,11 +1939,10 @@ class Operation(object):
would have these member values:
t_obj.obj -> None
- t_obj.name = 'NODE_A'
t_obj.filename = 'file_a.py'
t_obj.lineno = 15
- and node_b.op._colocation_code_locations would return the dictionary
+ and node_b.op._colocation_dict would return the dictionary
{ 'NODE_A': t_obj }
@@ -2735,7 +2792,7 @@ class Graph(object):
# Functions that will be applied to choose a device if none is specified.
# After switch_to_thread_local(), self._thread_local._device_function_stack
# is used instead.
- self._graph_device_function_stack = []
+ self._graph_device_function_stack = traceable_stack.TraceableStack()
# Default original_op applied to new ops.
self._default_original_op = None
# Current control flow context. It could be either CondContext or
@@ -4047,7 +4104,7 @@ class Graph(object):
# In the future, a caller may specify that device_functions win
# over colocation, in which case we can add support.
device_fn_tmp = self._device_function_stack
- self._device_function_stack = []
+ self._device_function_stack = traceable_stack.TraceableStack()
if ignore_existing:
current_stack = self._colocation_stack
@@ -4071,6 +4128,13 @@ class Graph(object):
if ignore_existing:
self._colocation_stack = current_stack
+ def _add_device_to_stack(self, device_name_or_function, offset=0):
+ """Add device to stack manually, separate from a context manager."""
+ total_offset = 1 + offset
+ spec = _UserDeviceSpec(device_name_or_function)
+ self._device_function_stack.push_obj(spec, offset=total_offset)
+ return spec
+
@tf_contextlib.contextmanager
def device(self, device_name_or_function):
# pylint: disable=line-too-long
@@ -4128,31 +4192,26 @@ class Graph(object):
Yields:
A context manager that specifies the default device to use for newly
created ops.
-
"""
- # pylint: enable=line-too-long
- if (device_name_or_function is not None and
- not callable(device_name_or_function)):
- device_function = pydev.merge_device(device_name_or_function)
- else:
- device_function = device_name_or_function
-
+ self._add_device_to_stack(device_name_or_function, offset=2)
try:
- self._device_function_stack.append(device_function)
yield
finally:
- self._device_function_stack.pop()
+ self._device_function_stack.pop_obj()
def _apply_device_functions(self, op):
"""Applies the current device function stack to the given operation."""
- # Apply any device functions in reverse order, so that the most recently
+ # Apply any device functions in LIFO order, so that the most recently
# pushed function has the first chance to apply a device to the op.
# We apply here because the result can depend on the Operation's
# signature, which is computed in the Operation constructor.
- for device_function in reversed(self._device_function_stack):
- if device_function is None:
+ # pylint: disable=protected-access
+ for device_spec in self._device_function_stack.peek_objs():
+ if device_spec.function is None:
break
- op._set_device(device_function(op)) # pylint: disable=protected-access
+ op._set_device(device_spec.function(op))
+ op._device_code_locations = self._snapshot_device_function_stack_metadata()
+ # pylint: enable=protected-access
# pylint: disable=g-doc-return-or-yield
@tf_contextlib.contextmanager
@@ -4676,17 +4735,45 @@ class Graph(object):
if self._stack_state_is_thread_local:
# This may be called from a thread where device_function_stack doesn't yet
# exist.
+ # pylint: disable=protected-access
if not hasattr(self._thread_local, "_device_function_stack"):
- self._thread_local._device_function_stack = (
- self._graph_device_function_stack[:])
+ stack_copy_for_this_thread = self._graph_device_function_stack.copy()
+ self._thread_local._device_function_stack = stack_copy_for_this_thread
return self._thread_local._device_function_stack
+ # pylint: enable=protected-access
else:
return self._graph_device_function_stack
+ @property
+ def _device_functions_outer_to_inner(self):
+ user_device_specs = self._device_function_stack.peek_objs()
+ device_functions = [spec.function for spec in user_device_specs]
+ device_functions_outer_to_inner = list(reversed(device_functions))
+ return device_functions_outer_to_inner
+
+ def _snapshot_device_function_stack_metadata(self):
+ """Return device function stack as a list of TraceableObjects.
+
+ Returns:
+ [traceable_stack.TraceableObject, ...] where each TraceableObject's .obj
+ member is a displayable name for the user's argument to Graph.device, and
+ the filename and lineno members point to the code location where
+ Graph.device was called directly or indirectly by the user.
+ """
+ traceable_objects = self._device_function_stack.peek_traceable_objs()
+ snapshot = []
+ for obj in traceable_objects:
+ obj_copy = obj.copy_metadata()
+ obj_copy.obj = obj.obj.display_name
+ snapshot.append(obj_copy)
+ return snapshot
+
@_device_function_stack.setter
def _device_function_stack(self, device_function_stack):
if self._stack_state_is_thread_local:
+ # pylint: disable=protected-access
self._thread_local._device_function_stack = device_function_stack
+ # pylint: enable=protected-access
else:
self._graph_device_function_stack = device_function_stack
@@ -4696,12 +4783,12 @@ class Graph(object):
if self._stack_state_is_thread_local:
# This may be called from a thread where colocation_stack doesn't yet
# exist.
+ # pylint: disable=protected-access
if not hasattr(self._thread_local, "_colocation_stack"):
stack_copy_for_this_thread = self._graph_colocation_stack.copy()
- # pylint: disable=protected-access
self._thread_local._colocation_stack = stack_copy_for_this_thread
- # pylint: enable=protected-access
return self._thread_local._colocation_stack
+ # pylint: enable=protected-access
else:
return self._graph_colocation_stack
@@ -4713,7 +4800,9 @@ class Graph(object):
@_colocation_stack.setter
def _colocation_stack(self, colocation_stack):
if self._stack_state_is_thread_local:
+ # pylint: disable=protected-access
self._thread_local._colocation_stack = colocation_stack
+ # pylint: enable=protected-access
else:
self._graph_colocation_stack = colocation_stack
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index f848b69782..48328a7f58 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import gc
+import os
import threading
import weakref
@@ -2542,6 +2543,56 @@ class StatisticsTest(test_util.TensorFlowTestCase):
self.assertEqual(3, flops_total.value)
+class DeviceStackTest(test_util.TensorFlowTestCase):
+
+ def testBasicDeviceAssignmentMetadata(self):
+
+ def device_func(unused_op):
+ return "/cpu:*"
+
+ const_zero = constant_op.constant([0.0], name="zero")
+ with ops.device("/cpu"):
+ const_one = constant_op.constant([1.0], name="one")
+ with ops.device("/cpu:0"):
+ const_two = constant_op.constant([2.0], name="two")
+ with ops.device(device_func):
+ const_three = constant_op.constant(3.0, name="three")
+
+ self.assertEqual(0, len(const_zero.op._device_assignments))
+
+ one_list = const_one.op._device_assignments
+ self.assertEqual(1, len(one_list))
+ self.assertEqual("/cpu", one_list[0].obj)
+ self.assertEqual("ops_test.py", os.path.basename(one_list[0].filename))
+
+ two_list = const_two.op._device_assignments
+ self.assertEqual(2, len(two_list))
+ devices = [t.obj for t in two_list]
+ self.assertEqual(set(["/cpu", "/cpu:0"]), set(devices))
+
+ three_list = const_three.op._device_assignments
+ self.assertEqual(1, len(three_list))
+ func_description = three_list[0].obj
+ expected_regex = r"device_func<.*ops_test.py, [0-9]+"
+ self.assertRegexpMatches(func_description, expected_regex)
+
+ def testDeviceAssignmentMetadataForGraphDeviceAndTfDeviceFunctions(self):
+
+ with ops.device("/cpu"):
+ const_one = constant_op.constant([1.0], name="one")
+ with ops.get_default_graph().device("/cpu"):
+ const_two = constant_op.constant([2.0], name="two")
+
+ one_metadata = const_one.op._device_assignments[0]
+ two_metadata = const_two.op._device_assignments[0]
+
+ # Verify both types of device assignment return the right stack info.
+ self.assertRegexpMatches("ops_test.py",
+ os.path.basename(one_metadata.filename))
+ self.assertEqual(one_metadata.filename, two_metadata.filename)
+ self.assertEqual(one_metadata.lineno + 2, two_metadata.lineno)
+
+
class ColocationGroupTest(test_util.TensorFlowTestCase):
def testBasic(self):
@@ -2554,13 +2605,17 @@ class ColocationGroupTest(test_util.TensorFlowTestCase):
with self.assertRaises(ValueError):
c.op.get_attr("_class")
- # Roughly test that stack information is being saved correctly for the op.
- locations_dict = b.op._colocation_dict
- self.assertIn("a", locations_dict)
- metadata = locations_dict["a"]
+ def testBasicColocationMetadata(self):
+ const_two = constant_op.constant([2.0], name="two")
+ with ops.colocate_with(const_two.op):
+ const_three = constant_op.constant(3.0, name="three")
+ locations_dict = const_three.op._colocation_dict
+ self.assertIn("two", locations_dict)
+ metadata = locations_dict["two"]
self.assertIsNone(metadata.obj)
- basename = metadata.filename.split("/")[-1]
- self.assertEqual("ops_test.py", basename)
+ # Check that this test's filename is recorded as the file containing the
+ # colocation statement.
+ self.assertEqual("ops_test.py", os.path.basename(metadata.filename))
def testColocationDeviceInteraction(self):
with ops.device("/cpu:0"):