aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager/python/examples/revnet/main.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/eager/python/examples/revnet/main.py')
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/main.py95
1 files changed, 45 insertions, 50 deletions
diff --git a/tensorflow/contrib/eager/python/examples/revnet/main.py b/tensorflow/contrib/eager/python/examples/revnet/main.py
index e2f43b03f9..dcd4e1697f 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/main.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/main.py
@@ -31,8 +31,11 @@ tfe = tf.contrib.eager
def main(_):
"""Eager execution workflow with RevNet trained on CIFAR-10."""
- config = get_config()
- ds_train, ds_train_one_shot, ds_validation, ds_test = get_datasets(config)
+ tf.enable_eager_execution()
+
+ config = get_config(config_name=FLAGS.config, dataset=FLAGS.dataset)
+ ds_train, ds_train_one_shot, ds_validation, ds_test = get_datasets(
+ data_dir=FLAGS.data_dir, config=config)
model = revnet.RevNet(config=config)
global_step = tf.train.get_or_create_global_step() # Ensure correct summary
global_step.assign(1)
@@ -43,6 +46,9 @@ def main(_):
checkpointer = tf.train.Checkpoint(
optimizer=optimizer, model=model, optimizer_step=global_step)
+ if FLAGS.use_defun:
+ model.call = tfe.defun(model.call)
+
if FLAGS.train_dir:
summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir)
if FLAGS.restore:
@@ -52,46 +58,37 @@ def main(_):
"with global_step: {}".format(latest_path, global_step.numpy()))
sys.stdout.flush()
- if FLAGS.manual_grad:
- print("Using manual gradients.")
- else:
- print("Not using manual gradients.")
- sys.stdout.flush()
-
for x, y in ds_train:
train_one_iter(model, x, y, optimizer, global_step=global_step)
if global_step.numpy() % config.log_every == 0:
- it_train = ds_train_one_shot.make_one_shot_iterator()
it_test = ds_test.make_one_shot_iterator()
- acc_train, loss_train = evaluate(model, it_train)
acc_test, loss_test = evaluate(model, it_test)
if FLAGS.validate:
+ it_train = ds_train_one_shot.make_one_shot_iterator()
it_validation = ds_validation.make_one_shot_iterator()
+ acc_train, loss_train = evaluate(model, it_train)
acc_validation, loss_validation = evaluate(model, it_validation)
print("Iter {}, "
"training set accuracy {:.4f}, loss {:.4f}; "
- "validation set accuracy {:.4f}, loss {:4.f}"
+ "validation set accuracy {:.4f}, loss {:.4f}; "
"test accuracy {:.4f}, loss {:.4f}".format(
global_step.numpy(), acc_train, loss_train, acc_validation,
loss_validation, acc_test, loss_test))
else:
- print("Iter {}, "
- "training set accuracy {:.4f}, loss {:.4f}; "
- "test accuracy {:.4f}, loss {:.4f}".format(
- global_step.numpy(), acc_train, loss_train, acc_test,
- loss_test))
+ print("Iter {}, test accuracy {:.4f}, loss {:.4f}".format(
+ global_step.numpy(), acc_test, loss_test))
sys.stdout.flush()
if FLAGS.train_dir:
with summary_writer.as_default():
with tf.contrib.summary.always_record_summaries():
- tf.contrib.summary.scalar("Training accuracy", acc_train)
tf.contrib.summary.scalar("Test accuracy", acc_test)
- tf.contrib.summary.scalar("Training loss", loss_train)
tf.contrib.summary.scalar("Test loss", loss_test)
if FLAGS.validate:
+ tf.contrib.summary.scalar("Training accuracy", acc_train)
+ tf.contrib.summary.scalar("Training loss", loss_train)
tf.contrib.summary.scalar("Validation accuracy", acc_validation)
tf.contrib.summary.scalar("Validation loss", loss_validation)
@@ -103,34 +100,38 @@ def main(_):
sys.stdout.flush()
-def get_config():
+def get_config(config_name="revnet-38", dataset="cifar-10"):
"""Return configuration."""
- print("Config: {}".format(FLAGS.config))
+ print("Config: {}".format(config_name))
sys.stdout.flush()
config = {
"revnet-38": config_.get_hparams_cifar_38(),
"revnet-110": config_.get_hparams_cifar_110(),
"revnet-164": config_.get_hparams_cifar_164(),
- }[FLAGS.config]
+ }[config_name]
- if FLAGS.dataset == "cifar-100":
- config.n_classes = 100
+ if dataset == "cifar-10":
+ config.add_hparam("n_classes", 10)
+ config.add_hparam("dataset", "cifar-10")
+ else:
+ config.add_hparam("n_classes", 100)
+ config.add_hparam("dataset", "cifar-100")
return config
-def get_datasets(config):
+def get_datasets(data_dir, config):
"""Return dataset."""
- if FLAGS.data_dir is None:
+ if data_dir is None:
raise ValueError("No supplied data directory")
- if not os.path.exists(FLAGS.data_dir):
- raise ValueError("Data directory {} does not exist".format(FLAGS.data_dir))
- if FLAGS.dataset not in ["cifar-10", "cifar-100"]:
- raise ValueError("Unknown dataset {}".format(FLAGS.dataset))
+ if not os.path.exists(data_dir):
+ raise ValueError("Data directory {} does not exist".format(data_dir))
+ if config.dataset not in ["cifar-10", "cifar-100"]:
+ raise ValueError("Unknown dataset {}".format(config.dataset))
- print("Training on {} dataset.".format(FLAGS.dataset))
+ print("Training on {} dataset.".format(config.dataset))
sys.stdout.flush()
- data_dir = os.path.join(FLAGS.data_dir, FLAGS.dataset)
+ data_dir = os.path.join(data_dir, config.dataset)
if FLAGS.validate:
# 40k Training set
ds_train = cifar_input.get_ds_from_tfrecords(
@@ -168,7 +169,7 @@ def get_datasets(config):
prefetch=config.batch_size)
ds_validation = None
- # Always compute loss and accuracy on whole training and test set
+ # Always compute loss and accuracy on whole test set
ds_train_one_shot = cifar_input.get_ds_from_tfrecords(
data_dir=data_dir,
split="train_all",
@@ -196,19 +197,11 @@ def get_datasets(config):
def train_one_iter(model, inputs, labels, optimizer, global_step=None):
"""Train for one iteration."""
- if FLAGS.manual_grad:
- grads, vars_, loss = model.compute_gradients(inputs, labels, training=True)
- optimizer.apply_gradients(zip(grads, vars_), global_step=global_step)
- else: # For correctness validation
- with tf.GradientTape() as tape:
- logits, _ = model(inputs, training=True)
- loss = model.compute_loss(logits=logits, labels=labels)
- tf.logging.info("Logits are placed on device: {}".format(logits.device))
- grads = tape.gradient(loss, model.trainable_variables)
- optimizer.apply_gradients(
- zip(grads, model.trainable_variables), global_step=global_step)
+ grads, vars_, logits, loss = model.compute_gradients(
+ inputs, labels, training=True)
+ optimizer.apply_gradients(zip(grads, vars_), global_step=global_step)
- return loss.numpy()
+ return logits, loss
def evaluate(model, iterator):
@@ -241,16 +234,18 @@ if __name__ == "__main__":
"validate",
default=False,
help="[Optional] Use the validation set or not for hyperparameter search")
- flags.DEFINE_boolean(
- "manual_grad",
- default=False,
- help="[Optional] Use manual gradient graph to save memory")
flags.DEFINE_string(
"dataset",
default="cifar-10",
help="[Optional] The dataset used; either `cifar-10` or `cifar-100`")
flags.DEFINE_string(
- "config", default="revnet-38", help="[Optional] Architecture of network.")
+ "config",
+ default="revnet-38",
+ help="[Optional] Architecture of network. "
+ "Other options include `revnet-110` and `revnet-164`")
+ flags.DEFINE_boolean(
+ "use_defun",
+ default=False,
+ help="[Optional] Use `tfe.defun` to boost performance.")
FLAGS = flags.FLAGS
- tf.enable_eager_execution()
tf.app.run(main)