diff options
author | Derek Murray <mrry@google.com> | 2016-04-14 10:26:23 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-04-14 11:33:07 -0700 |
commit | f34cf72a0063c130d909a285560cd9831f5fa04d (patch) | |
tree | 573609b9b4286725a097ecb28bf131a5dc52171a /tensorflow/python/ops/op_def_library.py | |
parent | 5d2075f12b0b43a09a432e077334a87394933cd2 (diff) |
Add explicit colocation between ref producer and ref consumer in Python client.
This prevents unsatisfiable placements resulting from code of the form:
```python
with tf.device("/job:ps"):
v = tf.Variable(0)
with tf.device("/job:worker"):
# ...
# Should have device "/job:ps" but would currently get "/job:worker"
assign_op = v.assign(1)
```
Change: 119871989
Diffstat (limited to 'tensorflow/python/ops/op_def_library.py')
-rw-r--r-- | tensorflow/python/ops/op_def_library.py | 47 |
1 files changed, 38 insertions, 9 deletions
diff --git a/tensorflow/python/ops/op_def_library.py b/tensorflow/python/ops/op_def_library.py index d50507a4e1..d657cc1333 100644 --- a/tensorflow/python/ops/op_def_library.py +++ b/tensorflow/python/ops/op_def_library.py @@ -19,6 +19,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import contextlib + import six from tensorflow.core.framework import attr_value_pb2 @@ -238,6 +240,27 @@ class _OpInfo(object): (arg.number_attr, op_def.name, attr_type)) +# pylint: disable=g-doc-return-or-yield +@contextlib.contextmanager +def _MaybeColocateWith(inputs): + """A context manager for (maybe) colocating with a list of input tensors. + + Args: + inputs: A list of `Tensor` or `Operation` objects. + + Returns: + A context manager. + """ + if not inputs: + yield + else: + # NOTE(mrry): The `ops.colocate_with()` function accepts only a single + # op or tensor, so we create one context manager per element in the list. + with ops.colocate_with(inputs[0]), _MaybeColocateWith(inputs[1:]): + yield +# pylint: enable=g-doc-return-or-yield + + class OpDefLibrary(object): """Holds a collection of OpDefs, can add the corresponding Ops to a graph.""" @@ -648,14 +671,20 @@ class OpDefLibrary(object): raise TypeError("apply_op() got unexpected keyword arguments: " + ", ".join(sorted(keywords.keys()))) - # Add Op to graph - if output_structure: - op = g.create_op(op_type_name, inputs, output_types, name=scope, - input_types=input_types, attrs=attr_protos, - op_def=op_def) - outputs = op.outputs - return _Restructure(ops.convert_n_to_tensor(outputs), output_structure) - else: - return g.create_op(op_type_name, inputs, output_types, name=scope, + # NOTE(mrry): We add an explicit colocation constraint between + # the newly created op and any of its reference-typed inputs. + must_colocate_inputs = [val for arg, val in zip(op_def.input_arg, inputs) + if arg.is_ref] + with _MaybeColocateWith(must_colocate_inputs): + # Add Op to graph + if output_structure: + op = g.create_op(op_type_name, inputs, output_types, name=scope, input_types=input_types, attrs=attr_protos, op_def=op_def) + outputs = op.outputs + return _Restructure(ops.convert_n_to_tensor(outputs), + output_structure) + else: + return g.create_op(op_type_name, inputs, output_types, name=scope, + input_types=input_types, attrs=attr_protos, + op_def=op_def) |