From 14b6365a36c8982092ed2010e1e90f66f663deeb Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Tue, 13 Feb 2018 21:49:28 -0800 Subject: tfe SPINN example: Add inference; fix serialization * Also de-flake a test. PiperOrigin-RevId: 185637742 --- third_party/examples/eager/spinn/spinn.py | 139 +++++++++++++++++++++++------- 1 file changed, 107 insertions(+), 32 deletions(-) (limited to 'third_party/examples/eager/spinn/spinn.py') diff --git a/third_party/examples/eager/spinn/spinn.py b/third_party/examples/eager/spinn/spinn.py index a2fa18eeb1..38ba48d501 100644 --- a/third_party/examples/eager/spinn/spinn.py +++ b/third_party/examples/eager/spinn/spinn.py @@ -44,6 +44,7 @@ import os import sys import time +import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf @@ -471,6 +472,15 @@ class SNLIClassifierTrainer(object): def learning_rate(self): return self._learning_rate + @property + def model(self): + return self._model + + @property + def variables(self): + return (self._model.variables + [self.learning_rate] + + self._optimizer.variables()) + def _batch_n_correct(logits, label): """Calculate number of correct predictions in a batch. @@ -488,13 +498,12 @@ def _batch_n_correct(logits, label): tf.argmax(logits, axis=1), label)), tf.float32)).numpy() -def _evaluate_on_dataset(snli_data, batch_size, model, trainer, use_gpu): +def _evaluate_on_dataset(snli_data, batch_size, trainer, use_gpu): """Run evaluation on a dataset. Args: snli_data: The `data.SnliData` to use in this evaluation. batch_size: The batch size to use during this evaluation. - model: An instance of `SNLIClassifier` to evaluate. trainer: An instance of `SNLIClassifierTrainer to use for this evaluation. use_gpu: Whether GPU is being used. @@ -509,7 +518,7 @@ def _evaluate_on_dataset(snli_data, batch_size, model, trainer, use_gpu): snli_data, batch_size): if use_gpu: label, prem, hypo = label.gpu(), prem.gpu(), hypo.gpu() - logits = model(prem, prem_trans, hypo, hypo_trans, training=False) + logits = trainer.model(prem, prem_trans, hypo, hypo_trans, training=False) loss_val = trainer.loss(label, logits) batch_size = tf.shape(label)[0] mean_loss(loss_val, weights=batch_size.gpu() if use_gpu else batch_size) @@ -536,13 +545,19 @@ def _get_dataset_iterator(snli_data, batch_size): return tfe.Iterator(dataset) -def train_spinn(embed, train_data, dev_data, test_data, config): - """Train a SPINN model. +def train_or_infer_spinn(embed, + word2index, + train_data, + dev_data, + test_data, + config): + """Perform Training or Inference on a SPINN model. Args: embed: The embedding matrix as a float32 numpy array with shape [vocabulary_size, word_vector_len]. word_vector_len is the length of a word embedding vector. + word2index: A `dict` mapping word to word index. train_data: An instance of `data.SnliData`, for the train split. dev_data: Same as above, for the dev split. test_data: Same as above, for the test split. @@ -550,13 +565,35 @@ def train_spinn(embed, train_data, dev_data, test_data, config): details. Returns: - 1. Final loss value on the test split. - 2. Final fraction of correct classifications on the test split. + If `config.inference_premise ` and `config.inference_hypothesis` are not + `None`, i.e., inference mode: the logits for the possible labels of the + SNLI data set, as numpy array of three floats. + else: + The trainer object. + Raises: + ValueError: if only one of config.inference_premise and + config.inference_hypothesis is specified. """ + # TODO(cais): Refactor this function into separate one for training and + # inference. use_gpu = tfe.num_gpus() > 0 and not config.force_cpu device = "gpu:0" if use_gpu else "cpu:0" print("Using device: %s" % device) + if ((config.inference_premise and not config.inference_hypothesis) or + (not config.inference_premise and config.inference_hypothesis)): + raise ValueError( + "--inference_premise and --inference_hypothesis must be both " + "specified or both unspecified, but only one is specified.") + + if config.inference_premise: + # Inference mode. + inference_sentence_pair = [ + data.encode_sentence(config.inference_premise, word2index), + data.encode_sentence(config.inference_hypothesis, word2index)] + else: + inference_sentence_pair = None + log_header = ( " Time Epoch Iteration Progress (%Epoch) Loss Dev/Loss" " Accuracy Dev/Accuracy") @@ -569,16 +606,36 @@ def train_spinn(embed, train_data, dev_data, test_data, config): summary_writer = tf.contrib.summary.create_file_writer( config.logdir, flush_millis=10000) - train_len = train_data.num_batches(config.batch_size) + with tf.device(device), \ - tfe.restore_variables_on_create( - tf.train.latest_checkpoint(config.logdir)), \ summary_writer.as_default(), \ tf.contrib.summary.always_record_summaries(): - model = SNLIClassifier(config, embed) - global_step = tf.train.get_or_create_global_step() - trainer = SNLIClassifierTrainer(model, config.lr) - + with tfe.restore_variables_on_create( + tf.train.latest_checkpoint(config.logdir)): + model = SNLIClassifier(config, embed) + global_step = tf.train.get_or_create_global_step() + trainer = SNLIClassifierTrainer(model, config.lr) + + if inference_sentence_pair: + # Inference mode. + with tfe.restore_variables_on_create( + tf.train.latest_checkpoint(config.logdir)): + prem, prem_trans = inference_sentence_pair[0] + hypo, hypo_trans = inference_sentence_pair[1] + hypo_trans = inference_sentence_pair[1][1] + inference_logits = model( # pylint: disable=not-callable + tf.constant(prem), tf.constant(prem_trans), + tf.constant(hypo), tf.constant(hypo_trans), training=False) + inference_logits = np.array(inference_logits[0][1:]) + max_index = np.argmax(inference_logits) + print("\nInference logits:") + for i, (label, logit) in enumerate( + zip(data.POSSIBLE_LABELS, inference_logits)): + winner_tag = " (winner)" if max_index == i else "" + print(" {0:<16}{1:.6f}{2}".format(label + ":", logit, winner_tag)) + return inference_logits + + train_len = train_data.num_batches(config.batch_size) start = time.time() iterations = 0 mean_loss = tfe.metrics.Mean() @@ -594,23 +651,24 @@ def train_spinn(embed, train_data, dev_data, test_data, config): # remain on CPU. Same in _evaluate_on_dataset(). iterations += 1 - batch_train_loss, batch_train_logits = trainer.train_batch( - label, prem, prem_trans, hypo, hypo_trans) + with tfe.restore_variables_on_create( + tf.train.latest_checkpoint(config.logdir)): + batch_train_loss, batch_train_logits = trainer.train_batch( + label, prem, prem_trans, hypo, hypo_trans) batch_size = tf.shape(label)[0] mean_loss(batch_train_loss.numpy(), weights=batch_size.gpu() if use_gpu else batch_size) accuracy(tf.argmax(batch_train_logits, axis=1), label) if iterations % config.save_every == 0: - all_variables = ( - model.variables + [trainer.learning_rate] + [global_step]) + all_variables = trainer.variables + [global_step] saver = tfe.Saver(all_variables) saver.save(os.path.join(config.logdir, "ckpt"), global_step=global_step) if iterations % config.dev_every == 0: dev_loss, dev_frac_correct = _evaluate_on_dataset( - dev_data, config.batch_size, model, trainer, use_gpu) + dev_data, config.batch_size, trainer, use_gpu) print(dev_log_template.format( time.time() - start, epoch, iterations, 1 + batch_idx, train_len, @@ -638,10 +696,12 @@ def train_spinn(embed, train_data, dev_data, test_data, config): trainer.decay_learning_rate(config.lr_decay_by) test_loss, test_frac_correct = _evaluate_on_dataset( - test_data, config.batch_size, model, trainer, use_gpu) + test_data, config.batch_size, trainer, use_gpu) print("Final test loss: %g; accuracy: %g%%" % (test_loss, test_frac_correct * 100.0)) + return trainer + def main(_): config = FLAGS @@ -650,18 +710,24 @@ def main(_): vocab = data.load_vocabulary(FLAGS.data_root) word2index, embed = data.load_word_vectors(FLAGS.data_root, vocab) - print("Loading train, dev and test data...") - train_data = data.SnliData( - os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_train.txt"), - word2index, sentence_len_limit=FLAGS.sentence_len_limit) - dev_data = data.SnliData( - os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_dev.txt"), - word2index, sentence_len_limit=FLAGS.sentence_len_limit) - test_data = data.SnliData( - os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_test.txt"), - word2index, sentence_len_limit=FLAGS.sentence_len_limit) - - train_spinn(embed, train_data, dev_data, test_data, config) + if not (config.inference_premise or config.inference_hypothesis): + print("Loading train, dev and test data...") + train_data = data.SnliData( + os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_train.txt"), + word2index, sentence_len_limit=FLAGS.sentence_len_limit) + dev_data = data.SnliData( + os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_dev.txt"), + word2index, sentence_len_limit=FLAGS.sentence_len_limit) + test_data = data.SnliData( + os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_test.txt"), + word2index, sentence_len_limit=FLAGS.sentence_len_limit) + else: + train_data = None + dev_data = None + test_data = None + + train_or_infer_spinn( + embed, word2index, train_data, dev_data, test_data, config) if __name__ == "__main__": @@ -678,6 +744,15 @@ if __name__ == "__main__": parser.add_argument("--logdir", type=str, default="/tmp/spinn-logs", help="Directory in which summaries will be written for " "TensorBoard.") + parser.add_argument("--inference_premise", type=str, default=None, + help="Premise sentence for inference. Must be " + "accompanied by --inference_hypothesis. If specified, " + "will override all training parameters and perform " + "inference.") + parser.add_argument("--inference_hypothesis", type=str, default=None, + help="Hypothesis sentence for inference. Must be " + "accompanied by --inference_premise. If specified, will " + "override all training parameters and perform inference.") parser.add_argument("--epochs", type=int, default=50, help="Number of epochs to train.") parser.add_argument("--batch_size", type=int, default=128, -- cgit v1.2.3