aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2017-12-01 13:56:10 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-01 13:59:31 -0800
commited9163acfd510c26c49201ec9e360e20a2625ca8 (patch)
tree5d61ab7de410baa898925f0023792e5015817c17
parentae10f63e2fc76faf5835a660043c328d891c41f0 (diff)
TF Eager: Add SPINN model example for dynamic/recursive NN.
PiperOrigin-RevId: 177636427
-rw-r--r--tensorflow/contrib/eager/README.md3
-rw-r--r--tensorflow/contrib/eager/python/examples/BUILD1
-rw-r--r--tensorflow/contrib/eager/python/examples/spinn/BUILD41
-rw-r--r--tensorflow/contrib/eager/python/examples/spinn/README.md13
-rw-r--r--tensorflow/contrib/eager/python/examples/spinn/data.py350
-rw-r--r--tensorflow/contrib/eager/python/examples/spinn/data_test.py243
-rw-r--r--tensorflow/contrib/eager/python/examples/spinn/spinn_test.py409
-rw-r--r--third_party/examples/eager/spinn/BUILD14
-rw-r--r--third_party/examples/eager/spinn/LICENSE29
-rw-r--r--third_party/examples/eager/spinn/README.md54
-rw-r--r--third_party/examples/eager/spinn/spinn.py732
11 files changed, 1889 insertions, 0 deletions
diff --git a/tensorflow/contrib/eager/README.md b/tensorflow/contrib/eager/README.md
index dcc370cd00..09242ee47d 100644
--- a/tensorflow/contrib/eager/README.md
+++ b/tensorflow/contrib/eager/README.md
@@ -76,3 +76,6 @@ For an introduction to eager execution in TensorFlow, see:
## Changelog
- 2017/10/31: Initial preview release.
+- 2017/12/01: Example of dynamic neural network:
+ [SPINN: Stack-augmented Parser-Interpreter Neural Network](https://arxiv.org/abs/1603.06021).
+ See [README.md](python/examples/spinn/README.md) for details.
diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD
index aa21a6ab99..6aef010a21 100644
--- a/tensorflow/contrib/eager/python/examples/BUILD
+++ b/tensorflow/contrib/eager/python/examples/BUILD
@@ -11,5 +11,6 @@ py_library(
"//tensorflow/contrib/eager/python/examples/resnet50",
"//tensorflow/contrib/eager/python/examples/rnn_colorbot",
"//tensorflow/contrib/eager/python/examples/rnn_ptb",
+ "//tensorflow/contrib/eager/python/examples/spinn:data",
],
)
diff --git a/tensorflow/contrib/eager/python/examples/spinn/BUILD b/tensorflow/contrib/eager/python/examples/spinn/BUILD
new file mode 100644
index 0000000000..0263d21325
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/spinn/BUILD
@@ -0,0 +1,41 @@
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = ["//tensorflow:internal"])
+
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+py_library(
+ name = "data",
+ srcs = ["data.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = ["//third_party/py/numpy"],
+)
+
+py_test(
+ name = "data_test",
+ size = "small",
+ srcs = ["data_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":data",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+cuda_py_test(
+ name = "spinn_test",
+ size = "medium",
+ srcs = ["spinn_test.py"],
+ additional_deps = [
+ ":data",
+ "//third_party/examples/eager/spinn",
+ "//third_party/py/numpy",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/contrib/summary:summary_test_util",
+ "//tensorflow/python/eager:test",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_test_lib",
+ ],
+)
diff --git a/tensorflow/contrib/eager/python/examples/spinn/README.md b/tensorflow/contrib/eager/python/examples/spinn/README.md
new file mode 100644
index 0000000000..eb0637df47
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/spinn/README.md
@@ -0,0 +1,13 @@
+# SPINN: Dynamic neural network with TensorFlow eager execution
+
+This directory contains files supporting the
+[spinn.py model in third_party/examples/eager/spinn/](../../../../../../third_party/examples/eager/spinn/spinn.py),
+including
+
+- `data.py`: Utility library for loading and preprocessing the SNLI and GloVe
+ data.
+- `data_test.py` and `spinn_test.py`: Unit tests for the data and model modules.
+
+See the [README.md in third_party/examples/eager/spinn/](../../../../../../third_party/examples/eager/spinn/README.md)
+for detailed background, license and usage information regarding the SPINN code.
+
diff --git a/tensorflow/contrib/eager/python/examples/spinn/data.py b/tensorflow/contrib/eager/python/examples/spinn/data.py
new file mode 100644
index 0000000000..a6e046320f
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/spinn/data.py
@@ -0,0 +1,350 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities of SNLI data and GloVe word vectors for SPINN model.
+
+See more details about the SNLI data set at:
+ https://nlp.stanford.edu/projects/snli/
+
+See more details about the GloVe pretrained word embeddings at:
+ https://nlp.stanford.edu/projects/glove/
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import glob
+import math
+import os
+import random
+
+import numpy as np
+
+POSSIBLE_LABELS = ("entailment", "contradiction", "neutral")
+
+UNK_CODE = 0 # Code for unknown word tokens.
+PAD_CODE = 1 # Code for padding tokens.
+
+SHIFT_CODE = 3
+REDUCE_CODE = 2
+
+WORD_VECTOR_LEN = 300 # Embedding dimensions.
+
+LEFT_PAREN = "("
+RIGHT_PAREN = ")"
+PARENTHESES = (LEFT_PAREN, RIGHT_PAREN)
+
+
+def get_non_parenthesis_words(items):
+ """Get the non-parenthesis items from a SNLI parsed sentence.
+
+ Args:
+ items: Data items from a parsed SNLI setence, with parentheses. E.g.,
+ ["(", "Man", "(", "(", "(", "(", "(", "wearing", "pass", ")", ...
+
+ Returns:
+ A list of non-parenthis word items, all converted to lower case. E.g.,
+ ["man", "wearing", "pass", ...
+ """
+ return [x.lower() for x in items if x not in PARENTHESES and x]
+
+
+def get_shift_reduce(items):
+ """Obtain shift-reduce vector from a list of items from the SNLI data.
+
+ Args:
+ items: Data items as a list of str, e.g.,
+ ["(", "Man", "(", "(", "(", "(", "(", "wearing", "pass", ")", ...
+
+ Returns:
+ A list of shift-reduce transitions, encoded as `SHIFT_CODE` for shift and
+ `REDUCE_CODE` for reduce. See code above for the values of `SHIFT_CODE`
+ and `REDUCE_CODE`.
+ """
+ trans = []
+ for item in items:
+ if item == LEFT_PAREN:
+ continue
+ elif item == RIGHT_PAREN:
+ trans.append(REDUCE_CODE)
+ else:
+ trans.append(SHIFT_CODE)
+ return trans
+
+
+def pad_and_reverse_word_ids(sentences):
+ """Pad a list of sentences to the common maximum length + 1.
+
+ Args:
+ sentences: A list of sentences as a list of list of integers. Each integer
+ is a word ID. Each list of integer corresponds to one sentence.
+
+ Returns:
+ A numpy.ndarray of shape (num_sentences, max_length + 1), wherein max_length
+ is the maximum sentence length (in # of words). Each sentence is reversed
+ and then padded with an extra one at head, as required by the model.
+ """
+ max_len = max(len(sent) for sent in sentences)
+ for sent in sentences:
+ if len(sent) < max_len:
+ sent.extend([PAD_CODE] * (max_len - len(sent)))
+ # Reverse in time order and pad an extra one.
+ sentences = np.fliplr(np.array(sentences, dtype=np.int64))
+ sentences = np.concatenate(
+ [np.ones([sentences.shape[0], 1], dtype=np.int64), sentences], axis=1)
+ return sentences
+
+
+def pad_transitions(sentences_transitions):
+ """Pad a list of shift-reduce transitions to the maximum length."""
+ max_len = max(len(transitions) for transitions in sentences_transitions)
+ for transitions in sentences_transitions:
+ if len(transitions) < max_len:
+ transitions.extend([PAD_CODE] * (max_len - len(transitions)))
+ return np.array(sentences_transitions, dtype=np.int64)
+
+
+def load_vocabulary(data_root):
+ """Load vocabulary from SNLI data files.
+
+ Args:
+ data_root: Root directory of the data. It is assumed that the SNLI data
+ files have been downloaded and extracted to the "snli/snli_1.0"
+ subdirectory of it.
+
+ Returns:
+ Vocabulary as a set of strings.
+
+ Raises:
+ ValueError: If SNLI data files cannot be found.
+ """
+ snli_path = os.path.join(data_root, "snli")
+ snli_glob_pattern = os.path.join(snli_path, "snli_1.0/snli_1.0_*.txt")
+ file_names = glob.glob(snli_glob_pattern)
+ if not file_names:
+ raise ValueError(
+ "Cannot find SNLI data files at %s. "
+ "Please download and extract SNLI data first." % snli_glob_pattern)
+
+ print("Loading vocabulary...")
+ vocab = set()
+ for file_name in file_names:
+ with open(os.path.join(snli_path, file_name), "rt") as f:
+ for i, line in enumerate(f):
+ if i == 0:
+ continue
+ items = line.split("\t")
+ premise_words = get_non_parenthesis_words(items[1].split(" "))
+ hypothesis_words = get_non_parenthesis_words(items[2].split(" "))
+ vocab.update(premise_words)
+ vocab.update(hypothesis_words)
+ return vocab
+
+
+def load_word_vectors(data_root, vocab):
+ """Load GloVe word vectors for words present in the vocabulary.
+
+ Args:
+ data_root: Data root directory. It is assumed that the GloVe file
+ has been downloaded and extracted at the "glove/" subdirectory of it.
+ vocab: A `set` of words, representing the vocabulary.
+
+ Returns:
+ 1. word2index: A dict from lower-case word to row index in the embedding
+ matrix, i.e, `embed` below.
+ 2. embed: The embedding matrix as a float32 numpy array. Its shape is
+ [vocabulary_size, WORD_VECTOR_LEN]. vocabulary_size is len(vocab).
+ WORD_VECTOR_LEN is the embedding dimension (300).
+
+ Raises:
+ ValueError: If GloVe embedding file cannot be found.
+ """
+ glove_path = os.path.join(data_root, "glove/glove.42B.300d.txt")
+ if not os.path.isfile(glove_path):
+ raise ValueError(
+ "Cannot find GloVe embedding file at %s. "
+ "Please download and extract GloVe embeddings first." % glove_path)
+
+ print("Loading word vectors...")
+
+ word2index = dict()
+ embed = []
+
+ embed.append([0] * WORD_VECTOR_LEN) # <unk>
+ embed.append([0] * WORD_VECTOR_LEN) # <pad>
+ word2index["<unk>"] = UNK_CODE
+ word2index["<pad>"] = PAD_CODE
+
+ with open(glove_path, "rt") as f:
+ for line in f:
+ items = line.split(" ")
+ word = items[0]
+ if word in vocab and word not in word2index:
+ word2index[word] = len(embed)
+ vector = np.array([float(item) for item in items[1:]])
+ assert (WORD_VECTOR_LEN,) == vector.shape
+ embed.append(vector)
+ embed = np.array(embed, dtype=np.float32)
+ return word2index, embed
+
+
+def calculate_bins(length2count, min_bin_size):
+ """Cacluate bin boundaries given a histogram of lengths and mininum bin size.
+
+ Args:
+ length2count: A `dict` mapping length to sentence count.
+ min_bin_size: Minimum bin size in terms of total number of sentence pairs
+ in the bin.
+
+ Returns:
+ A `list` representing the right bin boundaries, starting from the inclusive
+ right boundary of the first bin. For example, if the output is
+ [10, 20, 35],
+ it means there are three bins: [1, 10], [11, 20] and [21, 35].
+ """
+ bounds = []
+ lengths = sorted(length2count.keys())
+ cum_count = 0
+ for length in lengths:
+ cum_count += length2count[length]
+ if cum_count >= min_bin_size:
+ bounds.append(length)
+ cum_count = 0
+ if bounds[-1] != lengths[-1]:
+ bounds.append(lengths[-1])
+ return bounds
+
+
+class SnliData(object):
+ """A split of SNLI data."""
+
+ def __init__(self, data_file, word2index, sentence_len_limit=-1):
+ """SnliData constructor.
+
+ Args:
+ data_file: Full path to the data file, e.g.,
+ "/tmp/spinn-data/snli/snli_1.0/snli_1.0.train.txt"
+ word2index: A dict from lower-case word to row index in the embedding
+ matrix (see `load_word_vectors()` for details).
+ sentence_len_limit: Maximum allowed sentence length (# of words).
+ A value of <= 0 means unlimited. Sentences longer than this limit
+ are currently discarded, not truncated.
+ """
+
+ self._labels = []
+ self._premises = []
+ self._premise_transitions = []
+ self._hypotheses = []
+ self._hypothesis_transitions = []
+
+ with open(data_file, "rt") as f:
+ for i, line in enumerate(f):
+ if i == 0:
+ # Skip header line.
+ continue
+ items = line.split("\t")
+ if items[0] not in POSSIBLE_LABELS:
+ continue
+
+ premise_items = items[1].split(" ")
+ hypothesis_items = items[2].split(" ")
+ premise_words = get_non_parenthesis_words(premise_items)
+ hypothesis_words = get_non_parenthesis_words(hypothesis_items)
+
+ if (sentence_len_limit > 0 and
+ (len(premise_words) > sentence_len_limit or
+ len(hypothesis_words) > sentence_len_limit)):
+ # TODO(cais): Maybe truncate; do not discard.
+ continue
+
+ premise_ids = [
+ word2index.get(word, UNK_CODE) for word in premise_words]
+ hypothesis_ids = [
+ word2index.get(word, UNK_CODE) for word in hypothesis_words]
+
+ self._premises.append(premise_ids)
+ self._hypotheses.append(hypothesis_ids)
+ self._premise_transitions.append(get_shift_reduce(premise_items))
+ self._hypothesis_transitions.append(get_shift_reduce(hypothesis_items))
+ assert (len(self._premise_transitions[-1]) ==
+ 2 * len(premise_words) - 1)
+ assert (len(self._hypothesis_transitions[-1]) ==
+ 2 * len(hypothesis_words) - 1)
+
+ self._labels.append(POSSIBLE_LABELS.index(items[0]) + 1)
+
+ assert len(self._labels) == len(self._premises)
+ assert len(self._labels) == len(self._hypotheses)
+ assert len(self._labels) == len(self._premise_transitions)
+ assert len(self._labels) == len(self._hypothesis_transitions)
+
+ def num_batches(self, batch_size):
+ """Calculate number of batches given batch size."""
+ return int(math.ceil(len(self._labels) / batch_size))
+
+ def get_generator(self, batch_size):
+ """Obtain a generator for batched data.
+
+ All examples of this SnliData object are randomly shuffled, sorted
+ according to the maximum sentence length of the premise and hypothesis
+ sentences in the pair, and batched.
+
+ Args:
+ batch_size: Desired batch size.
+
+ Returns:
+ A generator for data batches. The generator yields a 5-tuple:
+ label: An array of the shape (batch_size,).
+ premise: An array of the shape (max_premise_len, batch_size), wherein
+ max_premise_len is the maximum length of the (padded) premise
+ sentence in the batch.
+ premise_transitions: An array of the shape (2 * max_premise_len -3,
+ batch_size).
+ hypothesis: Same as `premise`, but for hypothesis sentences.
+ hypothesis_transitions: Same as `premise_transitions`, but for
+ hypothesis sentences.
+ All the elements of the 5-tuple have dtype `int64`.
+ """
+ # Randomly shuffle examples.
+ zipped = list(zip(
+ self._labels, self._premises, self._premise_transitions,
+ self._hypotheses, self._hypothesis_transitions))
+ random.shuffle(zipped)
+ # Then sort the examples by maximum of the premise and hypothesis sentence
+ # lengths in the pair. During training, the batches are expected to be
+ # shuffled. So it is okay to leave them sorted by max length here.
+ (labels, premises, premise_transitions, hypotheses,
+ hypothesis_transitions) = zip(
+ *sorted(zipped, key=lambda x: max(len(x[1]), len(x[3]))))
+
+ def _generator():
+ begin = 0
+ while begin < len(labels):
+ # The sorting above and the batching here makes sure that sentences of
+ # similar max lengths are batched together, minimizing the inefficiency
+ # due to uneven max lengths. The sentences are batched differently in
+ # each call to get_generator() due to the shuffling before sotring
+ # above. The pad_and_reverse_word_ids() and pad_transitions() functions
+ # take care of any remaning unevenness of the max sentence lengths.
+ end = min(begin + batch_size, len(labels))
+ # Transpose, because the SPINN model requires time-major, instead of
+ # batch-major.
+ yield (labels[begin:end],
+ pad_and_reverse_word_ids(premises[begin:end]).T,
+ pad_transitions(premise_transitions[begin:end]).T,
+ pad_and_reverse_word_ids(hypotheses[begin:end]).T,
+ pad_transitions(hypothesis_transitions[begin:end]).T)
+ begin = end
+ return _generator
diff --git a/tensorflow/contrib/eager/python/examples/spinn/data_test.py b/tensorflow/contrib/eager/python/examples/spinn/data_test.py
new file mode 100644
index 0000000000..e4f0b37c50
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/spinn/data_test.py
@@ -0,0 +1,243 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Unit tests for SPINN data module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import shutil
+import tempfile
+
+import tensorflow as tf
+
+from tensorflow.contrib.eager.python.examples.spinn import data
+
+
+class DataTest(tf.test.TestCase):
+
+ def setUp(self):
+ super(DataTest, self).setUp()
+ self._temp_data_dir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ shutil.rmtree(self._temp_data_dir)
+ super(DataTest, self).tearDown()
+
+ def testGenNonParenthesisWords(self):
+ seq_with_parse = (
+ "( Man ( ( ( ( ( wearing pass ) ( on ( a lanyard ) ) ) and "
+ ") ( standing ( in ( ( a crowd ) ( of people ) ) ) ) ) . ) )")
+ self.assertEqual(
+ ["man", "wearing", "pass", "on", "a", "lanyard", "and", "standing",
+ "in", "a", "crowd", "of", "people", "."],
+ data.get_non_parenthesis_words(seq_with_parse.split(" ")))
+
+ def testGetShiftReduce(self):
+ seq_with_parse = (
+ "( Man ( ( ( ( ( wearing pass ) ( on ( a lanyard ) ) ) and "
+ ") ( standing ( in ( ( a crowd ) ( of people ) ) ) ) ) . ) )")
+ self.assertEqual(
+ [3, 3, 3, 2, 3, 3, 3, 2, 2, 2, 3, 2, 3, 3, 3, 3, 2, 3, 3, 2, 2, 2, 2, 2,
+ 3, 2, 2], data.get_shift_reduce(seq_with_parse.split(" ")))
+
+ def testPadAndReverseWordIds(self):
+ id_sequences = [[0, 2, 3, 4, 5],
+ [6, 7, 8],
+ [9, 10, 11, 12, 13, 14, 15, 16]]
+ self.assertAllClose(
+ [[1, 1, 1, 1, 5, 4, 3, 2, 0],
+ [1, 1, 1, 1, 1, 1, 8, 7, 6],
+ [1, 16, 15, 14, 13, 12, 11, 10, 9]],
+ data.pad_and_reverse_word_ids(id_sequences))
+
+ def testPadTransitions(self):
+ unpadded = [[3, 3, 3, 2, 2, 2, 2],
+ [3, 3, 2, 2, 2]]
+ self.assertAllClose(
+ [[3, 3, 3, 2, 2, 2, 2],
+ [3, 3, 2, 2, 2, 1, 1]],
+ data.pad_transitions(unpadded))
+
+ def testCalculateBins(self):
+ length2count = {
+ 1: 10,
+ 2: 15,
+ 3: 25,
+ 4: 40,
+ 5: 35,
+ 6: 10}
+ self.assertEqual([2, 3, 4, 5, 6],
+ data.calculate_bins(length2count, 20))
+ self.assertEqual([3, 4, 6], data.calculate_bins(length2count, 40))
+ self.assertEqual([4, 6], data.calculate_bins(length2count, 60))
+
+ def testLoadVoacbulary(self):
+ snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
+ fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt")
+ fake_dev_file = os.path.join(snli_1_0_dir, "snli_1.0_dev.txt")
+ os.makedirs(snli_1_0_dir)
+
+ with open(fake_train_file, "wt") as f:
+ f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t"
+ "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t"
+ "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n")
+ f.write("neutral\t( ( Foo bar ) . )\t( ( foo baz ) . )\t"
+ "DummySentence1Parse\tDummySentence2Parse\t"
+ "Foo bar.\tfoo baz.\t"
+ "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
+ "neutral\tentailment\tneutral\tneutral\tneutral\n")
+ with open(fake_dev_file, "wt") as f:
+ f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t"
+ "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t"
+ "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n")
+ f.write("neutral\t( ( Quux quuz ) ? )\t( ( Corge grault ) ! )\t"
+ "DummySentence1Parse\tDummySentence2Parse\t"
+ "Quux quuz?\t.Corge grault!\t"
+ "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
+ "neutral\tentailment\tneutral\tneutral\tneutral\n")
+
+ vocab = data.load_vocabulary(self._temp_data_dir)
+ self.assertSetEqual(
+ {".", "?", "!", "foo", "bar", "baz", "quux", "quuz", "corge", "grault"},
+ vocab)
+
+ def testLoadVoacbularyWithoutFileRaisesError(self):
+ with self.assertRaisesRegexp(ValueError, "Cannot find SNLI data files at"):
+ data.load_vocabulary(self._temp_data_dir)
+
+ os.makedirs(os.path.join(self._temp_data_dir, "snli"))
+ with self.assertRaisesRegexp(ValueError, "Cannot find SNLI data files at"):
+ data.load_vocabulary(self._temp_data_dir)
+
+ os.makedirs(os.path.join(self._temp_data_dir, "snli/snli_1.0"))
+ with self.assertRaisesRegexp(ValueError, "Cannot find SNLI data files at"):
+ data.load_vocabulary(self._temp_data_dir)
+
+ def testLoadWordVectors(self):
+ glove_dir = os.path.join(self._temp_data_dir, "glove")
+ os.makedirs(glove_dir)
+ glove_file = os.path.join(glove_dir, "glove.42B.300d.txt")
+
+ words = [".", ",", "foo", "bar", "baz"]
+ with open(glove_file, "wt") as f:
+ for i, word in enumerate(words):
+ f.write("%s " % word)
+ for j in range(data.WORD_VECTOR_LEN):
+ f.write("%.5f" % (i * 0.1))
+ if j < data.WORD_VECTOR_LEN - 1:
+ f.write(" ")
+ else:
+ f.write("\n")
+
+ vocab = {"foo", "bar", "baz", "qux", "."}
+ # Notice that "qux" is not present in `words`.
+ word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab)
+
+ self.assertEqual(6, len(word2index))
+ self.assertEqual(0, word2index["<unk>"])
+ self.assertEqual(1, word2index["<pad>"])
+ self.assertEqual(2, word2index["."])
+ self.assertEqual(3, word2index["foo"])
+ self.assertEqual(4, word2index["bar"])
+ self.assertEqual(5, word2index["baz"])
+ self.assertEqual((6, data.WORD_VECTOR_LEN), embed.shape)
+ self.assertAllClose([0.0] * data.WORD_VECTOR_LEN, embed[0, :])
+ self.assertAllClose([0.0] * data.WORD_VECTOR_LEN, embed[1, :])
+ self.assertAllClose([0.0] * data.WORD_VECTOR_LEN, embed[2, :])
+ self.assertAllClose([0.2] * data.WORD_VECTOR_LEN, embed[3, :])
+ self.assertAllClose([0.3] * data.WORD_VECTOR_LEN, embed[4, :])
+ self.assertAllClose([0.4] * data.WORD_VECTOR_LEN, embed[5, :])
+
+ def testLoadWordVectorsWithoutFileRaisesError(self):
+ vocab = {"foo", "bar", "baz", "qux", "."}
+ with self.assertRaisesRegexp(
+ ValueError, "Cannot find GloVe embedding file at"):
+ data.load_word_vectors(self._temp_data_dir, vocab)
+
+ os.makedirs(os.path.join(self._temp_data_dir, "glove"))
+ with self.assertRaisesRegexp(
+ ValueError, "Cannot find GloVe embedding file at"):
+ data.load_word_vectors(self._temp_data_dir, vocab)
+
+ def testSnliData(self):
+ """Unit test for SnliData objects."""
+ snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
+ fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt")
+ os.makedirs(snli_1_0_dir)
+
+ # Four sentences in total.
+ with open(fake_train_file, "wt") as f:
+ f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t"
+ "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t"
+ "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n")
+ f.write("neutral\t( ( Foo bar ) . )\t( ( foo . )\t"
+ "DummySentence1Parse\tDummySentence2Parse\t"
+ "Foo bar.\tfoo baz.\t"
+ "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
+ "neutral\tentailment\tneutral\tneutral\tneutral\n")
+ f.write("contradiction\t( ( Bar foo ) . )\t( ( baz . )\t"
+ "DummySentence1Parse\tDummySentence2Parse\t"
+ "Foo bar.\tfoo baz.\t"
+ "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
+ "neutral\tentailment\tneutral\tneutral\tneutral\n")
+ f.write("entailment\t( ( Quux quuz ) . )\t( ( grault . )\t"
+ "DummySentence1Parse\tDummySentence2Parse\t"
+ "Foo bar.\tfoo baz.\t"
+ "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
+ "neutral\tentailment\tneutral\tneutral\tneutral\n")
+ f.write("entailment\t( ( Quuz quux ) . )\t( ( garply . )\t"
+ "DummySentence1Parse\tDummySentence2Parse\t"
+ "Foo bar.\tfoo baz.\t"
+ "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
+ "neutral\tentailment\tneutral\tneutral\tneutral\n")
+
+ glove_dir = os.path.join(self._temp_data_dir, "glove")
+ os.makedirs(glove_dir)
+ glove_file = os.path.join(glove_dir, "glove.42B.300d.txt")
+
+ words = [".", "foo", "bar", "baz", "quux", "quuz", "grault", "garply"]
+ with open(glove_file, "wt") as f:
+ for i, word in enumerate(words):
+ f.write("%s " % word)
+ for j in range(data.WORD_VECTOR_LEN):
+ f.write("%.5f" % (i * 0.1))
+ if j < data.WORD_VECTOR_LEN - 1:
+ f.write(" ")
+ else:
+ f.write("\n")
+
+ vocab = data.load_vocabulary(self._temp_data_dir)
+ word2index, _ = data.load_word_vectors(self._temp_data_dir, vocab)
+
+ train_data = data.SnliData(fake_train_file, word2index)
+ self.assertEqual(4, train_data.num_batches(1))
+ self.assertEqual(2, train_data.num_batches(2))
+ self.assertEqual(2, train_data.num_batches(3))
+ self.assertEqual(1, train_data.num_batches(4))
+
+ generator = train_data.get_generator(2)()
+ for i in range(2):
+ label, prem, prem_trans, hypo, hypo_trans = next(generator)
+ self.assertEqual(2, len(label))
+ self.assertEqual((4, 2), prem.shape)
+ self.assertEqual((5, 2), prem_trans.shape)
+ self.assertEqual((3, 2), hypo.shape)
+ self.assertEqual((3, 2), hypo_trans.shape)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py
new file mode 100644
index 0000000000..84e25cf81a
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py
@@ -0,0 +1,409 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import gc
+import glob
+import os
+import shutil
+import tempfile
+import time
+
+import numpy as np
+import tensorflow as tf
+
+# pylint: disable=g-bad-import-order
+import tensorflow.contrib.eager as tfe
+from tensorflow.contrib.eager.python.examples.spinn import data
+from third_party.examples.eager.spinn import spinn
+from tensorflow.contrib.summary import summary_test_util
+from tensorflow.python.eager import test
+from tensorflow.python.framework import test_util
+# pylint: enable=g-bad-import-order
+
+
+def _generate_synthetic_snli_data_batch(sequence_length,
+ batch_size,
+ vocab_size):
+ """Generate a fake batch of SNLI data for testing."""
+ with tf.device("cpu:0"):
+ labels = tf.random_uniform([batch_size], minval=1, maxval=4, dtype=tf.int64)
+ prem = tf.random_uniform(
+ (sequence_length, batch_size), maxval=vocab_size, dtype=tf.int64)
+ prem_trans = tf.constant(np.array(
+ [[3, 3, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3,
+ 2, 3, 3, 2, 2, 3, 3, 3, 2, 2, 2, 2,
+ 3, 2, 2]] * batch_size, dtype=np.int64).T)
+ hypo = tf.random_uniform(
+ (sequence_length, batch_size), maxval=vocab_size, dtype=tf.int64)
+ hypo_trans = tf.constant(np.array(
+ [[3, 3, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3,
+ 2, 3, 3, 2, 2, 3, 3, 3, 2, 2, 2, 2,
+ 3, 2, 2]] * batch_size, dtype=np.int64).T)
+ if tfe.num_gpus():
+ labels = labels.gpu()
+ prem = prem.gpu()
+ prem_trans = prem_trans.gpu()
+ hypo = hypo.gpu()
+ hypo_trans = hypo_trans.gpu()
+ return labels, prem, prem_trans, hypo, hypo_trans
+
+
+def _test_spinn_config(d_embed, d_out, logdir=None):
+ config_tuple = collections.namedtuple(
+ "Config", ["d_hidden", "d_proj", "d_tracker", "predict",
+ "embed_dropout", "mlp_dropout", "n_mlp_layers", "d_mlp",
+ "d_out", "projection", "lr", "batch_size", "epochs",
+ "force_cpu", "logdir", "log_every", "dev_every", "save_every",
+ "lr_decay_every", "lr_decay_by"])
+ return config_tuple(
+ d_hidden=d_embed,
+ d_proj=d_embed * 2,
+ d_tracker=8,
+ predict=False,
+ embed_dropout=0.1,
+ mlp_dropout=0.1,
+ n_mlp_layers=2,
+ d_mlp=32,
+ d_out=d_out,
+ projection=True,
+ lr=2e-2,
+ batch_size=2,
+ epochs=10,
+ force_cpu=False,
+ logdir=logdir,
+ log_every=1,
+ dev_every=2,
+ save_every=2,
+ lr_decay_every=1,
+ lr_decay_by=0.75)
+
+
+class SpinnTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ super(SpinnTest, self).setUp()
+ self._test_device = "gpu:0" if tfe.num_gpus() else "cpu:0"
+ self._temp_data_dir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ shutil.rmtree(self._temp_data_dir)
+ super(SpinnTest, self).tearDown()
+
+ def testBundle(self):
+ with tf.device(self._test_device):
+ lstm_iter = [np.array([[0, 1], [2, 3]], dtype=np.float32),
+ np.array([[0, -1], [-2, -3]], dtype=np.float32),
+ np.array([[0, 2], [4, 6]], dtype=np.float32),
+ np.array([[0, -2], [-4, -6]], dtype=np.float32)]
+ out = spinn._bundle(lstm_iter)
+
+ self.assertEqual(2, len(out))
+ self.assertEqual(tf.float32, out[0].dtype)
+ self.assertEqual(tf.float32, out[1].dtype)
+ self.assertAllEqual(np.array([[0, 2, 0, -2, 0, 4, 0, -4]]).T,
+ out[0].numpy())
+ self.assertAllEqual(np.array([[1, 3, -1, -3, 2, 6, -2, -6]]).T,
+ out[1].numpy())
+
+ def testUnbunbdle(self):
+ with tf.device(self._test_device):
+ state = [np.array([[0, 1, 2], [3, 4, 5]], dtype=np.float32),
+ np.array([[0, -1, -2], [-3, -4, -5]], dtype=np.float32)]
+ out = spinn._unbundle(state)
+
+ self.assertEqual(2, len(out))
+ self.assertEqual(tf.float32, out[0].dtype)
+ self.assertEqual(tf.float32, out[1].dtype)
+ self.assertAllEqual(np.array([[0, 1, 2, 0, -1, -2]]),
+ out[0].numpy())
+ self.assertAllEqual(np.array([[3, 4, 5, -3, -4, -5]]),
+ out[1].numpy())
+
+ def testReducer(self):
+ with tf.device(self._test_device):
+ batch_size = 3
+ size = 10
+ tracker_size = 8
+ reducer = spinn.Reducer(size, tracker_size=tracker_size)
+
+ left_in = []
+ right_in = []
+ tracking = []
+ for _ in range(batch_size):
+ left_in.append(tf.random_normal((1, size * 2)))
+ right_in.append(tf.random_normal((1, size * 2)))
+ tracking.append(tf.random_normal((1, tracker_size * 2)))
+
+ out = reducer(left_in, right_in, tracking=tracking)
+ self.assertEqual(batch_size, len(out))
+ self.assertEqual(tf.float32, out[0].dtype)
+ self.assertEqual((1, size * 2), out[0].shape)
+
+ def testReduceTreeLSTM(self):
+ with tf.device(self._test_device):
+ size = 10
+ tracker_size = 8
+ reducer = spinn.Reducer(size, tracker_size=tracker_size)
+
+ lstm_in = np.array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
+ [0, -1, -2, -3, -4, -5, -6, -7, -8, -9]],
+ dtype=np.float32)
+ c1 = np.array([[0, 1], [2, 3]], dtype=np.float32)
+ c2 = np.array([[0, -1], [-2, -3]], dtype=np.float32)
+
+ h, c = reducer._tree_lstm(c1, c2, lstm_in)
+ self.assertEqual(tf.float32, h.dtype)
+ self.assertEqual(tf.float32, c.dtype)
+ self.assertEqual((2, 2), h.shape)
+ self.assertEqual((2, 2), c.shape)
+
+ def testTracker(self):
+ with tf.device(self._test_device):
+ batch_size = 2
+ size = 10
+ tracker_size = 8
+ buffer_length = 18
+ stack_size = 3
+
+ tracker = spinn.Tracker(tracker_size, False)
+ tracker.reset_state()
+
+ # Create dummy inputs for testing.
+ bufs = []
+ buf = []
+ for _ in range(buffer_length):
+ buf.append(tf.random_normal((batch_size, size * 2)))
+ bufs.append(buf)
+ self.assertEqual(1, len(bufs))
+ self.assertEqual(buffer_length, len(bufs[0]))
+ self.assertEqual((batch_size, size * 2), bufs[0][0].shape)
+
+ stacks = []
+ stack = []
+ for _ in range(stack_size):
+ stack.append(tf.random_normal((batch_size, size * 2)))
+ stacks.append(stack)
+ self.assertEqual(1, len(stacks))
+ self.assertEqual(3, len(stacks[0]))
+ self.assertEqual((batch_size, size * 2), stacks[0][0].shape)
+
+ for _ in range(2):
+ out1, out2 = tracker(bufs, stacks)
+ self.assertIsNone(out2)
+ self.assertEqual(batch_size, len(out1))
+ self.assertEqual(tf.float32, out1[0].dtype)
+ self.assertEqual((1, tracker_size * 2), out1[0].shape)
+
+ self.assertEqual(tf.float32, tracker.state.c.dtype)
+ self.assertEqual((batch_size, tracker_size), tracker.state.c.shape)
+ self.assertEqual(tf.float32, tracker.state.h.dtype)
+ self.assertEqual((batch_size, tracker_size), tracker.state.h.shape)
+
+ def testSPINN(self):
+ with tf.device(self._test_device):
+ embedding_dims = 10
+ d_tracker = 8
+ sequence_length = 15
+ num_transitions = 27
+
+ config_tuple = collections.namedtuple(
+ "Config", ["d_hidden", "d_proj", "d_tracker", "predict"])
+ config = config_tuple(
+ embedding_dims, embedding_dims * 2, d_tracker, False)
+ s = spinn.SPINN(config)
+
+ # Create some fake data.
+ buffers = tf.random_normal((sequence_length, 1, config.d_proj))
+ transitions = tf.constant(
+ [[3], [3], [2], [3], [3], [3], [2], [2], [2], [3], [3], [3],
+ [2], [3], [3], [2], [2], [3], [3], [3], [2], [2], [2], [2],
+ [3], [2], [2]], dtype=tf.int64)
+ self.assertEqual(tf.int64, transitions.dtype)
+ self.assertEqual((num_transitions, 1), transitions.shape)
+
+ out = s(buffers, transitions, training=True)
+ self.assertEqual(tf.float32, out.dtype)
+ self.assertEqual((1, embedding_dims), out.shape)
+
+ def testSNLIClassifierAndTrainer(self):
+ with tf.device(self._test_device):
+ vocab_size = 40
+ batch_size = 2
+ d_embed = 10
+ sequence_length = 15
+ d_out = 4
+
+ config = _test_spinn_config(d_embed, d_out)
+
+ # Create fake embedding matrix.
+ embed = tf.random_normal((vocab_size, d_embed))
+
+ model = spinn.SNLIClassifier(config, embed)
+ trainer = spinn.SNLIClassifierTrainer(model, config.lr)
+
+ (labels, prem, prem_trans, hypo,
+ hypo_trans) = _generate_synthetic_snli_data_batch(sequence_length,
+ batch_size,
+ vocab_size)
+
+ # Invoke model under non-training mode.
+ logits = model(prem, prem_trans, hypo, hypo_trans, training=False)
+ self.assertEqual(tf.float32, logits.dtype)
+ self.assertEqual((batch_size, d_out), logits.shape)
+
+ # Invoke model under training model.
+ logits = model(prem, prem_trans, hypo, hypo_trans, training=True)
+ self.assertEqual(tf.float32, logits.dtype)
+ self.assertEqual((batch_size, d_out), logits.shape)
+
+ # Calculate loss.
+ loss1 = trainer.loss(labels, logits)
+ self.assertEqual(tf.float32, loss1.dtype)
+ self.assertEqual((), loss1.shape)
+
+ loss2, logits = trainer.train_batch(
+ labels, prem, prem_trans, hypo, hypo_trans)
+ self.assertEqual(tf.float32, loss2.dtype)
+ self.assertEqual((), loss2.shape)
+ self.assertEqual(tf.float32, logits.dtype)
+ self.assertEqual((batch_size, d_out), logits.shape)
+ # Training on the batch should have led to a change in the loss value.
+ self.assertNotEqual(loss1.numpy(), loss2.numpy())
+
+ def testTrainSpinn(self):
+ """Test with fake toy SNLI data and GloVe vectors."""
+
+ # 1. Create and load a fake SNLI data file and a fake GloVe embedding file.
+ snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
+ fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt")
+ os.makedirs(snli_1_0_dir)
+
+ # Four sentences in total.
+ with open(fake_train_file, "wt") as f:
+ f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t"
+ "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t"
+ "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n")
+ f.write("neutral\t( ( Foo bar ) . )\t( ( foo . )\t"
+ "DummySentence1Parse\tDummySentence2Parse\t"
+ "Foo bar.\tfoo baz.\t"
+ "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
+ "neutral\tentailment\tneutral\tneutral\tneutral\n")
+ f.write("contradiction\t( ( Bar foo ) . )\t( ( baz . )\t"
+ "DummySentence1Parse\tDummySentence2Parse\t"
+ "Foo bar.\tfoo baz.\t"
+ "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
+ "neutral\tentailment\tneutral\tneutral\tneutral\n")
+ f.write("entailment\t( ( Quux quuz ) . )\t( ( grault . )\t"
+ "DummySentence1Parse\tDummySentence2Parse\t"
+ "Foo bar.\tfoo baz.\t"
+ "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
+ "neutral\tentailment\tneutral\tneutral\tneutral\n")
+ f.write("entailment\t( ( Quuz quux ) . )\t( ( garply . )\t"
+ "DummySentence1Parse\tDummySentence2Parse\t"
+ "Foo bar.\tfoo baz.\t"
+ "4705552913.jpg#2\t4705552913.jpg#2r1n\t"
+ "neutral\tentailment\tneutral\tneutral\tneutral\n")
+
+ glove_dir = os.path.join(self._temp_data_dir, "glove")
+ os.makedirs(glove_dir)
+ glove_file = os.path.join(glove_dir, "glove.42B.300d.txt")
+
+ words = [".", "foo", "bar", "baz", "quux", "quuz", "grault", "garply"]
+ with open(glove_file, "wt") as f:
+ for i, word in enumerate(words):
+ f.write("%s " % word)
+ for j in range(data.WORD_VECTOR_LEN):
+ f.write("%.5f" % (i * 0.1))
+ if j < data.WORD_VECTOR_LEN - 1:
+ f.write(" ")
+ else:
+ f.write("\n")
+
+ vocab = data.load_vocabulary(self._temp_data_dir)
+ word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab)
+
+ train_data = data.SnliData(fake_train_file, word2index)
+ dev_data = data.SnliData(fake_train_file, word2index)
+ test_data = data.SnliData(fake_train_file, word2index)
+ print(embed)
+
+ # 2. Create a fake config.
+ config = _test_spinn_config(
+ data.WORD_VECTOR_LEN, 4,
+ logdir=os.path.join(self._temp_data_dir, "logdir"))
+
+ # 3. Test training of a SPINN model.
+ spinn.train_spinn(embed, train_data, dev_data, test_data, config)
+
+ # 4. Load train loss values from the summary files and verify that they
+ # decrease with training.
+ summary_file = glob.glob(os.path.join(config.logdir, "events.out.*"))[0]
+ events = summary_test_util.events_from_file(summary_file)
+ train_losses = [event.summary.value[0].simple_value for event in events
+ if event.summary.value
+ and event.summary.value[0].tag == "train/loss"]
+ self.assertEqual(config.epochs, len(train_losses))
+ self.assertLess(train_losses[-1], train_losses[0])
+
+
+class EagerSpinnSNLIClassifierBenchmark(test.Benchmark):
+
+ def benchmarkEagerSpinnSNLIClassifier(self):
+ test_device = "gpu:0" if tfe.num_gpus() else "cpu:0"
+ with tf.device(test_device):
+ burn_in_iterations = 2
+ benchmark_iterations = 10
+
+ vocab_size = 1000
+ batch_size = 128
+ sequence_length = 15
+ d_embed = 200
+ d_out = 4
+
+ embed = tf.random_normal((vocab_size, d_embed))
+
+ config = _test_spinn_config(d_embed, d_out)
+ model = spinn.SNLIClassifier(config, embed)
+ trainer = spinn.SNLIClassifierTrainer(model, config.lr)
+
+ (labels, prem, prem_trans, hypo,
+ hypo_trans) = _generate_synthetic_snli_data_batch(sequence_length,
+ batch_size,
+ vocab_size)
+
+ for _ in range(burn_in_iterations):
+ trainer.train_batch(labels, prem, prem_trans, hypo, hypo_trans)
+
+ gc.collect()
+ start_time = time.time()
+ for _ in xrange(benchmark_iterations):
+ trainer.train_batch(labels, prem, prem_trans, hypo, hypo_trans)
+ wall_time = time.time() - start_time
+ # Named "examples"_per_sec to conform with other benchmarks.
+ extras = {"examples_per_sec": benchmark_iterations / wall_time}
+ self.report_benchmark(
+ name="Eager_SPINN_SNLIClassifier_Benchmark",
+ iters=benchmark_iterations,
+ wall_time=wall_time,
+ extras=extras)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/third_party/examples/eager/spinn/BUILD b/third_party/examples/eager/spinn/BUILD
new file mode 100644
index 0000000000..0e39d4696f
--- /dev/null
+++ b/third_party/examples/eager/spinn/BUILD
@@ -0,0 +1,14 @@
+licenses(["notice"]) # 3-clause BSD.
+
+py_binary(
+ name = "spinn",
+ srcs = ["spinn.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/contrib/eager/python:tfe",
+ "//tensorflow/contrib/eager/python/examples/spinn:data",
+ "@six_archive//:six",
+ ],
+)
diff --git a/third_party/examples/eager/spinn/LICENSE b/third_party/examples/eager/spinn/LICENSE
new file mode 100644
index 0000000000..09d493bf1f
--- /dev/null
+++ b/third_party/examples/eager/spinn/LICENSE
@@ -0,0 +1,29 @@
+BSD 3-Clause License
+
+Copyright (c) 2017,
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+* Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/third_party/examples/eager/spinn/README.md b/third_party/examples/eager/spinn/README.md
new file mode 100644
index 0000000000..c00d8d9015
--- /dev/null
+++ b/third_party/examples/eager/spinn/README.md
@@ -0,0 +1,54 @@
+# SPINN with TensorFlow eager execution
+
+SPINN, or Stack-Augmented Parser-Interpreter Neural Network, is a recursive
+neural network that utilizes syntactic parse information for natural language
+understanding.
+
+SPINN was originally described by:
+Bowman, S.R., Gauthier, J., Rastogi A., Gupta, R., Manning, C.D., & Potts, C.
+ (2016). A Fast Unified Model for Parsing and Sentence Understanding.
+ https://arxiv.org/abs/1603.06021
+
+Our implementation is based on @jekbradbury's PyTorch implementation at:
+https://github.com/jekbradbury/examples/blob/spinn/snli/spinn.py,
+
+which was released under the BSD 3-Clause License at:
+https://github.com/jekbradbury/examples/blob/spinn/LICENSE
+
+## Content
+
+Python source file(s):
+- `spinn.py`: Model definition and training routines written with TensorFlow
+ eager execution idioms.
+
+## To run
+
+- Make sure you have installed the latest `tf-nightly` or `tf-nightly-gpu` pip
+ package of TensorFlow in order to access the eager execution feature.
+
+- Download and extract the raw SNLI data and GloVe embedding vectors.
+ For example:
+
+ ```bash
+ curl -fSsL https://nlp.stanford.edu/projects/snli/snli_1.0.zip --create-dirs -o /tmp/spinn-data/snli/snli_1.0.zip
+ unzip -d /tmp/spinn-data/snli /tmp/spinn-data/snli/snli_1.0.zip
+ curl -fSsL http://nlp.stanford.edu/data/glove.42B.300d.zip --create-dirs -o /tmp/spinn-data/glove/glove.42B.300d.zip
+ unzip -d /tmp/spinn-data/glove /tmp/spinn-data/glove/glove.42B.300d.zip
+ ```
+
+- Train model. E.g.,
+
+ ```bash
+ python spinn.py --data_root /tmp/spinn-data --logdir /tmp/spinn-logs
+ ```
+
+ During training, model checkpoints and TensorBoard summaries will be written
+ periodically to the directory specified with the `--logdir` flag.
+ The training script will reload a saved checkpoint from the directory if it
+ can find one there.
+
+ To view the summaries with TensorBoard:
+
+ ```bash
+ tensorboard --logdir /tmp/spinn-logs
+ ```
diff --git a/third_party/examples/eager/spinn/spinn.py b/third_party/examples/eager/spinn/spinn.py
new file mode 100644
index 0000000000..963ac0e65b
--- /dev/null
+++ b/third_party/examples/eager/spinn/spinn.py
@@ -0,0 +1,732 @@
+r"""Implementation of SPINN in TensorFlow eager execution.
+
+SPINN: Stack-Augmented Parser-Interpreter Neural Network.
+
+Ths file contains model definition and code for training the model.
+
+The model definition is based on PyTorch implementation at:
+ https://github.com/jekbradbury/examples/tree/spinn/snli
+
+which was released under a BSD 3-Clause License at:
+https://github.com/jekbradbury/examples/blob/spinn/LICENSE:
+
+Copyright (c) 2017,
+All rights reserved.
+
+See ./LICENSE for more details.
+
+Instructions for use:
+* See `README.md` for details on how to prepare the SNLI and GloVe data.
+* Suppose you have prepared the data at "/tmp/spinn-data", use the folloing
+ command to train the model:
+
+ ```bash
+ python spinn.py --data_root /tmp/spinn-data --logdir /tmp/spinn-logs
+ ```
+
+ Checkpoints and TensorBoard summaries will be written to "/tmp/spinn-logs".
+
+References:
+* Bowman, S.R., Gauthier, J., Rastogi A., Gupta, R., Manning, C.D., & Potts, C.
+ (2016). A Fast Unified Model for Parsing and Sentence Understanding.
+ https://arxiv.org/abs/1603.06021
+* Bradbury, J. (2017). Recursive Neural Networks with PyTorch.
+ https://devblogs.nvidia.com/parallelforall/recursive-neural-networks-pytorch/
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import itertools
+import os
+import sys
+import time
+
+from six.moves import xrange # pylint: disable=redefined-builtin
+import tensorflow as tf
+
+import tensorflow.contrib.eager as tfe
+from tensorflow.contrib.eager.python.examples.spinn import data
+
+
+def _bundle(lstm_iter):
+ """Concatenate a list of Tensors along 1st axis and split result into two.
+
+ Args:
+ lstm_iter: A `list` of `N` dense `Tensor`s, each of which has the shape
+ (R, 2 * M).
+
+ Returns:
+ A `list` of two dense `Tensor`s, each of which has the shape (N * R, M).
+ """
+ return tf.split(tf.concat(lstm_iter, 0), 2, axis=1)
+
+
+def _unbundle(state):
+ """Concatenate a list of Tensors along 2nd axis and split result.
+
+ This is the inverse of `_bundle`.
+
+ Args:
+ state: A `list` of two dense `Tensor`s, each of which has the shape (R, M).
+
+ Returns:
+ A `list` of `R` dense `Tensors`, each of which has the shape (1, 2 * M).
+ """
+ return tf.split(tf.concat(state, 1), state[0].shape[0], axis=0)
+
+
+class Reducer(tfe.Network):
+ """A module that applies reduce operation on left and right vectors."""
+
+ def __init__(self, size, tracker_size=None):
+ super(Reducer, self).__init__()
+ self.left = self.track_layer(tf.layers.Dense(5 * size, activation=None))
+ self.right = self.track_layer(
+ tf.layers.Dense(5 * size, activation=None, use_bias=False))
+ if tracker_size is not None:
+ self.track = self.track_layer(
+ tf.layers.Dense(5 * size, activation=None, use_bias=False))
+ else:
+ self.track = None
+
+ def call(self, left_in, right_in, tracking=None):
+ """Invoke forward pass of the Reduce module.
+
+ This method feeds a linear combination of `left_in`, `right_in` and
+ `tracking` into a Tree LSTM and returns the output of the Tree LSTM.
+
+ Args:
+ left_in: A list of length L. Each item is a dense `Tensor` with
+ the shape (1, n_dims). n_dims is the size of the embedding vector.
+ right_in: A list of the same length as `left_in`. Each item should have
+ the same shape as the items of `left_in`.
+ tracking: Optional list of the same length as `left_in`. Each item is a
+ dense `Tensor` with shape (1, tracker_size * 2). tracker_size is the
+ size of the Tracker's state vector.
+
+ Returns:
+ Output: A list of length batch_size. Each item has the shape (1, n_dims).
+ """
+ left, right = _bundle(left_in), _bundle(right_in)
+ lstm_in = self.left(left[0]) + self.right(right[0])
+ if self.track and tracking:
+ lstm_in += self.track(_bundle(tracking)[0])
+ return _unbundle(self._tree_lstm(left[1], right[1], lstm_in))
+
+ def _tree_lstm(self, c1, c2, lstm_in):
+ a, i, f1, f2, o = tf.split(lstm_in, 5, axis=1)
+ c = tf.tanh(a) * tf.sigmoid(i) + tf.sigmoid(f1) * c1 + tf.sigmoid(f2) * c2
+ h = tf.sigmoid(o) * tf.tanh(c)
+ return h, c
+
+
+class Tracker(tfe.Network):
+ """A module that tracks the history of the sentence with an LSTM."""
+
+ def __init__(self, tracker_size, predict):
+ """Constructor of Tracker.
+
+ Args:
+ tracker_size: Number of dimensions of the underlying `LSTMCell`.
+ predict: (`bool`) Whether prediction mode is enabled.
+ """
+ super(Tracker, self).__init__()
+ self._rnn = self.track_layer(tf.nn.rnn_cell.LSTMCell(tracker_size))
+ self._state_size = tracker_size
+ if predict:
+ self._transition = self.track_layer(tf.layers.Dense(4))
+ else:
+ self._transition = None
+
+ def reset_state(self):
+ self.state = None
+
+ def call(self, bufs, stacks):
+ """Invoke the forward pass of the Tracker module.
+
+ This method feeds the concatenation of the top two elements of the stacks
+ into an LSTM cell and returns the resultant state of the LSTM cell.
+
+ Args:
+ bufs: A `list` of length batch_size. Each item is a `list` of
+ max_sequence_len (maximum sequence length of the batch). Each item
+ of the nested list is a dense `Tensor` of shape (1, d_proj), where
+ d_proj is the size of the word embedding vector or the size of the
+ vector space that the word embedding vector is projected to.
+ stacks: A `list` of size batch_size. Each item is a `list` of
+ variable length corresponding to the current height of the stack.
+ Each item of the nested list is a dense `Tensor` of shape (1, d_proj).
+
+ Returns:
+ 1. A list of length batch_size. Each item is a dense `Tensor` of shape
+ (1, d_tracker * 2).
+ 2. If under predict mode, result of applying a Dense layer on the
+ first state vector of the RNN. Else, `None`.
+ """
+ buf = _bundle([buf[-1] for buf in bufs])[0]
+ stack1 = _bundle([stack[-1] for stack in stacks])[0]
+ stack2 = _bundle([stack[-2] for stack in stacks])[0]
+ x = tf.concat([buf, stack1, stack2], 1)
+ if self.state is None:
+ batch_size = int(x.shape[0])
+ zeros = tf.zeros((batch_size, self._state_size), dtype=tf.float32)
+ self.state = [zeros, zeros]
+ _, self.state = self._rnn(x, self.state)
+ unbundled = _unbundle(self.state)
+ if self._transition:
+ return unbundled, self._transition(self.state[0])
+ else:
+ return unbundled, None
+
+
+class SPINN(tfe.Network):
+ """Stack-augmented Parser-Interpreter Neural Network.
+
+ See https://arxiv.org/abs/1603.06021 for more details.
+ """
+
+ def __init__(self, config):
+ """Constructor of SPINN.
+
+ Args:
+ config: A `namedtupled` with the following attributes.
+ d_proj - (`int`) number of dimensions of the vector space to project the
+ word embeddings to.
+ d_tracker - (`int`) number of dimensions of the Tracker's state vector.
+ d_hidden - (`int`) number of the dimensions of the hidden state, for the
+ Reducer module.
+ n_mlp_layers - (`int`) number of multi-layer perceptron layers to use to
+ convert the output of the `Feature` module to logits.
+ predict - (`bool`) Whether the Tracker will enabled predictions.
+ """
+ super(SPINN, self).__init__()
+ self.config = config
+ self.reducer = self.track_layer(Reducer(config.d_hidden, config.d_tracker))
+ if config.d_tracker is not None:
+ self.tracker = self.track_layer(Tracker(config.d_tracker, config.predict))
+ else:
+ self.tracker = None
+
+ def call(self, buffers, transitions, training=False):
+ """Invoke the forward pass of the SPINN model.
+
+ Args:
+ buffers: Dense `Tensor` of shape
+ (max_sequence_len, batch_size, config.d_proj).
+ transitions: Dense `Tensor` with integer values that represent the parse
+ trees of the sentences. A value of 2 indicates "reduce"; a value of 3
+ indicates "shift". Shape: (max_sequence_len * 2 - 3, batch_size).
+ training: Whether the invocation is under training mode.
+
+ Returns:
+ Output `Tensor` of shape (batch_size, config.d_embed).
+ """
+ max_sequence_len, batch_size, d_proj = (int(x) for x in buffers.shape)
+
+ # Split the buffers into left and right word items and put the initial
+ # items in a stack.
+ splitted = tf.split(
+ tf.reshape(tf.transpose(buffers, [1, 0, 2]), [-1, d_proj]),
+ max_sequence_len * batch_size, axis=0)
+ buffers = [splitted[k:k + max_sequence_len]
+ for k in xrange(0, len(splitted), max_sequence_len)]
+ stacks = [[buf[0], buf[0]] for buf in buffers]
+
+ if self.tracker:
+ # Reset tracker state for new batch.
+ self.tracker.reset_state()
+
+ num_transitions = transitions.shape[0]
+
+ # Iterate through transitions and perform the appropriate stack-pop, reduce
+ # and stack-push operations.
+ transitions = transitions.numpy()
+ for i in xrange(num_transitions):
+ trans = transitions[i]
+ if self.tracker:
+ # Invoke tracker to obtain the current tracker states for the sentences.
+ tracker_states, trans_hypothesis = self.tracker(buffers, stacks)
+ if trans_hypothesis:
+ trans = tf.argmax(trans_hypothesis, axis=-1)
+ else:
+ tracker_states = itertools.repeat(None)
+ lefts, rights, trackings = [], [], []
+ for transition, buf, stack, tracking in zip(
+ trans, buffers, stacks, tracker_states):
+ if int(transition) == 3: # Shift.
+ stack.append(buf.pop())
+ elif int(transition) == 2: # Reduce.
+ rights.append(stack.pop())
+ lefts.append(stack.pop())
+ trackings.append(tracking)
+
+ if rights:
+ reducer_output = self.reducer(lefts, rights, trackings)
+ reduced = iter(reducer_output)
+
+ for transition, stack in zip(trans, stacks):
+ if int(transition) == 2: # Reduce.
+ stack.append(next(reduced))
+ return _bundle([stack.pop() for stack in stacks])[0]
+
+
+class SNLIClassifier(tfe.Network):
+ """SNLI Classifier Model.
+
+ A model aimed at solving the SNLI (Standford Natural Language Inference)
+ task, using the SPINN model from above. For details of the task, see:
+ https://nlp.stanford.edu/projects/snli/
+ """
+
+ def __init__(self, config, embed):
+ """Constructor of SNLICLassifier.
+
+ Args:
+ config: A namedtuple containing required configurations for the model. It
+ needs to have the following attributes.
+ projection - (`bool`) whether the word vectors are to be projected onto
+ another vector space (of `d_proj` dimensions).
+ d_proj - (`int`) number of dimensions of the vector space to project the
+ word embeddings to.
+ embed_dropout - (`float`) dropout rate for the word embedding vectors.
+ n_mlp_layers - (`int`) number of multi-layer perceptron (MLP) layers to
+ use to convert the output of the `Feature` module to logits.
+ mlp_dropout - (`float`) dropout rate of the MLP layers.
+ d_out - (`int`) number of dimensions of the final output of the MLP
+ layers.
+ lr - (`float`) learning rate.
+ embed: A embedding matrix of shape (vocab_size, d_embed).
+ """
+ super(SNLIClassifier, self).__init__()
+ self.config = config
+ self.embed = tf.constant(embed)
+
+ self.projection = self.track_layer(tf.layers.Dense(config.d_proj))
+ self.embed_bn = self.track_layer(tf.layers.BatchNormalization())
+ self.embed_dropout = self.track_layer(
+ tf.layers.Dropout(rate=config.embed_dropout))
+ self.encoder = self.track_layer(SPINN(config))
+
+ self.feature_bn = self.track_layer(tf.layers.BatchNormalization())
+ self.feature_dropout = self.track_layer(
+ tf.layers.Dropout(rate=config.mlp_dropout))
+
+ self.mlp_dense = []
+ self.mlp_bn = []
+ self.mlp_dropout = []
+ for _ in xrange(config.n_mlp_layers):
+ self.mlp_dense.append(self.track_layer(tf.layers.Dense(config.d_mlp)))
+ self.mlp_bn.append(
+ self.track_layer(tf.layers.BatchNormalization()))
+ self.mlp_dropout.append(
+ self.track_layer(tf.layers.Dropout(rate=config.mlp_dropout)))
+ self.mlp_output = self.track_layer(tf.layers.Dense(
+ config.d_out,
+ kernel_initializer=tf.random_uniform_initializer(minval=-5e-3,
+ maxval=5e-3)))
+
+ def call(self,
+ premise,
+ premise_transition,
+ hypothesis,
+ hypothesis_transition,
+ training=False):
+ """Invoke the forward pass the SNLIClassifier model.
+
+ Args:
+ premise: The word indices of the premise sentences, with shape
+ (max_prem_seq_len, batch_size).
+ premise_transition: The transitions for the premise sentences, with shape
+ (max_prem_seq_len * 2 - 3, batch_size).
+ hypothesis: The word indices of the hypothesis sentences, with shape
+ (max_hypo_seq_len, batch_size).
+ hypothesis_transition: The transitions for the hypothesis sentences, with
+ shape (max_hypo_seq_len * 2 - 3, batch_size).
+ training: Whether the invocation is under training mode.
+
+ Returns:
+ The logits, as a dense `Tensor` of shape (batch_size, d_out), where d_out
+ is the size of the output vector.
+ """
+ # Perform embedding lookup on the premise and hypothesis inputs, which have
+ # the word-index format.
+ premise_embed = tf.nn.embedding_lookup(self.embed, premise)
+ hypothesis_embed = tf.nn.embedding_lookup(self.embed, hypothesis)
+
+ if self.config.projection:
+ # Project the embedding vectors to another vector space.
+ premise_embed = self.projection(premise_embed)
+ hypothesis_embed = self.projection(hypothesis_embed)
+
+ # Perform batch normalization and dropout on the possibly projected word
+ # vectors.
+ premise_embed = self.embed_bn(premise_embed, training=training)
+ hypothesis_embed = self.embed_bn(hypothesis_embed, training=training)
+ premise_embed = self.embed_dropout(premise_embed, training=training)
+ hypothesis_embed = self.embed_dropout(hypothesis_embed, training=training)
+
+ # Run the batch-normalized and dropout-processed word vectors through the
+ # SPINN encoder.
+ premise = self.encoder(premise_embed, premise_transition,
+ training=training)
+ hypothesis = self.encoder(hypothesis_embed, hypothesis_transition,
+ training=training)
+
+ # Combine encoder outputs for premises and hypotheses into logits.
+ # Then apply batch normalization and dropuout on the logits.
+ logits = tf.concat(
+ [premise, hypothesis, premise - hypothesis, premise * hypothesis], 1)
+ logits = self.feature_dropout(
+ self.feature_bn(logits, training=training), training=training)
+
+ # Apply the multi-layer perceptron on the logits.
+ for dense, bn, dropout in zip(
+ self.mlp_dense, self.mlp_bn, self.mlp_dropout):
+ logits = tf.nn.elu(dense(logits))
+ logits = dropout(bn(logits, training=training), training=training)
+ logits = self.mlp_output(logits)
+ return logits
+
+
+class SNLIClassifierTrainer(object):
+ """A class that coordinates the training of an SNLIClassifier."""
+
+ def __init__(self, snli_classifier, lr):
+ """Constructor of SNLIClassifierTrainer.
+
+ Args:
+ snli_classifier: An instance of `SNLIClassifier`.
+ lr: Learning rate.
+ """
+ self._model = snli_classifier
+ # Create a custom learning rate Variable for the RMSProp optimizer, because
+ # the learning rate needs to be manually decayed later (see
+ # decay_learning_rate()).
+ self._learning_rate = tfe.Variable(lr, name="learning_rate")
+ self._optimizer = tf.train.RMSPropOptimizer(self._learning_rate,
+ epsilon=1e-6)
+
+ def loss(self, labels, logits):
+ """Calculate the loss given a batch of data.
+
+ Args:
+ labels: The truth labels, with shape (batch_size,).
+ logits: The logits output from the forward pass of the SNLIClassifier
+ model, with shape (batch_size, d_out), where d_out is the output
+ dimension size of the SNLIClassifier.
+
+ Returns:
+ The loss value, as a scalar `Tensor`.
+ """
+ return tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
+ labels=labels, logits=logits))
+
+ def train_batch(self,
+ labels,
+ premise,
+ premise_transition,
+ hypothesis,
+ hypothesis_transition):
+ """Train model on batch of data.
+
+ Args:
+ labels: The truth labels, with shape (batch_size,).
+ premise: The word indices of the premise sentences, with shape
+ (max_prem_seq_len, batch_size).
+ premise_transition: The transitions for the premise sentences, with shape
+ (max_prem_seq_len * 2 - 3, batch_size).
+ hypothesis: The word indices of the hypothesis sentences, with shape
+ (max_hypo_seq_len, batch_size).
+ hypothesis_transition: The transitions for the hypothesis sentences, with
+ shape (max_hypo_seq_len * 2 - 3, batch_size).
+
+ Returns:
+ 1. loss value as a scalar `Tensor`.
+ 2. logits as a dense `Tensor` of shape (batch_size, d_out), where d_out is
+ the output dimension size of the SNLIClassifier.
+ """
+ with tfe.GradientTape() as tape:
+ tape.watch(self._model.variables)
+ logits = self._model(premise,
+ premise_transition,
+ hypothesis,
+ hypothesis_transition,
+ training=True)
+ loss = self.loss(labels, logits)
+ gradients = tape.gradient(loss, self._model.variables)
+ self._optimizer.apply_gradients(zip(gradients, self._model.variables),
+ global_step=tf.train.get_global_step())
+ return loss, logits
+
+ def decay_learning_rate(self, decay_by):
+ """Decay learning rate of the optimizer by factor decay_by."""
+ self._learning_rate.assign(self._learning_rate * decay_by)
+ print("Decayed learning rate of optimizer to: %s" %
+ self._learning_rate.numpy())
+
+ @property
+ def learning_rate(self):
+ return self._learning_rate
+
+
+def _batch_n_correct(logits, label):
+ """Calculate number of correct predictions in a batch.
+
+ Args:
+ logits: A logits Tensor of shape `(batch_size, num_categories)` and dtype
+ `float32`.
+ label: A labels Tensor of shape `(batch_size,)` and dtype `int64`
+
+ Returns:
+ Number of correct predictions.
+ """
+ return tf.reduce_sum(
+ tf.cast((tf.equal(
+ tf.argmax(logits, axis=1), label)), tf.float32)).numpy()
+
+
+def _evaluate_on_dataset(snli_data, batch_size, model, trainer, use_gpu):
+ """Run evaluation on a dataset.
+
+ Args:
+ snli_data: The `data.SnliData` to use in this evaluation.
+ batch_size: The batch size to use during this evaluation.
+ model: An instance of `SNLIClassifier` to evaluate.
+ trainer: An instance of `SNLIClassifierTrainer to use for this
+ evaluation.
+ use_gpu: Whether GPU is being used.
+
+ Returns:
+ 1. Average loss across all examples of the dataset.
+ 2. Average accuracy rate across all examples of the dataset.
+ """
+ mean_loss = tfe.metrics.Mean()
+ accuracy = tfe.metrics.Accuracy()
+ for label, prem, prem_trans, hypo, hypo_trans in _get_dataset_iterator(
+ snli_data, batch_size):
+ if use_gpu:
+ label, prem, hypo = label.gpu(), prem.gpu(), hypo.gpu()
+ logits = model(prem, prem_trans, hypo, hypo_trans, training=False)
+ loss_val = trainer.loss(label, logits)
+ batch_size = tf.shape(label)[0]
+ mean_loss(loss_val, weights=batch_size.gpu() if use_gpu else batch_size)
+ accuracy(tf.argmax(logits, axis=1), label)
+ return mean_loss.result().numpy(), accuracy.result().numpy()
+
+
+def _get_dataset_iterator(snli_data, batch_size):
+ """Get a data iterator for a split of SNLI data.
+
+ Args:
+ snli_data: A `data.SnliData` object.
+ batch_size: The desired batch size.
+
+ Returns:
+ A dataset iterator.
+ """
+ with tf.device("/device:CPU:0"):
+ # Some tf.data ops, such as ShuffleDataset, are available only on CPU.
+ dataset = tf.data.Dataset.from_generator(
+ snli_data.get_generator(batch_size),
+ (tf.int64, tf.int64, tf.int64, tf.int64, tf.int64))
+ dataset = dataset.shuffle(snli_data.num_batches(batch_size))
+ return tfe.Iterator(dataset)
+
+
+def train_spinn(embed, train_data, dev_data, test_data, config):
+ """Train a SPINN model.
+
+ Args:
+ embed: The embedding matrix as a float32 numpy array with shape
+ [vocabulary_size, word_vector_len]. word_vector_len is the length of a
+ word embedding vector.
+ train_data: An instance of `data.SnliData`, for the train split.
+ dev_data: Same as above, for the dev split.
+ test_data: Same as above, for the test split.
+ config: A configuration object. See the argument to this Python binary for
+ details.
+
+ Returns:
+ 1. Final loss value on the test split.
+ 2. Final fraction of correct classifications on the test split.
+ """
+ use_gpu = tfe.num_gpus() > 0 and not config.force_cpu
+ device = "gpu:0" if use_gpu else "cpu:0"
+ print("Using device: %s" % device)
+
+ log_header = (
+ " Time Epoch Iteration Progress (%Epoch) Loss Dev/Loss"
+ " Accuracy Dev/Accuracy")
+ log_template = (
+ "{:>6.0f} {:>5.0f} {:>9.0f} {:>5.0f}/{:<5.0f} {:>7.0f}% {:>8.6f} {} "
+ "{:12.4f} {}")
+ dev_log_template = (
+ "{:>6.0f} {:>5.0f} {:>9.0f} {:>5.0f}/{:<5.0f} {:>7.0f}% {:>8.6f} "
+ "{:8.6f} {:12.4f} {:12.4f}")
+
+ summary_writer = tf.contrib.summary.create_summary_file_writer(
+ config.logdir, flush_millis=10000)
+ train_len = train_data.num_batches(config.batch_size)
+ with tf.device(device), \
+ tfe.restore_variables_on_create(
+ tf.train.latest_checkpoint(config.logdir)), \
+ summary_writer.as_default(), \
+ tf.contrib.summary.always_record_summaries():
+ model = SNLIClassifier(config, embed)
+ global_step = tf.train.get_or_create_global_step()
+ trainer = SNLIClassifierTrainer(model, config.lr)
+
+ start = time.time()
+ iterations = 0
+ mean_loss = tfe.metrics.Mean()
+ accuracy = tfe.metrics.Accuracy()
+ print(log_header)
+ for epoch in xrange(config.epochs):
+ batch_idx = 0
+ for label, prem, prem_trans, hypo, hypo_trans in _get_dataset_iterator(
+ train_data, config.batch_size):
+ if use_gpu:
+ label, prem, hypo = label.gpu(), prem.gpu(), hypo.gpu()
+ # prem_trans and hypo_trans are used for dynamic control flow and can
+ # remain on CPU. Same in _evaluate_on_dataset().
+
+ iterations += 1
+ batch_train_loss, batch_train_logits = trainer.train_batch(
+ label, prem, prem_trans, hypo, hypo_trans)
+ batch_size = tf.shape(label)[0]
+ mean_loss(batch_train_loss.numpy(),
+ weights=batch_size.gpu() if use_gpu else batch_size)
+ accuracy(tf.argmax(batch_train_logits, axis=1), label)
+
+ if iterations % config.save_every == 0:
+ all_variables = (
+ model.variables + [trainer.learning_rate] + [global_step])
+ saver = tfe.Saver(all_variables)
+ saver.save(os.path.join(config.logdir, "ckpt"),
+ global_step=global_step)
+
+ if iterations % config.dev_every == 0:
+ dev_loss, dev_frac_correct = _evaluate_on_dataset(
+ dev_data, config.batch_size, model, trainer, use_gpu)
+ print(dev_log_template.format(
+ time.time() - start,
+ epoch, iterations, 1 + batch_idx, train_len,
+ 100.0 * (1 + batch_idx) / train_len,
+ mean_loss.result(), dev_loss,
+ accuracy.result() * 100.0, dev_frac_correct * 100.0))
+ tf.contrib.summary.scalar("dev/loss", dev_loss)
+ tf.contrib.summary.scalar("dev/accuracy", dev_frac_correct)
+ elif iterations % config.log_every == 0:
+ mean_loss_val = mean_loss.result()
+ accuracy_val = accuracy.result()
+ print(log_template.format(
+ time.time() - start,
+ epoch, iterations, 1 + batch_idx, train_len,
+ 100.0 * (1 + batch_idx) / train_len,
+ mean_loss_val, " " * 8, accuracy_val * 100.0, " " * 12))
+ tf.contrib.summary.scalar("train/loss", mean_loss_val)
+ tf.contrib.summary.scalar("train/accuracy", accuracy_val)
+ # Reset metrics.
+ mean_loss = tfe.metrics.Mean()
+ accuracy = tfe.metrics.Accuracy()
+
+ batch_idx += 1
+ if (epoch + 1) % config.lr_decay_every == 0:
+ trainer.decay_learning_rate(config.lr_decay_by)
+
+ test_loss, test_frac_correct = _evaluate_on_dataset(
+ test_data, config.batch_size, model, trainer, use_gpu)
+ print("Final test loss: %g; accuracy: %g%%" %
+ (test_loss, test_frac_correct * 100.0))
+
+
+def main(_):
+ config = FLAGS
+
+ # Load embedding vectors.
+ vocab = data.load_vocabulary(FLAGS.data_root)
+ word2index, embed = data.load_word_vectors(FLAGS.data_root, vocab)
+
+ print("Loading train, dev and test data...")
+ train_data = data.SnliData(
+ os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_train.txt"),
+ word2index, sentence_len_limit=FLAGS.sentence_len_limit)
+ dev_data = data.SnliData(
+ os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_dev.txt"),
+ word2index, sentence_len_limit=FLAGS.sentence_len_limit)
+ test_data = data.SnliData(
+ os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_test.txt"),
+ word2index, sentence_len_limit=FLAGS.sentence_len_limit)
+
+ train_spinn(embed, train_data, dev_data, test_data, config)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description=
+ "TensorFlow eager implementation of the SPINN SNLI classifier.")
+ parser.add_argument("--data_root", type=str, default="/tmp/spinn-data",
+ help="Root directory in which the training data and "
+ "embedding matrix are found. See README.md for how to "
+ "generate such a directory.")
+ parser.add_argument("--sentence_len_limit", type=int, default=-1,
+ help="Maximum allowed sentence length (# of words). "
+ "The default of -1 means unlimited.")
+ parser.add_argument("--logdir", type=str, default="/tmp/spinn-logs",
+ help="Directory in which summaries will be written for "
+ "TensorBoard.")
+ parser.add_argument("--epochs", type=int, default=50,
+ help="Number of epochs to train.")
+ parser.add_argument("--batch_size", type=int, default=128,
+ help="Batch size to use during training.")
+ parser.add_argument("--d_proj", type=int, default=600,
+ help="Dimensions to project the word embedding vectors "
+ "to.")
+ parser.add_argument("--d_hidden", type=int, default=300,
+ help="Size of the hidden layer of the Tracker.")
+ parser.add_argument("--d_out", type=int, default=4,
+ help="Output dimensions of the SNLIClassifier.")
+ parser.add_argument("--d_mlp", type=int, default=1024,
+ help="Size of each layer of the multi-layer perceptron "
+ "of the SNLICLassifier.")
+ parser.add_argument("--n_mlp_layers", type=int, default=2,
+ help="Number of layers in the multi-layer perceptron "
+ "of the SNLICLassifier.")
+ parser.add_argument("--d_tracker", type=int, default=64,
+ help="Size of the tracker LSTM.")
+ parser.add_argument("--log_every", type=int, default=50,
+ help="Print log and write TensorBoard summary every _ "
+ "training batches.")
+ parser.add_argument("--lr", type=float, default=2e-3,
+ help="Initial learning rate.")
+ parser.add_argument("--lr_decay_by", type=float, default=0.75,
+ help="The ratio to multiply the learning rate by every "
+ "time the learning rate is decayed.")
+ parser.add_argument("--lr_decay_every", type=float, default=1,
+ help="Decay the learning rate every _ epoch(s).")
+ parser.add_argument("--dev_every", type=int, default=1000,
+ help="Run evaluation on the dev split every _ training "
+ "batches.")
+ parser.add_argument("--save_every", type=int, default=1000,
+ help="Save checkpoint every _ training batches.")
+ parser.add_argument("--embed_dropout", type=float, default=0.08,
+ help="Word embedding dropout rate.")
+ parser.add_argument("--mlp_dropout", type=float, default=0.07,
+ help="SNLIClassifier multi-layer perceptron dropout "
+ "rate.")
+ parser.add_argument("--no-projection", action="store_false",
+ dest="projection",
+ help="Whether word embedding vectors are projected to "
+ "another set of vectors (see d_proj).")
+ parser.add_argument("--predict_transitions", action="store_true",
+ dest="predict",
+ help="Whether the Tracker will perform prediction.")
+ parser.add_argument("--force_cpu", action="store_true", dest="force_cpu",
+ help="Force use CPU-only regardless of whether a GPU is "
+ "available.")
+ FLAGS, unparsed = parser.parse_known_args()
+
+ tfe.run(main=main, argv=[sys.argv[0]] + unparsed)