diff options
author | Michael Case <mikecase@google.com> | 2018-06-25 10:27:33 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-06-25 10:27:33 -0700 |
commit | dde358c1fe366df3d17eec11e1476503db8cbb57 (patch) | |
tree | 94f150d59ab2b53f4e30c6b923472d77faa0b2c8 | |
parent | 28994d19d65c65636eabda7beeee353118a76c45 (diff) | |
parent | 1adbc5aa6927d1a5d7151c31aea1da6e73a1b53c (diff) |
Merge pull request #20203 from alextp/cherrypicks_GUVCA
Add a single positional argument mode for shape inference in subclass…
-rw-r--r-- | tensorflow/python/keras/engine/base_layer.py | 45 | ||||
-rw-r--r-- | tensorflow/python/keras/engine/network.py | 50 | ||||
-rw-r--r-- | tensorflow/python/keras/engine/training.py | 27 | ||||
-rw-r--r-- | tensorflow/python/keras/model_subclassing_test.py | 4 |
4 files changed, 98 insertions, 28 deletions
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 24716cfbe4..4814275fd5 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import collections +import enum # pylint: disable=g-bad-import-order import inspect # Necessary supplement to tf_inspect to deal with variadic args. import numpy as np @@ -50,6 +51,20 @@ from tensorflow.python.util import tf_inspect from tensorflow.python.util.tf_export import tf_export +class CallConvention(enum.Enum): + """Calling conventions for passing `Layer` inputs to `Layer.call`.""" + # The Layer takes inputs as its first argument, named "inputs" for + # compatibility with the signature of Layer.__call__. This is the mode assumed + # for Layers which are not subclassed Models. + EXPLICIT_INPUTS_ARGUMENT = 1 + # The Layer takes a single positional argument, not named "inputs". It's + # treated like an "inputs" argument. + SINGLE_POSITIONAL_ARGUMENT = 2 + # The Layer has multiple positional arguments to which its inputs should be + # bound. + POSITIONAL_ARGUMENTS_ARE_INPUTS = 3 + + @tf_export('keras.layers.Layer') class Layer(checkpointable.CheckpointableBase): """Base layer class. @@ -149,7 +164,7 @@ class Layer(checkpointable.CheckpointableBase): self._call_fn_args = function_utils.fn_args(self.call) self._compute_previous_mask = ('mask' in self._call_fn_args or hasattr(self, 'compute_mask')) - self._uses_inputs_arg = True + self._call_convention = CallConvention.EXPLICIT_INPUTS_ARGUMENT # These lists will be filled via successive calls # to self._add_inbound_node(). @@ -793,12 +808,22 @@ class Layer(checkpointable.CheckpointableBase): pass # C type such as dict. Masking not supported in this case. def _set_connectivity_metadata_(self, inputs, outputs, args, kwargs): - if args and getattr(self, '_uses_inputs_arg', True): - raise TypeError( - 'This Layer takes an `inputs` argument to call(), and only the ' - '`inputs` argument may be specified as a positional argument. ' - 'Pass everything else as a keyword argument (those arguments will' - ' not be tracked as inputs to the Layer).') + call_convention = getattr(self, '_call_convention', + CallConvention.EXPLICIT_INPUTS_ARGUMENT) + if args: + if call_convention == CallConvention.EXPLICIT_INPUTS_ARGUMENT: + raise TypeError( + 'This Layer takes an `inputs` argument to call(), and only the ' + '`inputs` argument may be specified as a positional argument. ' + 'Pass everything else as a keyword argument (those arguments will' + ' not be tracked as inputs to the Layer).') + elif call_convention == CallConvention.SINGLE_POSITIONAL_ARGUMENT: + raise TypeError( + 'This Layer takes a single positional argument to call(), which is ' + 'by convention the inputs argument, and only this argument may be ' + 'specified as a positional argument. Pass everything else as a ' + 'keyword argument (those arguments will not be tracked as inputs ' + 'to the Layer).') # If the layer returns tensors from its inputs, unmodified, # we copy them to avoid loss of tensor metadata. @@ -834,7 +859,11 @@ class Layer(checkpointable.CheckpointableBase): A tuple of (inputs, non_input_kwargs). These may be the same objects as were passed in (call_args and call_kwargs). """ - if getattr(self, '_uses_inputs_arg', True): + call_convention = getattr(self, '_call_convention', + CallConvention.EXPLICIT_INPUTS_ARGUMENT) + if (call_convention in ( + CallConvention.EXPLICIT_INPUTS_ARGUMENT, + CallConvention.SINGLE_POSITIONAL_ARGUMENT)): assert len(call_args) == 1 # TypeError raised earlier in __call__. return call_args[0], call_kwargs else: diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index 3d567b8378..6f27eea1e7 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -135,7 +135,7 @@ class Network(base_layer.Layer): self._in_progress_restore_finalizer = None def _init_graph_network(self, inputs, outputs, name=None): - self._uses_inputs_arg = True + self._call_convention = base_layer.CallConvention.EXPLICIT_INPUTS_ARGUMENT # Normalize and set self.inputs, self.outputs. if isinstance(inputs, (list, tuple)): self.inputs = list(inputs) # Tensor or list of tensors. @@ -295,19 +295,55 @@ class Network(base_layer.Layer): def _init_subclassed_network(self, name=None): self._base_init(name=name) self._is_graph_network = False - call_args = tf_inspect.getargspec(self.call).args - if 'training' in call_args: + call_argspec = tf_inspect.getargspec(self.call) + if 'training' in call_argspec.args: self._expects_training_arg = True else: self._expects_training_arg = False - if 'inputs' in call_args: - self._uses_inputs_arg = True - else: - self._uses_inputs_arg = False + self._call_convention = self._determine_call_convention(call_argspec) self.outputs = None self.inputs = None self.built = False + def _determine_call_convention(self, call_argspec): + """Decides how `self.call()` is invoked. See base_layer.CallConvention.""" + if call_argspec.varargs: + may_take_single_argument = False + else: + try: + # Note: tf_inspect doesn't raise a TypeError when regular inspect would, + # so we need to keep in mind that "getcallargs" may have returned + # something even though we under-specified positional arguments. + all_args = tf_inspect.getcallargs(self.call, None) + self_args = set() + for arg_name, obj in all_args.items(): + if obj is self: + self_args.add(arg_name) + may_take_single_argument = True + except TypeError: + may_take_single_argument = False + if may_take_single_argument: + # A single positional argument (plus "self") is considered equivalent to + # an "inputs" argument. + all_positional_args = len(call_argspec.args) + if call_argspec.defaults is not None: + all_positional_args -= len(call_argspec.defaults) + non_self_positional_args = all_positional_args + for positional_arg_name in call_argspec.args[:all_positional_args]: + if positional_arg_name in self_args: + non_self_positional_args -= 1 + if non_self_positional_args == 1: + if 'inputs' in call_argspec.args[all_positional_args:]: + raise TypeError( + "Model.call() takes a single positional argument (to which " + "inputs are passed by convention) and a separate 'inputs' " + "argument. Unable to determine which arguments are inputs.") + return base_layer.CallConvention.SINGLE_POSITIONAL_ARGUMENT + if 'inputs' in call_argspec.args: + return base_layer.CallConvention.EXPLICIT_INPUTS_ARGUMENT + else: + return base_layer.CallConvention.POSITIONAL_ARGUMENTS_ARE_INPUTS + def _track_layers(self, layers): """Add Checkpointable dependencies on a list of Layers.""" weight_layer_index = 0 diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 6d625f16c2..04a2aa7664 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -31,12 +31,11 @@ from tensorflow.python.keras import backend as K from tensorflow.python.keras import losses from tensorflow.python.keras import metrics as metrics_module from tensorflow.python.keras import optimizers +from tensorflow.python.keras.engine import base_layer from tensorflow.python.keras.engine import training_arrays from tensorflow.python.keras.engine import training_eager from tensorflow.python.keras.engine import training_generator from tensorflow.python.keras.engine import training_utils -from tensorflow.python.keras.engine.base_layer import DeferredTensor -from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.engine.network import Network from tensorflow.python.keras.utils.generic_utils import slice_arrays from tensorflow.python.ops import array_ops @@ -523,7 +522,7 @@ class Model(Network): # Keep track of state updates created by # stateful metrics (i.e. metrics layers). - if isinstance(metric_fn, Layer) and metric_fn.stateful: + if isinstance(metric_fn, base_layer.Layer) and metric_fn.stateful: self.stateful_metric_names.append(metric_name) self.stateful_metric_functions.append(metric_fn) self.metrics_updates += metric_fn.updates @@ -959,11 +958,17 @@ class Model(Network): whether to build the model's graph in inference mode (False), training mode (True), or using the Keras learning phase (None). """ - if not getattr(self, '_uses_inputs_arg', True): + call_convention = getattr( + self, + '_call_convention', + base_layer.CallConvention.EXPLICIT_INPUTS_ARGUMENT) + if call_convention not in ( + base_layer.CallConvention.EXPLICIT_INPUTS_ARGUMENT, + base_layer.CallConvention.SINGLE_POSITIONAL_ARGUMENT): raise NotImplementedError( - 'Subclassed Models without "inputs" in their call() signatures do ' - 'not yet support shape inference. File a feature request if this ' - 'limitation bothers you.') + 'Subclassed Models without "inputs" (or single positional arguments) ' + 'in their call() signatures do not yet support shape inference. File ' + 'a feature request if this limitation bothers you.') if self.__class__.__name__ == 'Sequential': # Note: we can't test whether the model is `Sequential` via `isinstance` # since `Sequential` depends on `Model`. @@ -1020,11 +1025,11 @@ class Model(Network): else: dummy_output_values = [dummy_output_values] self.outputs = [ - DeferredTensor(shape=(None for _ in v.shape), - dtype=v.dtype) for v in dummy_output_values] + base_layer.DeferredTensor(shape=(None for _ in v.shape), + dtype=v.dtype) for v in dummy_output_values] self.inputs = [ - DeferredTensor(shape=(None for _ in v.shape), - dtype=v.dtype) for v in dummy_input_values] + base_layer.DeferredTensor(shape=(None for _ in v.shape), + dtype=v.dtype) for v in dummy_input_values] self.input_names = [ 'input_%d' % (i + 1) for i in range(len(dummy_input_values))] self.output_names = [ diff --git a/tensorflow/python/keras/model_subclassing_test.py b/tensorflow/python/keras/model_subclassing_test.py index 86f7e20bec..8fb957da43 100644 --- a/tensorflow/python/keras/model_subclassing_test.py +++ b/tensorflow/python/keras/model_subclassing_test.py @@ -56,8 +56,8 @@ class SimpleTestModel(keras.Model): if self.use_bn: self.bn = keras.layers.BatchNormalization(axis=-1) - def call(self, inputs): - x = self.dense1(inputs) + def call(self, x): + x = self.dense1(x) if self.use_dp: x = self.dp(x) if self.use_bn: |