aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/combinations.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/distribute/python/combinations.py')
-rw-r--r--tensorflow/contrib/distribute/python/combinations.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py
index a1efbcaf9a..aeec9c44d7 100644
--- a/tensorflow/contrib/distribute/python/combinations.py
+++ b/tensorflow/contrib/distribute/python/combinations.py
@@ -56,7 +56,7 @@ from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.training import adam
-from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import gradient_descent
from tensorflow.python.util import tf_inspect
@@ -320,7 +320,7 @@ class NamedDistribution(object):
# pylint: disable=g-long-lambda
default_strategy = NamedDistribution(
"Default",
- lambda: distribute_lib._default_distribution_strategy, # pylint: disable=protected-access
+ distribution_strategy_context._get_default_distribution_strategy, # pylint: disable=protected-access
required_gpus=None)
one_device_strategy = NamedDistribution(
"OneDeviceCPU", lambda: one_device_lib.OneDeviceStrategy("/cpu:0"),