aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/opt
diff options
context:
space:
mode:
authorGravatar weidankong <kongweidan84@gmail.com>2018-08-27 15:59:54 -0700
committerGravatar weidankong <kongweidan84@gmail.com>2018-08-27 15:59:54 -0700
commit8d226fe074d18aadf98a869755e7d432341ba882 (patch)
tree70f53fdf2279c4f768753a194c72192b9b445e16 /tensorflow/contrib/opt
parent607004e583ecbd9fb788aaf9b360a8d85cf167ac (diff)
AGN: use variable_creator_scope to move variables from GLOBAL_VARIABLES to LOCAL VARIABLES
Diffstat (limited to 'tensorflow/contrib/opt')
-rw-r--r--tensorflow/contrib/opt/python/training/agn_optimizer.py15
-rw-r--r--tensorflow/contrib/opt/python/training/agn_optimizer_test.py12
2 files changed, 20 insertions, 7 deletions
diff --git a/tensorflow/contrib/opt/python/training/agn_optimizer.py b/tensorflow/contrib/opt/python/training/agn_optimizer.py
index f47ef5acc5..8f415c75b9 100644
--- a/tensorflow/contrib/opt/python/training/agn_optimizer.py
+++ b/tensorflow/contrib/opt/python/training/agn_optimizer.py
@@ -166,12 +166,17 @@ class AGNOptimizer(optimizer.Optimizer):
"""
local_vars = [v for g, v in grads_and_vars if g is not None]
grads = [g for g, v in grads_and_vars if g is not None]
+ def _variable_creator(next_creator, collections, **kwargs):
+ if not collections:
+ collections = [ops.GraphKeys.LOCAL_VARIABLES]
+ elif ops.GraphKeys.GLOBAL_VARIABLES in collections:
+ collections = list(collections)
+ collections.append(ops.GraphKeys.LOCAL_VARIABLES)
+ collections.remove(ops.GraphKeys.GLOBAL_VARIABLES)
+ return next_creator(collections=collections, **kwargs)
# theta = theta - lr * grad
- global_old = set(n.op.name for n in variables.global_variables())
- local_update_op = self._opt.apply_gradients(grads_and_vars)
- global_new = set(n.op.name for n in variables.global_variables())
-
- self._adjust_optimizer_variable_collection(global_new - global_old)
+ with variable_scope.variable_creator_scope(_variable_creator):
+ local_update_op = self._opt.apply_gradients(grads_and_vars)
# a = a + grad
update_ops = []
diff --git a/tensorflow/contrib/opt/python/training/agn_optimizer_test.py b/tensorflow/contrib/opt/python/training/agn_optimizer_test.py
index 4e2200fa1a..a2302d2f11 100644
--- a/tensorflow/contrib/opt/python/training/agn_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/agn_optimizer_test.py
@@ -23,10 +23,11 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import variables
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
from tensorflow.python.training import device_setter
-from tensorflow.python.training import momentum
+from tensorflow.python.training import adam
from tensorflow.python.training import server_lib
from tensorflow.python.training import training
from tensorflow.python.training import training_util
@@ -100,7 +101,7 @@ def _get_workers(num_workers, period, workers, num_ps=1):
grads_part_1 = constant_op.constant([[-1., -1., -1., -1.]])
optimizer = \
- momentum.MomentumOptimizer(learning_rate=0.1, momentum=0.0)
+ adam.AdamOptimizer(learning_rate=0.1, beta1=0.0, beta2=0.0)
opt = AGNOptimizer(
optimizer,
num_worker=num_workers,
@@ -152,6 +153,13 @@ class AGNOptimizerTest(test.TestCase):
var_0_g = graphs[0].get_tensor_by_name(GLOBAL_VARIABLE_NAME + "/v0:0")
var_1_g = graphs[0].get_tensor_by_name(GLOBAL_VARIABLE_NAME + "/v1:0")
+ # verify adam/beta variables not in global collection
+ with graphs[0].as_default():
+ for ele in variables.global_variables():
+ self.assertTrue(ele.op.name.find('beta') < 0)
+ if ele.op.name.find('global_center_variable') < 0:
+ self.assertTrue(ele.op.name.find('Adam') < 0)
+
# Verify the initialized value.
self.assertAllEqual(0.0, sessions[0].run(var_0))
self.assertAllEqual(0.5, sessions[0].run(var_1))