diff options
author | weidankong <kongweidan84@gmail.com> | 2018-08-27 15:59:54 -0700 |
---|---|---|
committer | weidankong <kongweidan84@gmail.com> | 2018-08-27 15:59:54 -0700 |
commit | 8d226fe074d18aadf98a869755e7d432341ba882 (patch) | |
tree | 70f53fdf2279c4f768753a194c72192b9b445e16 /tensorflow/contrib/opt | |
parent | 607004e583ecbd9fb788aaf9b360a8d85cf167ac (diff) |
AGN: use variable_creator_scope to move variables from GLOBAL_VARIABLES to LOCAL VARIABLES
Diffstat (limited to 'tensorflow/contrib/opt')
-rw-r--r-- | tensorflow/contrib/opt/python/training/agn_optimizer.py | 15 | ||||
-rw-r--r-- | tensorflow/contrib/opt/python/training/agn_optimizer_test.py | 12 |
2 files changed, 20 insertions, 7 deletions
diff --git a/tensorflow/contrib/opt/python/training/agn_optimizer.py b/tensorflow/contrib/opt/python/training/agn_optimizer.py index f47ef5acc5..8f415c75b9 100644 --- a/tensorflow/contrib/opt/python/training/agn_optimizer.py +++ b/tensorflow/contrib/opt/python/training/agn_optimizer.py @@ -166,12 +166,17 @@ class AGNOptimizer(optimizer.Optimizer): """ local_vars = [v for g, v in grads_and_vars if g is not None] grads = [g for g, v in grads_and_vars if g is not None] + def _variable_creator(next_creator, collections, **kwargs): + if not collections: + collections = [ops.GraphKeys.LOCAL_VARIABLES] + elif ops.GraphKeys.GLOBAL_VARIABLES in collections: + collections = list(collections) + collections.append(ops.GraphKeys.LOCAL_VARIABLES) + collections.remove(ops.GraphKeys.GLOBAL_VARIABLES) + return next_creator(collections=collections, **kwargs) # theta = theta - lr * grad - global_old = set(n.op.name for n in variables.global_variables()) - local_update_op = self._opt.apply_gradients(grads_and_vars) - global_new = set(n.op.name for n in variables.global_variables()) - - self._adjust_optimizer_variable_collection(global_new - global_old) + with variable_scope.variable_creator_scope(_variable_creator): + local_update_op = self._opt.apply_gradients(grads_and_vars) # a = a + grad update_ops = [] diff --git a/tensorflow/contrib/opt/python/training/agn_optimizer_test.py b/tensorflow/contrib/opt/python/training/agn_optimizer_test.py index 4e2200fa1a..a2302d2f11 100644 --- a/tensorflow/contrib/opt/python/training/agn_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/agn_optimizer_test.py @@ -23,10 +23,11 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import variables from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test from tensorflow.python.training import device_setter -from tensorflow.python.training import momentum +from tensorflow.python.training import adam from tensorflow.python.training import server_lib from tensorflow.python.training import training from tensorflow.python.training import training_util @@ -100,7 +101,7 @@ def _get_workers(num_workers, period, workers, num_ps=1): grads_part_1 = constant_op.constant([[-1., -1., -1., -1.]]) optimizer = \ - momentum.MomentumOptimizer(learning_rate=0.1, momentum=0.0) + adam.AdamOptimizer(learning_rate=0.1, beta1=0.0, beta2=0.0) opt = AGNOptimizer( optimizer, num_worker=num_workers, @@ -152,6 +153,13 @@ class AGNOptimizerTest(test.TestCase): var_0_g = graphs[0].get_tensor_by_name(GLOBAL_VARIABLE_NAME + "/v0:0") var_1_g = graphs[0].get_tensor_by_name(GLOBAL_VARIABLE_NAME + "/v1:0") + # verify adam/beta variables not in global collection + with graphs[0].as_default(): + for ele in variables.global_variables(): + self.assertTrue(ele.op.name.find('beta') < 0) + if ele.op.name.find('global_center_variable') < 0: + self.assertTrue(ele.op.name.find('Adam') < 0) + # Verify the initialized value. self.assertAllEqual(0.0, sessions[0].run(var_0)) self.assertAllEqual(0.5, sessions[0].run(var_1)) |