diff options
Diffstat (limited to 'third_party/examples/eager/spinn/spinn.py')
-rw-r--r-- | third_party/examples/eager/spinn/spinn.py | 29 |
1 files changed, 11 insertions, 18 deletions
diff --git a/third_party/examples/eager/spinn/spinn.py b/third_party/examples/eager/spinn/spinn.py index f8fb6ecb0c..8a2b24aa4e 100644 --- a/third_party/examples/eager/spinn/spinn.py +++ b/third_party/examples/eager/spinn/spinn.py @@ -266,8 +266,7 @@ class SPINN(tf.keras.Model): trackings.append(tracking) if rights: - reducer_output = self.reducer( - lefts, right_in=rights, tracking=trackings) + reducer_output = self.reducer(lefts, rights, trackings) reduced = iter(reducer_output) for transition, stack in zip(trans, stacks): @@ -388,10 +387,10 @@ class SNLIClassifier(tf.keras.Model): # Run the batch-normalized and dropout-processed word vectors through the # SPINN encoder. - premise = self.encoder( - premise_embed, transitions=premise_transition, training=training) - hypothesis = self.encoder( - hypothesis_embed, transitions=hypothesis_transition, training=training) + premise = self.encoder(premise_embed, premise_transition, + training=training) + hypothesis = self.encoder(hypothesis_embed, hypothesis_transition, + training=training) # Combine encoder outputs for premises and hypotheses into logits. # Then apply batch normalization and dropuout on the logits. @@ -465,11 +464,10 @@ class SNLIClassifierTrainer(tfe.Checkpointable): """ with tfe.GradientTape() as tape: tape.watch(self._model.variables) - # TODO(allenl): Allow passing Layer inputs as position arguments. logits = self._model(premise, - premise_transition=premise_transition, - hypothesis=hypothesis, - hypothesis_transition=hypothesis_transition, + premise_transition, + hypothesis, + hypothesis_transition, training=True) loss = self.loss(labels, logits) gradients = tape.gradient(loss, self._model.variables) @@ -533,9 +531,7 @@ def _evaluate_on_dataset(snli_data, batch_size, trainer, use_gpu): snli_data, batch_size): if use_gpu: label, prem, hypo = label.gpu(), prem.gpu(), hypo.gpu() - logits = trainer.model( - prem, premise_transition=prem_trans, hypothesis=hypo, - hypothesis_transition=hypo_trans, training=False) + logits = trainer.model(prem, prem_trans, hypo, hypo_trans, training=False) loss_val = trainer.loss(label, logits) batch_size = tf.shape(label)[0] mean_loss(loss_val, weights=batch_size.gpu() if use_gpu else batch_size) @@ -639,11 +635,8 @@ def train_or_infer_spinn(embed, hypo, hypo_trans = inference_sentence_pair[1] hypo_trans = inference_sentence_pair[1][1] inference_logits = model( - tf.constant(prem), - premise_transition=tf.constant(prem_trans), - hypothesis=tf.constant(hypo), - hypothesis_transition=tf.constant(hypo_trans), - training=False) + tf.constant(prem), tf.constant(prem_trans), + tf.constant(hypo), tf.constant(hypo_trans), training=False) inference_logits = inference_logits[0][1:] max_index = tf.argmax(inference_logits) print("\nInference logits:") |