aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/opt
diff options
context:
space:
mode:
authorGravatar weidan.kong <weidan.kong@alibaba-inc.com>2018-08-08 11:50:11 -0700
committerGravatar weidan.kong <weidan.kong@alibaba-inc.com>2018-08-08 11:50:11 -0700
commit167487ebf7e50e13779fb344038b2002056e9b81 (patch)
tree3e7b9d3dabde93d8b297a4021549a93b23fdda7d /tensorflow/contrib/opt
parent2e2486adedb5164b82b0c2fcb8b1d27f987c1428 (diff)
elastic averaging SGD update: support partitioner & more optimizers
Diffstat (limited to 'tensorflow/contrib/opt')
-rw-r--r--tensorflow/contrib/opt/python/training/elastic_average_optimizer.py183
-rw-r--r--tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py107
2 files changed, 246 insertions, 44 deletions
diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py
index 5763593b81..be72ef3767 100644
--- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py
+++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py
@@ -17,22 +17,26 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.ops import math_ops
-
-from tensorflow.python.ops import gen_nn_ops
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import gen_nn_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import optimizer
+from tensorflow.python.training import saver
from tensorflow.python.training import session_run_hook
-from tensorflow.python.ops import state_ops
-from tensorflow.python.ops import data_flow_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import constant_op
LOCAL_VARIABLE_NAME = 'local_center_variable'
GLOBAL_VARIABLE_NAME = 'global_center_variable'
+GLOBAL_SHARE_VARS = 'global_share_var'
+GLOBAL_STEP = 'global_step'
class ElasticAverageCustomGetter(object):
@@ -52,16 +56,32 @@ class ElasticAverageCustomGetter(object):
with tf.device(
tf.train.replica_device_setter(
worker_device=worker_device,
- ps_device="/job:ps/cpu:0",
+ ps_device="/job:ps",
cluster=cluster)),
tf.variable_scope('',custom_getter=ea_custom_getter):
- hid_w = tf.get_variable(
- initializer=tf.truncated_normal(
- [IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units],
- stddev=1.0 / IMAGE_PIXELS),
- name="hid_w")
- hid_b = tf.get_variable(initializer=tf.zeros([FLAGS.hidden_units]),
- name="hid_b")
+ ...
+ create your model here
+ ...
+ with tf.device(worker_device):
+ opt = tf.train.MomentumOptimizer(...)
+ optimizer = ElasticAverageOptimizer(
+ opt,
+ num_worker=2,
+ moving_rate=0.01, # or use default value
+ communication_period=20,
+ ea_custom_getter=ea_custom_getter)
+ ...
+ train_op = optimizer.apply_gradients(
+ grads_vars,
+ global_step=global_step)
+ ...
+ hooks = [optimizer.make_session_run_hook(is_chief, task_index)]
+ ...
+ with tf.train.MonitoredTrainingSession(master=server.target,
+ is_chief=is_chief,
+ checkpoint_dir=("...),
+ save_checkpoint_secs=600,
+ hooks=hooks) as mon_sess:
"""
def __init__(self, worker_device):
@@ -83,24 +103,50 @@ class ElasticAverageCustomGetter(object):
collections=[ops.GraphKeys.LOCAL_VARIABLES],
*args,
**kwargs)
- global_center_variable = variable_scope.variable(
+ if kwargs['reuse'] == True:
+ return local_var
+ global_center_variable = getter(
name='%s/%s' % (GLOBAL_VARIABLE_NAME, name),
- initial_value=local_var.initialized_value(),
trainable=False,
- collections=[ops.GraphKeys.GLOBAL_VARIABLES])
+ collections=[ops.GraphKeys.GLOBAL_VARIABLES],
+ *args,
+ **kwargs)
with ops.device(self._worker_device):
- local_center_variable = variable_scope.variable(
+ local_center_variable = getter(
name='%s/%s' % (LOCAL_VARIABLE_NAME, name),
- initial_value=local_var.initialized_value(),
trainable=False,
- collections=[ops.GraphKeys.LOCAL_VARIABLES])
-
- self._local_map[local_var] = local_center_variable
- self._global_map[local_var] = global_center_variable
+ collections=[ops.GraphKeys.LOCAL_VARIABLES],
+ *args,
+ **kwargs)
+ if kwargs['partitioner'] is None:
+ self._local_map[local_var] = local_center_variable
+ self._global_map[local_var] = global_center_variable
+ else:
+ v_list = list(local_var)
+ for i in range(len(v_list)):
+ self._local_map[v_list[i]] \
+ = list(local_center_variable)[i]
+ self._global_map[v_list[i]] \
+ = list(global_center_variable)[i]
return local_var
else:
- return getter(name, trainable, collections, *args, **kwargs)
+ # 1. default to LOCAL_VARIABLES (instead of GLOBAL_VARIABLES)
+ # 2. put to global if explicitly defined (GLOBAL_SHARE_VARS)
+ # 3. other GLOBAL_VARIABLES put to LOCAL_VARIABLES
+ # exept global_step, which must be global
+ if collections is None or len(collections) == 0:
+ collections = [ops.GraphKeys.LOCAL_VARIABLES]
+ elif GLOBAL_SHARE_VARS in collections:
+ collections = list(collections)
+ if ops.GraphKeys.GLOBAL_VARIABLES not in collections:
+ collections.append(ops.GraphKeys.GLOBAL_VARIABLES)
+ elif ops.GraphKeys.GLOBAL_VARIABLES in collections \
+ and GLOBAL_STEP not in name.split('/'):
+ collections = list(collections)
+ collections.append(ops.GraphKeys.LOCAL_VARIABLES)
+ collections.remove(ops.GraphKeys.GLOBAL_VARIABLES)
+ return getter(name, trainable=trainable, collections=collections, *args, **kwargs)
class ElasticAverageOptimizer(optimizer.Optimizer):
@@ -123,8 +169,9 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
ea_custom_getter,
communication_period=10,
moving_rate=None,
- rho=None,
+ rho=0.0,
use_locking=True,
+ sync_flag=False,
name='ElasticAverageOptimizer'):
"""Construct a new gradient descent optimizer.
@@ -139,6 +186,8 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
rho: the amount of exploration we allow ine the model. The default
value is moving_rate/learning_rate
use_locking: If True use locks for update operations.
+ sync_flag: Add_sync_queues_and_barrier or not, default to False, in case of
+ restarting a worker,the worker won't hung there.
name: Optional name prefix for the operations created when applying
gradients. Defaults to "ElasticAverageOptimizer".
"""
@@ -148,6 +197,7 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
self._period = communication_period
self._local_map = ea_custom_getter._local_map
self._global_map = ea_custom_getter._global_map
+ self._sync_flag = sync_flag
if moving_rate is None:
self._moving_rate = self.BETA / communication_period / num_worker
@@ -246,6 +296,25 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
local_update = state_ops.assign_add(
self._local_step, 1, name='local_step_update').op
+ # this is for place the variables created by optimizer to local collection
+ # e.g., AdamOptimizer will create beta as global variables
+ def _adjust_optimizer_variable_collection():
+ g = ops.get_default_graph()
+ # global to local & clear global
+ idx = 0
+ for _ in range(len(g._collections[ops.GraphKeys.GLOBAL_VARIABLES])):
+ var = g._collections[ops.GraphKeys.GLOBAL_VARIABLES][idx]
+ name = var.op.name
+ if GLOBAL_STEP not in name.split('/') \
+ and var not in ops.get_collection(GLOBAL_SHARE_VARS) \
+ and name.find(GLOBAL_VARIABLE_NAME) == -1:
+ ops.add_to_collection(ops.GraphKeys.LOCAL_VARIABLES, var)
+ del g._collections[ops.GraphKeys.GLOBAL_VARIABLES][idx]
+ else:
+ idx += 1
+
+ _adjust_optimizer_variable_collection()
+
# update global variables.
def _Update_global_variables():
local_vars = [v for g, v in grads_and_vars if g is not None]
@@ -290,7 +359,7 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
variables equal to the global center variables before the training begins"""
def _Add_sync_queues_and_barrier(enqueue_after_list):
- """Adds ops to enqueu on all worker queues"""
+ """Adds ops to enqueue on all worker queues"""
sync_queues = [
data_flow_ops.FIFOQueue(
self._num_worker, [dtypes.bool],
@@ -324,6 +393,9 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
init_ops.append(state_ops.assign(lc_var, gc_var))
init_op = control_flow_ops.group(*(init_ops))
+ if self._sync_flag == False:
+ return init_op
+
sync_queue_op = _Add_sync_queues_and_barrier([init_op])
return sync_queue_op
@@ -331,6 +403,59 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
"""Creates a hook to handle ElasticAverageOptimizerHook ops such as initialization."""
return _ElasticAverageOptimizerHook(self, is_chief, task_index)
+ def swapping_saver(self, var_list=None, name='swapping_saver', **kwargs):
+ """Create a saver copy global_center_variable to trainable variables
+ Please call this function after all your variables created with EACustomGetter.
+ For evaluations or inference, use this saver during training. It will save the
+ global_center_variable of the trained parameters under the original parameter names.
+ Args:
+ var_list: List of variables to save, as per `Saver()`.
+ If set to None, will save all the trainable_variables that have been
+ created before this call.
+ name: The name of the saver.
+ **kwargs: Keyword arguments of `Saver()`.
+ Returns:
+ A `tf.train.Saver` object.
+ Raises:
+ RuntimeError: global_center_variable is empty, please make sure
+ this is called after model created and
+ EACustomGetter is used when declaring you model
+ """
+ if not self._global_map:
+ raise RuntimeError('global_center_variable is empty, please make sure '
+ 'this is called after model created and '
+ 'ElasticAverageCustomGetter is used when declaring you model')
+
+ if var_list is None:
+ var_list = variables.trainable_variables()
+ if not isinstance(var_list, dict):
+ var_list = saver.BaseSaverBuilder.OpListToDict(var_list)
+
+ swapped_var_list = {}
+ has_global_step = False
+ for key, var in var_list.items():
+ tensor = var
+ if False == has_global_step\
+ and GLOBAL_STEP in key.split('/'):
+ has_global_step = True
+
+ if isinstance(var, list) == False:
+ for tvar in variables.trainable_variables():
+ if tvar.op.name == var.op.name:
+ tensor = self._global_map.get(tvar, var)
+ break
+ else: #partitioned variable
+ tensor = [self._global_map.get(lvar, lvar) for lvar in var]
+
+ swapped_var_list[key] = tensor
+
+ # find global_step and add it if missing
+ if False == has_global_step:
+ for ele in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES):
+ if GLOBAL_STEP in ele.op.name.split('/'):
+ swapped_var_list[ele.op.name] = ele
+
+ return saver.Saver(swapped_var_list, name=name, **kwargs)
class _ElasticAverageOptimizerHook(session_run_hook.SessionRunHook):
@@ -351,3 +476,7 @@ class _ElasticAverageOptimizerHook(session_run_hook.SessionRunHook):
if self._is_chief:
self._global_init_op = variables.global_variables_initializer()
self._variable_init_op = self._ea_optimizer.get_init_op(self._task_index)
+
+ def after_create_session(self, session, coord):
+ """Run initialization ops"""
+ session.run(self._variable_init_op) \ No newline at end of file
diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py
index 5ed8057b86..8a8f7ab080 100644
--- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py
@@ -18,19 +18,23 @@ from __future__ import division
from __future__ import print_function
import portpicker
+from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
+from tensorflow.python.training import device_setter
from tensorflow.python.training import gradient_descent
+from tensorflow.python.training import saver
from tensorflow.python.training import server_lib
from tensorflow.python.training import training
from tensorflow.python.training import training_util
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.training import device_setter
from tensorflow.contrib.opt.python.training.elastic_average_optimizer import \
- ElasticAverageOptimizer, ElasticAverageCustomGetter, GLOBAL_VARIABLE_NAME
+ ElasticAverageOptimizer, ElasticAverageCustomGetter, GLOBAL_VARIABLE_NAME, GLOBAL_SHARE_VARS
def create_local_cluster(num_workers, num_ps, protocol="grpc"):
@@ -59,29 +63,49 @@ def create_local_cluster(num_workers, num_ps, protocol="grpc"):
# Creates the workers and return their sessions, graphs, train_ops.
# Chief worker will update at last
-def _get_workers(num_workers, period, workers, moving_rate):
+def _get_workers(num_workers, period, workers, moving_rate, num_ps=1):
sessions = []
graphs = []
train_ops = []
+ savers = []
for worker_id in range(num_workers):
graph = ops.Graph()
is_chief = (worker_id == 0)
with graph.as_default():
worker_device = "/job:worker/task:%d/cpu:0" % (worker_id)
- ea_coustom = ElasticAverageCustomGetter(worker_device=worker_device)
+ ea_custom = ElasticAverageCustomGetter(worker_device=worker_device)
with variable_scope.variable_scope(
- "", custom_getter=ea_coustom), ops.device(
+ "", custom_getter=ea_custom), ops.device(
device_setter.replica_device_setter(
worker_device=worker_device,
ps_device="/job:ps/task:0/cpu:0",
ps_tasks=1)):
- global_step = variables.Variable(0, name="global_step", trainable=False)
+ global_step = training_util.get_or_create_global_step()
+ local_var = variable_scope.get_variable(initializer=0.0, trainable=False,
+ name='local_var', collections=[ops.GraphKeys.GLOBAL_VARIABLES])
+ global_var = variable_scope.get_variable(initializer=0.0, trainable=False,
+ name='global_var', collections=[GLOBAL_SHARE_VARS])
var_0 = variable_scope.get_variable(initializer=0.0, name="v0")
var_1 = variable_scope.get_variable(initializer=1.0, name="v1")
+ if num_ps > 1:
+ with variable_scope.variable_scope("",
+ partitioner=partitioned_variables.fixed_size_partitioner(num_ps, axis=0),
+ custom_getter=ea_custom), ops.device(
+ device_setter.replica_device_setter(
+ worker_device=worker_device,
+ ps_device="/job:ps/task:0/cpu:0",
+ ps_tasks=num_ps)):
+
+ partition_var = variable_scope.get_variable(
+ 'partition_var',shape=[2,4], initializer=init_ops.ones_initializer)
+ part_0 = list(partition_var)[0]
+ part_1 = list(partition_var)[1]
with ops.device("/job:worker/task:" + str(worker_id)):
grads_0 = constant_op.constant(-1.0)
grads_1 = constant_op.constant(-1.0)
+ grads_part_0 = constant_op.constant([[-1., -1., -1., -1.]])
+ grads_part_1 = constant_op.constant([[-1., -1., -1., -1.]])
sgd_opt = gradient_descent.GradientDescentOptimizer(1.0)
opt = ElasticAverageOptimizer(
@@ -89,12 +113,20 @@ def _get_workers(num_workers, period, workers, moving_rate):
num_worker=num_workers,
moving_rate=moving_rate,
communication_period=period,
- ea_custom_getter=ea_coustom)
- train_op = [
+ ea_custom_getter=ea_custom)
+ if num_ps == 1:
+ train_op = [
opt.apply_gradients(([grads_0, var_0], [grads_1, var_1]),
global_step)
- ]
+ ]
+ else:
+ train_op = [
+ opt.apply_gradients(([grads_0, var_0], [grads_1, var_1],
+ [grads_part_0, part_0], [grads_part_1, part_1]),
+ global_step)
+ ]
easgd_hook = opt.make_session_run_hook(is_chief, worker_id)
+ saver = opt.swapping_saver()
# Creates MonitoredSession
sess = training.MonitoredTrainingSession(
workers[worker_id].target, hooks=[easgd_hook])
@@ -102,8 +134,9 @@ def _get_workers(num_workers, period, workers, moving_rate):
sessions.append(sess)
graphs.append(graph)
train_ops.append(train_op)
+ savers.append(saver)
- return sessions, graphs, train_ops
+ return sessions, graphs, train_ops, savers
class ElasticAverageOptimizerTest(test.TestCase):
@@ -118,7 +151,7 @@ class ElasticAverageOptimizerTest(test.TestCase):
cluster, workers, _ = create_local_cluster(
num_workers=num_workers, num_ps=num_ps)
- sessions, graphs, train_ops = _get_workers(
+ sessions, graphs, train_ops, savers = _get_workers(
num_workers, communication_period, workers, 1.0)
var_0 = graphs[0].get_tensor_by_name("v0:0")
@@ -150,6 +183,13 @@ class ElasticAverageOptimizerTest(test.TestCase):
self.assertAllEqual(3.0, sessions[0].run(var_1_g))
self.assertAllEqual(1, sessions[0].run(global_step))
+ # verify variables are set to right collection
+ with graphs[0].as_default():
+ for ele in ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES):
+ self.assertNotEqual(ele.op.name, 'global_var')
+ for ele in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES):
+ self.assertNotEqual(ele.op.name, 'local_var')
+
# iteration 3
sessions[0].run(train_ops[0])
@@ -158,6 +198,20 @@ class ElasticAverageOptimizerTest(test.TestCase):
self.assertAllEqual(2.0, sessions[0].run(var_0_g))
self.assertAllEqual(3.0, sessions[0].run(var_1_g))
self.assertAllEqual(1, sessions[0].run(global_step))
+ sessions[0].run(train_ops[0])
+
+ # save, data will be global value
+ savers[0].save(sessions[0]._sess._sess._sess._sess, save_path='./model/model')
+ ops.reset_default_graph() # restore on a new graph
+ with session.Session() as sess:
+ v0 = variable_scope.get_variable(initializer=0.0, name="v0")
+ v1 = variable_scope.get_variable(initializer=1.0, name="v1")
+ sess.run(variables.local_variables_initializer())
+ global_step = training_util.get_or_create_global_step()
+ saver_opt = saver.Saver(var_list=[v1, v0, global_step])
+ saver_opt.restore(sess, './model/model')
+ self.assertAllEqual(2.0, sess.run(v0))
+ self.assertAllEqual(3.0, sess.run(v1))
def test2Worker1Period(self):
num_workers = 2
@@ -166,8 +220,8 @@ class ElasticAverageOptimizerTest(test.TestCase):
cluster, workers, _ = create_local_cluster(
num_workers=num_workers, num_ps=num_ps)
- sessions, graphs, train_ops = _get_workers(
- num_workers, communication_period, workers, 0.5)
+ sessions, graphs, train_ops, savers = _get_workers(
+ num_workers, communication_period, workers, 0.5, num_ps=2)
var_0 = graphs[0].get_tensor_by_name("v0:0")
var_1 = graphs[0].get_tensor_by_name("v1:0")
@@ -177,6 +231,8 @@ class ElasticAverageOptimizerTest(test.TestCase):
var_0_g = graphs[0].get_tensor_by_name(GLOBAL_VARIABLE_NAME + "/v0:0")
var_1_g = graphs[0].get_tensor_by_name(GLOBAL_VARIABLE_NAME + "/v1:0")
+ part_0_g = graphs[0].get_tensor_by_name(GLOBAL_VARIABLE_NAME + "/partition_var/part_0:0")
+
# Verify the initialized value.
self.assertAllEqual(0.0, sessions[0].run(var_0))
self.assertAllEqual(1.0, sessions[0].run(var_1))
@@ -194,22 +250,39 @@ class ElasticAverageOptimizerTest(test.TestCase):
self.assertAllEqual(1.75, sessions[0].run(var_1_g))
self.assertAllEqual(0.75, sessions[1].run(var_0_1))
self.assertAllEqual(1.75, sessions[1].run(var_1_1))
+ # part_0 of global_center copy
+ part_0_g = sessions[0].run(part_0_g)
+
+ savers[0].save(sessions[0]._sess._sess._sess._sess, save_path='./model/model')
+
+ # verify restore of partitioned_variables
+ ops.reset_default_graph() # restore on a new graph
+ g = ops.get_default_graph()
+ with session.Session() as sess, g.as_default():
+ with variable_scope.variable_scope("",
+ partitioner=partitioned_variables.fixed_size_partitioner(num_ps, axis=0)):
+ partition_var = variable_scope.get_variable(
+ 'partition_var',shape=[2,4], initializer=init_ops.ones_initializer)
+ s = saver.Saver(var_list=[partition_var])
+ s.restore(sess, './model/model')
+ part_0 = g.get_tensor_by_name('partition_var/part_0:0')
+ self.assertAllEqual(part_0_g, sess.run(part_0))
def testPS2TasksWithClusterSpecClass(self):
cluster_spec = server_lib.ClusterSpec({
"ps": ["ps0:2222", "ps1:2222"],
"worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
})
- ea_coustom = ElasticAverageCustomGetter(worker_device="/job:worker/task:0")
+ ea_custom = ElasticAverageCustomGetter(worker_device="/job:worker/task:0")
from tensorflow.python.training import device_setter
with ops.device(
device_setter.replica_device_setter(cluster=cluster_spec,
worker_device="/job:worker/task:0",
ps_device="/job:ps")), \
- variable_scope.variable_scope("", custom_getter=ea_coustom):
+ variable_scope.variable_scope("", custom_getter=ea_custom):
v = variable_scope.get_variable(initializer=[1, 2], name="v")
w = variable_scope.get_variable(initializer=[2, 1], name="w")
- v_g, w_g = ea_coustom._global_map[v], ea_coustom._global_map[w]
+ v_g, w_g = ea_custom._global_map[v], ea_custom._global_map[w]
self.assertDeviceEqual("/job:worker/task:0", v.device)
self.assertDeviceEqual("job:ps/task:0", v_g.device)
self.assertDeviceEqual("/job:worker/task:0", w.device)