aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/opt
diff options
context:
space:
mode:
authorGravatar weidankong <kongweidan84@gmail.com>2018-08-27 17:03:47 -0700
committerGravatar weidankong <kongweidan84@gmail.com>2018-08-27 17:03:47 -0700
commit540ca4a8755a3670920b49647860d085df834a00 (patch)
treec12407c0e5a8e4eabf62738da9caad18e86e7981 /tensorflow/contrib/opt
parent8d226fe074d18aadf98a869755e7d432341ba882 (diff)
AGN: fix Sanity test
Diffstat (limited to 'tensorflow/contrib/opt')
-rw-r--r--tensorflow/contrib/opt/python/training/agn_optimizer.py19
-rw-r--r--tensorflow/contrib/opt/python/training/agn_optimizer_test.py37
2 files changed, 21 insertions, 35 deletions
diff --git a/tensorflow/contrib/opt/python/training/agn_optimizer.py b/tensorflow/contrib/opt/python/training/agn_optimizer.py
index 8f415c75b9..9fb5be56e6 100644
--- a/tensorflow/contrib/opt/python/training/agn_optimizer.py
+++ b/tensorflow/contrib/opt/python/training/agn_optimizer.py
@@ -19,7 +19,6 @@ from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
@@ -132,20 +131,6 @@ class AGNOptimizer(optimizer.Optimizer):
name='local_step')
self._opt._prepare()
- def _adjust_optimizer_variable_collection(self, opt_vars):
- """ Move optimizer created variables to local collection
- """
- g = ops.get_default_graph()
- idx = 0
- for _ in range(len(g._collections[ops.GraphKeys.GLOBAL_VARIABLES])):
- var = g._collections[ops.GraphKeys.GLOBAL_VARIABLES][idx]
- name = var.op.name
- if name in opt_vars:
- ops.add_to_collection(ops.GraphKeys.LOCAL_VARIABLES, var)
- del g._collections[ops.GraphKeys.GLOBAL_VARIABLES][idx]
- else:
- idx += 1
-
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
"""Apply gradients to global variables.
@@ -182,7 +167,7 @@ class AGNOptimizer(optimizer.Optimizer):
update_ops = []
update_ops.append(local_update_op)
grad_vars = [self._grad_map[var] for var in local_vars]
- for g, grad_var in zip (grads, grad_vars):
+ for g, grad_var in zip(grads, grad_vars):
update_ops.append(state_ops.assign_add(grad_var, g))
global_center_vars = [self._global_map[var] for var in local_vars]
@@ -215,7 +200,7 @@ class AGNOptimizer(optimizer.Optimizer):
return variable_update
local_update = state_ops.assign_add(
- self._local_step, 1, name='local_step_update').op
+ self._local_step, 1, name='local_step_update').op
with ops.control_dependencies([local_update]):
condition = math_ops.equal(
diff --git a/tensorflow/contrib/opt/python/training/agn_optimizer_test.py b/tensorflow/contrib/opt/python/training/agn_optimizer_test.py
index a2302d2f11..28732c2a1d 100644
--- a/tensorflow/contrib/opt/python/training/agn_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/agn_optimizer_test.py
@@ -72,9 +72,9 @@ def _get_workers(num_workers, period, workers, num_ps=1):
with graph.as_default():
worker_device = "/job:worker/task:%d/cpu:0" % (worker_id)
ps_device = device_setter.replica_device_setter(
- worker_device=worker_device,
- ps_device="/job:ps/task:0/cpu:0",
- ps_tasks=1)
+ worker_device=worker_device,
+ ps_device="/job:ps/task:0/cpu:0",
+ ps_tasks=1)
agn_getter = AGNCustomGetter(worker_device=worker_device)
with variable_scope.variable_scope(
"", custom_getter=agn_getter), ops.device(ps_device):
@@ -82,7 +82,8 @@ def _get_workers(num_workers, period, workers, num_ps=1):
var_0 = variable_scope.get_variable(initializer=0.0, name="v0")
var_1 = variable_scope.get_variable(initializer=0.5, name="v1")
if num_ps > 1:
- with variable_scope.variable_scope("",
+ with variable_scope.variable_scope(
+ "",
partitioner=partitioned_variables.fixed_size_partitioner(
num_ps, axis=0),
custom_getter=agn_getter), ops.device(ps_device):
@@ -109,12 +110,12 @@ def _get_workers(num_workers, period, workers, num_ps=1):
custom_getter=agn_getter)
if num_ps == 1:
train_op = [
- opt.apply_gradients(([grads_0, var_0], [grads_1, var_1]),
+ opt.apply_gradients(([grads_0, var_0], [grads_1, var_1]),
global_step)
]
else:
train_op = [
- opt.apply_gradients(([grads_0, var_0],
+ opt.apply_gradients(([grads_0, var_0],
[grads_1, var_1],
[grads_part_0, part_0],
[grads_part_1, part_1]),
@@ -232,20 +233,20 @@ class AGNOptimizerTest(test.TestCase):
sessions[0].run(train_ops[0])
self.assertNear(0.1, sessions[0].run(var_0_g), 1e-6)
self.assertNDArrayNear([0.1, 0.1, 0.1, 0.1],
- sessions[0].run(part_0_g),
- 1e-6)
+ sessions[0].run(part_0_g),
+ 1e-6)
self.assertNDArrayNear([0.1, 0.1, 0.1, 0.1],
- sessions[0].run(part_1_g),
- 1e-6)
+ sessions[0].run(part_1_g),
+ 1e-6)
sessions[1].run(train_ops[1])
self.assertNear(0.2, sessions[0].run(var_0_g), 1e-6)
self.assertNDArrayNear([0.2, 0.2, 0.2, 0.2],
- sessions[0].run(part_0_g),
- 1e-6)
+ sessions[0].run(part_0_g),
+ 1e-6)
self.assertNDArrayNear([0.2, 0.2, 0.2, 0.2],
- sessions[0].run(part_1_g),
- 1e-6)
+ sessions[0].run(part_1_g),
+ 1e-6)
sessions[0].run(train_ops[0])
sessions[1].run(train_ops[1])
@@ -254,11 +255,11 @@ class AGNOptimizerTest(test.TestCase):
sessions[1].run(train_ops[1])
self.assertNear(0.6, sessions[0].run(var_0_g), 1e-6)
self.assertNDArrayNear([0.6, 0.6, 0.6, 0.6],
- sessions[0].run(part_0_g),
- 1e-6)
+ sessions[0].run(part_0_g),
+ 1e-6)
self.assertNDArrayNear([0.6, 0.6, 0.6, 0.6],
- sessions[0].run(part_1_g),
- 1e-6)
+ sessions[0].run(part_1_g),
+ 1e-6)
def testAGNCustomGetter(self):
cluster_spec = server_lib.ClusterSpec({