aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu
diff options
context:
space:
mode:
authorGravatar Youlong Cheng <ylc@google.com>2018-10-01 12:28:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 12:32:35 -0700
commitf0c219d095f38f7ce6febfb68d4f84d64aa1829a (patch)
treeb930a9435b6c5af2e8ac5029fa3ffa0784119b4f /tensorflow/contrib/tpu
parent61a872068ece1355945ef2d88659e99de2fe7591 (diff)
Expose tpu_host_placement_function().
PiperOrigin-RevId: 215259803
Diffstat (limited to 'tensorflow/contrib/tpu')
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_context.py14
1 files changed, 14 insertions, 0 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
index 7cfb6c38fa..da6bdf67d6 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
@@ -154,6 +154,20 @@ class TPUContext(object):
# as far as model is replicated to all cores in the system.
return self._internal_ctx.device_for_replica(replica_id)
+ @property
+ def tpu_host_placement_function(self):
+ """Returns the TPU host place function.
+
+ The place function takes host_id as the input and returns the TF device
+ for the correspoding host.
+ """
+
+ def _placement_function(host_id):
+ """Return the host device given host_id."""
+ return self._internal_ctx.tpu_host_placement_function(host_id=host_id)
+
+ return _placement_function
+
class _InternalTPUContext(object):
"""A context holds immutable states of TPU computation.