diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-04 14:14:24 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-04 14:14:29 -0700 |
commit | 8ef276fd2181fb71c2e232f60aa45ee96cb5905b (patch) | |
tree | 030350464c3e9449653db32ad7b0927ca209bfe1 /tensorflow/contrib/opt | |
parent | 8cf8afefdb4c240f74a05e24246c8cd2dcce9d54 (diff) | |
parent | ce035c2493c060b38e53ca7a63c66b26e265b210 (diff) |
Merge pull request #19661 from jinxin0924:ma_easgd
PiperOrigin-RevId: 211519911
Diffstat (limited to 'tensorflow/contrib/opt')
3 files changed, 35 insertions, 27 deletions
diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py index bbafd59aae..6c203e5519 100644 --- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py +++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py @@ -128,12 +128,14 @@ class ElasticAverageCustomGetter(object): = list(global_center_variable)[i] return local_var else: - return getter( - name, - trainable=trainable, - collections=collections, - *args, - **kwargs) + kwargs['trainable'] = trainable + kwargs['collections'] = collections + if ops.GraphKeys.LOCAL_VARIABLES in collections: + with ops.device(self._worker_device): + return getter(name, *args, **kwargs) + else: + return getter(name, *args, **kwargs) + class ElasticAverageOptimizer(optimizer.Optimizer): diff --git a/tensorflow/contrib/opt/python/training/model_average_optimizer.py b/tensorflow/contrib/opt/python/training/model_average_optimizer.py index b6b10e500b..746df77ba2 100644 --- a/tensorflow/contrib/opt/python/training/model_average_optimizer.py +++ b/tensorflow/contrib/opt/python/training/model_average_optimizer.py @@ -89,7 +89,13 @@ class ModelAverageCustomGetter(object): self._local_2_global[local_var] = global_variable return local_var else: - return getter(name, trainable, collections, *args, **kwargs) + kwargs['trainable'] = trainable + kwargs['collections'] = collections + if ops.GraphKeys.LOCAL_VARIABLES in collections: + with ops.device(self._worker_device): + return getter(name, *args, **kwargs) + else: + return getter(name, *args, **kwargs) class ModelAverageOptimizer(optimizer.Optimizer): diff --git a/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py b/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py index 3acd940268..b1fc50a21f 100644 --- a/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py @@ -80,28 +80,28 @@ def _get_workers(num_workers, steps, workers): var_0 = variable_scope.get_variable(initializer=0.0, name="v0") var_1 = variable_scope.get_variable(initializer=1.0, name="v1") - with ops.device("/job:worker/task:" + str(worker_id)): - if worker_id == 0: - grads_0 = constant_op.constant(-1.0) - grads_1 = constant_op.constant(-1.0) - else: - grads_0 = constant_op.constant(-2.0) - grads_1 = constant_op.constant(-2.0) - sgd_opt = gradient_descent.GradientDescentOptimizer(1.0) - opt = model_average_optimizer.ModelAverageOptimizer( - opt=sgd_opt, - num_worker=num_workers, - ma_custom_getter=ma_coustom, - is_chief=is_chief, - interval_steps=steps) - train_op = [ - opt.apply_gradients([[grads_0, var_0], [grads_1, var_1]], - global_step) - ] - easgd_hook = opt.make_session_run_hook() + with ops.device("/job:worker/task:" + str(worker_id)): + if worker_id == 0: + grads_0 = constant_op.constant(-1.0) + grads_1 = constant_op.constant(-1.0) + else: + grads_0 = constant_op.constant(-2.0) + grads_1 = constant_op.constant(-2.0) + sgd_opt = gradient_descent.GradientDescentOptimizer(1.0) + opt = model_average_optimizer.ModelAverageOptimizer( + opt=sgd_opt, + num_worker=num_workers, + ma_custom_getter=ma_coustom, + is_chief=is_chief, + interval_steps=steps) + train_op = [ + opt.apply_gradients([[grads_0, var_0], [grads_1, var_1]], + global_step) + ] + ma_hook = opt.make_session_run_hook() # Creates MonitoredSession sess = training.MonitoredTrainingSession( - workers[worker_id].target, hooks=[easgd_hook]) + workers[worker_id].target, hooks=[ma_hook]) sessions.append(sess) graphs.append(graph) |