diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-04 05:31:48 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-04 05:34:52 -0700 |
commit | 71b19430e8484b136e0b872f6a543aff8a242587 (patch) | |
tree | b783147361aafd2a412d7c6b0c7ad1a578410e1e /tensorflow/contrib/kfac | |
parent | 1f5324ca69bc1017972eef8e418691cff9a86dd7 (diff) |
Sync replicas distributed training example with two strategies:
1) Interleave covariance and inverse update ops with training op.
2) Run the inverse and covariance ops on separate dedicated workers.
PiperOrigin-RevId: 191579634
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r-- | tensorflow/contrib/kfac/examples/BUILD | 24 | ||||
-rw-r--r-- | tensorflow/contrib/kfac/examples/convnet.py | 315 | ||||
-rw-r--r-- | tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py | 62 | ||||
-rw-r--r-- | tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py | 48 | ||||
-rw-r--r-- | tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py (renamed from tensorflow/contrib/kfac/examples/convnet_mnist_main.py) | 32 | ||||
-rw-r--r-- | tensorflow/contrib/kfac/examples/tests/convnet_test.py | 17 |
6 files changed, 411 insertions, 87 deletions
diff --git a/tensorflow/contrib/kfac/examples/BUILD b/tensorflow/contrib/kfac/examples/BUILD index 7dd40c19c5..8186fa1c62 100644 --- a/tensorflow/contrib/kfac/examples/BUILD +++ b/tensorflow/contrib/kfac/examples/BUILD @@ -28,8 +28,28 @@ py_library( ) py_binary( - name = "convnet_mnist_main", - srcs = ["convnet_mnist_main.py"], + name = "convnet_mnist_single_main", + srcs = ["convnet_mnist_single_main.py"], + srcs_version = "PY2AND3", + deps = [ + ":convnet", + "//tensorflow:tensorflow_py", + ], +) + +py_binary( + name = "convnet_mnist_multi_tower_main", + srcs = ["convnet_mnist_multi_tower_main.py"], + srcs_version = "PY2AND3", + deps = [ + ":convnet", + "//tensorflow:tensorflow_py", + ], +) + +py_binary( + name = "convnet_mnist_distributed_main", + srcs = ["convnet_mnist_distributed_main.py"], srcs_version = "PY2AND3", deps = [ ":convnet", diff --git a/tensorflow/contrib/kfac/examples/convnet.py b/tensorflow/contrib/kfac/examples/convnet.py index 39d80addaa..e8e3353091 100644 --- a/tensorflow/contrib/kfac/examples/convnet.py +++ b/tensorflow/contrib/kfac/examples/convnet.py @@ -37,6 +37,8 @@ import tensorflow as tf from tensorflow.contrib.kfac.examples import mlp from tensorflow.contrib.kfac.examples import mnist +from tensorflow.contrib.kfac.python.ops import optimizer as opt + lc = tf.contrib.kfac.layer_collection oq = tf.contrib.kfac.op_queue @@ -48,12 +50,18 @@ __all__ = [ "linear_layer", "build_model", "minimize_loss_single_machine", - "minimize_loss_distributed", + "distributed_grads_only_and_ops_chief_worker", + "distributed_grads_and_ops_dedicated_workers", "train_mnist_single_machine", - "train_mnist_distributed", + "train_mnist_distributed_sync_replicas", + "train_mnist_multitower" ] +# Inverse update ops will be run every _INVERT_EVRY iterations. +_INVERT_EVERY = 10 + + def conv_layer(layer_id, inputs, kernel_size, out_channels): """Builds a convolutional layer with ReLU non-linearity. @@ -161,8 +169,9 @@ def build_model(examples, labels, num_labels, layer_collection): accuracy = tf.reduce_mean( tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.float32)) - tf.summary.scalar("loss", loss) - tf.summary.scalar("accuracy", accuracy) + with tf.device("/cpu:0"): + tf.summary.scalar("loss", loss) + tf.summary.scalar("accuracy", accuracy) # Register parameters. K-FAC needs to know about the inputs, outputs, and # parameters of each conv/fully connected layer and the logits powering the @@ -181,41 +190,59 @@ def build_model(examples, labels, num_labels, layer_collection): def minimize_loss_single_machine(loss, accuracy, layer_collection, + device="/gpu:0", session_config=None): """Minimize loss with K-FAC on a single machine. - A single Session is responsible for running all of K-FAC's ops. + A single Session is responsible for running all of K-FAC's ops. The covariance + and inverse update ops are placed on `device`. All model variables are on CPU. Args: loss: 0-D Tensor. Loss to be minimized. accuracy: 0-D Tensor. Accuracy of classifier on current minibatch. layer_collection: LayerCollection instance describing model architecture. Used by K-FAC to construct preconditioner. + device: string, Either '/cpu:0' or '/gpu:0'. The covaraince and invserse + update ops are run on this device. session_config: None or tf.ConfigProto. Configuration for tf.Session(). Returns: final value for 'accuracy'. """ # Train with K-FAC. - global_step = tf.train.get_or_create_global_step() + g_step = tf.train.get_or_create_global_step() optimizer = opt.KfacOptimizer( learning_rate=0.0001, cov_ema_decay=0.95, damping=0.001, layer_collection=layer_collection, + placement_strategy="round_robin", + cov_devices=[device], + inv_devices=[device], momentum=0.9) - train_op = optimizer.minimize(loss, global_step=global_step) + (cov_update_thunks, + inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() + + with tf.device(device): + train_op = optimizer.minimize(loss, global_step=g_step) + + def make_update_op(update_thunks): + update_op = [thunk() for thunk in update_thunks] + return tf.group(*update_op) + + cov_update_op = make_update_op(cov_update_thunks) + with tf.control_dependencies([train_op, cov_update_op]): + inverse_op = tf.cond( + tf.equal(tf.mod(g_step + 1, _INVERT_EVERY), 0), + lambda: make_update_op(inv_update_thunks), tf.no_op) tf.logging.info("Starting training.") with tf.train.MonitoredTrainingSession(config=session_config) as sess: while not sess.should_stop(): - global_step_, loss_, accuracy_, _, _ = sess.run( - [global_step, loss, accuracy, train_op, optimizer.cov_update_op]) - - if global_step_ % 100 == 0: - sess.run(optimizer.inv_update_op) + global_step_, loss_, accuracy_, _ = sess.run( + [g_step, loss, accuracy, inverse_op]) - if global_step_ % 100 == 0: + if (global_step_ + 1) % _INVERT_EVERY == 0: tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_, loss_, accuracy_) @@ -250,16 +277,62 @@ def _num_gradient_tasks(num_tasks): return int(np.ceil(0.6 * num_tasks)) -def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master, - checkpoint_dir, loss, accuracy, layer_collection): - """Minimize loss with an synchronous implementation of K-FAC. +def _make_distributed_train_op( + task_id, + num_worker_tasks, + num_ps_tasks, + layer_collection +): + """Creates optimizer and distributed training op. - Different tasks are responsible for different parts of K-FAC's Ops. The first - 60% of tasks update weights; the next 20% accumulate covariance statistics; - the last 20% invert the matrices used to precondition gradients. + Constructs KFAC optimizer and wraps it in `sync_replicas` optimizer. Makes + the train op. + + Args: + task_id: int. Integer in [0, num_worker_tasks). ID for this worker. + num_worker_tasks: int. Number of workers in this distributed training setup. + num_ps_tasks: int. Number of parameter servers holding variables. If 0, + parameter servers are not used. + layer_collection: LayerCollection instance describing model architecture. + Used by K-FAC to construct preconditioner. + + Returns: + sync_optimizer: `tf.train.SyncReplicasOptimizer` instance which wraps KFAC + optimizer. + optimizer: Instance of `opt.KfacOptimizer`. + global_step: `tensor`, Global step. + """ + tf.logging.info("Task id : %d", task_id) + with tf.device(tf.train.replica_device_setter(num_ps_tasks)): + global_step = tf.train.get_or_create_global_step() + optimizer = opt.KfacOptimizer( + learning_rate=0.0001, + cov_ema_decay=0.95, + damping=0.001, + layer_collection=layer_collection, + momentum=0.9) + sync_optimizer = tf.train.SyncReplicasOptimizer( + opt=optimizer, + replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks), + total_num_replicas=num_worker_tasks) + return sync_optimizer, optimizer, global_step + + +def distributed_grads_only_and_ops_chief_worker( + task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir, + loss, accuracy, layer_collection, invert_every=10): + """Minimize loss with a synchronous implementation of K-FAC. + + All workers perform gradient computation. Chief worker applies gradient after + averaging the gradients obtained from all the workers. All workers block + execution untill the update is applied. Chief worker runs covariance and + inverse update ops. Covariance and inverse matrices are placed on parameter + servers in a round robin manner. For further details on synchronous + distributed optimization check `tf.train.SyncReplicasOptimizer`. Args: task_id: int. Integer in [0, num_worker_tasks). ID for this worker. + is_chief: `boolean`, `True` if the worker is chief worker. num_worker_tasks: int. Number of workers in this distributed training setup. num_ps_tasks: int. Number of parameter servers holding variables. If 0, parameter servers are not used. @@ -271,6 +344,7 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master, run with each step. layer_collection: LayerCollection instance describing model architecture. Used by K-FAC to construct preconditioner. + invert_every: `int`, Number of steps between update the inverse. Returns: final value for 'accuracy'. @@ -278,19 +352,80 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master, Raises: ValueError: if task_id >= num_worker_tasks. """ - with tf.device(tf.train.replica_device_setter(num_ps_tasks)): - global_step = tf.train.get_or_create_global_step() - optimizer = opt.KfacOptimizer( - learning_rate=0.0001, - cov_ema_decay=0.95, - damping=0.001, - layer_collection=layer_collection, - momentum=0.9) - inv_update_queue = oq.OpQueue(optimizer.inv_update_ops) - sync_optimizer = tf.train.SyncReplicasOptimizer( - opt=optimizer, - replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks)) - train_op = sync_optimizer.minimize(loss, global_step=global_step) + + sync_optimizer, optimizer, global_step = _make_distributed_train_op( + task_id, num_worker_tasks, num_ps_tasks, layer_collection) + (cov_update_thunks, + inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() + train_op = sync_optimizer.minimize(loss, global_step=global_step) + + tf.logging.info("Starting training.") + hooks = [sync_optimizer.make_session_run_hook(is_chief)] + + def make_update_op(update_thunks): + update_op = [thunk() for thunk in update_thunks] + return tf.group(*update_op) + + if is_chief: + cov_update_op = make_update_op(cov_update_thunks) + with tf.control_dependencies([train_op, cov_update_op]): + update_op = tf.cond( + tf.equal(tf.mod(global_step + 1, invert_every), 0), + lambda: make_update_op(inv_update_thunks), + tf.no_op) + else: + update_op = train_op + + with tf.train.MonitoredTrainingSession( + master=master, + is_chief=is_chief, + checkpoint_dir=checkpoint_dir, + hooks=hooks, + stop_grace_period_secs=0) as sess: + while not sess.should_stop(): + global_step_, loss_, accuracy_, _ = sess.run( + [global_step, loss, accuracy, update_op]) + tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_, + loss_, accuracy_) + return accuracy_ + + +def distributed_grads_and_ops_dedicated_workers( + task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir, + loss, accuracy, layer_collection): + """Minimize loss with a synchronous implementation of K-FAC. + + Different workers are responsible for different parts of K-FAC's Ops. The + first 60% of tasks compute gradients; the next 20% accumulate covariance + statistics; the last 20% invert the matrices used to precondition gradients. + The chief worker applies the gradient . + + Args: + task_id: int. Integer in [0, num_worker_tasks). ID for this worker. + is_chief: `boolean`, `True` if the worker is chief worker. + num_worker_tasks: int. Number of workers in this distributed training setup. + num_ps_tasks: int. Number of parameter servers holding variables. If 0, + parameter servers are not used. + master: string. IP and port of TensorFlow runtime process. Set to empty + string to run locally. + checkpoint_dir: string or None. Path to store checkpoints under. + loss: 0-D Tensor. Loss to be minimized. + accuracy: dict mapping strings to 0-D Tensors. Additional accuracy to + run with each step. + layer_collection: LayerCollection instance describing model architecture. + Used by K-FAC to construct preconditioner. + + Returns: + final value for 'accuracy'. + + Raises: + ValueError: if task_id >= num_worker_tasks. + """ + sync_optimizer, optimizer, global_step = _make_distributed_train_op( + task_id, num_worker_tasks, num_ps_tasks, layer_collection) + _, cov_update_op, inv_update_ops, _, _, _ = optimizer.make_ops_and_vars() + train_op = sync_optimizer.minimize(loss, global_step=global_step) + inv_update_queue = oq.OpQueue(inv_update_ops) tf.logging.info("Starting training.") is_chief = (task_id == 0) @@ -306,7 +441,7 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master, if _is_gradient_task(task_id, num_worker_tasks): learning_op = train_op elif _is_cov_update_task(task_id, num_worker_tasks): - learning_op = optimizer.cov_update_op + learning_op = cov_update_op elif _is_inv_update_task(task_id, num_worker_tasks): # TODO(duckworthd): Running this op before cov_update_op has been run a # few times can result in "InvalidArgumentError: Cholesky decomposition @@ -324,13 +459,18 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master, return accuracy_ -def train_mnist_single_machine(data_dir, num_epochs, use_fake_data=False): +def train_mnist_single_machine(data_dir, + num_epochs, + use_fake_data=False, + device="/gpu:0"): """Train a ConvNet on MNIST. Args: data_dir: string. Directory to read MNIST examples from. num_epochs: int. Number of passes to make over the training set. use_fake_data: bool. If True, generate a synthetic dataset. + device: string, Either '/cpu:0' or '/gpu:0'. The covaraince and inverse + update ops are run on this device. Returns: accuracy of model on the final minibatch of training data. @@ -350,22 +490,38 @@ def train_mnist_single_machine(data_dir, num_epochs, use_fake_data=False): examples, labels, num_labels=10, layer_collection=layer_collection) # Fit model. - return minimize_loss_single_machine(loss, accuracy, layer_collection) + return minimize_loss_single_machine( + loss, accuracy, layer_collection, device=device) def train_mnist_multitower(data_dir, num_epochs, num_towers, - use_fake_data=True): + use_fake_data=True, devices=None): """Train a ConvNet on MNIST. + Training data is split equally among the towers. Each tower computes loss on + its own batch of data and the loss is aggregated on the CPU. The model + variables are placed on first tower. The covariance and inverse update ops + and variables are placed on GPUs in a round robin manner. + Args: data_dir: string. Directory to read MNIST examples from. num_epochs: int. Number of passes to make over the training set. num_towers: int. Number of CPUs to split inference across. use_fake_data: bool. If True, generate a synthetic dataset. + devices: string, Either list of CPU or GPU. The covaraince and inverse + update ops are run on this device. Returns: accuracy of model on the final minibatch of training data. """ + if devices: + device_count = {"GPU": num_towers} + else: + device_count = {"CPU": num_towers} + + devices = devices or [ + "/cpu:{}".format(tower_id) for tower_id in range(num_towers) + ] # Load a dataset. tf.logging.info("Loading MNIST into memory.") tower_batch_size = 128 @@ -388,7 +544,7 @@ def train_mnist_multitower(data_dir, num_epochs, num_towers, layer_collection = lc.LayerCollection() tower_results = [] for tower_id in range(num_towers): - with tf.device("/cpu:%d" % tower_id): + with tf.device(devices[tower_id]): with tf.name_scope("tower%d" % tower_id): with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)): tf.logging.info("Building tower %d." % tower_id) @@ -402,34 +558,79 @@ def train_mnist_multitower(data_dir, num_epochs, num_towers, accuracy = tf.reduce_mean(accuracies) # Fit model. + session_config = tf.ConfigProto( - allow_soft_placement=False, device_count={ - "CPU": num_towers - }) - return minimize_loss_single_machine( - loss, accuracy, layer_collection, session_config=session_config) + allow_soft_placement=False, + device_count=device_count, + ) + + g_step = tf.train.get_or_create_global_step() + optimizer = opt.KfacOptimizer( + learning_rate=0.0001, + cov_ema_decay=0.95, + damping=0.001, + layer_collection=layer_collection, + placement_strategy="round_robin", + cov_devices=devices, + inv_devices=devices, + momentum=0.9) + (cov_update_thunks, + inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() + train_op = optimizer.minimize(loss, global_step=g_step) -def train_mnist_distributed(task_id, - num_worker_tasks, - num_ps_tasks, - master, - data_dir, - num_epochs, - use_fake_data=False): - """Train a ConvNet on MNIST. + def make_update_op(update_thunks): + update_op = [thunk() for thunk in update_thunks] + return tf.group(*update_op) + + cov_update_op = make_update_op(cov_update_thunks) + with tf.control_dependencies([train_op, cov_update_op]): + inverse_op = tf.cond( + tf.equal(tf.mod(g_step + 1, _INVERT_EVERY), 0), + lambda: make_update_op(inv_update_thunks), tf.no_op) + + tf.logging.info("Starting training.") + with tf.train.MonitoredTrainingSession(config=session_config) as sess: + while not sess.should_stop(): + global_step_, loss_, accuracy_, _ = sess.run( + [g_step, loss, accuracy, inverse_op]) + + if (global_step_ + 1) % _INVERT_EVERY == 0: + tf.logging.info("global_step: %d | loss: %f | accuracy: %s", + global_step_, loss_, accuracy_) + + +def train_mnist_distributed_sync_replicas(task_id, + is_chief, + num_worker_tasks, + num_ps_tasks, + master, + data_dir, + num_epochs, + op_strategy, + use_fake_data=False): + """Train a ConvNet on MNIST using Sync replicas optimizer. Args: task_id: int. Integer in [0, num_worker_tasks). ID for this worker. + is_chief: `boolean`, `True` if the worker is chief worker. num_worker_tasks: int. Number of workers in this distributed training setup. num_ps_tasks: int. Number of parameter servers holding variables. master: string. IP and port of TensorFlow runtime process. data_dir: string. Directory to read MNIST examples from. num_epochs: int. Number of passes to make over the training set. + op_strategy: `string`, Strategy to run the covariance and inverse + ops. If op_strategy == `chief_worker` then covaraiance and inverse + update ops are run on chief worker otherwise they are run on dedicated + workers. + use_fake_data: bool. If True, generate a synthetic dataset. Returns: accuracy of model on the final minibatch of training data. + + Raises: + ValueError: If `op_strategy` not in ["chief_worker", "dedicated_workers"]. """ # Load a dataset. tf.logging.info("Loading MNIST into memory.") @@ -448,9 +649,17 @@ def train_mnist_distributed(task_id, # Fit model. checkpoint_dir = None if data_dir is None else os.path.join(data_dir, "kfac") - return minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, - master, checkpoint_dir, loss, accuracy, - layer_collection) + if op_strategy == "chief_worker": + return distributed_grads_only_and_ops_chief_worker( + task_id, is_chief, num_worker_tasks, num_ps_tasks, master, + checkpoint_dir, loss, accuracy, layer_collection) + elif op_strategy == "dedicated_workers": + return distributed_grads_and_ops_dedicated_workers( + task_id, is_chief, num_worker_tasks, num_ps_tasks, master, + checkpoint_dir, loss, accuracy, layer_collection) + else: + raise ValueError("Only supported op strategies are : {}, {}".format( + "chief_worker", "dedicated_workers")) if __name__ == "__main__": diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py new file mode 100644 index 0000000000..b4c2d4a9e9 --- /dev/null +++ b/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py @@ -0,0 +1,62 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Train a ConvNet on MNIST using K-FAC. + +Distributed training with sync replicas optimizer. See +`convnet.train_mnist_distributed_sync_replicas` for details. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from absl import flags +import tensorflow as tf + +from tensorflow.contrib.kfac.examples import convnet + +FLAGS = flags.FLAGS +flags.DEFINE_integer("task", -1, "Task identifier") +flags.DEFINE_string("data_dir", "/tmp/mnist", "local mnist dir") +flags.DEFINE_string( + "cov_inv_op_strategy", "chief_worker", + "In dist training mode run the cov, inv ops on chief or dedicated workers." +) +flags.DEFINE_string("master", "local", "Session master.") +flags.DEFINE_integer("ps_tasks", 2, + "Number of tasks in the parameter server job.") +flags.DEFINE_integer("replicas_to_aggregate", 5, + "Number of replicas to aggregate.") +flags.DEFINE_integer("worker_replicas", 5, "Number of replicas in worker job.") +flags.DEFINE_integer("num_epochs", None, "Number of epochs.") + + +def _is_chief(): + """Determines whether a job is the chief worker.""" + if "chief_worker" in FLAGS.brain_jobs: + return FLAGS.brain_job_name == "chief_worker" + else: + return FLAGS.task == 0 + + +def main(unused_argv): + _ = unused_argv + convnet.train_mnist_distributed_sync_replicas( + FLAGS.task, _is_chief(), FLAGS.worker_replicas, FLAGS.ps_tasks, + FLAGS.master, FLAGS.data_dir, FLAGS.num_epochs, FLAGS.cov_inv_op_strategy) + +if __name__ == "__main__": + tf.app.run(main=main) diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py new file mode 100644 index 0000000000..4249bf8a8d --- /dev/null +++ b/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py @@ -0,0 +1,48 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Train a ConvNet on MNIST using K-FAC. + +Multi tower training mode. See `convnet.train_mnist_multitower` for details. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from absl import flags +import tensorflow as tf + +from tensorflow.contrib.kfac.examples import convnet + +FLAGS = flags.FLAGS +flags.DEFINE_string("data_dir", "/tmp/multitower_1/mnist", "local mnist dir") +flags.DEFINE_integer("num_towers", 2, + "Number of towers for multi tower training.") + + +def main(unused_argv): + _ = unused_argv + assert FLAGS.num_towers > 1 + devices = ["/gpu:{}".format(tower_id) for tower_id in range(FLAGS.num_towers)] + convnet.train_mnist_multitower( + FLAGS.data_dir, + num_epochs=200, + num_towers=FLAGS.num_towers, + devices=devices) + + +if __name__ == "__main__": + tf.app.run(main=main) diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py index b0c6fbde19..3aa52aff19 100644 --- a/tensorflow/contrib/kfac/examples/convnet_mnist_main.py +++ b/tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py @@ -14,44 +14,26 @@ # ============================================================================== r"""Train a ConvNet on MNIST using K-FAC. -See convnet.py for details. +Train on single machine. See `convnet.train_mnist_single_machine` for details. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import argparse -import sys +from absl import flags import tensorflow as tf from tensorflow.contrib.kfac.examples import convnet -FLAGS = None +FLAGS = flags.FLAGS +flags.DEFINE_string("data_dir", "/tmp/mnist", "local mnist dir") -def main(argv): - _ = argv - - if FLAGS.num_towers > 1: - convnet.train_mnist_multitower( - FLAGS.data_dir, num_epochs=200, num_towers=FLAGS.num_towers) - else: - convnet.train_mnist_single_machine(FLAGS.data_dir, num_epochs=200) +def main(unused_argv): + convnet.train_mnist_single_gpu(FLAGS.data_dir, num_epochs=200) if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--data_dir", - type=str, - default="/tmp/mnist", - help="Directory to store dataset in.") - parser.add_argument( - "--num_towers", - type=int, - default=1, - help="Number of CPUs to split minibatch across.") - FLAGS, unparsed = parser.parse_known_args() - tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) + tf.app.run(main=main) diff --git a/tensorflow/contrib/kfac/examples/tests/convnet_test.py b/tensorflow/contrib/kfac/examples/tests/convnet_test.py index 8d86c2bb51..6de775cc79 100644 --- a/tensorflow/contrib/kfac/examples/tests/convnet_test.py +++ b/tensorflow/contrib/kfac/examples/tests/convnet_test.py @@ -112,15 +112,16 @@ class ConvNetTest(tf.test.TestCase): def testMinimizeLossSingleMachine(self): with tf.Graph().as_default(): loss, accuracy, layer_collection = self._build_toy_problem() - accuracy_ = convnet.minimize_loss_single_machine(loss, accuracy, - layer_collection) - self.assertLess(accuracy_, 1.0) + accuracy_ = convnet.minimize_loss_single_machine( + loss, accuracy, layer_collection, device="/cpu:0") + self.assertLess(accuracy_, 2.0) def testMinimizeLossDistributed(self): with tf.Graph().as_default(): loss, accuracy, layer_collection = self._build_toy_problem() - accuracy_ = convnet.minimize_loss_distributed( + accuracy_ = convnet.distributed_grads_only_and_ops_chief_worker( task_id=0, + is_chief=True, num_worker_tasks=1, num_ps_tasks=0, master="", @@ -128,7 +129,7 @@ class ConvNetTest(tf.test.TestCase): loss=loss, accuracy=accuracy, layer_collection=layer_collection) - self.assertLess(accuracy_, 1.0) + self.assertLess(accuracy_, 2.0) def testTrainMnistSingleMachine(self): with tf.Graph().as_default(): @@ -138,7 +139,7 @@ class ConvNetTest(tf.test.TestCase): # but there are too few parameters for the model to effectively memorize # the training set the way an MLP can. convnet.train_mnist_single_machine( - data_dir=None, num_epochs=1, use_fake_data=True) + data_dir=None, num_epochs=1, use_fake_data=True, device="/cpu:0") def testTrainMnistMultitower(self): with tf.Graph().as_default(): @@ -149,13 +150,15 @@ class ConvNetTest(tf.test.TestCase): def testTrainMnistDistributed(self): with tf.Graph().as_default(): # Ensure model training doesn't crash. - convnet.train_mnist_distributed( + convnet.train_mnist_distributed_sync_replicas( task_id=0, + is_chief=True, num_worker_tasks=1, num_ps_tasks=0, master="", data_dir=None, num_epochs=1, + op_strategy="chief_worker", use_fake_data=True) |