From be7a5e4b6e50842bc3c841daaa8dadadc793dd5f Mon Sep 17 00:00:00 2001 From: Youlong Cheng Date: Fri, 24 Aug 2018 17:46:49 -0700 Subject: Support Input partition in Predict mode. PiperOrigin-RevId: 210185805 --- tensorflow/contrib/tpu/python/tpu/tpu_estimator.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 2e4050bd99..1ff04f5c26 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -804,11 +804,14 @@ def generate_per_host_v2_enqueue_ops_fn_for_host( per_host_sharded_inputs.append(flattened_inputs) if inputs_structure_recorder.flattened_input_dims: + input_partition_dims = inputs_structure_recorder.flattened_input_dims + if signals: + input_partition_dims += [None] * len(signals) # pylint: disable=protected-access infeed_queue = tpu_feed._PartitionedInfeedQueue( number_of_tuple_elements=len(per_host_sharded_inputs[0]), host_id=host_id, - input_partition_dims=inputs_structure_recorder.flattened_input_dims, + input_partition_dims=input_partition_dims, device_assignment=ctx.device_assignment) per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( per_host_sharded_inputs) @@ -2821,8 +2824,6 @@ def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): """Executes `model_fn_wrapper` multiple times on all TPU shards.""" - num_cores = ctx.num_cores - (single_tpu_predict_step, host_calls, captured_scaffold_fn, captured_predict_hooks ) = model_fn_wrapper.convert_to_single_tpu_predict_step(dequeue_fn) @@ -2841,7 +2842,7 @@ def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): (dummy_predict_op,) = tpu.shard( multi_tpu_predict_steps_on_single_shard, inputs=[], - num_shards=num_cores, + num_shards=ctx.num_replicas, outputs_from_all_shards=False, device_assignment=ctx.device_assignment) -- cgit v1.2.3