diff options
author | Pavithra Vijay <psv@google.com> | 2018-06-29 18:02:18 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-29 18:05:25 -0700 |
commit | c290930ec1beacbcac414b43b3367dd44ffbd303 (patch) | |
tree | b1136e7c32718a6f1f9ebfde3073c88546078de6 /tensorflow/contrib/optimizer_v2 | |
parent | a520735d205ca5678fc8a371ea1add00413907fe (diff) |
Add `synchronization` and `aggregation` args to get_variable(). These args will be used for distributed variables.
Add Enum `VariableSynchronization` with values for `synchronization`: AUTO, UNREPLICATED, ON_WRITE, ON_READ
Add Enum `VariableAggregation` with values for `aggregation`: NONE, SUM, MEAN. Replace all the aggregation methods strings in distribution strategy to the enum values.
Update Mirrored strategy to use these parameters to decide on whether a variable should be Mirrored or TowerLocal.
Update different distribution strategy value types to use the `VariableAggregation` Enum
PiperOrigin-RevId: 202736077
Diffstat (limited to 'tensorflow/contrib/optimizer_v2')
-rw-r--r-- | tensorflow/contrib/optimizer_v2/optimizer_v2.py | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py index c6f3bd6ee1..8c11d8bcfd 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py @@ -766,7 +766,8 @@ class OptimizerV2(optimizer_v1.Optimizer): # *after* loss() is evaluated, so we know what loss reduction it uses. if scale_loss_by_num_towers is None: scale_loss_by_num_towers = ( - distribute_lib.get_loss_reduction() == "mean") + distribute_lib.get_loss_reduction() == + variable_scope.VariableAggregation.MEAN) if scale_loss_by_num_towers: num_towers = distribute_lib.get_distribution_strategy().num_towers if num_towers > 1: @@ -784,7 +785,8 @@ class OptimizerV2(optimizer_v1.Optimizer): # Scale loss for number of towers (non-callable-loss case). if scale_loss_by_num_towers is None: scale_loss_by_num_towers = ( - distribute_lib.get_loss_reduction() == "mean") + distribute_lib.get_loss_reduction() == + variable_scope.VariableAggregation.MEAN) if scale_loss_by_num_towers: num_towers = distribute_lib.get_distribution_strategy().num_towers if num_towers > 1: @@ -896,7 +898,8 @@ class OptimizerV2(optimizer_v1.Optimizer): def _distributed_apply(self, distribution, grads_and_vars, global_step, name): """`apply_gradients` for use with a `DistributionStrategy`.""" - reduced_grads = distribution.batch_reduce("sum", grads_and_vars) + reduced_grads = distribution.batch_reduce( + variable_scope.VariableAggregation.SUM, grads_and_vars) var_list = [v for _, v in grads_and_vars] grads_and_vars = zip(reduced_grads, var_list) |