diff options
Diffstat (limited to 'tensorflow/python/training/checkpoint_utils.py')
-rw-r--r-- | tensorflow/python/training/checkpoint_utils.py | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/tensorflow/python/training/checkpoint_utils.py b/tensorflow/python/training/checkpoint_utils.py index 52d092bc22..e7f88de1d2 100644 --- a/tensorflow/python/training/checkpoint_utils.py +++ b/tensorflow/python/training/checkpoint_utils.py @@ -290,7 +290,11 @@ def _set_checkpoint_initializer(variable, name: Name of the operation. """ base_type = variable.dtype.base_dtype - with ops.colocate_with(variable.op): + # Do not colocate with variable since RestoreV2 op only runs on CPU and + # colocation will force variable (and other ops that colocate with variable) + # to be on CPU as well. It is okay to place the variable's initializer op on + # CPU since it will only be run once at the start. + with ops.device(variable.device), ops.device("/cpu:0"): restore_op = io_ops.restore_v2( ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0] if isinstance(variable, resource_variable_ops.ResourceVariable): |