diff options
Diffstat (limited to 'tensorflow/python/framework/ops.py')
-rw-r--r-- | tensorflow/python/framework/ops.py | 177 |
1 files changed, 83 insertions, 94 deletions
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index b07c57d265..6a5c44e4d9 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -20,7 +20,6 @@ from __future__ import print_function import collections import copy -import linecache import os import re import sys @@ -49,7 +48,9 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import registry from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import traceable_stack from tensorflow.python.framework import versions +from tensorflow.python.util import tf_stack from tensorflow.python.ops import control_flow_util from tensorflow.python.platform import app from tensorflow.python.platform import tf_logging as logging @@ -1711,10 +1712,14 @@ class Operation(object): # This will be set by self.inputs. self._inputs_val = None - self._id_value = self._graph._next_id() # pylint: disable=protected-access + # pylint: disable=protected-access + self._id_value = self._graph._next_id() self._original_op = original_op - self._traceback = self._graph._extract_stack() # pylint: disable=protected-access - self._control_flow_context = self.graph._get_control_flow_context() # pylint: disable=protected-access + self._traceback = tf_stack.extract_stack() + # List of traceable_stack.TraceableObjects for colocation context managers. + self._colocation_code_locations = None + self._control_flow_context = self.graph._get_control_flow_context() + # pylint: enable=protected-access # Initialize self._c_op. if c_op: @@ -1853,6 +1858,42 @@ class Operation(object): return c_api.TF_OperationDevice(self._c_op) @property + def _colocation_dict(self): + """Code locations for colocation context managers active at op creation. + + This property will return a dictionary for which the keys are nodes with + which this Operation is colocated, and for which the values are + traceable_stack.TraceableObject instances. The TraceableObject instances + record the location of the relevant colocation context manager but have the + "obj" field set to None to prevent leaking private data. + + For example, suppose file_a contained these lines: + + file_a.py: + 14: node_a = tf.constant(3, name='NODE_A') + 15: with tf.colocate_with(node_a): + 16: node_b = tf.constant(4, name='NODE_B') + + Then a TraceableObject t_obj representing the colocation context manager + would have these member values: + + t_obj.obj -> None + t_obj.name = 'NODE_A' + t_obj.filename = 'file_a.py' + t_obj.lineno = 15 + + and node_b.op._colocation_code_locations would return the dictionary + + { 'NODE_A': t_obj } + + Returns: + {str: traceable_stack.TraceableObject} as per this method's description, + above. + """ + locations_dict = self._colocation_code_locations or {} + return locations_dict.copy() + + @property def _output_types(self): """List this operation's output types. @@ -2154,7 +2195,7 @@ class Operation(object): @property def traceback(self): """Returns the call stack from when this operation was constructed.""" - return self._graph._convert_stack(self._traceback) # pylint: disable=protected-access + return tf_stack.convert_stack(self._traceback) @property def traceback_with_start_lines(self): @@ -2163,9 +2204,8 @@ class Operation(object): Returns: A list of 5-tuples (filename, lineno, name, code, func_start_lineno). """ - return self._graph._convert_stack( # pylint: disable=protected-access - self._traceback, - include_func_start_lineno=True) + return tf_stack.convert_stack(self._traceback, + include_func_start_lineno=True) def _set_attr(self, attr_name, attr_value): """Private method used to set an attribute in the node_def.""" @@ -2617,7 +2657,6 @@ def _name_from_scope_name(name): _MUTATION_LOCK_GROUP = 0 _SESSION_RUN_LOCK_GROUP = 1 - @tf_export("Graph") class Graph(object): """A TensorFlow computation, represented as a dataflow graph. @@ -2726,7 +2765,7 @@ class Graph(object): self._building_function = False # Stack of colocate_with ops. After switch_to_thread_local(), # self._thread_local._colocation_stack is used instead. - self._graph_colocation_stack = [] + self._graph_colocation_stack = traceable_stack.TraceableStack() # Set of tensors that are dangerous to feed! self._unfeedable_tensors = set() # Set of operations that are dangerous to fetch! @@ -2766,36 +2805,6 @@ class Graph(object): """Temporary hack; can be overridden to force C API usage.""" return _USE_C_API - def _convert_stack(self, stack, include_func_start_lineno=False): - """Converts a stack extracted using _extract_stack() to a traceback stack. - - Args: - stack: A list of n 5-tuples, - (filename, lineno, name, frame_globals, func_start_lineno). - include_func_start_lineno: True if function start line number should be - included as the 5th entry in return tuples. - - Returns: - A list of n 4-tuples or 5-tuples - (filename, lineno, name, code, [optional: func_start_lineno]), where the - code tuple element is calculated from the corresponding elements of the - input tuple. - """ - ret = [] - for (filename, lineno, name, frame_globals, func_start_lineno, - unused_frame_info) in stack: - linecache.checkcache(filename) - line = linecache.getline(filename, lineno, frame_globals) - if line: - line = line.strip() - else: - line = None - if include_func_start_lineno: - ret.append((filename, lineno, name, line, func_start_lineno)) - else: - ret.append((filename, lineno, name, line)) - return ret - # Note: this method is private because the API of tf.Graph() is public and # frozen, and this functionality is still not ready for public visibility. @tf_contextlib.contextmanager @@ -2803,63 +2812,23 @@ class Graph(object): # This step makes a copy of the existing stack, and it also initializes # self._thread_local._variable_creator_stack if it doesn't exist yet. old = list(self._variable_creator_stack) - self._thread_local._variable_creator_stack.append(creator) + self._thread_local._variable_creator_stack.append(creator) # pylint: disable=protected-access try: yield finally: - self._thread_local._variable_creator_stack = old + self._thread_local._variable_creator_stack = old # pylint: disable=protected-access # Note: this method is private because the API of tf.Graph() is public and # frozen, and this functionality is still not ready for public visibility. @property def _variable_creator_stack(self): if not hasattr(self._thread_local, "_variable_creator_stack"): - self._thread_local._variable_creator_stack = [] - return list(self._thread_local._variable_creator_stack) + self._thread_local._variable_creator_stack = [] # pylint: disable=protected-access + return list(self._thread_local._variable_creator_stack) # pylint: disable=protected-access @_variable_creator_stack.setter def _variable_creator_stack(self, variable_creator_stack): - self._thread_local._variable_creator_stack = variable_creator_stack - - def _extract_stack(self): - """A lightweight, extensible re-implementation of traceback.extract_stack. - - NOTE(mrry): traceback.extract_stack eagerly retrieves the line of code for - each stack frame using linecache, which results in an abundance of stat() - calls. This implementation does not retrieve the code, and any consumer - should apply _convert_stack to the result to obtain a traceback that can - be formatted etc. using traceback methods. - - Derived classes can implement _extract_frame_info() to add extra information - to the traceback. - - Returns: - A list of 6-tuples - (filename, lineno, name, frame_globals, func_start_lineno, custom_info) - corresponding to the call stack of the current thread. - """ - try: - raise ZeroDivisionError - except ZeroDivisionError: - f = sys.exc_info()[2].tb_frame.f_back - ret = [] - while f is not None: - lineno = f.f_lineno - co = f.f_code - filename = co.co_filename - name = co.co_name - frame_globals = f.f_globals - func_start_lineno = co.co_firstlineno - frame_info = self._extract_frame_info(f) - ret.append((filename, lineno, name, frame_globals, func_start_lineno, - frame_info)) - f = f.f_back - ret.reverse() - return ret - - def _extract_frame_info(self, frame): # pylint: disable=unused-argument - """Extracts custom information from a frame in an op traceback.""" - return None + self._thread_local._variable_creator_stack = variable_creator_stack # pylint: disable=protected-access def _check_not_finalized(self): """Check if the graph is finalized. @@ -3301,7 +3270,7 @@ class Graph(object): if self._colocation_stack: all_colocation_groups = [] - for colocation_op in self._colocation_stack: + for colocation_op in self._colocation_stack.peek_objs(): all_colocation_groups.extend(colocation_op.colocation_groups()) if colocation_op.device: # Make this device match the device of the colocated op, to provide @@ -3320,6 +3289,7 @@ class Graph(object): # pylint: disable=protected-access op._set_attr("_class", attr_value_pb2.AttrValue( list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups))) + op._colocation_code_locations = self._snapshot_colocation_stack_metadata() # pylint: enable=protected-access # Sets "container" attribute if @@ -3629,9 +3599,13 @@ class Graph(object): This method should be used if you want to create multiple graphs in the same process. For convenience, a global default graph is provided, and all ops will be added to this graph if you do not - create a new graph explicitly. Use this method with the `with` keyword - to specify that ops created within the scope of a block should be - added to this graph. + create a new graph explicitly. + + Use this method with the `with` keyword to specify that ops created within + the scope of a block should be added to this graph. In this case, once + the scope of the `with` is exited, the previous default graph is set again + as default. There is a stack, so it's ok to have multiple nested levels + of `as_default` calls. The default graph is a property of the current thread. If you create a new thread, and wish to use the default graph in that @@ -4074,10 +4048,13 @@ class Graph(object): if ignore_existing: current_stack = self._colocation_stack - self._colocation_stack = [] + self._colocation_stack = traceable_stack.TraceableStack() if op is not None: - self._colocation_stack.append(op) + # offset refers to the stack frame used for storing code location. + # We use 4, the sum of 1 to use our caller's stack frame and 3 + # to jump over layers of context managers above us. + self._colocation_stack.push_obj(op, offset=4) try: yield @@ -4085,7 +4062,7 @@ class Graph(object): # Restore device function stack self._device_function_stack = device_fn_tmp if op is not None: - self._colocation_stack.pop() + self._colocation_stack.pop_obj() # Reset the colocation stack if requested. if ignore_existing: @@ -4712,15 +4689,24 @@ class Graph(object): @property def _colocation_stack(self): + """Return thread-local copy of colocation stack.""" if self._stack_state_is_thread_local: # This may be called from a thread where colocation_stack doesn't yet # exist. if not hasattr(self._thread_local, "_colocation_stack"): - self._thread_local._colocation_stack = self._graph_colocation_stack[:] + stack_copy_for_this_thread = self._graph_colocation_stack.copy() + # pylint: disable=protected-access + self._thread_local._colocation_stack = stack_copy_for_this_thread + # pylint: enable=protected-access return self._thread_local._colocation_stack else: return self._graph_colocation_stack + def _snapshot_colocation_stack_metadata(self): + """Return colocation stack metadata as a dictionary.""" + traceable_objects = self._colocation_stack.peek_traceable_objs() + return {obj.obj.name: obj.copy_metadata() for obj in traceable_objects} + @_colocation_stack.setter def _colocation_stack(self, colocation_stack): if self._stack_state_is_thread_local: @@ -5251,7 +5237,10 @@ def enable_eager_execution(config=None, to this function. """ return enable_eager_execution_internal( - config, device_policy, execution_mode, None) + config=config, + device_policy=device_policy, + execution_mode=execution_mode, + server_def=None) def enable_eager_execution_internal(config=None, |