diff options
author | weidankong <kongweidan84@gmail.com> | 2018-08-13 18:53:24 -0700 |
---|---|---|
committer | weidankong <kongweidan84@gmail.com> | 2018-08-13 18:53:24 -0700 |
commit | cd89c1bc76474cc0e5179ff647a81deb51bba25b (patch) | |
tree | 36e6fa5d8e026fe3a19005b879754395ce8c2883 /tensorflow/contrib/opt | |
parent | 167487ebf7e50e13779fb344038b2002056e9b81 (diff) |
update according review comments
Diffstat (limited to 'tensorflow/contrib/opt')
-rw-r--r-- | tensorflow/contrib/opt/python/training/elastic_average_optimizer.py | 36 | ||||
-rw-r--r-- | tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py | 3 |
2 files changed, 16 insertions, 23 deletions
diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py index be72ef3767..0554c43c18 100644 --- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py +++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py @@ -169,7 +169,7 @@ class ElasticAverageOptimizer(optimizer.Optimizer): ea_custom_getter, communication_period=10, moving_rate=None, - rho=0.0, + rho=None, use_locking=True, sync_flag=False, name='ElasticAverageOptimizer'): @@ -183,11 +183,16 @@ class ElasticAverageOptimizer(optimizer.Optimizer): communication_period: An int point value to controls the frequency of the communication between every worker and the ps. moving_rate: A floating point value to control the elastic difference. - rho: the amount of exploration we allow ine the model. The default + rho: the amount of exploration we allow in the model. The default value is moving_rate/learning_rate + rho=0.0 is suggested in async mode. use_locking: If True use locks for update operations. - sync_flag: Add_sync_queues_and_barrier or not, default to False, in case of - restarting a worker,the worker won't hung there. + sync_flag: Add_sync_queues_and_barrier or not. + True: all workers will wait for each other before start training + False: worker can start training when its initilization is done, + no need to wait for everyone is ready. + in case one worker is restarted, it can join and continue + training without being blocked. name: Optional name prefix for the operations created when applying gradients. Defaults to "ElasticAverageOptimizer". """ @@ -291,29 +296,28 @@ class ElasticAverageOptimizer(optimizer.Optimizer): TypeError: If `grads_and_vars` is malformed. ValueError: If none of the variables have gradients. """ + global_old = set(n.op.name for n in variables.global_variables()) apply_updates = self._opt.apply_gradients(grads_and_vars) + global_new = set(n.op.name for n in variables.global_variables()) with ops.control_dependencies([apply_updates]): local_update = state_ops.assign_add( self._local_step, 1, name='local_step_update').op # this is for place the variables created by optimizer to local collection # e.g., AdamOptimizer will create beta as global variables - def _adjust_optimizer_variable_collection(): + def _adjust_optimizer_variable_collection(opt_vars): g = ops.get_default_graph() - # global to local & clear global idx = 0 for _ in range(len(g._collections[ops.GraphKeys.GLOBAL_VARIABLES])): var = g._collections[ops.GraphKeys.GLOBAL_VARIABLES][idx] name = var.op.name - if GLOBAL_STEP not in name.split('/') \ - and var not in ops.get_collection(GLOBAL_SHARE_VARS) \ - and name.find(GLOBAL_VARIABLE_NAME) == -1: + if name in opt_vars: ops.add_to_collection(ops.GraphKeys.LOCAL_VARIABLES, var) del g._collections[ops.GraphKeys.GLOBAL_VARIABLES][idx] else: idx += 1 - _adjust_optimizer_variable_collection() + _adjust_optimizer_variable_collection(global_new - global_old) # update global variables. def _Update_global_variables(): @@ -432,14 +436,10 @@ class ElasticAverageOptimizer(optimizer.Optimizer): var_list = saver.BaseSaverBuilder.OpListToDict(var_list) swapped_var_list = {} - has_global_step = False for key, var in var_list.items(): tensor = var - if False == has_global_step\ - and GLOBAL_STEP in key.split('/'): - has_global_step = True - if isinstance(var, list) == False: + if not isinstance(var, list): for tvar in variables.trainable_variables(): if tvar.op.name == var.op.name: tensor = self._global_map.get(tvar, var) @@ -449,12 +449,6 @@ class ElasticAverageOptimizer(optimizer.Optimizer): swapped_var_list[key] = tensor - # find global_step and add it if missing - if False == has_global_step: - for ele in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES): - if GLOBAL_STEP in ele.op.name.split('/'): - swapped_var_list[ele.op.name] = ele - return saver.Saver(swapped_var_list, name=name, **kwargs) class _ElasticAverageOptimizerHook(session_run_hook.SessionRunHook): diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py index 8a8f7ab080..acb663d628 100644 --- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py @@ -207,8 +207,7 @@ class ElasticAverageOptimizerTest(test.TestCase): v0 = variable_scope.get_variable(initializer=0.0, name="v0") v1 = variable_scope.get_variable(initializer=1.0, name="v1") sess.run(variables.local_variables_initializer()) - global_step = training_util.get_or_create_global_step() - saver_opt = saver.Saver(var_list=[v1, v0, global_step]) + saver_opt = saver.Saver(var_list=[v1, v0]) saver_opt.restore(sess, './model/model') self.assertAllEqual(2.0, sess.run(v0)) self.assertAllEqual(3.0, sess.run(v1)) |