aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kfac
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-04 05:31:48 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-04 05:34:52 -0700
commit71b19430e8484b136e0b872f6a543aff8a242587 (patch)
treeb783147361aafd2a412d7c6b0c7ad1a578410e1e /tensorflow/contrib/kfac
parent1f5324ca69bc1017972eef8e418691cff9a86dd7 (diff)
Sync replicas distributed training example with two strategies:
1) Interleave covariance and inverse update ops with training op. 2) Run the inverse and covariance ops on separate dedicated workers. PiperOrigin-RevId: 191579634
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r--tensorflow/contrib/kfac/examples/BUILD24
-rw-r--r--tensorflow/contrib/kfac/examples/convnet.py315
-rw-r--r--tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py62
-rw-r--r--tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py48
-rw-r--r--tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py (renamed from tensorflow/contrib/kfac/examples/convnet_mnist_main.py)32
-rw-r--r--tensorflow/contrib/kfac/examples/tests/convnet_test.py17
6 files changed, 411 insertions, 87 deletions
diff --git a/tensorflow/contrib/kfac/examples/BUILD b/tensorflow/contrib/kfac/examples/BUILD
index 7dd40c19c5..8186fa1c62 100644
--- a/tensorflow/contrib/kfac/examples/BUILD
+++ b/tensorflow/contrib/kfac/examples/BUILD
@@ -28,8 +28,28 @@ py_library(
)
py_binary(
- name = "convnet_mnist_main",
- srcs = ["convnet_mnist_main.py"],
+ name = "convnet_mnist_single_main",
+ srcs = ["convnet_mnist_single_main.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":convnet",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_binary(
+ name = "convnet_mnist_multi_tower_main",
+ srcs = ["convnet_mnist_multi_tower_main.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":convnet",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_binary(
+ name = "convnet_mnist_distributed_main",
+ srcs = ["convnet_mnist_distributed_main.py"],
srcs_version = "PY2AND3",
deps = [
":convnet",
diff --git a/tensorflow/contrib/kfac/examples/convnet.py b/tensorflow/contrib/kfac/examples/convnet.py
index 39d80addaa..e8e3353091 100644
--- a/tensorflow/contrib/kfac/examples/convnet.py
+++ b/tensorflow/contrib/kfac/examples/convnet.py
@@ -37,6 +37,8 @@ import tensorflow as tf
from tensorflow.contrib.kfac.examples import mlp
from tensorflow.contrib.kfac.examples import mnist
+from tensorflow.contrib.kfac.python.ops import optimizer as opt
+
lc = tf.contrib.kfac.layer_collection
oq = tf.contrib.kfac.op_queue
@@ -48,12 +50,18 @@ __all__ = [
"linear_layer",
"build_model",
"minimize_loss_single_machine",
- "minimize_loss_distributed",
+ "distributed_grads_only_and_ops_chief_worker",
+ "distributed_grads_and_ops_dedicated_workers",
"train_mnist_single_machine",
- "train_mnist_distributed",
+ "train_mnist_distributed_sync_replicas",
+ "train_mnist_multitower"
]
+# Inverse update ops will be run every _INVERT_EVRY iterations.
+_INVERT_EVERY = 10
+
+
def conv_layer(layer_id, inputs, kernel_size, out_channels):
"""Builds a convolutional layer with ReLU non-linearity.
@@ -161,8 +169,9 @@ def build_model(examples, labels, num_labels, layer_collection):
accuracy = tf.reduce_mean(
tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.float32))
- tf.summary.scalar("loss", loss)
- tf.summary.scalar("accuracy", accuracy)
+ with tf.device("/cpu:0"):
+ tf.summary.scalar("loss", loss)
+ tf.summary.scalar("accuracy", accuracy)
# Register parameters. K-FAC needs to know about the inputs, outputs, and
# parameters of each conv/fully connected layer and the logits powering the
@@ -181,41 +190,59 @@ def build_model(examples, labels, num_labels, layer_collection):
def minimize_loss_single_machine(loss,
accuracy,
layer_collection,
+ device="/gpu:0",
session_config=None):
"""Minimize loss with K-FAC on a single machine.
- A single Session is responsible for running all of K-FAC's ops.
+ A single Session is responsible for running all of K-FAC's ops. The covariance
+ and inverse update ops are placed on `device`. All model variables are on CPU.
Args:
loss: 0-D Tensor. Loss to be minimized.
accuracy: 0-D Tensor. Accuracy of classifier on current minibatch.
layer_collection: LayerCollection instance describing model architecture.
Used by K-FAC to construct preconditioner.
+ device: string, Either '/cpu:0' or '/gpu:0'. The covaraince and invserse
+ update ops are run on this device.
session_config: None or tf.ConfigProto. Configuration for tf.Session().
Returns:
final value for 'accuracy'.
"""
# Train with K-FAC.
- global_step = tf.train.get_or_create_global_step()
+ g_step = tf.train.get_or_create_global_step()
optimizer = opt.KfacOptimizer(
learning_rate=0.0001,
cov_ema_decay=0.95,
damping=0.001,
layer_collection=layer_collection,
+ placement_strategy="round_robin",
+ cov_devices=[device],
+ inv_devices=[device],
momentum=0.9)
- train_op = optimizer.minimize(loss, global_step=global_step)
+ (cov_update_thunks,
+ inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
+
+ with tf.device(device):
+ train_op = optimizer.minimize(loss, global_step=g_step)
+
+ def make_update_op(update_thunks):
+ update_op = [thunk() for thunk in update_thunks]
+ return tf.group(*update_op)
+
+ cov_update_op = make_update_op(cov_update_thunks)
+ with tf.control_dependencies([train_op, cov_update_op]):
+ inverse_op = tf.cond(
+ tf.equal(tf.mod(g_step + 1, _INVERT_EVERY), 0),
+ lambda: make_update_op(inv_update_thunks), tf.no_op)
tf.logging.info("Starting training.")
with tf.train.MonitoredTrainingSession(config=session_config) as sess:
while not sess.should_stop():
- global_step_, loss_, accuracy_, _, _ = sess.run(
- [global_step, loss, accuracy, train_op, optimizer.cov_update_op])
-
- if global_step_ % 100 == 0:
- sess.run(optimizer.inv_update_op)
+ global_step_, loss_, accuracy_, _ = sess.run(
+ [g_step, loss, accuracy, inverse_op])
- if global_step_ % 100 == 0:
+ if (global_step_ + 1) % _INVERT_EVERY == 0:
tf.logging.info("global_step: %d | loss: %f | accuracy: %s",
global_step_, loss_, accuracy_)
@@ -250,16 +277,62 @@ def _num_gradient_tasks(num_tasks):
return int(np.ceil(0.6 * num_tasks))
-def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master,
- checkpoint_dir, loss, accuracy, layer_collection):
- """Minimize loss with an synchronous implementation of K-FAC.
+def _make_distributed_train_op(
+ task_id,
+ num_worker_tasks,
+ num_ps_tasks,
+ layer_collection
+):
+ """Creates optimizer and distributed training op.
- Different tasks are responsible for different parts of K-FAC's Ops. The first
- 60% of tasks update weights; the next 20% accumulate covariance statistics;
- the last 20% invert the matrices used to precondition gradients.
+ Constructs KFAC optimizer and wraps it in `sync_replicas` optimizer. Makes
+ the train op.
+
+ Args:
+ task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
+ num_worker_tasks: int. Number of workers in this distributed training setup.
+ num_ps_tasks: int. Number of parameter servers holding variables. If 0,
+ parameter servers are not used.
+ layer_collection: LayerCollection instance describing model architecture.
+ Used by K-FAC to construct preconditioner.
+
+ Returns:
+ sync_optimizer: `tf.train.SyncReplicasOptimizer` instance which wraps KFAC
+ optimizer.
+ optimizer: Instance of `opt.KfacOptimizer`.
+ global_step: `tensor`, Global step.
+ """
+ tf.logging.info("Task id : %d", task_id)
+ with tf.device(tf.train.replica_device_setter(num_ps_tasks)):
+ global_step = tf.train.get_or_create_global_step()
+ optimizer = opt.KfacOptimizer(
+ learning_rate=0.0001,
+ cov_ema_decay=0.95,
+ damping=0.001,
+ layer_collection=layer_collection,
+ momentum=0.9)
+ sync_optimizer = tf.train.SyncReplicasOptimizer(
+ opt=optimizer,
+ replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks),
+ total_num_replicas=num_worker_tasks)
+ return sync_optimizer, optimizer, global_step
+
+
+def distributed_grads_only_and_ops_chief_worker(
+ task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir,
+ loss, accuracy, layer_collection, invert_every=10):
+ """Minimize loss with a synchronous implementation of K-FAC.
+
+ All workers perform gradient computation. Chief worker applies gradient after
+ averaging the gradients obtained from all the workers. All workers block
+ execution untill the update is applied. Chief worker runs covariance and
+ inverse update ops. Covariance and inverse matrices are placed on parameter
+ servers in a round robin manner. For further details on synchronous
+ distributed optimization check `tf.train.SyncReplicasOptimizer`.
Args:
task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
+ is_chief: `boolean`, `True` if the worker is chief worker.
num_worker_tasks: int. Number of workers in this distributed training setup.
num_ps_tasks: int. Number of parameter servers holding variables. If 0,
parameter servers are not used.
@@ -271,6 +344,7 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master,
run with each step.
layer_collection: LayerCollection instance describing model architecture.
Used by K-FAC to construct preconditioner.
+ invert_every: `int`, Number of steps between update the inverse.
Returns:
final value for 'accuracy'.
@@ -278,19 +352,80 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master,
Raises:
ValueError: if task_id >= num_worker_tasks.
"""
- with tf.device(tf.train.replica_device_setter(num_ps_tasks)):
- global_step = tf.train.get_or_create_global_step()
- optimizer = opt.KfacOptimizer(
- learning_rate=0.0001,
- cov_ema_decay=0.95,
- damping=0.001,
- layer_collection=layer_collection,
- momentum=0.9)
- inv_update_queue = oq.OpQueue(optimizer.inv_update_ops)
- sync_optimizer = tf.train.SyncReplicasOptimizer(
- opt=optimizer,
- replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks))
- train_op = sync_optimizer.minimize(loss, global_step=global_step)
+
+ sync_optimizer, optimizer, global_step = _make_distributed_train_op(
+ task_id, num_worker_tasks, num_ps_tasks, layer_collection)
+ (cov_update_thunks,
+ inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
+ train_op = sync_optimizer.minimize(loss, global_step=global_step)
+
+ tf.logging.info("Starting training.")
+ hooks = [sync_optimizer.make_session_run_hook(is_chief)]
+
+ def make_update_op(update_thunks):
+ update_op = [thunk() for thunk in update_thunks]
+ return tf.group(*update_op)
+
+ if is_chief:
+ cov_update_op = make_update_op(cov_update_thunks)
+ with tf.control_dependencies([train_op, cov_update_op]):
+ update_op = tf.cond(
+ tf.equal(tf.mod(global_step + 1, invert_every), 0),
+ lambda: make_update_op(inv_update_thunks),
+ tf.no_op)
+ else:
+ update_op = train_op
+
+ with tf.train.MonitoredTrainingSession(
+ master=master,
+ is_chief=is_chief,
+ checkpoint_dir=checkpoint_dir,
+ hooks=hooks,
+ stop_grace_period_secs=0) as sess:
+ while not sess.should_stop():
+ global_step_, loss_, accuracy_, _ = sess.run(
+ [global_step, loss, accuracy, update_op])
+ tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_,
+ loss_, accuracy_)
+ return accuracy_
+
+
+def distributed_grads_and_ops_dedicated_workers(
+ task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir,
+ loss, accuracy, layer_collection):
+ """Minimize loss with a synchronous implementation of K-FAC.
+
+ Different workers are responsible for different parts of K-FAC's Ops. The
+ first 60% of tasks compute gradients; the next 20% accumulate covariance
+ statistics; the last 20% invert the matrices used to precondition gradients.
+ The chief worker applies the gradient .
+
+ Args:
+ task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
+ is_chief: `boolean`, `True` if the worker is chief worker.
+ num_worker_tasks: int. Number of workers in this distributed training setup.
+ num_ps_tasks: int. Number of parameter servers holding variables. If 0,
+ parameter servers are not used.
+ master: string. IP and port of TensorFlow runtime process. Set to empty
+ string to run locally.
+ checkpoint_dir: string or None. Path to store checkpoints under.
+ loss: 0-D Tensor. Loss to be minimized.
+ accuracy: dict mapping strings to 0-D Tensors. Additional accuracy to
+ run with each step.
+ layer_collection: LayerCollection instance describing model architecture.
+ Used by K-FAC to construct preconditioner.
+
+ Returns:
+ final value for 'accuracy'.
+
+ Raises:
+ ValueError: if task_id >= num_worker_tasks.
+ """
+ sync_optimizer, optimizer, global_step = _make_distributed_train_op(
+ task_id, num_worker_tasks, num_ps_tasks, layer_collection)
+ _, cov_update_op, inv_update_ops, _, _, _ = optimizer.make_ops_and_vars()
+ train_op = sync_optimizer.minimize(loss, global_step=global_step)
+ inv_update_queue = oq.OpQueue(inv_update_ops)
tf.logging.info("Starting training.")
is_chief = (task_id == 0)
@@ -306,7 +441,7 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master,
if _is_gradient_task(task_id, num_worker_tasks):
learning_op = train_op
elif _is_cov_update_task(task_id, num_worker_tasks):
- learning_op = optimizer.cov_update_op
+ learning_op = cov_update_op
elif _is_inv_update_task(task_id, num_worker_tasks):
# TODO(duckworthd): Running this op before cov_update_op has been run a
# few times can result in "InvalidArgumentError: Cholesky decomposition
@@ -324,13 +459,18 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master,
return accuracy_
-def train_mnist_single_machine(data_dir, num_epochs, use_fake_data=False):
+def train_mnist_single_machine(data_dir,
+ num_epochs,
+ use_fake_data=False,
+ device="/gpu:0"):
"""Train a ConvNet on MNIST.
Args:
data_dir: string. Directory to read MNIST examples from.
num_epochs: int. Number of passes to make over the training set.
use_fake_data: bool. If True, generate a synthetic dataset.
+ device: string, Either '/cpu:0' or '/gpu:0'. The covaraince and inverse
+ update ops are run on this device.
Returns:
accuracy of model on the final minibatch of training data.
@@ -350,22 +490,38 @@ def train_mnist_single_machine(data_dir, num_epochs, use_fake_data=False):
examples, labels, num_labels=10, layer_collection=layer_collection)
# Fit model.
- return minimize_loss_single_machine(loss, accuracy, layer_collection)
+ return minimize_loss_single_machine(
+ loss, accuracy, layer_collection, device=device)
def train_mnist_multitower(data_dir, num_epochs, num_towers,
- use_fake_data=True):
+ use_fake_data=True, devices=None):
"""Train a ConvNet on MNIST.
+ Training data is split equally among the towers. Each tower computes loss on
+ its own batch of data and the loss is aggregated on the CPU. The model
+ variables are placed on first tower. The covariance and inverse update ops
+ and variables are placed on GPUs in a round robin manner.
+
Args:
data_dir: string. Directory to read MNIST examples from.
num_epochs: int. Number of passes to make over the training set.
num_towers: int. Number of CPUs to split inference across.
use_fake_data: bool. If True, generate a synthetic dataset.
+ devices: string, Either list of CPU or GPU. The covaraince and inverse
+ update ops are run on this device.
Returns:
accuracy of model on the final minibatch of training data.
"""
+ if devices:
+ device_count = {"GPU": num_towers}
+ else:
+ device_count = {"CPU": num_towers}
+
+ devices = devices or [
+ "/cpu:{}".format(tower_id) for tower_id in range(num_towers)
+ ]
# Load a dataset.
tf.logging.info("Loading MNIST into memory.")
tower_batch_size = 128
@@ -388,7 +544,7 @@ def train_mnist_multitower(data_dir, num_epochs, num_towers,
layer_collection = lc.LayerCollection()
tower_results = []
for tower_id in range(num_towers):
- with tf.device("/cpu:%d" % tower_id):
+ with tf.device(devices[tower_id]):
with tf.name_scope("tower%d" % tower_id):
with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)):
tf.logging.info("Building tower %d." % tower_id)
@@ -402,34 +558,79 @@ def train_mnist_multitower(data_dir, num_epochs, num_towers,
accuracy = tf.reduce_mean(accuracies)
# Fit model.
+
session_config = tf.ConfigProto(
- allow_soft_placement=False, device_count={
- "CPU": num_towers
- })
- return minimize_loss_single_machine(
- loss, accuracy, layer_collection, session_config=session_config)
+ allow_soft_placement=False,
+ device_count=device_count,
+ )
+
+ g_step = tf.train.get_or_create_global_step()
+ optimizer = opt.KfacOptimizer(
+ learning_rate=0.0001,
+ cov_ema_decay=0.95,
+ damping=0.001,
+ layer_collection=layer_collection,
+ placement_strategy="round_robin",
+ cov_devices=devices,
+ inv_devices=devices,
+ momentum=0.9)
+ (cov_update_thunks,
+ inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
+ train_op = optimizer.minimize(loss, global_step=g_step)
-def train_mnist_distributed(task_id,
- num_worker_tasks,
- num_ps_tasks,
- master,
- data_dir,
- num_epochs,
- use_fake_data=False):
- """Train a ConvNet on MNIST.
+ def make_update_op(update_thunks):
+ update_op = [thunk() for thunk in update_thunks]
+ return tf.group(*update_op)
+
+ cov_update_op = make_update_op(cov_update_thunks)
+ with tf.control_dependencies([train_op, cov_update_op]):
+ inverse_op = tf.cond(
+ tf.equal(tf.mod(g_step + 1, _INVERT_EVERY), 0),
+ lambda: make_update_op(inv_update_thunks), tf.no_op)
+
+ tf.logging.info("Starting training.")
+ with tf.train.MonitoredTrainingSession(config=session_config) as sess:
+ while not sess.should_stop():
+ global_step_, loss_, accuracy_, _ = sess.run(
+ [g_step, loss, accuracy, inverse_op])
+
+ if (global_step_ + 1) % _INVERT_EVERY == 0:
+ tf.logging.info("global_step: %d | loss: %f | accuracy: %s",
+ global_step_, loss_, accuracy_)
+
+
+def train_mnist_distributed_sync_replicas(task_id,
+ is_chief,
+ num_worker_tasks,
+ num_ps_tasks,
+ master,
+ data_dir,
+ num_epochs,
+ op_strategy,
+ use_fake_data=False):
+ """Train a ConvNet on MNIST using Sync replicas optimizer.
Args:
task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
+ is_chief: `boolean`, `True` if the worker is chief worker.
num_worker_tasks: int. Number of workers in this distributed training setup.
num_ps_tasks: int. Number of parameter servers holding variables.
master: string. IP and port of TensorFlow runtime process.
data_dir: string. Directory to read MNIST examples from.
num_epochs: int. Number of passes to make over the training set.
+ op_strategy: `string`, Strategy to run the covariance and inverse
+ ops. If op_strategy == `chief_worker` then covaraiance and inverse
+ update ops are run on chief worker otherwise they are run on dedicated
+ workers.
+
use_fake_data: bool. If True, generate a synthetic dataset.
Returns:
accuracy of model on the final minibatch of training data.
+
+ Raises:
+ ValueError: If `op_strategy` not in ["chief_worker", "dedicated_workers"].
"""
# Load a dataset.
tf.logging.info("Loading MNIST into memory.")
@@ -448,9 +649,17 @@ def train_mnist_distributed(task_id,
# Fit model.
checkpoint_dir = None if data_dir is None else os.path.join(data_dir, "kfac")
- return minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks,
- master, checkpoint_dir, loss, accuracy,
- layer_collection)
+ if op_strategy == "chief_worker":
+ return distributed_grads_only_and_ops_chief_worker(
+ task_id, is_chief, num_worker_tasks, num_ps_tasks, master,
+ checkpoint_dir, loss, accuracy, layer_collection)
+ elif op_strategy == "dedicated_workers":
+ return distributed_grads_and_ops_dedicated_workers(
+ task_id, is_chief, num_worker_tasks, num_ps_tasks, master,
+ checkpoint_dir, loss, accuracy, layer_collection)
+ else:
+ raise ValueError("Only supported op strategies are : {}, {}".format(
+ "chief_worker", "dedicated_workers"))
if __name__ == "__main__":
diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py
new file mode 100644
index 0000000000..b4c2d4a9e9
--- /dev/null
+++ b/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py
@@ -0,0 +1,62 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+r"""Train a ConvNet on MNIST using K-FAC.
+
+Distributed training with sync replicas optimizer. See
+`convnet.train_mnist_distributed_sync_replicas` for details.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from absl import flags
+import tensorflow as tf
+
+from tensorflow.contrib.kfac.examples import convnet
+
+FLAGS = flags.FLAGS
+flags.DEFINE_integer("task", -1, "Task identifier")
+flags.DEFINE_string("data_dir", "/tmp/mnist", "local mnist dir")
+flags.DEFINE_string(
+ "cov_inv_op_strategy", "chief_worker",
+ "In dist training mode run the cov, inv ops on chief or dedicated workers."
+)
+flags.DEFINE_string("master", "local", "Session master.")
+flags.DEFINE_integer("ps_tasks", 2,
+ "Number of tasks in the parameter server job.")
+flags.DEFINE_integer("replicas_to_aggregate", 5,
+ "Number of replicas to aggregate.")
+flags.DEFINE_integer("worker_replicas", 5, "Number of replicas in worker job.")
+flags.DEFINE_integer("num_epochs", None, "Number of epochs.")
+
+
+def _is_chief():
+ """Determines whether a job is the chief worker."""
+ if "chief_worker" in FLAGS.brain_jobs:
+ return FLAGS.brain_job_name == "chief_worker"
+ else:
+ return FLAGS.task == 0
+
+
+def main(unused_argv):
+ _ = unused_argv
+ convnet.train_mnist_distributed_sync_replicas(
+ FLAGS.task, _is_chief(), FLAGS.worker_replicas, FLAGS.ps_tasks,
+ FLAGS.master, FLAGS.data_dir, FLAGS.num_epochs, FLAGS.cov_inv_op_strategy)
+
+if __name__ == "__main__":
+ tf.app.run(main=main)
diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py
new file mode 100644
index 0000000000..4249bf8a8d
--- /dev/null
+++ b/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py
@@ -0,0 +1,48 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+r"""Train a ConvNet on MNIST using K-FAC.
+
+Multi tower training mode. See `convnet.train_mnist_multitower` for details.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from absl import flags
+import tensorflow as tf
+
+from tensorflow.contrib.kfac.examples import convnet
+
+FLAGS = flags.FLAGS
+flags.DEFINE_string("data_dir", "/tmp/multitower_1/mnist", "local mnist dir")
+flags.DEFINE_integer("num_towers", 2,
+ "Number of towers for multi tower training.")
+
+
+def main(unused_argv):
+ _ = unused_argv
+ assert FLAGS.num_towers > 1
+ devices = ["/gpu:{}".format(tower_id) for tower_id in range(FLAGS.num_towers)]
+ convnet.train_mnist_multitower(
+ FLAGS.data_dir,
+ num_epochs=200,
+ num_towers=FLAGS.num_towers,
+ devices=devices)
+
+
+if __name__ == "__main__":
+ tf.app.run(main=main)
diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py
index b0c6fbde19..3aa52aff19 100644
--- a/tensorflow/contrib/kfac/examples/convnet_mnist_main.py
+++ b/tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py
@@ -14,44 +14,26 @@
# ==============================================================================
r"""Train a ConvNet on MNIST using K-FAC.
-See convnet.py for details.
+Train on single machine. See `convnet.train_mnist_single_machine` for details.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import argparse
-import sys
+from absl import flags
import tensorflow as tf
from tensorflow.contrib.kfac.examples import convnet
-FLAGS = None
+FLAGS = flags.FLAGS
+flags.DEFINE_string("data_dir", "/tmp/mnist", "local mnist dir")
-def main(argv):
- _ = argv
-
- if FLAGS.num_towers > 1:
- convnet.train_mnist_multitower(
- FLAGS.data_dir, num_epochs=200, num_towers=FLAGS.num_towers)
- else:
- convnet.train_mnist_single_machine(FLAGS.data_dir, num_epochs=200)
+def main(unused_argv):
+ convnet.train_mnist_single_gpu(FLAGS.data_dir, num_epochs=200)
if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--data_dir",
- type=str,
- default="/tmp/mnist",
- help="Directory to store dataset in.")
- parser.add_argument(
- "--num_towers",
- type=int,
- default=1,
- help="Number of CPUs to split minibatch across.")
- FLAGS, unparsed = parser.parse_known_args()
- tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
+ tf.app.run(main=main)
diff --git a/tensorflow/contrib/kfac/examples/tests/convnet_test.py b/tensorflow/contrib/kfac/examples/tests/convnet_test.py
index 8d86c2bb51..6de775cc79 100644
--- a/tensorflow/contrib/kfac/examples/tests/convnet_test.py
+++ b/tensorflow/contrib/kfac/examples/tests/convnet_test.py
@@ -112,15 +112,16 @@ class ConvNetTest(tf.test.TestCase):
def testMinimizeLossSingleMachine(self):
with tf.Graph().as_default():
loss, accuracy, layer_collection = self._build_toy_problem()
- accuracy_ = convnet.minimize_loss_single_machine(loss, accuracy,
- layer_collection)
- self.assertLess(accuracy_, 1.0)
+ accuracy_ = convnet.minimize_loss_single_machine(
+ loss, accuracy, layer_collection, device="/cpu:0")
+ self.assertLess(accuracy_, 2.0)
def testMinimizeLossDistributed(self):
with tf.Graph().as_default():
loss, accuracy, layer_collection = self._build_toy_problem()
- accuracy_ = convnet.minimize_loss_distributed(
+ accuracy_ = convnet.distributed_grads_only_and_ops_chief_worker(
task_id=0,
+ is_chief=True,
num_worker_tasks=1,
num_ps_tasks=0,
master="",
@@ -128,7 +129,7 @@ class ConvNetTest(tf.test.TestCase):
loss=loss,
accuracy=accuracy,
layer_collection=layer_collection)
- self.assertLess(accuracy_, 1.0)
+ self.assertLess(accuracy_, 2.0)
def testTrainMnistSingleMachine(self):
with tf.Graph().as_default():
@@ -138,7 +139,7 @@ class ConvNetTest(tf.test.TestCase):
# but there are too few parameters for the model to effectively memorize
# the training set the way an MLP can.
convnet.train_mnist_single_machine(
- data_dir=None, num_epochs=1, use_fake_data=True)
+ data_dir=None, num_epochs=1, use_fake_data=True, device="/cpu:0")
def testTrainMnistMultitower(self):
with tf.Graph().as_default():
@@ -149,13 +150,15 @@ class ConvNetTest(tf.test.TestCase):
def testTrainMnistDistributed(self):
with tf.Graph().as_default():
# Ensure model training doesn't crash.
- convnet.train_mnist_distributed(
+ convnet.train_mnist_distributed_sync_replicas(
task_id=0,
+ is_chief=True,
num_worker_tasks=1,
num_ps_tasks=0,
master="",
data_dir=None,
num_epochs=1,
+ op_strategy="chief_worker",
use_fake_data=True)