aboutsummaryrefslogtreecommitdiffhomepage
path: root/third_party/examples
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2018-02-13 21:49:28 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-13 21:52:54 -0800
commit14b6365a36c8982092ed2010e1e90f66f663deeb (patch)
treef1e21a5faba718cbf631e4c6b25d88d4507a9277 /third_party/examples
parent6aad89bb560e4bbeafff9d0c42cc8983ab4a3499 (diff)
tfe SPINN example: Add inference; fix serialization
* Also de-flake a test. PiperOrigin-RevId: 185637742
Diffstat (limited to 'third_party/examples')
-rw-r--r--third_party/examples/eager/spinn/README.md41
-rw-r--r--third_party/examples/eager/spinn/spinn.py139
2 files changed, 148 insertions, 32 deletions
diff --git a/third_party/examples/eager/spinn/README.md b/third_party/examples/eager/spinn/README.md
index 6bd3d53e56..335c0fa3b5 100644
--- a/third_party/examples/eager/spinn/README.md
+++ b/third_party/examples/eager/spinn/README.md
@@ -66,3 +66,44 @@ Other eager execution examples can be found under [tensorflow/contrib/eager/pyth
```bash
tensorboard --logdir /tmp/spinn-logs
```
+
+- After training, you may use the model to perform inference on input data in
+ the SNLI data format. The premise and hypotheses sentences are specified with
+ the command-line flags `--inference_premise` and `--inference_hypothesis`,
+ respecitvely. Each sentence should include the words, as well as parentheses
+ representing a binary parsing of the sentence. The words and parentheses
+ should all be separated by spaces. For instance,
+
+ ```bash
+ pythons spinn.py --data_root /tmp/spinn-data --logdir /tmp/spinn-logs \
+ --inference_premise '( ( The dog ) ( ( is running ) . ) )' \
+ --inference_hypothesis '( ( The dog ) ( moves . ) )'
+ ```
+
+ which will generate an output like the following, due to the semantic
+ consistency of the two sentences.
+
+ ```none
+ Inference logits:
+ entailment: 1.101249 (winner)
+ contradiction: -2.374171
+ neutral: -0.296733
+ ```
+
+ By contrast, the following sentence pair:
+
+ ```bash
+ pythons spinn.py --data_root /tmp/spinn-data --logdir /tmp/spinn-logs \
+ --inference_premise '( ( The dog ) ( ( is running ) . ) )' \
+ --inference_hypothesis '( ( The dog ) ( rests . ) )'
+ ```
+
+ will give you an output like the following, due to the semantic
+ contradiction of the two sentences.
+
+ ```none
+ Inference logits:
+ entailment: -1.070098
+ contradiction: 2.798695 (winner)
+ neutral: -1.402287
+ ```
diff --git a/third_party/examples/eager/spinn/spinn.py b/third_party/examples/eager/spinn/spinn.py
index a2fa18eeb1..38ba48d501 100644
--- a/third_party/examples/eager/spinn/spinn.py
+++ b/third_party/examples/eager/spinn/spinn.py
@@ -44,6 +44,7 @@ import os
import sys
import time
+import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
@@ -471,6 +472,15 @@ class SNLIClassifierTrainer(object):
def learning_rate(self):
return self._learning_rate
+ @property
+ def model(self):
+ return self._model
+
+ @property
+ def variables(self):
+ return (self._model.variables + [self.learning_rate] +
+ self._optimizer.variables())
+
def _batch_n_correct(logits, label):
"""Calculate number of correct predictions in a batch.
@@ -488,13 +498,12 @@ def _batch_n_correct(logits, label):
tf.argmax(logits, axis=1), label)), tf.float32)).numpy()
-def _evaluate_on_dataset(snli_data, batch_size, model, trainer, use_gpu):
+def _evaluate_on_dataset(snli_data, batch_size, trainer, use_gpu):
"""Run evaluation on a dataset.
Args:
snli_data: The `data.SnliData` to use in this evaluation.
batch_size: The batch size to use during this evaluation.
- model: An instance of `SNLIClassifier` to evaluate.
trainer: An instance of `SNLIClassifierTrainer to use for this
evaluation.
use_gpu: Whether GPU is being used.
@@ -509,7 +518,7 @@ def _evaluate_on_dataset(snli_data, batch_size, model, trainer, use_gpu):
snli_data, batch_size):
if use_gpu:
label, prem, hypo = label.gpu(), prem.gpu(), hypo.gpu()
- logits = model(prem, prem_trans, hypo, hypo_trans, training=False)
+ logits = trainer.model(prem, prem_trans, hypo, hypo_trans, training=False)
loss_val = trainer.loss(label, logits)
batch_size = tf.shape(label)[0]
mean_loss(loss_val, weights=batch_size.gpu() if use_gpu else batch_size)
@@ -536,13 +545,19 @@ def _get_dataset_iterator(snli_data, batch_size):
return tfe.Iterator(dataset)
-def train_spinn(embed, train_data, dev_data, test_data, config):
- """Train a SPINN model.
+def train_or_infer_spinn(embed,
+ word2index,
+ train_data,
+ dev_data,
+ test_data,
+ config):
+ """Perform Training or Inference on a SPINN model.
Args:
embed: The embedding matrix as a float32 numpy array with shape
[vocabulary_size, word_vector_len]. word_vector_len is the length of a
word embedding vector.
+ word2index: A `dict` mapping word to word index.
train_data: An instance of `data.SnliData`, for the train split.
dev_data: Same as above, for the dev split.
test_data: Same as above, for the test split.
@@ -550,13 +565,35 @@ def train_spinn(embed, train_data, dev_data, test_data, config):
details.
Returns:
- 1. Final loss value on the test split.
- 2. Final fraction of correct classifications on the test split.
+ If `config.inference_premise ` and `config.inference_hypothesis` are not
+ `None`, i.e., inference mode: the logits for the possible labels of the
+ SNLI data set, as numpy array of three floats.
+ else:
+ The trainer object.
+ Raises:
+ ValueError: if only one of config.inference_premise and
+ config.inference_hypothesis is specified.
"""
+ # TODO(cais): Refactor this function into separate one for training and
+ # inference.
use_gpu = tfe.num_gpus() > 0 and not config.force_cpu
device = "gpu:0" if use_gpu else "cpu:0"
print("Using device: %s" % device)
+ if ((config.inference_premise and not config.inference_hypothesis) or
+ (not config.inference_premise and config.inference_hypothesis)):
+ raise ValueError(
+ "--inference_premise and --inference_hypothesis must be both "
+ "specified or both unspecified, but only one is specified.")
+
+ if config.inference_premise:
+ # Inference mode.
+ inference_sentence_pair = [
+ data.encode_sentence(config.inference_premise, word2index),
+ data.encode_sentence(config.inference_hypothesis, word2index)]
+ else:
+ inference_sentence_pair = None
+
log_header = (
" Time Epoch Iteration Progress (%Epoch) Loss Dev/Loss"
" Accuracy Dev/Accuracy")
@@ -569,16 +606,36 @@ def train_spinn(embed, train_data, dev_data, test_data, config):
summary_writer = tf.contrib.summary.create_file_writer(
config.logdir, flush_millis=10000)
- train_len = train_data.num_batches(config.batch_size)
+
with tf.device(device), \
- tfe.restore_variables_on_create(
- tf.train.latest_checkpoint(config.logdir)), \
summary_writer.as_default(), \
tf.contrib.summary.always_record_summaries():
- model = SNLIClassifier(config, embed)
- global_step = tf.train.get_or_create_global_step()
- trainer = SNLIClassifierTrainer(model, config.lr)
-
+ with tfe.restore_variables_on_create(
+ tf.train.latest_checkpoint(config.logdir)):
+ model = SNLIClassifier(config, embed)
+ global_step = tf.train.get_or_create_global_step()
+ trainer = SNLIClassifierTrainer(model, config.lr)
+
+ if inference_sentence_pair:
+ # Inference mode.
+ with tfe.restore_variables_on_create(
+ tf.train.latest_checkpoint(config.logdir)):
+ prem, prem_trans = inference_sentence_pair[0]
+ hypo, hypo_trans = inference_sentence_pair[1]
+ hypo_trans = inference_sentence_pair[1][1]
+ inference_logits = model( # pylint: disable=not-callable
+ tf.constant(prem), tf.constant(prem_trans),
+ tf.constant(hypo), tf.constant(hypo_trans), training=False)
+ inference_logits = np.array(inference_logits[0][1:])
+ max_index = np.argmax(inference_logits)
+ print("\nInference logits:")
+ for i, (label, logit) in enumerate(
+ zip(data.POSSIBLE_LABELS, inference_logits)):
+ winner_tag = " (winner)" if max_index == i else ""
+ print(" {0:<16}{1:.6f}{2}".format(label + ":", logit, winner_tag))
+ return inference_logits
+
+ train_len = train_data.num_batches(config.batch_size)
start = time.time()
iterations = 0
mean_loss = tfe.metrics.Mean()
@@ -594,23 +651,24 @@ def train_spinn(embed, train_data, dev_data, test_data, config):
# remain on CPU. Same in _evaluate_on_dataset().
iterations += 1
- batch_train_loss, batch_train_logits = trainer.train_batch(
- label, prem, prem_trans, hypo, hypo_trans)
+ with tfe.restore_variables_on_create(
+ tf.train.latest_checkpoint(config.logdir)):
+ batch_train_loss, batch_train_logits = trainer.train_batch(
+ label, prem, prem_trans, hypo, hypo_trans)
batch_size = tf.shape(label)[0]
mean_loss(batch_train_loss.numpy(),
weights=batch_size.gpu() if use_gpu else batch_size)
accuracy(tf.argmax(batch_train_logits, axis=1), label)
if iterations % config.save_every == 0:
- all_variables = (
- model.variables + [trainer.learning_rate] + [global_step])
+ all_variables = trainer.variables + [global_step]
saver = tfe.Saver(all_variables)
saver.save(os.path.join(config.logdir, "ckpt"),
global_step=global_step)
if iterations % config.dev_every == 0:
dev_loss, dev_frac_correct = _evaluate_on_dataset(
- dev_data, config.batch_size, model, trainer, use_gpu)
+ dev_data, config.batch_size, trainer, use_gpu)
print(dev_log_template.format(
time.time() - start,
epoch, iterations, 1 + batch_idx, train_len,
@@ -638,10 +696,12 @@ def train_spinn(embed, train_data, dev_data, test_data, config):
trainer.decay_learning_rate(config.lr_decay_by)
test_loss, test_frac_correct = _evaluate_on_dataset(
- test_data, config.batch_size, model, trainer, use_gpu)
+ test_data, config.batch_size, trainer, use_gpu)
print("Final test loss: %g; accuracy: %g%%" %
(test_loss, test_frac_correct * 100.0))
+ return trainer
+
def main(_):
config = FLAGS
@@ -650,18 +710,24 @@ def main(_):
vocab = data.load_vocabulary(FLAGS.data_root)
word2index, embed = data.load_word_vectors(FLAGS.data_root, vocab)
- print("Loading train, dev and test data...")
- train_data = data.SnliData(
- os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_train.txt"),
- word2index, sentence_len_limit=FLAGS.sentence_len_limit)
- dev_data = data.SnliData(
- os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_dev.txt"),
- word2index, sentence_len_limit=FLAGS.sentence_len_limit)
- test_data = data.SnliData(
- os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_test.txt"),
- word2index, sentence_len_limit=FLAGS.sentence_len_limit)
-
- train_spinn(embed, train_data, dev_data, test_data, config)
+ if not (config.inference_premise or config.inference_hypothesis):
+ print("Loading train, dev and test data...")
+ train_data = data.SnliData(
+ os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_train.txt"),
+ word2index, sentence_len_limit=FLAGS.sentence_len_limit)
+ dev_data = data.SnliData(
+ os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_dev.txt"),
+ word2index, sentence_len_limit=FLAGS.sentence_len_limit)
+ test_data = data.SnliData(
+ os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_test.txt"),
+ word2index, sentence_len_limit=FLAGS.sentence_len_limit)
+ else:
+ train_data = None
+ dev_data = None
+ test_data = None
+
+ train_or_infer_spinn(
+ embed, word2index, train_data, dev_data, test_data, config)
if __name__ == "__main__":
@@ -678,6 +744,15 @@ if __name__ == "__main__":
parser.add_argument("--logdir", type=str, default="/tmp/spinn-logs",
help="Directory in which summaries will be written for "
"TensorBoard.")
+ parser.add_argument("--inference_premise", type=str, default=None,
+ help="Premise sentence for inference. Must be "
+ "accompanied by --inference_hypothesis. If specified, "
+ "will override all training parameters and perform "
+ "inference.")
+ parser.add_argument("--inference_hypothesis", type=str, default=None,
+ help="Hypothesis sentence for inference. Must be "
+ "accompanied by --inference_premise. If specified, will "
+ "override all training parameters and perform inference.")
parser.add_argument("--epochs", type=int, default=50,
help="Number of epochs to train.")
parser.add_argument("--batch_size", type=int, default=128,