aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager/python/examples/revnet/main.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/eager/python/examples/revnet/main.py')
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/main.py231
1 files changed, 126 insertions, 105 deletions
diff --git a/tensorflow/contrib/eager/python/examples/revnet/main.py b/tensorflow/contrib/eager/python/examples/revnet/main.py
index 1065592509..dcd4e1697f 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/main.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/main.py
@@ -23,7 +23,6 @@ import sys
from absl import flags
import tensorflow as tf
-from tqdm import tqdm
from tensorflow.contrib.eager.python.examples.revnet import cifar_input
from tensorflow.contrib.eager.python.examples.revnet import config as config_
from tensorflow.contrib.eager.python.examples.revnet import revnet
@@ -32,19 +31,111 @@ tfe = tf.contrib.eager
def main(_):
"""Eager execution workflow with RevNet trained on CIFAR-10."""
- if FLAGS.data_dir is None:
- raise ValueError("No supplied data directory")
+ tf.enable_eager_execution()
- if not os.path.exists(FLAGS.data_dir):
- raise ValueError("Data directory {} does not exist".format(FLAGS.data_dir))
+ config = get_config(config_name=FLAGS.config, dataset=FLAGS.dataset)
+ ds_train, ds_train_one_shot, ds_validation, ds_test = get_datasets(
+ data_dir=FLAGS.data_dir, config=config)
+ model = revnet.RevNet(config=config)
+ global_step = tf.train.get_or_create_global_step() # Ensure correct summary
+ global_step.assign(1)
+ learning_rate = tf.train.piecewise_constant(
+ global_step, config.lr_decay_steps, config.lr_list)
+ optimizer = tf.train.MomentumOptimizer(
+ learning_rate, momentum=config.momentum)
+ checkpointer = tf.train.Checkpoint(
+ optimizer=optimizer, model=model, optimizer_step=global_step)
- tf.enable_eager_execution()
- config = config_.get_hparams_cifar_38()
+ if FLAGS.use_defun:
+ model.call = tfe.defun(model.call)
+
+ if FLAGS.train_dir:
+ summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir)
+ if FLAGS.restore:
+ latest_path = tf.train.latest_checkpoint(FLAGS.train_dir)
+ checkpointer.restore(latest_path)
+ print("Restored latest checkpoint at path:\"{}\" "
+ "with global_step: {}".format(latest_path, global_step.numpy()))
+ sys.stdout.flush()
+
+ for x, y in ds_train:
+ train_one_iter(model, x, y, optimizer, global_step=global_step)
+
+ if global_step.numpy() % config.log_every == 0:
+ it_test = ds_test.make_one_shot_iterator()
+ acc_test, loss_test = evaluate(model, it_test)
+
+ if FLAGS.validate:
+ it_train = ds_train_one_shot.make_one_shot_iterator()
+ it_validation = ds_validation.make_one_shot_iterator()
+ acc_train, loss_train = evaluate(model, it_train)
+ acc_validation, loss_validation = evaluate(model, it_validation)
+ print("Iter {}, "
+ "training set accuracy {:.4f}, loss {:.4f}; "
+ "validation set accuracy {:.4f}, loss {:.4f}; "
+ "test accuracy {:.4f}, loss {:.4f}".format(
+ global_step.numpy(), acc_train, loss_train, acc_validation,
+ loss_validation, acc_test, loss_test))
+ else:
+ print("Iter {}, test accuracy {:.4f}, loss {:.4f}".format(
+ global_step.numpy(), acc_test, loss_test))
+ sys.stdout.flush()
+ if FLAGS.train_dir:
+ with summary_writer.as_default():
+ with tf.contrib.summary.always_record_summaries():
+ tf.contrib.summary.scalar("Test accuracy", acc_test)
+ tf.contrib.summary.scalar("Test loss", loss_test)
+ if FLAGS.validate:
+ tf.contrib.summary.scalar("Training accuracy", acc_train)
+ tf.contrib.summary.scalar("Training loss", loss_train)
+ tf.contrib.summary.scalar("Validation accuracy", acc_validation)
+ tf.contrib.summary.scalar("Validation loss", loss_validation)
+
+ if global_step.numpy() % config.save_every == 0 and FLAGS.train_dir:
+ saved_path = checkpointer.save(
+ file_prefix=os.path.join(FLAGS.train_dir, "ckpt"))
+ print("Saved checkpoint at path: \"{}\" "
+ "with global_step: {}".format(saved_path, global_step.numpy()))
+ sys.stdout.flush()
+
+
+def get_config(config_name="revnet-38", dataset="cifar-10"):
+ """Return configuration."""
+ print("Config: {}".format(config_name))
+ sys.stdout.flush()
+ config = {
+ "revnet-38": config_.get_hparams_cifar_38(),
+ "revnet-110": config_.get_hparams_cifar_110(),
+ "revnet-164": config_.get_hparams_cifar_164(),
+ }[config_name]
+
+ if dataset == "cifar-10":
+ config.add_hparam("n_classes", 10)
+ config.add_hparam("dataset", "cifar-10")
+ else:
+ config.add_hparam("n_classes", 100)
+ config.add_hparam("dataset", "cifar-100")
+
+ return config
+
+
+def get_datasets(data_dir, config):
+ """Return dataset."""
+ if data_dir is None:
+ raise ValueError("No supplied data directory")
+ if not os.path.exists(data_dir):
+ raise ValueError("Data directory {} does not exist".format(data_dir))
+ if config.dataset not in ["cifar-10", "cifar-100"]:
+ raise ValueError("Unknown dataset {}".format(config.dataset))
+
+ print("Training on {} dataset.".format(config.dataset))
+ sys.stdout.flush()
+ data_dir = os.path.join(data_dir, config.dataset)
if FLAGS.validate:
# 40k Training set
ds_train = cifar_input.get_ds_from_tfrecords(
- data_dir=FLAGS.data_dir,
+ data_dir=data_dir,
split="train",
data_aug=True,
batch_size=config.batch_size,
@@ -55,7 +146,7 @@ def main(_):
prefetch=config.batch_size)
# 10k Training set
ds_validation = cifar_input.get_ds_from_tfrecords(
- data_dir=FLAGS.data_dir,
+ data_dir=data_dir,
split="validation",
data_aug=False,
batch_size=config.eval_batch_size,
@@ -67,7 +158,7 @@ def main(_):
else:
# 50k Training set
ds_train = cifar_input.get_ds_from_tfrecords(
- data_dir=FLAGS.data_dir,
+ data_dir=data_dir,
split="train_all",
data_aug=True,
batch_size=config.batch_size,
@@ -76,10 +167,11 @@ def main(_):
data_format=config.data_format,
dtype=config.dtype,
prefetch=config.batch_size)
+ ds_validation = None
- # Always compute loss and accuracy on whole training and test set
+ # Always compute loss and accuracy on whole test set
ds_train_one_shot = cifar_input.get_ds_from_tfrecords(
- data_dir=FLAGS.data_dir,
+ data_dir=data_dir,
split="train_all",
data_aug=False,
batch_size=config.eval_batch_size,
@@ -90,7 +182,7 @@ def main(_):
prefetch=config.eval_batch_size)
ds_test = cifar_input.get_ds_from_tfrecords(
- data_dir=FLAGS.data_dir,
+ data_dir=data_dir,
split="test",
data_aug=False,
batch_size=config.eval_batch_size,
@@ -100,103 +192,23 @@ def main(_):
dtype=config.dtype,
prefetch=config.eval_batch_size)
- model = revnet.RevNet(config=config)
- global_step = tfe.Variable(1, trainable=False)
- learning_rate = tf.train.piecewise_constant(
- global_step, config.lr_decay_steps, config.lr_list)
- optimizer = tf.train.MomentumOptimizer(
- learning_rate, momentum=config.momentum)
- checkpointer = tf.train.Checkpoint(
- optimizer=optimizer, model=model, optimizer_step=global_step)
-
- if FLAGS.train_dir:
- summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir)
- if FLAGS.restore:
- latest_path = tf.train.latest_checkpoint(FLAGS.train_dir)
- checkpointer.restore(latest_path)
- print("Restored latest checkpoint at path:\"{}\" "
- "with global_step: {}".format(latest_path, global_step.numpy()))
- sys.stdout.flush()
-
- warmup(model, config)
-
- for x, y in ds_train:
- loss = train_one_iter(model, x, y, optimizer, global_step=global_step)
-
- if global_step.numpy() % config.log_every == 0:
- it_train = ds_train_one_shot.make_one_shot_iterator()
- acc_train, loss_train = evaluate(model, it_train)
- it_test = ds_test.make_one_shot_iterator()
- acc_test, loss_test = evaluate(model, it_test)
- if FLAGS.validate:
- it_validation = ds_validation.make_one_shot_iterator()
- acc_validation, loss_validation = evaluate(model, it_validation)
- print("Iter {}, "
- "training set accuracy {:.4f}, loss {:.4f}; "
- "validation set accuracy {:.4f}, loss {:4.f}"
- "test accuracy {:.4f}, loss {:.4f}".format(
- global_step.numpy(), acc_train, loss_train, acc_validation,
- loss_validation, acc_test, loss_test))
- else:
- print("Iter {}, "
- "training set accuracy {:.4f}, loss {:.4f}; "
- "test accuracy {:.4f}, loss {:.4f}".format(
- global_step.numpy(), acc_train, loss_train, acc_test,
- loss_test))
- sys.stdout.flush()
-
- if FLAGS.train_dir:
- with summary_writer.as_default():
- with tf.contrib.summary.always_record_summaries():
- tf.contrib.summary.scalar("Training loss", loss)
- tf.contrib.summary.scalar("Test accuracy", acc_test)
- if FLAGS.validate:
- tf.contrib.summary.scalar("Validation accuracy", acc_validation)
-
- if global_step.numpy() % config.save_every == 0 and FLAGS.train_dir:
- saved_path = checkpointer.save(
- file_prefix=os.path.join(FLAGS.train_dir, "ckpt"))
- print("Saved checkpoint at path: \"{}\" "
- "with global_step: {}".format(saved_path, global_step.numpy()))
- sys.stdout.flush()
-
-
-def warmup(model, config, steps=1):
- mock_input = tf.random_normal((config.batch_size,) + config.input_shape)
- for _ in range(steps):
- model(mock_input, training=False)
+ return ds_train, ds_train_one_shot, ds_validation, ds_test
-def train_one_iter(model,
- inputs,
- labels,
- optimizer,
- global_step=None,
- verbose=False):
+def train_one_iter(model, inputs, labels, optimizer, global_step=None):
"""Train for one iteration."""
- if FLAGS.manual_grad:
- if verbose:
- print("Using manual gradients")
- grads, vars_, loss = model.compute_gradients(inputs, labels)
- optimizer.apply_gradients(zip(grads, vars_), global_step=global_step)
- else: # For correctness validation
- if verbose:
- print("Not using manual gradients")
- with tf.GradientTape() as tape:
- logits, _ = model(inputs, training=True)
- loss = model.compute_loss(logits=logits, labels=labels)
- grads = tape.gradient(loss, model.trainable_variables)
- optimizer.apply_gradients(
- zip(grads, model.trainable_variables), global_step=global_step)
-
- return loss.numpy()
+ grads, vars_, logits, loss = model.compute_gradients(
+ inputs, labels, training=True)
+ optimizer.apply_gradients(zip(grads, vars_), global_step=global_step)
+
+ return logits, loss
def evaluate(model, iterator):
"""Compute accuracy with the given dataset iterator."""
mean_loss = tfe.metrics.Mean()
accuracy = tfe.metrics.Accuracy()
- for x, y in tqdm(iterator):
+ for x, y in iterator:
logits, _ = model(x, training=False)
loss = model.compute_loss(logits=logits, labels=y)
accuracy(
@@ -209,11 +221,11 @@ def evaluate(model, iterator):
if __name__ == "__main__":
flags.DEFINE_string(
+ "data_dir", default=None, help="Directory to load tfrecords")
+ flags.DEFINE_string(
"train_dir",
default=None,
help="[Optional] Directory to store the training information")
- flags.DEFINE_string(
- "data_dir", default=None, help="Directory to load tfrecords")
flags.DEFINE_boolean(
"restore",
default=False,
@@ -222,9 +234,18 @@ if __name__ == "__main__":
"validate",
default=False,
help="[Optional] Use the validation set or not for hyperparameter search")
+ flags.DEFINE_string(
+ "dataset",
+ default="cifar-10",
+ help="[Optional] The dataset used; either `cifar-10` or `cifar-100`")
+ flags.DEFINE_string(
+ "config",
+ default="revnet-38",
+ help="[Optional] Architecture of network. "
+ "Other options include `revnet-110` and `revnet-164`")
flags.DEFINE_boolean(
- "manual_grad",
+ "use_defun",
default=False,
- help="[Optional] Use manual gradient graph to save memory")
+ help="[Optional] Use `tfe.defun` to boost performance.")
FLAGS = flags.FLAGS
tf.app.run(main)