aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/opt
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-04 14:14:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-04 14:14:29 -0700
commit8ef276fd2181fb71c2e232f60aa45ee96cb5905b (patch)
tree030350464c3e9449653db32ad7b0927ca209bfe1 /tensorflow/contrib/opt
parent8cf8afefdb4c240f74a05e24246c8cd2dcce9d54 (diff)
parentce035c2493c060b38e53ca7a63c66b26e265b210 (diff)
Merge pull request #19661 from jinxin0924:ma_easgd
PiperOrigin-RevId: 211519911
Diffstat (limited to 'tensorflow/contrib/opt')
-rw-r--r--tensorflow/contrib/opt/python/training/elastic_average_optimizer.py14
-rw-r--r--tensorflow/contrib/opt/python/training/model_average_optimizer.py8
-rw-r--r--tensorflow/contrib/opt/python/training/model_average_optimizer_test.py40
3 files changed, 35 insertions, 27 deletions
diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py
index bbafd59aae..6c203e5519 100644
--- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py
+++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py
@@ -128,12 +128,14 @@ class ElasticAverageCustomGetter(object):
= list(global_center_variable)[i]
return local_var
else:
- return getter(
- name,
- trainable=trainable,
- collections=collections,
- *args,
- **kwargs)
+ kwargs['trainable'] = trainable
+ kwargs['collections'] = collections
+ if ops.GraphKeys.LOCAL_VARIABLES in collections:
+ with ops.device(self._worker_device):
+ return getter(name, *args, **kwargs)
+ else:
+ return getter(name, *args, **kwargs)
+
class ElasticAverageOptimizer(optimizer.Optimizer):
diff --git a/tensorflow/contrib/opt/python/training/model_average_optimizer.py b/tensorflow/contrib/opt/python/training/model_average_optimizer.py
index b6b10e500b..746df77ba2 100644
--- a/tensorflow/contrib/opt/python/training/model_average_optimizer.py
+++ b/tensorflow/contrib/opt/python/training/model_average_optimizer.py
@@ -89,7 +89,13 @@ class ModelAverageCustomGetter(object):
self._local_2_global[local_var] = global_variable
return local_var
else:
- return getter(name, trainable, collections, *args, **kwargs)
+ kwargs['trainable'] = trainable
+ kwargs['collections'] = collections
+ if ops.GraphKeys.LOCAL_VARIABLES in collections:
+ with ops.device(self._worker_device):
+ return getter(name, *args, **kwargs)
+ else:
+ return getter(name, *args, **kwargs)
class ModelAverageOptimizer(optimizer.Optimizer):
diff --git a/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py b/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py
index 3acd940268..b1fc50a21f 100644
--- a/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py
@@ -80,28 +80,28 @@ def _get_workers(num_workers, steps, workers):
var_0 = variable_scope.get_variable(initializer=0.0, name="v0")
var_1 = variable_scope.get_variable(initializer=1.0, name="v1")
- with ops.device("/job:worker/task:" + str(worker_id)):
- if worker_id == 0:
- grads_0 = constant_op.constant(-1.0)
- grads_1 = constant_op.constant(-1.0)
- else:
- grads_0 = constant_op.constant(-2.0)
- grads_1 = constant_op.constant(-2.0)
- sgd_opt = gradient_descent.GradientDescentOptimizer(1.0)
- opt = model_average_optimizer.ModelAverageOptimizer(
- opt=sgd_opt,
- num_worker=num_workers,
- ma_custom_getter=ma_coustom,
- is_chief=is_chief,
- interval_steps=steps)
- train_op = [
- opt.apply_gradients([[grads_0, var_0], [grads_1, var_1]],
- global_step)
- ]
- easgd_hook = opt.make_session_run_hook()
+ with ops.device("/job:worker/task:" + str(worker_id)):
+ if worker_id == 0:
+ grads_0 = constant_op.constant(-1.0)
+ grads_1 = constant_op.constant(-1.0)
+ else:
+ grads_0 = constant_op.constant(-2.0)
+ grads_1 = constant_op.constant(-2.0)
+ sgd_opt = gradient_descent.GradientDescentOptimizer(1.0)
+ opt = model_average_optimizer.ModelAverageOptimizer(
+ opt=sgd_opt,
+ num_worker=num_workers,
+ ma_custom_getter=ma_coustom,
+ is_chief=is_chief,
+ interval_steps=steps)
+ train_op = [
+ opt.apply_gradients([[grads_0, var_0], [grads_1, var_1]],
+ global_step)
+ ]
+ ma_hook = opt.make_session_run_hook()
# Creates MonitoredSession
sess = training.MonitoredTrainingSession(
- workers[worker_id].target, hooks=[easgd_hook])
+ workers[worker_id].target, hooks=[ma_hook])
sessions.append(sess)
graphs.append(graph)