aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-03-28 10:03:06 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-28 10:05:24 -0700
commit5a213116df09c19c3ee0eecb5fc79444e5671e80 (patch)
treebec028a2db003cc632913321fe70a50f8afdbc21
parent119ed5aa2acb6df04595835f6dfa99f5422449f2 (diff)
Allow positional arguments in tf.keras.Model subclasses
Makes the tf.keras.Layer.__call__ signature identical to tf.layers.Layer.__call__, but makes passing positional arguments other than "inputs" an error in most cases. The only case it's allowed is subclassed Models which do not have an "inputs" argument to their call() method. This means subclassed Models no longer need to pass all but the first argument as a keyword argument (or do list packing/unpacking) when call() takes multiple Tensor arguments. Includes errors for cases where whether an argument indicates an input is ambiguous, but otherwise doesn't do much to support non-"inputs" call() signatures for shape inference or deferred Tensors. The definition of an input/non-input is pretty clear, so that cleanup will mostly be tracking down all of the users of "self.call" and getting them to pass inputs as positional arguments if necessary. PiperOrigin-RevId: 190787899
-rw-r--r--tensorflow/contrib/eager/python/examples/spinn/spinn_test.py13
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/base_layer.py90
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/network.py9
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training.py5
-rw-r--r--tensorflow/python/keras/_impl/keras/model_subclassing_test.py130
-rw-r--r--tensorflow/python/layers/base.py2
-rw-r--r--third_party/examples/eager/spinn/spinn.py29
7 files changed, 246 insertions, 32 deletions
diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py
index 591d99edcd..9261823d77 100644
--- a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py
+++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py
@@ -173,7 +173,7 @@ class SpinnTest(test_util.TensorFlowTestCase):
right_in.append(tf.random_normal((1, size * 2)))
tracking.append(tf.random_normal((1, tracker_size * 2)))
- out = reducer(left_in, right_in=right_in, tracking=tracking)
+ out = reducer(left_in, right_in, tracking=tracking)
self.assertEqual(batch_size, len(out))
self.assertEqual(tf.float32, out[0].dtype)
self.assertEqual((1, size * 2), out[0].shape)
@@ -227,7 +227,7 @@ class SpinnTest(test_util.TensorFlowTestCase):
self.assertEqual((batch_size, size * 2), stacks[0][0].shape)
for _ in range(2):
- out1, out2 = tracker(bufs, stacks=stacks)
+ out1, out2 = tracker(bufs, stacks)
self.assertIsNone(out2)
self.assertEqual(batch_size, len(out1))
self.assertEqual(tf.float32, out1[0].dtype)
@@ -260,7 +260,7 @@ class SpinnTest(test_util.TensorFlowTestCase):
self.assertEqual(tf.int64, transitions.dtype)
self.assertEqual((num_transitions, 1), transitions.shape)
- out = s(buffers, transitions=transitions, training=True)
+ out = s(buffers, transitions, training=True)
self.assertEqual(tf.float32, out.dtype)
self.assertEqual((1, embedding_dims), out.shape)
@@ -286,15 +286,12 @@ class SpinnTest(test_util.TensorFlowTestCase):
vocab_size)
# Invoke model under non-training mode.
- logits = model(
- prem, premise_transition=prem_trans, hypothesis=hypo,
- hypothesis_transition=hypo_trans, training=False)
+ logits = model(prem, prem_trans, hypo, hypo_trans, training=False)
self.assertEqual(tf.float32, logits.dtype)
self.assertEqual((batch_size, d_out), logits.shape)
# Invoke model under training model.
- logits = model(prem, premise_transition=prem_trans, hypothesis=hypo,
- hypothesis_transition=hypo_trans, training=True)
+ logits = model(prem, prem_trans, hypo, hypo_trans, training=True)
self.assertEqual(tf.float32, logits.dtype)
self.assertEqual((batch_size, d_out), logits.shape)
diff --git a/tensorflow/python/keras/_impl/keras/engine/base_layer.py b/tensorflow/python/keras/_impl/keras/engine/base_layer.py
index 5615241ae3..755607aafb 100644
--- a/tensorflow/python/keras/_impl/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/_impl/keras/engine/base_layer.py
@@ -19,6 +19,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import inspect # Necessary supplement to tf_inspect to deal with variadic args.
+
from six.moves import zip # pylint: disable=redefined-builtin
from tensorflow.python.eager import context
@@ -30,6 +32,8 @@ from tensorflow.python.keras._impl.keras import regularizers
from tensorflow.python.keras._impl.keras.utils import generic_utils
from tensorflow.python.layers import base as tf_base_layers
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import tf_decorator
+from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
@@ -143,6 +147,7 @@ class Layer(tf_base_layers.Layer):
super(Layer, self).__init__(
name=name, dtype=dtype, trainable=trainable,
activity_regularizer=kwargs.get('activity_regularizer'))
+ self._uses_inputs_arg = True
# Add properties that are Keras-only for now.
self.supports_masking = False
@@ -213,7 +218,71 @@ class Layer(tf_base_layers.Layer):
"""
return inputs
- def __call__(self, inputs, **kwargs):
+ def _inputs_from_call_args(self, call_args, call_kwargs):
+ """Get Layer inputs from __call__ *args and **kwargs.
+
+ Args:
+ call_args: The positional arguments passed to __call__.
+ call_kwargs: The keyword argument dict passed to __call__.
+
+ Returns:
+ A tuple of (inputs, non_input_kwargs). These may be the same objects as
+ were passed in (call_args and call_kwargs).
+ """
+ if getattr(self, '_uses_inputs_arg', True):
+ assert len(call_args) == 1 # TypeError raised earlier in __call__.
+ return call_args[0], call_kwargs
+ else:
+ call_arg_spec = tf_inspect.getargspec(self.call)
+ # There is no explicit "inputs" argument expected or provided to
+ # call(). Arguments which have default values are considered non-inputs,
+ # and arguments without are considered inputs.
+ if call_arg_spec.defaults:
+ if call_arg_spec.varargs is not None:
+ raise TypeError(
+ 'Layer.call() may not accept both *args and arguments with '
+ 'default values (unable to determine which are inputs to the '
+ 'Layer).')
+ keyword_arg_names = set(
+ call_arg_spec.args[-len(call_arg_spec.defaults):])
+ else:
+ keyword_arg_names = set()
+ # Training is never an input argument name, to allow signatures like
+ # call(x, training).
+ keyword_arg_names.add('training')
+ _, unwrapped_call = tf_decorator.unwrap(self.call)
+ bound_args = inspect.getcallargs(
+ unwrapped_call, *call_args, **call_kwargs)
+ if call_arg_spec.keywords is not None:
+ var_kwargs = bound_args.pop(call_arg_spec.keywords)
+ bound_args.update(var_kwargs)
+ keyword_arg_names = keyword_arg_names.union(var_kwargs.keys())
+ all_args = call_arg_spec.args
+ if all_args and bound_args[all_args[0]] is self:
+ # Ignore the 'self' argument of methods
+ bound_args.pop(call_arg_spec.args[0])
+ all_args = all_args[1:]
+ non_input_arg_values = {}
+ input_arg_values = []
+ remaining_args_are_keyword = False
+ for argument_name in all_args:
+ if argument_name in keyword_arg_names:
+ remaining_args_are_keyword = True
+ else:
+ if remaining_args_are_keyword:
+ raise TypeError(
+ 'Found a positional argument to call() after a non-input '
+ 'argument. All arguments after "training" must be keyword '
+ 'arguments, and are not tracked as inputs to the Layer.')
+ if remaining_args_are_keyword:
+ non_input_arg_values[argument_name] = bound_args[argument_name]
+ else:
+ input_arg_values.append(bound_args[argument_name])
+ if call_arg_spec.varargs is not None:
+ input_arg_values.extend(bound_args[call_arg_spec.varargs])
+ return input_arg_values, non_input_arg_values
+
+ def __call__(self, inputs, *args, **kwargs):
"""Wrapper around self.call(), for handling internal references.
If a Keras tensor is passed:
@@ -226,6 +295,10 @@ class Layer(tf_base_layers.Layer):
Arguments:
inputs: Can be a tensor or list/tuple of tensors.
+ *args: Additional positional arguments to be passed to `call()`. Only
+ allowed in subclassed Models with custom call() signatures. In other
+ cases, `Layer` inputs must be passed using the `inputs` argument and
+ non-inputs must be keyword arguments.
**kwargs: Additional keyword arguments to be passed to `call()`.
Returns:
@@ -234,12 +307,25 @@ class Layer(tf_base_layers.Layer):
Raises:
ValueError: in case the layer is missing shape information
for its `build` call.
+ TypeError: If positional arguments are passed and this `Layer` is not a
+ subclassed `Model`.
"""
# Actually call the layer (optionally building it).
- output = super(Layer, self).__call__(inputs, **kwargs)
+ output = super(Layer, self).__call__(inputs, *args, **kwargs)
+
+ if args and getattr(self, '_uses_inputs_arg', True):
+ raise TypeError(
+ 'This Layer takes an `inputs` argument to call(), and only the '
+ '`inputs` argument may be specified as a positional argument. Pass '
+ 'everything else as a keyword argument (those arguments will not be '
+ 'tracked as inputs to the Layer).')
+
if context.executing_eagerly():
return output
+ inputs, kwargs = self._inputs_from_call_args(
+ call_args=(inputs,) + args, call_kwargs=kwargs)
+
if hasattr(self, '_symbolic_set_inputs') and not self.inputs:
# Subclassed network: explicitly set metadata normally set by a call to
# self._set_inputs().
diff --git a/tensorflow/python/keras/_impl/keras/engine/network.py b/tensorflow/python/keras/_impl/keras/engine/network.py
index ea4be0d293..9f1c7de115 100644
--- a/tensorflow/python/keras/_impl/keras/engine/network.py
+++ b/tensorflow/python/keras/_impl/keras/engine/network.py
@@ -117,6 +117,7 @@ class Network(base_layer.Layer):
self._inbound_nodes = []
def _init_graph_network(self, inputs, outputs, name=None):
+ self._uses_inputs_arg = True
# Normalize and set self.inputs, self.outputs.
if isinstance(inputs, (list, tuple)):
self.inputs = list(inputs) # Tensor or list of tensors.
@@ -274,11 +275,15 @@ class Network(base_layer.Layer):
def _init_subclassed_network(self, name=None):
self._base_init(name=name)
self._is_graph_network = False
- if 'training' in tf_inspect.getargspec(self.call).args:
+ call_args = tf_inspect.getargspec(self.call).args
+ if 'training' in call_args:
self._expects_training_arg = True
else:
self._expects_training_arg = False
-
+ if 'inputs' in call_args:
+ self._uses_inputs_arg = True
+ else:
+ self._uses_inputs_arg = False
self.outputs = None
self.inputs = None
self.built = False
diff --git a/tensorflow/python/keras/_impl/keras/engine/training.py b/tensorflow/python/keras/_impl/keras/engine/training.py
index 08288d353e..971245c162 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training.py
@@ -874,6 +874,11 @@ class Model(Network):
whether to build the model's graph in inference mode (False), training
mode (True), or using the Keras learning phase (None).
"""
+ if not getattr(self, '_uses_inputs_arg', True):
+ raise NotImplementedError(
+ 'Subclassed Models without "inputs" in their call() signatures do '
+ 'not yet support shape inference. File a feature request if this '
+ 'limitation bothers you.')
if self.__class__.__name__ == 'Sequential':
# Note: we can't test whether the model is `Sequential` via `isinstance`
# since `Sequential` depends on `Model`.
diff --git a/tensorflow/python/keras/_impl/keras/model_subclassing_test.py b/tensorflow/python/keras/_impl/keras/model_subclassing_test.py
index 58b144365b..4445900330 100644
--- a/tensorflow/python/keras/_impl/keras/model_subclassing_test.py
+++ b/tensorflow/python/keras/_impl/keras/model_subclassing_test.py
@@ -22,7 +22,9 @@ import os
import tempfile
import numpy as np
+import six
+from tensorflow.python.eager import context
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.keras._impl import keras
@@ -36,6 +38,7 @@ except ImportError:
h5py = None
+# pylint: disable=not-callable
class SimpleTestModel(keras.Model):
def __init__(self, use_bn=False, use_dp=False, num_classes=10):
@@ -104,7 +107,7 @@ class NestedTestModel1(keras.Model):
def call(self, inputs):
x = self.dense1(inputs)
x = self.bn(x)
- x = self.test_net(x) # pylint: disable=not-callable
+ x = self.test_net(x)
return self.dense2(x)
@@ -161,7 +164,7 @@ def get_nested_model_3(input_dim, num_classes):
return tensor_shape.TensorShape((input_shape[0], 5))
test_model = Inner()
- x = test_model(x) # pylint: disable=not-callable
+ x = test_model(x)
outputs = keras.layers.Dense(num_classes)(x)
return keras.Model(inputs, outputs, name='nested_model_3')
@@ -574,5 +577,128 @@ class ModelSubclassingTest(test.TestCase):
self.assertGreater(loss, 0.1)
+class CustomCallModel(keras.Model):
+
+ def __init__(self):
+ super(CustomCallModel, self).__init__()
+ self.dense1 = keras.layers.Dense(1, activation='relu')
+ self.dense2 = keras.layers.Dense(1, activation='softmax')
+
+ def call(self, first, second, fiddle_with_output='no', training=True):
+ combined = self.dense1(first) + self.dense2(second)
+ if fiddle_with_output == 'yes':
+ return 10. * combined
+ else:
+ return combined
+
+
+class CustomCallSignatureTests(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_no_inputs_in_signature(self):
+ model = CustomCallModel()
+ first = array_ops.ones([2, 3])
+ second = array_ops.ones([2, 5])
+ output = model(first, second)
+ self.evaluate([v.initializer for v in model.variables])
+ expected_output = self.evaluate(model.dense1(first) + model.dense2(second))
+ self.assertAllClose(expected_output, self.evaluate(output))
+ output = model(first, second, fiddle_with_output='yes')
+ self.assertAllClose(10. * expected_output, self.evaluate(output))
+ output = model(first, second=second, training=False)
+ self.assertAllClose(expected_output, self.evaluate(output))
+ if not context.executing_eagerly():
+ six.assertCountEqual(self, [first, second], model.inputs)
+ with self.assertRaises(TypeError):
+ # tf.layers.Layer expects an "inputs" argument, so all-keywords doesn't
+ # work at the moment.
+ model(first=first, second=second, fiddle_with_output='yes')
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_inputs_in_signature(self):
+
+ class HasInputsAndOtherPositional(keras.Model):
+
+ def call(self, inputs, some_other_arg, training=False):
+ return inputs
+
+ model = HasInputsAndOtherPositional()
+ with self.assertRaisesRegexp(
+ TypeError, 'everything else as a keyword argument'):
+ model(array_ops.ones([]), array_ops.ones([]))
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_kwargs_in_signature(self):
+
+ class HasKwargs(keras.Model):
+
+ def call(self, x, y=3, **key_words):
+ return x
+
+ model = HasKwargs()
+ arg = array_ops.ones([])
+ model(arg, a=3)
+ if not context.executing_eagerly():
+ six.assertCountEqual(self, [arg], model.inputs)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_args_in_signature(self):
+
+ class HasArgs(keras.Model):
+
+ def call(self, x, *args, **kwargs):
+ return [x] + list(args)
+
+ model = HasArgs()
+ arg1 = array_ops.ones([])
+ arg2 = array_ops.ones([])
+ arg3 = array_ops.ones([])
+ model(arg1, arg2, arg3, a=3)
+ if not context.executing_eagerly():
+ six.assertCountEqual(self, [arg1, arg2, arg3], model.inputs)
+
+ def test_args_and_keywords_in_signature(self):
+
+ class HasArgs(keras.Model):
+
+ def call(self, x, training=True, *args, **kwargs):
+ return x
+
+ with context.graph_mode():
+ model = HasArgs()
+ arg1 = array_ops.ones([])
+ arg2 = array_ops.ones([])
+ arg3 = array_ops.ones([])
+ with self.assertRaisesRegexp(TypeError, 'args and arguments with'):
+ model(arg1, arg2, arg3, a=3)
+
+ def test_training_no_default(self):
+
+ class TrainingNoDefault(keras.Model):
+
+ def call(self, x, training):
+ return x
+
+ with context.graph_mode():
+ model = TrainingNoDefault()
+ arg = array_ops.ones([])
+ model(arg, True)
+ six.assertCountEqual(self, [arg], model.inputs)
+
+ def test_training_no_default_with_positional(self):
+
+ class TrainingNoDefaultWithPositional(keras.Model):
+
+ def call(self, x, training, positional):
+ return x
+
+ with context.graph_mode():
+ model = TrainingNoDefaultWithPositional()
+ arg1 = array_ops.ones([])
+ arg2 = array_ops.ones([])
+ arg3 = array_ops.ones([])
+ with self.assertRaisesRegexp(TypeError, 'after a non-input'):
+ model(arg1, arg2, arg3)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index 1e5f26a77f..242cdff6f3 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -625,6 +625,8 @@ class Layer(checkpointable.CheckpointableBase):
input_list = nest.flatten(inputs)
build_graph = not context.executing_eagerly()
+ # TODO(fchollet, allenl): Make deferred mode work with subclassed Models
+ # which don't use an "inputs" argument.
in_deferred_mode = isinstance(input_list[0], _DeferredTensor)
# Ensure the Layer, if being reused, is working with inputs from
# the same graph as where it was created.
diff --git a/third_party/examples/eager/spinn/spinn.py b/third_party/examples/eager/spinn/spinn.py
index f8fb6ecb0c..8a2b24aa4e 100644
--- a/third_party/examples/eager/spinn/spinn.py
+++ b/third_party/examples/eager/spinn/spinn.py
@@ -266,8 +266,7 @@ class SPINN(tf.keras.Model):
trackings.append(tracking)
if rights:
- reducer_output = self.reducer(
- lefts, right_in=rights, tracking=trackings)
+ reducer_output = self.reducer(lefts, rights, trackings)
reduced = iter(reducer_output)
for transition, stack in zip(trans, stacks):
@@ -388,10 +387,10 @@ class SNLIClassifier(tf.keras.Model):
# Run the batch-normalized and dropout-processed word vectors through the
# SPINN encoder.
- premise = self.encoder(
- premise_embed, transitions=premise_transition, training=training)
- hypothesis = self.encoder(
- hypothesis_embed, transitions=hypothesis_transition, training=training)
+ premise = self.encoder(premise_embed, premise_transition,
+ training=training)
+ hypothesis = self.encoder(hypothesis_embed, hypothesis_transition,
+ training=training)
# Combine encoder outputs for premises and hypotheses into logits.
# Then apply batch normalization and dropuout on the logits.
@@ -465,11 +464,10 @@ class SNLIClassifierTrainer(tfe.Checkpointable):
"""
with tfe.GradientTape() as tape:
tape.watch(self._model.variables)
- # TODO(allenl): Allow passing Layer inputs as position arguments.
logits = self._model(premise,
- premise_transition=premise_transition,
- hypothesis=hypothesis,
- hypothesis_transition=hypothesis_transition,
+ premise_transition,
+ hypothesis,
+ hypothesis_transition,
training=True)
loss = self.loss(labels, logits)
gradients = tape.gradient(loss, self._model.variables)
@@ -533,9 +531,7 @@ def _evaluate_on_dataset(snli_data, batch_size, trainer, use_gpu):
snli_data, batch_size):
if use_gpu:
label, prem, hypo = label.gpu(), prem.gpu(), hypo.gpu()
- logits = trainer.model(
- prem, premise_transition=prem_trans, hypothesis=hypo,
- hypothesis_transition=hypo_trans, training=False)
+ logits = trainer.model(prem, prem_trans, hypo, hypo_trans, training=False)
loss_val = trainer.loss(label, logits)
batch_size = tf.shape(label)[0]
mean_loss(loss_val, weights=batch_size.gpu() if use_gpu else batch_size)
@@ -639,11 +635,8 @@ def train_or_infer_spinn(embed,
hypo, hypo_trans = inference_sentence_pair[1]
hypo_trans = inference_sentence_pair[1][1]
inference_logits = model(
- tf.constant(prem),
- premise_transition=tf.constant(prem_trans),
- hypothesis=tf.constant(hypo),
- hypothesis_transition=tf.constant(hypo_trans),
- training=False)
+ tf.constant(prem), tf.constant(prem_trans),
+ tf.constant(hypo), tf.constant(hypo_trans), training=False)
inference_logits = inference_logits[0][1:]
max_index = tf.argmax(inference_logits)
print("\nInference logits:")