aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework/ops.py')
-rw-r--r--tensorflow/python/framework/ops.py177
1 files changed, 83 insertions, 94 deletions
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index b07c57d265..6a5c44e4d9 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -20,7 +20,6 @@ from __future__ import print_function
import collections
import copy
-import linecache
import os
import re
import sys
@@ -49,7 +48,9 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import registry
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import traceable_stack
from tensorflow.python.framework import versions
+from tensorflow.python.util import tf_stack
from tensorflow.python.ops import control_flow_util
from tensorflow.python.platform import app
from tensorflow.python.platform import tf_logging as logging
@@ -1711,10 +1712,14 @@ class Operation(object):
# This will be set by self.inputs.
self._inputs_val = None
- self._id_value = self._graph._next_id() # pylint: disable=protected-access
+ # pylint: disable=protected-access
+ self._id_value = self._graph._next_id()
self._original_op = original_op
- self._traceback = self._graph._extract_stack() # pylint: disable=protected-access
- self._control_flow_context = self.graph._get_control_flow_context() # pylint: disable=protected-access
+ self._traceback = tf_stack.extract_stack()
+ # List of traceable_stack.TraceableObjects for colocation context managers.
+ self._colocation_code_locations = None
+ self._control_flow_context = self.graph._get_control_flow_context()
+ # pylint: enable=protected-access
# Initialize self._c_op.
if c_op:
@@ -1853,6 +1858,42 @@ class Operation(object):
return c_api.TF_OperationDevice(self._c_op)
@property
+ def _colocation_dict(self):
+ """Code locations for colocation context managers active at op creation.
+
+ This property will return a dictionary for which the keys are nodes with
+ which this Operation is colocated, and for which the values are
+ traceable_stack.TraceableObject instances. The TraceableObject instances
+ record the location of the relevant colocation context manager but have the
+ "obj" field set to None to prevent leaking private data.
+
+ For example, suppose file_a contained these lines:
+
+ file_a.py:
+ 14: node_a = tf.constant(3, name='NODE_A')
+ 15: with tf.colocate_with(node_a):
+ 16: node_b = tf.constant(4, name='NODE_B')
+
+ Then a TraceableObject t_obj representing the colocation context manager
+ would have these member values:
+
+ t_obj.obj -> None
+ t_obj.name = 'NODE_A'
+ t_obj.filename = 'file_a.py'
+ t_obj.lineno = 15
+
+ and node_b.op._colocation_code_locations would return the dictionary
+
+ { 'NODE_A': t_obj }
+
+ Returns:
+ {str: traceable_stack.TraceableObject} as per this method's description,
+ above.
+ """
+ locations_dict = self._colocation_code_locations or {}
+ return locations_dict.copy()
+
+ @property
def _output_types(self):
"""List this operation's output types.
@@ -2154,7 +2195,7 @@ class Operation(object):
@property
def traceback(self):
"""Returns the call stack from when this operation was constructed."""
- return self._graph._convert_stack(self._traceback) # pylint: disable=protected-access
+ return tf_stack.convert_stack(self._traceback)
@property
def traceback_with_start_lines(self):
@@ -2163,9 +2204,8 @@ class Operation(object):
Returns:
A list of 5-tuples (filename, lineno, name, code, func_start_lineno).
"""
- return self._graph._convert_stack( # pylint: disable=protected-access
- self._traceback,
- include_func_start_lineno=True)
+ return tf_stack.convert_stack(self._traceback,
+ include_func_start_lineno=True)
def _set_attr(self, attr_name, attr_value):
"""Private method used to set an attribute in the node_def."""
@@ -2617,7 +2657,6 @@ def _name_from_scope_name(name):
_MUTATION_LOCK_GROUP = 0
_SESSION_RUN_LOCK_GROUP = 1
-
@tf_export("Graph")
class Graph(object):
"""A TensorFlow computation, represented as a dataflow graph.
@@ -2726,7 +2765,7 @@ class Graph(object):
self._building_function = False
# Stack of colocate_with ops. After switch_to_thread_local(),
# self._thread_local._colocation_stack is used instead.
- self._graph_colocation_stack = []
+ self._graph_colocation_stack = traceable_stack.TraceableStack()
# Set of tensors that are dangerous to feed!
self._unfeedable_tensors = set()
# Set of operations that are dangerous to fetch!
@@ -2766,36 +2805,6 @@ class Graph(object):
"""Temporary hack; can be overridden to force C API usage."""
return _USE_C_API
- def _convert_stack(self, stack, include_func_start_lineno=False):
- """Converts a stack extracted using _extract_stack() to a traceback stack.
-
- Args:
- stack: A list of n 5-tuples,
- (filename, lineno, name, frame_globals, func_start_lineno).
- include_func_start_lineno: True if function start line number should be
- included as the 5th entry in return tuples.
-
- Returns:
- A list of n 4-tuples or 5-tuples
- (filename, lineno, name, code, [optional: func_start_lineno]), where the
- code tuple element is calculated from the corresponding elements of the
- input tuple.
- """
- ret = []
- for (filename, lineno, name, frame_globals, func_start_lineno,
- unused_frame_info) in stack:
- linecache.checkcache(filename)
- line = linecache.getline(filename, lineno, frame_globals)
- if line:
- line = line.strip()
- else:
- line = None
- if include_func_start_lineno:
- ret.append((filename, lineno, name, line, func_start_lineno))
- else:
- ret.append((filename, lineno, name, line))
- return ret
-
# Note: this method is private because the API of tf.Graph() is public and
# frozen, and this functionality is still not ready for public visibility.
@tf_contextlib.contextmanager
@@ -2803,63 +2812,23 @@ class Graph(object):
# This step makes a copy of the existing stack, and it also initializes
# self._thread_local._variable_creator_stack if it doesn't exist yet.
old = list(self._variable_creator_stack)
- self._thread_local._variable_creator_stack.append(creator)
+ self._thread_local._variable_creator_stack.append(creator) # pylint: disable=protected-access
try:
yield
finally:
- self._thread_local._variable_creator_stack = old
+ self._thread_local._variable_creator_stack = old # pylint: disable=protected-access
# Note: this method is private because the API of tf.Graph() is public and
# frozen, and this functionality is still not ready for public visibility.
@property
def _variable_creator_stack(self):
if not hasattr(self._thread_local, "_variable_creator_stack"):
- self._thread_local._variable_creator_stack = []
- return list(self._thread_local._variable_creator_stack)
+ self._thread_local._variable_creator_stack = [] # pylint: disable=protected-access
+ return list(self._thread_local._variable_creator_stack) # pylint: disable=protected-access
@_variable_creator_stack.setter
def _variable_creator_stack(self, variable_creator_stack):
- self._thread_local._variable_creator_stack = variable_creator_stack
-
- def _extract_stack(self):
- """A lightweight, extensible re-implementation of traceback.extract_stack.
-
- NOTE(mrry): traceback.extract_stack eagerly retrieves the line of code for
- each stack frame using linecache, which results in an abundance of stat()
- calls. This implementation does not retrieve the code, and any consumer
- should apply _convert_stack to the result to obtain a traceback that can
- be formatted etc. using traceback methods.
-
- Derived classes can implement _extract_frame_info() to add extra information
- to the traceback.
-
- Returns:
- A list of 6-tuples
- (filename, lineno, name, frame_globals, func_start_lineno, custom_info)
- corresponding to the call stack of the current thread.
- """
- try:
- raise ZeroDivisionError
- except ZeroDivisionError:
- f = sys.exc_info()[2].tb_frame.f_back
- ret = []
- while f is not None:
- lineno = f.f_lineno
- co = f.f_code
- filename = co.co_filename
- name = co.co_name
- frame_globals = f.f_globals
- func_start_lineno = co.co_firstlineno
- frame_info = self._extract_frame_info(f)
- ret.append((filename, lineno, name, frame_globals, func_start_lineno,
- frame_info))
- f = f.f_back
- ret.reverse()
- return ret
-
- def _extract_frame_info(self, frame): # pylint: disable=unused-argument
- """Extracts custom information from a frame in an op traceback."""
- return None
+ self._thread_local._variable_creator_stack = variable_creator_stack # pylint: disable=protected-access
def _check_not_finalized(self):
"""Check if the graph is finalized.
@@ -3301,7 +3270,7 @@ class Graph(object):
if self._colocation_stack:
all_colocation_groups = []
- for colocation_op in self._colocation_stack:
+ for colocation_op in self._colocation_stack.peek_objs():
all_colocation_groups.extend(colocation_op.colocation_groups())
if colocation_op.device:
# Make this device match the device of the colocated op, to provide
@@ -3320,6 +3289,7 @@ class Graph(object):
# pylint: disable=protected-access
op._set_attr("_class", attr_value_pb2.AttrValue(
list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups)))
+ op._colocation_code_locations = self._snapshot_colocation_stack_metadata()
# pylint: enable=protected-access
# Sets "container" attribute if
@@ -3629,9 +3599,13 @@ class Graph(object):
This method should be used if you want to create multiple graphs
in the same process. For convenience, a global default graph is
provided, and all ops will be added to this graph if you do not
- create a new graph explicitly. Use this method with the `with` keyword
- to specify that ops created within the scope of a block should be
- added to this graph.
+ create a new graph explicitly.
+
+ Use this method with the `with` keyword to specify that ops created within
+ the scope of a block should be added to this graph. In this case, once
+ the scope of the `with` is exited, the previous default graph is set again
+ as default. There is a stack, so it's ok to have multiple nested levels
+ of `as_default` calls.
The default graph is a property of the current thread. If you
create a new thread, and wish to use the default graph in that
@@ -4074,10 +4048,13 @@ class Graph(object):
if ignore_existing:
current_stack = self._colocation_stack
- self._colocation_stack = []
+ self._colocation_stack = traceable_stack.TraceableStack()
if op is not None:
- self._colocation_stack.append(op)
+ # offset refers to the stack frame used for storing code location.
+ # We use 4, the sum of 1 to use our caller's stack frame and 3
+ # to jump over layers of context managers above us.
+ self._colocation_stack.push_obj(op, offset=4)
try:
yield
@@ -4085,7 +4062,7 @@ class Graph(object):
# Restore device function stack
self._device_function_stack = device_fn_tmp
if op is not None:
- self._colocation_stack.pop()
+ self._colocation_stack.pop_obj()
# Reset the colocation stack if requested.
if ignore_existing:
@@ -4712,15 +4689,24 @@ class Graph(object):
@property
def _colocation_stack(self):
+ """Return thread-local copy of colocation stack."""
if self._stack_state_is_thread_local:
# This may be called from a thread where colocation_stack doesn't yet
# exist.
if not hasattr(self._thread_local, "_colocation_stack"):
- self._thread_local._colocation_stack = self._graph_colocation_stack[:]
+ stack_copy_for_this_thread = self._graph_colocation_stack.copy()
+ # pylint: disable=protected-access
+ self._thread_local._colocation_stack = stack_copy_for_this_thread
+ # pylint: enable=protected-access
return self._thread_local._colocation_stack
else:
return self._graph_colocation_stack
+ def _snapshot_colocation_stack_metadata(self):
+ """Return colocation stack metadata as a dictionary."""
+ traceable_objects = self._colocation_stack.peek_traceable_objs()
+ return {obj.obj.name: obj.copy_metadata() for obj in traceable_objects}
+
@_colocation_stack.setter
def _colocation_stack(self, colocation_stack):
if self._stack_state_is_thread_local:
@@ -5251,7 +5237,10 @@ def enable_eager_execution(config=None,
to this function.
"""
return enable_eager_execution_internal(
- config, device_policy, execution_mode, None)
+ config=config,
+ device_policy=device_policy,
+ execution_mode=execution_mode,
+ server_def=None)
def enable_eager_execution_internal(config=None,