aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/distribute/__init__.py3
-rw-r--r--tensorflow/contrib/distribute/python/combinations.py4
-rw-r--r--tensorflow/contrib/distribute/python/estimator_integration_test.py16
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py52
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_test.py5
-rw-r--r--tensorflow/contrib/distribute/python/parameter_server_strategy_test.py11
-rw-r--r--tensorflow/contrib/distribute/python/strategy_test_lib.py12
-rw-r--r--tensorflow/contrib/distribute/python/values.py39
-rw-r--r--tensorflow/contrib/metrics/python/metrics/classification.py4
-rw-r--r--tensorflow/contrib/optimizer_v2/optimizer_v2.py11
-rw-r--r--tensorflow/python/BUILD5
-rw-r--r--tensorflow/python/eager/function.py9
-rw-r--r--tensorflow/python/estimator/keras.py6
-rw-r--r--tensorflow/python/framework/ops.py12
-rw-r--r--tensorflow/python/keras/layers/normalization.py10
-rw-r--r--tensorflow/python/keras/metrics.py4
-rw-r--r--tensorflow/python/keras/optimizers.py4
-rw-r--r--tensorflow/python/ops/metrics_impl.py41
-rw-r--r--tensorflow/python/ops/summary_op_util.py4
-rw-r--r--tensorflow/python/training/checkpoint_utils.py6
-rw-r--r--tensorflow/python/training/distribute.py198
-rw-r--r--tensorflow/python/training/distribute_test.py53
-rw-r--r--tensorflow/python/training/distribution_strategy_context.py203
-rw-r--r--tensorflow/python/training/optimizer.py16
-rw-r--r--tensorflow/python/training/slot_creator.py8
25 files changed, 431 insertions, 305 deletions
diff --git a/tensorflow/contrib/distribute/__init__.py b/tensorflow/contrib/distribute/__init__.py
index 9123ca749b..5fa57f494c 100644
--- a/tensorflow/contrib/distribute/__init__.py
+++ b/tensorflow/contrib/distribute/__init__.py
@@ -22,13 +22,14 @@ from __future__ import print_function
from tensorflow.contrib.distribute.python.collective_all_reduce_strategy import CollectiveAllReduceStrategy
from tensorflow.contrib.distribute.python.cross_tower_ops import *
from tensorflow.contrib.distribute.python.mirrored_strategy import MirroredStrategy
-from tensorflow.contrib.distribute.python.multi_worker_strategy import MultiWorkerMirroredStrategy
from tensorflow.contrib.distribute.python.monitor import Monitor
+from tensorflow.contrib.distribute.python.multi_worker_strategy import MultiWorkerMirroredStrategy
from tensorflow.contrib.distribute.python.one_device_strategy import OneDeviceStrategy
from tensorflow.contrib.distribute.python.parameter_server_strategy import ParameterServerStrategy
from tensorflow.contrib.distribute.python.step_fn import *
from tensorflow.contrib.distribute.python.tpu_strategy import TPUStrategy
from tensorflow.python.training.distribute import *
+from tensorflow.python.training.distribution_strategy_context import *
from tensorflow.python.util.all_util import remove_undocumented
diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py
index a1efbcaf9a..aeec9c44d7 100644
--- a/tensorflow/contrib/distribute/python/combinations.py
+++ b/tensorflow/contrib/distribute/python/combinations.py
@@ -56,7 +56,7 @@ from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.training import adam
-from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import gradient_descent
from tensorflow.python.util import tf_inspect
@@ -320,7 +320,7 @@ class NamedDistribution(object):
# pylint: disable=g-long-lambda
default_strategy = NamedDistribution(
"Default",
- lambda: distribute_lib._default_distribution_strategy, # pylint: disable=protected-access
+ distribution_strategy_context._get_default_distribution_strategy, # pylint: disable=protected-access
required_gpus=None)
one_device_strategy = NamedDistribution(
"OneDeviceCPU", lambda: one_device_lib.OneDeviceStrategy("/cpu:0"),
diff --git a/tensorflow/contrib/distribute/python/estimator_integration_test.py b/tensorflow/contrib/distribute/python/estimator_integration_test.py
index 3e00cf4332..cc626c33bf 100644
--- a/tensorflow/contrib/distribute/python/estimator_integration_test.py
+++ b/tensorflow/contrib/distribute/python/estimator_integration_test.py
@@ -29,6 +29,7 @@ from tensorflow.contrib.optimizer_v2 import adagrad
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import test
from tensorflow.python.estimator import run_config
+from tensorflow.python.estimator import training
from tensorflow.python.estimator.canned import dnn_linear_combined
from tensorflow.python.estimator.canned import prediction_keys
from tensorflow.python.estimator.export import export
@@ -63,8 +64,9 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase,
combinations.one_device_strategy,
combinations.mirrored_strategy_with_gpu_and_cpu,
combinations.mirrored_strategy_with_two_gpus
- ]))
- def test_complete_flow_with_mode(self, distribution):
+ ],
+ use_train_and_evaluate=[True, False]))
+ def test_complete_flow_with_mode(self, distribution, use_train_and_evaluate):
label_dimension = 2
input_dimension = label_dimension
batch_size = 10
@@ -103,9 +105,15 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase,
train_distribute=distribution, eval_distribute=distribution))
num_steps = 10
- estimator.train(train_input_fn, steps=num_steps)
+ if use_train_and_evaluate:
+ scores, _ = training.train_and_evaluate(
+ estimator,
+ training.TrainSpec(train_input_fn, max_steps=num_steps),
+ training.EvalSpec(eval_input_fn))
+ else:
+ estimator.train(train_input_fn, steps=num_steps)
+ scores = estimator.evaluate(eval_input_fn)
- scores = estimator.evaluate(eval_input_fn)
self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP])
self.assertIn('loss', six.iterkeys(scores))
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index e064cfe37d..9a4cc0a897 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -40,7 +40,7 @@ from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import device_util
-from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
GPU_TEST = "test_gpu" in sys.argv[0]
@@ -164,7 +164,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
# This variable should be created only once across the threads because of
# special variable_creator functions used by `dist.call_for_each_tower`.
v = variable_scope.variable(1.0, name="foo")
- distribute_lib.get_tower_context().merge_call(lambda _: _)
+ distribution_strategy_context.get_tower_context().merge_call(lambda _: _)
return v
dist = mirrored_strategy.MirroredStrategy(
@@ -181,7 +181,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
def model_fn():
v = variable_scope.variable(1.0)
- distribute_lib.get_tower_context().merge_call(lambda _: _)
+ distribution_strategy_context.get_tower_context().merge_call(lambda _: _)
return v
dist = mirrored_strategy.MirroredStrategy(
@@ -201,7 +201,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
vs = []
for i in range(5):
vs.append(variable_scope.variable(1.0, name="foo" + str(i)))
- distribute_lib.get_tower_context().merge_call(lambda _: _)
+ distribution_strategy_context.get_tower_context().merge_call(lambda _: _)
return vs
dist = mirrored_strategy.MirroredStrategy(
@@ -223,7 +223,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
vs.append(variable_scope.variable(1.0, name="foo_1/bar"))
vs.append(variable_scope.variable(1.0, name="foo_1/bar_1"))
vs.append(variable_scope.variable(1.0, name="foo/bar_1"))
- distribute_lib.get_tower_context().merge_call(lambda _: _)
+ distribution_strategy_context.get_tower_context().merge_call(lambda _: _)
return vs
dist = mirrored_strategy.MirroredStrategy(
@@ -245,7 +245,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
def model_fn(device_id):
v = variable_scope.variable(1.0, name="foo_" + str(device_id))
- distribute_lib.get_tower_context().merge_call(lambda _: _)
+ distribution_strategy_context.get_tower_context().merge_call(lambda _: _)
return v
dist = mirrored_strategy.MirroredStrategy(
@@ -268,7 +268,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
layer2 = core.Dense(1)
layer2(features)
# This will pause the current thread, and execute the other thread.
- distribute_lib.get_tower_context().merge_call(lambda _: _)
+ distribution_strategy_context.get_tower_context().merge_call(
+ lambda _: _)
layer3 = core.Dense(1)
layer3(features)
return [(layer1.kernel, layer1.bias),
@@ -300,7 +301,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
with variable_scope.variable_scope("common"):
v1 = variable_scope.variable(1.0, name="var1")
# This will pause the current thread, and execute the other thread.
- distribute_lib.get_tower_context().merge_call(lambda _: _)
+ distribution_strategy_context.get_tower_context().merge_call(
+ lambda _: _)
v2 = variable_scope.variable(
1.0,
name="var2",
@@ -343,7 +345,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
with variable_scope.variable_scope("common"):
v1 = variable_scope.get_variable("var1", [1])
# This will pause the current thread, and execute the other thread.
- distribute_lib.get_tower_context().merge_call(lambda _: _)
+ distribution_strategy_context.get_tower_context().merge_call(
+ lambda _: _)
v2 = variable_scope.get_variable(
"var2", [1],
synchronization=variable_scope.VariableSynchronization.ON_READ,
@@ -453,7 +456,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
def model_fn():
v = variable_scope.variable(1.0, name="foo")
- distribute_lib.get_tower_context().merge_call(lambda _: _)
+ distribution_strategy_context.get_tower_context().merge_call(lambda _: _)
return v
dist = mirrored_strategy.MirroredStrategy(
@@ -470,7 +473,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
def model_fn(name):
v = variable_scope.variable(1.0, name=name)
- distribute_lib.get_tower_context().merge_call(lambda _: _)
+ distribution_strategy_context.get_tower_context().merge_call(lambda _: _)
return v
dist = mirrored_strategy.MirroredStrategy(
@@ -570,7 +573,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
def model_fn():
with ops.name_scope("foo"):
a = constant_op.constant(1.0, name="a")
- distribute_lib.get_tower_context().merge_call(lambda _: _)
+ distribution_strategy_context.get_tower_context().merge_call(
+ lambda _: _)
b = constant_op.constant(1.0, name="b")
return a, b
@@ -591,7 +595,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
def model_fn():
with ops.name_scope(None, "foo"):
a = constant_op.constant(1.0, name="a")
- distribute_lib.get_tower_context().merge_call(lambda _: _)
+ distribution_strategy_context.get_tower_context().merge_call(
+ lambda _: _)
b = constant_op.constant(2.0, name="b")
return a, b
@@ -619,7 +624,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
def model_fn():
b = variable_scope.variable(1.0, name="b")
with ops.name_scope("foo"):
- c = distribute_lib.get_tower_context().merge_call(in_cross_tower)
+ c = distribution_strategy_context.get_tower_context().merge_call(
+ in_cross_tower)
return b, c
dist = mirrored_strategy.MirroredStrategy(
@@ -651,7 +657,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
def model_fn():
b = variable_scope.get_variable("b", [1])
with ops.name_scope("foo"):
- c = distribute_lib.get_tower_context().merge_call(in_cross_tower)
+ c = distribution_strategy_context.get_tower_context().merge_call(
+ in_cross_tower)
return b, c
dist = mirrored_strategy.MirroredStrategy(
@@ -833,8 +840,9 @@ class MirroredVariableUpdateTest(test.TestCase):
self.assertEquals(1.0, self.evaluate(mirrored_var))
def model_fn():
- value = math_ops.cast(distribute_lib.get_tower_context().tower_id,
- mirrored_var.dtype)
+ value = math_ops.cast(
+ distribution_strategy_context.get_tower_context().tower_id,
+ mirrored_var.dtype)
return mirrored_var.assign(value)
self.evaluate(dist.unwrap(dist.call_for_each_tower(
@@ -898,8 +906,9 @@ class MirroredVariableUpdateTest(test.TestCase):
self.assertEquals(1.0, self.evaluate(mirrored_var))
def model_fn():
- value = math_ops.cast(distribute_lib.get_tower_context().tower_id,
- mirrored_var.dtype)
+ value = math_ops.cast(
+ distribution_strategy_context.get_tower_context().tower_id,
+ mirrored_var.dtype)
return mirrored_var.assign_add(value)
self.evaluate(dist.unwrap(dist.call_for_each_tower(
@@ -963,8 +972,9 @@ class MirroredVariableUpdateTest(test.TestCase):
self.assertEquals(5.0, self.evaluate(mirrored_var))
def model_fn():
- value = math_ops.cast(distribute_lib.get_tower_context().tower_id,
- mirrored_var.dtype)
+ value = math_ops.cast(
+ distribution_strategy_context.get_tower_context().tower_id,
+ mirrored_var.dtype)
return mirrored_var.assign_sub(value)
self.evaluate(dist.unwrap(dist.call_for_each_tower(
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py
index a066adf124..5db2fff239 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py
@@ -24,7 +24,7 @@ from tensorflow.python.eager import context
from tensorflow.python.eager import test
from tensorflow.python.framework import test_util
from tensorflow.python.ops import variable_scope
-from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase):
@@ -68,7 +68,8 @@ class VariableCreatorStackTest(test.TestCase):
v = variable_scope.variable(1.0)
# This will pause the current thread, and execute the other thread.
- distribute_lib.get_tower_context().merge_call(lambda _: _)
+ distribution_strategy_context.get_tower_context().merge_call(
+ lambda _: _)
return v
def main_thread_creator(next_creator, *args, **kwargs):
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
index cf29c0ed91..02eb68227d 100644
--- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
@@ -37,7 +37,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import device_util
-from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
@@ -101,7 +101,8 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
last_part_device = 'device:CPU:0'
else:
last_part_device = (
- 'device:GPU:%d' % distribute_lib.get_tower_context().tower_id)
+ 'device:GPU:%d' %
+ distribution_strategy_context.get_tower_context().tower_id)
a = constant_op.constant(1.0)
b = constant_op.constant(2.0)
@@ -192,14 +193,16 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
tower_compute_device = '/device:CPU:0'
else:
tower_compute_device = (
- '/device:GPU:%d' % distribute_lib.get_tower_context().tower_id)
+ '/device:GPU:%d' %
+ distribution_strategy_context.get_tower_context().tower_id)
tower_compute_device = device_util.canonicalize(tower_compute_device)
if 'CPU' in variable_device:
tower_variable_device = '/device:CPU:0'
else:
tower_variable_device = (
- '/device:GPU:%d' % distribute_lib.get_tower_context().tower_id)
+ '/device:GPU:%d' %
+ distribution_strategy_context.get_tower_context().tower_id)
tower_variable_device = device_util.canonicalize(tower_variable_device)
a = constant_op.constant(1.0)
diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py
index baed0ebaae..371b97ba96 100644
--- a/tensorflow/contrib/distribute/python/strategy_test_lib.py
+++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py
@@ -28,7 +28,7 @@ from tensorflow.python.layers import core
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
-from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import optimizer
@@ -45,7 +45,8 @@ def _raise_exception_fn(_=None):
# Must be the argument to a distribution.call_for_each_tower() call, calls a
# get_tower_context().merge_call() that raises an exception.
def _merge_raises_fn():
- distribute_lib.get_tower_context().merge_call(_raise_exception_fn)
+ distribution_strategy_context.get_tower_context().merge_call(
+ _raise_exception_fn)
# Must be the argument to a get_tower_context().merge_call() call, calls
@@ -58,7 +59,7 @@ def _call_raises_fn(dist):
# calls a get_tower_context().merge_call() that calls a
# call_for_each_tower() that raises an exception.
def _merge_call_raises_fn():
- distribute_lib.get_tower_context().merge_call(_call_raises_fn)
+ distribution_strategy_context.get_tower_context().merge_call(_call_raises_fn)
# Must be the argument to a get_tower_context().merge_call() call, calls
@@ -72,7 +73,8 @@ def _call_merge_raises_fn(dist):
# get_tower_context().merge_call() that calls a call_for_each_tower() that
# calls a get_tower_context().merge_call() that raises an exception.
def _merge_call_merge_raises_fn():
- distribute_lib.get_tower_context().merge_call(_call_merge_raises_fn)
+ distribution_strategy_context.get_tower_context().merge_call(
+ _call_merge_raises_fn)
class DistributionTestBase(test.TestCase):
@@ -208,7 +210,7 @@ class DistributionTestBase(test.TestCase):
expected_devices = [False] * len(d.worker_devices)
def mark_devices_fn():
- tower_id = distribute_lib.get_tower_context().tower_id
+ tower_id = distribution_strategy_context.get_tower_context().tower_id
self.assertLess(tower_id, len(d.worker_devices))
self.assertFalse(expected_devices[tower_id])
expected_devices[tower_id] = True
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 5fd4c9de69..8548a86421 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -38,6 +38,7 @@ from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.training import device_util
from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import saver
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.util import nest
@@ -56,7 +57,7 @@ class DistributedValues(object):
def get(self, device=None):
"""Returns the value for the current device or raises a ValueError."""
if device is None:
- tower_context = distribute_lib.get_tower_context()
+ tower_context = distribution_strategy_context.get_tower_context()
if tower_context:
device = tower_context.device
else:
@@ -289,14 +290,15 @@ class DistributedVariable(DistributedDelegate):
# We want cross-tower code that does some var.op.X calls
# to work (even if the current device isn't in self.devices), but
# other uses of var.op in a cross-tower context to fail.
- if distribute_lib.get_cross_tower_context():
+ if distribution_strategy_context.get_cross_tower_context():
return DistributedVarOp(self._primary_var.op.name,
self._primary_var.op.graph,
self._primary_var.op.type)
return self.get().op
def read_value(self):
- return distribute_lib.get_distribution_strategy().read_var(self)
+ return distribution_strategy_context.get_distribution_strategy().read_var(
+ self)
def _should_act_as_resource_variable(self):
"""Pass resource_variable_ops.is_resource_variable check."""
@@ -362,7 +364,7 @@ class MirroredVariable(DistributedVariable, Mirrored,
# update several non-slot variables in one call.
def _assign_func(self, *args, **kwargs):
f = kwargs.pop("f")
- if distribute_lib.get_cross_tower_context():
+ if distribution_strategy_context.get_cross_tower_context():
update_device = distribute_lib.get_update_device()
# We are calling update on the mirrored variable in cross tower context.
if update_device is not None:
@@ -371,7 +373,7 @@ class MirroredVariable(DistributedVariable, Mirrored,
v = self.get(device=update_device)
return f(v, *args, **kwargs)
- return distribute_lib.get_distribution_strategy().update(
+ return distribution_strategy_context.get_distribution_strategy().update(
self, f, *args, **kwargs)
else:
_assert_tower_context()
@@ -392,8 +394,8 @@ class MirroredVariable(DistributedVariable, Mirrored,
aggregation=self._aggregation, value=value, destinations=self),
*other_args, **other_kwargs)
- return distribute_lib.get_tower_context().merge_call(merge_fn, *args,
- **kwargs)
+ return distribution_strategy_context.get_tower_context().merge_call(
+ merge_fn, *args, **kwargs)
def assign_sub(self, *args, **kwargs):
assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
@@ -419,7 +421,7 @@ class MirroredVariable(DistributedVariable, Mirrored,
def _as_graph_element(self):
# pylint: disable=protected-access
- if distribute_lib.get_cross_tower_context():
+ if distribution_strategy_context.get_cross_tower_context():
return self._primary_var._as_graph_element()
return self.get()._as_graph_element()
@@ -459,7 +461,7 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject):
# We use a callable so that we don't have to evaluate this expression
# in the case where we are trying to restore instead of save.
def tensor():
- return distribute_lib.get_distribution_strategy().read_var(
+ return distribution_strategy_context.get_distribution_strategy().read_var(
tower_local_variable)
spec = saver.BaseSaverBuilder.SaveSpec(
tensor=tensor,
@@ -475,7 +477,7 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject):
def _assert_tower_context():
- if not distribute_lib.get_tower_context():
+ if not distribution_strategy_context.get_tower_context():
raise RuntimeError(
"Tower-local variables may only be assigned in a tower context.")
@@ -498,7 +500,7 @@ class TowerLocalVariable(DistributedVariable, PerDevice,
return self.get().assign_add(*args, **kwargs)
def assign(self, *args, **kwargs):
- if distribute_lib.get_cross_tower_context():
+ if distribution_strategy_context.get_cross_tower_context():
# To preserve the sum across save and restore, we have to divide the
# total across all devices when restoring a variable that was summed
# when saving.
@@ -526,7 +528,7 @@ class TowerLocalVariable(DistributedVariable, PerDevice,
def _as_graph_element(self):
# pylint: disable=protected-access
- if distribute_lib.get_cross_tower_context():
+ if distribution_strategy_context.get_cross_tower_context():
return self._get_cross_tower()
return self.get()._as_graph_element()
@@ -994,12 +996,12 @@ class MultiStepContext(object):
outputs as already reduced or not.
"""
- if distribute_lib.get_cross_tower_context():
+ if distribution_strategy_context.get_cross_tower_context():
self._last_step_outputs_aggregations[name] = aggregation
if aggregation is variables_lib.VariableAggregation.NONE:
self._last_step_outputs[name] = output
else:
- distribution = distribute_lib.get_distribution_strategy()
+ distribution = distribution_strategy_context.get_distribution_strategy()
self._last_step_outputs[name] = distribution.reduce(
aggregation, output, destinations="/device:CPU:0")
else:
@@ -1011,7 +1013,9 @@ class MultiStepContext(object):
# context object, so it's more robust to set it only once (even if all
# the towers are trying to set the same value).
self._last_step_outputs_aggregations[name] = aggregation
- distribute_lib.get_tower_context().merge_call(merge_fn, output)
+
+ distribution_strategy_context.get_tower_context().merge_call(
+ merge_fn, output)
@property
def non_tensor_outputs(self):
@@ -1020,14 +1024,15 @@ class MultiStepContext(object):
def set_non_tensor_output(self, name, output):
"""Set `output` with `name` to be captured as a non tensor output."""
- if distribute_lib.get_cross_tower_context():
+ if distribution_strategy_context.get_cross_tower_context():
self._non_tensor_outputs[name] = output
else:
def merge_fn(distribution, value):
# NOTE(priyag): For non tensor outputs, we simply return all the values
# in a list as aggregation doesn't make sense on non tensors.
self._non_tensor_outputs[name] = distribution.unwrap(value)
- distribute_lib.get_tower_context().merge_call(merge_fn, output)
+ distribution_strategy_context.get_tower_context().merge_call(
+ merge_fn, output)
def value_container(val):
diff --git a/tensorflow/contrib/metrics/python/metrics/classification.py b/tensorflow/contrib/metrics/python/metrics/classification.py
index e553612269..7053907da0 100644
--- a/tensorflow/contrib/metrics/python/metrics/classification.py
+++ b/tensorflow/contrib/metrics/python/metrics/classification.py
@@ -24,7 +24,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics_impl
from tensorflow.python.ops import variable_scope
-from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
# TODO(nsilberman): move into metrics/python/ops/
@@ -174,7 +174,7 @@ def f1_score(labels, predictions, weights=None, num_thresholds=200,
ops.add_to_collections(metrics_collections, best_f1)
return best_f1
- best_f1 = distribute_lib.get_tower_context().merge_call(
+ best_f1 = distribution_strategy_context.get_tower_context().merge_call(
f1_across_towers, values)
update_op = compute_best_f1_score(tp=update_ops['tp'], fp=update_ops['fp'],
diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
index 8c11d8bcfd..f6ecaba834 100644
--- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py
+++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
@@ -34,6 +34,7 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import optimizer as optimizer_v1
from tensorflow.python.training import slot_creator
from tensorflow.python.training.checkpointable import base as checkpointable
@@ -620,7 +621,7 @@ class OptimizerV2(optimizer_v1.Optimizer):
# Map from graph_key to state for that graph. We use the graph_key
# since it works in both eager and graph mode, and gives the outer
# graph inside functions.
- tower_context = distribute_lib.get_tower_context()
+ tower_context = distribution_strategy_context.get_tower_context()
if tower_context is None:
# In a cross-tower context for a DistributionStrategy, which means
# only one Optimizer will be created, not one per tower.
@@ -769,7 +770,8 @@ class OptimizerV2(optimizer_v1.Optimizer):
distribute_lib.get_loss_reduction() ==
variable_scope.VariableAggregation.MEAN)
if scale_loss_by_num_towers:
- num_towers = distribute_lib.get_distribution_strategy().num_towers
+ num_towers = distribution_strategy_context.get_distribution_strategy(
+ ).num_towers
if num_towers > 1:
loss_value *= 1. / num_towers
@@ -788,7 +790,8 @@ class OptimizerV2(optimizer_v1.Optimizer):
distribute_lib.get_loss_reduction() ==
variable_scope.VariableAggregation.MEAN)
if scale_loss_by_num_towers:
- num_towers = distribute_lib.get_distribution_strategy().num_towers
+ num_towers = distribution_strategy_context.get_distribution_strategy(
+ ).num_towers
if num_towers > 1:
loss *= 1. / num_towers
@@ -862,7 +865,7 @@ class OptimizerV2(optimizer_v1.Optimizer):
if not filtered:
raise ValueError("No gradients provided for any variable: %s." %
([str(v) for _, v in grads_and_vars],))
- return distribute_lib.get_tower_context().merge_call(
+ return distribution_strategy_context.get_tower_context().merge_call(
self._distributed_apply, filtered, global_step=global_step, name=name)
def _get_or_create_state(self, var_list=None):
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 2a71eaf030..a6bb6158e6 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -3340,7 +3340,10 @@ py_library(
py_library(
name = "distribute",
- srcs = ["training/distribute.py"],
+ srcs = [
+ "training/distribute.py",
+ "training/distribution_strategy_context.py",
+ ],
srcs_version = "PY2AND3",
deps = [
":array_ops",
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 8b9b62e034..4958ba56c5 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -43,7 +43,7 @@ from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
-from tensorflow.python.training import distribute
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
@@ -228,6 +228,9 @@ class FuncGraph(CapturingGraph):
self.get_collection_ref(collection)[:] = graph.get_collection(
collection)
+ # Copy distribution strategy scope from the containing graph as well.
+ self._distribution_strategy_stack = graph._distribution_strategy_stack # pylint: disable=protected-access
+
if context.executing_eagerly():
self.seed = context.global_seed()
else:
@@ -569,7 +572,7 @@ class GraphModeFunction(object):
# Find the variables that are components of something distributed and
# put them into a {handle_tensor -> distributed variable object} map.
self._distributed_variables = {}
- strategy = distribute.get_distribution_strategy()
+ strategy = distribution_strategy_context.get_distribution_strategy()
for variable in self._variables:
# If variable is not distributed, unwrap returns [variable].
component_variables = strategy.unwrap(variable)
@@ -901,7 +904,7 @@ def _trace_and_define_function(name, python_func, compiled, args, kwds,
# the function is run on a different device). Thus, instead of storing
# the specific captured variable, we replace it with its distributed
# container.
- strategy = distribute.get_distribution_strategy()
+ strategy = distribution_strategy_context.get_distribution_strategy()
for i, variable in enumerate(variables):
# If variable is not distributed value_container returns itself.
variables[i] = strategy.value_container(variable)
diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py
index a5f07fea3b..e4ce5339d0 100644
--- a/tensorflow/python/estimator/keras.py
+++ b/tensorflow/python/estimator/keras.py
@@ -43,7 +43,7 @@ from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import checkpoint_management
-from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import optimizer as tf_optimizer_module
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
@@ -361,7 +361,7 @@ def _create_keras_model_fn(keras_model, custom_objects=None):
"""model_fn for keras Estimator."""
# Raise an error when users use DistributionStrategy with native Keras
# optimizers. Currently we only support native TensorFlow optimizers.
- if distribute_lib.has_distribution_strategy() and \
+ if distribution_strategy_context.has_distribution_strategy() and \
not isinstance(keras_model.optimizer,
(tf_optimizer_module.Optimizer, optimizers.TFOptimizer)):
raise ValueError('Only TensorFlow native optimizers are supported with '
@@ -373,7 +373,7 @@ def _create_keras_model_fn(keras_model, custom_objects=None):
# We need to make sure that the output names of the last layer in the model
# is the same for each of the cloned models. This is required for mirrored
# strategy when we call regroup.
- if distribute_lib.has_distribution_strategy():
+ if distribution_strategy_context.has_distribution_strategy():
for name in model.output_names:
name = re.compile(r'_\d$').sub('', name)
model_output_names.append(name)
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 98a1802490..5527f52860 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -4860,6 +4860,18 @@ class Graph(object):
else:
self._graph_control_dependencies_stack = control_dependencies
+ @property
+ def _distribution_strategy_stack(self):
+ """A stack to maintain distribution strategy context for each thread."""
+ if not hasattr(self._thread_local, "_distribution_strategy_stack"):
+ self._thread_local._distribution_strategy_stack = [] # pylint: disable=protected-access
+ return self._thread_local._distribution_strategy_stack # pylint: disable=protected-access
+
+ @_distribution_strategy_stack.setter
+ def _distribution_strategy_stack(self, _distribution_strategy_stack):
+ self._thread_local._distribution_strategy_stack = ( # pylint: disable=protected-access
+ _distribution_strategy_stack)
+
def _mutation_lock(self):
"""Returns a lock to guard code that creates & mutates ops.
diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py
index a7835bc0a2..cd26e04c39 100644
--- a/tensorflow/python/keras/layers/normalization.py
+++ b/tensorflow/python/keras/layers/normalization.py
@@ -36,7 +36,7 @@ from tensorflow.python.ops import nn
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
-from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.util.tf_export import tf_export
@@ -345,16 +345,16 @@ class BatchNormalization(Layer):
aggregation=variable_scope.VariableAggregation.MEAN)
return var
- with distribute_lib.get_distribution_strategy().colocate_vars_with(
- self.moving_mean):
+ with distribution_strategy_context.get_distribution_strategy(
+ ).colocate_vars_with(self.moving_mean):
self.renorm_mean = _renorm_variable('renorm_mean', param_shape)
self.renorm_mean_weight = _renorm_variable('renorm_mean_weight', ())
# We initialize renorm_stddev to 0, and maintain the (0-initialized)
# renorm_stddev_weight. This allows us to (1) mix the average
# stddev with the minibatch stddev early in training, and (2) compute
# the unbiased average stddev by dividing renorm_stddev by the weight.
- with distribute_lib.get_distribution_strategy().colocate_vars_with(
- self.moving_variance):
+ with distribution_strategy_context.get_distribution_strategy(
+ ).colocate_vars_with(self.moving_variance):
self.renorm_stddev = _renorm_variable('renorm_stddev', param_shape)
self.renorm_stddev_weight = _renorm_variable('renorm_stddev_weight',
())
diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py
index 2dde9ee41f..9b87170ebe 100644
--- a/tensorflow/python/keras/metrics.py
+++ b/tensorflow/python/keras/metrics.py
@@ -55,7 +55,7 @@ from tensorflow.python.ops import nn
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import weights_broadcast_ops
-from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.util import tf_decorator
from tensorflow.python.util.tf_export import tf_export
@@ -111,7 +111,7 @@ def result_wrapper(result_fn):
def decorated(metric_obj, *args):
"""Decorated function with merge_call."""
- tower_context = distribute_lib.get_tower_context()
+ tower_context = distribution_strategy_context.get_tower_context()
if tower_context is None: # if in cross tower context already
result_t = result_fn(*args)
else:
diff --git a/tensorflow/python/keras/optimizers.py b/tensorflow/python/keras/optimizers.py
index 4f97442e82..f339a7e047 100644
--- a/tensorflow/python/keras/optimizers.py
+++ b/tensorflow/python/keras/optimizers.py
@@ -28,7 +28,7 @@ from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
-from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import optimizer as tf_optimizer_module
from tensorflow.python.training import training_util
from tensorflow.python.training.checkpointable import base as checkpointable
@@ -705,7 +705,7 @@ class TFOptimizer(Optimizer, checkpointable.CheckpointableBase):
return self.optimizer.compute_gradients(loss, params)
def get_updates(self, loss, params):
- if distribute_lib.has_distribution_strategy():
+ if distribution_strategy_context.has_distribution_strategy():
self.updates = []
if not params:
diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py
index 3aedeb6acd..9461a01515 100644
--- a/tensorflow/python/ops/metrics_impl.py
+++ b/tensorflow/python/ops/metrics_impl.py
@@ -34,7 +34,7 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import weights_broadcast_ops
from tensorflow.python.platform import tf_logging as logging
-from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
@@ -57,7 +57,8 @@ def metric_variable(shape, dtype, validate_shape=True, name=None):
Furthermore, the final answer should be computed once instead of
in every replica/tower. Both of these are accomplished by
running the computation of the final result value inside
- `tf.contrib.distribute.get_tower_context().merge_call(fn)`.
+ `tf.contrib.distribution_strategy_context.get_tower_context(
+ ).merge_call(fn)`.
Inside the `merge_call()`, ops are only added to the graph once
and access to a tower-local variable in a computation returns
the sum across all replicas/towers.
@@ -373,7 +374,7 @@ def mean(values,
ops.add_to_collections(metrics_collections, mean_t)
return mean_t
- mean_t = distribute_lib.get_tower_context().merge_call(
+ mean_t = distribution_strategy_context.get_tower_context().merge_call(
aggregate_across_towers, total, count)
update_op = _safe_div(update_total_op, update_count_op, 'update_op')
@@ -618,7 +619,7 @@ def _aggregate_variable(v, collections):
ops.add_to_collections(collections, value)
return value
- return distribute_lib.get_tower_context().merge_call(f, v)
+ return distribution_strategy_context.get_tower_context().merge_call(f, v)
@tf_export('metrics.auc')
@@ -813,7 +814,7 @@ def auc(labels,
ops.add_to_collections(metrics_collections, auc_value)
return auc_value
- auc_value = distribute_lib.get_tower_context().merge_call(
+ auc_value = distribution_strategy_context.get_tower_context().merge_call(
aggregate_auc, values)
update_op = compute_auc(update_ops['tp'], update_ops['fn'],
update_ops['tn'], update_ops['fp'], 'update_op')
@@ -1053,8 +1054,8 @@ def mean_per_class_accuracy(labels,
ops.add_to_collections(metrics_collections, mean_accuracy_v)
return mean_accuracy_v
- mean_accuracy_v = distribute_lib.get_tower_context().merge_call(
- aggregate_mean_accuracy, count, total)
+ mean_accuracy_v = distribution_strategy_context.get_tower_context(
+ ).merge_call(aggregate_mean_accuracy, count, total)
update_op = _safe_div(update_count_op, update_total_op, name='update_op')
if updates_collections:
@@ -1160,7 +1161,7 @@ def mean_iou(labels,
ops.add_to_collections(metrics_collections, mean_iou_v)
return mean_iou_v
- mean_iou_v = distribute_lib.get_tower_context().merge_call(
+ mean_iou_v = distribution_strategy_context.get_tower_context().merge_call(
mean_iou_across_towers, total_cm)
if updates_collections:
@@ -1376,7 +1377,7 @@ def mean_tensor(values,
ops.add_to_collections(metrics_collections, mean_t)
return mean_t
- mean_t = distribute_lib.get_tower_context().merge_call(
+ mean_t = distribution_strategy_context.get_tower_context().merge_call(
aggregate_across_towers, total, count)
update_op = _safe_div(update_total_op, update_count_op, 'update_op')
@@ -2008,7 +2009,7 @@ def precision(labels,
ops.add_to_collections(metrics_collections, p)
return p
- p = distribute_lib.get_tower_context().merge_call(
+ p = distribution_strategy_context.get_tower_context().merge_call(
once_across_towers, true_p, false_p)
update_op = compute_precision(true_positives_update_op,
@@ -2092,7 +2093,7 @@ def precision_at_thresholds(labels,
ops.add_to_collections(metrics_collections, prec)
return prec
- prec = distribute_lib.get_tower_context().merge_call(
+ prec = distribution_strategy_context.get_tower_context().merge_call(
precision_across_towers, values)
update_op = compute_precision(update_ops['tp'], update_ops['fp'],
@@ -2188,7 +2189,7 @@ def recall(labels,
ops.add_to_collections(metrics_collections, rec)
return rec
- rec = distribute_lib.get_tower_context().merge_call(
+ rec = distribution_strategy_context.get_tower_context().merge_call(
once_across_towers, true_p, false_n)
update_op = compute_recall(true_positives_update_op,
@@ -2627,7 +2628,7 @@ def recall_at_top_k(labels,
ops.add_to_collections(metrics_collections, metric)
return metric
- metric = distribute_lib.get_tower_context().merge_call(
+ metric = distribution_strategy_context.get_tower_context().merge_call(
aggregate_across_towers, tp, fn)
update = math_ops.div(
@@ -2708,7 +2709,7 @@ def recall_at_thresholds(labels,
ops.add_to_collections(metrics_collections, rec)
return rec
- rec = distribute_lib.get_tower_context().merge_call(
+ rec = distribution_strategy_context.get_tower_context().merge_call(
recall_across_towers, values)
update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op')
@@ -2783,7 +2784,7 @@ def root_mean_squared_error(labels,
ops.add_to_collections(metrics_collections, rmse)
return rmse
- rmse = distribute_lib.get_tower_context().merge_call(
+ rmse = distribution_strategy_context.get_tower_context().merge_call(
once_across_towers, mse)
update_rmse_op = math_ops.sqrt(update_mse_op)
@@ -2886,7 +2887,7 @@ def sensitivity_at_specificity(labels,
ops.add_to_collections(metrics_collections, sensitivity)
return sensitivity
- sensitivity = distribute_lib.get_tower_context().merge_call(
+ sensitivity = distribution_strategy_context.get_tower_context().merge_call(
aggregate_across_towers, values)
update_op = compute_sensitivity_at_specificity(
@@ -3162,8 +3163,8 @@ def _streaming_sparse_average_precision_at_top_k(labels,
ops.add_to_collections(metrics_collections, mean_average_precision)
return mean_average_precision
- mean_average_precision = distribute_lib.get_tower_context().merge_call(
- aggregate_across_towers, total_var, max_var)
+ mean_average_precision = distribution_strategy_context.get_tower_context(
+ ).merge_call(aggregate_across_towers, total_var, max_var)
update = _safe_scalar_div(total_update, max_update, name=scope)
if updates_collections:
@@ -3448,7 +3449,7 @@ def precision_at_top_k(labels,
ops.add_to_collections(metrics_collections, metric)
return metric
- metric = distribute_lib.get_tower_context().merge_call(
+ metric = distribution_strategy_context.get_tower_context().merge_call(
aggregate_across_towers, tp, fp)
update = math_ops.div(
@@ -3687,7 +3688,7 @@ def specificity_at_sensitivity(labels,
ops.add_to_collections(metrics_collections, specificity)
return specificity
- specificity = distribute_lib.get_tower_context().merge_call(
+ specificity = distribution_strategy_context.get_tower_context().merge_call(
aggregate_across_towers, values)
update_op = compute_specificity_at_sensitivity(
diff --git a/tensorflow/python/ops/summary_op_util.py b/tensorflow/python/ops/summary_op_util.py
index a793f634bd..b382c3b7ce 100644
--- a/tensorflow/python/ops/summary_op_util.py
+++ b/tensorflow/python/ops/summary_op_util.py
@@ -23,7 +23,7 @@ import re
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging
-from tensorflow.python.training import distribute
+from tensorflow.python.training import distribution_strategy_context
def collect(val, collections, default_collections):
@@ -49,7 +49,7 @@ def skip_summary():
# TODO(priyag): Add a new optional argument that will provide multiple
# alternatives to override default behavior. (e.g. run on last tower,
# compute sum or mean across towers).
- tower_context = distribute.get_tower_context()
+ tower_context = distribution_strategy_context.get_tower_context()
return tower_context and tower_context.tower_id > 0
diff --git a/tensorflow/python/training/checkpoint_utils.py b/tensorflow/python/training/checkpoint_utils.py
index 9b72b09f08..e6118177fd 100644
--- a/tensorflow/python/training/checkpoint_utils.py
+++ b/tensorflow/python/training/checkpoint_utils.py
@@ -29,7 +29,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_management
-from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import saver
from tensorflow.python.util.tf_export import tf_export
@@ -180,10 +180,10 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
tf.errors.OpError: If missing checkpoints or tensors in checkpoints.
ValueError: If missing variables in current graph.
"""
- if distribute_lib.get_cross_tower_context():
+ if distribution_strategy_context.get_cross_tower_context():
_init_from_checkpoint(None, ckpt_dir_or_file, assignment_map)
else:
- distribute_lib.get_tower_context().merge_call(
+ distribution_strategy_context.get_tower_context().merge_call(
_init_from_checkpoint, ckpt_dir_or_file, assignment_map)
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py
index 581db45e80..0d8d74a096 100644
--- a/tensorflow/python/training/distribute.py
+++ b/tensorflow/python/training/distribute.py
@@ -21,7 +21,7 @@ from __future__ import print_function
import threading
from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.eager import context
+from tensorflow.python.eager import context as eager_context
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@@ -31,71 +31,11 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import losses_impl
from tensorflow.python.platform import tf_logging
from tensorflow.python.training import device_util
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.util import nest
# ------------------------------------------------------------------------------
-# Internal API for setting the current thread mode as being either in a
-# tower or cross-tower context for a particular distribution strategy.
-
-
-class _ThreadMode(object):
-
- def __init__(self, dist, cross, tower):
- self.distribution_strategy = dist
- self.cross_tower_context = cross
- self.tower_context = tower
-
-
-class _CrossTowerThreadMode(_ThreadMode):
-
- def __init__(self, distribution_strategy):
- _ThreadMode.__init__(
- self, distribution_strategy, distribution_strategy, None)
-
-
-class _InTowerThreadMode(_ThreadMode):
-
- def __init__(self, tower_ctx):
- _ThreadMode.__init__(
- self, tower_ctx.distribution_strategy, None, tower_ctx)
-
-
-_per_thread_mode = threading.local()
-
-
-def _push_per_thread_mode(context):
- if not hasattr(_per_thread_mode, "stack"):
- _per_thread_mode.stack = []
- _per_thread_mode.stack.append(context)
-
-
-def _pop_per_thread_mode():
- _per_thread_mode.stack.pop(-1)
-
-
-class _DefaultTowerThreadMode(_ThreadMode):
- """Type of default value returned by `_get_per_thread_mode()`.
-
- Used when the thread-local stack is empty.
- """
-
- def __init__(self):
- # _default_distribution_strategy and _default_tower_context are
- # defined at the bottom of this file.
- _ThreadMode.__init__(
- self, _default_distribution_strategy, None, _default_tower_context)
-
-
-def _get_per_thread_mode():
- try:
- return _per_thread_mode.stack[-1]
- except (AttributeError, IndexError):
- # _default_tower_mode is defined at the bottom of this file.
- return _default_tower_mode
-
-
-# ------------------------------------------------------------------------------
# Context tracking whether in a distribution.update() or .update_non_slot()
# call.
@@ -128,96 +68,6 @@ class UpdateContext(object):
# ------------------------------------------------------------------------------
-# Public API for accessing the current thread mode
-
-
-def get_tower_context():
- """Returns the current TowerContext or None if in a cross-tower context.
-
- Note that execution:
- 1. starts in the default (single-tower) tower context (this function
- will return the default TowerContext object);
- 2. switches to cross-tower context (in which case this will return
- None) when entering a `with DistributionStrategy.scope():` block;
- 3. switches to a (non-default) tower context inside
- `call_for_each_tower(fn, ...)`;
- 4. if `fn` calls `get_tower_context()->merge_call(merge_fn, ...)`, then
- inside `merge_fn` you are back in the cross-tower context (and again
- this function will return None).
-
- Note that you can also go directly from step 1 to 4 to switch to a
- cross-tower context for the default `DistributionStrategy`. You may
- also switch from the cross-tower context of 4 to a tower context by
- calling `call_for_each_tower()`, jumping back to step 3.
-
- Most `DistributionStrategy` methods may only be executed in
- a cross-tower context, in a tower context you should use the
- `TowerContext` API instead.
-
- Returns:
- The current `TowerContext` object when in a tower context scope, else None.
-
- Exactly one of `get_tower_context()` and `get_cross_tower_context()`
- will return None in a particular block.
- """
- return _get_per_thread_mode().tower_context
-
-
-def get_cross_tower_context():
- """Returns the current DistributionStrategy if in a cross-tower context.
-
- Note that execution:
- 1. starts in the default (single-tower) tower context;
- 2. switches to cross-tower context when entering a
- `with DistributionStrategy.scope():` block;
- 3. switches to a (non-default) tower context inside
- `call_for_each_tower(fn, ...)`;
- 4. if `fn` calls `get_tower_context()->merge_call(merge_fn, ...)`, then
- inside `merge_fn` you are back in the cross-tower context.
-
- Note that you can also go directly from step 1 to 4 to switch to a
- cross-tower context for the default `DistributionStrategy`. You may
- also switch from the cross-tower context of 4 to a tower context by
- calling `call_for_each_tower()`, jumping back to step 3.
-
- Most `DistributionStrategy` methods may only be executed in
- a cross-tower context.
-
- Returns:
- Returns the current `DistributionStrategy` object in a cross-tower
- context, or None.
-
- Exactly one of `get_tower_context()` and `get_cross_tower_context()`
- will return None in a particular block.
- """
- return _get_per_thread_mode().cross_tower_context
-
-
-def get_distribution_strategy():
- """Returns the current `DistributionStrategy` object.
-
- Prefer to use `get_tower_context()` or `get_cross_tower_context()`
- instead when possible.
-
- Returns:
- A `DistributionStrategy` object. Inside a
- `with distribution_strategy.scope()` block, it returns
- `distribution_strategy`, otherwise it returns the default
- (single-tower) `DistributionStrategy` object.
- """
- return _get_per_thread_mode().distribution_strategy
-
-
-def has_distribution_strategy():
- """Return if there is a current non-default `DistributionStrategy`.
-
- Returns:
- True if inside a `with distribution_strategy.scope():`.
- """
- return get_distribution_strategy() is not _default_distribution_strategy
-
-
-# ------------------------------------------------------------------------------
# Public utility functions.
@@ -239,7 +89,8 @@ def _require_cross_tower_context(distribution_strategy):
if context.cross_tower_context is distribution_strategy: return
# We have an error to report, figure out the right message.
if context.distribution_strategy is not distribution_strategy:
- if context.distribution_strategy is _default_distribution_strategy:
+ if (context.distribution_strategy is
+ distribution_strategy_context._get_default_distribution_strategy()): # pylint: disable=protected-access
raise RuntimeError(
'Need to be inside "with distribution_strategy.scope()" for %s' %
(distribution_strategy,))
@@ -272,7 +123,8 @@ def _require_distribution_strategy_scope(distribution_strategy):
context = _get_per_thread_mode()
if context.distribution_strategy is distribution_strategy: return
# We have an error to report, figure out the right message.
- if context.distribution_strategy is _default_distribution_strategy:
+ if (context.distribution_strategy is
+ distribution_strategy_context._get_default_distribution_strategy()): # pylint: disable=protected-access
raise RuntimeError(
'Need to be inside "with distribution_strategy.scope()" for %s' %
(distribution_strategy,))
@@ -295,7 +147,8 @@ class _CurrentDistributionContext(object):
var_creator_scope,
var_scope=None,
default_device=None):
- self._context = _CrossTowerThreadMode(distribution_strategy)
+ self._context = distribution_strategy_context._CrossTowerThreadMode( # pylint: disable=protected-access
+ distribution_strategy)
self._var_creator_scope = var_creator_scope
self._var_scope = var_scope
if default_device:
@@ -588,7 +441,7 @@ class DistributionStrategy(object):
Returns:
A context manager.
"""
- if has_distribution_strategy():
+ if distribution_strategy_context.has_distribution_strategy():
_require_cross_tower_context(self)
return _SameScopeAgainContext(self)
@@ -740,7 +593,7 @@ class DistributionStrategy(object):
In eager mode, returns `None`.
In graph mode, a list of ops to execute. Empty list if nothing to be done.
"""
- if context.executing_eagerly():
+ if eager_context.executing_eagerly():
return
else:
return []
@@ -757,7 +610,7 @@ class DistributionStrategy(object):
In eager mode, returns `None`.
In graph mode, a list of ops to execute. Empty list if nothing to be done.
"""
- if context.executing_eagerly():
+ if eager_context.executing_eagerly():
return
else:
return []
@@ -1106,7 +959,8 @@ class TowerContext(object):
def __init__(self, distribution_strategy, tower_id):
self._distribution_strategy = distribution_strategy
- self._thread_context = _InTowerThreadMode(self)
+ self._thread_context = distribution_strategy_context._InTowerThreadMode( # pylint: disable=protected-access
+ self)
self._tower_id = tower_id
def __enter__(self):
@@ -1149,7 +1003,8 @@ class TowerContext(object):
def _merge_call(self, merge_fn, *args, **kwargs):
"""Default implementation for single tower."""
_push_per_thread_mode( # thread-local, so not needed with multiple threads
- _CrossTowerThreadMode(self._distribution_strategy))
+ distribution_strategy_context._CrossTowerThreadMode( # pylint: disable=protected-access
+ self._distribution_strategy))
try:
return merge_fn(self._distribution_strategy, *args, **kwargs)
finally:
@@ -1196,7 +1051,7 @@ class _DefaultDistributionStrategy(DistributionStrategy):
def scope(self):
"""Context manager setting a variable creator and `self` as current."""
- if has_distribution_strategy():
+ if distribution_strategy_context.has_distribution_strategy():
raise RuntimeError("Must not nest DistributionStrategy scopes.")
def creator(next_creator, *args, **kwargs):
@@ -1277,6 +1132,7 @@ class _DefaultDistributionStrategy(DistributionStrategy):
raise RuntimeError("worker_device_index() method unsupported by "
"_DefaultDistributionStrategy.")
+
# ------------------------------------------------------------------------------
# Common operations
@@ -1292,20 +1148,11 @@ def increment_var(v, amount=1):
def merge_fn(dist, vm):
return dist.group(dist.update(vm, update))
- tower_context = get_tower_context()
+ tower_context = distribution_strategy_context.get_tower_context()
return tower_context.merge_call(merge_fn, v)
# ------------------------------------------------------------------------------
-# Singletons
-
-_default_distribution_strategy = _DefaultDistributionStrategy()
-_default_tower_context = TowerContext(
- _default_distribution_strategy, tower_id=0)
-_default_tower_mode = _DefaultTowerThreadMode()
-
-
-# ------------------------------------------------------------------------------
# We haven't yet implemented deserialization for DistributedVariables.
# So here we catch any attempts to deserialize variables
# when using distribution strategies.
@@ -1314,7 +1161,7 @@ _original_from_proto = resource_variable_ops._from_proto_fn
def _from_proto_fn(v, import_scope=None):
- if has_distribution_strategy():
+ if distribution_strategy_context.has_distribution_strategy():
raise NotImplementedError(
"Deserialization of variables is not yet supported when using"
"distributed strategies.")
@@ -1323,3 +1170,10 @@ def _from_proto_fn(v, import_scope=None):
resource_variable_ops._from_proto_fn = _from_proto_fn
# pylint: enable=protected-access
+
+
+#-------------------------------------------------------------------------------
+# Shorthand for some methods from distribution_strategy_context.
+_push_per_thread_mode = distribution_strategy_context._push_per_thread_mode # pylint: disable=protected-access
+_get_per_thread_mode = distribution_strategy_context._get_per_thread_mode # pylint: disable=protected-access
+_pop_per_thread_mode = distribution_strategy_context._pop_per_thread_mode # pylint: disable=protected-access
diff --git a/tensorflow/python/training/distribute_test.py b/tensorflow/python/training/distribute_test.py
index 694145ede7..f03bd39100 100644
--- a/tensorflow/python/training/distribute_test.py
+++ b/tensorflow/python/training/distribute_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
from tensorflow.python.training import distribute
+from tensorflow.python.training import distribution_strategy_context
class _TestTowerContext(distribute.TowerContext):
@@ -49,12 +50,12 @@ class _TestStrategy(distribute.DistributionStrategy):
def _assert_in_default_state(t):
- t.assertIs(distribute._default_tower_context,
- distribute.get_tower_context())
- t.assertIs(None, distribute.get_cross_tower_context())
- t.assertIs(distribute._default_distribution_strategy,
- distribute.get_distribution_strategy())
- t.assertFalse(distribute.has_distribution_strategy())
+ t.assertIs(distribution_strategy_context._get_default_tower_context(),
+ distribution_strategy_context.get_tower_context())
+ t.assertIs(None, distribution_strategy_context.get_cross_tower_context())
+ t.assertIs(distribution_strategy_context._get_default_distribution_strategy(),
+ distribution_strategy_context.get_distribution_strategy())
+ t.assertFalse(distribution_strategy_context.has_distribution_strategy())
class TestStrategyTest(test.TestCase):
@@ -64,11 +65,13 @@ class TestStrategyTest(test.TestCase):
dist = _TestStrategy()
def run_fn():
- tower_context = distribute.get_tower_context()
+ tower_context = distribution_strategy_context.get_tower_context()
self.assertTrue(tower_context is not None)
- self.assertIs(None, distribute.get_cross_tower_context())
- self.assertTrue(distribute.has_distribution_strategy())
- self.assertIs(dist, distribute.get_distribution_strategy())
+ self.assertIs(None,
+ distribution_strategy_context.get_cross_tower_context())
+ self.assertTrue(distribution_strategy_context.has_distribution_strategy())
+ self.assertIs(dist,
+ distribution_strategy_context.get_distribution_strategy())
self.assertEqual("foo", tower_context.merge_call(None, test_arg="foo"))
expected_value = _get_test_variable(
"bar", variable_scope.VariableSynchronization.AUTO,
@@ -86,10 +89,12 @@ class TestStrategyTest(test.TestCase):
_assert_in_default_state(self)
dist = _TestStrategy()
with dist.scope():
- self.assertIs(None, distribute.get_tower_context())
- self.assertIs(dist, distribute.get_cross_tower_context())
- self.assertTrue(distribute.has_distribution_strategy())
- self.assertIs(dist, distribute.get_distribution_strategy())
+ self.assertIs(None, distribution_strategy_context.get_tower_context())
+ self.assertIs(dist,
+ distribution_strategy_context.get_cross_tower_context())
+ self.assertTrue(distribution_strategy_context.has_distribution_strategy())
+ self.assertIs(dist,
+ distribution_strategy_context.get_distribution_strategy())
expected_value = _get_test_variable(
"baz", variable_scope.VariableSynchronization.AUTO,
variable_scope.VariableAggregation.NONE)
@@ -120,15 +125,21 @@ class DefaultDistributionStrategyTest(test.TestCase):
_assert_in_default_state(self)
def merge_fn(dist, s):
- self.assertIs(distribute._default_distribution_strategy, dist)
- self.assertIs(None, distribute.get_tower_context())
- self.assertIs(dist, distribute.get_cross_tower_context())
- self.assertIs(dist, distribute.get_distribution_strategy())
- self.assertFalse(distribute.has_distribution_strategy())
+ self.assertIs(
+ distribution_strategy_context._get_default_distribution_strategy(),
+ dist)
+ self.assertIs(None, distribution_strategy_context.get_tower_context())
+ self.assertIs(dist,
+ distribution_strategy_context.get_cross_tower_context())
+ self.assertIs(dist,
+ distribution_strategy_context.get_distribution_strategy())
+ self.assertFalse(
+ distribution_strategy_context.has_distribution_strategy())
return "foo_" + s
- tower_ctx = distribute.get_tower_context()
- self.assertIs(distribute._default_tower_context, tower_ctx)
+ tower_ctx = distribution_strategy_context.get_tower_context()
+ self.assertIs(distribution_strategy_context._get_default_tower_context(),
+ tower_ctx)
self.assertEqual("foo_bar", tower_ctx.merge_call(merge_fn, "bar"))
_assert_in_default_state(self)
diff --git a/tensorflow/python/training/distribution_strategy_context.py b/tensorflow/python/training/distribution_strategy_context.py
new file mode 100644
index 0000000000..998b5c35ce
--- /dev/null
+++ b/tensorflow/python/training/distribution_strategy_context.py
@@ -0,0 +1,203 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utility to get distribution strategy related contexts."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.util.lazy_loader import LazyLoader
+
+
+# There is a circular dependency between this and `distribute` module. So we
+# load it lazily to workaround this.
+distribute_lib = LazyLoader(
+ "distribute_lib", globals(),
+ "tensorflow.python.training.distribute")
+
+# ------------------------------------------------------------------------------
+# Internal API for setting the current thread mode as being either in a
+# tower or cross-tower context for a particular distribution strategy.
+
+
+class _ThreadMode(object):
+
+ def __init__(self, dist, cross, tower):
+ self.distribution_strategy = dist
+ self.cross_tower_context = cross
+ self.tower_context = tower
+
+
+class _CrossTowerThreadMode(_ThreadMode):
+
+ def __init__(self, distribution_strategy):
+ _ThreadMode.__init__(
+ self, distribution_strategy, distribution_strategy, None)
+
+
+class _InTowerThreadMode(_ThreadMode):
+
+ def __init__(self, tower_ctx):
+ _ThreadMode.__init__(
+ self, tower_ctx.distribution_strategy, None, tower_ctx)
+
+
+def _push_per_thread_mode(context):
+ ops.get_default_graph()._distribution_strategy_stack.append(context) # pylint: disable=protected-access
+
+
+def _pop_per_thread_mode():
+ ops.get_default_graph()._distribution_strategy_stack.pop(-1) # pylint: disable=protected-access
+
+
+class _DefaultTowerThreadMode(_ThreadMode):
+ """Type of default value returned by `_get_per_thread_mode()`.
+
+ Used when the thread-local stack is empty.
+ """
+
+ def __init__(self):
+ _ThreadMode.__init__(self, _get_default_distribution_strategy(), None,
+ _get_default_tower_context())
+
+
+def _get_per_thread_mode():
+ try:
+ return ops.get_default_graph()._distribution_strategy_stack[-1] # pylint: disable=protected-access
+ except (AttributeError, IndexError):
+ return _get_default_tower_mode()
+
+
+# ------------------------------------------------------------------------------
+# Public API for accessing the current thread mode
+
+
+def get_tower_context():
+ """Returns the current TowerContext or None if in a cross-tower context.
+
+ Note that execution:
+ 1. starts in the default (single-tower) tower context (this function
+ will return the default TowerContext object);
+ 2. switches to cross-tower context (in which case this will return
+ None) when entering a `with DistributionStrategy.scope():` block;
+ 3. switches to a (non-default) tower context inside
+ `call_for_each_tower(fn, ...)`;
+ 4. if `fn` calls `get_tower_context()->merge_call(merge_fn, ...)`, then
+ inside `merge_fn` you are back in the cross-tower context (and again
+ this function will return None).
+
+ Note that you can also go directly from step 1 to 4 to switch to a
+ cross-tower context for the default `DistributionStrategy`. You may
+ also switch from the cross-tower context of 4 to a tower context by
+ calling `call_for_each_tower()`, jumping back to step 3.
+
+ Most `DistributionStrategy` methods may only be executed in
+ a cross-tower context, in a tower context you should use the
+ `TowerContext` API instead.
+
+ Returns:
+ The current `TowerContext` object when in a tower context scope, else None.
+
+ Exactly one of `get_tower_context()` and `get_cross_tower_context()`
+ will return None in a particular block.
+ """
+ return _get_per_thread_mode().tower_context
+
+
+def get_cross_tower_context():
+ """Returns the current DistributionStrategy if in a cross-tower context.
+
+ Note that execution:
+ 1. starts in the default (single-tower) tower context;
+ 2. switches to cross-tower context when entering a
+ `with DistributionStrategy.scope():` block;
+ 3. switches to a (non-default) tower context inside
+ `call_for_each_tower(fn, ...)`;
+ 4. if `fn` calls `get_tower_context()->merge_call(merge_fn, ...)`, then
+ inside `merge_fn` you are back in the cross-tower context.
+
+ Note that you can also go directly from step 1 to 4 to switch to a
+ cross-tower context for the default `DistributionStrategy`. You may
+ also switch from the cross-tower context of 4 to a tower context by
+ calling `call_for_each_tower()`, jumping back to step 3.
+
+ Most `DistributionStrategy` methods may only be executed in
+ a cross-tower context.
+
+ Returns:
+ Returns the current `DistributionStrategy` object in a cross-tower
+ context, or None.
+
+ Exactly one of `get_tower_context()` and `get_cross_tower_context()`
+ will return None in a particular block.
+ """
+ return _get_per_thread_mode().cross_tower_context
+
+
+def get_distribution_strategy():
+ """Returns the current `DistributionStrategy` object.
+
+ Prefer to use `get_tower_context()` or `get_cross_tower_context()`
+ instead when possible.
+
+ Returns:
+ A `DistributionStrategy` object. Inside a
+ `with distribution_strategy.scope()` block, it returns
+ `distribution_strategy`, otherwise it returns the default
+ (single-tower) `DistributionStrategy` object.
+ """
+ return _get_per_thread_mode().distribution_strategy
+
+
+def has_distribution_strategy():
+ """Return if there is a current non-default `DistributionStrategy`.
+
+ Returns:
+ True if inside a `with distribution_strategy.scope():`.
+ """
+ return get_distribution_strategy() is not _get_default_distribution_strategy()
+
+
+# ------------------------------------------------------------------------------
+# Defaults that are used when no distribution strategy is explicitly created.
+# We create them lazily in a function so that we can workaround the circular
+# dependency on distribute_lib. See lazy loader at the top of this file.
+
+_defaults = {
+ "distribution_strategy": None,
+ "tower_context": None,
+ "tower_mode": None
+}
+
+
+def _get_default_distribution_strategy():
+ if _defaults["distribution_strategy"] is None:
+ _defaults["distribution_strategy"] = (
+ distribute_lib._DefaultDistributionStrategy()) # pylint: disable=protected-access
+ return _defaults["distribution_strategy"]
+
+
+def _get_default_tower_context():
+ if _defaults["tower_context"] is None:
+ _defaults["tower_context"] = distribute_lib.TowerContext(
+ _get_default_distribution_strategy(), tower_id=0)
+ return _defaults["tower_context"]
+
+
+def _get_default_tower_mode():
+ if _defaults["tower_mode"] is None:
+ _defaults["tower_mode"] = _DefaultTowerThreadMode()
+ return _defaults["tower_mode"]
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index 6d95b144d5..1b6bce2865 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -35,6 +35,7 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import slot_creator
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.util import nest
@@ -464,7 +465,8 @@ class Optimizer(
# TODO(josh11b): Test that we handle weight decay in a reasonable way.
if (distribute_lib.get_loss_reduction() ==
variable_scope.VariableAggregation.MEAN):
- num_towers = distribute_lib.get_distribution_strategy().num_towers
+ num_towers = distribution_strategy_context.get_distribution_strategy(
+ ).num_towers
if num_towers > 1:
loss_value *= (1. / num_towers)
@@ -482,7 +484,8 @@ class Optimizer(
# Scale loss if using a "mean" loss reduction and multiple towers.
if (distribute_lib.get_loss_reduction() ==
variable_scope.VariableAggregation.MEAN):
- num_towers = distribute_lib.get_distribution_strategy().num_towers
+ num_towers = distribution_strategy_context.get_distribution_strategy(
+ ).num_towers
if num_towers > 1:
loss *= (1. / num_towers)
@@ -548,15 +551,15 @@ class Optimizer(
# methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse().
# Handle DistributionStrategy case.
- if distribute_lib.get_cross_tower_context():
+ if distribution_strategy_context.get_cross_tower_context():
raise RuntimeError("Use `_distributed_apply()` instead of "
"`apply_gradients()` in a cross-tower context.")
# TODO(isaprykin): Get rid of `has_distribution_strategy()` check by
# always calling _distributed_apply(), using the default distribution
# as needed.
- if distribute_lib.has_distribution_strategy():
+ if distribution_strategy_context.has_distribution_strategy():
grads_and_vars = get_filtered_grad_fn(lambda: grads_and_vars)()
- return distribute_lib.get_tower_context().merge_call(
+ return distribution_strategy_context.get_tower_context().merge_call(
self._distributed_apply, grads_and_vars, global_step, name)
# No DistributionStrategy case.
@@ -799,7 +802,8 @@ class Optimizer(
v = self._non_slot_dict.get(key, None)
if v is None:
self._maybe_initialize_checkpointable()
- distribution_strategy = distribute_lib.get_distribution_strategy()
+ distribution_strategy = (
+ distribution_strategy_context.get_distribution_strategy())
with distribution_strategy.colocate_vars_with(colocate_with):
if eager:
restored_initial_value = self._preload_simple_restoration(
diff --git a/tensorflow/python/training/slot_creator.py b/tensorflow/python/training/slot_creator.py
index 258a6f045d..d76b22acd8 100644
--- a/tensorflow/python/training/slot_creator.py
+++ b/tensorflow/python/training/slot_creator.py
@@ -45,7 +45,7 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
-from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
def _create_slot_var(primary, val, scope, validate_shape, shape, dtype):
@@ -112,7 +112,8 @@ def create_slot(primary, val, name, colocate_with_primary=True):
prefix = primary.op.name
with variable_scope.variable_scope(None, prefix + "/" + name):
if colocate_with_primary:
- distribution_strategy = distribute_lib.get_distribution_strategy()
+ distribution_strategy = (
+ distribution_strategy_context.get_distribution_strategy())
with distribution_strategy.colocate_vars_with(primary):
return _create_slot_var(primary, val, "", validate_shape, None, None)
else:
@@ -149,7 +150,8 @@ def create_slot_with_initializer(primary, initializer, shape, dtype, name,
prefix = primary.op.name
with variable_scope.variable_scope(None, prefix + "/" + name):
if colocate_with_primary:
- distribution_strategy = distribute_lib.get_distribution_strategy()
+ distribution_strategy = (
+ distribution_strategy_context.get_distribution_strategy())
with distribution_strategy.colocate_vars_with(primary):
return _create_slot_var(primary, initializer, "", validate_shape, shape,
dtype)