aboutsummaryrefslogtreecommitdiffhomepage
path: root/third_party/examples/eager/spinn/spinn.py
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/examples/eager/spinn/spinn.py')
-rw-r--r--third_party/examples/eager/spinn/spinn.py29
1 files changed, 11 insertions, 18 deletions
diff --git a/third_party/examples/eager/spinn/spinn.py b/third_party/examples/eager/spinn/spinn.py
index f8fb6ecb0c..8a2b24aa4e 100644
--- a/third_party/examples/eager/spinn/spinn.py
+++ b/third_party/examples/eager/spinn/spinn.py
@@ -266,8 +266,7 @@ class SPINN(tf.keras.Model):
trackings.append(tracking)
if rights:
- reducer_output = self.reducer(
- lefts, right_in=rights, tracking=trackings)
+ reducer_output = self.reducer(lefts, rights, trackings)
reduced = iter(reducer_output)
for transition, stack in zip(trans, stacks):
@@ -388,10 +387,10 @@ class SNLIClassifier(tf.keras.Model):
# Run the batch-normalized and dropout-processed word vectors through the
# SPINN encoder.
- premise = self.encoder(
- premise_embed, transitions=premise_transition, training=training)
- hypothesis = self.encoder(
- hypothesis_embed, transitions=hypothesis_transition, training=training)
+ premise = self.encoder(premise_embed, premise_transition,
+ training=training)
+ hypothesis = self.encoder(hypothesis_embed, hypothesis_transition,
+ training=training)
# Combine encoder outputs for premises and hypotheses into logits.
# Then apply batch normalization and dropuout on the logits.
@@ -465,11 +464,10 @@ class SNLIClassifierTrainer(tfe.Checkpointable):
"""
with tfe.GradientTape() as tape:
tape.watch(self._model.variables)
- # TODO(allenl): Allow passing Layer inputs as position arguments.
logits = self._model(premise,
- premise_transition=premise_transition,
- hypothesis=hypothesis,
- hypothesis_transition=hypothesis_transition,
+ premise_transition,
+ hypothesis,
+ hypothesis_transition,
training=True)
loss = self.loss(labels, logits)
gradients = tape.gradient(loss, self._model.variables)
@@ -533,9 +531,7 @@ def _evaluate_on_dataset(snli_data, batch_size, trainer, use_gpu):
snli_data, batch_size):
if use_gpu:
label, prem, hypo = label.gpu(), prem.gpu(), hypo.gpu()
- logits = trainer.model(
- prem, premise_transition=prem_trans, hypothesis=hypo,
- hypothesis_transition=hypo_trans, training=False)
+ logits = trainer.model(prem, prem_trans, hypo, hypo_trans, training=False)
loss_val = trainer.loss(label, logits)
batch_size = tf.shape(label)[0]
mean_loss(loss_val, weights=batch_size.gpu() if use_gpu else batch_size)
@@ -639,11 +635,8 @@ def train_or_infer_spinn(embed,
hypo, hypo_trans = inference_sentence_pair[1]
hypo_trans = inference_sentence_pair[1][1]
inference_logits = model(
- tf.constant(prem),
- premise_transition=tf.constant(prem_trans),
- hypothesis=tf.constant(hypo),
- hypothesis_transition=tf.constant(hypo_trans),
- training=False)
+ tf.constant(prem), tf.constant(prem_trans),
+ tf.constant(hypo), tf.constant(hypo_trans), training=False)
inference_logits = inference_logits[0][1:]
max_index = tf.argmax(inference_logits)
print("\nInference logits:")