diff options
Diffstat (limited to 'tensorflow/python/framework/ops.py')
-rw-r--r-- | tensorflow/python/framework/ops.py | 50 |
1 files changed, 41 insertions, 9 deletions
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index c25e29b0f4..ed0bf1afe0 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -44,6 +44,7 @@ from tensorflow.python.framework import c_api_util from tensorflow.python.framework import cpp_shape_inference_pb2 from tensorflow.python.framework import device as pydev from tensorflow.python.framework import dtypes +from tensorflow.python.framework import error_interpolation from tensorflow.python.framework import errors from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import registry @@ -454,7 +455,7 @@ class Tensor(_TensorLike): def __iter__(self): if not context.executing_eagerly(): raise TypeError( - "Tensor objects are not iterable when eager execution is not " + "Tensor objects are only iterable when eager execution is " "enabled. To iterate over this tensor use tf.map_fn.") shape = self._shape_tuple() if shape is None: @@ -3292,6 +3293,36 @@ class Graph(object): self._create_op_helper(ret, compute_device=compute_device) return ret + def _make_colocation_conflict_message(self, op, colocation_op): + """Return detailed error message about device conflict due to colocation.""" + # Example error message: + # Tried to colocate op 'a' (defined at file1.py:149) having device + # '/device:GPU:0' with op 'b' (defined at file2:96) which had an + # incompatible device '/device:CPU:0'. + # + # No node-device colocations were active during op 'a' creation. + # Device assignments active during op 'a' creation: + # with tf.device(/device:GPU:0): file1.py:148> + # + # Node-device colocations active during op 'b' creation: + # with tf.colocate_with(a): file2.py:93> + # Device assignments active during op 'b' creation: + # with tf.device(/cpu:0): file2.py:94 + op_info = error_interpolation.compute_field_dict(op) + coloc_op_info = error_interpolation.compute_field_dict(colocation_op) + msg = ("Tried to colocate op '{op_name}'{op_loc} having device '{op_dev}' " + "with op '{coloc_op_name}'{coloc_op_loc} which had an incompatible " + "device '{coloc_op_dev}'.\n\n{op_summary}\n\n{coloc_op_summary}" + .format(op_name=op.name, + op_loc=op_info["defined_at"], + op_dev=op.device, + op_summary=op_info["devs_and_colocs"], + coloc_op_name=colocation_op.name, + coloc_op_loc=coloc_op_info["defined_at"], + coloc_op_dev=colocation_op.device, + coloc_op_summary=coloc_op_info["devs_and_colocs"])) + return msg + def _create_op_helper(self, op, compute_device=True): """Common logic for creating an op in this graph.""" # Apply any additional attributes requested. Do not overwrite any existing @@ -3332,20 +3363,22 @@ class Graph(object): if compute_device: self._apply_device_functions(op) + # Snapshot the colocation stack metadata before we might generate error + # messages using it. Note that this snapshot depends on the actual stack + # and is independent of the op's _class attribute. + # pylint: disable=protected-access + op._colocation_code_locations = self._snapshot_colocation_stack_metadata() + # pylint: enable=protected-access + if self._colocation_stack: all_colocation_groups = [] for colocation_op in self._colocation_stack.peek_objs(): all_colocation_groups.extend(colocation_op.colocation_groups()) if colocation_op.device: - # Make this device match the device of the colocated op, to provide - # consistency between the device and the colocation property. if (op.device and pydev.canonical_name(op.device) != pydev.canonical_name(colocation_op.device)): - logging.warning("Tried to colocate %s with an op %s that had " - "a different device: %s vs %s. Postponing " - "error-checking until all devices are assigned.", - op.name, colocation_op.name, op.device, - colocation_op.device) + msg = self._make_colocation_conflict_message(op, colocation_op) + logging.warning(msg) else: op._set_device(colocation_op.device) # pylint: disable=protected-access @@ -3353,7 +3386,6 @@ class Graph(object): # pylint: disable=protected-access op._set_attr("_class", attr_value_pb2.AttrValue( list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups))) - op._colocation_code_locations = self._snapshot_colocation_stack_metadata() # pylint: enable=protected-access # Sets "container" attribute if |