aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-29 10:17:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-29 10:22:42 -0700
commitaca93368a979419360c1fd84b53b1766b19ba81a (patch)
tree2312ef53a30251ec2f5538d43ba066550679f6d9 /tensorflow/contrib/distribute/python
parent8a22fa7037332fc6066459ce8c6fabcd77c6ece4 (diff)
Add new aggregation mode "ONLY_FIRST_TOWER" and use it for the global
step counter. This allows us to get rid of the increment_var() function and just use a standard assign_add(). PiperOrigin-RevId: 210743165
Diffstat (limited to 'tensorflow/contrib/distribute/python')
-rw-r--r--tensorflow/contrib/distribute/python/BUILD3
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py31
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py102
-rw-r--r--tensorflow/contrib/distribute/python/parameter_server_strategy.py8
-rw-r--r--tensorflow/contrib/distribute/python/parameter_server_strategy_test.py18
-rw-r--r--tensorflow/contrib/distribute/python/tpu_strategy.py5
-rw-r--r--tensorflow/contrib/distribute/python/values.py2
7 files changed, 149 insertions, 20 deletions
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index 94deb2a432..c524d8b394 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -279,10 +279,11 @@ cuda_py_test(
":strategy_test_lib",
"//tensorflow/python:distribute",
"//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:layers",
+ "//tensorflow/python:state_ops",
"//tensorflow/python:variable_scope",
- "//tensorflow/python:array_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:test",
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index b44edfbd27..b4233a5eed 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -65,7 +65,7 @@ class _RequestedStop(Exception):
pass
-# Make _call_for_each_tower and _reduce_non_distributed_value not members of
+# _call_for_each_tower and _reduce_non_distributed_value are not members of
# MirroredStrategy so that they are generally not allowed to use anything
# specific to MirroredStrategy and thus can be shared with other distribution
# strategies.
@@ -197,10 +197,12 @@ def _reduce_non_distributed_value(distribution, aggregation, value,
# and equal to 0.
if value == 0:
return 0
- # If the aggregation type is MEAN, then this essentially means that the same
- # value should be on all destinations.
- if aggregation == variable_scope.VariableAggregation.MEAN:
- return distribution.broadcast(value, destinations)
+ # If the aggregation type is MEAN or ONLY_FIRST_TOWER, then this
+ # essentially means that the same value should be on all destinations.
+ if aggregation in (
+ variable_scope.VariableAggregation.MEAN,
+ variable_scope.VariableAggregation.ONLY_FIRST_TOWER):
+ return value
cross_tower_ops_lib.validate_destinations(destinations)
# We do not support an aggregation type of SUM if the value is the same across
@@ -208,8 +210,8 @@ def _reduce_non_distributed_value(distribution, aggregation, value,
# and summing up identical values across towers is not clearly defined.
if (len(distribution.worker_devices) != 1 or
not cross_tower_ops_lib.check_destinations(destinations)):
- raise ValueError("A non-DistributedValues value cannot be reduced with the "
- "given aggregation.")
+ raise ValueError("A non-DistributedValues value %s cannot be reduced with "
+ "the given aggregation %s." % (value, aggregation))
# TODO(anjalisridhar): Moves these methods to a device utility file?
devices = cross_tower_ops_lib.get_devices_from(destinations)
if len(devices) == 1:
@@ -254,11 +256,12 @@ def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs):
# Get aggregation value
aggregation = kwargs.pop("aggregation",
variable_scope.VariableAggregation.NONE)
- if aggregation not in [
+ if aggregation not in (
variable_scope.VariableAggregation.NONE,
variable_scope.VariableAggregation.SUM,
- variable_scope.VariableAggregation.MEAN
- ]:
+ variable_scope.VariableAggregation.MEAN,
+ variable_scope.VariableAggregation.ONLY_FIRST_TOWER
+ ):
raise ValueError("Invalid variable aggregation mode: " + aggregation +
" for variable: " + kwargs["name"])
@@ -591,10 +594,18 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
# which case `value` would be a single value or value could be 0.
return _reduce_non_distributed_value(self, aggregation, value,
destinations)
+ if aggregation == variable_scope.VariableAggregation.ONLY_FIRST_TOWER:
+ value = value.get(self._devices[0])
+ if isinstance(value, (int, float)):
+ return value
+ return self.broadcast(value, destinations)
return self._get_cross_tower_ops().reduce(
aggregation, value, destinations=destinations)
def _batch_reduce(self, aggregation, value_destination_pairs):
+ if aggregation == variable_scope.VariableAggregation.ONLY_FIRST_TOWER:
+ return [self.broadcast(v.get(self._devices[0]), d)
+ for v, d in value_destination_pairs]
return self._get_cross_tower_ops().batch_reduce(aggregation,
value_destination_pairs)
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index a12ff662db..830681a4ce 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -38,6 +38,7 @@ from tensorflow.python.layers import core
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell_impl
+from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import device_util
@@ -128,6 +129,25 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase):
expected = sum(range(len(dist.worker_devices)))
self.assertEqual(expected, self.evaluate(unwrapped[0]))
+ @test_util.run_in_graph_and_eager_modes
+ def testReduceOnlyFirstTowerUpdates(self):
+ if not GPU_TEST:
+ self.skipTest("Not GPU test")
+
+ def run_fn(device_id):
+ return constant_op.constant(3 + 5 * device_id)
+
+ dist = self._get_distribution_strategy()
+ with dist.scope():
+ result = dist.call_for_each_tower(run_fn, dist.worker_device_index)
+ reduced = dist.reduce(
+ variable_scope.VariableAggregation.ONLY_FIRST_TOWER,
+ result,
+ destinations="/device:CPU:0")
+ unwrapped = dist.unwrap(reduced)
+ self.assertEqual(1, len(unwrapped))
+ self.assertEqual(3, self.evaluate(unwrapped[0]))
+
@test_util.run_in_graph_and_eager_modes()
def testReduceToMultipleDestinations(self):
if not GPU_TEST:
@@ -384,6 +404,84 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
v3.aggregation)
@test_util.run_in_graph_and_eager_modes(config=config)
+ def testOnlyFirstTowerUpdatesVariables(self):
+ self._skip_eager_if_gpus_less_than(1)
+
+ def create_fn():
+ aggregation = variable_scope.VariableAggregation.ONLY_FIRST_TOWER
+ v0 = variable_scope.variable(
+ 2.0,
+ name="on_read",
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ aggregation=aggregation)
+ v1 = variable_scope.variable(
+ 3.0,
+ name="on_write",
+ synchronization=variable_scope.VariableSynchronization.ON_WRITE,
+ aggregation=aggregation)
+ return v0, v1
+
+ devices = ["/device:GPU:0", "/device:CPU:0"]
+ dist = mirrored_strategy.MirroredStrategy(devices)
+ with dist.scope():
+ v0, v1 = dist.call_for_each_tower(create_fn, run_concurrently=False)
+ self.evaluate(v0.initializer)
+ self.assertEqual(2.0, self.evaluate(v0.get(devices[0])))
+ self.assertEqual(2.0, self.evaluate(v0.get(devices[1])))
+ self.assertEqual(2.0, self.evaluate(dist.read_var(v0)))
+ self.evaluate(v1.initializer)
+ self.assertEqual(3.0, self.evaluate(v1.get(devices[0])))
+ self.assertEqual(3.0, self.evaluate(v1.get(devices[1])))
+ self.assertEqual(3.0, self.evaluate(dist.read_var(v1)))
+
+ # Update using the assign_add member function.
+ def update_member_fn(device_id):
+ update0 = v0.assign_add(5.0 * (device_id + 1))
+ update1 = v1.assign_add(7.0 * (device_id + 1))
+ return update0, update1
+
+ update0a, update1a = dist.call_for_each_tower(
+ update_member_fn, dist.worker_device_index, run_concurrently=False)
+
+ # Update "sync on read" variable.
+ self.evaluate(dist.group(update0a))
+ self.assertEqual(2.0 + 5.0, self.evaluate(v0.get(devices[0])))
+ # Writes are not synchronized for "sync on read" variables,
+ # so device[1] can end up with a different value.
+ self.assertEqual(2.0 + 2*5.0, self.evaluate(v0.get(devices[1])))
+ # Always reads from device 0.
+ self.assertEqual(2.0 + 5.0, self.evaluate(dist.read_var(v0)))
+
+ # Update "sync on write" variable.
+ self.evaluate(dist.group(update1a))
+ self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[0])))
+ # Writes are synchronized for v1, only the argument to assign_add on
+ # device[0] is used.
+ self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[1])))
+ self.assertEqual(3.0 + 7.0, self.evaluate(dist.read_var(v1)))
+
+ # Update using state_ops.assign_add global function.
+ def update_state_ops_fn(device_id):
+ update0 = state_ops.assign_add(v0, 11.0 * (device_id + 1))
+ update1 = state_ops.assign_add(v1, 13.0 * (device_id + 1))
+ return update0, update1
+
+ update0b, update1b = dist.call_for_each_tower(
+ update_state_ops_fn, dist.worker_device_index, run_concurrently=False)
+ self.evaluate(dist.group(update0b))
+
+ # Update "sync on read" variable.
+ self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate(v0.get(devices[0])))
+ self.assertEqual(2.0 + 2*5.0 + 2*11.0, self.evaluate(v0.get(devices[1])))
+ self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate(dist.read_var(v0)))
+
+ # Update "sync on write" variable.
+ self.evaluate(dist.group(update1b))
+ self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[0])))
+ self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[1])))
+ self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(dist.read_var(v1)))
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
def testNoneSynchronizationWithGetVariable(self):
self._skip_eager_if_gpus_less_than(1)
devices = ["/device:CPU:0", "/device:GPU:0"]
@@ -804,8 +902,8 @@ class MirroredVariableUpdateTest(test.TestCase):
return mirrored_var.assign(5.0)
with self.assertRaisesRegexp(
- ValueError, "A non-DistributedValues value cannot be reduced with "
- "the given aggregation."):
+ ValueError, "A non-DistributedValues value 5.0 cannot be reduced "
+ "with the given aggregation VariableAggregation.SUM."):
self.evaluate(dist.unwrap(dist.call_for_each_tower(model_fn)))
@test_util.run_in_graph_and_eager_modes(config=config)
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py
index 361c8be590..0f439f6c1f 100644
--- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py
+++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py
@@ -235,7 +235,8 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
if aggregation not in (
vs.VariableAggregation.NONE,
vs.VariableAggregation.SUM,
- vs.VariableAggregation.MEAN
+ vs.VariableAggregation.MEAN,
+ vs.VariableAggregation.ONLY_FIRST_TOWER
):
raise ValueError("Invalid variable aggregation mode: " + aggregation +
" for variable: " + kwargs["name"])
@@ -302,10 +303,15 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
# pylint: disable=protected-access
return mirrored_strategy._reduce_non_distributed_value(
self, aggregation, value, destinations)
+ if aggregation == vs.VariableAggregation.ONLY_FIRST_TOWER:
+ return self.broadcast(value.get(self._compute_devices[0]), destinations)
return self._cross_tower_ops.reduce(
aggregation, value, destinations=destinations)
def _batch_reduce(self, aggregation, value_destination_pairs):
+ if aggregation == vs.VariableAggregation.ONLY_FIRST_TOWER:
+ return [self.broadcast(v.get(self._compute_devices[0]), d)
+ for v, d in value_destination_pairs]
for _, destinations in value_destination_pairs:
self._verify_destinations_not_different_worker(destinations)
return self._cross_tower_ops.batch_reduce(aggregation,
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
index 0e2bfcec5f..cbf18bf1d2 100644
--- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
@@ -286,18 +286,22 @@ class ParameterServerStrategyTestBase(
y = variable_scope.get_variable(
'y', initializer=20.0,
aggregation=variable_scope.VariableAggregation.SUM)
+ z = variable_scope.get_variable(
+ 'z', initializer=30.0,
+ aggregation=variable_scope.VariableAggregation.ONLY_FIRST_TOWER)
# We explicitly make a constant tensor here to avoid complaints about
# summing non-distributed values.
one = constant_op.constant(1.0)
x_add = x.assign_add(one, use_locking=True)
y_add = y.assign_add(one, use_locking=True)
+ z_add = z.assign_add(one, use_locking=True)
- train_op = control_flow_ops.group([x_add, y_add])
- return x, y, train_op
+ train_op = control_flow_ops.group(x_add, y_add, z_add)
+ return x, y, z, train_op
- x, y, train_op = d.call_for_each_tower(model_fn)
- train_op = d.group(d.unwrap(train_op))
+ x, y, z, train_op = d.call_for_each_tower(model_fn)
+ train_op = d.group(train_op)
if context.num_gpus() < d._num_gpus_per_worker:
return True
@@ -323,11 +327,13 @@ class ParameterServerStrategyTestBase(
self._finish_condition.notify_all()
self._finish_condition.release()
- x_val, y_val = sess.run([x, y])
+ x_val, y_val, z_val = sess.run([x, y, z])
self.assertEqual(x_val, 10.0 + 1.0 * num_workers * d.num_towers)
self.assertEqual(y_val, 20.0 + 1.0 * num_workers * d.num_towers)
+ self.assertEqual(z_val, 30.0 + 1.0 * num_workers)
return (x_val == 10.0 + 1.0 * num_workers * d.num_towers and
- y_val == 20.0 + 1.0 * num_workers * d.num_towers)
+ y_val == 20.0 + 1.0 * num_workers * d.num_towers and
+ z_val == 30.0 + 1.0 * num_workers)
def _test_minimize_loss_graph(self, task_type, task_id, num_gpus):
d, master_target = self._get_test_objects(task_type, task_id, num_gpus)
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
index 6202a0750a..d0dbbd0da8 100644
--- a/tensorflow/contrib/distribute/python/tpu_strategy.py
+++ b/tensorflow/contrib/distribute/python/tpu_strategy.py
@@ -238,6 +238,9 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
if aggregation == vs.VariableAggregation.MEAN:
# TODO(jhseu): Revisit once we support model-parallelism.
value *= (1. / self.num_towers)
+ elif aggregation != vs.VariableAggregation.SUM:
+ raise NotImplementedError(
+ 'Currently only support sum & mean in TPUStrategy.')
return tpu_ops.cross_replica_sum(value)
cf_context = cf_context.outer_context
@@ -251,6 +254,8 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
else:
raise ValueError('Multiple devices are not supported for TPUStrategy')
+ if aggregation == vs.VariableAggregation.ONLY_FIRST_TOWER:
+ return value[0]
output = math_ops.add_n(value)
if aggregation == vs.VariableAggregation.MEAN:
return output * (1. / len(value))
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 3ccaa2690e..479b7f39d6 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -523,6 +523,8 @@ class TowerLocalVariable(DistributedVariable, PerDevice,
return self._aggregation
def _get_cross_tower(self):
+ if self._aggregation == vs.VariableAggregation.ONLY_FIRST_TOWER:
+ return self._primary_var
all_components = tuple(self._index.values())
# TODO(josh11b): Use a strategy-specific method.
total = math_ops.add_n(all_components)