aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework/ops.py')
-rw-r--r--tensorflow/python/framework/ops.py50
1 files changed, 41 insertions, 9 deletions
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index c25e29b0f4..ed0bf1afe0 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -44,6 +44,7 @@ from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import cpp_shape_inference_pb2
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import error_interpolation
from tensorflow.python.framework import errors
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import registry
@@ -454,7 +455,7 @@ class Tensor(_TensorLike):
def __iter__(self):
if not context.executing_eagerly():
raise TypeError(
- "Tensor objects are not iterable when eager execution is not "
+ "Tensor objects are only iterable when eager execution is "
"enabled. To iterate over this tensor use tf.map_fn.")
shape = self._shape_tuple()
if shape is None:
@@ -3292,6 +3293,36 @@ class Graph(object):
self._create_op_helper(ret, compute_device=compute_device)
return ret
+ def _make_colocation_conflict_message(self, op, colocation_op):
+ """Return detailed error message about device conflict due to colocation."""
+ # Example error message:
+ # Tried to colocate op 'a' (defined at file1.py:149) having device
+ # '/device:GPU:0' with op 'b' (defined at file2:96) which had an
+ # incompatible device '/device:CPU:0'.
+ #
+ # No node-device colocations were active during op 'a' creation.
+ # Device assignments active during op 'a' creation:
+ # with tf.device(/device:GPU:0): file1.py:148>
+ #
+ # Node-device colocations active during op 'b' creation:
+ # with tf.colocate_with(a): file2.py:93>
+ # Device assignments active during op 'b' creation:
+ # with tf.device(/cpu:0): file2.py:94
+ op_info = error_interpolation.compute_field_dict(op)
+ coloc_op_info = error_interpolation.compute_field_dict(colocation_op)
+ msg = ("Tried to colocate op '{op_name}'{op_loc} having device '{op_dev}' "
+ "with op '{coloc_op_name}'{coloc_op_loc} which had an incompatible "
+ "device '{coloc_op_dev}'.\n\n{op_summary}\n\n{coloc_op_summary}"
+ .format(op_name=op.name,
+ op_loc=op_info["defined_at"],
+ op_dev=op.device,
+ op_summary=op_info["devs_and_colocs"],
+ coloc_op_name=colocation_op.name,
+ coloc_op_loc=coloc_op_info["defined_at"],
+ coloc_op_dev=colocation_op.device,
+ coloc_op_summary=coloc_op_info["devs_and_colocs"]))
+ return msg
+
def _create_op_helper(self, op, compute_device=True):
"""Common logic for creating an op in this graph."""
# Apply any additional attributes requested. Do not overwrite any existing
@@ -3332,20 +3363,22 @@ class Graph(object):
if compute_device:
self._apply_device_functions(op)
+ # Snapshot the colocation stack metadata before we might generate error
+ # messages using it. Note that this snapshot depends on the actual stack
+ # and is independent of the op's _class attribute.
+ # pylint: disable=protected-access
+ op._colocation_code_locations = self._snapshot_colocation_stack_metadata()
+ # pylint: enable=protected-access
+
if self._colocation_stack:
all_colocation_groups = []
for colocation_op in self._colocation_stack.peek_objs():
all_colocation_groups.extend(colocation_op.colocation_groups())
if colocation_op.device:
- # Make this device match the device of the colocated op, to provide
- # consistency between the device and the colocation property.
if (op.device and pydev.canonical_name(op.device) !=
pydev.canonical_name(colocation_op.device)):
- logging.warning("Tried to colocate %s with an op %s that had "
- "a different device: %s vs %s. Postponing "
- "error-checking until all devices are assigned.",
- op.name, colocation_op.name, op.device,
- colocation_op.device)
+ msg = self._make_colocation_conflict_message(op, colocation_op)
+ logging.warning(msg)
else:
op._set_device(colocation_op.device) # pylint: disable=protected-access
@@ -3353,7 +3386,6 @@ class Graph(object):
# pylint: disable=protected-access
op._set_attr("_class", attr_value_pb2.AttrValue(
list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups)))
- op._colocation_code_locations = self._snapshot_colocation_stack_metadata()
# pylint: enable=protected-access
# Sets "container" attribute if