aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/optimizer_v2
diff options
context:
space:
mode:
authorGravatar Pavithra Vijay <psv@google.com>2018-06-29 18:02:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-29 18:05:25 -0700
commitc290930ec1beacbcac414b43b3367dd44ffbd303 (patch)
treeb1136e7c32718a6f1f9ebfde3073c88546078de6 /tensorflow/contrib/optimizer_v2
parenta520735d205ca5678fc8a371ea1add00413907fe (diff)
Add `synchronization` and `aggregation` args to get_variable(). These args will be used for distributed variables.
Add Enum `VariableSynchronization` with values for `synchronization`: AUTO, UNREPLICATED, ON_WRITE, ON_READ Add Enum `VariableAggregation` with values for `aggregation`: NONE, SUM, MEAN. Replace all the aggregation methods strings in distribution strategy to the enum values. Update Mirrored strategy to use these parameters to decide on whether a variable should be Mirrored or TowerLocal. Update different distribution strategy value types to use the `VariableAggregation` Enum PiperOrigin-RevId: 202736077
Diffstat (limited to 'tensorflow/contrib/optimizer_v2')
-rw-r--r--tensorflow/contrib/optimizer_v2/optimizer_v2.py9
1 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
index c6f3bd6ee1..8c11d8bcfd 100644
--- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py
+++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
@@ -766,7 +766,8 @@ class OptimizerV2(optimizer_v1.Optimizer):
# *after* loss() is evaluated, so we know what loss reduction it uses.
if scale_loss_by_num_towers is None:
scale_loss_by_num_towers = (
- distribute_lib.get_loss_reduction() == "mean")
+ distribute_lib.get_loss_reduction() ==
+ variable_scope.VariableAggregation.MEAN)
if scale_loss_by_num_towers:
num_towers = distribute_lib.get_distribution_strategy().num_towers
if num_towers > 1:
@@ -784,7 +785,8 @@ class OptimizerV2(optimizer_v1.Optimizer):
# Scale loss for number of towers (non-callable-loss case).
if scale_loss_by_num_towers is None:
scale_loss_by_num_towers = (
- distribute_lib.get_loss_reduction() == "mean")
+ distribute_lib.get_loss_reduction() ==
+ variable_scope.VariableAggregation.MEAN)
if scale_loss_by_num_towers:
num_towers = distribute_lib.get_distribution_strategy().num_towers
if num_towers > 1:
@@ -896,7 +898,8 @@ class OptimizerV2(optimizer_v1.Optimizer):
def _distributed_apply(self, distribution, grads_and_vars, global_step, name):
"""`apply_gradients` for use with a `DistributionStrategy`."""
- reduced_grads = distribution.batch_reduce("sum", grads_and_vars)
+ reduced_grads = distribution.batch_reduce(
+ variable_scope.VariableAggregation.SUM, grads_and_vars)
var_list = [v for _, v in grads_and_vars]
grads_and_vars = zip(reduced_grads, var_list)