diff options
author | 2017-12-01 13:56:10 -0800 | |
---|---|---|
committer | 2017-12-01 13:59:31 -0800 | |
commit | ed9163acfd510c26c49201ec9e360e20a2625ca8 (patch) | |
tree | 5d61ab7de410baa898925f0023792e5015817c17 | |
parent | ae10f63e2fc76faf5835a660043c328d891c41f0 (diff) |
TF Eager: Add SPINN model example for dynamic/recursive NN.
PiperOrigin-RevId: 177636427
-rw-r--r-- | tensorflow/contrib/eager/README.md | 3 | ||||
-rw-r--r-- | tensorflow/contrib/eager/python/examples/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/eager/python/examples/spinn/BUILD | 41 | ||||
-rw-r--r-- | tensorflow/contrib/eager/python/examples/spinn/README.md | 13 | ||||
-rw-r--r-- | tensorflow/contrib/eager/python/examples/spinn/data.py | 350 | ||||
-rw-r--r-- | tensorflow/contrib/eager/python/examples/spinn/data_test.py | 243 | ||||
-rw-r--r-- | tensorflow/contrib/eager/python/examples/spinn/spinn_test.py | 409 | ||||
-rw-r--r-- | third_party/examples/eager/spinn/BUILD | 14 | ||||
-rw-r--r-- | third_party/examples/eager/spinn/LICENSE | 29 | ||||
-rw-r--r-- | third_party/examples/eager/spinn/README.md | 54 | ||||
-rw-r--r-- | third_party/examples/eager/spinn/spinn.py | 732 |
11 files changed, 1889 insertions, 0 deletions
diff --git a/tensorflow/contrib/eager/README.md b/tensorflow/contrib/eager/README.md index dcc370cd00..09242ee47d 100644 --- a/tensorflow/contrib/eager/README.md +++ b/tensorflow/contrib/eager/README.md @@ -76,3 +76,6 @@ For an introduction to eager execution in TensorFlow, see: ## Changelog - 2017/10/31: Initial preview release. +- 2017/12/01: Example of dynamic neural network: + [SPINN: Stack-augmented Parser-Interpreter Neural Network](https://arxiv.org/abs/1603.06021). + See [README.md](python/examples/spinn/README.md) for details. diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD index aa21a6ab99..6aef010a21 100644 --- a/tensorflow/contrib/eager/python/examples/BUILD +++ b/tensorflow/contrib/eager/python/examples/BUILD @@ -11,5 +11,6 @@ py_library( "//tensorflow/contrib/eager/python/examples/resnet50", "//tensorflow/contrib/eager/python/examples/rnn_colorbot", "//tensorflow/contrib/eager/python/examples/rnn_ptb", + "//tensorflow/contrib/eager/python/examples/spinn:data", ], ) diff --git a/tensorflow/contrib/eager/python/examples/spinn/BUILD b/tensorflow/contrib/eager/python/examples/spinn/BUILD new file mode 100644 index 0000000000..0263d21325 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/spinn/BUILD @@ -0,0 +1,41 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "cuda_py_test") +load("//tensorflow:tensorflow.bzl", "py_test") + +py_library( + name = "data", + srcs = ["data.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = ["//third_party/py/numpy"], +) + +py_test( + name = "data_test", + size = "small", + srcs = ["data_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":data", + "//tensorflow:tensorflow_py", + ], +) + +cuda_py_test( + name = "spinn_test", + size = "medium", + srcs = ["spinn_test.py"], + additional_deps = [ + ":data", + "//third_party/examples/eager/spinn", + "//third_party/py/numpy", + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/summary:summary_test_util", + "//tensorflow/python/eager:test", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + ], +) diff --git a/tensorflow/contrib/eager/python/examples/spinn/README.md b/tensorflow/contrib/eager/python/examples/spinn/README.md new file mode 100644 index 0000000000..eb0637df47 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/spinn/README.md @@ -0,0 +1,13 @@ +# SPINN: Dynamic neural network with TensorFlow eager execution + +This directory contains files supporting the +[spinn.py model in third_party/examples/eager/spinn/](../../../../../../third_party/examples/eager/spinn/spinn.py), +including + +- `data.py`: Utility library for loading and preprocessing the SNLI and GloVe + data. +- `data_test.py` and `spinn_test.py`: Unit tests for the data and model modules. + +See the [README.md in third_party/examples/eager/spinn/](../../../../../../third_party/examples/eager/spinn/README.md) +for detailed background, license and usage information regarding the SPINN code. + diff --git a/tensorflow/contrib/eager/python/examples/spinn/data.py b/tensorflow/contrib/eager/python/examples/spinn/data.py new file mode 100644 index 0000000000..a6e046320f --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/spinn/data.py @@ -0,0 +1,350 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities of SNLI data and GloVe word vectors for SPINN model. + +See more details about the SNLI data set at: + https://nlp.stanford.edu/projects/snli/ + +See more details about the GloVe pretrained word embeddings at: + https://nlp.stanford.edu/projects/glove/ +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import glob +import math +import os +import random + +import numpy as np + +POSSIBLE_LABELS = ("entailment", "contradiction", "neutral") + +UNK_CODE = 0 # Code for unknown word tokens. +PAD_CODE = 1 # Code for padding tokens. + +SHIFT_CODE = 3 +REDUCE_CODE = 2 + +WORD_VECTOR_LEN = 300 # Embedding dimensions. + +LEFT_PAREN = "(" +RIGHT_PAREN = ")" +PARENTHESES = (LEFT_PAREN, RIGHT_PAREN) + + +def get_non_parenthesis_words(items): + """Get the non-parenthesis items from a SNLI parsed sentence. + + Args: + items: Data items from a parsed SNLI setence, with parentheses. E.g., + ["(", "Man", "(", "(", "(", "(", "(", "wearing", "pass", ")", ... + + Returns: + A list of non-parenthis word items, all converted to lower case. E.g., + ["man", "wearing", "pass", ... + """ + return [x.lower() for x in items if x not in PARENTHESES and x] + + +def get_shift_reduce(items): + """Obtain shift-reduce vector from a list of items from the SNLI data. + + Args: + items: Data items as a list of str, e.g., + ["(", "Man", "(", "(", "(", "(", "(", "wearing", "pass", ")", ... + + Returns: + A list of shift-reduce transitions, encoded as `SHIFT_CODE` for shift and + `REDUCE_CODE` for reduce. See code above for the values of `SHIFT_CODE` + and `REDUCE_CODE`. + """ + trans = [] + for item in items: + if item == LEFT_PAREN: + continue + elif item == RIGHT_PAREN: + trans.append(REDUCE_CODE) + else: + trans.append(SHIFT_CODE) + return trans + + +def pad_and_reverse_word_ids(sentences): + """Pad a list of sentences to the common maximum length + 1. + + Args: + sentences: A list of sentences as a list of list of integers. Each integer + is a word ID. Each list of integer corresponds to one sentence. + + Returns: + A numpy.ndarray of shape (num_sentences, max_length + 1), wherein max_length + is the maximum sentence length (in # of words). Each sentence is reversed + and then padded with an extra one at head, as required by the model. + """ + max_len = max(len(sent) for sent in sentences) + for sent in sentences: + if len(sent) < max_len: + sent.extend([PAD_CODE] * (max_len - len(sent))) + # Reverse in time order and pad an extra one. + sentences = np.fliplr(np.array(sentences, dtype=np.int64)) + sentences = np.concatenate( + [np.ones([sentences.shape[0], 1], dtype=np.int64), sentences], axis=1) + return sentences + + +def pad_transitions(sentences_transitions): + """Pad a list of shift-reduce transitions to the maximum length.""" + max_len = max(len(transitions) for transitions in sentences_transitions) + for transitions in sentences_transitions: + if len(transitions) < max_len: + transitions.extend([PAD_CODE] * (max_len - len(transitions))) + return np.array(sentences_transitions, dtype=np.int64) + + +def load_vocabulary(data_root): + """Load vocabulary from SNLI data files. + + Args: + data_root: Root directory of the data. It is assumed that the SNLI data + files have been downloaded and extracted to the "snli/snli_1.0" + subdirectory of it. + + Returns: + Vocabulary as a set of strings. + + Raises: + ValueError: If SNLI data files cannot be found. + """ + snli_path = os.path.join(data_root, "snli") + snli_glob_pattern = os.path.join(snli_path, "snli_1.0/snli_1.0_*.txt") + file_names = glob.glob(snli_glob_pattern) + if not file_names: + raise ValueError( + "Cannot find SNLI data files at %s. " + "Please download and extract SNLI data first." % snli_glob_pattern) + + print("Loading vocabulary...") + vocab = set() + for file_name in file_names: + with open(os.path.join(snli_path, file_name), "rt") as f: + for i, line in enumerate(f): + if i == 0: + continue + items = line.split("\t") + premise_words = get_non_parenthesis_words(items[1].split(" ")) + hypothesis_words = get_non_parenthesis_words(items[2].split(" ")) + vocab.update(premise_words) + vocab.update(hypothesis_words) + return vocab + + +def load_word_vectors(data_root, vocab): + """Load GloVe word vectors for words present in the vocabulary. + + Args: + data_root: Data root directory. It is assumed that the GloVe file + has been downloaded and extracted at the "glove/" subdirectory of it. + vocab: A `set` of words, representing the vocabulary. + + Returns: + 1. word2index: A dict from lower-case word to row index in the embedding + matrix, i.e, `embed` below. + 2. embed: The embedding matrix as a float32 numpy array. Its shape is + [vocabulary_size, WORD_VECTOR_LEN]. vocabulary_size is len(vocab). + WORD_VECTOR_LEN is the embedding dimension (300). + + Raises: + ValueError: If GloVe embedding file cannot be found. + """ + glove_path = os.path.join(data_root, "glove/glove.42B.300d.txt") + if not os.path.isfile(glove_path): + raise ValueError( + "Cannot find GloVe embedding file at %s. " + "Please download and extract GloVe embeddings first." % glove_path) + + print("Loading word vectors...") + + word2index = dict() + embed = [] + + embed.append([0] * WORD_VECTOR_LEN) # <unk> + embed.append([0] * WORD_VECTOR_LEN) # <pad> + word2index["<unk>"] = UNK_CODE + word2index["<pad>"] = PAD_CODE + + with open(glove_path, "rt") as f: + for line in f: + items = line.split(" ") + word = items[0] + if word in vocab and word not in word2index: + word2index[word] = len(embed) + vector = np.array([float(item) for item in items[1:]]) + assert (WORD_VECTOR_LEN,) == vector.shape + embed.append(vector) + embed = np.array(embed, dtype=np.float32) + return word2index, embed + + +def calculate_bins(length2count, min_bin_size): + """Cacluate bin boundaries given a histogram of lengths and mininum bin size. + + Args: + length2count: A `dict` mapping length to sentence count. + min_bin_size: Minimum bin size in terms of total number of sentence pairs + in the bin. + + Returns: + A `list` representing the right bin boundaries, starting from the inclusive + right boundary of the first bin. For example, if the output is + [10, 20, 35], + it means there are three bins: [1, 10], [11, 20] and [21, 35]. + """ + bounds = [] + lengths = sorted(length2count.keys()) + cum_count = 0 + for length in lengths: + cum_count += length2count[length] + if cum_count >= min_bin_size: + bounds.append(length) + cum_count = 0 + if bounds[-1] != lengths[-1]: + bounds.append(lengths[-1]) + return bounds + + +class SnliData(object): + """A split of SNLI data.""" + + def __init__(self, data_file, word2index, sentence_len_limit=-1): + """SnliData constructor. + + Args: + data_file: Full path to the data file, e.g., + "/tmp/spinn-data/snli/snli_1.0/snli_1.0.train.txt" + word2index: A dict from lower-case word to row index in the embedding + matrix (see `load_word_vectors()` for details). + sentence_len_limit: Maximum allowed sentence length (# of words). + A value of <= 0 means unlimited. Sentences longer than this limit + are currently discarded, not truncated. + """ + + self._labels = [] + self._premises = [] + self._premise_transitions = [] + self._hypotheses = [] + self._hypothesis_transitions = [] + + with open(data_file, "rt") as f: + for i, line in enumerate(f): + if i == 0: + # Skip header line. + continue + items = line.split("\t") + if items[0] not in POSSIBLE_LABELS: + continue + + premise_items = items[1].split(" ") + hypothesis_items = items[2].split(" ") + premise_words = get_non_parenthesis_words(premise_items) + hypothesis_words = get_non_parenthesis_words(hypothesis_items) + + if (sentence_len_limit > 0 and + (len(premise_words) > sentence_len_limit or + len(hypothesis_words) > sentence_len_limit)): + # TODO(cais): Maybe truncate; do not discard. + continue + + premise_ids = [ + word2index.get(word, UNK_CODE) for word in premise_words] + hypothesis_ids = [ + word2index.get(word, UNK_CODE) for word in hypothesis_words] + + self._premises.append(premise_ids) + self._hypotheses.append(hypothesis_ids) + self._premise_transitions.append(get_shift_reduce(premise_items)) + self._hypothesis_transitions.append(get_shift_reduce(hypothesis_items)) + assert (len(self._premise_transitions[-1]) == + 2 * len(premise_words) - 1) + assert (len(self._hypothesis_transitions[-1]) == + 2 * len(hypothesis_words) - 1) + + self._labels.append(POSSIBLE_LABELS.index(items[0]) + 1) + + assert len(self._labels) == len(self._premises) + assert len(self._labels) == len(self._hypotheses) + assert len(self._labels) == len(self._premise_transitions) + assert len(self._labels) == len(self._hypothesis_transitions) + + def num_batches(self, batch_size): + """Calculate number of batches given batch size.""" + return int(math.ceil(len(self._labels) / batch_size)) + + def get_generator(self, batch_size): + """Obtain a generator for batched data. + + All examples of this SnliData object are randomly shuffled, sorted + according to the maximum sentence length of the premise and hypothesis + sentences in the pair, and batched. + + Args: + batch_size: Desired batch size. + + Returns: + A generator for data batches. The generator yields a 5-tuple: + label: An array of the shape (batch_size,). + premise: An array of the shape (max_premise_len, batch_size), wherein + max_premise_len is the maximum length of the (padded) premise + sentence in the batch. + premise_transitions: An array of the shape (2 * max_premise_len -3, + batch_size). + hypothesis: Same as `premise`, but for hypothesis sentences. + hypothesis_transitions: Same as `premise_transitions`, but for + hypothesis sentences. + All the elements of the 5-tuple have dtype `int64`. + """ + # Randomly shuffle examples. + zipped = list(zip( + self._labels, self._premises, self._premise_transitions, + self._hypotheses, self._hypothesis_transitions)) + random.shuffle(zipped) + # Then sort the examples by maximum of the premise and hypothesis sentence + # lengths in the pair. During training, the batches are expected to be + # shuffled. So it is okay to leave them sorted by max length here. + (labels, premises, premise_transitions, hypotheses, + hypothesis_transitions) = zip( + *sorted(zipped, key=lambda x: max(len(x[1]), len(x[3])))) + + def _generator(): + begin = 0 + while begin < len(labels): + # The sorting above and the batching here makes sure that sentences of + # similar max lengths are batched together, minimizing the inefficiency + # due to uneven max lengths. The sentences are batched differently in + # each call to get_generator() due to the shuffling before sotring + # above. The pad_and_reverse_word_ids() and pad_transitions() functions + # take care of any remaning unevenness of the max sentence lengths. + end = min(begin + batch_size, len(labels)) + # Transpose, because the SPINN model requires time-major, instead of + # batch-major. + yield (labels[begin:end], + pad_and_reverse_word_ids(premises[begin:end]).T, + pad_transitions(premise_transitions[begin:end]).T, + pad_and_reverse_word_ids(hypotheses[begin:end]).T, + pad_transitions(hypothesis_transitions[begin:end]).T) + begin = end + return _generator diff --git a/tensorflow/contrib/eager/python/examples/spinn/data_test.py b/tensorflow/contrib/eager/python/examples/spinn/data_test.py new file mode 100644 index 0000000000..e4f0b37c50 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/spinn/data_test.py @@ -0,0 +1,243 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests for SPINN data module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import shutil +import tempfile + +import tensorflow as tf + +from tensorflow.contrib.eager.python.examples.spinn import data + + +class DataTest(tf.test.TestCase): + + def setUp(self): + super(DataTest, self).setUp() + self._temp_data_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self._temp_data_dir) + super(DataTest, self).tearDown() + + def testGenNonParenthesisWords(self): + seq_with_parse = ( + "( Man ( ( ( ( ( wearing pass ) ( on ( a lanyard ) ) ) and " + ") ( standing ( in ( ( a crowd ) ( of people ) ) ) ) ) . ) )") + self.assertEqual( + ["man", "wearing", "pass", "on", "a", "lanyard", "and", "standing", + "in", "a", "crowd", "of", "people", "."], + data.get_non_parenthesis_words(seq_with_parse.split(" "))) + + def testGetShiftReduce(self): + seq_with_parse = ( + "( Man ( ( ( ( ( wearing pass ) ( on ( a lanyard ) ) ) and " + ") ( standing ( in ( ( a crowd ) ( of people ) ) ) ) ) . ) )") + self.assertEqual( + [3, 3, 3, 2, 3, 3, 3, 2, 2, 2, 3, 2, 3, 3, 3, 3, 2, 3, 3, 2, 2, 2, 2, 2, + 3, 2, 2], data.get_shift_reduce(seq_with_parse.split(" "))) + + def testPadAndReverseWordIds(self): + id_sequences = [[0, 2, 3, 4, 5], + [6, 7, 8], + [9, 10, 11, 12, 13, 14, 15, 16]] + self.assertAllClose( + [[1, 1, 1, 1, 5, 4, 3, 2, 0], + [1, 1, 1, 1, 1, 1, 8, 7, 6], + [1, 16, 15, 14, 13, 12, 11, 10, 9]], + data.pad_and_reverse_word_ids(id_sequences)) + + def testPadTransitions(self): + unpadded = [[3, 3, 3, 2, 2, 2, 2], + [3, 3, 2, 2, 2]] + self.assertAllClose( + [[3, 3, 3, 2, 2, 2, 2], + [3, 3, 2, 2, 2, 1, 1]], + data.pad_transitions(unpadded)) + + def testCalculateBins(self): + length2count = { + 1: 10, + 2: 15, + 3: 25, + 4: 40, + 5: 35, + 6: 10} + self.assertEqual([2, 3, 4, 5, 6], + data.calculate_bins(length2count, 20)) + self.assertEqual([3, 4, 6], data.calculate_bins(length2count, 40)) + self.assertEqual([4, 6], data.calculate_bins(length2count, 60)) + + def testLoadVoacbulary(self): + snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") + fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt") + fake_dev_file = os.path.join(snli_1_0_dir, "snli_1.0_dev.txt") + os.makedirs(snli_1_0_dir) + + with open(fake_train_file, "wt") as f: + f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t" + "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t" + "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n") + f.write("neutral\t( ( Foo bar ) . )\t( ( foo baz ) . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + with open(fake_dev_file, "wt") as f: + f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t" + "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t" + "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n") + f.write("neutral\t( ( Quux quuz ) ? )\t( ( Corge grault ) ! )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Quux quuz?\t.Corge grault!\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + + vocab = data.load_vocabulary(self._temp_data_dir) + self.assertSetEqual( + {".", "?", "!", "foo", "bar", "baz", "quux", "quuz", "corge", "grault"}, + vocab) + + def testLoadVoacbularyWithoutFileRaisesError(self): + with self.assertRaisesRegexp(ValueError, "Cannot find SNLI data files at"): + data.load_vocabulary(self._temp_data_dir) + + os.makedirs(os.path.join(self._temp_data_dir, "snli")) + with self.assertRaisesRegexp(ValueError, "Cannot find SNLI data files at"): + data.load_vocabulary(self._temp_data_dir) + + os.makedirs(os.path.join(self._temp_data_dir, "snli/snli_1.0")) + with self.assertRaisesRegexp(ValueError, "Cannot find SNLI data files at"): + data.load_vocabulary(self._temp_data_dir) + + def testLoadWordVectors(self): + glove_dir = os.path.join(self._temp_data_dir, "glove") + os.makedirs(glove_dir) + glove_file = os.path.join(glove_dir, "glove.42B.300d.txt") + + words = [".", ",", "foo", "bar", "baz"] + with open(glove_file, "wt") as f: + for i, word in enumerate(words): + f.write("%s " % word) + for j in range(data.WORD_VECTOR_LEN): + f.write("%.5f" % (i * 0.1)) + if j < data.WORD_VECTOR_LEN - 1: + f.write(" ") + else: + f.write("\n") + + vocab = {"foo", "bar", "baz", "qux", "."} + # Notice that "qux" is not present in `words`. + word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab) + + self.assertEqual(6, len(word2index)) + self.assertEqual(0, word2index["<unk>"]) + self.assertEqual(1, word2index["<pad>"]) + self.assertEqual(2, word2index["."]) + self.assertEqual(3, word2index["foo"]) + self.assertEqual(4, word2index["bar"]) + self.assertEqual(5, word2index["baz"]) + self.assertEqual((6, data.WORD_VECTOR_LEN), embed.shape) + self.assertAllClose([0.0] * data.WORD_VECTOR_LEN, embed[0, :]) + self.assertAllClose([0.0] * data.WORD_VECTOR_LEN, embed[1, :]) + self.assertAllClose([0.0] * data.WORD_VECTOR_LEN, embed[2, :]) + self.assertAllClose([0.2] * data.WORD_VECTOR_LEN, embed[3, :]) + self.assertAllClose([0.3] * data.WORD_VECTOR_LEN, embed[4, :]) + self.assertAllClose([0.4] * data.WORD_VECTOR_LEN, embed[5, :]) + + def testLoadWordVectorsWithoutFileRaisesError(self): + vocab = {"foo", "bar", "baz", "qux", "."} + with self.assertRaisesRegexp( + ValueError, "Cannot find GloVe embedding file at"): + data.load_word_vectors(self._temp_data_dir, vocab) + + os.makedirs(os.path.join(self._temp_data_dir, "glove")) + with self.assertRaisesRegexp( + ValueError, "Cannot find GloVe embedding file at"): + data.load_word_vectors(self._temp_data_dir, vocab) + + def testSnliData(self): + """Unit test for SnliData objects.""" + snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") + fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt") + os.makedirs(snli_1_0_dir) + + # Four sentences in total. + with open(fake_train_file, "wt") as f: + f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t" + "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t" + "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n") + f.write("neutral\t( ( Foo bar ) . )\t( ( foo . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + f.write("contradiction\t( ( Bar foo ) . )\t( ( baz . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + f.write("entailment\t( ( Quux quuz ) . )\t( ( grault . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + f.write("entailment\t( ( Quuz quux ) . )\t( ( garply . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + + glove_dir = os.path.join(self._temp_data_dir, "glove") + os.makedirs(glove_dir) + glove_file = os.path.join(glove_dir, "glove.42B.300d.txt") + + words = [".", "foo", "bar", "baz", "quux", "quuz", "grault", "garply"] + with open(glove_file, "wt") as f: + for i, word in enumerate(words): + f.write("%s " % word) + for j in range(data.WORD_VECTOR_LEN): + f.write("%.5f" % (i * 0.1)) + if j < data.WORD_VECTOR_LEN - 1: + f.write(" ") + else: + f.write("\n") + + vocab = data.load_vocabulary(self._temp_data_dir) + word2index, _ = data.load_word_vectors(self._temp_data_dir, vocab) + + train_data = data.SnliData(fake_train_file, word2index) + self.assertEqual(4, train_data.num_batches(1)) + self.assertEqual(2, train_data.num_batches(2)) + self.assertEqual(2, train_data.num_batches(3)) + self.assertEqual(1, train_data.num_batches(4)) + + generator = train_data.get_generator(2)() + for i in range(2): + label, prem, prem_trans, hypo, hypo_trans = next(generator) + self.assertEqual(2, len(label)) + self.assertEqual((4, 2), prem.shape) + self.assertEqual((5, 2), prem_trans.shape) + self.assertEqual((3, 2), hypo.shape) + self.assertEqual((3, 2), hypo_trans.shape) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py new file mode 100644 index 0000000000..84e25cf81a --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py @@ -0,0 +1,409 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import gc +import glob +import os +import shutil +import tempfile +import time + +import numpy as np +import tensorflow as tf + +# pylint: disable=g-bad-import-order +import tensorflow.contrib.eager as tfe +from tensorflow.contrib.eager.python.examples.spinn import data +from third_party.examples.eager.spinn import spinn +from tensorflow.contrib.summary import summary_test_util +from tensorflow.python.eager import test +from tensorflow.python.framework import test_util +# pylint: enable=g-bad-import-order + + +def _generate_synthetic_snli_data_batch(sequence_length, + batch_size, + vocab_size): + """Generate a fake batch of SNLI data for testing.""" + with tf.device("cpu:0"): + labels = tf.random_uniform([batch_size], minval=1, maxval=4, dtype=tf.int64) + prem = tf.random_uniform( + (sequence_length, batch_size), maxval=vocab_size, dtype=tf.int64) + prem_trans = tf.constant(np.array( + [[3, 3, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3, + 2, 3, 3, 2, 2, 3, 3, 3, 2, 2, 2, 2, + 3, 2, 2]] * batch_size, dtype=np.int64).T) + hypo = tf.random_uniform( + (sequence_length, batch_size), maxval=vocab_size, dtype=tf.int64) + hypo_trans = tf.constant(np.array( + [[3, 3, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3, + 2, 3, 3, 2, 2, 3, 3, 3, 2, 2, 2, 2, + 3, 2, 2]] * batch_size, dtype=np.int64).T) + if tfe.num_gpus(): + labels = labels.gpu() + prem = prem.gpu() + prem_trans = prem_trans.gpu() + hypo = hypo.gpu() + hypo_trans = hypo_trans.gpu() + return labels, prem, prem_trans, hypo, hypo_trans + + +def _test_spinn_config(d_embed, d_out, logdir=None): + config_tuple = collections.namedtuple( + "Config", ["d_hidden", "d_proj", "d_tracker", "predict", + "embed_dropout", "mlp_dropout", "n_mlp_layers", "d_mlp", + "d_out", "projection", "lr", "batch_size", "epochs", + "force_cpu", "logdir", "log_every", "dev_every", "save_every", + "lr_decay_every", "lr_decay_by"]) + return config_tuple( + d_hidden=d_embed, + d_proj=d_embed * 2, + d_tracker=8, + predict=False, + embed_dropout=0.1, + mlp_dropout=0.1, + n_mlp_layers=2, + d_mlp=32, + d_out=d_out, + projection=True, + lr=2e-2, + batch_size=2, + epochs=10, + force_cpu=False, + logdir=logdir, + log_every=1, + dev_every=2, + save_every=2, + lr_decay_every=1, + lr_decay_by=0.75) + + +class SpinnTest(test_util.TensorFlowTestCase): + + def setUp(self): + super(SpinnTest, self).setUp() + self._test_device = "gpu:0" if tfe.num_gpus() else "cpu:0" + self._temp_data_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self._temp_data_dir) + super(SpinnTest, self).tearDown() + + def testBundle(self): + with tf.device(self._test_device): + lstm_iter = [np.array([[0, 1], [2, 3]], dtype=np.float32), + np.array([[0, -1], [-2, -3]], dtype=np.float32), + np.array([[0, 2], [4, 6]], dtype=np.float32), + np.array([[0, -2], [-4, -6]], dtype=np.float32)] + out = spinn._bundle(lstm_iter) + + self.assertEqual(2, len(out)) + self.assertEqual(tf.float32, out[0].dtype) + self.assertEqual(tf.float32, out[1].dtype) + self.assertAllEqual(np.array([[0, 2, 0, -2, 0, 4, 0, -4]]).T, + out[0].numpy()) + self.assertAllEqual(np.array([[1, 3, -1, -3, 2, 6, -2, -6]]).T, + out[1].numpy()) + + def testUnbunbdle(self): + with tf.device(self._test_device): + state = [np.array([[0, 1, 2], [3, 4, 5]], dtype=np.float32), + np.array([[0, -1, -2], [-3, -4, -5]], dtype=np.float32)] + out = spinn._unbundle(state) + + self.assertEqual(2, len(out)) + self.assertEqual(tf.float32, out[0].dtype) + self.assertEqual(tf.float32, out[1].dtype) + self.assertAllEqual(np.array([[0, 1, 2, 0, -1, -2]]), + out[0].numpy()) + self.assertAllEqual(np.array([[3, 4, 5, -3, -4, -5]]), + out[1].numpy()) + + def testReducer(self): + with tf.device(self._test_device): + batch_size = 3 + size = 10 + tracker_size = 8 + reducer = spinn.Reducer(size, tracker_size=tracker_size) + + left_in = [] + right_in = [] + tracking = [] + for _ in range(batch_size): + left_in.append(tf.random_normal((1, size * 2))) + right_in.append(tf.random_normal((1, size * 2))) + tracking.append(tf.random_normal((1, tracker_size * 2))) + + out = reducer(left_in, right_in, tracking=tracking) + self.assertEqual(batch_size, len(out)) + self.assertEqual(tf.float32, out[0].dtype) + self.assertEqual((1, size * 2), out[0].shape) + + def testReduceTreeLSTM(self): + with tf.device(self._test_device): + size = 10 + tracker_size = 8 + reducer = spinn.Reducer(size, tracker_size=tracker_size) + + lstm_in = np.array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + [0, -1, -2, -3, -4, -5, -6, -7, -8, -9]], + dtype=np.float32) + c1 = np.array([[0, 1], [2, 3]], dtype=np.float32) + c2 = np.array([[0, -1], [-2, -3]], dtype=np.float32) + + h, c = reducer._tree_lstm(c1, c2, lstm_in) + self.assertEqual(tf.float32, h.dtype) + self.assertEqual(tf.float32, c.dtype) + self.assertEqual((2, 2), h.shape) + self.assertEqual((2, 2), c.shape) + + def testTracker(self): + with tf.device(self._test_device): + batch_size = 2 + size = 10 + tracker_size = 8 + buffer_length = 18 + stack_size = 3 + + tracker = spinn.Tracker(tracker_size, False) + tracker.reset_state() + + # Create dummy inputs for testing. + bufs = [] + buf = [] + for _ in range(buffer_length): + buf.append(tf.random_normal((batch_size, size * 2))) + bufs.append(buf) + self.assertEqual(1, len(bufs)) + self.assertEqual(buffer_length, len(bufs[0])) + self.assertEqual((batch_size, size * 2), bufs[0][0].shape) + + stacks = [] + stack = [] + for _ in range(stack_size): + stack.append(tf.random_normal((batch_size, size * 2))) + stacks.append(stack) + self.assertEqual(1, len(stacks)) + self.assertEqual(3, len(stacks[0])) + self.assertEqual((batch_size, size * 2), stacks[0][0].shape) + + for _ in range(2): + out1, out2 = tracker(bufs, stacks) + self.assertIsNone(out2) + self.assertEqual(batch_size, len(out1)) + self.assertEqual(tf.float32, out1[0].dtype) + self.assertEqual((1, tracker_size * 2), out1[0].shape) + + self.assertEqual(tf.float32, tracker.state.c.dtype) + self.assertEqual((batch_size, tracker_size), tracker.state.c.shape) + self.assertEqual(tf.float32, tracker.state.h.dtype) + self.assertEqual((batch_size, tracker_size), tracker.state.h.shape) + + def testSPINN(self): + with tf.device(self._test_device): + embedding_dims = 10 + d_tracker = 8 + sequence_length = 15 + num_transitions = 27 + + config_tuple = collections.namedtuple( + "Config", ["d_hidden", "d_proj", "d_tracker", "predict"]) + config = config_tuple( + embedding_dims, embedding_dims * 2, d_tracker, False) + s = spinn.SPINN(config) + + # Create some fake data. + buffers = tf.random_normal((sequence_length, 1, config.d_proj)) + transitions = tf.constant( + [[3], [3], [2], [3], [3], [3], [2], [2], [2], [3], [3], [3], + [2], [3], [3], [2], [2], [3], [3], [3], [2], [2], [2], [2], + [3], [2], [2]], dtype=tf.int64) + self.assertEqual(tf.int64, transitions.dtype) + self.assertEqual((num_transitions, 1), transitions.shape) + + out = s(buffers, transitions, training=True) + self.assertEqual(tf.float32, out.dtype) + self.assertEqual((1, embedding_dims), out.shape) + + def testSNLIClassifierAndTrainer(self): + with tf.device(self._test_device): + vocab_size = 40 + batch_size = 2 + d_embed = 10 + sequence_length = 15 + d_out = 4 + + config = _test_spinn_config(d_embed, d_out) + + # Create fake embedding matrix. + embed = tf.random_normal((vocab_size, d_embed)) + + model = spinn.SNLIClassifier(config, embed) + trainer = spinn.SNLIClassifierTrainer(model, config.lr) + + (labels, prem, prem_trans, hypo, + hypo_trans) = _generate_synthetic_snli_data_batch(sequence_length, + batch_size, + vocab_size) + + # Invoke model under non-training mode. + logits = model(prem, prem_trans, hypo, hypo_trans, training=False) + self.assertEqual(tf.float32, logits.dtype) + self.assertEqual((batch_size, d_out), logits.shape) + + # Invoke model under training model. + logits = model(prem, prem_trans, hypo, hypo_trans, training=True) + self.assertEqual(tf.float32, logits.dtype) + self.assertEqual((batch_size, d_out), logits.shape) + + # Calculate loss. + loss1 = trainer.loss(labels, logits) + self.assertEqual(tf.float32, loss1.dtype) + self.assertEqual((), loss1.shape) + + loss2, logits = trainer.train_batch( + labels, prem, prem_trans, hypo, hypo_trans) + self.assertEqual(tf.float32, loss2.dtype) + self.assertEqual((), loss2.shape) + self.assertEqual(tf.float32, logits.dtype) + self.assertEqual((batch_size, d_out), logits.shape) + # Training on the batch should have led to a change in the loss value. + self.assertNotEqual(loss1.numpy(), loss2.numpy()) + + def testTrainSpinn(self): + """Test with fake toy SNLI data and GloVe vectors.""" + + # 1. Create and load a fake SNLI data file and a fake GloVe embedding file. + snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") + fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt") + os.makedirs(snli_1_0_dir) + + # Four sentences in total. + with open(fake_train_file, "wt") as f: + f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t" + "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t" + "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n") + f.write("neutral\t( ( Foo bar ) . )\t( ( foo . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + f.write("contradiction\t( ( Bar foo ) . )\t( ( baz . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + f.write("entailment\t( ( Quux quuz ) . )\t( ( grault . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + f.write("entailment\t( ( Quuz quux ) . )\t( ( garply . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + + glove_dir = os.path.join(self._temp_data_dir, "glove") + os.makedirs(glove_dir) + glove_file = os.path.join(glove_dir, "glove.42B.300d.txt") + + words = [".", "foo", "bar", "baz", "quux", "quuz", "grault", "garply"] + with open(glove_file, "wt") as f: + for i, word in enumerate(words): + f.write("%s " % word) + for j in range(data.WORD_VECTOR_LEN): + f.write("%.5f" % (i * 0.1)) + if j < data.WORD_VECTOR_LEN - 1: + f.write(" ") + else: + f.write("\n") + + vocab = data.load_vocabulary(self._temp_data_dir) + word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab) + + train_data = data.SnliData(fake_train_file, word2index) + dev_data = data.SnliData(fake_train_file, word2index) + test_data = data.SnliData(fake_train_file, word2index) + print(embed) + + # 2. Create a fake config. + config = _test_spinn_config( + data.WORD_VECTOR_LEN, 4, + logdir=os.path.join(self._temp_data_dir, "logdir")) + + # 3. Test training of a SPINN model. + spinn.train_spinn(embed, train_data, dev_data, test_data, config) + + # 4. Load train loss values from the summary files and verify that they + # decrease with training. + summary_file = glob.glob(os.path.join(config.logdir, "events.out.*"))[0] + events = summary_test_util.events_from_file(summary_file) + train_losses = [event.summary.value[0].simple_value for event in events + if event.summary.value + and event.summary.value[0].tag == "train/loss"] + self.assertEqual(config.epochs, len(train_losses)) + self.assertLess(train_losses[-1], train_losses[0]) + + +class EagerSpinnSNLIClassifierBenchmark(test.Benchmark): + + def benchmarkEagerSpinnSNLIClassifier(self): + test_device = "gpu:0" if tfe.num_gpus() else "cpu:0" + with tf.device(test_device): + burn_in_iterations = 2 + benchmark_iterations = 10 + + vocab_size = 1000 + batch_size = 128 + sequence_length = 15 + d_embed = 200 + d_out = 4 + + embed = tf.random_normal((vocab_size, d_embed)) + + config = _test_spinn_config(d_embed, d_out) + model = spinn.SNLIClassifier(config, embed) + trainer = spinn.SNLIClassifierTrainer(model, config.lr) + + (labels, prem, prem_trans, hypo, + hypo_trans) = _generate_synthetic_snli_data_batch(sequence_length, + batch_size, + vocab_size) + + for _ in range(burn_in_iterations): + trainer.train_batch(labels, prem, prem_trans, hypo, hypo_trans) + + gc.collect() + start_time = time.time() + for _ in xrange(benchmark_iterations): + trainer.train_batch(labels, prem, prem_trans, hypo, hypo_trans) + wall_time = time.time() - start_time + # Named "examples"_per_sec to conform with other benchmarks. + extras = {"examples_per_sec": benchmark_iterations / wall_time} + self.report_benchmark( + name="Eager_SPINN_SNLIClassifier_Benchmark", + iters=benchmark_iterations, + wall_time=wall_time, + extras=extras) + + +if __name__ == "__main__": + test.main() diff --git a/third_party/examples/eager/spinn/BUILD b/third_party/examples/eager/spinn/BUILD new file mode 100644 index 0000000000..0e39d4696f --- /dev/null +++ b/third_party/examples/eager/spinn/BUILD @@ -0,0 +1,14 @@ +licenses(["notice"]) # 3-clause BSD. + +py_binary( + name = "spinn", + srcs = ["spinn.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/eager/python:tfe", + "//tensorflow/contrib/eager/python/examples/spinn:data", + "@six_archive//:six", + ], +) diff --git a/third_party/examples/eager/spinn/LICENSE b/third_party/examples/eager/spinn/LICENSE new file mode 100644 index 0000000000..09d493bf1f --- /dev/null +++ b/third_party/examples/eager/spinn/LICENSE @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2017, +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/third_party/examples/eager/spinn/README.md b/third_party/examples/eager/spinn/README.md new file mode 100644 index 0000000000..c00d8d9015 --- /dev/null +++ b/third_party/examples/eager/spinn/README.md @@ -0,0 +1,54 @@ +# SPINN with TensorFlow eager execution + +SPINN, or Stack-Augmented Parser-Interpreter Neural Network, is a recursive +neural network that utilizes syntactic parse information for natural language +understanding. + +SPINN was originally described by: +Bowman, S.R., Gauthier, J., Rastogi A., Gupta, R., Manning, C.D., & Potts, C. + (2016). A Fast Unified Model for Parsing and Sentence Understanding. + https://arxiv.org/abs/1603.06021 + +Our implementation is based on @jekbradbury's PyTorch implementation at: +https://github.com/jekbradbury/examples/blob/spinn/snli/spinn.py, + +which was released under the BSD 3-Clause License at: +https://github.com/jekbradbury/examples/blob/spinn/LICENSE + +## Content + +Python source file(s): +- `spinn.py`: Model definition and training routines written with TensorFlow + eager execution idioms. + +## To run + +- Make sure you have installed the latest `tf-nightly` or `tf-nightly-gpu` pip + package of TensorFlow in order to access the eager execution feature. + +- Download and extract the raw SNLI data and GloVe embedding vectors. + For example: + + ```bash + curl -fSsL https://nlp.stanford.edu/projects/snli/snli_1.0.zip --create-dirs -o /tmp/spinn-data/snli/snli_1.0.zip + unzip -d /tmp/spinn-data/snli /tmp/spinn-data/snli/snli_1.0.zip + curl -fSsL http://nlp.stanford.edu/data/glove.42B.300d.zip --create-dirs -o /tmp/spinn-data/glove/glove.42B.300d.zip + unzip -d /tmp/spinn-data/glove /tmp/spinn-data/glove/glove.42B.300d.zip + ``` + +- Train model. E.g., + + ```bash + python spinn.py --data_root /tmp/spinn-data --logdir /tmp/spinn-logs + ``` + + During training, model checkpoints and TensorBoard summaries will be written + periodically to the directory specified with the `--logdir` flag. + The training script will reload a saved checkpoint from the directory if it + can find one there. + + To view the summaries with TensorBoard: + + ```bash + tensorboard --logdir /tmp/spinn-logs + ``` diff --git a/third_party/examples/eager/spinn/spinn.py b/third_party/examples/eager/spinn/spinn.py new file mode 100644 index 0000000000..963ac0e65b --- /dev/null +++ b/third_party/examples/eager/spinn/spinn.py @@ -0,0 +1,732 @@ +r"""Implementation of SPINN in TensorFlow eager execution. + +SPINN: Stack-Augmented Parser-Interpreter Neural Network. + +Ths file contains model definition and code for training the model. + +The model definition is based on PyTorch implementation at: + https://github.com/jekbradbury/examples/tree/spinn/snli + +which was released under a BSD 3-Clause License at: +https://github.com/jekbradbury/examples/blob/spinn/LICENSE: + +Copyright (c) 2017, +All rights reserved. + +See ./LICENSE for more details. + +Instructions for use: +* See `README.md` for details on how to prepare the SNLI and GloVe data. +* Suppose you have prepared the data at "/tmp/spinn-data", use the folloing + command to train the model: + + ```bash + python spinn.py --data_root /tmp/spinn-data --logdir /tmp/spinn-logs + ``` + + Checkpoints and TensorBoard summaries will be written to "/tmp/spinn-logs". + +References: +* Bowman, S.R., Gauthier, J., Rastogi A., Gupta, R., Manning, C.D., & Potts, C. + (2016). A Fast Unified Model for Parsing and Sentence Understanding. + https://arxiv.org/abs/1603.06021 +* Bradbury, J. (2017). Recursive Neural Networks with PyTorch. + https://devblogs.nvidia.com/parallelforall/recursive-neural-networks-pytorch/ +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import itertools +import os +import sys +import time + +from six.moves import xrange # pylint: disable=redefined-builtin +import tensorflow as tf + +import tensorflow.contrib.eager as tfe +from tensorflow.contrib.eager.python.examples.spinn import data + + +def _bundle(lstm_iter): + """Concatenate a list of Tensors along 1st axis and split result into two. + + Args: + lstm_iter: A `list` of `N` dense `Tensor`s, each of which has the shape + (R, 2 * M). + + Returns: + A `list` of two dense `Tensor`s, each of which has the shape (N * R, M). + """ + return tf.split(tf.concat(lstm_iter, 0), 2, axis=1) + + +def _unbundle(state): + """Concatenate a list of Tensors along 2nd axis and split result. + + This is the inverse of `_bundle`. + + Args: + state: A `list` of two dense `Tensor`s, each of which has the shape (R, M). + + Returns: + A `list` of `R` dense `Tensors`, each of which has the shape (1, 2 * M). + """ + return tf.split(tf.concat(state, 1), state[0].shape[0], axis=0) + + +class Reducer(tfe.Network): + """A module that applies reduce operation on left and right vectors.""" + + def __init__(self, size, tracker_size=None): + super(Reducer, self).__init__() + self.left = self.track_layer(tf.layers.Dense(5 * size, activation=None)) + self.right = self.track_layer( + tf.layers.Dense(5 * size, activation=None, use_bias=False)) + if tracker_size is not None: + self.track = self.track_layer( + tf.layers.Dense(5 * size, activation=None, use_bias=False)) + else: + self.track = None + + def call(self, left_in, right_in, tracking=None): + """Invoke forward pass of the Reduce module. + + This method feeds a linear combination of `left_in`, `right_in` and + `tracking` into a Tree LSTM and returns the output of the Tree LSTM. + + Args: + left_in: A list of length L. Each item is a dense `Tensor` with + the shape (1, n_dims). n_dims is the size of the embedding vector. + right_in: A list of the same length as `left_in`. Each item should have + the same shape as the items of `left_in`. + tracking: Optional list of the same length as `left_in`. Each item is a + dense `Tensor` with shape (1, tracker_size * 2). tracker_size is the + size of the Tracker's state vector. + + Returns: + Output: A list of length batch_size. Each item has the shape (1, n_dims). + """ + left, right = _bundle(left_in), _bundle(right_in) + lstm_in = self.left(left[0]) + self.right(right[0]) + if self.track and tracking: + lstm_in += self.track(_bundle(tracking)[0]) + return _unbundle(self._tree_lstm(left[1], right[1], lstm_in)) + + def _tree_lstm(self, c1, c2, lstm_in): + a, i, f1, f2, o = tf.split(lstm_in, 5, axis=1) + c = tf.tanh(a) * tf.sigmoid(i) + tf.sigmoid(f1) * c1 + tf.sigmoid(f2) * c2 + h = tf.sigmoid(o) * tf.tanh(c) + return h, c + + +class Tracker(tfe.Network): + """A module that tracks the history of the sentence with an LSTM.""" + + def __init__(self, tracker_size, predict): + """Constructor of Tracker. + + Args: + tracker_size: Number of dimensions of the underlying `LSTMCell`. + predict: (`bool`) Whether prediction mode is enabled. + """ + super(Tracker, self).__init__() + self._rnn = self.track_layer(tf.nn.rnn_cell.LSTMCell(tracker_size)) + self._state_size = tracker_size + if predict: + self._transition = self.track_layer(tf.layers.Dense(4)) + else: + self._transition = None + + def reset_state(self): + self.state = None + + def call(self, bufs, stacks): + """Invoke the forward pass of the Tracker module. + + This method feeds the concatenation of the top two elements of the stacks + into an LSTM cell and returns the resultant state of the LSTM cell. + + Args: + bufs: A `list` of length batch_size. Each item is a `list` of + max_sequence_len (maximum sequence length of the batch). Each item + of the nested list is a dense `Tensor` of shape (1, d_proj), where + d_proj is the size of the word embedding vector or the size of the + vector space that the word embedding vector is projected to. + stacks: A `list` of size batch_size. Each item is a `list` of + variable length corresponding to the current height of the stack. + Each item of the nested list is a dense `Tensor` of shape (1, d_proj). + + Returns: + 1. A list of length batch_size. Each item is a dense `Tensor` of shape + (1, d_tracker * 2). + 2. If under predict mode, result of applying a Dense layer on the + first state vector of the RNN. Else, `None`. + """ + buf = _bundle([buf[-1] for buf in bufs])[0] + stack1 = _bundle([stack[-1] for stack in stacks])[0] + stack2 = _bundle([stack[-2] for stack in stacks])[0] + x = tf.concat([buf, stack1, stack2], 1) + if self.state is None: + batch_size = int(x.shape[0]) + zeros = tf.zeros((batch_size, self._state_size), dtype=tf.float32) + self.state = [zeros, zeros] + _, self.state = self._rnn(x, self.state) + unbundled = _unbundle(self.state) + if self._transition: + return unbundled, self._transition(self.state[0]) + else: + return unbundled, None + + +class SPINN(tfe.Network): + """Stack-augmented Parser-Interpreter Neural Network. + + See https://arxiv.org/abs/1603.06021 for more details. + """ + + def __init__(self, config): + """Constructor of SPINN. + + Args: + config: A `namedtupled` with the following attributes. + d_proj - (`int`) number of dimensions of the vector space to project the + word embeddings to. + d_tracker - (`int`) number of dimensions of the Tracker's state vector. + d_hidden - (`int`) number of the dimensions of the hidden state, for the + Reducer module. + n_mlp_layers - (`int`) number of multi-layer perceptron layers to use to + convert the output of the `Feature` module to logits. + predict - (`bool`) Whether the Tracker will enabled predictions. + """ + super(SPINN, self).__init__() + self.config = config + self.reducer = self.track_layer(Reducer(config.d_hidden, config.d_tracker)) + if config.d_tracker is not None: + self.tracker = self.track_layer(Tracker(config.d_tracker, config.predict)) + else: + self.tracker = None + + def call(self, buffers, transitions, training=False): + """Invoke the forward pass of the SPINN model. + + Args: + buffers: Dense `Tensor` of shape + (max_sequence_len, batch_size, config.d_proj). + transitions: Dense `Tensor` with integer values that represent the parse + trees of the sentences. A value of 2 indicates "reduce"; a value of 3 + indicates "shift". Shape: (max_sequence_len * 2 - 3, batch_size). + training: Whether the invocation is under training mode. + + Returns: + Output `Tensor` of shape (batch_size, config.d_embed). + """ + max_sequence_len, batch_size, d_proj = (int(x) for x in buffers.shape) + + # Split the buffers into left and right word items and put the initial + # items in a stack. + splitted = tf.split( + tf.reshape(tf.transpose(buffers, [1, 0, 2]), [-1, d_proj]), + max_sequence_len * batch_size, axis=0) + buffers = [splitted[k:k + max_sequence_len] + for k in xrange(0, len(splitted), max_sequence_len)] + stacks = [[buf[0], buf[0]] for buf in buffers] + + if self.tracker: + # Reset tracker state for new batch. + self.tracker.reset_state() + + num_transitions = transitions.shape[0] + + # Iterate through transitions and perform the appropriate stack-pop, reduce + # and stack-push operations. + transitions = transitions.numpy() + for i in xrange(num_transitions): + trans = transitions[i] + if self.tracker: + # Invoke tracker to obtain the current tracker states for the sentences. + tracker_states, trans_hypothesis = self.tracker(buffers, stacks) + if trans_hypothesis: + trans = tf.argmax(trans_hypothesis, axis=-1) + else: + tracker_states = itertools.repeat(None) + lefts, rights, trackings = [], [], [] + for transition, buf, stack, tracking in zip( + trans, buffers, stacks, tracker_states): + if int(transition) == 3: # Shift. + stack.append(buf.pop()) + elif int(transition) == 2: # Reduce. + rights.append(stack.pop()) + lefts.append(stack.pop()) + trackings.append(tracking) + + if rights: + reducer_output = self.reducer(lefts, rights, trackings) + reduced = iter(reducer_output) + + for transition, stack in zip(trans, stacks): + if int(transition) == 2: # Reduce. + stack.append(next(reduced)) + return _bundle([stack.pop() for stack in stacks])[0] + + +class SNLIClassifier(tfe.Network): + """SNLI Classifier Model. + + A model aimed at solving the SNLI (Standford Natural Language Inference) + task, using the SPINN model from above. For details of the task, see: + https://nlp.stanford.edu/projects/snli/ + """ + + def __init__(self, config, embed): + """Constructor of SNLICLassifier. + + Args: + config: A namedtuple containing required configurations for the model. It + needs to have the following attributes. + projection - (`bool`) whether the word vectors are to be projected onto + another vector space (of `d_proj` dimensions). + d_proj - (`int`) number of dimensions of the vector space to project the + word embeddings to. + embed_dropout - (`float`) dropout rate for the word embedding vectors. + n_mlp_layers - (`int`) number of multi-layer perceptron (MLP) layers to + use to convert the output of the `Feature` module to logits. + mlp_dropout - (`float`) dropout rate of the MLP layers. + d_out - (`int`) number of dimensions of the final output of the MLP + layers. + lr - (`float`) learning rate. + embed: A embedding matrix of shape (vocab_size, d_embed). + """ + super(SNLIClassifier, self).__init__() + self.config = config + self.embed = tf.constant(embed) + + self.projection = self.track_layer(tf.layers.Dense(config.d_proj)) + self.embed_bn = self.track_layer(tf.layers.BatchNormalization()) + self.embed_dropout = self.track_layer( + tf.layers.Dropout(rate=config.embed_dropout)) + self.encoder = self.track_layer(SPINN(config)) + + self.feature_bn = self.track_layer(tf.layers.BatchNormalization()) + self.feature_dropout = self.track_layer( + tf.layers.Dropout(rate=config.mlp_dropout)) + + self.mlp_dense = [] + self.mlp_bn = [] + self.mlp_dropout = [] + for _ in xrange(config.n_mlp_layers): + self.mlp_dense.append(self.track_layer(tf.layers.Dense(config.d_mlp))) + self.mlp_bn.append( + self.track_layer(tf.layers.BatchNormalization())) + self.mlp_dropout.append( + self.track_layer(tf.layers.Dropout(rate=config.mlp_dropout))) + self.mlp_output = self.track_layer(tf.layers.Dense( + config.d_out, + kernel_initializer=tf.random_uniform_initializer(minval=-5e-3, + maxval=5e-3))) + + def call(self, + premise, + premise_transition, + hypothesis, + hypothesis_transition, + training=False): + """Invoke the forward pass the SNLIClassifier model. + + Args: + premise: The word indices of the premise sentences, with shape + (max_prem_seq_len, batch_size). + premise_transition: The transitions for the premise sentences, with shape + (max_prem_seq_len * 2 - 3, batch_size). + hypothesis: The word indices of the hypothesis sentences, with shape + (max_hypo_seq_len, batch_size). + hypothesis_transition: The transitions for the hypothesis sentences, with + shape (max_hypo_seq_len * 2 - 3, batch_size). + training: Whether the invocation is under training mode. + + Returns: + The logits, as a dense `Tensor` of shape (batch_size, d_out), where d_out + is the size of the output vector. + """ + # Perform embedding lookup on the premise and hypothesis inputs, which have + # the word-index format. + premise_embed = tf.nn.embedding_lookup(self.embed, premise) + hypothesis_embed = tf.nn.embedding_lookup(self.embed, hypothesis) + + if self.config.projection: + # Project the embedding vectors to another vector space. + premise_embed = self.projection(premise_embed) + hypothesis_embed = self.projection(hypothesis_embed) + + # Perform batch normalization and dropout on the possibly projected word + # vectors. + premise_embed = self.embed_bn(premise_embed, training=training) + hypothesis_embed = self.embed_bn(hypothesis_embed, training=training) + premise_embed = self.embed_dropout(premise_embed, training=training) + hypothesis_embed = self.embed_dropout(hypothesis_embed, training=training) + + # Run the batch-normalized and dropout-processed word vectors through the + # SPINN encoder. + premise = self.encoder(premise_embed, premise_transition, + training=training) + hypothesis = self.encoder(hypothesis_embed, hypothesis_transition, + training=training) + + # Combine encoder outputs for premises and hypotheses into logits. + # Then apply batch normalization and dropuout on the logits. + logits = tf.concat( + [premise, hypothesis, premise - hypothesis, premise * hypothesis], 1) + logits = self.feature_dropout( + self.feature_bn(logits, training=training), training=training) + + # Apply the multi-layer perceptron on the logits. + for dense, bn, dropout in zip( + self.mlp_dense, self.mlp_bn, self.mlp_dropout): + logits = tf.nn.elu(dense(logits)) + logits = dropout(bn(logits, training=training), training=training) + logits = self.mlp_output(logits) + return logits + + +class SNLIClassifierTrainer(object): + """A class that coordinates the training of an SNLIClassifier.""" + + def __init__(self, snli_classifier, lr): + """Constructor of SNLIClassifierTrainer. + + Args: + snli_classifier: An instance of `SNLIClassifier`. + lr: Learning rate. + """ + self._model = snli_classifier + # Create a custom learning rate Variable for the RMSProp optimizer, because + # the learning rate needs to be manually decayed later (see + # decay_learning_rate()). + self._learning_rate = tfe.Variable(lr, name="learning_rate") + self._optimizer = tf.train.RMSPropOptimizer(self._learning_rate, + epsilon=1e-6) + + def loss(self, labels, logits): + """Calculate the loss given a batch of data. + + Args: + labels: The truth labels, with shape (batch_size,). + logits: The logits output from the forward pass of the SNLIClassifier + model, with shape (batch_size, d_out), where d_out is the output + dimension size of the SNLIClassifier. + + Returns: + The loss value, as a scalar `Tensor`. + """ + return tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits( + labels=labels, logits=logits)) + + def train_batch(self, + labels, + premise, + premise_transition, + hypothesis, + hypothesis_transition): + """Train model on batch of data. + + Args: + labels: The truth labels, with shape (batch_size,). + premise: The word indices of the premise sentences, with shape + (max_prem_seq_len, batch_size). + premise_transition: The transitions for the premise sentences, with shape + (max_prem_seq_len * 2 - 3, batch_size). + hypothesis: The word indices of the hypothesis sentences, with shape + (max_hypo_seq_len, batch_size). + hypothesis_transition: The transitions for the hypothesis sentences, with + shape (max_hypo_seq_len * 2 - 3, batch_size). + + Returns: + 1. loss value as a scalar `Tensor`. + 2. logits as a dense `Tensor` of shape (batch_size, d_out), where d_out is + the output dimension size of the SNLIClassifier. + """ + with tfe.GradientTape() as tape: + tape.watch(self._model.variables) + logits = self._model(premise, + premise_transition, + hypothesis, + hypothesis_transition, + training=True) + loss = self.loss(labels, logits) + gradients = tape.gradient(loss, self._model.variables) + self._optimizer.apply_gradients(zip(gradients, self._model.variables), + global_step=tf.train.get_global_step()) + return loss, logits + + def decay_learning_rate(self, decay_by): + """Decay learning rate of the optimizer by factor decay_by.""" + self._learning_rate.assign(self._learning_rate * decay_by) + print("Decayed learning rate of optimizer to: %s" % + self._learning_rate.numpy()) + + @property + def learning_rate(self): + return self._learning_rate + + +def _batch_n_correct(logits, label): + """Calculate number of correct predictions in a batch. + + Args: + logits: A logits Tensor of shape `(batch_size, num_categories)` and dtype + `float32`. + label: A labels Tensor of shape `(batch_size,)` and dtype `int64` + + Returns: + Number of correct predictions. + """ + return tf.reduce_sum( + tf.cast((tf.equal( + tf.argmax(logits, axis=1), label)), tf.float32)).numpy() + + +def _evaluate_on_dataset(snli_data, batch_size, model, trainer, use_gpu): + """Run evaluation on a dataset. + + Args: + snli_data: The `data.SnliData` to use in this evaluation. + batch_size: The batch size to use during this evaluation. + model: An instance of `SNLIClassifier` to evaluate. + trainer: An instance of `SNLIClassifierTrainer to use for this + evaluation. + use_gpu: Whether GPU is being used. + + Returns: + 1. Average loss across all examples of the dataset. + 2. Average accuracy rate across all examples of the dataset. + """ + mean_loss = tfe.metrics.Mean() + accuracy = tfe.metrics.Accuracy() + for label, prem, prem_trans, hypo, hypo_trans in _get_dataset_iterator( + snli_data, batch_size): + if use_gpu: + label, prem, hypo = label.gpu(), prem.gpu(), hypo.gpu() + logits = model(prem, prem_trans, hypo, hypo_trans, training=False) + loss_val = trainer.loss(label, logits) + batch_size = tf.shape(label)[0] + mean_loss(loss_val, weights=batch_size.gpu() if use_gpu else batch_size) + accuracy(tf.argmax(logits, axis=1), label) + return mean_loss.result().numpy(), accuracy.result().numpy() + + +def _get_dataset_iterator(snli_data, batch_size): + """Get a data iterator for a split of SNLI data. + + Args: + snli_data: A `data.SnliData` object. + batch_size: The desired batch size. + + Returns: + A dataset iterator. + """ + with tf.device("/device:CPU:0"): + # Some tf.data ops, such as ShuffleDataset, are available only on CPU. + dataset = tf.data.Dataset.from_generator( + snli_data.get_generator(batch_size), + (tf.int64, tf.int64, tf.int64, tf.int64, tf.int64)) + dataset = dataset.shuffle(snli_data.num_batches(batch_size)) + return tfe.Iterator(dataset) + + +def train_spinn(embed, train_data, dev_data, test_data, config): + """Train a SPINN model. + + Args: + embed: The embedding matrix as a float32 numpy array with shape + [vocabulary_size, word_vector_len]. word_vector_len is the length of a + word embedding vector. + train_data: An instance of `data.SnliData`, for the train split. + dev_data: Same as above, for the dev split. + test_data: Same as above, for the test split. + config: A configuration object. See the argument to this Python binary for + details. + + Returns: + 1. Final loss value on the test split. + 2. Final fraction of correct classifications on the test split. + """ + use_gpu = tfe.num_gpus() > 0 and not config.force_cpu + device = "gpu:0" if use_gpu else "cpu:0" + print("Using device: %s" % device) + + log_header = ( + " Time Epoch Iteration Progress (%Epoch) Loss Dev/Loss" + " Accuracy Dev/Accuracy") + log_template = ( + "{:>6.0f} {:>5.0f} {:>9.0f} {:>5.0f}/{:<5.0f} {:>7.0f}% {:>8.6f} {} " + "{:12.4f} {}") + dev_log_template = ( + "{:>6.0f} {:>5.0f} {:>9.0f} {:>5.0f}/{:<5.0f} {:>7.0f}% {:>8.6f} " + "{:8.6f} {:12.4f} {:12.4f}") + + summary_writer = tf.contrib.summary.create_summary_file_writer( + config.logdir, flush_millis=10000) + train_len = train_data.num_batches(config.batch_size) + with tf.device(device), \ + tfe.restore_variables_on_create( + tf.train.latest_checkpoint(config.logdir)), \ + summary_writer.as_default(), \ + tf.contrib.summary.always_record_summaries(): + model = SNLIClassifier(config, embed) + global_step = tf.train.get_or_create_global_step() + trainer = SNLIClassifierTrainer(model, config.lr) + + start = time.time() + iterations = 0 + mean_loss = tfe.metrics.Mean() + accuracy = tfe.metrics.Accuracy() + print(log_header) + for epoch in xrange(config.epochs): + batch_idx = 0 + for label, prem, prem_trans, hypo, hypo_trans in _get_dataset_iterator( + train_data, config.batch_size): + if use_gpu: + label, prem, hypo = label.gpu(), prem.gpu(), hypo.gpu() + # prem_trans and hypo_trans are used for dynamic control flow and can + # remain on CPU. Same in _evaluate_on_dataset(). + + iterations += 1 + batch_train_loss, batch_train_logits = trainer.train_batch( + label, prem, prem_trans, hypo, hypo_trans) + batch_size = tf.shape(label)[0] + mean_loss(batch_train_loss.numpy(), + weights=batch_size.gpu() if use_gpu else batch_size) + accuracy(tf.argmax(batch_train_logits, axis=1), label) + + if iterations % config.save_every == 0: + all_variables = ( + model.variables + [trainer.learning_rate] + [global_step]) + saver = tfe.Saver(all_variables) + saver.save(os.path.join(config.logdir, "ckpt"), + global_step=global_step) + + if iterations % config.dev_every == 0: + dev_loss, dev_frac_correct = _evaluate_on_dataset( + dev_data, config.batch_size, model, trainer, use_gpu) + print(dev_log_template.format( + time.time() - start, + epoch, iterations, 1 + batch_idx, train_len, + 100.0 * (1 + batch_idx) / train_len, + mean_loss.result(), dev_loss, + accuracy.result() * 100.0, dev_frac_correct * 100.0)) + tf.contrib.summary.scalar("dev/loss", dev_loss) + tf.contrib.summary.scalar("dev/accuracy", dev_frac_correct) + elif iterations % config.log_every == 0: + mean_loss_val = mean_loss.result() + accuracy_val = accuracy.result() + print(log_template.format( + time.time() - start, + epoch, iterations, 1 + batch_idx, train_len, + 100.0 * (1 + batch_idx) / train_len, + mean_loss_val, " " * 8, accuracy_val * 100.0, " " * 12)) + tf.contrib.summary.scalar("train/loss", mean_loss_val) + tf.contrib.summary.scalar("train/accuracy", accuracy_val) + # Reset metrics. + mean_loss = tfe.metrics.Mean() + accuracy = tfe.metrics.Accuracy() + + batch_idx += 1 + if (epoch + 1) % config.lr_decay_every == 0: + trainer.decay_learning_rate(config.lr_decay_by) + + test_loss, test_frac_correct = _evaluate_on_dataset( + test_data, config.batch_size, model, trainer, use_gpu) + print("Final test loss: %g; accuracy: %g%%" % + (test_loss, test_frac_correct * 100.0)) + + +def main(_): + config = FLAGS + + # Load embedding vectors. + vocab = data.load_vocabulary(FLAGS.data_root) + word2index, embed = data.load_word_vectors(FLAGS.data_root, vocab) + + print("Loading train, dev and test data...") + train_data = data.SnliData( + os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_train.txt"), + word2index, sentence_len_limit=FLAGS.sentence_len_limit) + dev_data = data.SnliData( + os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_dev.txt"), + word2index, sentence_len_limit=FLAGS.sentence_len_limit) + test_data = data.SnliData( + os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_test.txt"), + word2index, sentence_len_limit=FLAGS.sentence_len_limit) + + train_spinn(embed, train_data, dev_data, test_data, config) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description= + "TensorFlow eager implementation of the SPINN SNLI classifier.") + parser.add_argument("--data_root", type=str, default="/tmp/spinn-data", + help="Root directory in which the training data and " + "embedding matrix are found. See README.md for how to " + "generate such a directory.") + parser.add_argument("--sentence_len_limit", type=int, default=-1, + help="Maximum allowed sentence length (# of words). " + "The default of -1 means unlimited.") + parser.add_argument("--logdir", type=str, default="/tmp/spinn-logs", + help="Directory in which summaries will be written for " + "TensorBoard.") + parser.add_argument("--epochs", type=int, default=50, + help="Number of epochs to train.") + parser.add_argument("--batch_size", type=int, default=128, + help="Batch size to use during training.") + parser.add_argument("--d_proj", type=int, default=600, + help="Dimensions to project the word embedding vectors " + "to.") + parser.add_argument("--d_hidden", type=int, default=300, + help="Size of the hidden layer of the Tracker.") + parser.add_argument("--d_out", type=int, default=4, + help="Output dimensions of the SNLIClassifier.") + parser.add_argument("--d_mlp", type=int, default=1024, + help="Size of each layer of the multi-layer perceptron " + "of the SNLICLassifier.") + parser.add_argument("--n_mlp_layers", type=int, default=2, + help="Number of layers in the multi-layer perceptron " + "of the SNLICLassifier.") + parser.add_argument("--d_tracker", type=int, default=64, + help="Size of the tracker LSTM.") + parser.add_argument("--log_every", type=int, default=50, + help="Print log and write TensorBoard summary every _ " + "training batches.") + parser.add_argument("--lr", type=float, default=2e-3, + help="Initial learning rate.") + parser.add_argument("--lr_decay_by", type=float, default=0.75, + help="The ratio to multiply the learning rate by every " + "time the learning rate is decayed.") + parser.add_argument("--lr_decay_every", type=float, default=1, + help="Decay the learning rate every _ epoch(s).") + parser.add_argument("--dev_every", type=int, default=1000, + help="Run evaluation on the dev split every _ training " + "batches.") + parser.add_argument("--save_every", type=int, default=1000, + help="Save checkpoint every _ training batches.") + parser.add_argument("--embed_dropout", type=float, default=0.08, + help="Word embedding dropout rate.") + parser.add_argument("--mlp_dropout", type=float, default=0.07, + help="SNLIClassifier multi-layer perceptron dropout " + "rate.") + parser.add_argument("--no-projection", action="store_false", + dest="projection", + help="Whether word embedding vectors are projected to " + "another set of vectors (see d_proj).") + parser.add_argument("--predict_transitions", action="store_true", + dest="predict", + help="Whether the Tracker will perform prediction.") + parser.add_argument("--force_cpu", action="store_true", dest="force_cpu", + help="Force use CPU-only regardless of whether a GPU is " + "available.") + FLAGS, unparsed = parser.parse_known_args() + + tfe.run(main=main, argv=[sys.argv[0]] + unparsed) |