aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Youlong Cheng <ylc@google.com>2018-08-24 17:46:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-24 17:50:11 -0700
commitbe7a5e4b6e50842bc3c841daaa8dadadc793dd5f (patch)
tree3a681838c50942bfc2b3b4a6604055a842701fcf
parentb42c222b19cde1a8a72fdd81c483bd5a2b1f674e (diff)
Support Input partition in Predict mode.
PiperOrigin-RevId: 210185805
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py9
1 files changed, 5 insertions, 4 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 2e4050bd99..1ff04f5c26 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -804,11 +804,14 @@ def generate_per_host_v2_enqueue_ops_fn_for_host(
per_host_sharded_inputs.append(flattened_inputs)
if inputs_structure_recorder.flattened_input_dims:
+ input_partition_dims = inputs_structure_recorder.flattened_input_dims
+ if signals:
+ input_partition_dims += [None] * len(signals)
# pylint: disable=protected-access
infeed_queue = tpu_feed._PartitionedInfeedQueue(
number_of_tuple_elements=len(per_host_sharded_inputs[0]),
host_id=host_id,
- input_partition_dims=inputs_structure_recorder.flattened_input_dims,
+ input_partition_dims=input_partition_dims,
device_assignment=ctx.device_assignment)
per_host_enqueue_ops = infeed_queue.generate_enqueue_ops(
per_host_sharded_inputs)
@@ -2821,8 +2824,6 @@ def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
"""Executes `model_fn_wrapper` multiple times on all TPU shards."""
- num_cores = ctx.num_cores
-
(single_tpu_predict_step, host_calls, captured_scaffold_fn,
captured_predict_hooks
) = model_fn_wrapper.convert_to_single_tpu_predict_step(dequeue_fn)
@@ -2841,7 +2842,7 @@ def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
(dummy_predict_op,) = tpu.shard(
multi_tpu_predict_steps_on_single_shard,
inputs=[],
- num_shards=num_cores,
+ num_shards=ctx.num_replicas,
outputs_from_all_shards=False,
device_assignment=ctx.device_assignment)