aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/checkpoint_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/checkpoint_utils.py')
-rw-r--r--tensorflow/python/training/checkpoint_utils.py6
1 files changed, 5 insertions, 1 deletions
diff --git a/tensorflow/python/training/checkpoint_utils.py b/tensorflow/python/training/checkpoint_utils.py
index 52d092bc22..e7f88de1d2 100644
--- a/tensorflow/python/training/checkpoint_utils.py
+++ b/tensorflow/python/training/checkpoint_utils.py
@@ -290,7 +290,11 @@ def _set_checkpoint_initializer(variable,
name: Name of the operation.
"""
base_type = variable.dtype.base_dtype
- with ops.colocate_with(variable.op):
+ # Do not colocate with variable since RestoreV2 op only runs on CPU and
+ # colocation will force variable (and other ops that colocate with variable)
+ # to be on CPU as well. It is okay to place the variable's initializer op on
+ # CPU since it will only be run once at the start.
+ with ops.device(variable.device), ops.device("/cpu:0"):
restore_op = io_ops.restore_v2(
ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0]
if isinstance(variable, resource_variable_ops.ResourceVariable):