aboutsummaryrefslogtreecommitdiffhomepage
path: root/third_party/examples
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-03-28 10:03:06 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-28 10:05:24 -0700
commit5a213116df09c19c3ee0eecb5fc79444e5671e80 (patch)
treebec028a2db003cc632913321fe70a50f8afdbc21 /third_party/examples
parent119ed5aa2acb6df04595835f6dfa99f5422449f2 (diff)
Allow positional arguments in tf.keras.Model subclasses
Makes the tf.keras.Layer.__call__ signature identical to tf.layers.Layer.__call__, but makes passing positional arguments other than "inputs" an error in most cases. The only case it's allowed is subclassed Models which do not have an "inputs" argument to their call() method. This means subclassed Models no longer need to pass all but the first argument as a keyword argument (or do list packing/unpacking) when call() takes multiple Tensor arguments. Includes errors for cases where whether an argument indicates an input is ambiguous, but otherwise doesn't do much to support non-"inputs" call() signatures for shape inference or deferred Tensors. The definition of an input/non-input is pretty clear, so that cleanup will mostly be tracking down all of the users of "self.call" and getting them to pass inputs as positional arguments if necessary. PiperOrigin-RevId: 190787899
Diffstat (limited to 'third_party/examples')
-rw-r--r--third_party/examples/eager/spinn/spinn.py29
1 files changed, 11 insertions, 18 deletions
diff --git a/third_party/examples/eager/spinn/spinn.py b/third_party/examples/eager/spinn/spinn.py
index f8fb6ecb0c..8a2b24aa4e 100644
--- a/third_party/examples/eager/spinn/spinn.py
+++ b/third_party/examples/eager/spinn/spinn.py
@@ -266,8 +266,7 @@ class SPINN(tf.keras.Model):
trackings.append(tracking)
if rights:
- reducer_output = self.reducer(
- lefts, right_in=rights, tracking=trackings)
+ reducer_output = self.reducer(lefts, rights, trackings)
reduced = iter(reducer_output)
for transition, stack in zip(trans, stacks):
@@ -388,10 +387,10 @@ class SNLIClassifier(tf.keras.Model):
# Run the batch-normalized and dropout-processed word vectors through the
# SPINN encoder.
- premise = self.encoder(
- premise_embed, transitions=premise_transition, training=training)
- hypothesis = self.encoder(
- hypothesis_embed, transitions=hypothesis_transition, training=training)
+ premise = self.encoder(premise_embed, premise_transition,
+ training=training)
+ hypothesis = self.encoder(hypothesis_embed, hypothesis_transition,
+ training=training)
# Combine encoder outputs for premises and hypotheses into logits.
# Then apply batch normalization and dropuout on the logits.
@@ -465,11 +464,10 @@ class SNLIClassifierTrainer(tfe.Checkpointable):
"""
with tfe.GradientTape() as tape:
tape.watch(self._model.variables)
- # TODO(allenl): Allow passing Layer inputs as position arguments.
logits = self._model(premise,
- premise_transition=premise_transition,
- hypothesis=hypothesis,
- hypothesis_transition=hypothesis_transition,
+ premise_transition,
+ hypothesis,
+ hypothesis_transition,
training=True)
loss = self.loss(labels, logits)
gradients = tape.gradient(loss, self._model.variables)
@@ -533,9 +531,7 @@ def _evaluate_on_dataset(snli_data, batch_size, trainer, use_gpu):
snli_data, batch_size):
if use_gpu:
label, prem, hypo = label.gpu(), prem.gpu(), hypo.gpu()
- logits = trainer.model(
- prem, premise_transition=prem_trans, hypothesis=hypo,
- hypothesis_transition=hypo_trans, training=False)
+ logits = trainer.model(prem, prem_trans, hypo, hypo_trans, training=False)
loss_val = trainer.loss(label, logits)
batch_size = tf.shape(label)[0]
mean_loss(loss_val, weights=batch_size.gpu() if use_gpu else batch_size)
@@ -639,11 +635,8 @@ def train_or_infer_spinn(embed,
hypo, hypo_trans = inference_sentence_pair[1]
hypo_trans = inference_sentence_pair[1][1]
inference_logits = model(
- tf.constant(prem),
- premise_transition=tf.constant(prem_trans),
- hypothesis=tf.constant(hypo),
- hypothesis_transition=tf.constant(hypo_trans),
- training=False)
+ tf.constant(prem), tf.constant(prem_trans),
+ tf.constant(hypo), tf.constant(hypo_trans), training=False)
inference_logits = inference_logits[0][1:]
max_index = tf.argmax(inference_logits)
print("\nInference logits:")