aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/opt
diff options
context:
space:
mode:
authorGravatar weidankong <kongweidan84@gmail.com>2018-08-13 18:53:24 -0700
committerGravatar weidankong <kongweidan84@gmail.com>2018-08-13 18:53:24 -0700
commitcd89c1bc76474cc0e5179ff647a81deb51bba25b (patch)
tree36e6fa5d8e026fe3a19005b879754395ce8c2883 /tensorflow/contrib/opt
parent167487ebf7e50e13779fb344038b2002056e9b81 (diff)
update according review comments
Diffstat (limited to 'tensorflow/contrib/opt')
-rw-r--r--tensorflow/contrib/opt/python/training/elastic_average_optimizer.py36
-rw-r--r--tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py3
2 files changed, 16 insertions, 23 deletions
diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py
index be72ef3767..0554c43c18 100644
--- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py
+++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py
@@ -169,7 +169,7 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
ea_custom_getter,
communication_period=10,
moving_rate=None,
- rho=0.0,
+ rho=None,
use_locking=True,
sync_flag=False,
name='ElasticAverageOptimizer'):
@@ -183,11 +183,16 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
communication_period: An int point value to controls the frequency
of the communication between every worker and the ps.
moving_rate: A floating point value to control the elastic difference.
- rho: the amount of exploration we allow ine the model. The default
+ rho: the amount of exploration we allow in the model. The default
value is moving_rate/learning_rate
+ rho=0.0 is suggested in async mode.
use_locking: If True use locks for update operations.
- sync_flag: Add_sync_queues_and_barrier or not, default to False, in case of
- restarting a worker,the worker won't hung there.
+ sync_flag: Add_sync_queues_and_barrier or not.
+ True: all workers will wait for each other before start training
+ False: worker can start training when its initilization is done,
+ no need to wait for everyone is ready.
+ in case one worker is restarted, it can join and continue
+ training without being blocked.
name: Optional name prefix for the operations created when applying
gradients. Defaults to "ElasticAverageOptimizer".
"""
@@ -291,29 +296,28 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
TypeError: If `grads_and_vars` is malformed.
ValueError: If none of the variables have gradients.
"""
+ global_old = set(n.op.name for n in variables.global_variables())
apply_updates = self._opt.apply_gradients(grads_and_vars)
+ global_new = set(n.op.name for n in variables.global_variables())
with ops.control_dependencies([apply_updates]):
local_update = state_ops.assign_add(
self._local_step, 1, name='local_step_update').op
# this is for place the variables created by optimizer to local collection
# e.g., AdamOptimizer will create beta as global variables
- def _adjust_optimizer_variable_collection():
+ def _adjust_optimizer_variable_collection(opt_vars):
g = ops.get_default_graph()
- # global to local & clear global
idx = 0
for _ in range(len(g._collections[ops.GraphKeys.GLOBAL_VARIABLES])):
var = g._collections[ops.GraphKeys.GLOBAL_VARIABLES][idx]
name = var.op.name
- if GLOBAL_STEP not in name.split('/') \
- and var not in ops.get_collection(GLOBAL_SHARE_VARS) \
- and name.find(GLOBAL_VARIABLE_NAME) == -1:
+ if name in opt_vars:
ops.add_to_collection(ops.GraphKeys.LOCAL_VARIABLES, var)
del g._collections[ops.GraphKeys.GLOBAL_VARIABLES][idx]
else:
idx += 1
- _adjust_optimizer_variable_collection()
+ _adjust_optimizer_variable_collection(global_new - global_old)
# update global variables.
def _Update_global_variables():
@@ -432,14 +436,10 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
var_list = saver.BaseSaverBuilder.OpListToDict(var_list)
swapped_var_list = {}
- has_global_step = False
for key, var in var_list.items():
tensor = var
- if False == has_global_step\
- and GLOBAL_STEP in key.split('/'):
- has_global_step = True
- if isinstance(var, list) == False:
+ if not isinstance(var, list):
for tvar in variables.trainable_variables():
if tvar.op.name == var.op.name:
tensor = self._global_map.get(tvar, var)
@@ -449,12 +449,6 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
swapped_var_list[key] = tensor
- # find global_step and add it if missing
- if False == has_global_step:
- for ele in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES):
- if GLOBAL_STEP in ele.op.name.split('/'):
- swapped_var_list[ele.op.name] = ele
-
return saver.Saver(swapped_var_list, name=name, **kwargs)
class _ElasticAverageOptimizerHook(session_run_hook.SessionRunHook):
diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py
index 8a8f7ab080..acb663d628 100644
--- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py
@@ -207,8 +207,7 @@ class ElasticAverageOptimizerTest(test.TestCase):
v0 = variable_scope.get_variable(initializer=0.0, name="v0")
v1 = variable_scope.get_variable(initializer=1.0, name="v1")
sess.run(variables.local_variables_initializer())
- global_step = training_util.get_or_create_global_step()
- saver_opt = saver.Saver(var_list=[v1, v0, global_step])
+ saver_opt = saver.Saver(var_list=[v1, v0])
saver_opt.restore(sess, './model/model')
self.assertAllEqual(2.0, sess.run(v0))
self.assertAllEqual(3.0, sess.run(v1))