aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-04-06 17:43:30 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-06 19:05:33 -0700
commit983fdfe955a16e6a67204b825782e1a30141482a (patch)
treee748d9311f944d58610ec490215634f77caa067f
parent7f28f166092e8f6621bc264e12a7201a22f76997 (diff)
Colocate ResourceVariable reads with their handles.
Change: 152455939
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py16
1 files changed, 10 insertions, 6 deletions
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index 86e0cae27a..77f0468c01 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -197,8 +197,10 @@ class ResourceVariable(object):
self._initialize_op = gen_resource_variable_ops.assign_variable_op(
self._handle, self._initial_value, name=n)
with ops.name_scope("Read"), ops.colocate_with(self._handle):
- value = gen_resource_variable_ops.read_variable_op(
- self._handle, dtype=self._dtype)
+ # Manually assign reads to the handle's device to avoid log messages.
+ with ops.device(self._handle.device):
+ value = gen_resource_variable_ops.read_variable_op(
+ self._handle, dtype=self._dtype)
self._graph_element = value
if caching_device is not None:
# Variables may be created in a tf.device() or ops.colocate_with()
@@ -276,8 +278,9 @@ class ResourceVariable(object):
"""A cached operation which reads the value of this variable."""
if self._cached_value is not None:
return self._cached_value
- return gen_resource_variable_ops.read_variable_op(
- self._handle, dtype=self._dtype)
+ with ops.device(self._handle.device):
+ return gen_resource_variable_ops.read_variable_op(
+ self._handle, dtype=self._dtype)
def _as_graph_element(self):
"""Conversion function for Graph.as_graph_element()."""
@@ -318,8 +321,9 @@ class ResourceVariable(object):
the read operation.
"""
with ops.name_scope("Read"):
- value = gen_resource_variable_ops.read_variable_op(
- self._handle, dtype=self._dtype)
+ with ops.device(self._handle.device):
+ value = gen_resource_variable_ops.read_variable_op(
+ self._handle, dtype=self._dtype)
# Return an identity so it can get placed on whatever device the context
# specifies instead of the device where the variable is.
return array_ops.identity(value)