aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/values.py
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2018-06-13 15:48:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-13 15:51:15 -0700
commit88ad9949ef4ea6e07105a326a1d21c108cb2883a (patch)
tree912ce0a69fb2f04d1f4170f5584c6fe3abc39d6f /tensorflow/contrib/distribute/python/values.py
parente2213af0f25d17c5d91337aaf1ad5815ed5d2871 (diff)
Make ops.colocate_with work with tower-local variables as well.
PiperOrigin-RevId: 200467472
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r--tensorflow/contrib/distribute/python/values.py36
1 files changed, 25 insertions, 11 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 9572ade8e4..aca544b7e7 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -238,17 +238,6 @@ class DistributedVariable(DistributedDelegate):
pass
-# Register a conversion function which reads the value of the variable,
-# allowing instances of the class to be used as tensors.
-def _tensor_conversion(var, dtype=None, name=None, as_ref=False):
- # Try to avoid assignments to and other mutations of MirroredVariable
- # state except through a DistributionStrategy.update() call.
- assert not as_ref
- return ops.internal_convert_to_tensor(
- var.get(), dtype=dtype, name=name, as_ref=as_ref)
-
-
-ops.register_tensor_conversion_function(DistributedVariable, _tensor_conversion)
ops.register_dense_tensor_like_type(DistributedVariable)
@@ -342,6 +331,20 @@ class MirroredVariable(DistributedVariable, Mirrored,
return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory}
+# Register a conversion function which reads the value of the variable,
+# allowing instances of the class to be used as tensors.
+def _tensor_conversion_mirrored(var, dtype=None, name=None, as_ref=False):
+ # Try to avoid assignments to and other mutations of MirroredVariable
+ # state except through a DistributionStrategy.update() call.
+ assert not as_ref
+ return ops.internal_convert_to_tensor(
+ var.get(), dtype=dtype, name=name, as_ref=as_ref)
+
+
+ops.register_tensor_conversion_function(MirroredVariable,
+ _tensor_conversion_mirrored)
+
+
class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject):
"""Class for defining how to restore a TowerLocalVariable."""
@@ -431,6 +434,17 @@ class TowerLocalVariable(DistributedVariable, PerDevice,
return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory}
+# Register a conversion function for TowerLocalVariable which allows as_ref to
+# be true.
+def _tensor_conversion_tower_local(var, dtype=None, name=None, as_ref=False):
+ return ops.internal_convert_to_tensor(
+ var.get(), dtype=dtype, name=name, as_ref=as_ref)
+
+
+ops.register_tensor_conversion_function(TowerLocalVariable,
+ _tensor_conversion_tower_local)
+
+
def _devices_match(d1, d2):
return device_util.canonicalize(d1) == device_util.canonicalize(d2)