aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu/python/tpu/tpu_feed.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tpu/python/tpu/tpu_feed.py')
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_feed.py16
1 files changed, 13 insertions, 3 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py
index 604e6600c8..a44b4f4622 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py
@@ -461,7 +461,10 @@ class InfeedQueue(object):
name=full_name,
device_ordinal=tpu_ordinal)
- def generate_enqueue_ops(self, sharded_inputs, tpu_ordinal_function=None):
+ def generate_enqueue_ops(self,
+ sharded_inputs,
+ tpu_ordinal_function=None,
+ placement_function=None):
"""Generates the host-side Ops to enqueue the shards of a tuple.
sharded_inputs is a list, one for each shard, of lists of
@@ -483,6 +486,9 @@ class InfeedQueue(object):
shard index as input and returns the ordinal of the TPU device
the shard's infeed should be placed on. tpu_ordinal_function must be
set if the inputs are placed on CPU devices.
+ placement_function: if not None, a function that takes the shard index as
+ input and returns the host device where the enqueue op should be placed
+ on.
Returns:
A list of host-side Ops, one for each shard, that when executed together
@@ -508,8 +514,12 @@ class InfeedQueue(object):
tpu_ordinal_function = lambda index: -1
name_prefix = "%s/enqueue" % self._name
return [
- self._generate_enqueue_op(shard, name_prefix, index,
- tpu_ordinal=tpu_ordinal_function(index))
+ self._generate_enqueue_op(
+ shard,
+ name_prefix,
+ index,
+ tpu_ordinal=tpu_ordinal_function(index),
+ device=placement_function(index) if placement_function else None)
for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards))
]