aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/values.py
diff options
context:
space:
mode:
authorGravatar Ruoxin Sang <rxsang@google.com>2018-10-04 19:07:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 19:12:18 -0700
commit5608454c31bb298096bb6aa463b33baa2fa68f08 (patch)
tree83dfeef619e1c02a83ed3967ec95e97a24f3e981 /tensorflow/contrib/distribute/python/values.py
parent83ff640fa5026b8bd3cb9c2ceff9e99e8e03823a (diff)
Add 'device' property to TPUMirroredVariable, so tf.train.init_from_checkpoint can be supported.
PiperOrigin-RevId: 215843249
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r--tensorflow/contrib/distribute/python/values.py4
1 files changed, 4 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 18ceba42c2..0dd78ba185 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -571,6 +571,10 @@ class TPUMirroredVariable(checkpointable.CheckpointableBase):
ValueError("Device %s not found in %s (current device %s)" %
(device, self._index.keys(), device_util.current())), e)
+ @property
+ def device(self):
+ return self._get().device
+
# The arguments to update() are automatically unwrapped so the update()
# function would normally see regular variables, not MirroredVariables.
# However, the update function can still operate on wrapped MirroredVariables