aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Michael Case <mikecase@google.com>2018-06-25 10:27:33 -0700
committerGravatar GitHub <noreply@github.com>2018-06-25 10:27:33 -0700
commitdde358c1fe366df3d17eec11e1476503db8cbb57 (patch)
tree94f150d59ab2b53f4e30c6b923472d77faa0b2c8
parent28994d19d65c65636eabda7beeee353118a76c45 (diff)
parent1adbc5aa6927d1a5d7151c31aea1da6e73a1b53c (diff)
Merge pull request #20203 from alextp/cherrypicks_GUVCA
Add a single positional argument mode for shape inference in subclass…
-rw-r--r--tensorflow/python/keras/engine/base_layer.py45
-rw-r--r--tensorflow/python/keras/engine/network.py50
-rw-r--r--tensorflow/python/keras/engine/training.py27
-rw-r--r--tensorflow/python/keras/model_subclassing_test.py4
4 files changed, 98 insertions, 28 deletions
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py
index 24716cfbe4..4814275fd5 100644
--- a/tensorflow/python/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/engine/base_layer.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import collections
+import enum # pylint: disable=g-bad-import-order
import inspect # Necessary supplement to tf_inspect to deal with variadic args.
import numpy as np
@@ -50,6 +51,20 @@ from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
+class CallConvention(enum.Enum):
+ """Calling conventions for passing `Layer` inputs to `Layer.call`."""
+ # The Layer takes inputs as its first argument, named "inputs" for
+ # compatibility with the signature of Layer.__call__. This is the mode assumed
+ # for Layers which are not subclassed Models.
+ EXPLICIT_INPUTS_ARGUMENT = 1
+ # The Layer takes a single positional argument, not named "inputs". It's
+ # treated like an "inputs" argument.
+ SINGLE_POSITIONAL_ARGUMENT = 2
+ # The Layer has multiple positional arguments to which its inputs should be
+ # bound.
+ POSITIONAL_ARGUMENTS_ARE_INPUTS = 3
+
+
@tf_export('keras.layers.Layer')
class Layer(checkpointable.CheckpointableBase):
"""Base layer class.
@@ -149,7 +164,7 @@ class Layer(checkpointable.CheckpointableBase):
self._call_fn_args = function_utils.fn_args(self.call)
self._compute_previous_mask = ('mask' in self._call_fn_args or
hasattr(self, 'compute_mask'))
- self._uses_inputs_arg = True
+ self._call_convention = CallConvention.EXPLICIT_INPUTS_ARGUMENT
# These lists will be filled via successive calls
# to self._add_inbound_node().
@@ -793,12 +808,22 @@ class Layer(checkpointable.CheckpointableBase):
pass # C type such as dict. Masking not supported in this case.
def _set_connectivity_metadata_(self, inputs, outputs, args, kwargs):
- if args and getattr(self, '_uses_inputs_arg', True):
- raise TypeError(
- 'This Layer takes an `inputs` argument to call(), and only the '
- '`inputs` argument may be specified as a positional argument. '
- 'Pass everything else as a keyword argument (those arguments will'
- ' not be tracked as inputs to the Layer).')
+ call_convention = getattr(self, '_call_convention',
+ CallConvention.EXPLICIT_INPUTS_ARGUMENT)
+ if args:
+ if call_convention == CallConvention.EXPLICIT_INPUTS_ARGUMENT:
+ raise TypeError(
+ 'This Layer takes an `inputs` argument to call(), and only the '
+ '`inputs` argument may be specified as a positional argument. '
+ 'Pass everything else as a keyword argument (those arguments will'
+ ' not be tracked as inputs to the Layer).')
+ elif call_convention == CallConvention.SINGLE_POSITIONAL_ARGUMENT:
+ raise TypeError(
+ 'This Layer takes a single positional argument to call(), which is '
+ 'by convention the inputs argument, and only this argument may be '
+ 'specified as a positional argument. Pass everything else as a '
+ 'keyword argument (those arguments will not be tracked as inputs '
+ 'to the Layer).')
# If the layer returns tensors from its inputs, unmodified,
# we copy them to avoid loss of tensor metadata.
@@ -834,7 +859,11 @@ class Layer(checkpointable.CheckpointableBase):
A tuple of (inputs, non_input_kwargs). These may be the same objects as
were passed in (call_args and call_kwargs).
"""
- if getattr(self, '_uses_inputs_arg', True):
+ call_convention = getattr(self, '_call_convention',
+ CallConvention.EXPLICIT_INPUTS_ARGUMENT)
+ if (call_convention in (
+ CallConvention.EXPLICIT_INPUTS_ARGUMENT,
+ CallConvention.SINGLE_POSITIONAL_ARGUMENT)):
assert len(call_args) == 1 # TypeError raised earlier in __call__.
return call_args[0], call_kwargs
else:
diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py
index 3d567b8378..6f27eea1e7 100644
--- a/tensorflow/python/keras/engine/network.py
+++ b/tensorflow/python/keras/engine/network.py
@@ -135,7 +135,7 @@ class Network(base_layer.Layer):
self._in_progress_restore_finalizer = None
def _init_graph_network(self, inputs, outputs, name=None):
- self._uses_inputs_arg = True
+ self._call_convention = base_layer.CallConvention.EXPLICIT_INPUTS_ARGUMENT
# Normalize and set self.inputs, self.outputs.
if isinstance(inputs, (list, tuple)):
self.inputs = list(inputs) # Tensor or list of tensors.
@@ -295,19 +295,55 @@ class Network(base_layer.Layer):
def _init_subclassed_network(self, name=None):
self._base_init(name=name)
self._is_graph_network = False
- call_args = tf_inspect.getargspec(self.call).args
- if 'training' in call_args:
+ call_argspec = tf_inspect.getargspec(self.call)
+ if 'training' in call_argspec.args:
self._expects_training_arg = True
else:
self._expects_training_arg = False
- if 'inputs' in call_args:
- self._uses_inputs_arg = True
- else:
- self._uses_inputs_arg = False
+ self._call_convention = self._determine_call_convention(call_argspec)
self.outputs = None
self.inputs = None
self.built = False
+ def _determine_call_convention(self, call_argspec):
+ """Decides how `self.call()` is invoked. See base_layer.CallConvention."""
+ if call_argspec.varargs:
+ may_take_single_argument = False
+ else:
+ try:
+ # Note: tf_inspect doesn't raise a TypeError when regular inspect would,
+ # so we need to keep in mind that "getcallargs" may have returned
+ # something even though we under-specified positional arguments.
+ all_args = tf_inspect.getcallargs(self.call, None)
+ self_args = set()
+ for arg_name, obj in all_args.items():
+ if obj is self:
+ self_args.add(arg_name)
+ may_take_single_argument = True
+ except TypeError:
+ may_take_single_argument = False
+ if may_take_single_argument:
+ # A single positional argument (plus "self") is considered equivalent to
+ # an "inputs" argument.
+ all_positional_args = len(call_argspec.args)
+ if call_argspec.defaults is not None:
+ all_positional_args -= len(call_argspec.defaults)
+ non_self_positional_args = all_positional_args
+ for positional_arg_name in call_argspec.args[:all_positional_args]:
+ if positional_arg_name in self_args:
+ non_self_positional_args -= 1
+ if non_self_positional_args == 1:
+ if 'inputs' in call_argspec.args[all_positional_args:]:
+ raise TypeError(
+ "Model.call() takes a single positional argument (to which "
+ "inputs are passed by convention) and a separate 'inputs' "
+ "argument. Unable to determine which arguments are inputs.")
+ return base_layer.CallConvention.SINGLE_POSITIONAL_ARGUMENT
+ if 'inputs' in call_argspec.args:
+ return base_layer.CallConvention.EXPLICIT_INPUTS_ARGUMENT
+ else:
+ return base_layer.CallConvention.POSITIONAL_ARGUMENTS_ARE_INPUTS
+
def _track_layers(self, layers):
"""Add Checkpointable dependencies on a list of Layers."""
weight_layer_index = 0
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 6d625f16c2..04a2aa7664 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -31,12 +31,11 @@ from tensorflow.python.keras import backend as K
from tensorflow.python.keras import losses
from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.keras import optimizers
+from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import training_arrays
from tensorflow.python.keras.engine import training_eager
from tensorflow.python.keras.engine import training_generator
from tensorflow.python.keras.engine import training_utils
-from tensorflow.python.keras.engine.base_layer import DeferredTensor
-from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.engine.network import Network
from tensorflow.python.keras.utils.generic_utils import slice_arrays
from tensorflow.python.ops import array_ops
@@ -523,7 +522,7 @@ class Model(Network):
# Keep track of state updates created by
# stateful metrics (i.e. metrics layers).
- if isinstance(metric_fn, Layer) and metric_fn.stateful:
+ if isinstance(metric_fn, base_layer.Layer) and metric_fn.stateful:
self.stateful_metric_names.append(metric_name)
self.stateful_metric_functions.append(metric_fn)
self.metrics_updates += metric_fn.updates
@@ -959,11 +958,17 @@ class Model(Network):
whether to build the model's graph in inference mode (False), training
mode (True), or using the Keras learning phase (None).
"""
- if not getattr(self, '_uses_inputs_arg', True):
+ call_convention = getattr(
+ self,
+ '_call_convention',
+ base_layer.CallConvention.EXPLICIT_INPUTS_ARGUMENT)
+ if call_convention not in (
+ base_layer.CallConvention.EXPLICIT_INPUTS_ARGUMENT,
+ base_layer.CallConvention.SINGLE_POSITIONAL_ARGUMENT):
raise NotImplementedError(
- 'Subclassed Models without "inputs" in their call() signatures do '
- 'not yet support shape inference. File a feature request if this '
- 'limitation bothers you.')
+ 'Subclassed Models without "inputs" (or single positional arguments) '
+ 'in their call() signatures do not yet support shape inference. File '
+ 'a feature request if this limitation bothers you.')
if self.__class__.__name__ == 'Sequential':
# Note: we can't test whether the model is `Sequential` via `isinstance`
# since `Sequential` depends on `Model`.
@@ -1020,11 +1025,11 @@ class Model(Network):
else:
dummy_output_values = [dummy_output_values]
self.outputs = [
- DeferredTensor(shape=(None for _ in v.shape),
- dtype=v.dtype) for v in dummy_output_values]
+ base_layer.DeferredTensor(shape=(None for _ in v.shape),
+ dtype=v.dtype) for v in dummy_output_values]
self.inputs = [
- DeferredTensor(shape=(None for _ in v.shape),
- dtype=v.dtype) for v in dummy_input_values]
+ base_layer.DeferredTensor(shape=(None for _ in v.shape),
+ dtype=v.dtype) for v in dummy_input_values]
self.input_names = [
'input_%d' % (i + 1) for i in range(len(dummy_input_values))]
self.output_names = [
diff --git a/tensorflow/python/keras/model_subclassing_test.py b/tensorflow/python/keras/model_subclassing_test.py
index 86f7e20bec..8fb957da43 100644
--- a/tensorflow/python/keras/model_subclassing_test.py
+++ b/tensorflow/python/keras/model_subclassing_test.py
@@ -56,8 +56,8 @@ class SimpleTestModel(keras.Model):
if self.use_bn:
self.bn = keras.layers.BatchNormalization(axis=-1)
- def call(self, inputs):
- x = self.dense1(inputs)
+ def call(self, x):
+ x = self.dense1(x)
if self.use_dp:
x = self.dp(x)
if self.use_bn: