diff options
author | Yuefeng Zhou <yuefengz@google.com> | 2018-06-13 15:48:32 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-13 15:51:15 -0700 |
commit | 88ad9949ef4ea6e07105a326a1d21c108cb2883a (patch) | |
tree | 912ce0a69fb2f04d1f4170f5584c6fe3abc39d6f /tensorflow/contrib/distribute/python/values.py | |
parent | e2213af0f25d17c5d91337aaf1ad5815ed5d2871 (diff) |
Make ops.colocate_with work with tower-local variables as well.
PiperOrigin-RevId: 200467472
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/values.py | 36 |
1 files changed, 25 insertions, 11 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 9572ade8e4..aca544b7e7 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -238,17 +238,6 @@ class DistributedVariable(DistributedDelegate): pass -# Register a conversion function which reads the value of the variable, -# allowing instances of the class to be used as tensors. -def _tensor_conversion(var, dtype=None, name=None, as_ref=False): - # Try to avoid assignments to and other mutations of MirroredVariable - # state except through a DistributionStrategy.update() call. - assert not as_ref - return ops.internal_convert_to_tensor( - var.get(), dtype=dtype, name=name, as_ref=as_ref) - - -ops.register_tensor_conversion_function(DistributedVariable, _tensor_conversion) ops.register_dense_tensor_like_type(DistributedVariable) @@ -342,6 +331,20 @@ class MirroredVariable(DistributedVariable, Mirrored, return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory} +# Register a conversion function which reads the value of the variable, +# allowing instances of the class to be used as tensors. +def _tensor_conversion_mirrored(var, dtype=None, name=None, as_ref=False): + # Try to avoid assignments to and other mutations of MirroredVariable + # state except through a DistributionStrategy.update() call. + assert not as_ref + return ops.internal_convert_to_tensor( + var.get(), dtype=dtype, name=name, as_ref=as_ref) + + +ops.register_tensor_conversion_function(MirroredVariable, + _tensor_conversion_mirrored) + + class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject): """Class for defining how to restore a TowerLocalVariable.""" @@ -431,6 +434,17 @@ class TowerLocalVariable(DistributedVariable, PerDevice, return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory} +# Register a conversion function for TowerLocalVariable which allows as_ref to +# be true. +def _tensor_conversion_tower_local(var, dtype=None, name=None, as_ref=False): + return ops.internal_convert_to_tensor( + var.get(), dtype=dtype, name=name, as_ref=as_ref) + + +ops.register_tensor_conversion_function(TowerLocalVariable, + _tensor_conversion_tower_local) + + def _devices_match(d1, d2): return device_util.canonicalize(d1) == device_util.canonicalize(d2) |