aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index bd8f2c99a8..c104b2403c 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -2886,7 +2886,8 @@ def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
multi_tpu_predict_steps_on_single_shard,
inputs=[],
num_shards=num_cores,
- outputs_from_all_shards=False)
+ outputs_from_all_shards=False,
+ device_assignment=ctx.device_assignment)
scaffold = _get_scaffold(captured_scaffold_fn)
return dummy_predict_op, host_calls, scaffold, captured_predict_hooks.get()