diff options
25 files changed, 431 insertions, 305 deletions
diff --git a/tensorflow/contrib/distribute/__init__.py b/tensorflow/contrib/distribute/__init__.py index 9123ca749b..5fa57f494c 100644 --- a/tensorflow/contrib/distribute/__init__.py +++ b/tensorflow/contrib/distribute/__init__.py @@ -22,13 +22,14 @@ from __future__ import print_function from tensorflow.contrib.distribute.python.collective_all_reduce_strategy import CollectiveAllReduceStrategy from tensorflow.contrib.distribute.python.cross_tower_ops import * from tensorflow.contrib.distribute.python.mirrored_strategy import MirroredStrategy -from tensorflow.contrib.distribute.python.multi_worker_strategy import MultiWorkerMirroredStrategy from tensorflow.contrib.distribute.python.monitor import Monitor +from tensorflow.contrib.distribute.python.multi_worker_strategy import MultiWorkerMirroredStrategy from tensorflow.contrib.distribute.python.one_device_strategy import OneDeviceStrategy from tensorflow.contrib.distribute.python.parameter_server_strategy import ParameterServerStrategy from tensorflow.contrib.distribute.python.step_fn import * from tensorflow.contrib.distribute.python.tpu_strategy import TPUStrategy from tensorflow.python.training.distribute import * +from tensorflow.python.training.distribution_strategy_context import * from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index a1efbcaf9a..aeec9c44d7 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -56,7 +56,7 @@ from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.training import adam -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import gradient_descent from tensorflow.python.util import tf_inspect @@ -320,7 +320,7 @@ class NamedDistribution(object): # pylint: disable=g-long-lambda default_strategy = NamedDistribution( "Default", - lambda: distribute_lib._default_distribution_strategy, # pylint: disable=protected-access + distribution_strategy_context._get_default_distribution_strategy, # pylint: disable=protected-access required_gpus=None) one_device_strategy = NamedDistribution( "OneDeviceCPU", lambda: one_device_lib.OneDeviceStrategy("/cpu:0"), diff --git a/tensorflow/contrib/distribute/python/estimator_integration_test.py b/tensorflow/contrib/distribute/python/estimator_integration_test.py index 3e00cf4332..cc626c33bf 100644 --- a/tensorflow/contrib/distribute/python/estimator_integration_test.py +++ b/tensorflow/contrib/distribute/python/estimator_integration_test.py @@ -29,6 +29,7 @@ from tensorflow.contrib.optimizer_v2 import adagrad from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import test from tensorflow.python.estimator import run_config +from tensorflow.python.estimator import training from tensorflow.python.estimator.canned import dnn_linear_combined from tensorflow.python.estimator.canned import prediction_keys from tensorflow.python.estimator.export import export @@ -63,8 +64,9 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase, combinations.one_device_strategy, combinations.mirrored_strategy_with_gpu_and_cpu, combinations.mirrored_strategy_with_two_gpus - ])) - def test_complete_flow_with_mode(self, distribution): + ], + use_train_and_evaluate=[True, False])) + def test_complete_flow_with_mode(self, distribution, use_train_and_evaluate): label_dimension = 2 input_dimension = label_dimension batch_size = 10 @@ -103,9 +105,15 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase, train_distribute=distribution, eval_distribute=distribution)) num_steps = 10 - estimator.train(train_input_fn, steps=num_steps) + if use_train_and_evaluate: + scores, _ = training.train_and_evaluate( + estimator, + training.TrainSpec(train_input_fn, max_steps=num_steps), + training.EvalSpec(eval_input_fn)) + else: + estimator.train(train_input_fn, steps=num_steps) + scores = estimator.evaluate(eval_input_fn) - scores = estimator.evaluate(eval_input_fn) self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) self.assertIn('loss', six.iterkeys(scores)) diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index e064cfe37d..9a4cc0a897 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -40,7 +40,7 @@ from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import device_util -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context GPU_TEST = "test_gpu" in sys.argv[0] @@ -164,7 +164,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): # This variable should be created only once across the threads because of # special variable_creator functions used by `dist.call_for_each_tower`. v = variable_scope.variable(1.0, name="foo") - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call(lambda _: _) return v dist = mirrored_strategy.MirroredStrategy( @@ -181,7 +181,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): def model_fn(): v = variable_scope.variable(1.0) - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call(lambda _: _) return v dist = mirrored_strategy.MirroredStrategy( @@ -201,7 +201,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): vs = [] for i in range(5): vs.append(variable_scope.variable(1.0, name="foo" + str(i))) - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call(lambda _: _) return vs dist = mirrored_strategy.MirroredStrategy( @@ -223,7 +223,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): vs.append(variable_scope.variable(1.0, name="foo_1/bar")) vs.append(variable_scope.variable(1.0, name="foo_1/bar_1")) vs.append(variable_scope.variable(1.0, name="foo/bar_1")) - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call(lambda _: _) return vs dist = mirrored_strategy.MirroredStrategy( @@ -245,7 +245,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): def model_fn(device_id): v = variable_scope.variable(1.0, name="foo_" + str(device_id)) - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call(lambda _: _) return v dist = mirrored_strategy.MirroredStrategy( @@ -268,7 +268,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): layer2 = core.Dense(1) layer2(features) # This will pause the current thread, and execute the other thread. - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call( + lambda _: _) layer3 = core.Dense(1) layer3(features) return [(layer1.kernel, layer1.bias), @@ -300,7 +301,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): with variable_scope.variable_scope("common"): v1 = variable_scope.variable(1.0, name="var1") # This will pause the current thread, and execute the other thread. - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call( + lambda _: _) v2 = variable_scope.variable( 1.0, name="var2", @@ -343,7 +345,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): with variable_scope.variable_scope("common"): v1 = variable_scope.get_variable("var1", [1]) # This will pause the current thread, and execute the other thread. - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call( + lambda _: _) v2 = variable_scope.get_variable( "var2", [1], synchronization=variable_scope.VariableSynchronization.ON_READ, @@ -453,7 +456,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): def model_fn(): v = variable_scope.variable(1.0, name="foo") - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call(lambda _: _) return v dist = mirrored_strategy.MirroredStrategy( @@ -470,7 +473,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): def model_fn(name): v = variable_scope.variable(1.0, name=name) - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call(lambda _: _) return v dist = mirrored_strategy.MirroredStrategy( @@ -570,7 +573,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): def model_fn(): with ops.name_scope("foo"): a = constant_op.constant(1.0, name="a") - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call( + lambda _: _) b = constant_op.constant(1.0, name="b") return a, b @@ -591,7 +595,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): def model_fn(): with ops.name_scope(None, "foo"): a = constant_op.constant(1.0, name="a") - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call( + lambda _: _) b = constant_op.constant(2.0, name="b") return a, b @@ -619,7 +624,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): def model_fn(): b = variable_scope.variable(1.0, name="b") with ops.name_scope("foo"): - c = distribute_lib.get_tower_context().merge_call(in_cross_tower) + c = distribution_strategy_context.get_tower_context().merge_call( + in_cross_tower) return b, c dist = mirrored_strategy.MirroredStrategy( @@ -651,7 +657,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): def model_fn(): b = variable_scope.get_variable("b", [1]) with ops.name_scope("foo"): - c = distribute_lib.get_tower_context().merge_call(in_cross_tower) + c = distribution_strategy_context.get_tower_context().merge_call( + in_cross_tower) return b, c dist = mirrored_strategy.MirroredStrategy( @@ -833,8 +840,9 @@ class MirroredVariableUpdateTest(test.TestCase): self.assertEquals(1.0, self.evaluate(mirrored_var)) def model_fn(): - value = math_ops.cast(distribute_lib.get_tower_context().tower_id, - mirrored_var.dtype) + value = math_ops.cast( + distribution_strategy_context.get_tower_context().tower_id, + mirrored_var.dtype) return mirrored_var.assign(value) self.evaluate(dist.unwrap(dist.call_for_each_tower( @@ -898,8 +906,9 @@ class MirroredVariableUpdateTest(test.TestCase): self.assertEquals(1.0, self.evaluate(mirrored_var)) def model_fn(): - value = math_ops.cast(distribute_lib.get_tower_context().tower_id, - mirrored_var.dtype) + value = math_ops.cast( + distribution_strategy_context.get_tower_context().tower_id, + mirrored_var.dtype) return mirrored_var.assign_add(value) self.evaluate(dist.unwrap(dist.call_for_each_tower( @@ -963,8 +972,9 @@ class MirroredVariableUpdateTest(test.TestCase): self.assertEquals(5.0, self.evaluate(mirrored_var)) def model_fn(): - value = math_ops.cast(distribute_lib.get_tower_context().tower_id, - mirrored_var.dtype) + value = math_ops.cast( + distribution_strategy_context.get_tower_context().tower_id, + mirrored_var.dtype) return mirrored_var.assign_sub(value) self.evaluate(dist.unwrap(dist.call_for_each_tower( diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py index a066adf124..5db2fff239 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py @@ -24,7 +24,7 @@ from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import test_util from tensorflow.python.ops import variable_scope -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase): @@ -68,7 +68,8 @@ class VariableCreatorStackTest(test.TestCase): v = variable_scope.variable(1.0) # This will pause the current thread, and execute the other thread. - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call( + lambda _: _) return v def main_thread_creator(next_creator, *args, **kwargs): diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py index cf29c0ed91..02eb68227d 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py @@ -37,7 +37,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import device_util -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, @@ -101,7 +101,8 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, last_part_device = 'device:CPU:0' else: last_part_device = ( - 'device:GPU:%d' % distribute_lib.get_tower_context().tower_id) + 'device:GPU:%d' % + distribution_strategy_context.get_tower_context().tower_id) a = constant_op.constant(1.0) b = constant_op.constant(2.0) @@ -192,14 +193,16 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, tower_compute_device = '/device:CPU:0' else: tower_compute_device = ( - '/device:GPU:%d' % distribute_lib.get_tower_context().tower_id) + '/device:GPU:%d' % + distribution_strategy_context.get_tower_context().tower_id) tower_compute_device = device_util.canonicalize(tower_compute_device) if 'CPU' in variable_device: tower_variable_device = '/device:CPU:0' else: tower_variable_device = ( - '/device:GPU:%d' % distribute_lib.get_tower_context().tower_id) + '/device:GPU:%d' % + distribution_strategy_context.get_tower_context().tower_id) tower_variable_device = device_util.canonicalize(tower_variable_device) a = constant_op.constant(1.0) diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py index baed0ebaae..371b97ba96 100644 --- a/tensorflow/contrib/distribute/python/strategy_test_lib.py +++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py @@ -28,7 +28,7 @@ from tensorflow.python.layers import core from tensorflow.python.ops import array_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import optimizer @@ -45,7 +45,8 @@ def _raise_exception_fn(_=None): # Must be the argument to a distribution.call_for_each_tower() call, calls a # get_tower_context().merge_call() that raises an exception. def _merge_raises_fn(): - distribute_lib.get_tower_context().merge_call(_raise_exception_fn) + distribution_strategy_context.get_tower_context().merge_call( + _raise_exception_fn) # Must be the argument to a get_tower_context().merge_call() call, calls @@ -58,7 +59,7 @@ def _call_raises_fn(dist): # calls a get_tower_context().merge_call() that calls a # call_for_each_tower() that raises an exception. def _merge_call_raises_fn(): - distribute_lib.get_tower_context().merge_call(_call_raises_fn) + distribution_strategy_context.get_tower_context().merge_call(_call_raises_fn) # Must be the argument to a get_tower_context().merge_call() call, calls @@ -72,7 +73,8 @@ def _call_merge_raises_fn(dist): # get_tower_context().merge_call() that calls a call_for_each_tower() that # calls a get_tower_context().merge_call() that raises an exception. def _merge_call_merge_raises_fn(): - distribute_lib.get_tower_context().merge_call(_call_merge_raises_fn) + distribution_strategy_context.get_tower_context().merge_call( + _call_merge_raises_fn) class DistributionTestBase(test.TestCase): @@ -208,7 +210,7 @@ class DistributionTestBase(test.TestCase): expected_devices = [False] * len(d.worker_devices) def mark_devices_fn(): - tower_id = distribute_lib.get_tower_context().tower_id + tower_id = distribution_strategy_context.get_tower_context().tower_id self.assertLess(tower_id, len(d.worker_devices)) self.assertFalse(expected_devices[tower_id]) expected_devices[tower_id] = True diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 5fd4c9de69..8548a86421 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -38,6 +38,7 @@ from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variables as variables_lib from tensorflow.python.training import device_util from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import saver from tensorflow.python.training.checkpointable import base as checkpointable from tensorflow.python.util import nest @@ -56,7 +57,7 @@ class DistributedValues(object): def get(self, device=None): """Returns the value for the current device or raises a ValueError.""" if device is None: - tower_context = distribute_lib.get_tower_context() + tower_context = distribution_strategy_context.get_tower_context() if tower_context: device = tower_context.device else: @@ -289,14 +290,15 @@ class DistributedVariable(DistributedDelegate): # We want cross-tower code that does some var.op.X calls # to work (even if the current device isn't in self.devices), but # other uses of var.op in a cross-tower context to fail. - if distribute_lib.get_cross_tower_context(): + if distribution_strategy_context.get_cross_tower_context(): return DistributedVarOp(self._primary_var.op.name, self._primary_var.op.graph, self._primary_var.op.type) return self.get().op def read_value(self): - return distribute_lib.get_distribution_strategy().read_var(self) + return distribution_strategy_context.get_distribution_strategy().read_var( + self) def _should_act_as_resource_variable(self): """Pass resource_variable_ops.is_resource_variable check.""" @@ -362,7 +364,7 @@ class MirroredVariable(DistributedVariable, Mirrored, # update several non-slot variables in one call. def _assign_func(self, *args, **kwargs): f = kwargs.pop("f") - if distribute_lib.get_cross_tower_context(): + if distribution_strategy_context.get_cross_tower_context(): update_device = distribute_lib.get_update_device() # We are calling update on the mirrored variable in cross tower context. if update_device is not None: @@ -371,7 +373,7 @@ class MirroredVariable(DistributedVariable, Mirrored, v = self.get(device=update_device) return f(v, *args, **kwargs) - return distribute_lib.get_distribution_strategy().update( + return distribution_strategy_context.get_distribution_strategy().update( self, f, *args, **kwargs) else: _assert_tower_context() @@ -392,8 +394,8 @@ class MirroredVariable(DistributedVariable, Mirrored, aggregation=self._aggregation, value=value, destinations=self), *other_args, **other_kwargs) - return distribute_lib.get_tower_context().merge_call(merge_fn, *args, - **kwargs) + return distribution_strategy_context.get_tower_context().merge_call( + merge_fn, *args, **kwargs) def assign_sub(self, *args, **kwargs): assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw) @@ -419,7 +421,7 @@ class MirroredVariable(DistributedVariable, Mirrored, def _as_graph_element(self): # pylint: disable=protected-access - if distribute_lib.get_cross_tower_context(): + if distribution_strategy_context.get_cross_tower_context(): return self._primary_var._as_graph_element() return self.get()._as_graph_element() @@ -459,7 +461,7 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject): # We use a callable so that we don't have to evaluate this expression # in the case where we are trying to restore instead of save. def tensor(): - return distribute_lib.get_distribution_strategy().read_var( + return distribution_strategy_context.get_distribution_strategy().read_var( tower_local_variable) spec = saver.BaseSaverBuilder.SaveSpec( tensor=tensor, @@ -475,7 +477,7 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject): def _assert_tower_context(): - if not distribute_lib.get_tower_context(): + if not distribution_strategy_context.get_tower_context(): raise RuntimeError( "Tower-local variables may only be assigned in a tower context.") @@ -498,7 +500,7 @@ class TowerLocalVariable(DistributedVariable, PerDevice, return self.get().assign_add(*args, **kwargs) def assign(self, *args, **kwargs): - if distribute_lib.get_cross_tower_context(): + if distribution_strategy_context.get_cross_tower_context(): # To preserve the sum across save and restore, we have to divide the # total across all devices when restoring a variable that was summed # when saving. @@ -526,7 +528,7 @@ class TowerLocalVariable(DistributedVariable, PerDevice, def _as_graph_element(self): # pylint: disable=protected-access - if distribute_lib.get_cross_tower_context(): + if distribution_strategy_context.get_cross_tower_context(): return self._get_cross_tower() return self.get()._as_graph_element() @@ -994,12 +996,12 @@ class MultiStepContext(object): outputs as already reduced or not. """ - if distribute_lib.get_cross_tower_context(): + if distribution_strategy_context.get_cross_tower_context(): self._last_step_outputs_aggregations[name] = aggregation if aggregation is variables_lib.VariableAggregation.NONE: self._last_step_outputs[name] = output else: - distribution = distribute_lib.get_distribution_strategy() + distribution = distribution_strategy_context.get_distribution_strategy() self._last_step_outputs[name] = distribution.reduce( aggregation, output, destinations="/device:CPU:0") else: @@ -1011,7 +1013,9 @@ class MultiStepContext(object): # context object, so it's more robust to set it only once (even if all # the towers are trying to set the same value). self._last_step_outputs_aggregations[name] = aggregation - distribute_lib.get_tower_context().merge_call(merge_fn, output) + + distribution_strategy_context.get_tower_context().merge_call( + merge_fn, output) @property def non_tensor_outputs(self): @@ -1020,14 +1024,15 @@ class MultiStepContext(object): def set_non_tensor_output(self, name, output): """Set `output` with `name` to be captured as a non tensor output.""" - if distribute_lib.get_cross_tower_context(): + if distribution_strategy_context.get_cross_tower_context(): self._non_tensor_outputs[name] = output else: def merge_fn(distribution, value): # NOTE(priyag): For non tensor outputs, we simply return all the values # in a list as aggregation doesn't make sense on non tensors. self._non_tensor_outputs[name] = distribution.unwrap(value) - distribute_lib.get_tower_context().merge_call(merge_fn, output) + distribution_strategy_context.get_tower_context().merge_call( + merge_fn, output) def value_container(val): diff --git a/tensorflow/contrib/metrics/python/metrics/classification.py b/tensorflow/contrib/metrics/python/metrics/classification.py index e553612269..7053907da0 100644 --- a/tensorflow/contrib/metrics/python/metrics/classification.py +++ b/tensorflow/contrib/metrics/python/metrics/classification.py @@ -24,7 +24,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics_impl from tensorflow.python.ops import variable_scope -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context # TODO(nsilberman): move into metrics/python/ops/ @@ -174,7 +174,7 @@ def f1_score(labels, predictions, weights=None, num_thresholds=200, ops.add_to_collections(metrics_collections, best_f1) return best_f1 - best_f1 = distribute_lib.get_tower_context().merge_call( + best_f1 = distribution_strategy_context.get_tower_context().merge_call( f1_across_towers, values) update_op = compute_best_f1_score(tp=update_ops['tp'], fp=update_ops['fp'], diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py index 8c11d8bcfd..f6ecaba834 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py @@ -34,6 +34,7 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import optimizer as optimizer_v1 from tensorflow.python.training import slot_creator from tensorflow.python.training.checkpointable import base as checkpointable @@ -620,7 +621,7 @@ class OptimizerV2(optimizer_v1.Optimizer): # Map from graph_key to state for that graph. We use the graph_key # since it works in both eager and graph mode, and gives the outer # graph inside functions. - tower_context = distribute_lib.get_tower_context() + tower_context = distribution_strategy_context.get_tower_context() if tower_context is None: # In a cross-tower context for a DistributionStrategy, which means # only one Optimizer will be created, not one per tower. @@ -769,7 +770,8 @@ class OptimizerV2(optimizer_v1.Optimizer): distribute_lib.get_loss_reduction() == variable_scope.VariableAggregation.MEAN) if scale_loss_by_num_towers: - num_towers = distribute_lib.get_distribution_strategy().num_towers + num_towers = distribution_strategy_context.get_distribution_strategy( + ).num_towers if num_towers > 1: loss_value *= 1. / num_towers @@ -788,7 +790,8 @@ class OptimizerV2(optimizer_v1.Optimizer): distribute_lib.get_loss_reduction() == variable_scope.VariableAggregation.MEAN) if scale_loss_by_num_towers: - num_towers = distribute_lib.get_distribution_strategy().num_towers + num_towers = distribution_strategy_context.get_distribution_strategy( + ).num_towers if num_towers > 1: loss *= 1. / num_towers @@ -862,7 +865,7 @@ class OptimizerV2(optimizer_v1.Optimizer): if not filtered: raise ValueError("No gradients provided for any variable: %s." % ([str(v) for _, v in grads_and_vars],)) - return distribute_lib.get_tower_context().merge_call( + return distribution_strategy_context.get_tower_context().merge_call( self._distributed_apply, filtered, global_step=global_step, name=name) def _get_or_create_state(self, var_list=None): diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 2a71eaf030..a6bb6158e6 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -3340,7 +3340,10 @@ py_library( py_library( name = "distribute", - srcs = ["training/distribute.py"], + srcs = [ + "training/distribute.py", + "training/distribution_strategy_context.py", + ], srcs_version = "PY2AND3", deps = [ ":array_ops", diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 8b9b62e034..4958ba56c5 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -43,7 +43,7 @@ from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.training import distribute +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.util import compat from tensorflow.python.util import nest from tensorflow.python.util import tf_decorator @@ -228,6 +228,9 @@ class FuncGraph(CapturingGraph): self.get_collection_ref(collection)[:] = graph.get_collection( collection) + # Copy distribution strategy scope from the containing graph as well. + self._distribution_strategy_stack = graph._distribution_strategy_stack # pylint: disable=protected-access + if context.executing_eagerly(): self.seed = context.global_seed() else: @@ -569,7 +572,7 @@ class GraphModeFunction(object): # Find the variables that are components of something distributed and # put them into a {handle_tensor -> distributed variable object} map. self._distributed_variables = {} - strategy = distribute.get_distribution_strategy() + strategy = distribution_strategy_context.get_distribution_strategy() for variable in self._variables: # If variable is not distributed, unwrap returns [variable]. component_variables = strategy.unwrap(variable) @@ -901,7 +904,7 @@ def _trace_and_define_function(name, python_func, compiled, args, kwds, # the function is run on a different device). Thus, instead of storing # the specific captured variable, we replace it with its distributed # container. - strategy = distribute.get_distribution_strategy() + strategy = distribution_strategy_context.get_distribution_strategy() for i, variable in enumerate(variables): # If variable is not distributed value_container returns itself. variables[i] = strategy.value_container(variable) diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py index a5f07fea3b..e4ce5339d0 100644 --- a/tensorflow/python/estimator/keras.py +++ b/tensorflow/python/estimator/keras.py @@ -43,7 +43,7 @@ from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import signature_constants from tensorflow.python.training import checkpoint_management -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import optimizer as tf_optimizer_module from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import training_util @@ -361,7 +361,7 @@ def _create_keras_model_fn(keras_model, custom_objects=None): """model_fn for keras Estimator.""" # Raise an error when users use DistributionStrategy with native Keras # optimizers. Currently we only support native TensorFlow optimizers. - if distribute_lib.has_distribution_strategy() and \ + if distribution_strategy_context.has_distribution_strategy() and \ not isinstance(keras_model.optimizer, (tf_optimizer_module.Optimizer, optimizers.TFOptimizer)): raise ValueError('Only TensorFlow native optimizers are supported with ' @@ -373,7 +373,7 @@ def _create_keras_model_fn(keras_model, custom_objects=None): # We need to make sure that the output names of the last layer in the model # is the same for each of the cloned models. This is required for mirrored # strategy when we call regroup. - if distribute_lib.has_distribution_strategy(): + if distribution_strategy_context.has_distribution_strategy(): for name in model.output_names: name = re.compile(r'_\d$').sub('', name) model_output_names.append(name) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 98a1802490..5527f52860 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -4860,6 +4860,18 @@ class Graph(object): else: self._graph_control_dependencies_stack = control_dependencies + @property + def _distribution_strategy_stack(self): + """A stack to maintain distribution strategy context for each thread.""" + if not hasattr(self._thread_local, "_distribution_strategy_stack"): + self._thread_local._distribution_strategy_stack = [] # pylint: disable=protected-access + return self._thread_local._distribution_strategy_stack # pylint: disable=protected-access + + @_distribution_strategy_stack.setter + def _distribution_strategy_stack(self, _distribution_strategy_stack): + self._thread_local._distribution_strategy_stack = ( # pylint: disable=protected-access + _distribution_strategy_stack) + def _mutation_lock(self): """Returns a lock to guard code that creates & mutates ops. diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py index a7835bc0a2..cd26e04c39 100644 --- a/tensorflow/python/keras/layers/normalization.py +++ b/tensorflow/python/keras/layers/normalization.py @@ -36,7 +36,7 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.util.tf_export import tf_export @@ -345,16 +345,16 @@ class BatchNormalization(Layer): aggregation=variable_scope.VariableAggregation.MEAN) return var - with distribute_lib.get_distribution_strategy().colocate_vars_with( - self.moving_mean): + with distribution_strategy_context.get_distribution_strategy( + ).colocate_vars_with(self.moving_mean): self.renorm_mean = _renorm_variable('renorm_mean', param_shape) self.renorm_mean_weight = _renorm_variable('renorm_mean_weight', ()) # We initialize renorm_stddev to 0, and maintain the (0-initialized) # renorm_stddev_weight. This allows us to (1) mix the average # stddev with the minibatch stddev early in training, and (2) compute # the unbiased average stddev by dividing renorm_stddev by the weight. - with distribute_lib.get_distribution_strategy().colocate_vars_with( - self.moving_variance): + with distribution_strategy_context.get_distribution_strategy( + ).colocate_vars_with(self.moving_variance): self.renorm_stddev = _renorm_variable('renorm_stddev', param_shape) self.renorm_stddev_weight = _renorm_variable('renorm_stddev_weight', ()) diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py index 2dde9ee41f..9b87170ebe 100644 --- a/tensorflow/python/keras/metrics.py +++ b/tensorflow/python/keras/metrics.py @@ -55,7 +55,7 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import weights_broadcast_ops -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.util import tf_decorator from tensorflow.python.util.tf_export import tf_export @@ -111,7 +111,7 @@ def result_wrapper(result_fn): def decorated(metric_obj, *args): """Decorated function with merge_call.""" - tower_context = distribute_lib.get_tower_context() + tower_context = distribution_strategy_context.get_tower_context() if tower_context is None: # if in cross tower context already result_t = result_fn(*args) else: diff --git a/tensorflow/python/keras/optimizers.py b/tensorflow/python/keras/optimizers.py index 4f97442e82..f339a7e047 100644 --- a/tensorflow/python/keras/optimizers.py +++ b/tensorflow/python/keras/optimizers.py @@ -28,7 +28,7 @@ from tensorflow.python.keras.utils.generic_utils import serialize_keras_object from tensorflow.python.ops import clip_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import optimizer as tf_optimizer_module from tensorflow.python.training import training_util from tensorflow.python.training.checkpointable import base as checkpointable @@ -705,7 +705,7 @@ class TFOptimizer(Optimizer, checkpointable.CheckpointableBase): return self.optimizer.compute_gradients(loss, params) def get_updates(self, loss, params): - if distribute_lib.has_distribution_strategy(): + if distribution_strategy_context.has_distribution_strategy(): self.updates = [] if not params: diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py index 3aedeb6acd..9461a01515 100644 --- a/tensorflow/python/ops/metrics_impl.py +++ b/tensorflow/python/ops/metrics_impl.py @@ -34,7 +34,7 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import weights_broadcast_ops from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.util.deprecation import deprecated from tensorflow.python.util.tf_export import tf_export @@ -57,7 +57,8 @@ def metric_variable(shape, dtype, validate_shape=True, name=None): Furthermore, the final answer should be computed once instead of in every replica/tower. Both of these are accomplished by running the computation of the final result value inside - `tf.contrib.distribute.get_tower_context().merge_call(fn)`. + `tf.contrib.distribution_strategy_context.get_tower_context( + ).merge_call(fn)`. Inside the `merge_call()`, ops are only added to the graph once and access to a tower-local variable in a computation returns the sum across all replicas/towers. @@ -373,7 +374,7 @@ def mean(values, ops.add_to_collections(metrics_collections, mean_t) return mean_t - mean_t = distribute_lib.get_tower_context().merge_call( + mean_t = distribution_strategy_context.get_tower_context().merge_call( aggregate_across_towers, total, count) update_op = _safe_div(update_total_op, update_count_op, 'update_op') @@ -618,7 +619,7 @@ def _aggregate_variable(v, collections): ops.add_to_collections(collections, value) return value - return distribute_lib.get_tower_context().merge_call(f, v) + return distribution_strategy_context.get_tower_context().merge_call(f, v) @tf_export('metrics.auc') @@ -813,7 +814,7 @@ def auc(labels, ops.add_to_collections(metrics_collections, auc_value) return auc_value - auc_value = distribute_lib.get_tower_context().merge_call( + auc_value = distribution_strategy_context.get_tower_context().merge_call( aggregate_auc, values) update_op = compute_auc(update_ops['tp'], update_ops['fn'], update_ops['tn'], update_ops['fp'], 'update_op') @@ -1053,8 +1054,8 @@ def mean_per_class_accuracy(labels, ops.add_to_collections(metrics_collections, mean_accuracy_v) return mean_accuracy_v - mean_accuracy_v = distribute_lib.get_tower_context().merge_call( - aggregate_mean_accuracy, count, total) + mean_accuracy_v = distribution_strategy_context.get_tower_context( + ).merge_call(aggregate_mean_accuracy, count, total) update_op = _safe_div(update_count_op, update_total_op, name='update_op') if updates_collections: @@ -1160,7 +1161,7 @@ def mean_iou(labels, ops.add_to_collections(metrics_collections, mean_iou_v) return mean_iou_v - mean_iou_v = distribute_lib.get_tower_context().merge_call( + mean_iou_v = distribution_strategy_context.get_tower_context().merge_call( mean_iou_across_towers, total_cm) if updates_collections: @@ -1376,7 +1377,7 @@ def mean_tensor(values, ops.add_to_collections(metrics_collections, mean_t) return mean_t - mean_t = distribute_lib.get_tower_context().merge_call( + mean_t = distribution_strategy_context.get_tower_context().merge_call( aggregate_across_towers, total, count) update_op = _safe_div(update_total_op, update_count_op, 'update_op') @@ -2008,7 +2009,7 @@ def precision(labels, ops.add_to_collections(metrics_collections, p) return p - p = distribute_lib.get_tower_context().merge_call( + p = distribution_strategy_context.get_tower_context().merge_call( once_across_towers, true_p, false_p) update_op = compute_precision(true_positives_update_op, @@ -2092,7 +2093,7 @@ def precision_at_thresholds(labels, ops.add_to_collections(metrics_collections, prec) return prec - prec = distribute_lib.get_tower_context().merge_call( + prec = distribution_strategy_context.get_tower_context().merge_call( precision_across_towers, values) update_op = compute_precision(update_ops['tp'], update_ops['fp'], @@ -2188,7 +2189,7 @@ def recall(labels, ops.add_to_collections(metrics_collections, rec) return rec - rec = distribute_lib.get_tower_context().merge_call( + rec = distribution_strategy_context.get_tower_context().merge_call( once_across_towers, true_p, false_n) update_op = compute_recall(true_positives_update_op, @@ -2627,7 +2628,7 @@ def recall_at_top_k(labels, ops.add_to_collections(metrics_collections, metric) return metric - metric = distribute_lib.get_tower_context().merge_call( + metric = distribution_strategy_context.get_tower_context().merge_call( aggregate_across_towers, tp, fn) update = math_ops.div( @@ -2708,7 +2709,7 @@ def recall_at_thresholds(labels, ops.add_to_collections(metrics_collections, rec) return rec - rec = distribute_lib.get_tower_context().merge_call( + rec = distribution_strategy_context.get_tower_context().merge_call( recall_across_towers, values) update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op') @@ -2783,7 +2784,7 @@ def root_mean_squared_error(labels, ops.add_to_collections(metrics_collections, rmse) return rmse - rmse = distribute_lib.get_tower_context().merge_call( + rmse = distribution_strategy_context.get_tower_context().merge_call( once_across_towers, mse) update_rmse_op = math_ops.sqrt(update_mse_op) @@ -2886,7 +2887,7 @@ def sensitivity_at_specificity(labels, ops.add_to_collections(metrics_collections, sensitivity) return sensitivity - sensitivity = distribute_lib.get_tower_context().merge_call( + sensitivity = distribution_strategy_context.get_tower_context().merge_call( aggregate_across_towers, values) update_op = compute_sensitivity_at_specificity( @@ -3162,8 +3163,8 @@ def _streaming_sparse_average_precision_at_top_k(labels, ops.add_to_collections(metrics_collections, mean_average_precision) return mean_average_precision - mean_average_precision = distribute_lib.get_tower_context().merge_call( - aggregate_across_towers, total_var, max_var) + mean_average_precision = distribution_strategy_context.get_tower_context( + ).merge_call(aggregate_across_towers, total_var, max_var) update = _safe_scalar_div(total_update, max_update, name=scope) if updates_collections: @@ -3448,7 +3449,7 @@ def precision_at_top_k(labels, ops.add_to_collections(metrics_collections, metric) return metric - metric = distribute_lib.get_tower_context().merge_call( + metric = distribution_strategy_context.get_tower_context().merge_call( aggregate_across_towers, tp, fp) update = math_ops.div( @@ -3687,7 +3688,7 @@ def specificity_at_sensitivity(labels, ops.add_to_collections(metrics_collections, specificity) return specificity - specificity = distribute_lib.get_tower_context().merge_call( + specificity = distribution_strategy_context.get_tower_context().merge_call( aggregate_across_towers, values) update_op = compute_specificity_at_sensitivity( diff --git a/tensorflow/python/ops/summary_op_util.py b/tensorflow/python/ops/summary_op_util.py index a793f634bd..b382c3b7ce 100644 --- a/tensorflow/python/ops/summary_op_util.py +++ b/tensorflow/python/ops/summary_op_util.py @@ -23,7 +23,7 @@ import re from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging -from tensorflow.python.training import distribute +from tensorflow.python.training import distribution_strategy_context def collect(val, collections, default_collections): @@ -49,7 +49,7 @@ def skip_summary(): # TODO(priyag): Add a new optional argument that will provide multiple # alternatives to override default behavior. (e.g. run on last tower, # compute sum or mean across towers). - tower_context = distribute.get_tower_context() + tower_context = distribution_strategy_context.get_tower_context() return tower_context and tower_context.tower_id > 0 diff --git a/tensorflow/python/training/checkpoint_utils.py b/tensorflow/python/training/checkpoint_utils.py index 9b72b09f08..e6118177fd 100644 --- a/tensorflow/python/training/checkpoint_utils.py +++ b/tensorflow/python/training/checkpoint_utils.py @@ -29,7 +29,7 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import checkpoint_management -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import saver from tensorflow.python.util.tf_export import tf_export @@ -180,10 +180,10 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map): tf.errors.OpError: If missing checkpoints or tensors in checkpoints. ValueError: If missing variables in current graph. """ - if distribute_lib.get_cross_tower_context(): + if distribution_strategy_context.get_cross_tower_context(): _init_from_checkpoint(None, ckpt_dir_or_file, assignment_map) else: - distribute_lib.get_tower_context().merge_call( + distribution_strategy_context.get_tower_context().merge_call( _init_from_checkpoint, ckpt_dir_or_file, assignment_map) diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py index 581db45e80..0d8d74a096 100644 --- a/tensorflow/python/training/distribute.py +++ b/tensorflow/python/training/distribute.py @@ -21,7 +21,7 @@ from __future__ import print_function import threading from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.eager import context +from tensorflow.python.eager import context as eager_context from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -31,71 +31,11 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops.losses import losses_impl from tensorflow.python.platform import tf_logging from tensorflow.python.training import device_util +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.util import nest # ------------------------------------------------------------------------------ -# Internal API for setting the current thread mode as being either in a -# tower or cross-tower context for a particular distribution strategy. - - -class _ThreadMode(object): - - def __init__(self, dist, cross, tower): - self.distribution_strategy = dist - self.cross_tower_context = cross - self.tower_context = tower - - -class _CrossTowerThreadMode(_ThreadMode): - - def __init__(self, distribution_strategy): - _ThreadMode.__init__( - self, distribution_strategy, distribution_strategy, None) - - -class _InTowerThreadMode(_ThreadMode): - - def __init__(self, tower_ctx): - _ThreadMode.__init__( - self, tower_ctx.distribution_strategy, None, tower_ctx) - - -_per_thread_mode = threading.local() - - -def _push_per_thread_mode(context): - if not hasattr(_per_thread_mode, "stack"): - _per_thread_mode.stack = [] - _per_thread_mode.stack.append(context) - - -def _pop_per_thread_mode(): - _per_thread_mode.stack.pop(-1) - - -class _DefaultTowerThreadMode(_ThreadMode): - """Type of default value returned by `_get_per_thread_mode()`. - - Used when the thread-local stack is empty. - """ - - def __init__(self): - # _default_distribution_strategy and _default_tower_context are - # defined at the bottom of this file. - _ThreadMode.__init__( - self, _default_distribution_strategy, None, _default_tower_context) - - -def _get_per_thread_mode(): - try: - return _per_thread_mode.stack[-1] - except (AttributeError, IndexError): - # _default_tower_mode is defined at the bottom of this file. - return _default_tower_mode - - -# ------------------------------------------------------------------------------ # Context tracking whether in a distribution.update() or .update_non_slot() # call. @@ -128,96 +68,6 @@ class UpdateContext(object): # ------------------------------------------------------------------------------ -# Public API for accessing the current thread mode - - -def get_tower_context(): - """Returns the current TowerContext or None if in a cross-tower context. - - Note that execution: - 1. starts in the default (single-tower) tower context (this function - will return the default TowerContext object); - 2. switches to cross-tower context (in which case this will return - None) when entering a `with DistributionStrategy.scope():` block; - 3. switches to a (non-default) tower context inside - `call_for_each_tower(fn, ...)`; - 4. if `fn` calls `get_tower_context()->merge_call(merge_fn, ...)`, then - inside `merge_fn` you are back in the cross-tower context (and again - this function will return None). - - Note that you can also go directly from step 1 to 4 to switch to a - cross-tower context for the default `DistributionStrategy`. You may - also switch from the cross-tower context of 4 to a tower context by - calling `call_for_each_tower()`, jumping back to step 3. - - Most `DistributionStrategy` methods may only be executed in - a cross-tower context, in a tower context you should use the - `TowerContext` API instead. - - Returns: - The current `TowerContext` object when in a tower context scope, else None. - - Exactly one of `get_tower_context()` and `get_cross_tower_context()` - will return None in a particular block. - """ - return _get_per_thread_mode().tower_context - - -def get_cross_tower_context(): - """Returns the current DistributionStrategy if in a cross-tower context. - - Note that execution: - 1. starts in the default (single-tower) tower context; - 2. switches to cross-tower context when entering a - `with DistributionStrategy.scope():` block; - 3. switches to a (non-default) tower context inside - `call_for_each_tower(fn, ...)`; - 4. if `fn` calls `get_tower_context()->merge_call(merge_fn, ...)`, then - inside `merge_fn` you are back in the cross-tower context. - - Note that you can also go directly from step 1 to 4 to switch to a - cross-tower context for the default `DistributionStrategy`. You may - also switch from the cross-tower context of 4 to a tower context by - calling `call_for_each_tower()`, jumping back to step 3. - - Most `DistributionStrategy` methods may only be executed in - a cross-tower context. - - Returns: - Returns the current `DistributionStrategy` object in a cross-tower - context, or None. - - Exactly one of `get_tower_context()` and `get_cross_tower_context()` - will return None in a particular block. - """ - return _get_per_thread_mode().cross_tower_context - - -def get_distribution_strategy(): - """Returns the current `DistributionStrategy` object. - - Prefer to use `get_tower_context()` or `get_cross_tower_context()` - instead when possible. - - Returns: - A `DistributionStrategy` object. Inside a - `with distribution_strategy.scope()` block, it returns - `distribution_strategy`, otherwise it returns the default - (single-tower) `DistributionStrategy` object. - """ - return _get_per_thread_mode().distribution_strategy - - -def has_distribution_strategy(): - """Return if there is a current non-default `DistributionStrategy`. - - Returns: - True if inside a `with distribution_strategy.scope():`. - """ - return get_distribution_strategy() is not _default_distribution_strategy - - -# ------------------------------------------------------------------------------ # Public utility functions. @@ -239,7 +89,8 @@ def _require_cross_tower_context(distribution_strategy): if context.cross_tower_context is distribution_strategy: return # We have an error to report, figure out the right message. if context.distribution_strategy is not distribution_strategy: - if context.distribution_strategy is _default_distribution_strategy: + if (context.distribution_strategy is + distribution_strategy_context._get_default_distribution_strategy()): # pylint: disable=protected-access raise RuntimeError( 'Need to be inside "with distribution_strategy.scope()" for %s' % (distribution_strategy,)) @@ -272,7 +123,8 @@ def _require_distribution_strategy_scope(distribution_strategy): context = _get_per_thread_mode() if context.distribution_strategy is distribution_strategy: return # We have an error to report, figure out the right message. - if context.distribution_strategy is _default_distribution_strategy: + if (context.distribution_strategy is + distribution_strategy_context._get_default_distribution_strategy()): # pylint: disable=protected-access raise RuntimeError( 'Need to be inside "with distribution_strategy.scope()" for %s' % (distribution_strategy,)) @@ -295,7 +147,8 @@ class _CurrentDistributionContext(object): var_creator_scope, var_scope=None, default_device=None): - self._context = _CrossTowerThreadMode(distribution_strategy) + self._context = distribution_strategy_context._CrossTowerThreadMode( # pylint: disable=protected-access + distribution_strategy) self._var_creator_scope = var_creator_scope self._var_scope = var_scope if default_device: @@ -588,7 +441,7 @@ class DistributionStrategy(object): Returns: A context manager. """ - if has_distribution_strategy(): + if distribution_strategy_context.has_distribution_strategy(): _require_cross_tower_context(self) return _SameScopeAgainContext(self) @@ -740,7 +593,7 @@ class DistributionStrategy(object): In eager mode, returns `None`. In graph mode, a list of ops to execute. Empty list if nothing to be done. """ - if context.executing_eagerly(): + if eager_context.executing_eagerly(): return else: return [] @@ -757,7 +610,7 @@ class DistributionStrategy(object): In eager mode, returns `None`. In graph mode, a list of ops to execute. Empty list if nothing to be done. """ - if context.executing_eagerly(): + if eager_context.executing_eagerly(): return else: return [] @@ -1106,7 +959,8 @@ class TowerContext(object): def __init__(self, distribution_strategy, tower_id): self._distribution_strategy = distribution_strategy - self._thread_context = _InTowerThreadMode(self) + self._thread_context = distribution_strategy_context._InTowerThreadMode( # pylint: disable=protected-access + self) self._tower_id = tower_id def __enter__(self): @@ -1149,7 +1003,8 @@ class TowerContext(object): def _merge_call(self, merge_fn, *args, **kwargs): """Default implementation for single tower.""" _push_per_thread_mode( # thread-local, so not needed with multiple threads - _CrossTowerThreadMode(self._distribution_strategy)) + distribution_strategy_context._CrossTowerThreadMode( # pylint: disable=protected-access + self._distribution_strategy)) try: return merge_fn(self._distribution_strategy, *args, **kwargs) finally: @@ -1196,7 +1051,7 @@ class _DefaultDistributionStrategy(DistributionStrategy): def scope(self): """Context manager setting a variable creator and `self` as current.""" - if has_distribution_strategy(): + if distribution_strategy_context.has_distribution_strategy(): raise RuntimeError("Must not nest DistributionStrategy scopes.") def creator(next_creator, *args, **kwargs): @@ -1277,6 +1132,7 @@ class _DefaultDistributionStrategy(DistributionStrategy): raise RuntimeError("worker_device_index() method unsupported by " "_DefaultDistributionStrategy.") + # ------------------------------------------------------------------------------ # Common operations @@ -1292,20 +1148,11 @@ def increment_var(v, amount=1): def merge_fn(dist, vm): return dist.group(dist.update(vm, update)) - tower_context = get_tower_context() + tower_context = distribution_strategy_context.get_tower_context() return tower_context.merge_call(merge_fn, v) # ------------------------------------------------------------------------------ -# Singletons - -_default_distribution_strategy = _DefaultDistributionStrategy() -_default_tower_context = TowerContext( - _default_distribution_strategy, tower_id=0) -_default_tower_mode = _DefaultTowerThreadMode() - - -# ------------------------------------------------------------------------------ # We haven't yet implemented deserialization for DistributedVariables. # So here we catch any attempts to deserialize variables # when using distribution strategies. @@ -1314,7 +1161,7 @@ _original_from_proto = resource_variable_ops._from_proto_fn def _from_proto_fn(v, import_scope=None): - if has_distribution_strategy(): + if distribution_strategy_context.has_distribution_strategy(): raise NotImplementedError( "Deserialization of variables is not yet supported when using" "distributed strategies.") @@ -1323,3 +1170,10 @@ def _from_proto_fn(v, import_scope=None): resource_variable_ops._from_proto_fn = _from_proto_fn # pylint: enable=protected-access + + +#------------------------------------------------------------------------------- +# Shorthand for some methods from distribution_strategy_context. +_push_per_thread_mode = distribution_strategy_context._push_per_thread_mode # pylint: disable=protected-access +_get_per_thread_mode = distribution_strategy_context._get_per_thread_mode # pylint: disable=protected-access +_pop_per_thread_mode = distribution_strategy_context._pop_per_thread_mode # pylint: disable=protected-access diff --git a/tensorflow/python/training/distribute_test.py b/tensorflow/python/training/distribute_test.py index 694145ede7..f03bd39100 100644 --- a/tensorflow/python/training/distribute_test.py +++ b/tensorflow/python/training/distribute_test.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test from tensorflow.python.training import distribute +from tensorflow.python.training import distribution_strategy_context class _TestTowerContext(distribute.TowerContext): @@ -49,12 +50,12 @@ class _TestStrategy(distribute.DistributionStrategy): def _assert_in_default_state(t): - t.assertIs(distribute._default_tower_context, - distribute.get_tower_context()) - t.assertIs(None, distribute.get_cross_tower_context()) - t.assertIs(distribute._default_distribution_strategy, - distribute.get_distribution_strategy()) - t.assertFalse(distribute.has_distribution_strategy()) + t.assertIs(distribution_strategy_context._get_default_tower_context(), + distribution_strategy_context.get_tower_context()) + t.assertIs(None, distribution_strategy_context.get_cross_tower_context()) + t.assertIs(distribution_strategy_context._get_default_distribution_strategy(), + distribution_strategy_context.get_distribution_strategy()) + t.assertFalse(distribution_strategy_context.has_distribution_strategy()) class TestStrategyTest(test.TestCase): @@ -64,11 +65,13 @@ class TestStrategyTest(test.TestCase): dist = _TestStrategy() def run_fn(): - tower_context = distribute.get_tower_context() + tower_context = distribution_strategy_context.get_tower_context() self.assertTrue(tower_context is not None) - self.assertIs(None, distribute.get_cross_tower_context()) - self.assertTrue(distribute.has_distribution_strategy()) - self.assertIs(dist, distribute.get_distribution_strategy()) + self.assertIs(None, + distribution_strategy_context.get_cross_tower_context()) + self.assertTrue(distribution_strategy_context.has_distribution_strategy()) + self.assertIs(dist, + distribution_strategy_context.get_distribution_strategy()) self.assertEqual("foo", tower_context.merge_call(None, test_arg="foo")) expected_value = _get_test_variable( "bar", variable_scope.VariableSynchronization.AUTO, @@ -86,10 +89,12 @@ class TestStrategyTest(test.TestCase): _assert_in_default_state(self) dist = _TestStrategy() with dist.scope(): - self.assertIs(None, distribute.get_tower_context()) - self.assertIs(dist, distribute.get_cross_tower_context()) - self.assertTrue(distribute.has_distribution_strategy()) - self.assertIs(dist, distribute.get_distribution_strategy()) + self.assertIs(None, distribution_strategy_context.get_tower_context()) + self.assertIs(dist, + distribution_strategy_context.get_cross_tower_context()) + self.assertTrue(distribution_strategy_context.has_distribution_strategy()) + self.assertIs(dist, + distribution_strategy_context.get_distribution_strategy()) expected_value = _get_test_variable( "baz", variable_scope.VariableSynchronization.AUTO, variable_scope.VariableAggregation.NONE) @@ -120,15 +125,21 @@ class DefaultDistributionStrategyTest(test.TestCase): _assert_in_default_state(self) def merge_fn(dist, s): - self.assertIs(distribute._default_distribution_strategy, dist) - self.assertIs(None, distribute.get_tower_context()) - self.assertIs(dist, distribute.get_cross_tower_context()) - self.assertIs(dist, distribute.get_distribution_strategy()) - self.assertFalse(distribute.has_distribution_strategy()) + self.assertIs( + distribution_strategy_context._get_default_distribution_strategy(), + dist) + self.assertIs(None, distribution_strategy_context.get_tower_context()) + self.assertIs(dist, + distribution_strategy_context.get_cross_tower_context()) + self.assertIs(dist, + distribution_strategy_context.get_distribution_strategy()) + self.assertFalse( + distribution_strategy_context.has_distribution_strategy()) return "foo_" + s - tower_ctx = distribute.get_tower_context() - self.assertIs(distribute._default_tower_context, tower_ctx) + tower_ctx = distribution_strategy_context.get_tower_context() + self.assertIs(distribution_strategy_context._get_default_tower_context(), + tower_ctx) self.assertEqual("foo_bar", tower_ctx.merge_call(merge_fn, "bar")) _assert_in_default_state(self) diff --git a/tensorflow/python/training/distribution_strategy_context.py b/tensorflow/python/training/distribution_strategy_context.py new file mode 100644 index 0000000000..998b5c35ce --- /dev/null +++ b/tensorflow/python/training/distribution_strategy_context.py @@ -0,0 +1,203 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility to get distribution strategy related contexts.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.util.lazy_loader import LazyLoader + + +# There is a circular dependency between this and `distribute` module. So we +# load it lazily to workaround this. +distribute_lib = LazyLoader( + "distribute_lib", globals(), + "tensorflow.python.training.distribute") + +# ------------------------------------------------------------------------------ +# Internal API for setting the current thread mode as being either in a +# tower or cross-tower context for a particular distribution strategy. + + +class _ThreadMode(object): + + def __init__(self, dist, cross, tower): + self.distribution_strategy = dist + self.cross_tower_context = cross + self.tower_context = tower + + +class _CrossTowerThreadMode(_ThreadMode): + + def __init__(self, distribution_strategy): + _ThreadMode.__init__( + self, distribution_strategy, distribution_strategy, None) + + +class _InTowerThreadMode(_ThreadMode): + + def __init__(self, tower_ctx): + _ThreadMode.__init__( + self, tower_ctx.distribution_strategy, None, tower_ctx) + + +def _push_per_thread_mode(context): + ops.get_default_graph()._distribution_strategy_stack.append(context) # pylint: disable=protected-access + + +def _pop_per_thread_mode(): + ops.get_default_graph()._distribution_strategy_stack.pop(-1) # pylint: disable=protected-access + + +class _DefaultTowerThreadMode(_ThreadMode): + """Type of default value returned by `_get_per_thread_mode()`. + + Used when the thread-local stack is empty. + """ + + def __init__(self): + _ThreadMode.__init__(self, _get_default_distribution_strategy(), None, + _get_default_tower_context()) + + +def _get_per_thread_mode(): + try: + return ops.get_default_graph()._distribution_strategy_stack[-1] # pylint: disable=protected-access + except (AttributeError, IndexError): + return _get_default_tower_mode() + + +# ------------------------------------------------------------------------------ +# Public API for accessing the current thread mode + + +def get_tower_context(): + """Returns the current TowerContext or None if in a cross-tower context. + + Note that execution: + 1. starts in the default (single-tower) tower context (this function + will return the default TowerContext object); + 2. switches to cross-tower context (in which case this will return + None) when entering a `with DistributionStrategy.scope():` block; + 3. switches to a (non-default) tower context inside + `call_for_each_tower(fn, ...)`; + 4. if `fn` calls `get_tower_context()->merge_call(merge_fn, ...)`, then + inside `merge_fn` you are back in the cross-tower context (and again + this function will return None). + + Note that you can also go directly from step 1 to 4 to switch to a + cross-tower context for the default `DistributionStrategy`. You may + also switch from the cross-tower context of 4 to a tower context by + calling `call_for_each_tower()`, jumping back to step 3. + + Most `DistributionStrategy` methods may only be executed in + a cross-tower context, in a tower context you should use the + `TowerContext` API instead. + + Returns: + The current `TowerContext` object when in a tower context scope, else None. + + Exactly one of `get_tower_context()` and `get_cross_tower_context()` + will return None in a particular block. + """ + return _get_per_thread_mode().tower_context + + +def get_cross_tower_context(): + """Returns the current DistributionStrategy if in a cross-tower context. + + Note that execution: + 1. starts in the default (single-tower) tower context; + 2. switches to cross-tower context when entering a + `with DistributionStrategy.scope():` block; + 3. switches to a (non-default) tower context inside + `call_for_each_tower(fn, ...)`; + 4. if `fn` calls `get_tower_context()->merge_call(merge_fn, ...)`, then + inside `merge_fn` you are back in the cross-tower context. + + Note that you can also go directly from step 1 to 4 to switch to a + cross-tower context for the default `DistributionStrategy`. You may + also switch from the cross-tower context of 4 to a tower context by + calling `call_for_each_tower()`, jumping back to step 3. + + Most `DistributionStrategy` methods may only be executed in + a cross-tower context. + + Returns: + Returns the current `DistributionStrategy` object in a cross-tower + context, or None. + + Exactly one of `get_tower_context()` and `get_cross_tower_context()` + will return None in a particular block. + """ + return _get_per_thread_mode().cross_tower_context + + +def get_distribution_strategy(): + """Returns the current `DistributionStrategy` object. + + Prefer to use `get_tower_context()` or `get_cross_tower_context()` + instead when possible. + + Returns: + A `DistributionStrategy` object. Inside a + `with distribution_strategy.scope()` block, it returns + `distribution_strategy`, otherwise it returns the default + (single-tower) `DistributionStrategy` object. + """ + return _get_per_thread_mode().distribution_strategy + + +def has_distribution_strategy(): + """Return if there is a current non-default `DistributionStrategy`. + + Returns: + True if inside a `with distribution_strategy.scope():`. + """ + return get_distribution_strategy() is not _get_default_distribution_strategy() + + +# ------------------------------------------------------------------------------ +# Defaults that are used when no distribution strategy is explicitly created. +# We create them lazily in a function so that we can workaround the circular +# dependency on distribute_lib. See lazy loader at the top of this file. + +_defaults = { + "distribution_strategy": None, + "tower_context": None, + "tower_mode": None +} + + +def _get_default_distribution_strategy(): + if _defaults["distribution_strategy"] is None: + _defaults["distribution_strategy"] = ( + distribute_lib._DefaultDistributionStrategy()) # pylint: disable=protected-access + return _defaults["distribution_strategy"] + + +def _get_default_tower_context(): + if _defaults["tower_context"] is None: + _defaults["tower_context"] = distribute_lib.TowerContext( + _get_default_distribution_strategy(), tower_id=0) + return _defaults["tower_context"] + + +def _get_default_tower_mode(): + if _defaults["tower_mode"] is None: + _defaults["tower_mode"] = _DefaultTowerThreadMode() + return _defaults["tower_mode"] diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index 6d95b144d5..1b6bce2865 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import slot_creator from tensorflow.python.training.checkpointable import base as checkpointable from tensorflow.python.util import nest @@ -464,7 +465,8 @@ class Optimizer( # TODO(josh11b): Test that we handle weight decay in a reasonable way. if (distribute_lib.get_loss_reduction() == variable_scope.VariableAggregation.MEAN): - num_towers = distribute_lib.get_distribution_strategy().num_towers + num_towers = distribution_strategy_context.get_distribution_strategy( + ).num_towers if num_towers > 1: loss_value *= (1. / num_towers) @@ -482,7 +484,8 @@ class Optimizer( # Scale loss if using a "mean" loss reduction and multiple towers. if (distribute_lib.get_loss_reduction() == variable_scope.VariableAggregation.MEAN): - num_towers = distribute_lib.get_distribution_strategy().num_towers + num_towers = distribution_strategy_context.get_distribution_strategy( + ).num_towers if num_towers > 1: loss *= (1. / num_towers) @@ -548,15 +551,15 @@ class Optimizer( # methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse(). # Handle DistributionStrategy case. - if distribute_lib.get_cross_tower_context(): + if distribution_strategy_context.get_cross_tower_context(): raise RuntimeError("Use `_distributed_apply()` instead of " "`apply_gradients()` in a cross-tower context.") # TODO(isaprykin): Get rid of `has_distribution_strategy()` check by # always calling _distributed_apply(), using the default distribution # as needed. - if distribute_lib.has_distribution_strategy(): + if distribution_strategy_context.has_distribution_strategy(): grads_and_vars = get_filtered_grad_fn(lambda: grads_and_vars)() - return distribute_lib.get_tower_context().merge_call( + return distribution_strategy_context.get_tower_context().merge_call( self._distributed_apply, grads_and_vars, global_step, name) # No DistributionStrategy case. @@ -799,7 +802,8 @@ class Optimizer( v = self._non_slot_dict.get(key, None) if v is None: self._maybe_initialize_checkpointable() - distribution_strategy = distribute_lib.get_distribution_strategy() + distribution_strategy = ( + distribution_strategy_context.get_distribution_strategy()) with distribution_strategy.colocate_vars_with(colocate_with): if eager: restored_initial_value = self._preload_simple_restoration( diff --git a/tensorflow/python/training/slot_creator.py b/tensorflow/python/training/slot_creator.py index 258a6f045d..d76b22acd8 100644 --- a/tensorflow/python/training/slot_creator.py +++ b/tensorflow/python/training/slot_creator.py @@ -45,7 +45,7 @@ from tensorflow.python.ops import init_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context def _create_slot_var(primary, val, scope, validate_shape, shape, dtype): @@ -112,7 +112,8 @@ def create_slot(primary, val, name, colocate_with_primary=True): prefix = primary.op.name with variable_scope.variable_scope(None, prefix + "/" + name): if colocate_with_primary: - distribution_strategy = distribute_lib.get_distribution_strategy() + distribution_strategy = ( + distribution_strategy_context.get_distribution_strategy()) with distribution_strategy.colocate_vars_with(primary): return _create_slot_var(primary, val, "", validate_shape, None, None) else: @@ -149,7 +150,8 @@ def create_slot_with_initializer(primary, initializer, shape, dtype, name, prefix = primary.op.name with variable_scope.variable_scope(None, prefix + "/" + name): if colocate_with_primary: - distribution_strategy = distribute_lib.get_distribution_strategy() + distribution_strategy = ( + distribution_strategy_context.get_distribution_strategy()) with distribution_strategy.colocate_vars_with(primary): return _create_slot_var(primary, initializer, "", validate_shape, shape, dtype) |