aboutsummaryrefslogtreecommitdiffhomepage
path: root/third_party/examples
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-03-23 15:12:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-25 04:13:12 -0700
commitdb51253fce5882bf766e19b97131d90f0947d0df (patch)
treeed75f26f706dc089af7f5592f650d52756f4bfc4 /third_party/examples
parentfce07c395d7c3931bc809183031c232651eb0638 (diff)
Convert the eager SPINN example to use tf.keras.Model and object-based checkpointing.
Uses a more recursive/functional tracking style which avoids numbering layers. Maybe this is too magical and we should adapt tf.keras.Sequential first? Let me know what you think. PiperOrigin-RevId: 190282346
Diffstat (limited to 'third_party/examples')
-rw-r--r--third_party/examples/eager/spinn/spinn.py168
1 files changed, 91 insertions, 77 deletions
diff --git a/third_party/examples/eager/spinn/spinn.py b/third_party/examples/eager/spinn/spinn.py
index 8a1c7db2ea..f8fb6ecb0c 100644
--- a/third_party/examples/eager/spinn/spinn.py
+++ b/third_party/examples/eager/spinn/spinn.py
@@ -51,6 +51,9 @@ import tensorflow.contrib.eager as tfe
from tensorflow.contrib.eager.python.examples.spinn import data
+layers = tf.keras.layers
+
+
def _bundle(lstm_iter):
"""Concatenate a list of Tensors along 1st axis and split result into two.
@@ -78,17 +81,16 @@ def _unbundle(state):
return tf.split(tf.concat(state, 1), state[0].shape[0], axis=0)
-class Reducer(tfe.Network):
+# pylint: disable=not-callable
+class Reducer(tf.keras.Model):
"""A module that applies reduce operation on left and right vectors."""
def __init__(self, size, tracker_size=None):
super(Reducer, self).__init__()
- self.left = self.track_layer(tf.layers.Dense(5 * size, activation=None))
- self.right = self.track_layer(
- tf.layers.Dense(5 * size, activation=None, use_bias=False))
+ self.left = layers.Dense(5 * size, activation=None)
+ self.right = layers.Dense(5 * size, activation=None, use_bias=False)
if tracker_size is not None:
- self.track = self.track_layer(
- tf.layers.Dense(5 * size, activation=None, use_bias=False))
+ self.track = layers.Dense(5 * size, activation=None, use_bias=False)
else:
self.track = None
@@ -123,7 +125,7 @@ class Reducer(tfe.Network):
return h, c
-class Tracker(tfe.Network):
+class Tracker(tf.keras.Model):
"""A module that tracks the history of the sentence with an LSTM."""
def __init__(self, tracker_size, predict):
@@ -134,10 +136,10 @@ class Tracker(tfe.Network):
predict: (`bool`) Whether prediction mode is enabled.
"""
super(Tracker, self).__init__()
- self._rnn = self.track_layer(tf.nn.rnn_cell.LSTMCell(tracker_size))
+ self._rnn = tf.nn.rnn_cell.LSTMCell(tracker_size)
self._state_size = tracker_size
if predict:
- self._transition = self.track_layer(tf.layers.Dense(4))
+ self._transition = layers.Dense(4)
else:
self._transition = None
@@ -182,7 +184,7 @@ class Tracker(tfe.Network):
return unbundled, None
-class SPINN(tfe.Network):
+class SPINN(tf.keras.Model):
"""Stack-augmented Parser-Interpreter Neural Network.
See https://arxiv.org/abs/1603.06021 for more details.
@@ -204,9 +206,9 @@ class SPINN(tfe.Network):
"""
super(SPINN, self).__init__()
self.config = config
- self.reducer = self.track_layer(Reducer(config.d_hidden, config.d_tracker))
+ self.reducer = Reducer(config.d_hidden, config.d_tracker)
if config.d_tracker is not None:
- self.tracker = self.track_layer(Tracker(config.d_tracker, config.predict))
+ self.tracker = Tracker(config.d_tracker, config.predict)
else:
self.tracker = None
@@ -248,7 +250,7 @@ class SPINN(tfe.Network):
trans = transitions[i]
if self.tracker:
# Invoke tracker to obtain the current tracker states for the sentences.
- tracker_states, trans_hypothesis = self.tracker(buffers, stacks)
+ tracker_states, trans_hypothesis = self.tracker(buffers, stacks=stacks)
if trans_hypothesis:
trans = tf.argmax(trans_hypothesis, axis=-1)
else:
@@ -264,7 +266,8 @@ class SPINN(tfe.Network):
trackings.append(tracking)
if rights:
- reducer_output = self.reducer(lefts, rights, trackings)
+ reducer_output = self.reducer(
+ lefts, right_in=rights, tracking=trackings)
reduced = iter(reducer_output)
for transition, stack in zip(trans, stacks):
@@ -273,7 +276,27 @@ class SPINN(tfe.Network):
return _bundle([stack.pop() for stack in stacks])[0]
-class SNLIClassifier(tfe.Network):
+class Perceptron(tf.keras.Model):
+ """One layer of the SNLIClassifier multi-layer perceptron."""
+
+ def __init__(self, dimension, dropout_rate, previous_layer):
+ """Configure the Perceptron."""
+ super(Perceptron, self).__init__()
+ self.dense = tf.keras.layers.Dense(dimension, activation=tf.nn.elu)
+ self.batchnorm = layers.BatchNormalization()
+ self.dropout = layers.Dropout(rate=dropout_rate)
+ self.previous_layer = previous_layer
+
+ def call(self, x, training):
+ """Run previous Perceptron layers, then this one."""
+ x = self.previous_layer(x, training=training)
+ x = self.dense(x)
+ x = self.batchnorm(x, training=training)
+ x = self.dropout(x, training=training)
+ return x
+
+
+class SNLIClassifier(tf.keras.Model):
"""SNLI Classifier Model.
A model aimed at solving the SNLI (Standford Natural Language Inference)
@@ -304,29 +327,24 @@ class SNLIClassifier(tfe.Network):
self.config = config
self.embed = tf.constant(embed)
- self.projection = self.track_layer(tf.layers.Dense(config.d_proj))
- self.embed_bn = self.track_layer(tf.layers.BatchNormalization())
- self.embed_dropout = self.track_layer(
- tf.layers.Dropout(rate=config.embed_dropout))
- self.encoder = self.track_layer(SPINN(config))
-
- self.feature_bn = self.track_layer(tf.layers.BatchNormalization())
- self.feature_dropout = self.track_layer(
- tf.layers.Dropout(rate=config.mlp_dropout))
-
- self.mlp_dense = []
- self.mlp_bn = []
- self.mlp_dropout = []
- for _ in xrange(config.n_mlp_layers):
- self.mlp_dense.append(self.track_layer(tf.layers.Dense(config.d_mlp)))
- self.mlp_bn.append(
- self.track_layer(tf.layers.BatchNormalization()))
- self.mlp_dropout.append(
- self.track_layer(tf.layers.Dropout(rate=config.mlp_dropout)))
- self.mlp_output = self.track_layer(tf.layers.Dense(
+ self.projection = layers.Dense(config.d_proj)
+ self.embed_bn = layers.BatchNormalization()
+ self.embed_dropout = layers.Dropout(rate=config.embed_dropout)
+ self.encoder = SPINN(config)
+
+ self.feature_bn = layers.BatchNormalization()
+ self.feature_dropout = layers.Dropout(rate=config.mlp_dropout)
+
+ current_mlp = lambda result, training: result
+ for _ in range(config.n_mlp_layers):
+ current_mlp = Perceptron(dimension=config.d_mlp,
+ dropout_rate=config.mlp_dropout,
+ previous_layer=current_mlp)
+ self.mlp = current_mlp
+ self.mlp_output = layers.Dense(
config.d_out,
kernel_initializer=tf.random_uniform_initializer(minval=-5e-3,
- maxval=5e-3)))
+ maxval=5e-3))
def call(self,
premise,
@@ -370,10 +388,10 @@ class SNLIClassifier(tfe.Network):
# Run the batch-normalized and dropout-processed word vectors through the
# SPINN encoder.
- premise = self.encoder(premise_embed, premise_transition,
- training=training)
- hypothesis = self.encoder(hypothesis_embed, hypothesis_transition,
- training=training)
+ premise = self.encoder(
+ premise_embed, transitions=premise_transition, training=training)
+ hypothesis = self.encoder(
+ hypothesis_embed, transitions=hypothesis_transition, training=training)
# Combine encoder outputs for premises and hypotheses into logits.
# Then apply batch normalization and dropuout on the logits.
@@ -383,15 +401,12 @@ class SNLIClassifier(tfe.Network):
self.feature_bn(logits, training=training), training=training)
# Apply the multi-layer perceptron on the logits.
- for dense, bn, dropout in zip(
- self.mlp_dense, self.mlp_bn, self.mlp_dropout):
- logits = tf.nn.elu(dense(logits))
- logits = dropout(bn(logits, training=training), training=training)
+ logits = self.mlp(logits, training=training)
logits = self.mlp_output(logits)
return logits
-class SNLIClassifierTrainer(object):
+class SNLIClassifierTrainer(tfe.Checkpointable):
"""A class that coordinates the training of an SNLIClassifier."""
def __init__(self, snli_classifier, lr):
@@ -450,10 +465,11 @@ class SNLIClassifierTrainer(object):
"""
with tfe.GradientTape() as tape:
tape.watch(self._model.variables)
+ # TODO(allenl): Allow passing Layer inputs as position arguments.
logits = self._model(premise,
- premise_transition,
- hypothesis,
- hypothesis_transition,
+ premise_transition=premise_transition,
+ hypothesis=hypothesis,
+ hypothesis_transition=hypothesis_transition,
training=True)
loss = self.loss(labels, logits)
gradients = tape.gradient(loss, self._model.variables)
@@ -517,7 +533,9 @@ def _evaluate_on_dataset(snli_data, batch_size, trainer, use_gpu):
snli_data, batch_size):
if use_gpu:
label, prem, hypo = label.gpu(), prem.gpu(), hypo.gpu()
- logits = trainer.model(prem, prem_trans, hypo, hypo_trans, training=False)
+ logits = trainer.model(
+ prem, premise_transition=prem_trans, hypothesis=hypo,
+ hypothesis_transition=hypo_trans, training=False)
loss_val = trainer.loss(label, logits)
batch_size = tf.shape(label)[0]
mean_loss(loss_val, weights=batch_size.gpu() if use_gpu else batch_size)
@@ -609,29 +627,30 @@ def train_or_infer_spinn(embed,
with tf.device(device), \
summary_writer.as_default(), \
tf.contrib.summary.always_record_summaries():
- with tfe.restore_variables_on_create(
- tf.train.latest_checkpoint(config.logdir)):
- model = SNLIClassifier(config, embed)
- global_step = tf.train.get_or_create_global_step()
- trainer = SNLIClassifierTrainer(model, config.lr)
+ model = SNLIClassifier(config, embed)
+ global_step = tf.train.get_or_create_global_step()
+ trainer = SNLIClassifierTrainer(model, config.lr)
+ checkpoint = tfe.Checkpoint(trainer=trainer, global_step=global_step)
+ checkpoint.restore(tf.train.latest_checkpoint(config.logdir))
if inference_sentence_pair:
# Inference mode.
- with tfe.restore_variables_on_create(
- tf.train.latest_checkpoint(config.logdir)):
- prem, prem_trans = inference_sentence_pair[0]
- hypo, hypo_trans = inference_sentence_pair[1]
- hypo_trans = inference_sentence_pair[1][1]
- inference_logits = model( # pylint: disable=not-callable
- tf.constant(prem), tf.constant(prem_trans),
- tf.constant(hypo), tf.constant(hypo_trans), training=False)
- inference_logits = inference_logits[0][1:]
- max_index = tf.argmax(inference_logits)
- print("\nInference logits:")
- for i, (label, logit) in enumerate(
- zip(data.POSSIBLE_LABELS, inference_logits)):
- winner_tag = " (winner)" if max_index == i else ""
- print(" {0:<16}{1:.6f}{2}".format(label + ":", logit, winner_tag))
+ prem, prem_trans = inference_sentence_pair[0]
+ hypo, hypo_trans = inference_sentence_pair[1]
+ hypo_trans = inference_sentence_pair[1][1]
+ inference_logits = model(
+ tf.constant(prem),
+ premise_transition=tf.constant(prem_trans),
+ hypothesis=tf.constant(hypo),
+ hypothesis_transition=tf.constant(hypo_trans),
+ training=False)
+ inference_logits = inference_logits[0][1:]
+ max_index = tf.argmax(inference_logits)
+ print("\nInference logits:")
+ for i, (label, logit) in enumerate(
+ zip(data.POSSIBLE_LABELS, inference_logits)):
+ winner_tag = " (winner)" if max_index == i else ""
+ print(" {0:<16}{1:.6f}{2}".format(label + ":", logit, winner_tag))
return inference_logits
train_len = train_data.num_batches(config.batch_size)
@@ -650,20 +669,15 @@ def train_or_infer_spinn(embed,
# remain on CPU. Same in _evaluate_on_dataset().
iterations += 1
- with tfe.restore_variables_on_create(
- tf.train.latest_checkpoint(config.logdir)):
- batch_train_loss, batch_train_logits = trainer.train_batch(
- label, prem, prem_trans, hypo, hypo_trans)
+ batch_train_loss, batch_train_logits = trainer.train_batch(
+ label, prem, prem_trans, hypo, hypo_trans)
batch_size = tf.shape(label)[0]
mean_loss(batch_train_loss.numpy(),
weights=batch_size.gpu() if use_gpu else batch_size)
accuracy(tf.argmax(batch_train_logits, axis=1), label)
if iterations % config.save_every == 0:
- all_variables = trainer.variables + [global_step]
- saver = tfe.Saver(all_variables)
- saver.save(os.path.join(config.logdir, "ckpt"),
- global_step=global_step)
+ checkpoint.save(os.path.join(config.logdir, "ckpt"))
if iterations % config.dev_every == 0:
dev_loss, dev_frac_correct = _evaluate_on_dataset(