aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/op_def_library.py
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2016-04-14 10:26:23 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-04-14 11:33:07 -0700
commitf34cf72a0063c130d909a285560cd9831f5fa04d (patch)
tree573609b9b4286725a097ecb28bf131a5dc52171a /tensorflow/python/ops/op_def_library.py
parent5d2075f12b0b43a09a432e077334a87394933cd2 (diff)
Add explicit colocation between ref producer and ref consumer in Python client.
This prevents unsatisfiable placements resulting from code of the form: ```python with tf.device("/job:ps"): v = tf.Variable(0) with tf.device("/job:worker"): # ... # Should have device "/job:ps" but would currently get "/job:worker" assign_op = v.assign(1) ``` Change: 119871989
Diffstat (limited to 'tensorflow/python/ops/op_def_library.py')
-rw-r--r--tensorflow/python/ops/op_def_library.py47
1 files changed, 38 insertions, 9 deletions
diff --git a/tensorflow/python/ops/op_def_library.py b/tensorflow/python/ops/op_def_library.py
index d50507a4e1..d657cc1333 100644
--- a/tensorflow/python/ops/op_def_library.py
+++ b/tensorflow/python/ops/op_def_library.py
@@ -19,6 +19,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import contextlib
+
import six
from tensorflow.core.framework import attr_value_pb2
@@ -238,6 +240,27 @@ class _OpInfo(object):
(arg.number_attr, op_def.name, attr_type))
+# pylint: disable=g-doc-return-or-yield
+@contextlib.contextmanager
+def _MaybeColocateWith(inputs):
+ """A context manager for (maybe) colocating with a list of input tensors.
+
+ Args:
+ inputs: A list of `Tensor` or `Operation` objects.
+
+ Returns:
+ A context manager.
+ """
+ if not inputs:
+ yield
+ else:
+ # NOTE(mrry): The `ops.colocate_with()` function accepts only a single
+ # op or tensor, so we create one context manager per element in the list.
+ with ops.colocate_with(inputs[0]), _MaybeColocateWith(inputs[1:]):
+ yield
+# pylint: enable=g-doc-return-or-yield
+
+
class OpDefLibrary(object):
"""Holds a collection of OpDefs, can add the corresponding Ops to a graph."""
@@ -648,14 +671,20 @@ class OpDefLibrary(object):
raise TypeError("apply_op() got unexpected keyword arguments: " +
", ".join(sorted(keywords.keys())))
- # Add Op to graph
- if output_structure:
- op = g.create_op(op_type_name, inputs, output_types, name=scope,
- input_types=input_types, attrs=attr_protos,
- op_def=op_def)
- outputs = op.outputs
- return _Restructure(ops.convert_n_to_tensor(outputs), output_structure)
- else:
- return g.create_op(op_type_name, inputs, output_types, name=scope,
+ # NOTE(mrry): We add an explicit colocation constraint between
+ # the newly created op and any of its reference-typed inputs.
+ must_colocate_inputs = [val for arg, val in zip(op_def.input_arg, inputs)
+ if arg.is_ref]
+ with _MaybeColocateWith(must_colocate_inputs):
+ # Add Op to graph
+ if output_structure:
+ op = g.create_op(op_type_name, inputs, output_types, name=scope,
input_types=input_types, attrs=attr_protos,
op_def=op_def)
+ outputs = op.outputs
+ return _Restructure(ops.convert_n_to_tensor(outputs),
+ output_structure)
+ else:
+ return g.create_op(op_type_name, inputs, output_types, name=scope,
+ input_types=input_types, attrs=attr_protos,
+ op_def=op_def)