aboutsummaryrefslogtreecommitdiffhomepage
path: root/third_party/examples
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2017-12-01 13:56:10 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-01 13:59:31 -0800
commited9163acfd510c26c49201ec9e360e20a2625ca8 (patch)
tree5d61ab7de410baa898925f0023792e5015817c17 /third_party/examples
parentae10f63e2fc76faf5835a660043c328d891c41f0 (diff)
TF Eager: Add SPINN model example for dynamic/recursive NN.
PiperOrigin-RevId: 177636427
Diffstat (limited to 'third_party/examples')
-rw-r--r--third_party/examples/eager/spinn/BUILD14
-rw-r--r--third_party/examples/eager/spinn/LICENSE29
-rw-r--r--third_party/examples/eager/spinn/README.md54
-rw-r--r--third_party/examples/eager/spinn/spinn.py732
4 files changed, 829 insertions, 0 deletions
diff --git a/third_party/examples/eager/spinn/BUILD b/third_party/examples/eager/spinn/BUILD
new file mode 100644
index 0000000000..0e39d4696f
--- /dev/null
+++ b/third_party/examples/eager/spinn/BUILD
@@ -0,0 +1,14 @@
+licenses(["notice"]) # 3-clause BSD.
+
+py_binary(
+ name = "spinn",
+ srcs = ["spinn.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/contrib/eager/python:tfe",
+ "//tensorflow/contrib/eager/python/examples/spinn:data",
+ "@six_archive//:six",
+ ],
+)
diff --git a/third_party/examples/eager/spinn/LICENSE b/third_party/examples/eager/spinn/LICENSE
new file mode 100644
index 0000000000..09d493bf1f
--- /dev/null
+++ b/third_party/examples/eager/spinn/LICENSE
@@ -0,0 +1,29 @@
+BSD 3-Clause License
+
+Copyright (c) 2017,
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+* Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/third_party/examples/eager/spinn/README.md b/third_party/examples/eager/spinn/README.md
new file mode 100644
index 0000000000..c00d8d9015
--- /dev/null
+++ b/third_party/examples/eager/spinn/README.md
@@ -0,0 +1,54 @@
+# SPINN with TensorFlow eager execution
+
+SPINN, or Stack-Augmented Parser-Interpreter Neural Network, is a recursive
+neural network that utilizes syntactic parse information for natural language
+understanding.
+
+SPINN was originally described by:
+Bowman, S.R., Gauthier, J., Rastogi A., Gupta, R., Manning, C.D., & Potts, C.
+ (2016). A Fast Unified Model for Parsing and Sentence Understanding.
+ https://arxiv.org/abs/1603.06021
+
+Our implementation is based on @jekbradbury's PyTorch implementation at:
+https://github.com/jekbradbury/examples/blob/spinn/snli/spinn.py,
+
+which was released under the BSD 3-Clause License at:
+https://github.com/jekbradbury/examples/blob/spinn/LICENSE
+
+## Content
+
+Python source file(s):
+- `spinn.py`: Model definition and training routines written with TensorFlow
+ eager execution idioms.
+
+## To run
+
+- Make sure you have installed the latest `tf-nightly` or `tf-nightly-gpu` pip
+ package of TensorFlow in order to access the eager execution feature.
+
+- Download and extract the raw SNLI data and GloVe embedding vectors.
+ For example:
+
+ ```bash
+ curl -fSsL https://nlp.stanford.edu/projects/snli/snli_1.0.zip --create-dirs -o /tmp/spinn-data/snli/snli_1.0.zip
+ unzip -d /tmp/spinn-data/snli /tmp/spinn-data/snli/snli_1.0.zip
+ curl -fSsL http://nlp.stanford.edu/data/glove.42B.300d.zip --create-dirs -o /tmp/spinn-data/glove/glove.42B.300d.zip
+ unzip -d /tmp/spinn-data/glove /tmp/spinn-data/glove/glove.42B.300d.zip
+ ```
+
+- Train model. E.g.,
+
+ ```bash
+ python spinn.py --data_root /tmp/spinn-data --logdir /tmp/spinn-logs
+ ```
+
+ During training, model checkpoints and TensorBoard summaries will be written
+ periodically to the directory specified with the `--logdir` flag.
+ The training script will reload a saved checkpoint from the directory if it
+ can find one there.
+
+ To view the summaries with TensorBoard:
+
+ ```bash
+ tensorboard --logdir /tmp/spinn-logs
+ ```
diff --git a/third_party/examples/eager/spinn/spinn.py b/third_party/examples/eager/spinn/spinn.py
new file mode 100644
index 0000000000..963ac0e65b
--- /dev/null
+++ b/third_party/examples/eager/spinn/spinn.py
@@ -0,0 +1,732 @@
+r"""Implementation of SPINN in TensorFlow eager execution.
+
+SPINN: Stack-Augmented Parser-Interpreter Neural Network.
+
+Ths file contains model definition and code for training the model.
+
+The model definition is based on PyTorch implementation at:
+ https://github.com/jekbradbury/examples/tree/spinn/snli
+
+which was released under a BSD 3-Clause License at:
+https://github.com/jekbradbury/examples/blob/spinn/LICENSE:
+
+Copyright (c) 2017,
+All rights reserved.
+
+See ./LICENSE for more details.
+
+Instructions for use:
+* See `README.md` for details on how to prepare the SNLI and GloVe data.
+* Suppose you have prepared the data at "/tmp/spinn-data", use the folloing
+ command to train the model:
+
+ ```bash
+ python spinn.py --data_root /tmp/spinn-data --logdir /tmp/spinn-logs
+ ```
+
+ Checkpoints and TensorBoard summaries will be written to "/tmp/spinn-logs".
+
+References:
+* Bowman, S.R., Gauthier, J., Rastogi A., Gupta, R., Manning, C.D., & Potts, C.
+ (2016). A Fast Unified Model for Parsing and Sentence Understanding.
+ https://arxiv.org/abs/1603.06021
+* Bradbury, J. (2017). Recursive Neural Networks with PyTorch.
+ https://devblogs.nvidia.com/parallelforall/recursive-neural-networks-pytorch/
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import itertools
+import os
+import sys
+import time
+
+from six.moves import xrange # pylint: disable=redefined-builtin
+import tensorflow as tf
+
+import tensorflow.contrib.eager as tfe
+from tensorflow.contrib.eager.python.examples.spinn import data
+
+
+def _bundle(lstm_iter):
+ """Concatenate a list of Tensors along 1st axis and split result into two.
+
+ Args:
+ lstm_iter: A `list` of `N` dense `Tensor`s, each of which has the shape
+ (R, 2 * M).
+
+ Returns:
+ A `list` of two dense `Tensor`s, each of which has the shape (N * R, M).
+ """
+ return tf.split(tf.concat(lstm_iter, 0), 2, axis=1)
+
+
+def _unbundle(state):
+ """Concatenate a list of Tensors along 2nd axis and split result.
+
+ This is the inverse of `_bundle`.
+
+ Args:
+ state: A `list` of two dense `Tensor`s, each of which has the shape (R, M).
+
+ Returns:
+ A `list` of `R` dense `Tensors`, each of which has the shape (1, 2 * M).
+ """
+ return tf.split(tf.concat(state, 1), state[0].shape[0], axis=0)
+
+
+class Reducer(tfe.Network):
+ """A module that applies reduce operation on left and right vectors."""
+
+ def __init__(self, size, tracker_size=None):
+ super(Reducer, self).__init__()
+ self.left = self.track_layer(tf.layers.Dense(5 * size, activation=None))
+ self.right = self.track_layer(
+ tf.layers.Dense(5 * size, activation=None, use_bias=False))
+ if tracker_size is not None:
+ self.track = self.track_layer(
+ tf.layers.Dense(5 * size, activation=None, use_bias=False))
+ else:
+ self.track = None
+
+ def call(self, left_in, right_in, tracking=None):
+ """Invoke forward pass of the Reduce module.
+
+ This method feeds a linear combination of `left_in`, `right_in` and
+ `tracking` into a Tree LSTM and returns the output of the Tree LSTM.
+
+ Args:
+ left_in: A list of length L. Each item is a dense `Tensor` with
+ the shape (1, n_dims). n_dims is the size of the embedding vector.
+ right_in: A list of the same length as `left_in`. Each item should have
+ the same shape as the items of `left_in`.
+ tracking: Optional list of the same length as `left_in`. Each item is a
+ dense `Tensor` with shape (1, tracker_size * 2). tracker_size is the
+ size of the Tracker's state vector.
+
+ Returns:
+ Output: A list of length batch_size. Each item has the shape (1, n_dims).
+ """
+ left, right = _bundle(left_in), _bundle(right_in)
+ lstm_in = self.left(left[0]) + self.right(right[0])
+ if self.track and tracking:
+ lstm_in += self.track(_bundle(tracking)[0])
+ return _unbundle(self._tree_lstm(left[1], right[1], lstm_in))
+
+ def _tree_lstm(self, c1, c2, lstm_in):
+ a, i, f1, f2, o = tf.split(lstm_in, 5, axis=1)
+ c = tf.tanh(a) * tf.sigmoid(i) + tf.sigmoid(f1) * c1 + tf.sigmoid(f2) * c2
+ h = tf.sigmoid(o) * tf.tanh(c)
+ return h, c
+
+
+class Tracker(tfe.Network):
+ """A module that tracks the history of the sentence with an LSTM."""
+
+ def __init__(self, tracker_size, predict):
+ """Constructor of Tracker.
+
+ Args:
+ tracker_size: Number of dimensions of the underlying `LSTMCell`.
+ predict: (`bool`) Whether prediction mode is enabled.
+ """
+ super(Tracker, self).__init__()
+ self._rnn = self.track_layer(tf.nn.rnn_cell.LSTMCell(tracker_size))
+ self._state_size = tracker_size
+ if predict:
+ self._transition = self.track_layer(tf.layers.Dense(4))
+ else:
+ self._transition = None
+
+ def reset_state(self):
+ self.state = None
+
+ def call(self, bufs, stacks):
+ """Invoke the forward pass of the Tracker module.
+
+ This method feeds the concatenation of the top two elements of the stacks
+ into an LSTM cell and returns the resultant state of the LSTM cell.
+
+ Args:
+ bufs: A `list` of length batch_size. Each item is a `list` of
+ max_sequence_len (maximum sequence length of the batch). Each item
+ of the nested list is a dense `Tensor` of shape (1, d_proj), where
+ d_proj is the size of the word embedding vector or the size of the
+ vector space that the word embedding vector is projected to.
+ stacks: A `list` of size batch_size. Each item is a `list` of
+ variable length corresponding to the current height of the stack.
+ Each item of the nested list is a dense `Tensor` of shape (1, d_proj).
+
+ Returns:
+ 1. A list of length batch_size. Each item is a dense `Tensor` of shape
+ (1, d_tracker * 2).
+ 2. If under predict mode, result of applying a Dense layer on the
+ first state vector of the RNN. Else, `None`.
+ """
+ buf = _bundle([buf[-1] for buf in bufs])[0]
+ stack1 = _bundle([stack[-1] for stack in stacks])[0]
+ stack2 = _bundle([stack[-2] for stack in stacks])[0]
+ x = tf.concat([buf, stack1, stack2], 1)
+ if self.state is None:
+ batch_size = int(x.shape[0])
+ zeros = tf.zeros((batch_size, self._state_size), dtype=tf.float32)
+ self.state = [zeros, zeros]
+ _, self.state = self._rnn(x, self.state)
+ unbundled = _unbundle(self.state)
+ if self._transition:
+ return unbundled, self._transition(self.state[0])
+ else:
+ return unbundled, None
+
+
+class SPINN(tfe.Network):
+ """Stack-augmented Parser-Interpreter Neural Network.
+
+ See https://arxiv.org/abs/1603.06021 for more details.
+ """
+
+ def __init__(self, config):
+ """Constructor of SPINN.
+
+ Args:
+ config: A `namedtupled` with the following attributes.
+ d_proj - (`int`) number of dimensions of the vector space to project the
+ word embeddings to.
+ d_tracker - (`int`) number of dimensions of the Tracker's state vector.
+ d_hidden - (`int`) number of the dimensions of the hidden state, for the
+ Reducer module.
+ n_mlp_layers - (`int`) number of multi-layer perceptron layers to use to
+ convert the output of the `Feature` module to logits.
+ predict - (`bool`) Whether the Tracker will enabled predictions.
+ """
+ super(SPINN, self).__init__()
+ self.config = config
+ self.reducer = self.track_layer(Reducer(config.d_hidden, config.d_tracker))
+ if config.d_tracker is not None:
+ self.tracker = self.track_layer(Tracker(config.d_tracker, config.predict))
+ else:
+ self.tracker = None
+
+ def call(self, buffers, transitions, training=False):
+ """Invoke the forward pass of the SPINN model.
+
+ Args:
+ buffers: Dense `Tensor` of shape
+ (max_sequence_len, batch_size, config.d_proj).
+ transitions: Dense `Tensor` with integer values that represent the parse
+ trees of the sentences. A value of 2 indicates "reduce"; a value of 3
+ indicates "shift". Shape: (max_sequence_len * 2 - 3, batch_size).
+ training: Whether the invocation is under training mode.
+
+ Returns:
+ Output `Tensor` of shape (batch_size, config.d_embed).
+ """
+ max_sequence_len, batch_size, d_proj = (int(x) for x in buffers.shape)
+
+ # Split the buffers into left and right word items and put the initial
+ # items in a stack.
+ splitted = tf.split(
+ tf.reshape(tf.transpose(buffers, [1, 0, 2]), [-1, d_proj]),
+ max_sequence_len * batch_size, axis=0)
+ buffers = [splitted[k:k + max_sequence_len]
+ for k in xrange(0, len(splitted), max_sequence_len)]
+ stacks = [[buf[0], buf[0]] for buf in buffers]
+
+ if self.tracker:
+ # Reset tracker state for new batch.
+ self.tracker.reset_state()
+
+ num_transitions = transitions.shape[0]
+
+ # Iterate through transitions and perform the appropriate stack-pop, reduce
+ # and stack-push operations.
+ transitions = transitions.numpy()
+ for i in xrange(num_transitions):
+ trans = transitions[i]
+ if self.tracker:
+ # Invoke tracker to obtain the current tracker states for the sentences.
+ tracker_states, trans_hypothesis = self.tracker(buffers, stacks)
+ if trans_hypothesis:
+ trans = tf.argmax(trans_hypothesis, axis=-1)
+ else:
+ tracker_states = itertools.repeat(None)
+ lefts, rights, trackings = [], [], []
+ for transition, buf, stack, tracking in zip(
+ trans, buffers, stacks, tracker_states):
+ if int(transition) == 3: # Shift.
+ stack.append(buf.pop())
+ elif int(transition) == 2: # Reduce.
+ rights.append(stack.pop())
+ lefts.append(stack.pop())
+ trackings.append(tracking)
+
+ if rights:
+ reducer_output = self.reducer(lefts, rights, trackings)
+ reduced = iter(reducer_output)
+
+ for transition, stack in zip(trans, stacks):
+ if int(transition) == 2: # Reduce.
+ stack.append(next(reduced))
+ return _bundle([stack.pop() for stack in stacks])[0]
+
+
+class SNLIClassifier(tfe.Network):
+ """SNLI Classifier Model.
+
+ A model aimed at solving the SNLI (Standford Natural Language Inference)
+ task, using the SPINN model from above. For details of the task, see:
+ https://nlp.stanford.edu/projects/snli/
+ """
+
+ def __init__(self, config, embed):
+ """Constructor of SNLICLassifier.
+
+ Args:
+ config: A namedtuple containing required configurations for the model. It
+ needs to have the following attributes.
+ projection - (`bool`) whether the word vectors are to be projected onto
+ another vector space (of `d_proj` dimensions).
+ d_proj - (`int`) number of dimensions of the vector space to project the
+ word embeddings to.
+ embed_dropout - (`float`) dropout rate for the word embedding vectors.
+ n_mlp_layers - (`int`) number of multi-layer perceptron (MLP) layers to
+ use to convert the output of the `Feature` module to logits.
+ mlp_dropout - (`float`) dropout rate of the MLP layers.
+ d_out - (`int`) number of dimensions of the final output of the MLP
+ layers.
+ lr - (`float`) learning rate.
+ embed: A embedding matrix of shape (vocab_size, d_embed).
+ """
+ super(SNLIClassifier, self).__init__()
+ self.config = config
+ self.embed = tf.constant(embed)
+
+ self.projection = self.track_layer(tf.layers.Dense(config.d_proj))
+ self.embed_bn = self.track_layer(tf.layers.BatchNormalization())
+ self.embed_dropout = self.track_layer(
+ tf.layers.Dropout(rate=config.embed_dropout))
+ self.encoder = self.track_layer(SPINN(config))
+
+ self.feature_bn = self.track_layer(tf.layers.BatchNormalization())
+ self.feature_dropout = self.track_layer(
+ tf.layers.Dropout(rate=config.mlp_dropout))
+
+ self.mlp_dense = []
+ self.mlp_bn = []
+ self.mlp_dropout = []
+ for _ in xrange(config.n_mlp_layers):
+ self.mlp_dense.append(self.track_layer(tf.layers.Dense(config.d_mlp)))
+ self.mlp_bn.append(
+ self.track_layer(tf.layers.BatchNormalization()))
+ self.mlp_dropout.append(
+ self.track_layer(tf.layers.Dropout(rate=config.mlp_dropout)))
+ self.mlp_output = self.track_layer(tf.layers.Dense(
+ config.d_out,
+ kernel_initializer=tf.random_uniform_initializer(minval=-5e-3,
+ maxval=5e-3)))
+
+ def call(self,
+ premise,
+ premise_transition,
+ hypothesis,
+ hypothesis_transition,
+ training=False):
+ """Invoke the forward pass the SNLIClassifier model.
+
+ Args:
+ premise: The word indices of the premise sentences, with shape
+ (max_prem_seq_len, batch_size).
+ premise_transition: The transitions for the premise sentences, with shape
+ (max_prem_seq_len * 2 - 3, batch_size).
+ hypothesis: The word indices of the hypothesis sentences, with shape
+ (max_hypo_seq_len, batch_size).
+ hypothesis_transition: The transitions for the hypothesis sentences, with
+ shape (max_hypo_seq_len * 2 - 3, batch_size).
+ training: Whether the invocation is under training mode.
+
+ Returns:
+ The logits, as a dense `Tensor` of shape (batch_size, d_out), where d_out
+ is the size of the output vector.
+ """
+ # Perform embedding lookup on the premise and hypothesis inputs, which have
+ # the word-index format.
+ premise_embed = tf.nn.embedding_lookup(self.embed, premise)
+ hypothesis_embed = tf.nn.embedding_lookup(self.embed, hypothesis)
+
+ if self.config.projection:
+ # Project the embedding vectors to another vector space.
+ premise_embed = self.projection(premise_embed)
+ hypothesis_embed = self.projection(hypothesis_embed)
+
+ # Perform batch normalization and dropout on the possibly projected word
+ # vectors.
+ premise_embed = self.embed_bn(premise_embed, training=training)
+ hypothesis_embed = self.embed_bn(hypothesis_embed, training=training)
+ premise_embed = self.embed_dropout(premise_embed, training=training)
+ hypothesis_embed = self.embed_dropout(hypothesis_embed, training=training)
+
+ # Run the batch-normalized and dropout-processed word vectors through the
+ # SPINN encoder.
+ premise = self.encoder(premise_embed, premise_transition,
+ training=training)
+ hypothesis = self.encoder(hypothesis_embed, hypothesis_transition,
+ training=training)
+
+ # Combine encoder outputs for premises and hypotheses into logits.
+ # Then apply batch normalization and dropuout on the logits.
+ logits = tf.concat(
+ [premise, hypothesis, premise - hypothesis, premise * hypothesis], 1)
+ logits = self.feature_dropout(
+ self.feature_bn(logits, training=training), training=training)
+
+ # Apply the multi-layer perceptron on the logits.
+ for dense, bn, dropout in zip(
+ self.mlp_dense, self.mlp_bn, self.mlp_dropout):
+ logits = tf.nn.elu(dense(logits))
+ logits = dropout(bn(logits, training=training), training=training)
+ logits = self.mlp_output(logits)
+ return logits
+
+
+class SNLIClassifierTrainer(object):
+ """A class that coordinates the training of an SNLIClassifier."""
+
+ def __init__(self, snli_classifier, lr):
+ """Constructor of SNLIClassifierTrainer.
+
+ Args:
+ snli_classifier: An instance of `SNLIClassifier`.
+ lr: Learning rate.
+ """
+ self._model = snli_classifier
+ # Create a custom learning rate Variable for the RMSProp optimizer, because
+ # the learning rate needs to be manually decayed later (see
+ # decay_learning_rate()).
+ self._learning_rate = tfe.Variable(lr, name="learning_rate")
+ self._optimizer = tf.train.RMSPropOptimizer(self._learning_rate,
+ epsilon=1e-6)
+
+ def loss(self, labels, logits):
+ """Calculate the loss given a batch of data.
+
+ Args:
+ labels: The truth labels, with shape (batch_size,).
+ logits: The logits output from the forward pass of the SNLIClassifier
+ model, with shape (batch_size, d_out), where d_out is the output
+ dimension size of the SNLIClassifier.
+
+ Returns:
+ The loss value, as a scalar `Tensor`.
+ """
+ return tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
+ labels=labels, logits=logits))
+
+ def train_batch(self,
+ labels,
+ premise,
+ premise_transition,
+ hypothesis,
+ hypothesis_transition):
+ """Train model on batch of data.
+
+ Args:
+ labels: The truth labels, with shape (batch_size,).
+ premise: The word indices of the premise sentences, with shape
+ (max_prem_seq_len, batch_size).
+ premise_transition: The transitions for the premise sentences, with shape
+ (max_prem_seq_len * 2 - 3, batch_size).
+ hypothesis: The word indices of the hypothesis sentences, with shape
+ (max_hypo_seq_len, batch_size).
+ hypothesis_transition: The transitions for the hypothesis sentences, with
+ shape (max_hypo_seq_len * 2 - 3, batch_size).
+
+ Returns:
+ 1. loss value as a scalar `Tensor`.
+ 2. logits as a dense `Tensor` of shape (batch_size, d_out), where d_out is
+ the output dimension size of the SNLIClassifier.
+ """
+ with tfe.GradientTape() as tape:
+ tape.watch(self._model.variables)
+ logits = self._model(premise,
+ premise_transition,
+ hypothesis,
+ hypothesis_transition,
+ training=True)
+ loss = self.loss(labels, logits)
+ gradients = tape.gradient(loss, self._model.variables)
+ self._optimizer.apply_gradients(zip(gradients, self._model.variables),
+ global_step=tf.train.get_global_step())
+ return loss, logits
+
+ def decay_learning_rate(self, decay_by):
+ """Decay learning rate of the optimizer by factor decay_by."""
+ self._learning_rate.assign(self._learning_rate * decay_by)
+ print("Decayed learning rate of optimizer to: %s" %
+ self._learning_rate.numpy())
+
+ @property
+ def learning_rate(self):
+ return self._learning_rate
+
+
+def _batch_n_correct(logits, label):
+ """Calculate number of correct predictions in a batch.
+
+ Args:
+ logits: A logits Tensor of shape `(batch_size, num_categories)` and dtype
+ `float32`.
+ label: A labels Tensor of shape `(batch_size,)` and dtype `int64`
+
+ Returns:
+ Number of correct predictions.
+ """
+ return tf.reduce_sum(
+ tf.cast((tf.equal(
+ tf.argmax(logits, axis=1), label)), tf.float32)).numpy()
+
+
+def _evaluate_on_dataset(snli_data, batch_size, model, trainer, use_gpu):
+ """Run evaluation on a dataset.
+
+ Args:
+ snli_data: The `data.SnliData` to use in this evaluation.
+ batch_size: The batch size to use during this evaluation.
+ model: An instance of `SNLIClassifier` to evaluate.
+ trainer: An instance of `SNLIClassifierTrainer to use for this
+ evaluation.
+ use_gpu: Whether GPU is being used.
+
+ Returns:
+ 1. Average loss across all examples of the dataset.
+ 2. Average accuracy rate across all examples of the dataset.
+ """
+ mean_loss = tfe.metrics.Mean()
+ accuracy = tfe.metrics.Accuracy()
+ for label, prem, prem_trans, hypo, hypo_trans in _get_dataset_iterator(
+ snli_data, batch_size):
+ if use_gpu:
+ label, prem, hypo = label.gpu(), prem.gpu(), hypo.gpu()
+ logits = model(prem, prem_trans, hypo, hypo_trans, training=False)
+ loss_val = trainer.loss(label, logits)
+ batch_size = tf.shape(label)[0]
+ mean_loss(loss_val, weights=batch_size.gpu() if use_gpu else batch_size)
+ accuracy(tf.argmax(logits, axis=1), label)
+ return mean_loss.result().numpy(), accuracy.result().numpy()
+
+
+def _get_dataset_iterator(snli_data, batch_size):
+ """Get a data iterator for a split of SNLI data.
+
+ Args:
+ snli_data: A `data.SnliData` object.
+ batch_size: The desired batch size.
+
+ Returns:
+ A dataset iterator.
+ """
+ with tf.device("/device:CPU:0"):
+ # Some tf.data ops, such as ShuffleDataset, are available only on CPU.
+ dataset = tf.data.Dataset.from_generator(
+ snli_data.get_generator(batch_size),
+ (tf.int64, tf.int64, tf.int64, tf.int64, tf.int64))
+ dataset = dataset.shuffle(snli_data.num_batches(batch_size))
+ return tfe.Iterator(dataset)
+
+
+def train_spinn(embed, train_data, dev_data, test_data, config):
+ """Train a SPINN model.
+
+ Args:
+ embed: The embedding matrix as a float32 numpy array with shape
+ [vocabulary_size, word_vector_len]. word_vector_len is the length of a
+ word embedding vector.
+ train_data: An instance of `data.SnliData`, for the train split.
+ dev_data: Same as above, for the dev split.
+ test_data: Same as above, for the test split.
+ config: A configuration object. See the argument to this Python binary for
+ details.
+
+ Returns:
+ 1. Final loss value on the test split.
+ 2. Final fraction of correct classifications on the test split.
+ """
+ use_gpu = tfe.num_gpus() > 0 and not config.force_cpu
+ device = "gpu:0" if use_gpu else "cpu:0"
+ print("Using device: %s" % device)
+
+ log_header = (
+ " Time Epoch Iteration Progress (%Epoch) Loss Dev/Loss"
+ " Accuracy Dev/Accuracy")
+ log_template = (
+ "{:>6.0f} {:>5.0f} {:>9.0f} {:>5.0f}/{:<5.0f} {:>7.0f}% {:>8.6f} {} "
+ "{:12.4f} {}")
+ dev_log_template = (
+ "{:>6.0f} {:>5.0f} {:>9.0f} {:>5.0f}/{:<5.0f} {:>7.0f}% {:>8.6f} "
+ "{:8.6f} {:12.4f} {:12.4f}")
+
+ summary_writer = tf.contrib.summary.create_summary_file_writer(
+ config.logdir, flush_millis=10000)
+ train_len = train_data.num_batches(config.batch_size)
+ with tf.device(device), \
+ tfe.restore_variables_on_create(
+ tf.train.latest_checkpoint(config.logdir)), \
+ summary_writer.as_default(), \
+ tf.contrib.summary.always_record_summaries():
+ model = SNLIClassifier(config, embed)
+ global_step = tf.train.get_or_create_global_step()
+ trainer = SNLIClassifierTrainer(model, config.lr)
+
+ start = time.time()
+ iterations = 0
+ mean_loss = tfe.metrics.Mean()
+ accuracy = tfe.metrics.Accuracy()
+ print(log_header)
+ for epoch in xrange(config.epochs):
+ batch_idx = 0
+ for label, prem, prem_trans, hypo, hypo_trans in _get_dataset_iterator(
+ train_data, config.batch_size):
+ if use_gpu:
+ label, prem, hypo = label.gpu(), prem.gpu(), hypo.gpu()
+ # prem_trans and hypo_trans are used for dynamic control flow and can
+ # remain on CPU. Same in _evaluate_on_dataset().
+
+ iterations += 1
+ batch_train_loss, batch_train_logits = trainer.train_batch(
+ label, prem, prem_trans, hypo, hypo_trans)
+ batch_size = tf.shape(label)[0]
+ mean_loss(batch_train_loss.numpy(),
+ weights=batch_size.gpu() if use_gpu else batch_size)
+ accuracy(tf.argmax(batch_train_logits, axis=1), label)
+
+ if iterations % config.save_every == 0:
+ all_variables = (
+ model.variables + [trainer.learning_rate] + [global_step])
+ saver = tfe.Saver(all_variables)
+ saver.save(os.path.join(config.logdir, "ckpt"),
+ global_step=global_step)
+
+ if iterations % config.dev_every == 0:
+ dev_loss, dev_frac_correct = _evaluate_on_dataset(
+ dev_data, config.batch_size, model, trainer, use_gpu)
+ print(dev_log_template.format(
+ time.time() - start,
+ epoch, iterations, 1 + batch_idx, train_len,
+ 100.0 * (1 + batch_idx) / train_len,
+ mean_loss.result(), dev_loss,
+ accuracy.result() * 100.0, dev_frac_correct * 100.0))
+ tf.contrib.summary.scalar("dev/loss", dev_loss)
+ tf.contrib.summary.scalar("dev/accuracy", dev_frac_correct)
+ elif iterations % config.log_every == 0:
+ mean_loss_val = mean_loss.result()
+ accuracy_val = accuracy.result()
+ print(log_template.format(
+ time.time() - start,
+ epoch, iterations, 1 + batch_idx, train_len,
+ 100.0 * (1 + batch_idx) / train_len,
+ mean_loss_val, " " * 8, accuracy_val * 100.0, " " * 12))
+ tf.contrib.summary.scalar("train/loss", mean_loss_val)
+ tf.contrib.summary.scalar("train/accuracy", accuracy_val)
+ # Reset metrics.
+ mean_loss = tfe.metrics.Mean()
+ accuracy = tfe.metrics.Accuracy()
+
+ batch_idx += 1
+ if (epoch + 1) % config.lr_decay_every == 0:
+ trainer.decay_learning_rate(config.lr_decay_by)
+
+ test_loss, test_frac_correct = _evaluate_on_dataset(
+ test_data, config.batch_size, model, trainer, use_gpu)
+ print("Final test loss: %g; accuracy: %g%%" %
+ (test_loss, test_frac_correct * 100.0))
+
+
+def main(_):
+ config = FLAGS
+
+ # Load embedding vectors.
+ vocab = data.load_vocabulary(FLAGS.data_root)
+ word2index, embed = data.load_word_vectors(FLAGS.data_root, vocab)
+
+ print("Loading train, dev and test data...")
+ train_data = data.SnliData(
+ os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_train.txt"),
+ word2index, sentence_len_limit=FLAGS.sentence_len_limit)
+ dev_data = data.SnliData(
+ os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_dev.txt"),
+ word2index, sentence_len_limit=FLAGS.sentence_len_limit)
+ test_data = data.SnliData(
+ os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_test.txt"),
+ word2index, sentence_len_limit=FLAGS.sentence_len_limit)
+
+ train_spinn(embed, train_data, dev_data, test_data, config)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description=
+ "TensorFlow eager implementation of the SPINN SNLI classifier.")
+ parser.add_argument("--data_root", type=str, default="/tmp/spinn-data",
+ help="Root directory in which the training data and "
+ "embedding matrix are found. See README.md for how to "
+ "generate such a directory.")
+ parser.add_argument("--sentence_len_limit", type=int, default=-1,
+ help="Maximum allowed sentence length (# of words). "
+ "The default of -1 means unlimited.")
+ parser.add_argument("--logdir", type=str, default="/tmp/spinn-logs",
+ help="Directory in which summaries will be written for "
+ "TensorBoard.")
+ parser.add_argument("--epochs", type=int, default=50,
+ help="Number of epochs to train.")
+ parser.add_argument("--batch_size", type=int, default=128,
+ help="Batch size to use during training.")
+ parser.add_argument("--d_proj", type=int, default=600,
+ help="Dimensions to project the word embedding vectors "
+ "to.")
+ parser.add_argument("--d_hidden", type=int, default=300,
+ help="Size of the hidden layer of the Tracker.")
+ parser.add_argument("--d_out", type=int, default=4,
+ help="Output dimensions of the SNLIClassifier.")
+ parser.add_argument("--d_mlp", type=int, default=1024,
+ help="Size of each layer of the multi-layer perceptron "
+ "of the SNLICLassifier.")
+ parser.add_argument("--n_mlp_layers", type=int, default=2,
+ help="Number of layers in the multi-layer perceptron "
+ "of the SNLICLassifier.")
+ parser.add_argument("--d_tracker", type=int, default=64,
+ help="Size of the tracker LSTM.")
+ parser.add_argument("--log_every", type=int, default=50,
+ help="Print log and write TensorBoard summary every _ "
+ "training batches.")
+ parser.add_argument("--lr", type=float, default=2e-3,
+ help="Initial learning rate.")
+ parser.add_argument("--lr_decay_by", type=float, default=0.75,
+ help="The ratio to multiply the learning rate by every "
+ "time the learning rate is decayed.")
+ parser.add_argument("--lr_decay_every", type=float, default=1,
+ help="Decay the learning rate every _ epoch(s).")
+ parser.add_argument("--dev_every", type=int, default=1000,
+ help="Run evaluation on the dev split every _ training "
+ "batches.")
+ parser.add_argument("--save_every", type=int, default=1000,
+ help="Save checkpoint every _ training batches.")
+ parser.add_argument("--embed_dropout", type=float, default=0.08,
+ help="Word embedding dropout rate.")
+ parser.add_argument("--mlp_dropout", type=float, default=0.07,
+ help="SNLIClassifier multi-layer perceptron dropout "
+ "rate.")
+ parser.add_argument("--no-projection", action="store_false",
+ dest="projection",
+ help="Whether word embedding vectors are projected to "
+ "another set of vectors (see d_proj).")
+ parser.add_argument("--predict_transitions", action="store_true",
+ dest="predict",
+ help="Whether the Tracker will perform prediction.")
+ parser.add_argument("--force_cpu", action="store_true", dest="force_cpu",
+ help="Force use CPU-only regardless of whether a GPU is "
+ "available.")
+ FLAGS, unparsed = parser.parse_known_args()
+
+ tfe.run(main=main, argv=[sys.argv[0]] + unparsed)