diff options
author | Youlong Cheng <ylc@google.com> | 2018-08-03 14:37:58 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-03 14:56:32 -0700 |
commit | 1e1522f9f3837106383e4d465e21d55a87f23325 (patch) | |
tree | 1d17a28978e11eb76ce0c2fc8813999ffefa7e05 | |
parent | d6e18992832a4b2ac7daf8d32b1593d61d5e006e (diff) |
PUBLIC: PREDICT mode should respect ctx.device_assignment.
PiperOrigin-RevId: 207326276
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/tpu_estimator.py | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index bd8f2c99a8..c104b2403c 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -2886,7 +2886,8 @@ def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): multi_tpu_predict_steps_on_single_shard, inputs=[], num_shards=num_cores, - outputs_from_all_shards=False) + outputs_from_all_shards=False, + device_assignment=ctx.device_assignment) scaffold = _get_scaffold(captured_scaffold_fn) return dummy_predict_op, host_calls, scaffold, captured_predict_hooks.get() |