aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/checkpoint_utils.py
diff options
context:
space:
mode:
authorGravatar Priya Gupta <priyag@google.com>2018-08-14 11:22:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-14 11:31:36 -0700
commit77fabbeabb5b9061d8c606050c1ea79aec990c03 (patch)
tree1495d6acb396eebd40c703b891a4f2e7437a8532 /tensorflow/python/training/checkpoint_utils.py
parentcea262e16a004d73295259c42f21e2655da3df13 (diff)
1. Move distribution strategy context utility methods to a separate file with few dependencies. This allows us to import this in some places without creating circular dependencies as the original file imported many things.
2. Move the stack used in distribution strategy context to the graph. This allows us to use different strategies in different graphs (for e.g. in train and eval). This fixes #21412 and #21180. PiperOrigin-RevId: 208680454
Diffstat (limited to 'tensorflow/python/training/checkpoint_utils.py')
-rw-r--r--tensorflow/python/training/checkpoint_utils.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/python/training/checkpoint_utils.py b/tensorflow/python/training/checkpoint_utils.py
index 9b72b09f08..e6118177fd 100644
--- a/tensorflow/python/training/checkpoint_utils.py
+++ b/tensorflow/python/training/checkpoint_utils.py
@@ -29,7 +29,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_management
-from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import saver
from tensorflow.python.util.tf_export import tf_export
@@ -180,10 +180,10 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
tf.errors.OpError: If missing checkpoints or tensors in checkpoints.
ValueError: If missing variables in current graph.
"""
- if distribute_lib.get_cross_tower_context():
+ if distribution_strategy_context.get_cross_tower_context():
_init_from_checkpoint(None, ckpt_dir_or_file, assignment_map)
else:
- distribute_lib.get_tower_context().merge_call(
+ distribution_strategy_context.get_tower_context().merge_call(
_init_from_checkpoint, ckpt_dir_or_file, assignment_map)