aboutsummaryrefslogtreecommitdiffhomepage
path: root/third_party/examples/eager/spinn/spinn.py
blob: de63ebe9e67d37dcc0ecf309edf1fae89169af5f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
r"""Implementation of SPINN in TensorFlow eager execution.

SPINN: Stack-Augmented Parser-Interpreter Neural Network.

Ths file contains model definition and code for training the model.

The model definition is based on PyTorch implementation at:
  https://github.com/jekbradbury/examples/tree/spinn/snli

which was released under a BSD 3-Clause License at:
https://github.com/jekbradbury/examples/blob/spinn/LICENSE:

Copyright (c) 2017,
All rights reserved.

See ./LICENSE for more details.

Instructions for use:
* See `README.md` for details on how to prepare the SNLI and GloVe data.
* Suppose you have prepared the data at "/tmp/spinn-data", use the folloing
  command to train the model:

  ```bash
  python spinn.py --data_root /tmp/spinn-data --logdir /tmp/spinn-logs
  ```

  Checkpoints and TensorBoard summaries will be written to "/tmp/spinn-logs".

References:
* Bowman, S.R., Gauthier, J., Rastogi A., Gupta, R., Manning, C.D., & Potts, C.
  (2016). A Fast Unified Model for Parsing and Sentence Understanding.
  https://arxiv.org/abs/1603.06021
* Bradbury, J. (2017). Recursive Neural Networks with PyTorch.
  https://devblogs.nvidia.com/parallelforall/recursive-neural-networks-pytorch/
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import itertools
import os
import sys
import time

from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf

import tensorflow.contrib.eager as tfe
from tensorflow.contrib.eager.python.examples.spinn import data


layers = tf.keras.layers


def _bundle(lstm_iter):
  """Concatenate a list of Tensors along 1st axis and split result into two.

  Args:
    lstm_iter: A `list` of `N` dense `Tensor`s, each of which has the shape
      (R, 2 * M).

  Returns:
    A `list` of two dense `Tensor`s, each of which has the shape (N * R, M).
  """
  return tf.split(tf.concat(lstm_iter, 0), 2, axis=1)


def _unbundle(state):
  """Concatenate a list of Tensors along 2nd axis and split result.

  This is the inverse of `_bundle`.

  Args:
    state: A `list` of two dense `Tensor`s, each of which has the shape (R, M).

  Returns:
    A `list` of `R` dense `Tensors`, each of which has the shape (1, 2 * M).
  """
  return tf.split(tf.concat(state, 1), state[0].shape[0], axis=0)


# pylint: disable=not-callable
class Reducer(tf.keras.Model):
  """A module that applies reduce operation on left and right vectors."""

  def __init__(self, size, tracker_size=None):
    super(Reducer, self).__init__()
    self.left = layers.Dense(5 * size, activation=None)
    self.right = layers.Dense(5 * size, activation=None, use_bias=False)
    if tracker_size is not None:
      self.track = layers.Dense(5 * size, activation=None, use_bias=False)
    else:
      self.track = None

  def call(self, left_in, right_in, tracking=None):
    """Invoke forward pass of the Reduce module.

    This method feeds a linear combination of `left_in`, `right_in` and
    `tracking` into a Tree LSTM and returns the output of the Tree LSTM.

    Args:
      left_in: A list of length L. Each item is a dense `Tensor` with
        the shape (1, n_dims). n_dims is the size of the embedding vector.
      right_in: A list of the same length as `left_in`. Each item should have
        the same shape as the items of `left_in`.
      tracking: Optional list of the same length as `left_in`. Each item is a
        dense `Tensor` with shape (1, tracker_size * 2). tracker_size is the
        size of the Tracker's state vector.

    Returns:
      Output: A list of length batch_size. Each item has the shape (1, n_dims).
    """
    left, right = _bundle(left_in), _bundle(right_in)
    lstm_in = self.left(left[0]) + self.right(right[0])
    if self.track and tracking:
      lstm_in += self.track(_bundle(tracking)[0])
    return _unbundle(self._tree_lstm(left[1], right[1], lstm_in))

  def _tree_lstm(self, c1, c2, lstm_in):
    a, i, f1, f2, o = tf.split(lstm_in, 5, axis=1)
    c = tf.tanh(a) * tf.sigmoid(i) + tf.sigmoid(f1) * c1 + tf.sigmoid(f2) * c2
    h = tf.sigmoid(o) * tf.tanh(c)
    return h, c


class Tracker(tf.keras.Model):
  """A module that tracks the history of the sentence with an LSTM."""

  def __init__(self, tracker_size, predict):
    """Constructor of Tracker.

    Args:
      tracker_size: Number of dimensions of the underlying `LSTMCell`.
      predict: (`bool`) Whether prediction mode is enabled.
    """
    super(Tracker, self).__init__()
    self._rnn = tf.nn.rnn_cell.LSTMCell(tracker_size)
    self._state_size = tracker_size
    if predict:
      self._transition = layers.Dense(4)
    else:
      self._transition = None

  def reset_state(self):
    self.state = None

  def call(self, bufs, stacks):
    """Invoke the forward pass of the Tracker module.

    This method feeds the concatenation of the top two elements of the stacks
    into an LSTM cell and returns the resultant state of the LSTM cell.

    Args:
      bufs: A `list` of length batch_size. Each item is a `list` of
        max_sequence_len (maximum sequence length of the batch). Each item
        of the nested list is a dense `Tensor` of shape (1, d_proj), where
        d_proj is the size of the word embedding vector or the size of the
        vector space that the word embedding vector is projected to.
      stacks: A `list` of size batch_size. Each item is a `list` of
        variable length corresponding to the current height of the stack.
        Each item of the nested list is a dense `Tensor` of shape (1, d_proj).

    Returns:
      1. A list of length batch_size. Each item is a dense `Tensor` of shape
        (1, d_tracker * 2).
      2.  If under predict mode, result of applying a Dense layer on the
        first state vector of the RNN. Else, `None`.
    """
    buf = _bundle([buf[-1] for buf in bufs])[0]
    stack1 = _bundle([stack[-1] for stack in stacks])[0]
    stack2 = _bundle([stack[-2] for stack in stacks])[0]
    x = tf.concat([buf, stack1, stack2], 1)
    if self.state is None:
      batch_size = int(x.shape[0])
      zeros = tf.zeros((batch_size, self._state_size), dtype=tf.float32)
      self.state = [zeros, zeros]
    _, self.state = self._rnn(x, self.state)
    unbundled = _unbundle(self.state)
    if self._transition:
      return unbundled, self._transition(self.state[0])
    else:
      return unbundled, None


class SPINN(tf.keras.Model):
  """Stack-augmented Parser-Interpreter Neural Network.

  See https://arxiv.org/abs/1603.06021 for more details.
  """

  def __init__(self, config):
    """Constructor of SPINN.

    Args:
      config: A `namedtupled` with the following attributes.
        d_proj - (`int`) number of dimensions of the vector space to project the
          word embeddings to.
        d_tracker - (`int`) number of dimensions of the Tracker's state vector.
        d_hidden - (`int`) number of the dimensions of the hidden state, for the
          Reducer module.
        n_mlp_layers - (`int`) number of multi-layer perceptron layers to use to
          convert the output of the `Feature` module to logits.
        predict - (`bool`) Whether the Tracker will enabled predictions.
    """
    super(SPINN, self).__init__()
    self.config = config
    self.reducer = Reducer(config.d_hidden, config.d_tracker)
    if config.d_tracker is not None:
      self.tracker = Tracker(config.d_tracker, config.predict)
    else:
      self.tracker = None

  def call(self, buffers, transitions, training=False):
    """Invoke the forward pass of the SPINN model.

    Args:
      buffers: Dense `Tensor` of shape
        (max_sequence_len, batch_size, config.d_proj).
      transitions: Dense `Tensor` with integer values that represent the parse
        trees of the sentences. A value of 2 indicates "reduce"; a value of 3
        indicates "shift". Shape: (max_sequence_len * 2 - 3, batch_size).
      training: Whether the invocation is under training mode.

    Returns:
      Output `Tensor` of shape (batch_size, config.d_embed).
    """
    max_sequence_len, batch_size, d_proj = (int(x) for x in buffers.shape)

    # Split the buffers into left and right word items and put the initial
    # items in a stack.
    splitted = tf.split(
        tf.reshape(tf.transpose(buffers, [1, 0, 2]), [-1, d_proj]),
        max_sequence_len * batch_size, axis=0)
    buffers = [splitted[k:k + max_sequence_len]
               for k in xrange(0, len(splitted), max_sequence_len)]
    stacks = [[buf[0], buf[0]] for buf in buffers]

    if self.tracker:
      # Reset tracker state for new batch.
      self.tracker.reset_state()

    num_transitions = transitions.shape[0]

    # Iterate through transitions and perform the appropriate stack-pop, reduce
    # and stack-push operations.
    transitions = transitions.numpy()
    for i in xrange(num_transitions):
      trans = transitions[i]
      if self.tracker:
        # Invoke tracker to obtain the current tracker states for the sentences.
        tracker_states, trans_hypothesis = self.tracker(buffers, stacks=stacks)
        if trans_hypothesis:
          trans = tf.argmax(trans_hypothesis, axis=-1)
      else:
        tracker_states = itertools.repeat(None)
      lefts, rights, trackings = [], [], []
      for transition, buf, stack, tracking in zip(
          trans, buffers, stacks, tracker_states):
        if int(transition) == 3:  # Shift.
          stack.append(buf.pop())
        elif int(transition) == 2:  # Reduce.
          rights.append(stack.pop())
          lefts.append(stack.pop())
          trackings.append(tracking)

      if rights:
        reducer_output = self.reducer(lefts, rights, trackings)
        reduced = iter(reducer_output)

        for transition, stack in zip(trans, stacks):
          if int(transition) == 2:  # Reduce.
            stack.append(next(reduced))
    return _bundle([stack.pop() for stack in stacks])[0]


class Perceptron(tf.keras.Model):
  """One layer of the SNLIClassifier multi-layer perceptron."""

  def __init__(self, dimension, dropout_rate, previous_layer):
    """Configure the Perceptron."""
    super(Perceptron, self).__init__()
    self.dense = tf.keras.layers.Dense(dimension, activation=tf.nn.elu)
    self.batchnorm = layers.BatchNormalization()
    self.dropout = layers.Dropout(rate=dropout_rate)
    self.previous_layer = previous_layer

  def call(self, x, training):
    """Run previous Perceptron layers, then this one."""
    x = self.previous_layer(x, training=training)
    x = self.dense(x)
    x = self.batchnorm(x, training=training)
    x = self.dropout(x, training=training)
    return x


class SNLIClassifier(tf.keras.Model):
  """SNLI Classifier Model.

  A model aimed at solving the SNLI (Standford Natural Language Inference)
  task, using the SPINN model from above. For details of the task, see:
    https://nlp.stanford.edu/projects/snli/
  """

  def __init__(self, config, embed):
    """Constructor of SNLICLassifier.

    Args:
      config: A namedtuple containing required configurations for the model. It
        needs to have the following attributes.
        projection - (`bool`) whether the word vectors are to be projected onto
          another vector space (of `d_proj` dimensions).
        d_proj - (`int`) number of dimensions of the vector space to project the
          word embeddings to.
        embed_dropout - (`float`) dropout rate for the word embedding vectors.
        n_mlp_layers - (`int`) number of multi-layer perceptron (MLP) layers to
          use to convert the output of the `Feature` module to logits.
        mlp_dropout - (`float`) dropout rate of the MLP layers.
        d_out - (`int`) number of dimensions of the final output of the MLP
          layers.
        lr - (`float`) learning rate.
      embed: A embedding matrix of shape (vocab_size, d_embed).
    """
    super(SNLIClassifier, self).__init__()
    self.config = config
    self.embed = tf.constant(embed)

    self.projection = layers.Dense(config.d_proj)
    self.embed_bn = layers.BatchNormalization()
    self.embed_dropout = layers.Dropout(rate=config.embed_dropout)
    self.encoder = SPINN(config)

    self.feature_bn = layers.BatchNormalization()
    self.feature_dropout = layers.Dropout(rate=config.mlp_dropout)

    current_mlp = lambda result, training: result
    for _ in range(config.n_mlp_layers):
      current_mlp = Perceptron(dimension=config.d_mlp,
                               dropout_rate=config.mlp_dropout,
                               previous_layer=current_mlp)
    self.mlp = current_mlp
    self.mlp_output = layers.Dense(
        config.d_out,
        kernel_initializer=tf.random_uniform_initializer(minval=-5e-3,
                                                         maxval=5e-3))

  def call(self,
           premise,
           premise_transition,
           hypothesis,
           hypothesis_transition,
           training=False):
    """Invoke the forward pass the SNLIClassifier model.

    Args:
      premise: The word indices of the premise sentences, with shape
        (max_prem_seq_len, batch_size).
      premise_transition: The transitions for the premise sentences, with shape
        (max_prem_seq_len * 2 - 3, batch_size).
      hypothesis: The word indices of the hypothesis sentences, with shape
        (max_hypo_seq_len, batch_size).
      hypothesis_transition: The transitions for the hypothesis sentences, with
        shape (max_hypo_seq_len * 2 - 3, batch_size).
      training: Whether the invocation is under training mode.

    Returns:
      The logits, as a dense `Tensor` of shape (batch_size, d_out), where d_out
      is the size of the output vector.
    """
    # Perform embedding lookup on the premise and hypothesis inputs, which have
    # the word-index format.
    premise_embed = tf.nn.embedding_lookup(self.embed, premise)
    hypothesis_embed = tf.nn.embedding_lookup(self.embed, hypothesis)

    if self.config.projection:
      # Project the embedding vectors to another vector space.
      premise_embed = self.projection(premise_embed)
      hypothesis_embed = self.projection(hypothesis_embed)

    # Perform batch normalization and dropout on the possibly projected word
    # vectors.
    premise_embed = self.embed_bn(premise_embed, training=training)
    hypothesis_embed = self.embed_bn(hypothesis_embed, training=training)
    premise_embed = self.embed_dropout(premise_embed, training=training)
    hypothesis_embed = self.embed_dropout(hypothesis_embed, training=training)

    # Run the batch-normalized and dropout-processed word vectors through the
    # SPINN encoder.
    premise = self.encoder(premise_embed, premise_transition,
                           training=training)
    hypothesis = self.encoder(hypothesis_embed, hypothesis_transition,
                              training=training)

    # Combine encoder outputs for premises and hypotheses into logits.
    # Then apply batch normalization and dropuout on the logits.
    logits = tf.concat(
        [premise, hypothesis, premise - hypothesis, premise * hypothesis], 1)
    logits = self.feature_dropout(
        self.feature_bn(logits, training=training), training=training)

    # Apply the multi-layer perceptron on the logits.
    logits = self.mlp(logits, training=training)
    logits = self.mlp_output(logits)
    return logits


class SNLIClassifierTrainer(tfe.Checkpointable):
  """A class that coordinates the training of an SNLIClassifier."""

  def __init__(self, snli_classifier, lr):
    """Constructor of SNLIClassifierTrainer.

    Args:
      snli_classifier: An instance of `SNLIClassifier`.
      lr: Learning rate.
    """
    self._model = snli_classifier
    # Create a custom learning rate Variable for the RMSProp optimizer, because
    # the learning rate needs to be manually decayed later (see
    # decay_learning_rate()).
    self._learning_rate = tf.Variable(lr, name="learning_rate")
    self._optimizer = tf.train.RMSPropOptimizer(self._learning_rate,
                                                epsilon=1e-6)

  def loss(self, labels, logits):
    """Calculate the loss given a batch of data.

    Args:
      labels: The truth labels, with shape (batch_size,).
      logits: The logits output from the forward pass of the SNLIClassifier
        model, with shape (batch_size, d_out), where d_out is the output
        dimension size of the SNLIClassifier.

    Returns:
      The loss value, as a scalar `Tensor`.
    """
    return tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=labels, logits=logits))

  def train_batch(self,
                  labels,
                  premise,
                  premise_transition,
                  hypothesis,
                  hypothesis_transition):
    """Train model on batch of data.

    Args:
      labels: The truth labels, with shape (batch_size,).
      premise: The word indices of the premise sentences, with shape
        (max_prem_seq_len, batch_size).
      premise_transition: The transitions for the premise sentences, with shape
        (max_prem_seq_len * 2 - 3, batch_size).
      hypothesis: The word indices of the hypothesis sentences, with shape
        (max_hypo_seq_len, batch_size).
      hypothesis_transition: The transitions for the hypothesis sentences, with
        shape (max_hypo_seq_len * 2 - 3, batch_size).

    Returns:
      1. loss value as a scalar `Tensor`.
      2. logits as a dense `Tensor` of shape (batch_size, d_out), where d_out is
        the output dimension size of the SNLIClassifier.
    """
    with tf.GradientTape() as tape:
      tape.watch(self._model.variables)
      logits = self._model(premise,
                           premise_transition,
                           hypothesis,
                           hypothesis_transition,
                           training=True)
      loss = self.loss(labels, logits)
    gradients = tape.gradient(loss, self._model.variables)
    self._optimizer.apply_gradients(zip(gradients, self._model.variables),
                                    global_step=tf.train.get_global_step())
    return loss, logits

  def decay_learning_rate(self, decay_by):
    """Decay learning rate of the optimizer by factor decay_by."""
    self._learning_rate.assign(self._learning_rate * decay_by)
    print("Decayed learning rate of optimizer to: %s" %
          self._learning_rate.numpy())

  @property
  def learning_rate(self):
    return self._learning_rate

  @property
  def model(self):
    return self._model

  @property
  def variables(self):
    return (self._model.variables + [self.learning_rate] +
            self._optimizer.variables())


def _batch_n_correct(logits, label):
  """Calculate number of correct predictions in a batch.

  Args:
    logits: A logits Tensor of shape `(batch_size, num_categories)` and dtype
      `float32`.
    label: A labels Tensor of shape `(batch_size,)` and dtype `int64`

  Returns:
    Number of correct predictions.
  """
  return tf.reduce_sum(
      tf.cast((tf.equal(
          tf.argmax(logits, axis=1), label)), tf.float32)).numpy()


def _evaluate_on_dataset(snli_data, batch_size, trainer, use_gpu):
  """Run evaluation on a dataset.

  Args:
    snli_data: The `data.SnliData` to use in this evaluation.
    batch_size: The batch size to use during this evaluation.
    trainer: An instance of `SNLIClassifierTrainer to use for this
      evaluation.
    use_gpu: Whether GPU is being used.

  Returns:
    1. Average loss across all examples of the dataset.
    2. Average accuracy rate across all examples of the dataset.
  """
  mean_loss = tfe.metrics.Mean()
  accuracy = tfe.metrics.Accuracy()
  for label, prem, prem_trans, hypo, hypo_trans in _get_dataset_iterator(
      snli_data, batch_size):
    if use_gpu:
      label, prem, hypo = label.gpu(), prem.gpu(), hypo.gpu()
    logits = trainer.model(prem, prem_trans, hypo, hypo_trans, training=False)
    loss_val = trainer.loss(label, logits)
    batch_size = tf.shape(label)[0]
    mean_loss(loss_val, weights=batch_size.gpu() if use_gpu else batch_size)
    accuracy(tf.argmax(logits, axis=1), label)
  return mean_loss.result().numpy(), accuracy.result().numpy()


def _get_dataset_iterator(snli_data, batch_size):
  """Get a data iterator for a split of SNLI data.

  Args:
    snli_data: A `data.SnliData` object.
    batch_size: The desired batch size.

  Returns:
    A dataset iterator.
  """
  with tf.device("/device:CPU:0"):
    # Some tf.data ops, such as ShuffleDataset, are available only on CPU.
    dataset = tf.data.Dataset.from_generator(
        snli_data.get_generator(batch_size),
        (tf.int64, tf.int64, tf.int64, tf.int64, tf.int64))
    dataset = dataset.shuffle(snli_data.num_batches(batch_size))
    return tfe.Iterator(dataset)


def train_or_infer_spinn(embed,
                         word2index,
                         train_data,
                         dev_data,
                         test_data,
                         config):
  """Perform Training or Inference on a SPINN model.

  Args:
    embed: The embedding matrix as a float32 numpy array with shape
      [vocabulary_size, word_vector_len]. word_vector_len is the length of a
      word embedding vector.
    word2index: A `dict` mapping word to word index.
    train_data: An instance of `data.SnliData`, for the train split.
    dev_data: Same as above, for the dev split.
    test_data: Same as above, for the test split.
    config: A configuration object. See the argument to this Python binary for
      details.

  Returns:
    If `config.inference_premise ` and `config.inference_hypothesis` are not
      `None`, i.e., inference mode: the logits for the possible labels of the
      SNLI data set, as a `Tensor` of three floats.
    else:
      The trainer object.
  Raises:
    ValueError: if only one of config.inference_premise and
      config.inference_hypothesis is specified.
  """
  # TODO(cais): Refactor this function into separate one for training and
  #   inference.
  use_gpu = tfe.num_gpus() > 0 and not config.force_cpu
  device = "gpu:0" if use_gpu else "cpu:0"
  print("Using device: %s" % device)

  if ((config.inference_premise and not config.inference_hypothesis) or
      (not config.inference_premise and config.inference_hypothesis)):
    raise ValueError(
        "--inference_premise and --inference_hypothesis must be both "
        "specified or both unspecified, but only one is specified.")

  if config.inference_premise:
    # Inference mode.
    inference_sentence_pair = [
        data.encode_sentence(config.inference_premise, word2index),
        data.encode_sentence(config.inference_hypothesis, word2index)]
  else:
    inference_sentence_pair = None

  log_header = (
      "  Time Epoch Iteration Progress    (%Epoch)   Loss   Dev/Loss"
      "     Accuracy  Dev/Accuracy")
  log_template = (
      "{:>6.0f} {:>5.0f} {:>9.0f} {:>5.0f}/{:<5.0f} {:>7.0f}% {:>8.6f} {} "
      "{:12.4f} {}")
  dev_log_template = (
      "{:>6.0f} {:>5.0f} {:>9.0f} {:>5.0f}/{:<5.0f} {:>7.0f}% {:>8.6f} "
      "{:8.6f} {:12.4f} {:12.4f}")

  summary_writer = tf.contrib.summary.create_file_writer(
      config.logdir, flush_millis=10000)

  with tf.device(device), \
       summary_writer.as_default(), \
       tf.contrib.summary.always_record_summaries():
    model = SNLIClassifier(config, embed)
    global_step = tf.train.get_or_create_global_step()
    trainer = SNLIClassifierTrainer(model, config.lr)
    checkpoint = tf.train.Checkpoint(trainer=trainer, global_step=global_step)
    checkpoint.restore(tf.train.latest_checkpoint(config.logdir))

    if inference_sentence_pair:
      # Inference mode.
      prem, prem_trans = inference_sentence_pair[0]
      hypo, hypo_trans = inference_sentence_pair[1]
      hypo_trans = inference_sentence_pair[1][1]
      inference_logits = model(
          tf.constant(prem), tf.constant(prem_trans),
          tf.constant(hypo), tf.constant(hypo_trans), training=False)
      inference_logits = inference_logits[0][1:]
      max_index = tf.argmax(inference_logits)
      print("\nInference logits:")
      for i, (label, logit) in enumerate(
          zip(data.POSSIBLE_LABELS, inference_logits)):
        winner_tag = " (winner)" if max_index == i else ""
        print("  {0:<16}{1:.6f}{2}".format(label + ":", logit, winner_tag))
      return inference_logits

    train_len = train_data.num_batches(config.batch_size)
    start = time.time()
    iterations = 0
    mean_loss = tfe.metrics.Mean()
    accuracy = tfe.metrics.Accuracy()
    print(log_header)
    for epoch in xrange(config.epochs):
      batch_idx = 0
      for label, prem, prem_trans, hypo, hypo_trans in _get_dataset_iterator(
          train_data, config.batch_size):
        if use_gpu:
          label, prem, hypo = label.gpu(), prem.gpu(), hypo.gpu()
          # prem_trans and hypo_trans are used for dynamic control flow and can
          # remain on CPU. Same in _evaluate_on_dataset().

        iterations += 1
        batch_train_loss, batch_train_logits = trainer.train_batch(
            label, prem, prem_trans, hypo, hypo_trans)
        batch_size = tf.shape(label)[0]
        mean_loss(batch_train_loss.numpy(),
                  weights=batch_size.gpu() if use_gpu else batch_size)
        accuracy(tf.argmax(batch_train_logits, axis=1), label)

        if iterations % config.save_every == 0:
          checkpoint.save(os.path.join(config.logdir, "ckpt"))

        if iterations % config.dev_every == 0:
          dev_loss, dev_frac_correct = _evaluate_on_dataset(
              dev_data, config.batch_size, trainer, use_gpu)
          print(dev_log_template.format(
              time.time() - start,
              epoch, iterations, 1 + batch_idx, train_len,
              100.0 * (1 + batch_idx) / train_len,
              mean_loss.result(), dev_loss,
              accuracy.result() * 100.0, dev_frac_correct * 100.0))
          tf.contrib.summary.scalar("dev/loss", dev_loss)
          tf.contrib.summary.scalar("dev/accuracy", dev_frac_correct)
        elif iterations % config.log_every == 0:
          mean_loss_val = mean_loss.result()
          accuracy_val = accuracy.result()
          print(log_template.format(
              time.time() - start,
              epoch, iterations, 1 + batch_idx, train_len,
              100.0 * (1 + batch_idx) / train_len,
              mean_loss_val, " " * 8, accuracy_val * 100.0, " " * 12))
          tf.contrib.summary.scalar("train/loss", mean_loss_val)
          tf.contrib.summary.scalar("train/accuracy", accuracy_val)
          # Reset metrics.
          mean_loss = tfe.metrics.Mean()
          accuracy = tfe.metrics.Accuracy()

        batch_idx += 1
      if (epoch + 1) % config.lr_decay_every == 0:
        trainer.decay_learning_rate(config.lr_decay_by)

    test_loss, test_frac_correct = _evaluate_on_dataset(
        test_data, config.batch_size, trainer, use_gpu)
    print("Final test loss: %g; accuracy: %g%%" %
          (test_loss, test_frac_correct * 100.0))

  return trainer


def main(_):
  config = FLAGS

  # Load embedding vectors.
  vocab = data.load_vocabulary(FLAGS.data_root)
  word2index, embed = data.load_word_vectors(FLAGS.data_root, vocab)

  if not (config.inference_premise or config.inference_hypothesis):
    print("Loading train, dev and test data...")
    train_data = data.SnliData(
        os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_train.txt"),
        word2index, sentence_len_limit=FLAGS.sentence_len_limit)
    dev_data = data.SnliData(
        os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_dev.txt"),
        word2index, sentence_len_limit=FLAGS.sentence_len_limit)
    test_data = data.SnliData(
        os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_test.txt"),
        word2index, sentence_len_limit=FLAGS.sentence_len_limit)
  else:
    train_data = None
    dev_data = None
    test_data = None

  train_or_infer_spinn(
      embed, word2index, train_data, dev_data, test_data, config)


if __name__ == "__main__":
  parser = argparse.ArgumentParser(
      description=
      "TensorFlow eager implementation of the SPINN SNLI classifier.")
  parser.add_argument("--data_root", type=str, default="/tmp/spinn-data",
                      help="Root directory in which the training data and "
                      "embedding matrix are found. See README.md for how to "
                      "generate such a directory.")
  parser.add_argument("--sentence_len_limit", type=int, default=-1,
                      help="Maximum allowed sentence length (# of words). "
                      "The default of -1 means unlimited.")
  parser.add_argument("--logdir", type=str, default="/tmp/spinn-logs",
                      help="Directory in which summaries will be written for "
                      "TensorBoard.")
  parser.add_argument("--inference_premise", type=str, default=None,
                      help="Premise sentence for inference. Must be "
                      "accompanied by --inference_hypothesis. If specified, "
                      "will override all training parameters and perform "
                      "inference.")
  parser.add_argument("--inference_hypothesis", type=str, default=None,
                      help="Hypothesis sentence for inference. Must be "
                      "accompanied by --inference_premise. If specified, will "
                      "override all training parameters and perform inference.")
  parser.add_argument("--epochs", type=int, default=50,
                      help="Number of epochs to train.")
  parser.add_argument("--batch_size", type=int, default=128,
                      help="Batch size to use during training.")
  parser.add_argument("--d_proj", type=int, default=600,
                      help="Dimensions to project the word embedding vectors "
                      "to.")
  parser.add_argument("--d_hidden", type=int, default=300,
                      help="Size of the hidden layer of the Tracker.")
  parser.add_argument("--d_out", type=int, default=4,
                      help="Output dimensions of the SNLIClassifier.")
  parser.add_argument("--d_mlp", type=int, default=1024,
                      help="Size of each layer of the multi-layer perceptron "
                      "of the SNLICLassifier.")
  parser.add_argument("--n_mlp_layers", type=int, default=2,
                      help="Number of layers in the multi-layer perceptron "
                      "of the SNLICLassifier.")
  parser.add_argument("--d_tracker", type=int, default=64,
                      help="Size of the tracker LSTM.")
  parser.add_argument("--log_every", type=int, default=50,
                      help="Print log and write TensorBoard summary every _ "
                      "training batches.")
  parser.add_argument("--lr", type=float, default=2e-3,
                      help="Initial learning rate.")
  parser.add_argument("--lr_decay_by", type=float, default=0.75,
                      help="The ratio to multiply the learning rate by every "
                      "time the learning rate is decayed.")
  parser.add_argument("--lr_decay_every", type=float, default=1,
                      help="Decay the learning rate every _ epoch(s).")
  parser.add_argument("--dev_every", type=int, default=1000,
                      help="Run evaluation on the dev split every _ training "
                      "batches.")
  parser.add_argument("--save_every", type=int, default=1000,
                      help="Save checkpoint every _ training batches.")
  parser.add_argument("--embed_dropout", type=float, default=0.08,
                      help="Word embedding dropout rate.")
  parser.add_argument("--mlp_dropout", type=float, default=0.07,
                      help="SNLIClassifier multi-layer perceptron dropout "
                      "rate.")
  parser.add_argument("--no-projection", action="store_false",
                      dest="projection",
                      help="Whether word embedding vectors are projected to "
                      "another set of vectors (see d_proj).")
  parser.add_argument("--predict_transitions", action="store_true",
                      dest="predict",
                      help="Whether the Tracker will perform prediction.")
  parser.add_argument("--force_cpu", action="store_true", dest="force_cpu",
                      help="Force use CPU-only regardless of whether a GPU is "
                      "available.")
  FLAGS, unparsed = parser.parse_known_args()

  tfe.run(main=main, argv=[sys.argv[0]] + unparsed)