aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-05-31 19:03:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-31 19:07:09 -0700
commit3df9efb6fd65d7cf1249f9cad44c53d7f0a142b9 (patch)
tree042814a4607ccb4b409916beeb2f9797cbece5e7
parentd3095c93fc042cf6200f5552e910804e1f9dc196 (diff)
Add a single positional argument mode for shape inference in subclassed Models.
Allows fit() when call's signature looks something like call(x, training=True). Calling conventions are "inputs", single positional, and multiple positional. Right now the distinction between "inputs" and single positional calling conventions is the text of one error message. Both support shape inference (which just hasn't been implemented for multiple positional input arguments yet). PiperOrigin-RevId: 198815483
-rw-r--r--tensorflow/python/keras/engine/base_layer.py45
-rw-r--r--tensorflow/python/keras/engine/network.py50
-rw-r--r--tensorflow/python/keras/engine/training.py27
-rw-r--r--tensorflow/python/keras/model_subclassing_test.py4
4 files changed, 98 insertions, 28 deletions
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py
index 24716cfbe4..4814275fd5 100644
--- a/tensorflow/python/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/engine/base_layer.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import collections
+import enum # pylint: disable=g-bad-import-order
import inspect # Necessary supplement to tf_inspect to deal with variadic args.
import numpy as np
@@ -50,6 +51,20 @@ from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
+class CallConvention(enum.Enum):
+ """Calling conventions for passing `Layer` inputs to `Layer.call`."""
+ # The Layer takes inputs as its first argument, named "inputs" for
+ # compatibility with the signature of Layer.__call__. This is the mode assumed
+ # for Layers which are not subclassed Models.
+ EXPLICIT_INPUTS_ARGUMENT = 1
+ # The Layer takes a single positional argument, not named "inputs". It's
+ # treated like an "inputs" argument.
+ SINGLE_POSITIONAL_ARGUMENT = 2
+ # The Layer has multiple positional arguments to which its inputs should be
+ # bound.
+ POSITIONAL_ARGUMENTS_ARE_INPUTS = 3
+
+
@tf_export('keras.layers.Layer')
class Layer(checkpointable.CheckpointableBase):
"""Base layer class.
@@ -149,7 +164,7 @@ class Layer(checkpointable.CheckpointableBase):
self._call_fn_args = function_utils.fn_args(self.call)
self._compute_previous_mask = ('mask' in self._call_fn_args or
hasattr(self, 'compute_mask'))
- self._uses_inputs_arg = True
+ self._call_convention = CallConvention.EXPLICIT_INPUTS_ARGUMENT
# These lists will be filled via successive calls
# to self._add_inbound_node().
@@ -793,12 +808,22 @@ class Layer(checkpointable.CheckpointableBase):
pass # C type such as dict. Masking not supported in this case.
def _set_connectivity_metadata_(self, inputs, outputs, args, kwargs):
- if args and getattr(self, '_uses_inputs_arg', True):
- raise TypeError(
- 'This Layer takes an `inputs` argument to call(), and only the '
- '`inputs` argument may be specified as a positional argument. '
- 'Pass everything else as a keyword argument (those arguments will'
- ' not be tracked as inputs to the Layer).')
+ call_convention = getattr(self, '_call_convention',
+ CallConvention.EXPLICIT_INPUTS_ARGUMENT)
+ if args:
+ if call_convention == CallConvention.EXPLICIT_INPUTS_ARGUMENT:
+ raise TypeError(
+ 'This Layer takes an `inputs` argument to call(), and only the '
+ '`inputs` argument may be specified as a positional argument. '
+ 'Pass everything else as a keyword argument (those arguments will'
+ ' not be tracked as inputs to the Layer).')
+ elif call_convention == CallConvention.SINGLE_POSITIONAL_ARGUMENT:
+ raise TypeError(
+ 'This Layer takes a single positional argument to call(), which is '
+ 'by convention the inputs argument, and only this argument may be '
+ 'specified as a positional argument. Pass everything else as a '
+ 'keyword argument (those arguments will not be tracked as inputs '
+ 'to the Layer).')
# If the layer returns tensors from its inputs, unmodified,
# we copy them to avoid loss of tensor metadata.
@@ -834,7 +859,11 @@ class Layer(checkpointable.CheckpointableBase):
A tuple of (inputs, non_input_kwargs). These may be the same objects as
were passed in (call_args and call_kwargs).
"""
- if getattr(self, '_uses_inputs_arg', True):
+ call_convention = getattr(self, '_call_convention',
+ CallConvention.EXPLICIT_INPUTS_ARGUMENT)
+ if (call_convention in (
+ CallConvention.EXPLICIT_INPUTS_ARGUMENT,
+ CallConvention.SINGLE_POSITIONAL_ARGUMENT)):
assert len(call_args) == 1 # TypeError raised earlier in __call__.
return call_args[0], call_kwargs
else:
diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py
index f63ca1a207..d43aba6875 100644
--- a/tensorflow/python/keras/engine/network.py
+++ b/tensorflow/python/keras/engine/network.py
@@ -134,7 +134,7 @@ class Network(base_layer.Layer):
self._in_progress_restore_finalizer = None
def _init_graph_network(self, inputs, outputs, name=None):
- self._uses_inputs_arg = True
+ self._call_convention = base_layer.CallConvention.EXPLICIT_INPUTS_ARGUMENT
# Normalize and set self.inputs, self.outputs.
if isinstance(inputs, (list, tuple)):
self.inputs = list(inputs) # Tensor or list of tensors.
@@ -294,19 +294,55 @@ class Network(base_layer.Layer):
def _init_subclassed_network(self, name=None):
self._base_init(name=name)
self._is_graph_network = False
- call_args = tf_inspect.getargspec(self.call).args
- if 'training' in call_args:
+ call_argspec = tf_inspect.getargspec(self.call)
+ if 'training' in call_argspec.args:
self._expects_training_arg = True
else:
self._expects_training_arg = False
- if 'inputs' in call_args:
- self._uses_inputs_arg = True
- else:
- self._uses_inputs_arg = False
+ self._call_convention = self._determine_call_convention(call_argspec)
self.outputs = None
self.inputs = None
self.built = False
+ def _determine_call_convention(self, call_argspec):
+ """Decides how `self.call()` is invoked. See base_layer.CallConvention."""
+ if call_argspec.varargs:
+ may_take_single_argument = False
+ else:
+ try:
+ # Note: tf_inspect doesn't raise a TypeError when regular inspect would,
+ # so we need to keep in mind that "getcallargs" may have returned
+ # something even though we under-specified positional arguments.
+ all_args = tf_inspect.getcallargs(self.call, None)
+ self_args = set()
+ for arg_name, obj in all_args.items():
+ if obj is self:
+ self_args.add(arg_name)
+ may_take_single_argument = True
+ except TypeError:
+ may_take_single_argument = False
+ if may_take_single_argument:
+ # A single positional argument (plus "self") is considered equivalent to
+ # an "inputs" argument.
+ all_positional_args = len(call_argspec.args)
+ if call_argspec.defaults is not None:
+ all_positional_args -= len(call_argspec.defaults)
+ non_self_positional_args = all_positional_args
+ for positional_arg_name in call_argspec.args[:all_positional_args]:
+ if positional_arg_name in self_args:
+ non_self_positional_args -= 1
+ if non_self_positional_args == 1:
+ if 'inputs' in call_argspec.args[all_positional_args:]:
+ raise TypeError(
+ "Model.call() takes a single positional argument (to which "
+ "inputs are passed by convention) and a separate 'inputs' "
+ "argument. Unable to determine which arguments are inputs.")
+ return base_layer.CallConvention.SINGLE_POSITIONAL_ARGUMENT
+ if 'inputs' in call_argspec.args:
+ return base_layer.CallConvention.EXPLICIT_INPUTS_ARGUMENT
+ else:
+ return base_layer.CallConvention.POSITIONAL_ARGUMENTS_ARE_INPUTS
+
def _track_layers(self, layers):
"""Add Checkpointable dependencies on a list of Layers."""
weight_layer_index = 0
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 6d625f16c2..04a2aa7664 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -31,12 +31,11 @@ from tensorflow.python.keras import backend as K
from tensorflow.python.keras import losses
from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.keras import optimizers
+from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import training_arrays
from tensorflow.python.keras.engine import training_eager
from tensorflow.python.keras.engine import training_generator
from tensorflow.python.keras.engine import training_utils
-from tensorflow.python.keras.engine.base_layer import DeferredTensor
-from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.engine.network import Network
from tensorflow.python.keras.utils.generic_utils import slice_arrays
from tensorflow.python.ops import array_ops
@@ -523,7 +522,7 @@ class Model(Network):
# Keep track of state updates created by
# stateful metrics (i.e. metrics layers).
- if isinstance(metric_fn, Layer) and metric_fn.stateful:
+ if isinstance(metric_fn, base_layer.Layer) and metric_fn.stateful:
self.stateful_metric_names.append(metric_name)
self.stateful_metric_functions.append(metric_fn)
self.metrics_updates += metric_fn.updates
@@ -959,11 +958,17 @@ class Model(Network):
whether to build the model's graph in inference mode (False), training
mode (True), or using the Keras learning phase (None).
"""
- if not getattr(self, '_uses_inputs_arg', True):
+ call_convention = getattr(
+ self,
+ '_call_convention',
+ base_layer.CallConvention.EXPLICIT_INPUTS_ARGUMENT)
+ if call_convention not in (
+ base_layer.CallConvention.EXPLICIT_INPUTS_ARGUMENT,
+ base_layer.CallConvention.SINGLE_POSITIONAL_ARGUMENT):
raise NotImplementedError(
- 'Subclassed Models without "inputs" in their call() signatures do '
- 'not yet support shape inference. File a feature request if this '
- 'limitation bothers you.')
+ 'Subclassed Models without "inputs" (or single positional arguments) '
+ 'in their call() signatures do not yet support shape inference. File '
+ 'a feature request if this limitation bothers you.')
if self.__class__.__name__ == 'Sequential':
# Note: we can't test whether the model is `Sequential` via `isinstance`
# since `Sequential` depends on `Model`.
@@ -1020,11 +1025,11 @@ class Model(Network):
else:
dummy_output_values = [dummy_output_values]
self.outputs = [
- DeferredTensor(shape=(None for _ in v.shape),
- dtype=v.dtype) for v in dummy_output_values]
+ base_layer.DeferredTensor(shape=(None for _ in v.shape),
+ dtype=v.dtype) for v in dummy_output_values]
self.inputs = [
- DeferredTensor(shape=(None for _ in v.shape),
- dtype=v.dtype) for v in dummy_input_values]
+ base_layer.DeferredTensor(shape=(None for _ in v.shape),
+ dtype=v.dtype) for v in dummy_input_values]
self.input_names = [
'input_%d' % (i + 1) for i in range(len(dummy_input_values))]
self.output_names = [
diff --git a/tensorflow/python/keras/model_subclassing_test.py b/tensorflow/python/keras/model_subclassing_test.py
index 86f7e20bec..8fb957da43 100644
--- a/tensorflow/python/keras/model_subclassing_test.py
+++ b/tensorflow/python/keras/model_subclassing_test.py
@@ -56,8 +56,8 @@ class SimpleTestModel(keras.Model):
if self.use_bn:
self.bn = keras.layers.BatchNormalization(axis=-1)
- def call(self, inputs):
- x = self.dense1(inputs)
+ def call(self, x):
+ x = self.dense1(x)
if self.use_dp:
x = self.dp(x)
if self.use_bn: