diff options
author | 2018-07-26 01:27:42 -0700 | |
---|---|---|
committer | 2018-07-26 01:31:05 -0700 | |
commit | e5cc33df74ec4f761da26c87bb785edfa3fb8280 (patch) | |
tree | 67dedd2463f703845bd0444fb6fac1b7c12a80aa /tensorflow/python/framework | |
parent | b1dc68e816e2bf6b8acd3651077c890f2f2f3b7b (diff) |
Convert device function stack into TraceableStack for use in error message interpolation.
PiperOrigin-RevId: 206120307
Diffstat (limited to 'tensorflow/python/framework')
-rw-r--r-- | tensorflow/python/framework/error_interpolation.py | 53 | ||||
-rw-r--r-- | tensorflow/python/framework/error_interpolation_test.py | 105 | ||||
-rw-r--r-- | tensorflow/python/framework/function.py | 26 | ||||
-rw-r--r-- | tensorflow/python/framework/ops.py | 137 | ||||
-rw-r--r-- | tensorflow/python/framework/ops_test.py | 67 |
5 files changed, 319 insertions, 69 deletions
diff --git a/tensorflow/python/framework/error_interpolation.py b/tensorflow/python/framework/error_interpolation.py index a79073b748..7719d03019 100644 --- a/tensorflow/python/framework/error_interpolation.py +++ b/tensorflow/python/framework/error_interpolation.py @@ -87,6 +87,53 @@ def _parse_message(message): return seps, tags +def _compute_device_summary_from_list(device_assignment_list, prefix=""): + """Return a summary of an op's device function stack. + + Args: + device_assignment_list: The op._device_assignments list. + prefix: An optional string prefix used before each line of the multi- + line string returned by this function. + + Returns: + A multi-line string similar to: + Device assignments active during op creation: + with tf.device(/cpu:0): <test_1.py:27> + with tf.device(some_func<foo.py, 123>): <test_2.py:38> + The first line will have no padding to its left by default. Subsequent + lines will have two spaces of left-padding. Use the prefix argument + to increase indentation. + """ + if not device_assignment_list: + message = "No device assignments were active during op creation." + return prefix + message + + str_list = [] + str_list.append("%sDevice assignments active during op creation:" % prefix) + + for traceable_obj in device_assignment_list: + location_summary = "<{file}:{line}>".format(file=traceable_obj.filename, + line=traceable_obj.lineno) + subs = { + "prefix": prefix, + "indent": " ", + "dev_name": traceable_obj.obj, + "loc": location_summary, + } + str_list.append( + "{prefix}{indent}with tf.device({dev_name}): {loc}".format(**subs)) + + return "\n".join(str_list) + + +def _compute_device_assignment_summary_from_op(op, prefix=""): + if not op: + return "" + # pylint: disable=protected-access + return _compute_device_summary_from_list(op._device_assignments, prefix) + # pylint: enable=protected-access + + def _compute_colocation_summary_from_dict(colocation_dict, prefix=""): """Return a summary of an op's colocation stack. @@ -203,6 +250,7 @@ def _compute_field_dict(op): "file": default_value, "line": default_value, "colocations": default_value, + "devices": default_value, } frame = _get_defining_frame_from_op(op) if frame: @@ -211,6 +259,9 @@ def _compute_field_dict(op): colocation_summary = _compute_colocation_summary_from_op(op) if colocation_summary: field_dict["colocations"] = colocation_summary + device_summary = _compute_device_assignment_summary_from_op(op) + if device_summary: + field_dict["devices"] = device_summary return field_dict @@ -233,6 +284,8 @@ def interpolate(error_message, graph): node_name_to_substitution_dict = {} for name in [t.name for t in tags]: + if name in node_name_to_substitution_dict: + continue try: op = graph.get_operation_by_name(name) except KeyError: diff --git a/tensorflow/python/framework/error_interpolation_test.py b/tensorflow/python/framework/error_interpolation_test.py index 1e5cb73854..fbf182879b 100644 --- a/tensorflow/python/framework/error_interpolation_test.py +++ b/tensorflow/python/framework/error_interpolation_test.py @@ -57,13 +57,32 @@ def _modify_op_stack_with_filenames(op, num_user_frames, user_filename, op._traceback = stack -def assert_node_in_colocation_summary(test_obj, colocation_summary_string, - name, filename="", lineno=""): - lineno = str(lineno) - name_phrase = "colocate_with(%s)" % name - for term in [name_phrase, filename, lineno]: - test_obj.assertIn(term, colocation_summary_string) - test_obj.assertNotIn("loc:@", colocation_summary_string) +class ComputeDeviceSummaryFromOpTest(test.TestCase): + + def testCorrectFormatWithActiveDeviceAssignments(self): + assignments = [] + assignments.append( + traceable_stack.TraceableObject("/cpu:0", + filename="hope.py", + lineno=24)) + assignments.append( + traceable_stack.TraceableObject("/gpu:2", + filename="please.py", + lineno=42)) + + summary = error_interpolation._compute_device_summary_from_list( + assignments, prefix=" ") + + self.assertIn("tf.device(/cpu:0)", summary) + self.assertIn("<hope.py:24>", summary) + self.assertIn("tf.device(/gpu:2)", summary) + self.assertIn("<please.py:42>", summary) + + def testCorrectFormatWhenNoColocationsWereActive(self): + device_assignment_list = [] + summary = error_interpolation._compute_device_summary_from_list( + device_assignment_list, prefix=" ") + self.assertIn("No device assignments", summary) class ComputeColocationSummaryFromOpTest(test.TestCase): @@ -81,15 +100,10 @@ class ComputeColocationSummaryFromOpTest(test.TestCase): } summary = error_interpolation._compute_colocation_summary_from_dict( colocation_dict, prefix=" ") - assert_node_in_colocation_summary(self, - summary, - name="test_node_1", - filename="test_1.py", - lineno=27) - assert_node_in_colocation_summary(self, summary, - name="test_node_2", - filename="test_2.py", - lineno=38) + self.assertIn("colocate_with(test_node_1)", summary) + self.assertIn("<test_1.py:27>", summary) + self.assertIn("colocate_with(test_node_2)", summary) + self.assertIn("<test_2.py:38>", summary) def testCorrectFormatWhenNoColocationsWereActive(self): colocation_dict = {} @@ -98,9 +112,10 @@ class ComputeColocationSummaryFromOpTest(test.TestCase): self.assertIn("No node-device colocations", summary) -class InterpolateTest(test.TestCase): +class InterpolateFilenamesAndLineNumbersTest(test.TestCase): def setUp(self): + ops.reset_default_graph() # Add nodes to the graph for retrieval by name later. constant_op.constant(1, name="One") constant_op.constant(2, name="Two") @@ -177,9 +192,57 @@ class InterpolateTest(test.TestCase): self.assertRegexpMatches(interpolated_string, expected_regex) +class InterpolateDeviceSummaryTest(test.TestCase): + + def _fancy_device_function(self, unused_op): + return "/cpu:*" + + def setUp(self): + ops.reset_default_graph() + self.zero = constant_op.constant([0.0], name="zero") + with ops.device("/cpu"): + self.one = constant_op.constant([1.0], name="one") + with ops.device("/cpu:0"): + self.two = constant_op.constant([2.0], name="two") + with ops.device(self._fancy_device_function): + self.three = constant_op.constant(3.0, name="three") + + self.graph = self.three.graph + + def testNodeZeroHasNoDeviceSummaryInfo(self): + message = "^^node:zero:${devices}^^" + result = error_interpolation.interpolate(message, self.graph) + self.assertIn("No device assignments were active", result) + + def testNodeOneHasExactlyOneInterpolatedDevice(self): + message = "^^node:one:${devices}^^" + result = error_interpolation.interpolate(message, self.graph) + num_devices = result.count("tf.device") + self.assertEqual(1, num_devices) + self.assertIn("tf.device(/cpu)", result) + + def testNodeTwoHasTwoInterpolatedDevice(self): + message = "^^node:two:${devices}^^" + result = error_interpolation.interpolate(message, self.graph) + num_devices = result.count("tf.device") + self.assertEqual(2, num_devices) + self.assertIn("tf.device(/cpu)", result) + self.assertIn("tf.device(/cpu:0)", result) + + def testNodeThreeHasFancyFunctionDisplayNameForInterpolatedDevice(self): + message = "^^node:three:${devices}^^" + result = error_interpolation.interpolate(message, self.graph) + num_devices = result.count("tf.device") + self.assertEqual(1, num_devices) + name_re = r"_fancy_device_function<.*error_interpolation_test.py, [0-9]+>" + expected_re = r"with tf.device\(.*%s\)" % name_re + self.assertRegexpMatches(result, expected_re) + + class InterpolateColocationSummaryTest(test.TestCase): def setUp(self): + ops.reset_default_graph() # Add nodes to the graph for retrieval by name later. node_one = constant_op.constant(1, name="One") node_two = constant_op.constant(2, name="Two") @@ -203,12 +266,12 @@ class InterpolateColocationSummaryTest(test.TestCase): def testNodeThreeHasColocationInterpolation(self): message = "^^node:Three_with_one:${colocations}^^" result = error_interpolation.interpolate(message, self.graph) - assert_node_in_colocation_summary(self, result, name="One") + self.assertIn("colocate_with(One)", result) def testNodeFourHasColocationInterpolationForNodeThreeOnly(self): message = "^^node:Four_with_three:${colocations}^^" result = error_interpolation.interpolate(message, self.graph) - assert_node_in_colocation_summary(self, result, name="Three_with_one") + self.assertIn("colocate_with(Three_with_one)", result) self.assertNotIn( "One", result, "Node One should not appear in Four_with_three's summary:\n%s" @@ -217,8 +280,8 @@ class InterpolateColocationSummaryTest(test.TestCase): def testNodeFiveHasColocationInterpolationForNodeOneAndTwo(self): message = "^^node:Five_with_one_with_two:${colocations}^^" result = error_interpolation.interpolate(message, self.graph) - assert_node_in_colocation_summary(self, result, name="One") - assert_node_in_colocation_summary(self, result, name="Two") + self.assertIn("colocate_with(One)", result) + self.assertIn("colocate_with(Two)", result) def testColocationInterpolationForNodeLackingColocation(self): message = "^^node:One:${colocations}^^" diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index 6525607fae..c76743d2c6 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -38,8 +38,8 @@ from tensorflow.python.ops import cond_v2_impl from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.util import compat +from tensorflow.python.util import function_utils from tensorflow.python.util import tf_contextlib -from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect # This is to avoid a circular dependency with cond_v2_impl. @@ -255,9 +255,12 @@ class _DefinedFunction(object): # Constructed only when C API is enabled, lazily self._c_func = None self._sub_functions = dict() # Constructed with _definition or _c_func - device_stack = ops.get_default_graph()._device_function_stack # pylint: disable=protected-access + # pylint: disable=protected-access + device_funcs = ops.get_default_graph()._device_functions_outer_to_inner + # pylint: enable=protected-access + # Get the innermost device if possbile. - self._caller_device = device_stack[-1] if device_stack else None + self._caller_device = device_funcs[-1] if device_funcs else None # Cached OpDef for this function. When C API is enabled, this is # the only part of FunctionDef that we cache in Python. When C API @@ -354,7 +357,7 @@ class _DefinedFunction(object): if self._func_name: base_func_name = self._func_name else: - base_func_name = _get_func_name(self._func) + base_func_name = function_utils.get_func_name(self._func) if self._grad_func: base_func_name += ("_%s" % self._grad_func.name) kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs) @@ -841,7 +844,7 @@ def func_graph_from_py_func(func, arg_names, arg_types, name=None, ValueError: if func returns None. """ if not name: - name = _get_func_name(func) + name = function_utils.get_func_name(func) func_graph = _FuncGraph(name, capture_by_value) with func_graph.as_default(), ops.device(device): @@ -1139,19 +1142,6 @@ def _parse_kwargs_as_attrs(func_name, **kwargs): return attrs -def _get_func_name(func): - _, func = tf_decorator.unwrap(func) - if callable(func): - if tf_inspect.isfunction(func): - return func.__name__ - elif tf_inspect.ismethod(func): - return "%s.%s" % (func.__self__.__name__, func.__name__) - else: # Probably a class instance with __call__ - return type(func) - else: - raise ValueError("Argument must be callable") - - def get_extra_vars(): """Returns the captured variables by the function. diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 0fd028ebf0..197317cad9 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -50,14 +50,15 @@ from tensorflow.python.framework import registry from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import traceable_stack from tensorflow.python.framework import versions -from tensorflow.python.util import tf_stack from tensorflow.python.ops import control_flow_util from tensorflow.python.platform import app from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat from tensorflow.python.util import decorator_utils +from tensorflow.python.util import function_utils from tensorflow.python.util import lock_util from tensorflow.python.util import tf_contextlib +from tensorflow.python.util import tf_stack from tensorflow.python.util.deprecation import deprecated_args from tensorflow.python.util.tf_export import tf_export @@ -73,6 +74,27 @@ def tensor_id(tensor): return tensor._id # pylint: disable=protected-access +class _UserDeviceSpec(object): + """Store user-specified device and provide computation of merged device.""" + + def __init__(self, device_name_or_function): + self._device_name_or_function = device_name_or_function + + self.display_name = str(self._device_name_or_function) + if callable(self._device_name_or_function): + dev_func = self._device_name_or_function + func_name = function_utils.get_func_name(dev_func) + func_code = function_utils.get_func_code(dev_func) + self.display_name = "%s<%s, %d>" % (func_name, + func_code.co_filename, + func_code.co_firstlineno) + + self.function = self._device_name_or_function + if not (self._device_name_or_function is None or + callable(self._device_name_or_function)): + self.function = pydev.merge_device(self._device_name_or_function) + + class _NullContextmanager(object): def __enter__(self): @@ -1719,7 +1741,12 @@ class Operation(object): self._id_value = self._graph._next_id() self._original_op = original_op self._traceback = tf_stack.extract_stack() - # List of traceable_stack.TraceableObjects for colocation context managers. + + # List of _UserDevSpecs holding code location of device context manager + # invocations and the users original argument to them. + self._device_code_locations = None + # Dict mapping op name to file and line information for op colocation + # context managers. self._colocation_code_locations = None self._control_flow_context = self.graph._get_control_flow_context() # pylint: enable=protected-access @@ -1861,6 +1888,37 @@ class Operation(object): return c_api.TF_OperationDevice(self._c_op) @property + def _device_assignments(self): + """Code locations for device context managers active at op creation. + + This property will return a list of traceable_stack.TraceableObject + instances where .obj is a string representing the assigned device + (or information about the function that would be applied to this op + to compute the desired device) and the filename and lineno members + record the location of the relevant device context manager. + + For example, suppose file_a contained these lines: + + file_a.py: + 15: with tf.device('/gpu:0'): + 16: node_b = tf.constant(4, name='NODE_B') + + Then a TraceableObject t_obj representing the device context manager + would have these member values: + + t_obj.obj -> '/gpu:0' + t_obj.filename = 'file_a.py' + t_obj.lineno = 15 + + and node_b.op._device_assignments would return the list [t_obj]. + + Returns: + [str: traceable_stack.TraceableObject, ...] as per this method's + description, above. + """ + return self._device_code_locations or [] + + @property def _colocation_dict(self): """Code locations for colocation context managers active at op creation. @@ -1881,11 +1939,10 @@ class Operation(object): would have these member values: t_obj.obj -> None - t_obj.name = 'NODE_A' t_obj.filename = 'file_a.py' t_obj.lineno = 15 - and node_b.op._colocation_code_locations would return the dictionary + and node_b.op._colocation_dict would return the dictionary { 'NODE_A': t_obj } @@ -2735,7 +2792,7 @@ class Graph(object): # Functions that will be applied to choose a device if none is specified. # After switch_to_thread_local(), self._thread_local._device_function_stack # is used instead. - self._graph_device_function_stack = [] + self._graph_device_function_stack = traceable_stack.TraceableStack() # Default original_op applied to new ops. self._default_original_op = None # Current control flow context. It could be either CondContext or @@ -4047,7 +4104,7 @@ class Graph(object): # In the future, a caller may specify that device_functions win # over colocation, in which case we can add support. device_fn_tmp = self._device_function_stack - self._device_function_stack = [] + self._device_function_stack = traceable_stack.TraceableStack() if ignore_existing: current_stack = self._colocation_stack @@ -4071,6 +4128,13 @@ class Graph(object): if ignore_existing: self._colocation_stack = current_stack + def _add_device_to_stack(self, device_name_or_function, offset=0): + """Add device to stack manually, separate from a context manager.""" + total_offset = 1 + offset + spec = _UserDeviceSpec(device_name_or_function) + self._device_function_stack.push_obj(spec, offset=total_offset) + return spec + @tf_contextlib.contextmanager def device(self, device_name_or_function): # pylint: disable=line-too-long @@ -4128,31 +4192,26 @@ class Graph(object): Yields: A context manager that specifies the default device to use for newly created ops. - """ - # pylint: enable=line-too-long - if (device_name_or_function is not None and - not callable(device_name_or_function)): - device_function = pydev.merge_device(device_name_or_function) - else: - device_function = device_name_or_function - + self._add_device_to_stack(device_name_or_function, offset=2) try: - self._device_function_stack.append(device_function) yield finally: - self._device_function_stack.pop() + self._device_function_stack.pop_obj() def _apply_device_functions(self, op): """Applies the current device function stack to the given operation.""" - # Apply any device functions in reverse order, so that the most recently + # Apply any device functions in LIFO order, so that the most recently # pushed function has the first chance to apply a device to the op. # We apply here because the result can depend on the Operation's # signature, which is computed in the Operation constructor. - for device_function in reversed(self._device_function_stack): - if device_function is None: + # pylint: disable=protected-access + for device_spec in self._device_function_stack.peek_objs(): + if device_spec.function is None: break - op._set_device(device_function(op)) # pylint: disable=protected-access + op._set_device(device_spec.function(op)) + op._device_code_locations = self._snapshot_device_function_stack_metadata() + # pylint: enable=protected-access # pylint: disable=g-doc-return-or-yield @tf_contextlib.contextmanager @@ -4676,17 +4735,45 @@ class Graph(object): if self._stack_state_is_thread_local: # This may be called from a thread where device_function_stack doesn't yet # exist. + # pylint: disable=protected-access if not hasattr(self._thread_local, "_device_function_stack"): - self._thread_local._device_function_stack = ( - self._graph_device_function_stack[:]) + stack_copy_for_this_thread = self._graph_device_function_stack.copy() + self._thread_local._device_function_stack = stack_copy_for_this_thread return self._thread_local._device_function_stack + # pylint: enable=protected-access else: return self._graph_device_function_stack + @property + def _device_functions_outer_to_inner(self): + user_device_specs = self._device_function_stack.peek_objs() + device_functions = [spec.function for spec in user_device_specs] + device_functions_outer_to_inner = list(reversed(device_functions)) + return device_functions_outer_to_inner + + def _snapshot_device_function_stack_metadata(self): + """Return device function stack as a list of TraceableObjects. + + Returns: + [traceable_stack.TraceableObject, ...] where each TraceableObject's .obj + member is a displayable name for the user's argument to Graph.device, and + the filename and lineno members point to the code location where + Graph.device was called directly or indirectly by the user. + """ + traceable_objects = self._device_function_stack.peek_traceable_objs() + snapshot = [] + for obj in traceable_objects: + obj_copy = obj.copy_metadata() + obj_copy.obj = obj.obj.display_name + snapshot.append(obj_copy) + return snapshot + @_device_function_stack.setter def _device_function_stack(self, device_function_stack): if self._stack_state_is_thread_local: + # pylint: disable=protected-access self._thread_local._device_function_stack = device_function_stack + # pylint: enable=protected-access else: self._graph_device_function_stack = device_function_stack @@ -4696,12 +4783,12 @@ class Graph(object): if self._stack_state_is_thread_local: # This may be called from a thread where colocation_stack doesn't yet # exist. + # pylint: disable=protected-access if not hasattr(self._thread_local, "_colocation_stack"): stack_copy_for_this_thread = self._graph_colocation_stack.copy() - # pylint: disable=protected-access self._thread_local._colocation_stack = stack_copy_for_this_thread - # pylint: enable=protected-access return self._thread_local._colocation_stack + # pylint: enable=protected-access else: return self._graph_colocation_stack @@ -4713,7 +4800,9 @@ class Graph(object): @_colocation_stack.setter def _colocation_stack(self, colocation_stack): if self._stack_state_is_thread_local: + # pylint: disable=protected-access self._thread_local._colocation_stack = colocation_stack + # pylint: enable=protected-access else: self._graph_colocation_stack = colocation_stack diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index f848b69782..48328a7f58 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import gc +import os import threading import weakref @@ -2542,6 +2543,56 @@ class StatisticsTest(test_util.TensorFlowTestCase): self.assertEqual(3, flops_total.value) +class DeviceStackTest(test_util.TensorFlowTestCase): + + def testBasicDeviceAssignmentMetadata(self): + + def device_func(unused_op): + return "/cpu:*" + + const_zero = constant_op.constant([0.0], name="zero") + with ops.device("/cpu"): + const_one = constant_op.constant([1.0], name="one") + with ops.device("/cpu:0"): + const_two = constant_op.constant([2.0], name="two") + with ops.device(device_func): + const_three = constant_op.constant(3.0, name="three") + + self.assertEqual(0, len(const_zero.op._device_assignments)) + + one_list = const_one.op._device_assignments + self.assertEqual(1, len(one_list)) + self.assertEqual("/cpu", one_list[0].obj) + self.assertEqual("ops_test.py", os.path.basename(one_list[0].filename)) + + two_list = const_two.op._device_assignments + self.assertEqual(2, len(two_list)) + devices = [t.obj for t in two_list] + self.assertEqual(set(["/cpu", "/cpu:0"]), set(devices)) + + three_list = const_three.op._device_assignments + self.assertEqual(1, len(three_list)) + func_description = three_list[0].obj + expected_regex = r"device_func<.*ops_test.py, [0-9]+" + self.assertRegexpMatches(func_description, expected_regex) + + def testDeviceAssignmentMetadataForGraphDeviceAndTfDeviceFunctions(self): + + with ops.device("/cpu"): + const_one = constant_op.constant([1.0], name="one") + with ops.get_default_graph().device("/cpu"): + const_two = constant_op.constant([2.0], name="two") + + one_metadata = const_one.op._device_assignments[0] + two_metadata = const_two.op._device_assignments[0] + + # Verify both types of device assignment return the right stack info. + self.assertRegexpMatches("ops_test.py", + os.path.basename(one_metadata.filename)) + self.assertEqual(one_metadata.filename, two_metadata.filename) + self.assertEqual(one_metadata.lineno + 2, two_metadata.lineno) + + class ColocationGroupTest(test_util.TensorFlowTestCase): def testBasic(self): @@ -2554,13 +2605,17 @@ class ColocationGroupTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError): c.op.get_attr("_class") - # Roughly test that stack information is being saved correctly for the op. - locations_dict = b.op._colocation_dict - self.assertIn("a", locations_dict) - metadata = locations_dict["a"] + def testBasicColocationMetadata(self): + const_two = constant_op.constant([2.0], name="two") + with ops.colocate_with(const_two.op): + const_three = constant_op.constant(3.0, name="three") + locations_dict = const_three.op._colocation_dict + self.assertIn("two", locations_dict) + metadata = locations_dict["two"] self.assertIsNone(metadata.obj) - basename = metadata.filename.split("/")[-1] - self.assertEqual("ops_test.py", basename) + # Check that this test's filename is recorded as the file containing the + # colocation statement. + self.assertEqual("ops_test.py", os.path.basename(metadata.filename)) def testColocationDeviceInteraction(self): with ops.device("/cpu:0"): |