aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/eager/python/examples/spinn/spinn_test.py13
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/base_layer.py90
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/network.py9
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training.py5
-rw-r--r--tensorflow/python/keras/_impl/keras/model_subclassing_test.py130
-rw-r--r--tensorflow/python/layers/base.py2
-rw-r--r--third_party/examples/eager/spinn/spinn.py29
7 files changed, 246 insertions, 32 deletions
diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py
index 591d99edcd..9261823d77 100644
--- a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py
+++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py
@@ -173,7 +173,7 @@ class SpinnTest(test_util.TensorFlowTestCase):
right_in.append(tf.random_normal((1, size * 2)))
tracking.append(tf.random_normal((1, tracker_size * 2)))
- out = reducer(left_in, right_in=right_in, tracking=tracking)
+ out = reducer(left_in, right_in, tracking=tracking)
self.assertEqual(batch_size, len(out))
self.assertEqual(tf.float32, out[0].dtype)
self.assertEqual((1, size * 2), out[0].shape)
@@ -227,7 +227,7 @@ class SpinnTest(test_util.TensorFlowTestCase):
self.assertEqual((batch_size, size * 2), stacks[0][0].shape)
for _ in range(2):
- out1, out2 = tracker(bufs, stacks=stacks)
+ out1, out2 = tracker(bufs, stacks)
self.assertIsNone(out2)
self.assertEqual(batch_size, len(out1))
self.assertEqual(tf.float32, out1[0].dtype)
@@ -260,7 +260,7 @@ class SpinnTest(test_util.TensorFlowTestCase):
self.assertEqual(tf.int64, transitions.dtype)
self.assertEqual((num_transitions, 1), transitions.shape)
- out = s(buffers, transitions=transitions, training=True)
+ out = s(buffers, transitions, training=True)
self.assertEqual(tf.float32, out.dtype)
self.assertEqual((1, embedding_dims), out.shape)
@@ -286,15 +286,12 @@ class SpinnTest(test_util.TensorFlowTestCase):
vocab_size)
# Invoke model under non-training mode.
- logits = model(
- prem, premise_transition=prem_trans, hypothesis=hypo,
- hypothesis_transition=hypo_trans, training=False)
+ logits = model(prem, prem_trans, hypo, hypo_trans, training=False)
self.assertEqual(tf.float32, logits.dtype)
self.assertEqual((batch_size, d_out), logits.shape)
# Invoke model under training model.
- logits = model(prem, premise_transition=prem_trans, hypothesis=hypo,
- hypothesis_transition=hypo_trans, training=True)
+ logits = model(prem, prem_trans, hypo, hypo_trans, training=True)
self.assertEqual(tf.float32, logits.dtype)
self.assertEqual((batch_size, d_out), logits.shape)
diff --git a/tensorflow/python/keras/_impl/keras/engine/base_layer.py b/tensorflow/python/keras/_impl/keras/engine/base_layer.py
index 5615241ae3..755607aafb 100644
--- a/tensorflow/python/keras/_impl/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/_impl/keras/engine/base_layer.py
@@ -19,6 +19,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import inspect # Necessary supplement to tf_inspect to deal with variadic args.
+
from six.moves import zip # pylint: disable=redefined-builtin
from tensorflow.python.eager import context
@@ -30,6 +32,8 @@ from tensorflow.python.keras._impl.keras import regularizers
from tensorflow.python.keras._impl.keras.utils import generic_utils
from tensorflow.python.layers import base as tf_base_layers
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import tf_decorator
+from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
@@ -143,6 +147,7 @@ class Layer(tf_base_layers.Layer):
super(Layer, self).__init__(
name=name, dtype=dtype, trainable=trainable,
activity_regularizer=kwargs.get('activity_regularizer'))
+ self._uses_inputs_arg = True
# Add properties that are Keras-only for now.
self.supports_masking = False
@@ -213,7 +218,71 @@ class Layer(tf_base_layers.Layer):
"""
return inputs
- def __call__(self, inputs, **kwargs):
+ def _inputs_from_call_args(self, call_args, call_kwargs):
+ """Get Layer inputs from __call__ *args and **kwargs.
+
+ Args:
+ call_args: The positional arguments passed to __call__.
+ call_kwargs: The keyword argument dict passed to __call__.
+
+ Returns:
+ A tuple of (inputs, non_input_kwargs). These may be the same objects as
+ were passed in (call_args and call_kwargs).
+ """
+ if getattr(self, '_uses_inputs_arg', True):
+ assert len(call_args) == 1 # TypeError raised earlier in __call__.
+ return call_args[0], call_kwargs
+ else:
+ call_arg_spec = tf_inspect.getargspec(self.call)
+ # There is no explicit "inputs" argument expected or provided to
+ # call(). Arguments which have default values are considered non-inputs,
+ # and arguments without are considered inputs.
+ if call_arg_spec.defaults:
+ if call_arg_spec.varargs is not None:
+ raise TypeError(
+ 'Layer.call() may not accept both *args and arguments with '
+ 'default values (unable to determine which are inputs to the '
+ 'Layer).')
+ keyword_arg_names = set(
+ call_arg_spec.args[-len(call_arg_spec.defaults):])
+ else:
+ keyword_arg_names = set()
+ # Training is never an input argument name, to allow signatures like
+ # call(x, training).
+ keyword_arg_names.add('training')
+ _, unwrapped_call = tf_decorator.unwrap(self.call)
+ bound_args = inspect.getcallargs(
+ unwrapped_call, *call_args, **call_kwargs)
+ if call_arg_spec.keywords is not None:
+ var_kwargs = bound_args.pop(call_arg_spec.keywords)
+ bound_args.update(var_kwargs)
+ keyword_arg_names = keyword_arg_names.union(var_kwargs.keys())
+ all_args = call_arg_spec.args
+ if all_args and bound_args[all_args[0]] is self:
+ # Ignore the 'self' argument of methods
+ bound_args.pop(call_arg_spec.args[0])
+ all_args = all_args[1:]
+ non_input_arg_values = {}
+ input_arg_values = []
+ remaining_args_are_keyword = False
+ for argument_name in all_args:
+ if argument_name in keyword_arg_names:
+ remaining_args_are_keyword = True
+ else:
+ if remaining_args_are_keyword:
+ raise TypeError(
+ 'Found a positional argument to call() after a non-input '
+ 'argument. All arguments after "training" must be keyword '
+ 'arguments, and are not tracked as inputs to the Layer.')
+ if remaining_args_are_keyword:
+ non_input_arg_values[argument_name] = bound_args[argument_name]
+ else:
+ input_arg_values.append(bound_args[argument_name])
+ if call_arg_spec.varargs is not None:
+ input_arg_values.extend(bound_args[call_arg_spec.varargs])
+ return input_arg_values, non_input_arg_values
+
+ def __call__(self, inputs, *args, **kwargs):
"""Wrapper around self.call(), for handling internal references.
If a Keras tensor is passed:
@@ -226,6 +295,10 @@ class Layer(tf_base_layers.Layer):
Arguments:
inputs: Can be a tensor or list/tuple of tensors.
+ *args: Additional positional arguments to be passed to `call()`. Only
+ allowed in subclassed Models with custom call() signatures. In other
+ cases, `Layer` inputs must be passed using the `inputs` argument and
+ non-inputs must be keyword arguments.
**kwargs: Additional keyword arguments to be passed to `call()`.
Returns:
@@ -234,12 +307,25 @@ class Layer(tf_base_layers.Layer):
Raises:
ValueError: in case the layer is missing shape information
for its `build` call.
+ TypeError: If positional arguments are passed and this `Layer` is not a
+ subclassed `Model`.
"""
# Actually call the layer (optionally building it).
- output = super(Layer, self).__call__(inputs, **kwargs)
+ output = super(Layer, self).__call__(inputs, *args, **kwargs)
+
+ if args and getattr(self, '_uses_inputs_arg', True):
+ raise TypeError(
+ 'This Layer takes an `inputs` argument to call(), and only the '
+ '`inputs` argument may be specified as a positional argument. Pass '
+ 'everything else as a keyword argument (those arguments will not be '
+ 'tracked as inputs to the Layer).')
+
if context.executing_eagerly():
return output
+ inputs, kwargs = self._inputs_from_call_args(
+ call_args=(inputs,) + args, call_kwargs=kwargs)
+
if hasattr(self, '_symbolic_set_inputs') and not self.inputs:
# Subclassed network: explicitly set metadata normally set by a call to
# self._set_inputs().
diff --git a/tensorflow/python/keras/_impl/keras/engine/network.py b/tensorflow/python/keras/_impl/keras/engine/network.py
index ea4be0d293..9f1c7de115 100644
--- a/tensorflow/python/keras/_impl/keras/engine/network.py
+++ b/tensorflow/python/keras/_impl/keras/engine/network.py
@@ -117,6 +117,7 @@ class Network(base_layer.Layer):
self._inbound_nodes = []
def _init_graph_network(self, inputs, outputs, name=None):
+ self._uses_inputs_arg = True
# Normalize and set self.inputs, self.outputs.
if isinstance(inputs, (list, tuple)):
self.inputs = list(inputs) # Tensor or list of tensors.
@@ -274,11 +275,15 @@ class Network(base_layer.Layer):
def _init_subclassed_network(self, name=None):
self._base_init(name=name)
self._is_graph_network = False
- if 'training' in tf_inspect.getargspec(self.call).args:
+ call_args = tf_inspect.getargspec(self.call).args
+ if 'training' in call_args:
self._expects_training_arg = True
else:
self._expects_training_arg = False
-
+ if 'inputs' in call_args:
+ self._uses_inputs_arg = True
+ else:
+ self._uses_inputs_arg = False
self.outputs = None
self.inputs = None
self.built = False
diff --git a/tensorflow/python/keras/_impl/keras/engine/training.py b/tensorflow/python/keras/_impl/keras/engine/training.py
index 08288d353e..971245c162 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training.py
@@ -874,6 +874,11 @@ class Model(Network):
whether to build the model's graph in inference mode (False), training
mode (True), or using the Keras learning phase (None).
"""
+ if not getattr(self, '_uses_inputs_arg', True):
+ raise NotImplementedError(
+ 'Subclassed Models without "inputs" in their call() signatures do '
+ 'not yet support shape inference. File a feature request if this '
+ 'limitation bothers you.')
if self.__class__.__name__ == 'Sequential':
# Note: we can't test whether the model is `Sequential` via `isinstance`
# since `Sequential` depends on `Model`.
diff --git a/tensorflow/python/keras/_impl/keras/model_subclassing_test.py b/tensorflow/python/keras/_impl/keras/model_subclassing_test.py
index 58b144365b..4445900330 100644
--- a/tensorflow/python/keras/_impl/keras/model_subclassing_test.py
+++ b/tensorflow/python/keras/_impl/keras/model_subclassing_test.py
@@ -22,7 +22,9 @@ import os
import tempfile
import numpy as np
+import six
+from tensorflow.python.eager import context
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.keras._impl import keras
@@ -36,6 +38,7 @@ except ImportError:
h5py = None
+# pylint: disable=not-callable
class SimpleTestModel(keras.Model):
def __init__(self, use_bn=False, use_dp=False, num_classes=10):
@@ -104,7 +107,7 @@ class NestedTestModel1(keras.Model):
def call(self, inputs):
x = self.dense1(inputs)
x = self.bn(x)
- x = self.test_net(x) # pylint: disable=not-callable
+ x = self.test_net(x)
return self.dense2(x)
@@ -161,7 +164,7 @@ def get_nested_model_3(input_dim, num_classes):
return tensor_shape.TensorShape((input_shape[0], 5))
test_model = Inner()
- x = test_model(x) # pylint: disable=not-callable
+ x = test_model(x)
outputs = keras.layers.Dense(num_classes)(x)
return keras.Model(inputs, outputs, name='nested_model_3')
@@ -574,5 +577,128 @@ class ModelSubclassingTest(test.TestCase):
self.assertGreater(loss, 0.1)
+class CustomCallModel(keras.Model):
+
+ def __init__(self):
+ super(CustomCallModel, self).__init__()
+ self.dense1 = keras.layers.Dense(1, activation='relu')
+ self.dense2 = keras.layers.Dense(1, activation='softmax')
+
+ def call(self, first, second, fiddle_with_output='no', training=True):
+ combined = self.dense1(first) + self.dense2(second)
+ if fiddle_with_output == 'yes':
+ return 10. * combined
+ else:
+ return combined
+
+
+class CustomCallSignatureTests(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_no_inputs_in_signature(self):
+ model = CustomCallModel()
+ first = array_ops.ones([2, 3])
+ second = array_ops.ones([2, 5])
+ output = model(first, second)
+ self.evaluate([v.initializer for v in model.variables])
+ expected_output = self.evaluate(model.dense1(first) + model.dense2(second))
+ self.assertAllClose(expected_output, self.evaluate(output))
+ output = model(first, second, fiddle_with_output='yes')
+ self.assertAllClose(10. * expected_output, self.evaluate(output))
+ output = model(first, second=second, training=False)
+ self.assertAllClose(expected_output, self.evaluate(output))
+ if not context.executing_eagerly():
+ six.assertCountEqual(self, [first, second], model.inputs)
+ with self.assertRaises(TypeError):
+ # tf.layers.Layer expects an "inputs" argument, so all-keywords doesn't
+ # work at the moment.
+ model(first=first, second=second, fiddle_with_output='yes')
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_inputs_in_signature(self):
+
+ class HasInputsAndOtherPositional(keras.Model):
+
+ def call(self, inputs, some_other_arg, training=False):
+ return inputs
+
+ model = HasInputsAndOtherPositional()
+ with self.assertRaisesRegexp(
+ TypeError, 'everything else as a keyword argument'):
+ model(array_ops.ones([]), array_ops.ones([]))
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_kwargs_in_signature(self):
+
+ class HasKwargs(keras.Model):
+
+ def call(self, x, y=3, **key_words):
+ return x
+
+ model = HasKwargs()
+ arg = array_ops.ones([])
+ model(arg, a=3)
+ if not context.executing_eagerly():
+ six.assertCountEqual(self, [arg], model.inputs)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_args_in_signature(self):
+
+ class HasArgs(keras.Model):
+
+ def call(self, x, *args, **kwargs):
+ return [x] + list(args)
+
+ model = HasArgs()
+ arg1 = array_ops.ones([])
+ arg2 = array_ops.ones([])
+ arg3 = array_ops.ones([])
+ model(arg1, arg2, arg3, a=3)
+ if not context.executing_eagerly():
+ six.assertCountEqual(self, [arg1, arg2, arg3], model.inputs)
+
+ def test_args_and_keywords_in_signature(self):
+
+ class HasArgs(keras.Model):
+
+ def call(self, x, training=True, *args, **kwargs):
+ return x
+
+ with context.graph_mode():
+ model = HasArgs()
+ arg1 = array_ops.ones([])
+ arg2 = array_ops.ones([])
+ arg3 = array_ops.ones([])
+ with self.assertRaisesRegexp(TypeError, 'args and arguments with'):
+ model(arg1, arg2, arg3, a=3)
+
+ def test_training_no_default(self):
+
+ class TrainingNoDefault(keras.Model):
+
+ def call(self, x, training):
+ return x
+
+ with context.graph_mode():
+ model = TrainingNoDefault()
+ arg = array_ops.ones([])
+ model(arg, True)
+ six.assertCountEqual(self, [arg], model.inputs)
+
+ def test_training_no_default_with_positional(self):
+
+ class TrainingNoDefaultWithPositional(keras.Model):
+
+ def call(self, x, training, positional):
+ return x
+
+ with context.graph_mode():
+ model = TrainingNoDefaultWithPositional()
+ arg1 = array_ops.ones([])
+ arg2 = array_ops.ones([])
+ arg3 = array_ops.ones([])
+ with self.assertRaisesRegexp(TypeError, 'after a non-input'):
+ model(arg1, arg2, arg3)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index 1e5f26a77f..242cdff6f3 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -625,6 +625,8 @@ class Layer(checkpointable.CheckpointableBase):
input_list = nest.flatten(inputs)
build_graph = not context.executing_eagerly()
+ # TODO(fchollet, allenl): Make deferred mode work with subclassed Models
+ # which don't use an "inputs" argument.
in_deferred_mode = isinstance(input_list[0], _DeferredTensor)
# Ensure the Layer, if being reused, is working with inputs from
# the same graph as where it was created.
diff --git a/third_party/examples/eager/spinn/spinn.py b/third_party/examples/eager/spinn/spinn.py
index f8fb6ecb0c..8a2b24aa4e 100644
--- a/third_party/examples/eager/spinn/spinn.py
+++ b/third_party/examples/eager/spinn/spinn.py
@@ -266,8 +266,7 @@ class SPINN(tf.keras.Model):
trackings.append(tracking)
if rights:
- reducer_output = self.reducer(
- lefts, right_in=rights, tracking=trackings)
+ reducer_output = self.reducer(lefts, rights, trackings)
reduced = iter(reducer_output)
for transition, stack in zip(trans, stacks):
@@ -388,10 +387,10 @@ class SNLIClassifier(tf.keras.Model):
# Run the batch-normalized and dropout-processed word vectors through the
# SPINN encoder.
- premise = self.encoder(
- premise_embed, transitions=premise_transition, training=training)
- hypothesis = self.encoder(
- hypothesis_embed, transitions=hypothesis_transition, training=training)
+ premise = self.encoder(premise_embed, premise_transition,
+ training=training)
+ hypothesis = self.encoder(hypothesis_embed, hypothesis_transition,
+ training=training)
# Combine encoder outputs for premises and hypotheses into logits.
# Then apply batch normalization and dropuout on the logits.
@@ -465,11 +464,10 @@ class SNLIClassifierTrainer(tfe.Checkpointable):
"""
with tfe.GradientTape() as tape:
tape.watch(self._model.variables)
- # TODO(allenl): Allow passing Layer inputs as position arguments.
logits = self._model(premise,
- premise_transition=premise_transition,
- hypothesis=hypothesis,
- hypothesis_transition=hypothesis_transition,
+ premise_transition,
+ hypothesis,
+ hypothesis_transition,
training=True)
loss = self.loss(labels, logits)
gradients = tape.gradient(loss, self._model.variables)
@@ -533,9 +531,7 @@ def _evaluate_on_dataset(snli_data, batch_size, trainer, use_gpu):
snli_data, batch_size):
if use_gpu:
label, prem, hypo = label.gpu(), prem.gpu(), hypo.gpu()
- logits = trainer.model(
- prem, premise_transition=prem_trans, hypothesis=hypo,
- hypothesis_transition=hypo_trans, training=False)
+ logits = trainer.model(prem, prem_trans, hypo, hypo_trans, training=False)
loss_val = trainer.loss(label, logits)
batch_size = tf.shape(label)[0]
mean_loss(loss_val, weights=batch_size.gpu() if use_gpu else batch_size)
@@ -639,11 +635,8 @@ def train_or_infer_spinn(embed,
hypo, hypo_trans = inference_sentence_pair[1]
hypo_trans = inference_sentence_pair[1][1]
inference_logits = model(
- tf.constant(prem),
- premise_transition=tf.constant(prem_trans),
- hypothesis=tf.constant(hypo),
- hypothesis_transition=tf.constant(hypo_trans),
- training=False)
+ tf.constant(prem), tf.constant(prem_trans),
+ tf.constant(hypo), tf.constant(hypo_trans), training=False)
inference_logits = inference_logits[0][1:]
max_index = tf.argmax(inference_logits)
print("\nInference logits:")