diff options
Diffstat (limited to 'tensorflow/contrib/eager/python/examples/revnet/main.py')
-rw-r--r-- | tensorflow/contrib/eager/python/examples/revnet/main.py | 95 |
1 files changed, 45 insertions, 50 deletions
diff --git a/tensorflow/contrib/eager/python/examples/revnet/main.py b/tensorflow/contrib/eager/python/examples/revnet/main.py index e2f43b03f9..dcd4e1697f 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/main.py +++ b/tensorflow/contrib/eager/python/examples/revnet/main.py @@ -31,8 +31,11 @@ tfe = tf.contrib.eager def main(_): """Eager execution workflow with RevNet trained on CIFAR-10.""" - config = get_config() - ds_train, ds_train_one_shot, ds_validation, ds_test = get_datasets(config) + tf.enable_eager_execution() + + config = get_config(config_name=FLAGS.config, dataset=FLAGS.dataset) + ds_train, ds_train_one_shot, ds_validation, ds_test = get_datasets( + data_dir=FLAGS.data_dir, config=config) model = revnet.RevNet(config=config) global_step = tf.train.get_or_create_global_step() # Ensure correct summary global_step.assign(1) @@ -43,6 +46,9 @@ def main(_): checkpointer = tf.train.Checkpoint( optimizer=optimizer, model=model, optimizer_step=global_step) + if FLAGS.use_defun: + model.call = tfe.defun(model.call) + if FLAGS.train_dir: summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir) if FLAGS.restore: @@ -52,46 +58,37 @@ def main(_): "with global_step: {}".format(latest_path, global_step.numpy())) sys.stdout.flush() - if FLAGS.manual_grad: - print("Using manual gradients.") - else: - print("Not using manual gradients.") - sys.stdout.flush() - for x, y in ds_train: train_one_iter(model, x, y, optimizer, global_step=global_step) if global_step.numpy() % config.log_every == 0: - it_train = ds_train_one_shot.make_one_shot_iterator() it_test = ds_test.make_one_shot_iterator() - acc_train, loss_train = evaluate(model, it_train) acc_test, loss_test = evaluate(model, it_test) if FLAGS.validate: + it_train = ds_train_one_shot.make_one_shot_iterator() it_validation = ds_validation.make_one_shot_iterator() + acc_train, loss_train = evaluate(model, it_train) acc_validation, loss_validation = evaluate(model, it_validation) print("Iter {}, " "training set accuracy {:.4f}, loss {:.4f}; " - "validation set accuracy {:.4f}, loss {:4.f}" + "validation set accuracy {:.4f}, loss {:.4f}; " "test accuracy {:.4f}, loss {:.4f}".format( global_step.numpy(), acc_train, loss_train, acc_validation, loss_validation, acc_test, loss_test)) else: - print("Iter {}, " - "training set accuracy {:.4f}, loss {:.4f}; " - "test accuracy {:.4f}, loss {:.4f}".format( - global_step.numpy(), acc_train, loss_train, acc_test, - loss_test)) + print("Iter {}, test accuracy {:.4f}, loss {:.4f}".format( + global_step.numpy(), acc_test, loss_test)) sys.stdout.flush() if FLAGS.train_dir: with summary_writer.as_default(): with tf.contrib.summary.always_record_summaries(): - tf.contrib.summary.scalar("Training accuracy", acc_train) tf.contrib.summary.scalar("Test accuracy", acc_test) - tf.contrib.summary.scalar("Training loss", loss_train) tf.contrib.summary.scalar("Test loss", loss_test) if FLAGS.validate: + tf.contrib.summary.scalar("Training accuracy", acc_train) + tf.contrib.summary.scalar("Training loss", loss_train) tf.contrib.summary.scalar("Validation accuracy", acc_validation) tf.contrib.summary.scalar("Validation loss", loss_validation) @@ -103,34 +100,38 @@ def main(_): sys.stdout.flush() -def get_config(): +def get_config(config_name="revnet-38", dataset="cifar-10"): """Return configuration.""" - print("Config: {}".format(FLAGS.config)) + print("Config: {}".format(config_name)) sys.stdout.flush() config = { "revnet-38": config_.get_hparams_cifar_38(), "revnet-110": config_.get_hparams_cifar_110(), "revnet-164": config_.get_hparams_cifar_164(), - }[FLAGS.config] + }[config_name] - if FLAGS.dataset == "cifar-100": - config.n_classes = 100 + if dataset == "cifar-10": + config.add_hparam("n_classes", 10) + config.add_hparam("dataset", "cifar-10") + else: + config.add_hparam("n_classes", 100) + config.add_hparam("dataset", "cifar-100") return config -def get_datasets(config): +def get_datasets(data_dir, config): """Return dataset.""" - if FLAGS.data_dir is None: + if data_dir is None: raise ValueError("No supplied data directory") - if not os.path.exists(FLAGS.data_dir): - raise ValueError("Data directory {} does not exist".format(FLAGS.data_dir)) - if FLAGS.dataset not in ["cifar-10", "cifar-100"]: - raise ValueError("Unknown dataset {}".format(FLAGS.dataset)) + if not os.path.exists(data_dir): + raise ValueError("Data directory {} does not exist".format(data_dir)) + if config.dataset not in ["cifar-10", "cifar-100"]: + raise ValueError("Unknown dataset {}".format(config.dataset)) - print("Training on {} dataset.".format(FLAGS.dataset)) + print("Training on {} dataset.".format(config.dataset)) sys.stdout.flush() - data_dir = os.path.join(FLAGS.data_dir, FLAGS.dataset) + data_dir = os.path.join(data_dir, config.dataset) if FLAGS.validate: # 40k Training set ds_train = cifar_input.get_ds_from_tfrecords( @@ -168,7 +169,7 @@ def get_datasets(config): prefetch=config.batch_size) ds_validation = None - # Always compute loss and accuracy on whole training and test set + # Always compute loss and accuracy on whole test set ds_train_one_shot = cifar_input.get_ds_from_tfrecords( data_dir=data_dir, split="train_all", @@ -196,19 +197,11 @@ def get_datasets(config): def train_one_iter(model, inputs, labels, optimizer, global_step=None): """Train for one iteration.""" - if FLAGS.manual_grad: - grads, vars_, loss = model.compute_gradients(inputs, labels, training=True) - optimizer.apply_gradients(zip(grads, vars_), global_step=global_step) - else: # For correctness validation - with tf.GradientTape() as tape: - logits, _ = model(inputs, training=True) - loss = model.compute_loss(logits=logits, labels=labels) - tf.logging.info("Logits are placed on device: {}".format(logits.device)) - grads = tape.gradient(loss, model.trainable_variables) - optimizer.apply_gradients( - zip(grads, model.trainable_variables), global_step=global_step) + grads, vars_, logits, loss = model.compute_gradients( + inputs, labels, training=True) + optimizer.apply_gradients(zip(grads, vars_), global_step=global_step) - return loss.numpy() + return logits, loss def evaluate(model, iterator): @@ -241,16 +234,18 @@ if __name__ == "__main__": "validate", default=False, help="[Optional] Use the validation set or not for hyperparameter search") - flags.DEFINE_boolean( - "manual_grad", - default=False, - help="[Optional] Use manual gradient graph to save memory") flags.DEFINE_string( "dataset", default="cifar-10", help="[Optional] The dataset used; either `cifar-10` or `cifar-100`") flags.DEFINE_string( - "config", default="revnet-38", help="[Optional] Architecture of network.") + "config", + default="revnet-38", + help="[Optional] Architecture of network. " + "Other options include `revnet-110` and `revnet-164`") + flags.DEFINE_boolean( + "use_defun", + default=False, + help="[Optional] Use `tfe.defun` to boost performance.") FLAGS = flags.FLAGS - tf.enable_eager_execution() tf.app.run(main) |