aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-05-26 16:55:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-26 20:29:29 -0700
commit068275d9d9c05d5834b39381d5707d3b920a46a6 (patch)
tree3ba9139b0cad2b24bf1bdd07d8c051cb201eff16
parent04fede8bde23381166fa4f5d9582cf7e05985393 (diff)
AdagradDA with global step to work with tf.learn models.
PiperOrigin-RevId: 157278849
-rw-r--r--tensorflow/python/training/adagrad_da.py29
1 files changed, 13 insertions, 16 deletions
diff --git a/tensorflow/python/training/adagrad_da.py b/tensorflow/python/training/adagrad_da.py
index ba0e4e1b9d..1aa513f13e 100644
--- a/tensorflow/python/training/adagrad_da.py
+++ b/tensorflow/python/training/adagrad_da.py
@@ -80,6 +80,7 @@ class AdagradDAOptimizer(optimizer.Optimizer):
self._l1_regularization_strength = l1_regularization_strength
self._l2_regularization_strength = l2_regularization_strength
self._global_step = global_step
+ self._global_step_on_worker = None
def _create_slots(self, var_list):
for v in var_list:
@@ -97,14 +98,16 @@ class AdagradDAOptimizer(optimizer.Optimizer):
def _prepare(self):
self._learning_rate_tensor = ops.convert_to_tensor(
self._learning_rate, name="learning_rate")
+ # Performance optimization so that worker creates a copy of the global step
+ # to avoid overloading the parameter server holding the global step.
+ with ops.colocate_with(self._learning_rate_tensor):
+ self._global_step_on_worker = array_ops.identity(self._global_step) + 1
def _apply_dense(self, grad, var):
g_acc = self.get_slot(var, "gradient_accumulator")
gg_acc = self.get_slot(var, "gradient_squared_accumulator")
- # Performance optimization so that worker creates a copy of the global step
- # to avoid overloading the parameter server holding the global step.
- with ops.device(grad[0].device):
- global_step = array_ops.identity(self._global_step) + 1
+ with ops.device(var.device):
+ global_step = array_ops.identity(self._global_step_on_worker)
return training_ops.apply_adagrad_da(
var,
g_acc,
@@ -119,10 +122,8 @@ class AdagradDAOptimizer(optimizer.Optimizer):
def _resource_apply_dense(self, grad, var):
g_acc = self.get_slot(var, "gradient_accumulator")
gg_acc = self.get_slot(var, "gradient_squared_accumulator")
- # Performance optimization so that worker creates a copy of the global step
- # to avoid overloading the parameter server holding the global step.
- with ops.device(grad[0].device):
- global_step = array_ops.identity(self._global_step) + 1
+ with ops.device(var.device):
+ global_step = array_ops.identity(self._global_step_on_worker)
return training_ops.resource_apply_adagrad_da(
var.handle,
g_acc.handle,
@@ -137,10 +138,8 @@ class AdagradDAOptimizer(optimizer.Optimizer):
def _apply_sparse(self, grad, var):
g_acc = self.get_slot(var, "gradient_accumulator")
gg_acc = self.get_slot(var, "gradient_squared_accumulator")
- # Performance optimization so that worker creates a copy of the global step
- # to avoid overloading the parameter server holding the global step.
- with ops.device(grad[0].device):
- global_step = array_ops.identity(self._global_step) + 1
+ with ops.device(var.device):
+ global_step = array_ops.identity(self._global_step_on_worker)
return training_ops.sparse_apply_adagrad_da(
var,
g_acc,
@@ -156,10 +155,8 @@ class AdagradDAOptimizer(optimizer.Optimizer):
def _resource_apply_sparse(self, grad, var, indices):
g_acc = self.get_slot(var, "gradient_accumulator")
gg_acc = self.get_slot(var, "gradient_squared_accumulator")
- # Performance optimization so that worker creates a copy of the global step
- # to avoid overloading the parameter server holding the global step.
- with ops.device(grad[0].device):
- global_step = array_ops.identity(self._global_step) + 1
+ with ops.device(var.device):
+ global_step = array_ops.identity(self._global_step_on_worker)
return training_ops.resource_sparse_apply_adagrad_da(
var.handle,
g_acc.handle,