aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-12 09:51:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-12 09:56:39 -0700
commitb0a15f21d2009ead9c8ed5e245a02b5c42355853 (patch)
treec510b2d9996def66f345b2022120c49a89701a4c
parent3b4f4164663da4c65807c34e7188e43c9d7d7535 (diff)
Make the return value of `read_var` consistently a tensor instead of
sometimes a variable. PiperOrigin-RevId: 200231463
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py2
-rw-r--r--tensorflow/contrib/distribute/python/one_device_strategy.py2
-rw-r--r--tensorflow/python/training/distribute.py4
3 files changed, 4 insertions, 4 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index 403e47d94f..900aa10e93 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -349,7 +349,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
if isinstance(tower_local_var, values.TowerLocalVariable):
return math_ops.add_n(self.unwrap(tower_local_var))
assert isinstance(tower_local_var, values.Mirrored)
- return tower_local_var.get()
+ return array_ops.identity(tower_local_var.get())
def _fetch(self, val, destination, fn):
"""Return a copy of `val` or `fn(val)` on `destination`."""
diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py
index 6378af32bd..7f4bab9d93 100644
--- a/tensorflow/contrib/distribute/python/one_device_strategy.py
+++ b/tensorflow/contrib/distribute/python/one_device_strategy.py
@@ -104,7 +104,7 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy):
def read_var(self, tower_local_var):
"""Read the aggregate value of a tower-local variable."""
- return tower_local_var
+ return array_ops.identity(tower_local_var)
def _fetch(self, val, destination, fn):
"""Return a copy of `val` or `fn(val)` on `destination`."""
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py
index 29198e48fa..caffd042a0 100644
--- a/tensorflow/python/training/distribute.py
+++ b/tensorflow/python/training/distribute.py
@@ -652,7 +652,7 @@ class DistributionStrategy(object):
"""Reads the value of a variable.
Returns the aggregate value of a tower-local variable, or the
- (possibly read-only) value of any other variable.
+ (read-only) value of any other variable.
Args:
v: A variable allocated within the scope of this `DistributionStrategy`.
@@ -1217,7 +1217,7 @@ class _DefaultDistributionStrategy(DistributionStrategy):
return fn(*args, **kwargs)
def read_var(self, tower_local_var):
- return tower_local_var
+ return array_ops.identity(tower_local_var)
def _fetch(self, var, destination, fn):
with ops.colocate_with(var):