diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-29 10:17:53 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-29 10:22:42 -0700 |
commit | aca93368a979419360c1fd84b53b1766b19ba81a (patch) | |
tree | 2312ef53a30251ec2f5538d43ba066550679f6d9 | |
parent | 8a22fa7037332fc6066459ce8c6fabcd77c6ece4 (diff) |
Add new aggregation mode "ONLY_FIRST_TOWER" and use it for the global
step counter. This allows us to get rid of the increment_var()
function and just use a standard assign_add().
PiperOrigin-RevId: 210743165
18 files changed, 210 insertions, 46 deletions
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 94deb2a432..c524d8b394 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -279,10 +279,11 @@ cuda_py_test( ":strategy_test_lib", "//tensorflow/python:distribute", "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", "//tensorflow/python:layers", + "//tensorflow/python:state_ops", "//tensorflow/python:variable_scope", - "//tensorflow/python:array_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index b44edfbd27..b4233a5eed 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -65,7 +65,7 @@ class _RequestedStop(Exception): pass -# Make _call_for_each_tower and _reduce_non_distributed_value not members of +# _call_for_each_tower and _reduce_non_distributed_value are not members of # MirroredStrategy so that they are generally not allowed to use anything # specific to MirroredStrategy and thus can be shared with other distribution # strategies. @@ -197,10 +197,12 @@ def _reduce_non_distributed_value(distribution, aggregation, value, # and equal to 0. if value == 0: return 0 - # If the aggregation type is MEAN, then this essentially means that the same - # value should be on all destinations. - if aggregation == variable_scope.VariableAggregation.MEAN: - return distribution.broadcast(value, destinations) + # If the aggregation type is MEAN or ONLY_FIRST_TOWER, then this + # essentially means that the same value should be on all destinations. + if aggregation in ( + variable_scope.VariableAggregation.MEAN, + variable_scope.VariableAggregation.ONLY_FIRST_TOWER): + return value cross_tower_ops_lib.validate_destinations(destinations) # We do not support an aggregation type of SUM if the value is the same across @@ -208,8 +210,8 @@ def _reduce_non_distributed_value(distribution, aggregation, value, # and summing up identical values across towers is not clearly defined. if (len(distribution.worker_devices) != 1 or not cross_tower_ops_lib.check_destinations(destinations)): - raise ValueError("A non-DistributedValues value cannot be reduced with the " - "given aggregation.") + raise ValueError("A non-DistributedValues value %s cannot be reduced with " + "the given aggregation %s." % (value, aggregation)) # TODO(anjalisridhar): Moves these methods to a device utility file? devices = cross_tower_ops_lib.get_devices_from(destinations) if len(devices) == 1: @@ -254,11 +256,12 @@ def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs): # Get aggregation value aggregation = kwargs.pop("aggregation", variable_scope.VariableAggregation.NONE) - if aggregation not in [ + if aggregation not in ( variable_scope.VariableAggregation.NONE, variable_scope.VariableAggregation.SUM, - variable_scope.VariableAggregation.MEAN - ]: + variable_scope.VariableAggregation.MEAN, + variable_scope.VariableAggregation.ONLY_FIRST_TOWER + ): raise ValueError("Invalid variable aggregation mode: " + aggregation + " for variable: " + kwargs["name"]) @@ -591,10 +594,18 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): # which case `value` would be a single value or value could be 0. return _reduce_non_distributed_value(self, aggregation, value, destinations) + if aggregation == variable_scope.VariableAggregation.ONLY_FIRST_TOWER: + value = value.get(self._devices[0]) + if isinstance(value, (int, float)): + return value + return self.broadcast(value, destinations) return self._get_cross_tower_ops().reduce( aggregation, value, destinations=destinations) def _batch_reduce(self, aggregation, value_destination_pairs): + if aggregation == variable_scope.VariableAggregation.ONLY_FIRST_TOWER: + return [self.broadcast(v.get(self._devices[0]), d) + for v, d in value_destination_pairs] return self._get_cross_tower_ops().batch_reduce(aggregation, value_destination_pairs) diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index a12ff662db..830681a4ce 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -38,6 +38,7 @@ from tensorflow.python.layers import core from tensorflow.python.ops import math_ops from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn_cell_impl +from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import device_util @@ -128,6 +129,25 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): expected = sum(range(len(dist.worker_devices))) self.assertEqual(expected, self.evaluate(unwrapped[0])) + @test_util.run_in_graph_and_eager_modes + def testReduceOnlyFirstTowerUpdates(self): + if not GPU_TEST: + self.skipTest("Not GPU test") + + def run_fn(device_id): + return constant_op.constant(3 + 5 * device_id) + + dist = self._get_distribution_strategy() + with dist.scope(): + result = dist.call_for_each_tower(run_fn, dist.worker_device_index) + reduced = dist.reduce( + variable_scope.VariableAggregation.ONLY_FIRST_TOWER, + result, + destinations="/device:CPU:0") + unwrapped = dist.unwrap(reduced) + self.assertEqual(1, len(unwrapped)) + self.assertEqual(3, self.evaluate(unwrapped[0])) + @test_util.run_in_graph_and_eager_modes() def testReduceToMultipleDestinations(self): if not GPU_TEST: @@ -384,6 +404,84 @@ class MirroredStrategyVariableCreationTest(test.TestCase): v3.aggregation) @test_util.run_in_graph_and_eager_modes(config=config) + def testOnlyFirstTowerUpdatesVariables(self): + self._skip_eager_if_gpus_less_than(1) + + def create_fn(): + aggregation = variable_scope.VariableAggregation.ONLY_FIRST_TOWER + v0 = variable_scope.variable( + 2.0, + name="on_read", + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=aggregation) + v1 = variable_scope.variable( + 3.0, + name="on_write", + synchronization=variable_scope.VariableSynchronization.ON_WRITE, + aggregation=aggregation) + return v0, v1 + + devices = ["/device:GPU:0", "/device:CPU:0"] + dist = mirrored_strategy.MirroredStrategy(devices) + with dist.scope(): + v0, v1 = dist.call_for_each_tower(create_fn, run_concurrently=False) + self.evaluate(v0.initializer) + self.assertEqual(2.0, self.evaluate(v0.get(devices[0]))) + self.assertEqual(2.0, self.evaluate(v0.get(devices[1]))) + self.assertEqual(2.0, self.evaluate(dist.read_var(v0))) + self.evaluate(v1.initializer) + self.assertEqual(3.0, self.evaluate(v1.get(devices[0]))) + self.assertEqual(3.0, self.evaluate(v1.get(devices[1]))) + self.assertEqual(3.0, self.evaluate(dist.read_var(v1))) + + # Update using the assign_add member function. + def update_member_fn(device_id): + update0 = v0.assign_add(5.0 * (device_id + 1)) + update1 = v1.assign_add(7.0 * (device_id + 1)) + return update0, update1 + + update0a, update1a = dist.call_for_each_tower( + update_member_fn, dist.worker_device_index, run_concurrently=False) + + # Update "sync on read" variable. + self.evaluate(dist.group(update0a)) + self.assertEqual(2.0 + 5.0, self.evaluate(v0.get(devices[0]))) + # Writes are not synchronized for "sync on read" variables, + # so device[1] can end up with a different value. + self.assertEqual(2.0 + 2*5.0, self.evaluate(v0.get(devices[1]))) + # Always reads from device 0. + self.assertEqual(2.0 + 5.0, self.evaluate(dist.read_var(v0))) + + # Update "sync on write" variable. + self.evaluate(dist.group(update1a)) + self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[0]))) + # Writes are synchronized for v1, only the argument to assign_add on + # device[0] is used. + self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[1]))) + self.assertEqual(3.0 + 7.0, self.evaluate(dist.read_var(v1))) + + # Update using state_ops.assign_add global function. + def update_state_ops_fn(device_id): + update0 = state_ops.assign_add(v0, 11.0 * (device_id + 1)) + update1 = state_ops.assign_add(v1, 13.0 * (device_id + 1)) + return update0, update1 + + update0b, update1b = dist.call_for_each_tower( + update_state_ops_fn, dist.worker_device_index, run_concurrently=False) + self.evaluate(dist.group(update0b)) + + # Update "sync on read" variable. + self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate(v0.get(devices[0]))) + self.assertEqual(2.0 + 2*5.0 + 2*11.0, self.evaluate(v0.get(devices[1]))) + self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate(dist.read_var(v0))) + + # Update "sync on write" variable. + self.evaluate(dist.group(update1b)) + self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[0]))) + self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[1]))) + self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(dist.read_var(v1))) + + @test_util.run_in_graph_and_eager_modes(config=config) def testNoneSynchronizationWithGetVariable(self): self._skip_eager_if_gpus_less_than(1) devices = ["/device:CPU:0", "/device:GPU:0"] @@ -804,8 +902,8 @@ class MirroredVariableUpdateTest(test.TestCase): return mirrored_var.assign(5.0) with self.assertRaisesRegexp( - ValueError, "A non-DistributedValues value cannot be reduced with " - "the given aggregation."): + ValueError, "A non-DistributedValues value 5.0 cannot be reduced " + "with the given aggregation VariableAggregation.SUM."): self.evaluate(dist.unwrap(dist.call_for_each_tower(model_fn))) @test_util.run_in_graph_and_eager_modes(config=config) diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py index 361c8be590..0f439f6c1f 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py @@ -235,7 +235,8 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): if aggregation not in ( vs.VariableAggregation.NONE, vs.VariableAggregation.SUM, - vs.VariableAggregation.MEAN + vs.VariableAggregation.MEAN, + vs.VariableAggregation.ONLY_FIRST_TOWER ): raise ValueError("Invalid variable aggregation mode: " + aggregation + " for variable: " + kwargs["name"]) @@ -302,10 +303,15 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): # pylint: disable=protected-access return mirrored_strategy._reduce_non_distributed_value( self, aggregation, value, destinations) + if aggregation == vs.VariableAggregation.ONLY_FIRST_TOWER: + return self.broadcast(value.get(self._compute_devices[0]), destinations) return self._cross_tower_ops.reduce( aggregation, value, destinations=destinations) def _batch_reduce(self, aggregation, value_destination_pairs): + if aggregation == vs.VariableAggregation.ONLY_FIRST_TOWER: + return [self.broadcast(v.get(self._compute_devices[0]), d) + for v, d in value_destination_pairs] for _, destinations in value_destination_pairs: self._verify_destinations_not_different_worker(destinations) return self._cross_tower_ops.batch_reduce(aggregation, diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py index 0e2bfcec5f..cbf18bf1d2 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py @@ -286,18 +286,22 @@ class ParameterServerStrategyTestBase( y = variable_scope.get_variable( 'y', initializer=20.0, aggregation=variable_scope.VariableAggregation.SUM) + z = variable_scope.get_variable( + 'z', initializer=30.0, + aggregation=variable_scope.VariableAggregation.ONLY_FIRST_TOWER) # We explicitly make a constant tensor here to avoid complaints about # summing non-distributed values. one = constant_op.constant(1.0) x_add = x.assign_add(one, use_locking=True) y_add = y.assign_add(one, use_locking=True) + z_add = z.assign_add(one, use_locking=True) - train_op = control_flow_ops.group([x_add, y_add]) - return x, y, train_op + train_op = control_flow_ops.group(x_add, y_add, z_add) + return x, y, z, train_op - x, y, train_op = d.call_for_each_tower(model_fn) - train_op = d.group(d.unwrap(train_op)) + x, y, z, train_op = d.call_for_each_tower(model_fn) + train_op = d.group(train_op) if context.num_gpus() < d._num_gpus_per_worker: return True @@ -323,11 +327,13 @@ class ParameterServerStrategyTestBase( self._finish_condition.notify_all() self._finish_condition.release() - x_val, y_val = sess.run([x, y]) + x_val, y_val, z_val = sess.run([x, y, z]) self.assertEqual(x_val, 10.0 + 1.0 * num_workers * d.num_towers) self.assertEqual(y_val, 20.0 + 1.0 * num_workers * d.num_towers) + self.assertEqual(z_val, 30.0 + 1.0 * num_workers) return (x_val == 10.0 + 1.0 * num_workers * d.num_towers and - y_val == 20.0 + 1.0 * num_workers * d.num_towers) + y_val == 20.0 + 1.0 * num_workers * d.num_towers and + z_val == 30.0 + 1.0 * num_workers) def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): d, master_target = self._get_test_objects(task_type, task_id, num_gpus) diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 6202a0750a..d0dbbd0da8 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -238,6 +238,9 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): if aggregation == vs.VariableAggregation.MEAN: # TODO(jhseu): Revisit once we support model-parallelism. value *= (1. / self.num_towers) + elif aggregation != vs.VariableAggregation.SUM: + raise NotImplementedError( + 'Currently only support sum & mean in TPUStrategy.') return tpu_ops.cross_replica_sum(value) cf_context = cf_context.outer_context @@ -251,6 +254,8 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): else: raise ValueError('Multiple devices are not supported for TPUStrategy') + if aggregation == vs.VariableAggregation.ONLY_FIRST_TOWER: + return value[0] output = math_ops.add_n(value) if aggregation == vs.VariableAggregation.MEAN: return output * (1. / len(value)) diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 3ccaa2690e..479b7f39d6 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -523,6 +523,8 @@ class TowerLocalVariable(DistributedVariable, PerDevice, return self._aggregation def _get_cross_tower(self): + if self._aggregation == vs.VariableAggregation.ONLY_FIRST_TOWER: + return self._primary_var all_components = tuple(self._index.values()) # TODO(josh11b): Use a strategy-specific method. total = math_ops.add_n(all_components) diff --git a/tensorflow/contrib/estimator/python/estimator/baseline_test.py b/tensorflow/contrib/estimator/python/estimator/baseline_test.py index 505c94e971..513feb03b6 100644 --- a/tensorflow/contrib/estimator/python/estimator/baseline_test.py +++ b/tensorflow/contrib/estimator/python/estimator/baseline_test.py @@ -37,13 +37,13 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables from tensorflow.python.ops.losses import losses from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache from tensorflow.python.training import checkpoint_utils -from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import optimizer from tensorflow.python.training import saver @@ -339,7 +339,7 @@ class BaselineEstimatorTrainingTest(test.TestCase): self.assertEquals(0, loss.shape.ndims) if expected_loss is None: if global_step is not None: - return distribute_lib.increment_var(global_step) + return state_ops.assign_add(global_step, 1).op return control_flow_ops.no_op() assert_loss = assert_close( math_ops.to_float(expected_loss, name='expected'), @@ -347,7 +347,7 @@ class BaselineEstimatorTrainingTest(test.TestCase): name='assert_loss') with ops.control_dependencies((assert_loss,)): if global_step is not None: - return distribute_lib.increment_var(global_step) + return state_ops.assign_add(global_step, 1).op return control_flow_ops.no_op() mock_optimizer = test.mock.NonCallableMock( diff --git a/tensorflow/python/estimator/canned/baseline_test.py b/tensorflow/python/estimator/canned/baseline_test.py index e46a3a156d..1df7216ba6 100644 --- a/tensorflow/python/estimator/canned/baseline_test.py +++ b/tensorflow/python/estimator/canned/baseline_test.py @@ -42,13 +42,13 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache from tensorflow.python.training import checkpoint_utils -from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import input as input_lib from tensorflow.python.training import optimizer from tensorflow.python.training import queue_runner @@ -490,7 +490,7 @@ class BaselineRegressorTrainingTest(test.TestCase): self.assertEquals(0, loss.shape.ndims) if expected_loss is None: if global_step is not None: - return distribute_lib.increment_var(global_step) + return state_ops.assign_add(global_step, 1).op return control_flow_ops.no_op() assert_loss = assert_close( math_ops.to_float(expected_loss, name='expected'), @@ -498,7 +498,7 @@ class BaselineRegressorTrainingTest(test.TestCase): name='assert_loss') with ops.control_dependencies((assert_loss,)): if global_step is not None: - return distribute_lib.increment_var(global_step) + return state_ops.assign_add(global_step, 1).op return control_flow_ops.no_op() mock_optimizer = test.mock.NonCallableMock( @@ -693,13 +693,13 @@ class BaselineClassifierTrainingTest(test.TestCase): # Verify loss. We can't check the value directly, so we add an assert op. self.assertEquals(0, loss.shape.ndims) if expected_loss is None: - return distribute_lib.increment_var(global_step) + return state_ops.assign_add(global_step, 1).op assert_loss = assert_close( math_ops.to_float(expected_loss, name='expected'), loss, name='assert_loss') with ops.control_dependencies((assert_loss,)): - return distribute_lib.increment_var(global_step) + return state_ops.assign_add(global_step, 1).op mock_optimizer = test.mock.NonCallableMock( spec=optimizer.Optimizer, diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py index ef7c217190..d104c961d3 100644 --- a/tensorflow/python/estimator/canned/boosted_trees.py +++ b/tensorflow/python/estimator/canned/boosted_trees.py @@ -38,7 +38,6 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops.losses import losses from tensorflow.python.summary import summary -from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import session_run_hook from tensorflow.python.training import training_util from tensorflow.python.util.tf_export import estimator_export @@ -876,7 +875,7 @@ def _bt_model_fn( train_op.append(update_model) with ops.control_dependencies([update_model]): - increment_global = distribute_lib.increment_var(global_step) + increment_global = state_ops.assign_add(global_step, 1).op train_op.append(increment_global) return control_flow_ops.group(train_op, name='train_op') diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined.py b/tensorflow/python/estimator/canned/dnn_linear_combined.py index 4945c3ba11..62a1adf78c 100644 --- a/tensorflow/python/estimator/canned/dnn_linear_combined.py +++ b/tensorflow/python/estimator/canned/dnn_linear_combined.py @@ -31,10 +31,10 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import nn from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops.losses import losses from tensorflow.python.summary import summary -from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import sync_replicas_optimizer from tensorflow.python.training import training_util from tensorflow.python.util.tf_export import estimator_export @@ -222,7 +222,7 @@ def _dnn_linear_combined_model_fn(features, train_op = control_flow_ops.group(*train_ops) with ops.control_dependencies([train_op]): - return distribute_lib.increment_var(global_step) + return state_ops.assign_add(global_step, 1).op return head.create_estimator_spec( features=features, diff --git a/tensorflow/python/estimator/canned/dnn_testing_utils.py b/tensorflow/python/estimator/canned/dnn_testing_utils.py index de226ed0ef..11f1e93630 100644 --- a/tensorflow/python/estimator/canned/dnn_testing_utils.py +++ b/tensorflow/python/estimator/canned/dnn_testing_utils.py @@ -44,13 +44,13 @@ from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import test from tensorflow.python.summary import summary as summary_lib from tensorflow.python.summary.writer import writer_cache from tensorflow.python.training import checkpoint_utils -from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import gradient_descent from tensorflow.python.training import monitored_session from tensorflow.python.training import optimizer as optimizer_lib @@ -222,7 +222,7 @@ def mock_optimizer(testcase, hidden_units, expected_loss=None): testcase.assertEquals(0, loss.shape.ndims) if expected_loss is None: if global_step is not None: - return distribute_lib.increment_var(global_step) + return state_ops.assign_add(global_step, 1).op return control_flow_ops.no_op() assert_loss = assert_close( math_ops.to_float(expected_loss, name='expected'), @@ -230,7 +230,7 @@ def mock_optimizer(testcase, hidden_units, expected_loss=None): name='assert_loss') with ops.control_dependencies((assert_loss,)): if global_step is not None: - return distribute_lib.increment_var(global_step) + return state_ops.assign_add(global_step, 1).op return control_flow_ops.no_op() optimizer_mock = test.mock.NonCallableMagicMock( diff --git a/tensorflow/python/estimator/canned/linear_testing_utils.py b/tensorflow/python/estimator/canned/linear_testing_utils.py index c3934c7a80..65cdd50061 100644 --- a/tensorflow/python/estimator/canned/linear_testing_utils.py +++ b/tensorflow/python/estimator/canned/linear_testing_utils.py @@ -48,13 +48,13 @@ from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache from tensorflow.python.training import checkpoint_utils -from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import gradient_descent from tensorflow.python.training import input as input_lib from tensorflow.python.training import optimizer as optimizer_lib @@ -756,7 +756,7 @@ class BaseLinearRegressorTrainingTest(object): self.assertEquals(0, loss.shape.ndims) if expected_loss is None: if global_step is not None: - return distribute_lib.increment_var(global_step) + return state_ops.assign_add(global_step, 1).op return control_flow_ops.no_op() assert_loss = assert_close( math_ops.to_float(expected_loss, name='expected'), @@ -764,7 +764,7 @@ class BaseLinearRegressorTrainingTest(object): name='assert_loss') with ops.control_dependencies((assert_loss,)): if global_step is not None: - return distribute_lib.increment_var(global_step) + return state_ops.assign_add(global_step, 1).op return control_flow_ops.no_op() mock_optimizer = test.mock.NonCallableMock( @@ -979,13 +979,13 @@ class BaseLinearClassifierTrainingTest(object): # Verify loss. We can't check the value directly, so we add an assert op. self.assertEquals(0, loss.shape.ndims) if expected_loss is None: - return distribute_lib.increment_var(global_step) + return state_ops.assign_add(global_step, 1).op assert_loss = assert_close( math_ops.to_float(expected_loss, name='expected'), loss, name='assert_loss') with ops.control_dependencies((assert_loss,)): - return distribute_lib.increment_var(global_step) + return state_ops.assign_add(global_step, 1).op mock_optimizer = test.mock.NonCallableMock( spec=optimizer_lib.Optimizer, diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index f7da3f7d64..3383383467 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -78,10 +78,26 @@ class VariableSynchronization(enum.Enum): @tf_export("VariableAggregation") class VariableAggregation(enum.Enum): - """Indicates how a distributed variable will be aggregated.""" + """Indicates how a distributed variable will be aggregated. + + `tf.contrib.distribute.DistributionStrategy` distributes a model by making + multiple copies (called "towers") acting data-parallel on different elements + of the input batch. When performing some variable-update operation, say + `var.assign_add(x)`, in a model, we need to resolve how to combine the + different values for `x` computed in the different towers. + + * `NONE`: This is the default, giving an error if you use a + variable-update operation with multiple towers. + * `SUM`: Add the updates across towers. + * `MEAN`: Take the arithmetic mean ("average") of the updates across towers. + * `ONLY_FIRST_TOWER`: This is for when every tower is performing the same + update, but we only want to perform the update once. Used, e.g., for the + global step counter. + """ NONE = 0 SUM = 1 MEAN = 2 + ONLY_FIRST_TOWER = 3 class VariableMetaclass(type): diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py index 1ac7c39872..ac92238d57 100644 --- a/tensorflow/python/training/distribute.py +++ b/tensorflow/python/training/distribute.py @@ -32,6 +32,7 @@ from tensorflow.python.ops.losses import losses_impl from tensorflow.python.platform import tf_logging from tensorflow.python.training import device_util from tensorflow.python.training import distribution_strategy_context +from tensorflow.python.util import deprecation from tensorflow.python.util import nest @@ -723,7 +724,8 @@ class DistributionStrategy(object): Args: aggregation: Indicates how a variable will be aggregated. Accepted values - are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`. + are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`, + `tf.VariableAggregation.ONLY_FIRST_TOWER`. value: A per-device value with one value per tower. destinations: An optional mirrored variable, a device string, list of device strings. The return value will be copied to all @@ -740,7 +742,8 @@ class DistributionStrategy(object): _require_cross_tower_context(self) assert aggregation in [ variable_scope.VariableAggregation.SUM, - variable_scope.VariableAggregation.MEAN + variable_scope.VariableAggregation.MEAN, + variable_scope.VariableAggregation.ONLY_FIRST_TOWER ] return self._reduce(aggregation, value, destinations) @@ -752,7 +755,8 @@ class DistributionStrategy(object): Args: aggregation: Indicates how a variable will be aggregated. Accepted values - are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`. + are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`, + `tf.VariableAggregation.ONLY_FIRST_TOWER`. value_destination_pairs: A sequence of (value, destinations) pairs. See `reduce()` for a description. @@ -763,7 +767,8 @@ class DistributionStrategy(object): _require_cross_tower_context(self) assert aggregation in [ variable_scope.VariableAggregation.SUM, - variable_scope.VariableAggregation.MEAN + variable_scope.VariableAggregation.MEAN, + variable_scope.VariableAggregation.ONLY_FIRST_TOWER ] return self._batch_reduce(aggregation, value_destination_pairs) @@ -1168,9 +1173,14 @@ class _DefaultDistributionStrategy(DistributionStrategy): # ------------------------------------------------------------------------------ -# Common operations +# Deprecated, use v.assign_add(amount) instead. Internal API, so expect +# it to be deleted soon. +@deprecation.deprecated(None, + "Use v.assign_add(amount) instead. You may need to set " + "aggregation=tf.VariableAggregation.ONLY_FIRST_TOWER " + "when creating the variable.") def increment_var(v, amount=1): """`v += amount`, distributed-aware version.""" def update(vu): diff --git a/tensorflow/python/training/training_util.py b/tensorflow/python/training/training_util.py index 2ff3eeb153..d998d6af81 100644 --- a/tensorflow/python/training/training_util.py +++ b/tensorflow/python/training/training_util.py @@ -129,6 +129,7 @@ def create_global_step(graph=None): dtype=dtypes.int64, initializer=init_ops.zeros_initializer(), trainable=False, + aggregation=variables.VariableAggregation.ONLY_FIRST_TOWER, collections=[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP]) # Create in proper graph and base name_scope. @@ -139,6 +140,7 @@ def create_global_step(graph=None): dtype=dtypes.int64, initializer=init_ops.zeros_initializer(), trainable=False, + aggregation=variables.VariableAggregation.ONLY_FIRST_TOWER, collections=[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP]) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-variable-aggregation.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-variable-aggregation.pbtxt index 36b534af36..66a20547eb 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-variable-aggregation.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-variable-aggregation.pbtxt @@ -10,6 +10,10 @@ tf_class { mtype: "<enum \'VariableAggregation\'>" } member { + name: "ONLY_FIRST_TOWER" + mtype: "<enum \'VariableAggregation\'>" + } + member { name: "SUM" mtype: "<enum \'VariableAggregation\'>" } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-variable-aggregation.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-variable-aggregation.pbtxt index 36b534af36..66a20547eb 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.-variable-aggregation.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.-variable-aggregation.pbtxt @@ -10,6 +10,10 @@ tf_class { mtype: "<enum \'VariableAggregation\'>" } member { + name: "ONLY_FIRST_TOWER" + mtype: "<enum \'VariableAggregation\'>" + } + member { name: "SUM" mtype: "<enum \'VariableAggregation\'>" } |