aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/training/adagrad_da.py29
1 files changed, 13 insertions, 16 deletions
diff --git a/tensorflow/python/training/adagrad_da.py b/tensorflow/python/training/adagrad_da.py
index ba0e4e1b9d..1aa513f13e 100644
--- a/tensorflow/python/training/adagrad_da.py
+++ b/tensorflow/python/training/adagrad_da.py
@@ -80,6 +80,7 @@ class AdagradDAOptimizer(optimizer.Optimizer):
self._l1_regularization_strength = l1_regularization_strength
self._l2_regularization_strength = l2_regularization_strength
self._global_step = global_step
+ self._global_step_on_worker = None
def _create_slots(self, var_list):
for v in var_list:
@@ -97,14 +98,16 @@ class AdagradDAOptimizer(optimizer.Optimizer):
def _prepare(self):
self._learning_rate_tensor = ops.convert_to_tensor(
self._learning_rate, name="learning_rate")
+ # Performance optimization so that worker creates a copy of the global step
+ # to avoid overloading the parameter server holding the global step.
+ with ops.colocate_with(self._learning_rate_tensor):
+ self._global_step_on_worker = array_ops.identity(self._global_step) + 1
def _apply_dense(self, grad, var):
g_acc = self.get_slot(var, "gradient_accumulator")
gg_acc = self.get_slot(var, "gradient_squared_accumulator")
- # Performance optimization so that worker creates a copy of the global step
- # to avoid overloading the parameter server holding the global step.
- with ops.device(grad[0].device):
- global_step = array_ops.identity(self._global_step) + 1
+ with ops.device(var.device):
+ global_step = array_ops.identity(self._global_step_on_worker)
return training_ops.apply_adagrad_da(
var,
g_acc,
@@ -119,10 +122,8 @@ class AdagradDAOptimizer(optimizer.Optimizer):
def _resource_apply_dense(self, grad, var):
g_acc = self.get_slot(var, "gradient_accumulator")
gg_acc = self.get_slot(var, "gradient_squared_accumulator")
- # Performance optimization so that worker creates a copy of the global step
- # to avoid overloading the parameter server holding the global step.
- with ops.device(grad[0].device):
- global_step = array_ops.identity(self._global_step) + 1
+ with ops.device(var.device):
+ global_step = array_ops.identity(self._global_step_on_worker)
return training_ops.resource_apply_adagrad_da(
var.handle,
g_acc.handle,
@@ -137,10 +138,8 @@ class AdagradDAOptimizer(optimizer.Optimizer):
def _apply_sparse(self, grad, var):
g_acc = self.get_slot(var, "gradient_accumulator")
gg_acc = self.get_slot(var, "gradient_squared_accumulator")
- # Performance optimization so that worker creates a copy of the global step
- # to avoid overloading the parameter server holding the global step.
- with ops.device(grad[0].device):
- global_step = array_ops.identity(self._global_step) + 1
+ with ops.device(var.device):
+ global_step = array_ops.identity(self._global_step_on_worker)
return training_ops.sparse_apply_adagrad_da(
var,
g_acc,
@@ -156,10 +155,8 @@ class AdagradDAOptimizer(optimizer.Optimizer):
def _resource_apply_sparse(self, grad, var, indices):
g_acc = self.get_slot(var, "gradient_accumulator")
gg_acc = self.get_slot(var, "gradient_squared_accumulator")
- # Performance optimization so that worker creates a copy of the global step
- # to avoid overloading the parameter server holding the global step.
- with ops.device(grad[0].device):
- global_step = array_ops.identity(self._global_step) + 1
+ with ops.device(var.device):
+ global_step = array_ops.identity(self._global_step_on_worker)
return training_ops.resource_sparse_apply_adagrad_da(
var.handle,
g_acc.handle,