aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2018-02-13 21:49:28 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-13 21:52:54 -0800
commit14b6365a36c8982092ed2010e1e90f66f663deeb (patch)
treef1e21a5faba718cbf631e4c6b25d88d4507a9277
parent6aad89bb560e4bbeafff9d0c42cc8983ab4a3499 (diff)
tfe SPINN example: Add inference; fix serialization
* Also de-flake a test. PiperOrigin-RevId: 185637742
-rw-r--r--tensorflow/contrib/eager/python/examples/spinn/BUILD6
-rw-r--r--tensorflow/contrib/eager/python/examples/spinn/data.py23
-rw-r--r--tensorflow/contrib/eager/python/examples/spinn/data_test.py51
-rw-r--r--tensorflow/contrib/eager/python/examples/spinn/spinn_test.py87
-rw-r--r--third_party/examples/eager/spinn/README.md41
-rw-r--r--third_party/examples/eager/spinn/spinn.py139
6 files changed, 287 insertions, 60 deletions
diff --git a/tensorflow/contrib/eager/python/examples/spinn/BUILD b/tensorflow/contrib/eager/python/examples/spinn/BUILD
index 21055cfe11..a1f8a759e2 100644
--- a/tensorflow/contrib/eager/python/examples/spinn/BUILD
+++ b/tensorflow/contrib/eager/python/examples/spinn/BUILD
@@ -38,9 +38,5 @@ cuda_py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib",
],
- tags = [
- "manual",
- "no_gpu",
- "no_pip", # because spinn.py is under third_party/.
- ],
+ tags = ["no_pip"], # because spinn.py is under third_party/.
)
diff --git a/tensorflow/contrib/eager/python/examples/spinn/data.py b/tensorflow/contrib/eager/python/examples/spinn/data.py
index fcaae0a4f8..3bc3bb49bc 100644
--- a/tensorflow/contrib/eager/python/examples/spinn/data.py
+++ b/tensorflow/contrib/eager/python/examples/spinn/data.py
@@ -227,6 +227,29 @@ def calculate_bins(length2count, min_bin_size):
return bounds
+def encode_sentence(sentence, word2index):
+ """Encode a single sentence as word indices and shift-reduce code.
+
+ Args:
+ sentence: The sentence with added binary parse information, represented as
+ a string, with all the word items and parentheses separated by spaces.
+ E.g., '( ( The dog ) ( ( is ( playing toys ) ) . ) )'.
+ word2index: A `dict` mapping words to their word indices.
+
+ Returns:
+ 1. Word indices as a numpy array, with shape `(sequence_len, 1)`.
+ 2. Shift-reduce sequence as a numpy array, with shape
+ `(sequence_len * 2 - 3, 1)`.
+ """
+ items = [w for w in sentence.split(" ") if w]
+ words = get_non_parenthesis_words(items)
+ shift_reduce = get_shift_reduce(items)
+ word_indices = pad_and_reverse_word_ids(
+ [[word2index.get(word, UNK_CODE) for word in words]]).T
+ return (word_indices,
+ np.expand_dims(np.array(shift_reduce, dtype=np.int64), -1))
+
+
class SnliData(object):
"""A split of SNLI data."""
diff --git a/tensorflow/contrib/eager/python/examples/spinn/data_test.py b/tensorflow/contrib/eager/python/examples/spinn/data_test.py
index e4f0b37c50..54fef2c3fe 100644
--- a/tensorflow/contrib/eager/python/examples/spinn/data_test.py
+++ b/tensorflow/contrib/eager/python/examples/spinn/data_test.py
@@ -22,6 +22,7 @@ import os
import shutil
import tempfile
+import numpy as np
import tensorflow as tf
from tensorflow.contrib.eager.python.examples.spinn import data
@@ -173,14 +174,9 @@ class DataTest(tf.test.TestCase):
ValueError, "Cannot find GloVe embedding file at"):
data.load_word_vectors(self._temp_data_dir, vocab)
- def testSnliData(self):
- """Unit test for SnliData objects."""
- snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
- fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt")
- os.makedirs(snli_1_0_dir)
-
+ def _createFakeSnliData(self, fake_snli_file):
# Four sentences in total.
- with open(fake_train_file, "wt") as f:
+ with open(fake_snli_file, "wt") as f:
f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t"
"sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t"
"captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n")
@@ -205,10 +201,7 @@ class DataTest(tf.test.TestCase):
"4705552913.jpg#2\t4705552913.jpg#2r1n\t"
"neutral\tentailment\tneutral\tneutral\tneutral\n")
- glove_dir = os.path.join(self._temp_data_dir, "glove")
- os.makedirs(glove_dir)
- glove_file = os.path.join(glove_dir, "glove.42B.300d.txt")
-
+ def _createFakeGloveData(self, glove_file):
words = [".", "foo", "bar", "baz", "quux", "quuz", "grault", "garply"]
with open(glove_file, "wt") as f:
for i, word in enumerate(words):
@@ -220,6 +213,40 @@ class DataTest(tf.test.TestCase):
else:
f.write("\n")
+ def testEncodeSingleSentence(self):
+ snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
+ fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt")
+ os.makedirs(snli_1_0_dir)
+ self._createFakeSnliData(fake_train_file)
+ vocab = data.load_vocabulary(self._temp_data_dir)
+ glove_dir = os.path.join(self._temp_data_dir, "glove")
+ os.makedirs(glove_dir)
+ glove_file = os.path.join(glove_dir, "glove.42B.300d.txt")
+ self._createFakeGloveData(glove_file)
+ word2index, _ = data.load_word_vectors(self._temp_data_dir, vocab)
+
+ sentence_variants = [
+ "( Foo ( ( bar baz ) . ) )",
+ " ( Foo ( ( bar baz ) . ) ) ",
+ "( Foo ( ( bar baz ) . ) )"]
+ for sentence in sentence_variants:
+ word_indices, shift_reduce = data.encode_sentence(sentence, word2index)
+ self.assertEqual(np.int64, word_indices.dtype)
+ self.assertEqual((5, 1), word_indices.shape)
+ self.assertAllClose(
+ np.array([[3, 3, 3, 2, 3, 2, 2]], dtype=np.int64).T, shift_reduce)
+
+ def testSnliData(self):
+ snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
+ fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt")
+ os.makedirs(snli_1_0_dir)
+ self._createFakeSnliData(fake_train_file)
+
+ glove_dir = os.path.join(self._temp_data_dir, "glove")
+ os.makedirs(glove_dir)
+ glove_file = os.path.join(glove_dir, "glove.42B.300d.txt")
+ self._createFakeGloveData(glove_file)
+
vocab = data.load_vocabulary(self._temp_data_dir)
word2index, _ = data.load_word_vectors(self._temp_data_dir, vocab)
@@ -230,7 +257,7 @@ class DataTest(tf.test.TestCase):
self.assertEqual(1, train_data.num_batches(4))
generator = train_data.get_generator(2)()
- for i in range(2):
+ for _ in range(2):
label, prem, prem_trans, hypo, hypo_trans = next(generator)
self.assertEqual(2, len(label))
self.assertEqual((4, 2), prem.shape)
diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py
index 7b2f09cba1..eefc06d90d 100644
--- a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py
+++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py
@@ -36,6 +36,7 @@ from third_party.examples.eager.spinn import spinn
from tensorflow.contrib.summary import summary_test_util
from tensorflow.python.eager import test
from tensorflow.python.framework import test_util
+from tensorflow.python.training import checkpoint_utils
# pylint: enable=g-bad-import-order
@@ -66,13 +67,30 @@ def _generate_synthetic_snli_data_batch(sequence_length,
return labels, prem, prem_trans, hypo, hypo_trans
-def _test_spinn_config(d_embed, d_out, logdir=None):
+def _test_spinn_config(d_embed, d_out, logdir=None, inference_sentences=None):
+ """Generate a config tuple for testing.
+
+ Args:
+ d_embed: Embedding dimensions.
+ d_out: Model output dimensions.
+ logdir: Optional logdir.
+ inference_sentences: A 2-tuple of strings representing the sentences (with
+ binary parsing result), e.g.,
+ ("( ( The dog ) ( ( is running ) . ) )", "( ( The dog ) ( moves . ) )").
+
+ Returns:
+ A config tuple.
+ """
config_tuple = collections.namedtuple(
"Config", ["d_hidden", "d_proj", "d_tracker", "predict",
"embed_dropout", "mlp_dropout", "n_mlp_layers", "d_mlp",
"d_out", "projection", "lr", "batch_size", "epochs",
"force_cpu", "logdir", "log_every", "dev_every", "save_every",
- "lr_decay_every", "lr_decay_by"])
+ "lr_decay_every", "lr_decay_by", "inference_premise",
+ "inference_hypothesis"])
+
+ inference_premise = inference_sentences[0] if inference_sentences else None
+ inference_hypothesis = inference_sentences[1] if inference_sentences else None
return config_tuple(
d_hidden=d_embed,
d_proj=d_embed * 2,
@@ -86,14 +104,16 @@ def _test_spinn_config(d_embed, d_out, logdir=None):
projection=True,
lr=2e-2,
batch_size=2,
- epochs=10,
+ epochs=20,
force_cpu=False,
logdir=logdir,
log_every=1,
dev_every=2,
save_every=2,
lr_decay_every=1,
- lr_decay_by=0.75)
+ lr_decay_by=0.75,
+ inference_premise=inference_premise,
+ inference_hypothesis=inference_hypothesis)
class SpinnTest(test_util.TensorFlowTestCase):
@@ -288,11 +308,7 @@ class SpinnTest(test_util.TensorFlowTestCase):
# Training on the batch should have led to a change in the loss value.
self.assertNotEqual(loss1.numpy(), loss2.numpy())
- def testTrainSpinn(self):
- """Test with fake toy SNLI data and GloVe vectors."""
-
- # 1. Create and load a fake SNLI data file and a fake GloVe embedding file.
- snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
+ def _create_test_data(self, snli_1_0_dir):
fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt")
os.makedirs(snli_1_0_dir)
@@ -337,13 +353,52 @@ class SpinnTest(test_util.TensorFlowTestCase):
else:
f.write("\n")
+ return fake_train_file
+
+ def testInferSpinnWorks(self):
+ """Test inference with the spinn model."""
+ snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
+ self._create_test_data(snli_1_0_dir)
+
+ vocab = data.load_vocabulary(self._temp_data_dir)
+ word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab)
+
+ config = _test_spinn_config(
+ data.WORD_VECTOR_LEN, 4,
+ logdir=os.path.join(self._temp_data_dir, "logdir"),
+ inference_sentences=("( foo ( bar . ) )", "( bar ( foo . ) )"))
+ logits = spinn.train_or_infer_spinn(
+ embed, word2index, None, None, None, config)
+ self.assertEqual(np.float32, logits.dtype)
+ self.assertEqual((3,), logits.shape)
+
+ def testInferSpinnThrowsErrorIfOnlyOneSentenceIsSpecified(self):
+ snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
+ self._create_test_data(snli_1_0_dir)
+
+ vocab = data.load_vocabulary(self._temp_data_dir)
+ word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab)
+
+ config = _test_spinn_config(
+ data.WORD_VECTOR_LEN, 4,
+ logdir=os.path.join(self._temp_data_dir, "logdir"),
+ inference_sentences=("( foo ( bar . ) )", None))
+ with self.assertRaises(ValueError):
+ spinn.train_or_infer_spinn(embed, word2index, None, None, None, config)
+
+ def testTrainSpinn(self):
+ """Test with fake toy SNLI data and GloVe vectors."""
+
+ # 1. Create and load a fake SNLI data file and a fake GloVe embedding file.
+ snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
+ fake_train_file = self._create_test_data(snli_1_0_dir)
+
vocab = data.load_vocabulary(self._temp_data_dir)
word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab)
train_data = data.SnliData(fake_train_file, word2index)
dev_data = data.SnliData(fake_train_file, word2index)
test_data = data.SnliData(fake_train_file, word2index)
- print(embed)
# 2. Create a fake config.
config = _test_spinn_config(
@@ -351,7 +406,8 @@ class SpinnTest(test_util.TensorFlowTestCase):
logdir=os.path.join(self._temp_data_dir, "logdir"))
# 3. Test training of a SPINN model.
- spinn.train_spinn(embed, train_data, dev_data, test_data, config)
+ trainer = spinn.train_or_infer_spinn(
+ embed, word2index, train_data, dev_data, test_data, config)
# 4. Load train loss values from the summary files and verify that they
# decrease with training.
@@ -363,6 +419,15 @@ class SpinnTest(test_util.TensorFlowTestCase):
self.assertEqual(config.epochs, len(train_losses))
self.assertLess(train_losses[-1], train_losses[0])
+ # 5. Verify that checkpoints exist and contains all the expected variables.
+ self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*")))
+ ckpt_variable_names = [
+ item[0] for item in checkpoint_utils.list_variables(config.logdir)]
+ self.assertIn("global_step", ckpt_variable_names)
+ for v in trainer.variables:
+ variable_name = v.name[:v.name.index(":")] if ":" in v.name else v.name
+ self.assertIn(variable_name, ckpt_variable_names)
+
class EagerSpinnSNLIClassifierBenchmark(test.Benchmark):
diff --git a/third_party/examples/eager/spinn/README.md b/third_party/examples/eager/spinn/README.md
index 6bd3d53e56..335c0fa3b5 100644
--- a/third_party/examples/eager/spinn/README.md
+++ b/third_party/examples/eager/spinn/README.md
@@ -66,3 +66,44 @@ Other eager execution examples can be found under [tensorflow/contrib/eager/pyth
```bash
tensorboard --logdir /tmp/spinn-logs
```
+
+- After training, you may use the model to perform inference on input data in
+ the SNLI data format. The premise and hypotheses sentences are specified with
+ the command-line flags `--inference_premise` and `--inference_hypothesis`,
+ respecitvely. Each sentence should include the words, as well as parentheses
+ representing a binary parsing of the sentence. The words and parentheses
+ should all be separated by spaces. For instance,
+
+ ```bash
+ pythons spinn.py --data_root /tmp/spinn-data --logdir /tmp/spinn-logs \
+ --inference_premise '( ( The dog ) ( ( is running ) . ) )' \
+ --inference_hypothesis '( ( The dog ) ( moves . ) )'
+ ```
+
+ which will generate an output like the following, due to the semantic
+ consistency of the two sentences.
+
+ ```none
+ Inference logits:
+ entailment: 1.101249 (winner)
+ contradiction: -2.374171
+ neutral: -0.296733
+ ```
+
+ By contrast, the following sentence pair:
+
+ ```bash
+ pythons spinn.py --data_root /tmp/spinn-data --logdir /tmp/spinn-logs \
+ --inference_premise '( ( The dog ) ( ( is running ) . ) )' \
+ --inference_hypothesis '( ( The dog ) ( rests . ) )'
+ ```
+
+ will give you an output like the following, due to the semantic
+ contradiction of the two sentences.
+
+ ```none
+ Inference logits:
+ entailment: -1.070098
+ contradiction: 2.798695 (winner)
+ neutral: -1.402287
+ ```
diff --git a/third_party/examples/eager/spinn/spinn.py b/third_party/examples/eager/spinn/spinn.py
index a2fa18eeb1..38ba48d501 100644
--- a/third_party/examples/eager/spinn/spinn.py
+++ b/third_party/examples/eager/spinn/spinn.py
@@ -44,6 +44,7 @@ import os
import sys
import time
+import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
@@ -471,6 +472,15 @@ class SNLIClassifierTrainer(object):
def learning_rate(self):
return self._learning_rate
+ @property
+ def model(self):
+ return self._model
+
+ @property
+ def variables(self):
+ return (self._model.variables + [self.learning_rate] +
+ self._optimizer.variables())
+
def _batch_n_correct(logits, label):
"""Calculate number of correct predictions in a batch.
@@ -488,13 +498,12 @@ def _batch_n_correct(logits, label):
tf.argmax(logits, axis=1), label)), tf.float32)).numpy()
-def _evaluate_on_dataset(snli_data, batch_size, model, trainer, use_gpu):
+def _evaluate_on_dataset(snli_data, batch_size, trainer, use_gpu):
"""Run evaluation on a dataset.
Args:
snli_data: The `data.SnliData` to use in this evaluation.
batch_size: The batch size to use during this evaluation.
- model: An instance of `SNLIClassifier` to evaluate.
trainer: An instance of `SNLIClassifierTrainer to use for this
evaluation.
use_gpu: Whether GPU is being used.
@@ -509,7 +518,7 @@ def _evaluate_on_dataset(snli_data, batch_size, model, trainer, use_gpu):
snli_data, batch_size):
if use_gpu:
label, prem, hypo = label.gpu(), prem.gpu(), hypo.gpu()
- logits = model(prem, prem_trans, hypo, hypo_trans, training=False)
+ logits = trainer.model(prem, prem_trans, hypo, hypo_trans, training=False)
loss_val = trainer.loss(label, logits)
batch_size = tf.shape(label)[0]
mean_loss(loss_val, weights=batch_size.gpu() if use_gpu else batch_size)
@@ -536,13 +545,19 @@ def _get_dataset_iterator(snli_data, batch_size):
return tfe.Iterator(dataset)
-def train_spinn(embed, train_data, dev_data, test_data, config):
- """Train a SPINN model.
+def train_or_infer_spinn(embed,
+ word2index,
+ train_data,
+ dev_data,
+ test_data,
+ config):
+ """Perform Training or Inference on a SPINN model.
Args:
embed: The embedding matrix as a float32 numpy array with shape
[vocabulary_size, word_vector_len]. word_vector_len is the length of a
word embedding vector.
+ word2index: A `dict` mapping word to word index.
train_data: An instance of `data.SnliData`, for the train split.
dev_data: Same as above, for the dev split.
test_data: Same as above, for the test split.
@@ -550,13 +565,35 @@ def train_spinn(embed, train_data, dev_data, test_data, config):
details.
Returns:
- 1. Final loss value on the test split.
- 2. Final fraction of correct classifications on the test split.
+ If `config.inference_premise ` and `config.inference_hypothesis` are not
+ `None`, i.e., inference mode: the logits for the possible labels of the
+ SNLI data set, as numpy array of three floats.
+ else:
+ The trainer object.
+ Raises:
+ ValueError: if only one of config.inference_premise and
+ config.inference_hypothesis is specified.
"""
+ # TODO(cais): Refactor this function into separate one for training and
+ # inference.
use_gpu = tfe.num_gpus() > 0 and not config.force_cpu
device = "gpu:0" if use_gpu else "cpu:0"
print("Using device: %s" % device)
+ if ((config.inference_premise and not config.inference_hypothesis) or
+ (not config.inference_premise and config.inference_hypothesis)):
+ raise ValueError(
+ "--inference_premise and --inference_hypothesis must be both "
+ "specified or both unspecified, but only one is specified.")
+
+ if config.inference_premise:
+ # Inference mode.
+ inference_sentence_pair = [
+ data.encode_sentence(config.inference_premise, word2index),
+ data.encode_sentence(config.inference_hypothesis, word2index)]
+ else:
+ inference_sentence_pair = None
+
log_header = (
" Time Epoch Iteration Progress (%Epoch) Loss Dev/Loss"
" Accuracy Dev/Accuracy")
@@ -569,16 +606,36 @@ def train_spinn(embed, train_data, dev_data, test_data, config):
summary_writer = tf.contrib.summary.create_file_writer(
config.logdir, flush_millis=10000)
- train_len = train_data.num_batches(config.batch_size)
+
with tf.device(device), \
- tfe.restore_variables_on_create(
- tf.train.latest_checkpoint(config.logdir)), \
summary_writer.as_default(), \
tf.contrib.summary.always_record_summaries():
- model = SNLIClassifier(config, embed)
- global_step = tf.train.get_or_create_global_step()
- trainer = SNLIClassifierTrainer(model, config.lr)
-
+ with tfe.restore_variables_on_create(
+ tf.train.latest_checkpoint(config.logdir)):
+ model = SNLIClassifier(config, embed)
+ global_step = tf.train.get_or_create_global_step()
+ trainer = SNLIClassifierTrainer(model, config.lr)
+
+ if inference_sentence_pair:
+ # Inference mode.
+ with tfe.restore_variables_on_create(
+ tf.train.latest_checkpoint(config.logdir)):
+ prem, prem_trans = inference_sentence_pair[0]
+ hypo, hypo_trans = inference_sentence_pair[1]
+ hypo_trans = inference_sentence_pair[1][1]
+ inference_logits = model( # pylint: disable=not-callable
+ tf.constant(prem), tf.constant(prem_trans),
+ tf.constant(hypo), tf.constant(hypo_trans), training=False)
+ inference_logits = np.array(inference_logits[0][1:])
+ max_index = np.argmax(inference_logits)
+ print("\nInference logits:")
+ for i, (label, logit) in enumerate(
+ zip(data.POSSIBLE_LABELS, inference_logits)):
+ winner_tag = " (winner)" if max_index == i else ""
+ print(" {0:<16}{1:.6f}{2}".format(label + ":", logit, winner_tag))
+ return inference_logits
+
+ train_len = train_data.num_batches(config.batch_size)
start = time.time()
iterations = 0
mean_loss = tfe.metrics.Mean()
@@ -594,23 +651,24 @@ def train_spinn(embed, train_data, dev_data, test_data, config):
# remain on CPU. Same in _evaluate_on_dataset().
iterations += 1
- batch_train_loss, batch_train_logits = trainer.train_batch(
- label, prem, prem_trans, hypo, hypo_trans)
+ with tfe.restore_variables_on_create(
+ tf.train.latest_checkpoint(config.logdir)):
+ batch_train_loss, batch_train_logits = trainer.train_batch(
+ label, prem, prem_trans, hypo, hypo_trans)
batch_size = tf.shape(label)[0]
mean_loss(batch_train_loss.numpy(),
weights=batch_size.gpu() if use_gpu else batch_size)
accuracy(tf.argmax(batch_train_logits, axis=1), label)
if iterations % config.save_every == 0:
- all_variables = (
- model.variables + [trainer.learning_rate] + [global_step])
+ all_variables = trainer.variables + [global_step]
saver = tfe.Saver(all_variables)
saver.save(os.path.join(config.logdir, "ckpt"),
global_step=global_step)
if iterations % config.dev_every == 0:
dev_loss, dev_frac_correct = _evaluate_on_dataset(
- dev_data, config.batch_size, model, trainer, use_gpu)
+ dev_data, config.batch_size, trainer, use_gpu)
print(dev_log_template.format(
time.time() - start,
epoch, iterations, 1 + batch_idx, train_len,
@@ -638,10 +696,12 @@ def train_spinn(embed, train_data, dev_data, test_data, config):
trainer.decay_learning_rate(config.lr_decay_by)
test_loss, test_frac_correct = _evaluate_on_dataset(
- test_data, config.batch_size, model, trainer, use_gpu)
+ test_data, config.batch_size, trainer, use_gpu)
print("Final test loss: %g; accuracy: %g%%" %
(test_loss, test_frac_correct * 100.0))
+ return trainer
+
def main(_):
config = FLAGS
@@ -650,18 +710,24 @@ def main(_):
vocab = data.load_vocabulary(FLAGS.data_root)
word2index, embed = data.load_word_vectors(FLAGS.data_root, vocab)
- print("Loading train, dev and test data...")
- train_data = data.SnliData(
- os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_train.txt"),
- word2index, sentence_len_limit=FLAGS.sentence_len_limit)
- dev_data = data.SnliData(
- os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_dev.txt"),
- word2index, sentence_len_limit=FLAGS.sentence_len_limit)
- test_data = data.SnliData(
- os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_test.txt"),
- word2index, sentence_len_limit=FLAGS.sentence_len_limit)
-
- train_spinn(embed, train_data, dev_data, test_data, config)
+ if not (config.inference_premise or config.inference_hypothesis):
+ print("Loading train, dev and test data...")
+ train_data = data.SnliData(
+ os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_train.txt"),
+ word2index, sentence_len_limit=FLAGS.sentence_len_limit)
+ dev_data = data.SnliData(
+ os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_dev.txt"),
+ word2index, sentence_len_limit=FLAGS.sentence_len_limit)
+ test_data = data.SnliData(
+ os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_test.txt"),
+ word2index, sentence_len_limit=FLAGS.sentence_len_limit)
+ else:
+ train_data = None
+ dev_data = None
+ test_data = None
+
+ train_or_infer_spinn(
+ embed, word2index, train_data, dev_data, test_data, config)
if __name__ == "__main__":
@@ -678,6 +744,15 @@ if __name__ == "__main__":
parser.add_argument("--logdir", type=str, default="/tmp/spinn-logs",
help="Directory in which summaries will be written for "
"TensorBoard.")
+ parser.add_argument("--inference_premise", type=str, default=None,
+ help="Premise sentence for inference. Must be "
+ "accompanied by --inference_hypothesis. If specified, "
+ "will override all training parameters and perform "
+ "inference.")
+ parser.add_argument("--inference_hypothesis", type=str, default=None,
+ help="Hypothesis sentence for inference. Must be "
+ "accompanied by --inference_premise. If specified, will "
+ "override all training parameters and perform inference.")
parser.add_argument("--epochs", type=int, default=50,
help="Number of epochs to train.")
parser.add_argument("--batch_size", type=int, default=128,