diff options
7 files changed, 246 insertions, 32 deletions
diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py index 591d99edcd..9261823d77 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py +++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py @@ -173,7 +173,7 @@ class SpinnTest(test_util.TensorFlowTestCase): right_in.append(tf.random_normal((1, size * 2))) tracking.append(tf.random_normal((1, tracker_size * 2))) - out = reducer(left_in, right_in=right_in, tracking=tracking) + out = reducer(left_in, right_in, tracking=tracking) self.assertEqual(batch_size, len(out)) self.assertEqual(tf.float32, out[0].dtype) self.assertEqual((1, size * 2), out[0].shape) @@ -227,7 +227,7 @@ class SpinnTest(test_util.TensorFlowTestCase): self.assertEqual((batch_size, size * 2), stacks[0][0].shape) for _ in range(2): - out1, out2 = tracker(bufs, stacks=stacks) + out1, out2 = tracker(bufs, stacks) self.assertIsNone(out2) self.assertEqual(batch_size, len(out1)) self.assertEqual(tf.float32, out1[0].dtype) @@ -260,7 +260,7 @@ class SpinnTest(test_util.TensorFlowTestCase): self.assertEqual(tf.int64, transitions.dtype) self.assertEqual((num_transitions, 1), transitions.shape) - out = s(buffers, transitions=transitions, training=True) + out = s(buffers, transitions, training=True) self.assertEqual(tf.float32, out.dtype) self.assertEqual((1, embedding_dims), out.shape) @@ -286,15 +286,12 @@ class SpinnTest(test_util.TensorFlowTestCase): vocab_size) # Invoke model under non-training mode. - logits = model( - prem, premise_transition=prem_trans, hypothesis=hypo, - hypothesis_transition=hypo_trans, training=False) + logits = model(prem, prem_trans, hypo, hypo_trans, training=False) self.assertEqual(tf.float32, logits.dtype) self.assertEqual((batch_size, d_out), logits.shape) # Invoke model under training model. - logits = model(prem, premise_transition=prem_trans, hypothesis=hypo, - hypothesis_transition=hypo_trans, training=True) + logits = model(prem, prem_trans, hypo, hypo_trans, training=True) self.assertEqual(tf.float32, logits.dtype) self.assertEqual((batch_size, d_out), logits.shape) diff --git a/tensorflow/python/keras/_impl/keras/engine/base_layer.py b/tensorflow/python/keras/_impl/keras/engine/base_layer.py index 5615241ae3..755607aafb 100644 --- a/tensorflow/python/keras/_impl/keras/engine/base_layer.py +++ b/tensorflow/python/keras/_impl/keras/engine/base_layer.py @@ -19,6 +19,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import inspect # Necessary supplement to tf_inspect to deal with variadic args. + from six.moves import zip # pylint: disable=redefined-builtin from tensorflow.python.eager import context @@ -30,6 +32,8 @@ from tensorflow.python.keras._impl.keras import regularizers from tensorflow.python.keras._impl.keras.utils import generic_utils from tensorflow.python.layers import base as tf_base_layers from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import tf_decorator +from tensorflow.python.util import tf_inspect from tensorflow.python.util.tf_export import tf_export @@ -143,6 +147,7 @@ class Layer(tf_base_layers.Layer): super(Layer, self).__init__( name=name, dtype=dtype, trainable=trainable, activity_regularizer=kwargs.get('activity_regularizer')) + self._uses_inputs_arg = True # Add properties that are Keras-only for now. self.supports_masking = False @@ -213,7 +218,71 @@ class Layer(tf_base_layers.Layer): """ return inputs - def __call__(self, inputs, **kwargs): + def _inputs_from_call_args(self, call_args, call_kwargs): + """Get Layer inputs from __call__ *args and **kwargs. + + Args: + call_args: The positional arguments passed to __call__. + call_kwargs: The keyword argument dict passed to __call__. + + Returns: + A tuple of (inputs, non_input_kwargs). These may be the same objects as + were passed in (call_args and call_kwargs). + """ + if getattr(self, '_uses_inputs_arg', True): + assert len(call_args) == 1 # TypeError raised earlier in __call__. + return call_args[0], call_kwargs + else: + call_arg_spec = tf_inspect.getargspec(self.call) + # There is no explicit "inputs" argument expected or provided to + # call(). Arguments which have default values are considered non-inputs, + # and arguments without are considered inputs. + if call_arg_spec.defaults: + if call_arg_spec.varargs is not None: + raise TypeError( + 'Layer.call() may not accept both *args and arguments with ' + 'default values (unable to determine which are inputs to the ' + 'Layer).') + keyword_arg_names = set( + call_arg_spec.args[-len(call_arg_spec.defaults):]) + else: + keyword_arg_names = set() + # Training is never an input argument name, to allow signatures like + # call(x, training). + keyword_arg_names.add('training') + _, unwrapped_call = tf_decorator.unwrap(self.call) + bound_args = inspect.getcallargs( + unwrapped_call, *call_args, **call_kwargs) + if call_arg_spec.keywords is not None: + var_kwargs = bound_args.pop(call_arg_spec.keywords) + bound_args.update(var_kwargs) + keyword_arg_names = keyword_arg_names.union(var_kwargs.keys()) + all_args = call_arg_spec.args + if all_args and bound_args[all_args[0]] is self: + # Ignore the 'self' argument of methods + bound_args.pop(call_arg_spec.args[0]) + all_args = all_args[1:] + non_input_arg_values = {} + input_arg_values = [] + remaining_args_are_keyword = False + for argument_name in all_args: + if argument_name in keyword_arg_names: + remaining_args_are_keyword = True + else: + if remaining_args_are_keyword: + raise TypeError( + 'Found a positional argument to call() after a non-input ' + 'argument. All arguments after "training" must be keyword ' + 'arguments, and are not tracked as inputs to the Layer.') + if remaining_args_are_keyword: + non_input_arg_values[argument_name] = bound_args[argument_name] + else: + input_arg_values.append(bound_args[argument_name]) + if call_arg_spec.varargs is not None: + input_arg_values.extend(bound_args[call_arg_spec.varargs]) + return input_arg_values, non_input_arg_values + + def __call__(self, inputs, *args, **kwargs): """Wrapper around self.call(), for handling internal references. If a Keras tensor is passed: @@ -226,6 +295,10 @@ class Layer(tf_base_layers.Layer): Arguments: inputs: Can be a tensor or list/tuple of tensors. + *args: Additional positional arguments to be passed to `call()`. Only + allowed in subclassed Models with custom call() signatures. In other + cases, `Layer` inputs must be passed using the `inputs` argument and + non-inputs must be keyword arguments. **kwargs: Additional keyword arguments to be passed to `call()`. Returns: @@ -234,12 +307,25 @@ class Layer(tf_base_layers.Layer): Raises: ValueError: in case the layer is missing shape information for its `build` call. + TypeError: If positional arguments are passed and this `Layer` is not a + subclassed `Model`. """ # Actually call the layer (optionally building it). - output = super(Layer, self).__call__(inputs, **kwargs) + output = super(Layer, self).__call__(inputs, *args, **kwargs) + + if args and getattr(self, '_uses_inputs_arg', True): + raise TypeError( + 'This Layer takes an `inputs` argument to call(), and only the ' + '`inputs` argument may be specified as a positional argument. Pass ' + 'everything else as a keyword argument (those arguments will not be ' + 'tracked as inputs to the Layer).') + if context.executing_eagerly(): return output + inputs, kwargs = self._inputs_from_call_args( + call_args=(inputs,) + args, call_kwargs=kwargs) + if hasattr(self, '_symbolic_set_inputs') and not self.inputs: # Subclassed network: explicitly set metadata normally set by a call to # self._set_inputs(). diff --git a/tensorflow/python/keras/_impl/keras/engine/network.py b/tensorflow/python/keras/_impl/keras/engine/network.py index ea4be0d293..9f1c7de115 100644 --- a/tensorflow/python/keras/_impl/keras/engine/network.py +++ b/tensorflow/python/keras/_impl/keras/engine/network.py @@ -117,6 +117,7 @@ class Network(base_layer.Layer): self._inbound_nodes = [] def _init_graph_network(self, inputs, outputs, name=None): + self._uses_inputs_arg = True # Normalize and set self.inputs, self.outputs. if isinstance(inputs, (list, tuple)): self.inputs = list(inputs) # Tensor or list of tensors. @@ -274,11 +275,15 @@ class Network(base_layer.Layer): def _init_subclassed_network(self, name=None): self._base_init(name=name) self._is_graph_network = False - if 'training' in tf_inspect.getargspec(self.call).args: + call_args = tf_inspect.getargspec(self.call).args + if 'training' in call_args: self._expects_training_arg = True else: self._expects_training_arg = False - + if 'inputs' in call_args: + self._uses_inputs_arg = True + else: + self._uses_inputs_arg = False self.outputs = None self.inputs = None self.built = False diff --git a/tensorflow/python/keras/_impl/keras/engine/training.py b/tensorflow/python/keras/_impl/keras/engine/training.py index 08288d353e..971245c162 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training.py +++ b/tensorflow/python/keras/_impl/keras/engine/training.py @@ -874,6 +874,11 @@ class Model(Network): whether to build the model's graph in inference mode (False), training mode (True), or using the Keras learning phase (None). """ + if not getattr(self, '_uses_inputs_arg', True): + raise NotImplementedError( + 'Subclassed Models without "inputs" in their call() signatures do ' + 'not yet support shape inference. File a feature request if this ' + 'limitation bothers you.') if self.__class__.__name__ == 'Sequential': # Note: we can't test whether the model is `Sequential` via `isinstance` # since `Sequential` depends on `Model`. diff --git a/tensorflow/python/keras/_impl/keras/model_subclassing_test.py b/tensorflow/python/keras/_impl/keras/model_subclassing_test.py index 58b144365b..4445900330 100644 --- a/tensorflow/python/keras/_impl/keras/model_subclassing_test.py +++ b/tensorflow/python/keras/_impl/keras/model_subclassing_test.py @@ -22,7 +22,9 @@ import os import tempfile import numpy as np +import six +from tensorflow.python.eager import context from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.keras._impl import keras @@ -36,6 +38,7 @@ except ImportError: h5py = None +# pylint: disable=not-callable class SimpleTestModel(keras.Model): def __init__(self, use_bn=False, use_dp=False, num_classes=10): @@ -104,7 +107,7 @@ class NestedTestModel1(keras.Model): def call(self, inputs): x = self.dense1(inputs) x = self.bn(x) - x = self.test_net(x) # pylint: disable=not-callable + x = self.test_net(x) return self.dense2(x) @@ -161,7 +164,7 @@ def get_nested_model_3(input_dim, num_classes): return tensor_shape.TensorShape((input_shape[0], 5)) test_model = Inner() - x = test_model(x) # pylint: disable=not-callable + x = test_model(x) outputs = keras.layers.Dense(num_classes)(x) return keras.Model(inputs, outputs, name='nested_model_3') @@ -574,5 +577,128 @@ class ModelSubclassingTest(test.TestCase): self.assertGreater(loss, 0.1) +class CustomCallModel(keras.Model): + + def __init__(self): + super(CustomCallModel, self).__init__() + self.dense1 = keras.layers.Dense(1, activation='relu') + self.dense2 = keras.layers.Dense(1, activation='softmax') + + def call(self, first, second, fiddle_with_output='no', training=True): + combined = self.dense1(first) + self.dense2(second) + if fiddle_with_output == 'yes': + return 10. * combined + else: + return combined + + +class CustomCallSignatureTests(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def test_no_inputs_in_signature(self): + model = CustomCallModel() + first = array_ops.ones([2, 3]) + second = array_ops.ones([2, 5]) + output = model(first, second) + self.evaluate([v.initializer for v in model.variables]) + expected_output = self.evaluate(model.dense1(first) + model.dense2(second)) + self.assertAllClose(expected_output, self.evaluate(output)) + output = model(first, second, fiddle_with_output='yes') + self.assertAllClose(10. * expected_output, self.evaluate(output)) + output = model(first, second=second, training=False) + self.assertAllClose(expected_output, self.evaluate(output)) + if not context.executing_eagerly(): + six.assertCountEqual(self, [first, second], model.inputs) + with self.assertRaises(TypeError): + # tf.layers.Layer expects an "inputs" argument, so all-keywords doesn't + # work at the moment. + model(first=first, second=second, fiddle_with_output='yes') + + @test_util.run_in_graph_and_eager_modes() + def test_inputs_in_signature(self): + + class HasInputsAndOtherPositional(keras.Model): + + def call(self, inputs, some_other_arg, training=False): + return inputs + + model = HasInputsAndOtherPositional() + with self.assertRaisesRegexp( + TypeError, 'everything else as a keyword argument'): + model(array_ops.ones([]), array_ops.ones([])) + + @test_util.run_in_graph_and_eager_modes() + def test_kwargs_in_signature(self): + + class HasKwargs(keras.Model): + + def call(self, x, y=3, **key_words): + return x + + model = HasKwargs() + arg = array_ops.ones([]) + model(arg, a=3) + if not context.executing_eagerly(): + six.assertCountEqual(self, [arg], model.inputs) + + @test_util.run_in_graph_and_eager_modes() + def test_args_in_signature(self): + + class HasArgs(keras.Model): + + def call(self, x, *args, **kwargs): + return [x] + list(args) + + model = HasArgs() + arg1 = array_ops.ones([]) + arg2 = array_ops.ones([]) + arg3 = array_ops.ones([]) + model(arg1, arg2, arg3, a=3) + if not context.executing_eagerly(): + six.assertCountEqual(self, [arg1, arg2, arg3], model.inputs) + + def test_args_and_keywords_in_signature(self): + + class HasArgs(keras.Model): + + def call(self, x, training=True, *args, **kwargs): + return x + + with context.graph_mode(): + model = HasArgs() + arg1 = array_ops.ones([]) + arg2 = array_ops.ones([]) + arg3 = array_ops.ones([]) + with self.assertRaisesRegexp(TypeError, 'args and arguments with'): + model(arg1, arg2, arg3, a=3) + + def test_training_no_default(self): + + class TrainingNoDefault(keras.Model): + + def call(self, x, training): + return x + + with context.graph_mode(): + model = TrainingNoDefault() + arg = array_ops.ones([]) + model(arg, True) + six.assertCountEqual(self, [arg], model.inputs) + + def test_training_no_default_with_positional(self): + + class TrainingNoDefaultWithPositional(keras.Model): + + def call(self, x, training, positional): + return x + + with context.graph_mode(): + model = TrainingNoDefaultWithPositional() + arg1 = array_ops.ones([]) + arg2 = array_ops.ones([]) + arg3 = array_ops.ones([]) + with self.assertRaisesRegexp(TypeError, 'after a non-input'): + model(arg1, arg2, arg3) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index 1e5f26a77f..242cdff6f3 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -625,6 +625,8 @@ class Layer(checkpointable.CheckpointableBase): input_list = nest.flatten(inputs) build_graph = not context.executing_eagerly() + # TODO(fchollet, allenl): Make deferred mode work with subclassed Models + # which don't use an "inputs" argument. in_deferred_mode = isinstance(input_list[0], _DeferredTensor) # Ensure the Layer, if being reused, is working with inputs from # the same graph as where it was created. diff --git a/third_party/examples/eager/spinn/spinn.py b/third_party/examples/eager/spinn/spinn.py index f8fb6ecb0c..8a2b24aa4e 100644 --- a/third_party/examples/eager/spinn/spinn.py +++ b/third_party/examples/eager/spinn/spinn.py @@ -266,8 +266,7 @@ class SPINN(tf.keras.Model): trackings.append(tracking) if rights: - reducer_output = self.reducer( - lefts, right_in=rights, tracking=trackings) + reducer_output = self.reducer(lefts, rights, trackings) reduced = iter(reducer_output) for transition, stack in zip(trans, stacks): @@ -388,10 +387,10 @@ class SNLIClassifier(tf.keras.Model): # Run the batch-normalized and dropout-processed word vectors through the # SPINN encoder. - premise = self.encoder( - premise_embed, transitions=premise_transition, training=training) - hypothesis = self.encoder( - hypothesis_embed, transitions=hypothesis_transition, training=training) + premise = self.encoder(premise_embed, premise_transition, + training=training) + hypothesis = self.encoder(hypothesis_embed, hypothesis_transition, + training=training) # Combine encoder outputs for premises and hypotheses into logits. # Then apply batch normalization and dropuout on the logits. @@ -465,11 +464,10 @@ class SNLIClassifierTrainer(tfe.Checkpointable): """ with tfe.GradientTape() as tape: tape.watch(self._model.variables) - # TODO(allenl): Allow passing Layer inputs as position arguments. logits = self._model(premise, - premise_transition=premise_transition, - hypothesis=hypothesis, - hypothesis_transition=hypothesis_transition, + premise_transition, + hypothesis, + hypothesis_transition, training=True) loss = self.loss(labels, logits) gradients = tape.gradient(loss, self._model.variables) @@ -533,9 +531,7 @@ def _evaluate_on_dataset(snli_data, batch_size, trainer, use_gpu): snli_data, batch_size): if use_gpu: label, prem, hypo = label.gpu(), prem.gpu(), hypo.gpu() - logits = trainer.model( - prem, premise_transition=prem_trans, hypothesis=hypo, - hypothesis_transition=hypo_trans, training=False) + logits = trainer.model(prem, prem_trans, hypo, hypo_trans, training=False) loss_val = trainer.loss(label, logits) batch_size = tf.shape(label)[0] mean_loss(loss_val, weights=batch_size.gpu() if use_gpu else batch_size) @@ -639,11 +635,8 @@ def train_or_infer_spinn(embed, hypo, hypo_trans = inference_sentence_pair[1] hypo_trans = inference_sentence_pair[1][1] inference_logits = model( - tf.constant(prem), - premise_transition=tf.constant(prem_trans), - hypothesis=tf.constant(hypo), - hypothesis_transition=tf.constant(hypo_trans), - training=False) + tf.constant(prem), tf.constant(prem_trans), + tf.constant(hypo), tf.constant(hypo_trans), training=False) inference_logits = inference_logits[0][1:] max_index = tf.argmax(inference_logits) print("\nInference logits:") |