aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/layers
diff options
context:
space:
mode:
authorGravatar Francois Chollet <fchollet@google.com>2018-04-10 13:49:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-10 13:51:54 -0700
commit693b339ab2f062ec5bbb29f976c5d1fd94fbffa5 (patch)
tree1e11b6becc6b156b6e89ebb3f4d5d2f886bed188 /tensorflow/python/layers
parent6b593d329005ffb1a10b1c9cd1374d2cdb620b21 (diff)
Refactor layers:
- tf.layers layers now subclasses tf.keras.layers layers. - tf.keras.layers is now agnostic to variable scopes and global collections (future-proof). It also uses ResourceVariable everywhere by default. - As a result tf.keras.layers is in general lower-complexity, with fewer hacks and workarounds. However some of current code is temporary (variable creation should be moved to Checkpointable, arguably, and there are some dependency issues that will require later refactors). - The legacy tf.layers layers behavior is kept, with references to variable scopes and global collections injected in the subclassed tf.layers.base.Layer class (the content of tf.layers.base.Layer is the complexity differential between the old implementation and the new one). Note: this refactor does slightly change the behavior of tf.layers.base.Layer, by disabling extreme edge-case behavior that either has long been invalid, or is dangerous and should most definitely be disabled. This will not affect any users since such behaviors only existed in the base Layer unit tests. The behaviors disabled are: - Option to create reusable variables in `call` (already invalid for some time). - Option to use a variable scope to create layer variables outside of the layer while not having the layer track such variables locally. PiperOrigin-RevId: 192339798
Diffstat (limited to 'tensorflow/python/layers')
-rw-r--r--tensorflow/python/layers/base.py1443
-rw-r--r--tensorflow/python/layers/base_test.py94
-rw-r--r--tensorflow/python/layers/convolutional.py702
-rw-r--r--tensorflow/python/layers/core.py142
-rw-r--r--tensorflow/python/layers/normalization.py516
-rw-r--r--tensorflow/python/layers/pooling.py258
-rw-r--r--tensorflow/python/layers/utils_test.py29
7 files changed, 192 insertions, 2992 deletions
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index ec741d3265..64db49c900 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -12,148 +12,91 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
-
-# pylint: disable=unused-import,g-bad-import-order
"""Contains the base Layer class, from which all layers inherit."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import collections
import copy
-import re
-import weakref
-import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.estimator import util as estimator_util
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.layers import utils as layers_util
-from tensorflow.python.framework import tensor_util
-from tensorflow.python.ops import array_ops
+from tensorflow.python.keras._impl.keras.engine import base_layer
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as tf_variables
-from tensorflow.python.platform import tf_logging as logging
-from tensorflow.python.training import checkpointable
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
-@tf_export('layers.Layer')
-class Layer(checkpointable.CheckpointableBase):
- """Base layer class.
+InputSpec = base_layer.InputSpec # pylint: disable=invalid-name
- This is the class from which all layers inherit, implementing common
- infrastructure functionality.
- A layer is a class implementing common neural networks operations, such
- as convolution, batch norm, etc. These operations require managing variables,
- losses, and updates, as well as applying TensorFlow ops to input tensors.
+@tf_export('layers.Layer')
+class Layer(base_layer.Layer):
+ """Base layer class.
- Users will just instantiate it and then treat it as a callable.
+ It is considered legacy, and we recommend the use of `tf.keras.layers.Layer`
+ instead.
- We recommend that descendants of Layer implement the following methods:
- * `__init__()`: Save configuration in member variables
- * `build()`: Called once from `__call__`, when we know the shapes of inputs
- and `dtype`. Should have the calls to `add_variable()`, and then
- call the super's `build()` (which sets `self.built = True`, which is
- nice in case the user wants to call `build()` manually before the
- first `__call__`).
- * `call()`: Called in `__call__` after making sure `build()` has been called
- once. Should actually perform the logic of applying the layer to the
- input tensors (which should be passed in as the first argument).
+ Arguments:
+ trainable: Boolean, whether the layer's variables should be trainable.
+ name: String name of the layer.
+ dtype: Default dtype of the layer's weights (default of `None` means use the
+ type of the first input).
Read-only properties:
- `name`: The name of the layer (string).
- `dtype`: Default dtype of the layer (default of `None` means use the
+ name: The name of the layer (string).
+ dtype: Default dtype of the layer's weights (default of `None` means use the
type of the first input).
- `trainable_variables`: List of trainable variables.
- `non_trainable_variables`: List of non-trainable variables.
- `variables`: List of all variables of this layer, trainable and
+ trainable_variables: List of trainable variables.
+ non_trainable_variables: List of non-trainable variables.
+ variables: List of all variables of this layer, trainable and
non-trainable.
- `updates`: List of update ops of this layer.
- `losses`: List of losses added by this layer.
+ updates: List of update ops of this layer.
+ losses: List of losses added by this layer.
+ trainable_weights: List of variables to be included in backprop.
+ non_trainable_weights: List of variables that should not be
+ included in backprop.
+ weights: The concatenation of the lists trainable_weights and
+ non_trainable_weights (in this order).
Mutable properties:
- `trainable`: Whether the layer should be trained (boolean).
- `input_spec`: Optional (list of) `InputSpec` object(s) specifying the
+ trainable: Whether the layer should be trained (boolean).
+ input_spec: Optional (list of) `InputSpec` object(s) specifying the
constraints on inputs that can be accepted by the layer.
"""
def __init__(self, trainable=True, name=None, dtype=None,
- activity_regularizer=None, **kwargs):
- # We use a kwargs dict here because these kwargs only exist
- # for compatibility reasons.
- # The list of kwargs is subject to changes in the future.
- # We do not want to commit to it or to expose the list to users at all.
- # Note this is exactly as safe as defining kwargs in the function signature,
- # the only difference being that the list of valid kwargs is defined
- # below rather rather in the signature, and default values are defined
- # in calls to kwargs.get().
- allowed_kwargs = {
- '_scope',
- '_reuse',
- 'input_shape', # For compatibility with Keras `Sequential` model.
- 'batch_size', # For compatibility with Keras `Sequential` model.
- }
- for kwarg in kwargs:
- if kwarg not in allowed_kwargs:
- raise TypeError('Keyword argument not understood:', kwarg)
-
- # Mutable properties
- # Indicates whether the layer's weights are updated during training
- # and whether the layer's updates are run during training
- self.trainable = trainable
- # A stateful layer is a layer whose updates are run during inference too,
- # for instance stateful RNNs.
- self.stateful = False
- # Indicates whether `build` needs to be called upon layer call, to create
- # the layer's weights.
- self.built = False
- # Provides information about which inputs are compatible with the layer.
- self.input_spec = None
-
- if activity_regularizer and context.executing_eagerly():
- raise ValueError(
- ('Activity regularization is not supported when executing eagerly. '
- 'Got activity_regularizer=%s') % (activity_regularizer,))
- self._activity_regularizer = activity_regularizer
+ **kwargs):
+ # For backwards compatibility, legacy layers do not use `ResourceVariable`
+ # by default.
+ self._use_resource_variables = False
+ scope = kwargs.pop('_scope', None)
+ self._reuse = kwargs.pop('_reuse', None)
+
+ # Avoid an incorrect lint error
self._trainable_weights = []
- self._non_trainable_weights = []
- self._updates = []
- # When executing eagerly, _losses is a list of zero-argument lambdas which
- # return tensors. When using graph execution, _losses is a list of ops.
- self._losses = []
- self._reuse = kwargs.get('_reuse')
- self._graph = None # Will be set at build time.
- self._dtype = None if dtype is None else dtypes.as_dtype(dtype).name
- self._call_fn_args = estimator_util.fn_args(self.call)
- self._compute_previous_mask = ('mask' in self._call_fn_args or
- hasattr(self, 'compute_mask'))
- self._call_has_scope_arg = 'scope' in self._call_fn_args
-
- # These lists will be filled via successive calls
- # to self._add_inbound_node().
- self._inbound_nodes = []
- self._outbound_nodes = []
+ self.built = False
- self._init_set_name(name)
+ super(Layer, self).__init__(trainable=trainable, name=name, dtype=dtype,
+ **kwargs)
- # Determine variable scope.
- scope = kwargs.get('_scope')
+ self._graph = None
+ self._call_has_scope_arg = 'scope' in self._call_fn_args
if scope:
with vs.variable_scope(scope) as captured_scope:
self._scope = captured_scope
else:
self._scope = None
+ self._current_scope = None
- # Set `_batch_input_shape` attribute
- # for compatibility with Keras `Sequential` model.
- if 'input_shape' in kwargs:
- batch_size = kwargs.get('batch_size')
- self._batch_input_shape = (batch_size,) + tuple(kwargs['input_shape'])
+ @property
+ def graph(self):
+ if context.executing_eagerly():
+ raise RuntimeError('Layer.graph not supported when executing eagerly.')
+ return self._graph
def _init_set_name(self, name):
# Determine layer name (non-unique).
@@ -166,18 +109,15 @@ class Layer(checkpointable.CheckpointableBase):
self._name, base_name = self._make_unique_name()
self._base_name = base_name
- @property
- def dtype(self):
- return self._dtype
-
- @property
- def name(self):
- return self._name
-
- @property
- def activity_regularizer(self):
- """Optional regularizer function for the output of this layer."""
- return self._activity_regularizer
+ def _make_unique_name(self, name_uid_map=None, avoid_names=None,
+ namespace='', zero_based=False):
+ base_name = base_layer.to_snake_case(self.__class__.__name__)
+ name = base_layer.unique_layer_name(base_name,
+ name_uid_map=name_uid_map,
+ avoid_names=avoid_names,
+ namespace=namespace,
+ zero_based=zero_based)
+ return (name, base_name)
@property
def scope_name(self):
@@ -189,271 +129,16 @@ class Layer(checkpointable.CheckpointableBase):
'querying `scope_name`.')
return self._scope.name
- @property
- def trainable_weights(self):
- return self._trainable_weights if self.trainable else []
-
- @property
- def non_trainable_weights(self):
- if self.trainable:
- return self._non_trainable_weights
- else:
- return self._trainable_weights + self._non_trainable_weights
-
- @property
- def trainable_variables(self):
- return self.trainable_weights
-
- @property
- def non_trainable_variables(self):
- return self.non_trainable_weights
-
- @property
- def weights(self):
- """Returns the list of all layer variables/weights.
-
- Returns:
- A list of variables.
- """
- return self.trainable_weights + self.non_trainable_weights
-
- @property
- def variables(self):
- """Returns the list of all layer variables/weights.
-
- Returns:
- A list of variables.
- """
- return self.weights
-
- @property
- def updates(self):
- if context.executing_eagerly():
- raise RuntimeError('Layer.updates not supported in Eager mode.')
- if not self.trainable and not self.stateful:
- return []
- return self._updates
-
- def add_update(self, updates, inputs=None):
- """Add update op(s), potentially dependent on layer inputs.
-
- Weight updates (for instance, the updates of the moving mean and variance
- in a BatchNormalization layer) may be dependent on the inputs passed
- when calling a layer. Hence, when reusing the same layer on
- different inputs `a` and `b`, some entries in `layer.updates` may be
- dependent on `a` and some on `b`. This method automatically keeps track
- of dependencies.
-
- The `get_updates_for` method allows to retrieve the updates relevant to a
- specific set of inputs.
-
- This call is ignored in Eager mode.
-
- Arguments:
- updates: Update op, or list/tuple of update ops.
- inputs: If anything other than None is passed, it signals the updates
- are conditional on some of the layer's inputs,
- and thus they should only be run where these inputs are available.
- This is the case for BatchNormalization updates, for instance.
- If None, the updates will be taken into account unconditionally,
- and you are responsible for making sure that any dependency they might
- have is available at runtime.
- A step counter might fall into this category.
- """
- if context.executing_eagerly():
- return # Updates already applied when in eager mode.
-
- updates = _to_list(updates)
- updates = [x if isinstance(x, ops.Operation)
- else ops.convert_to_tensor(x) for x in updates]
- self._updates += updates
- if inputs is None:
- for u in updates:
- u._unconditional_update = True # pylint: disable=protected-access
- else:
- for u in updates:
- u._unconditional_update = False # pylint: disable=protected-access
-
- def get_updates_for(self, inputs):
- """Retrieves updates relevant to a specific set of inputs.
-
- Arguments:
- inputs: Input tensor or list/tuple of input tensors.
-
- Returns:
- List of update ops of the layer that depend on `inputs`.
-
- Raises:
- RuntimeError: If called in Eager mode.
- """
- if context.executing_eagerly():
- raise RuntimeError('`get_updates_for()` not supported in Eager mode.')
-
- # Updates disabled if layer is not trainable and not explicitly stateful.
- if not self.trainable and not self.stateful:
- return []
-
- if inputs is None:
- # Requesting unconditional updates.
- return [x for x in self.updates if x._unconditional_update] # pylint: disable=protected-access
-
- # Requesting input-conditional updates.
- inputs = nest.flatten(inputs)
- reachable = layers_util.get_reachable_from_inputs(inputs, self.updates)
- updates = []
- for update in self.updates:
- if update in reachable:
- updates.append(update)
- return updates
-
- @property
- def losses(self):
- """Losses which are associated with this `Layer`.
-
- Note that when executing eagerly, getting this property evaluates
- regularizers. When using graph execution, variable regularization ops have
- already been created and are simply returned here.
-
- Returns:
- A list of tensors.
- """
- if context.executing_eagerly():
- # _losses may only contain variable regularization losses when executing
- # eagerly, and they have been saved as lambdas to be executed when
- # requested.
- return [regularizer() for regularizer in self._losses]
- else:
- return self._losses
-
def add_loss(self, losses, inputs=None):
- """Add loss tensor(s), potentially dependent on layer inputs.
-
- Some losses (for instance, activity regularization losses) may be dependent
- on the inputs passed when calling a layer. Hence, when reusing the same
- layer on different inputs `a` and `b`, some entries in `layer.losses` may
- be dependent on `a` and some on `b`. This method automatically keeps track
- of dependencies.
-
- The `get_losses_for` method allows to retrieve the losses relevant to a
- specific set of inputs.
-
- Note that `add_loss` is not supported when executing eagerly. Instead,
- variable regularizers may be added through `add_variable`. Activity
- regularization is not supported directly (but such losses may be returned
- from `Layer.call()`).
-
- Arguments:
- losses: Loss tensor, or list/tuple of tensors.
- inputs: If anything other than None is passed, it signals the losses
- are conditional on some of the layer's inputs,
- and thus they should only be run where these inputs are available.
- This is the case for activity regularization losses, for instance.
- If `None` is passed, the losses are assumed
- to be unconditional, and will apply across all dataflows of the layer
- (e.g. weight regularization losses).
-
- Raises:
- RuntimeError: If called in Eager mode.
- """
- if context.executing_eagerly():
- # TODO(fchollet): it should be possible (and highly desirable) to support
- # `add_loss` in eager mode. This allows great convenience and flexibility
- # in defining custom losses on the fly (e.g. in VAEs).
- # Simply appending the loss value to `self._losses`
- # is the correct behavior.
- # The only caveat is that we need to force the user to only call
- # `add_loss` from inside a model or Layer's `call` method
- # (otherwise the loss computation cannot be backproped through).
- raise RuntimeError('Layer.add_loss not supported in Eager mode.')
-
- losses = _to_list(losses)
- self._losses += losses
- if inputs is None:
- for loss in losses:
- loss._unconditional_loss = True # pylint: disable=protected-access
- else:
- for loss in losses:
- loss._unconditional_loss = False # pylint: disable=protected-access
+ previous_losses_length = len(self._losses)
+ super(Layer, self).add_loss(losses, inputs=inputs)
# TODO(fchollet): deprecate collection below.
- _add_elements_to_collection(losses, ops.GraphKeys.REGULARIZATION_LOSSES)
-
- def get_losses_for(self, inputs):
- """Retrieves losses relevant to a specific set of inputs.
-
- Arguments:
- inputs: Input tensor or list/tuple of input tensors.
-
- Returns:
- List of loss tensors of the layer that depend on `inputs`.
-
- Raises:
- RuntimeError: If called in Eager mode.
- """
- if context.executing_eagerly():
- raise RuntimeError('Layer.get_losses_for not supported in Eager mode.')
-
- if inputs is None:
- # Requesting unconditional losses.
- return [x for x in self.losses if x._unconditional_loss] # pylint: disable=protected-access
-
- # Requesting input-conditional losses.
- inputs = nest.flatten(inputs)
- # Retrieve the set of tensors in the TF graph that depend on `inputs`.
- # The losses we want to return will be part of this set.
- # To avoid unnecessary work, we stop the search in case all of
- # `self.losses` have been retrieved.
- reachable = layers_util.get_reachable_from_inputs(inputs, self.losses)
- losses = []
- for loss in self.losses:
- if loss in reachable:
- losses.append(loss)
- return losses
-
- def build(self, _):
- """Creates the variables of the layer."""
- self.built = True
-
- def call(self, inputs, **kwargs): # pylint: disable=unused-argument
- """The logic of the layer lives here.
+ new_losses = self._losses[previous_losses_length:]
+ _add_elements_to_collection(new_losses, ops.GraphKeys.REGULARIZATION_LOSSES)
- Arguments:
- inputs: input tensor(s).
- **kwargs: additional keyword arguments.
-
- Returns:
- Output tensor(s).
- """
- return inputs
-
- def _name_scope_name(self, current_variable_scope):
+ def _name_scope(self):
"""Determines op naming for the Layer."""
- return current_variable_scope.original_name_scope
-
- def compute_output_shape(self, input_shape):
- """Computes the output shape of the layer given the input shape.
-
- Args:
- input_shape: A (possibly nested tuple of) `TensorShape`. It need not
- be fully defined (e.g. the batch size may be unknown).
-
- Returns:
- A (possibly nested tuple of) `TensorShape`.
-
- Raises:
- TypeError: if `input_shape` is not a (possibly nested tuple of)
- `TensorShape`.
- ValueError: if `input_shape` is incomplete or is incompatible with the
- the layer.
- """
- raise NotImplementedError
-
- def _make_unique_name(self, name_uid_map=None, avoid_names=None,
- namespace='', zero_based=False):
- base_name = _to_snake_case(self.__class__.__name__)
- name = _unique_layer_name(base_name, name_uid_map=name_uid_map,
- avoid_names=avoid_names, namespace=namespace,
- zero_based=zero_based)
- return (name, base_name)
+ return self._current_scope.original_name_scope
def _set_scope(self, scope=None):
if self._scope is None:
@@ -467,10 +152,11 @@ class Layer(checkpointable.CheckpointableBase):
scope, default_name=self._base_name) as captured_scope:
self._scope = captured_scope
- def add_variable(self, name, shape, dtype=None,
- initializer=None, regularizer=None,
- trainable=True, constraint=None,
- partitioner=None):
+ def add_weight(self, name, shape, dtype=None,
+ initializer=None, regularizer=None,
+ trainable=True, constraint=None,
+ use_resource=None,
+ partitioner=None):
"""Adds a new variable to the layer, or gets an existing one; returns it.
Arguments:
@@ -486,6 +172,7 @@ class Layer(checkpointable.CheckpointableBase):
then this parameter is ignored and any added variables are also
marked as non-trainable.
constraint: constraint instance (callable).
+ use_resource: Whether to use `ResourceVariable`.
partitioner: (optional) partitioner instance (callable). If
provided, when the requested variable is created it will be split
into multiple partitions according to `partitioner`. In this case,
@@ -504,10 +191,6 @@ class Layer(checkpointable.CheckpointableBase):
RuntimeError: If called with partioned variable regularization and
eager execution is enabled.
"""
-
- # `init_graph` should point to the graph in which variable initialization
- # will occur; it should be None if and only if initialization will take
- # place in the eager context.
init_graph = None
if not context.executing_eagerly():
default_graph = ops.get_default_graph()
@@ -530,71 +213,43 @@ class Layer(checkpointable.CheckpointableBase):
self._set_scope(None)
reuse = self.built or self._reuse
+ prev_len_trainable = len(self._trainable_weights)
with vs.variable_scope(
self._scope, reuse=reuse, auxiliary_name_scope=False) as scope:
- with ops.name_scope(self._name_scope_name(scope)):
- variable = self._add_variable_with_custom_getter(
- name=name,
- shape=shape,
- getter=vs.get_variable,
- # Manage errors in Layer rather than Checkpointable.
- overwrite=True,
- initializer=initializer,
+ self._current_scope = scope
+ with ops.name_scope(self._name_scope()):
+ use_resource = (use_resource or
+ self._use_resource_variables or
+ scope.use_resource)
+ variable = super(Layer, self).add_weight(
+ name,
+ shape,
dtype=dtypes.as_dtype(dtype),
+ initializer=initializer or scope.initializer,
+ trainable=trainable,
constraint=constraint,
- trainable=trainable and self.trainable,
- partitioner=partitioner)
-
- if init_graph is not None: # pylint: disable=protected-access
- # The variable was created and initialized in a graph.
-
- if variable in existing_variables:
- # To match the behavior of tf.get_variable(), we only apply
- # regularization if the variable is newly created.
- return variable
-
+ partitioner=partitioner,
+ use_resource=use_resource,
+ getter=vs.get_variable)
+
+ if regularizer:
+ if context.executing_eagerly() or variable not in existing_variables:
+ self._handle_weight_regularization(name, variable, regularizer)
+
+ if init_graph is not None:
+ # Handle edge case where a custom getter has overridden `trainable`.
+ # There is one known occurrence of this, in unit test
+ # testBasicRNNCellNotTrainable in
+ # contrib.rnn.python.kernel_tests.core_rnn_cell_test
with init_graph.as_default():
trainable_variables = tf_variables.trainable_variables()
if (trainable and self.trainable and
variable not in trainable_variables):
# A custom getter / variable scope overrode the trainable flag.
- trainable = False
-
- if regularizer:
- if isinstance(variable, tf_variables.PartitionedVariable):
- for v in variable:
- with ops.colocate_with(v.op):
- with ops.name_scope(name + '/Regularizer'):
- regularization = regularizer(v)
- if regularization is not None:
- self.add_loss(regularization)
- else:
- with ops.colocate_with(variable.op):
- with ops.name_scope(name + '/Regularizer'):
- regularization = regularizer(variable)
- if regularization is not None:
- self.add_loss(regularization)
- elif regularizer: # and initialization took place in an eager context
- if isinstance(variable, tf_variables.PartitionedVariable):
- raise RuntimeError(
- 'Partitioned variable regularization is not yet '
- 'supported when executing eagerly. File a feature request '
- 'if this is important to you.')
- # Save a zero-argument lambda which runs the regularizer on the
- # variable, to be executed when `Layer.losses` is requested.
- # This makes losses responsive to variable updates when executing
- # eagerly.
- #
- # TODO(akshayka): Do the same for graphs as well, so that losses
- # collected in a while_loop can be run outside its control flow
- # context and so that losses won't be swallowed up by graph functions
- # (i.e., `.losses()` should always create regularizers).
- self._losses.append(lambda: regularizer(variable))
-
- if trainable:
- self._trainable_weights.append(variable)
- else:
- self._non_trainable_weights.append(variable)
+ extra_trainable_vars = self._trainable_weights[prev_len_trainable:]
+ self._trainable_weights = self._trainable_weights[
+ :prev_len_trainable]
+ self._non_trainable_weights += extra_trainable_vars
return variable
def __call__(self, inputs, *args, **kwargs):
@@ -622,35 +277,14 @@ class Layer(checkpointable.CheckpointableBase):
ValueError: if the layer's `call` method returns None (an invalid value).
"""
self._set_scope(kwargs.pop('scope', None))
- input_list = nest.flatten(inputs)
- build_graph = not context.executing_eagerly()
- # TODO(fchollet, allenl): Make deferred mode work with subclassed Models
- # which don't use an "inputs" argument.
- in_deferred_mode = isinstance(input_list[0], _DeferredTensor)
- # Ensure the Layer, if being reused, is working with inputs from
- # the same graph as where it was created.
- if build_graph:
+ if not context.executing_eagerly():
try:
# Set layer's "graph" at build time
- self._graph = ops._get_graph_from_inputs(input_list, graph=self._graph) # pylint: disable=protected-access
+ self._graph = ops._get_graph_from_inputs(nest.flatten(inputs), # pylint: disable=protected-access
+ graph=self._graph)
except ValueError as e:
raise ValueError('Input graph and Layer graph are not the same: %s' % e)
- if build_graph or in_deferred_mode:
- user_kwargs = copy.copy(kwargs)
-
- # Handle Keras mask propagation from previous layer to current layer.
- previous_mask = None
- if (not hasattr(self, '_compute_previous_mask') or
- self._compute_previous_mask):
- previous_mask = _collect_previous_mask(inputs)
- if not hasattr(self, '_call_fn_args'):
- self._call_fn_args = estimator_util.fn_args(self.call)
- if ('mask' in self._call_fn_args and 'mask' not in kwargs and
- not _is_all_none(previous_mask)):
- # The previous layer generated a mask, and mask was not explicitly pass
- # to __call__, hence we set previous_mask as the default value.
- kwargs['mask'] = previous_mask
if self.built:
try:
@@ -667,134 +301,27 @@ class Layer(checkpointable.CheckpointableBase):
else:
scope_context_manager = vs.variable_scope(
self._scope, reuse=self._reuse, auxiliary_name_scope=False)
- input_shapes = None
- with scope_context_manager as scope:
- with ops.name_scope(self._name_scope_name(scope)):
- if not self.built:
- if not build_graph:
- # Activity regularization is currently unsupported in Eager mode.
- if self._activity_regularizer:
- raise ValueError(
- 'activity_regularizer currently unsupported with '
- 'eager execution enabled. Found an activity_regularizer in '
- '%s(%s).' % (self.__class__.__name__, self))
- if not build_graph and not in_deferred_mode:
- # TODO(agarwal): support _keras_history in Eager mode.
- for x in input_list:
- if hasattr(x, '_keras_history'):
- raise ValueError('_keras_history currently unsupported in '
- 'Eager mode. Found _keras_history in %s while '
- 'executing __call__ for %s(%s)' %
- (x, self.__class_.__name__, self))
-
- # Check input assumptions set before layer building, e.g. input rank.
- self._assert_input_compatibility(inputs)
- if input_list and self._dtype is None:
- try:
- self._dtype = input_list[0].dtype.base_dtype.name
- except AttributeError:
- pass
- if all(hasattr(x, 'get_shape') for x in input_list):
- input_shapes = nest.map_structure(lambda x: x.get_shape(), inputs)
- self.build(input_shapes)
- try:
- # Note: not all sub-classes of Layer call Layer.__init__ (especially
- # the ones under tensorflow/python/keras). Hence we recompute this
- # attribute here if it is not set.
- # TODO(agarwal): Fix the sub-classes and avoid this complexity.
- call_has_scope_arg = self._call_has_scope_arg
- except AttributeError:
- self._call_fn_args = estimator_util.fn_args(self.call)
- self._call_has_scope_arg = 'scope' in self._call_fn_args
- call_has_scope_arg = self._call_has_scope_arg
- if call_has_scope_arg:
- kwargs['scope'] = scope
- # Check input assumptions set after layer building, e.g. input shape.
- if build_graph or in_deferred_mode:
- self._assert_input_compatibility(inputs)
-
- if not in_deferred_mode:
- outputs = self.call(inputs, *args, **kwargs)
- if outputs is None:
- raise ValueError('A layer\'s `call` method should return a Tensor '
- 'or a list of Tensors, not None.')
- else:
- # Deferred mode behavior: use `compute_output_shape` to
- # infer the number of outputs of the layer and their shapes.
- if input_shapes is None:
- input_shapes = nest.map_structure(lambda x: x.get_shape(), inputs)
-
- output_shapes = self.compute_output_shape(input_shapes)
- output_shapes = nest.flatten(output_shapes)
- outputs = [
- # TODO(fchollet): name the deferred tensors?
- _DeferredTensor(shape=shape, dtype=self._dtype)
- for shape in output_shapes
- ]
- if len(outputs) == 1:
- outputs = outputs[0]
- if build_graph:
- # Apply activity regularization.
- # Note that it should be applied every time the layer creates a new
- # output, since it is output-specific.
- if self._activity_regularizer:
- output_list = nest.flatten(outputs)
- for output in output_list:
- with ops.name_scope('ActivityRegularizer'):
- activity_regularization = self._activity_regularizer(output)
- self.add_loss(activity_regularization, inputs=inputs)
+ with scope_context_manager as scope:
+ self._current_scope = scope
- # TODO(fchollet): consider enabling masking for Eager mode.
- if hasattr(self, 'compute_mask'):
- output_mask = self.compute_mask(inputs, previous_mask)
- if isinstance(outputs, (list, tuple)):
- if output_mask is None:
- output_mask = [None for _ in range(len(outputs))]
- for x, m in zip(outputs, output_mask):
- x._keras_mask = m # pylint: disable=protected-access
- else:
- outputs._keras_mask = output_mask # pylint: disable=protected-access
+ try:
+ call_has_scope_arg = self._call_has_scope_arg
+ except AttributeError:
+ self._call_fn_args = estimator_util.fn_args(self.call)
+ self._call_has_scope_arg = 'scope' in self._call_fn_args
+ call_has_scope_arg = self._call_has_scope_arg
+ if call_has_scope_arg:
+ kwargs['scope'] = scope
- if build_graph:
- # If all input tensors have history metadata,
- # we update the output tensors
- # with corresponding history metadata, thus eventually allowing to use
- # these tensors to instantiate a Network.
- if _have_all_keras_metadata(inputs):
- # If the layer returns tensors from its inputs, unmodified,
- # we copy them to avoid loss of tensor metadata.
- output_ls = nest.flatten(outputs)
- output_ls_copy = []
- for x in output_ls:
- if x in input_list:
- with ops.name_scope(scope.original_name_scope):
- x = array_ops.identity(x)
- output_ls_copy.append(x)
- if len(output_ls_copy) == 1:
- outputs = output_ls_copy[0]
- else:
- outputs = output_ls_copy
+ # Actually call layer
+ outputs = super(Layer, self).__call__(inputs, *args, **kwargs)
+ if not context.executing_eagerly():
# Update global default collections.
_add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
-
- if in_deferred_mode or build_graph:
- if _have_all_keras_metadata(inputs):
- # Add an inbound node to the layer, so it can keep track of this call.
- # This updates the layer history of the output tensor(s).
- self._add_inbound_node(
- input_tensors=inputs, output_tensors=outputs, arguments=user_kwargs)
-
- self.built = True
return outputs
- @property
- def graph(self):
- if context.executing_eagerly():
- raise RuntimeError('Layer.graph not supported in Eager mode.')
- return self._graph
-
def __deepcopy__(self, memo):
no_copy = set(['_graph'])
shallow_copy = set(['_scope', '_always_reuse_variable_scope'])
@@ -806,658 +333,12 @@ class Layer(checkpointable.CheckpointableBase):
setattr(result, k, v)
elif k in shallow_copy:
setattr(result, k, copy.copy(v))
- elif _is_tensor_or_tensor_list(v):
+ elif base_layer.is_tensor_or_tensor_list(v):
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))
return result
- def apply(self, inputs, *args, **kwargs):
- """Apply the layer on a input.
-
- This simply wraps `self.__call__`.
-
- Arguments:
- inputs: Input tensor(s).
- *args: additional positional arguments to be passed to `self.call`.
- **kwargs: additional keyword arguments to be passed to `self.call`.
-
- Returns:
- Output tensor(s).
- """
- return self.__call__(inputs, *args, **kwargs)
-
- def _add_inbound_node(self,
- input_tensors,
- output_tensors,
- arguments=None):
- """Internal method to create an inbound node for the layer.
-
- Arguments:
- input_tensors: list of input tensors.
- output_tensors: list of output tensors.
- arguments: dictionary of keyword arguments that were passed to the
- `call` method of the layer at the call that created the node.
- """
- input_tensors = nest.flatten(input_tensors)
- output_tensors = nest.flatten(output_tensors)
-
- # Collect input tensor(s) coordinates.
- inbound_layers = []
- node_indices = []
- tensor_indices = []
- for x in input_tensors:
- assert hasattr(x, '_keras_history')
- inbound_layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access
- inbound_layers.append(inbound_layer)
- node_indices.append(node_index)
- tensor_indices.append(tensor_index)
-
- # Create node, add it to inbound nodes.
- Node(
- self,
- inbound_layers=inbound_layers,
- node_indices=node_indices,
- tensor_indices=tensor_indices,
- input_tensors=input_tensors,
- output_tensors=output_tensors,
- arguments=arguments)
-
- # Update tensor history metadata.
- for i in range(len(output_tensors)):
- # The metadata attribute consists of 1) a layer instance
- # 2) a node index for the layer, 3) a tensor index for the node.
- # The allows layer reuse (multiple nodes per layer) and multi-output
- # or multi-input layers (e.g. a layer can return multiple tensors,
- # and each can be sent to a different layer).
- output_tensors[i]._keras_history = (self, len(self._inbound_nodes) - 1, i) # pylint: disable=protected-access
-
- def _get_node_attribute_at_index(self, node_index, attr, attr_name):
- """Private utility to retrieves an attribute (e.g. inputs) from a node.
-
- This is used to implement the methods:
- - get_input_shape_at
- - get_output_shape_at
- - get_input_at
- etc...
-
- Arguments:
- node_index: Integer index of the node from which
- to retrieve the attribute.
- attr: Exact node attribute name.
- attr_name: Human-readable attribute name, for error messages.
-
- Returns:
- The layer's attribute `attr` at the node of index `node_index`.
-
- Raises:
- RuntimeError: If the layer has no inbound nodes, or if called in Eager
- mode.
- ValueError: If the index provided does not match any node.
- """
- if not self._inbound_nodes:
- raise RuntimeError('The layer has never been called '
- 'and thus has no defined ' + attr_name + '.')
- if not len(self._inbound_nodes) > node_index:
- raise ValueError('Asked to get ' + attr_name + ' at node ' +
- str(node_index) + ', but the layer has only ' +
- str(len(self._inbound_nodes)) + ' inbound nodes.')
- values = getattr(self._inbound_nodes[node_index], attr)
- if len(values) == 1:
- return values[0]
- else:
- return values
-
- def get_input_shape_at(self, node_index):
- """Retrieves the input shape(s) of a layer at a given node.
-
- Arguments:
- node_index: Integer, index of the node
- from which to retrieve the attribute.
- E.g. `node_index=0` will correspond to the
- first time the layer was called.
-
- Returns:
- A shape tuple
- (or list of shape tuples if the layer has multiple inputs).
-
- Raises:
- RuntimeError: If called in Eager mode.
- """
- return self._get_node_attribute_at_index(node_index, 'input_shapes',
- 'input shape')
-
- def get_output_shape_at(self, node_index):
- """Retrieves the output shape(s) of a layer at a given node.
-
- Arguments:
- node_index: Integer, index of the node
- from which to retrieve the attribute.
- E.g. `node_index=0` will correspond to the
- first time the layer was called.
-
- Returns:
- A shape tuple
- (or list of shape tuples if the layer has multiple outputs).
-
- Raises:
- RuntimeError: If called in Eager mode.
- """
- if context.executing_eagerly():
- raise RuntimeError(
- 'Layer.get_output_shape_at not supported in Eager mode.')
- return self._get_node_attribute_at_index(node_index, 'output_shapes',
- 'output shape')
-
- def get_input_at(self, node_index):
- """Retrieves the input tensor(s) of a layer at a given node.
-
- Arguments:
- node_index: Integer, index of the node
- from which to retrieve the attribute.
- E.g. `node_index=0` will correspond to the
- first time the layer was called.
-
- Returns:
- A tensor (or list of tensors if the layer has multiple inputs).
-
- Raises:
- RuntimeError: If called in Eager mode.
- """
- if context.executing_eagerly():
- raise RuntimeError('Layer.get_input_at not supported in Eager mode.')
- return self._get_node_attribute_at_index(node_index, 'input_tensors',
- 'input')
-
- def get_output_at(self, node_index):
- """Retrieves the output tensor(s) of a layer at a given node.
-
- Arguments:
- node_index: Integer, index of the node
- from which to retrieve the attribute.
- E.g. `node_index=0` will correspond to the
- first time the layer was called.
-
- Returns:
- A tensor (or list of tensors if the layer has multiple outputs).
-
- Raises:
- RuntimeError: If called in Eager mode.
- """
- return self._get_node_attribute_at_index(node_index, 'output_tensors',
- 'output')
-
- @property
- def input(self):
- """Retrieves the input tensor(s) of a layer.
-
- Only applicable if the layer has exactly one input,
- i.e. if it is connected to one incoming layer.
-
- Returns:
- Input tensor or list of input tensors.
-
- Raises:
- AttributeError: if the layer is connected to
- more than one incoming layers.
-
- Raises:
- RuntimeError: If called in Eager mode.
- AttributeError: If no inbound nodes are found.
- """
- if not self._inbound_nodes:
- raise AttributeError('Layer ' + self.name +
- ' is not connected, no input to return.')
- return self._get_node_attribute_at_index(0, 'input_tensors', 'input')
-
- @property
- def output(self):
- """Retrieves the output tensor(s) of a layer.
-
- Only applicable if the layer has exactly one output,
- i.e. if it is connected to one incoming layer.
-
- Returns:
- Output tensor or list of output tensors.
-
- Raises:
- AttributeError: if the layer is connected to more than one incoming
- layers.
- RuntimeError: if called in Eager mode.
- """
- if not self._inbound_nodes:
- raise AttributeError('Layer ' + self.name + ' has no inbound nodes.')
- return self._get_node_attribute_at_index(0, 'output_tensors', 'output')
-
- @property
- def input_shape(self):
- """Retrieves the input shape(s) of a layer.
-
- Only applicable if the layer has exactly one input,
- i.e. if it is connected to one incoming layer, or if all inputs
- have the same shape.
-
- Returns:
- Input shape, as an integer shape tuple
- (or list of shape tuples, one tuple per input tensor).
-
- Raises:
- AttributeError: if the layer has no defined input_shape.
- RuntimeError: if called in Eager mode.
- """
- if not self._inbound_nodes:
- raise AttributeError('The layer has never been called '
- 'and thus has no defined input shape.')
- all_input_shapes = set(
- [str(node.input_shapes) for node in self._inbound_nodes])
- if len(all_input_shapes) == 1:
- input_shapes = self._inbound_nodes[0].input_shapes
- if len(input_shapes) == 1:
- return tuple(tensor_shape.TensorShape(input_shapes[0]).as_list())
- else:
- return [
- tuple(tensor_shape.TensorShape(shape).as_list())
- for shape in input_shapes
- ]
- else:
- raise AttributeError('The layer "' + str(self.name) +
- ' has multiple inbound nodes, '
- 'with different input shapes. Hence '
- 'the notion of "input shape" is '
- 'ill-defined for the layer. '
- 'Use `get_input_shape_at(node_index)` '
- 'instead.')
-
- def count_params(self):
- """Count the total number of scalars composing the weights.
-
- Returns:
- An integer count.
-
- Raises:
- ValueError: if the layer isn't yet built
- (in which case its weights aren't yet defined).
- """
- if not self.built:
- if self.__class__.__name__ == 'Sequential':
- self.build() # pylint: disable=no-value-for-parameter
- else:
- raise ValueError('You tried to call `count_params` on ' + self.name +
- ', but the layer isn\'t built. '
- 'You can build it manually via: `' + self.name +
- '.build(batch_input_shape)`.')
- weight_shapes = [w.get_shape().as_list() for w in self.weights]
- return int(sum([np.prod(w) for w in weight_shapes]))
-
- @property
- def output_shape(self):
- """Retrieves the output shape(s) of a layer.
-
- Only applicable if the layer has one output,
- or if all outputs have the same shape.
-
- Returns:
- Output shape, as an integer shape tuple
- (or list of shape tuples, one tuple per output tensor).
-
- Raises:
- AttributeError: if the layer has no defined output shape.
- RuntimeError: if called in Eager mode.
- """
- if not self._inbound_nodes:
- raise AttributeError('The layer has never been called '
- 'and thus has no defined output shape.')
- all_output_shapes = set(
- [str(node.output_shapes) for node in self._inbound_nodes])
- if len(all_output_shapes) == 1:
- output_shapes = self._inbound_nodes[0].output_shapes
- if len(output_shapes) == 1:
- return tuple(tensor_shape.TensorShape(output_shapes[0]).as_list())
- else:
- return [
- tuple(tensor_shape.TensorShape(shape).as_list())
- for shape in output_shapes
- ]
- else:
- raise AttributeError('The layer "%s"'
- ' has multiple inbound nodes, '
- 'with different output shapes. Hence '
- 'the notion of "output shape" is '
- 'ill-defined for the layer. '
- 'Use `get_output_shape_at(node_index)` '
- 'instead.' % self.name)
-
- @property
- def inbound_nodes(self):
- """Deprecated, do NOT use! Only for compatibility with external Keras."""
- return self._inbound_nodes
-
- @property
- def outbound_nodes(self):
- """Deprecated, do NOT use! Only for compatibility with external Keras."""
- return self._outbound_nodes
-
- def _assert_input_compatibility(self, inputs):
- """Checks compatibility between the layer and provided inputs.
-
- This checks that the tensor(s) `inputs` verify the input assumptions
- of the layer (if any). If not, a clear and actional exception gets raised.
-
- Arguments:
- inputs: input tensor or list of input tensors.
-
- Raises:
- ValueError: in case of mismatch between
- the provided inputs and the expectations of the layer.
- """
- if not self.input_spec:
- return
- if not isinstance(self.input_spec, (list, tuple)):
- input_spec = nest.flatten(self.input_spec)
- else:
- input_spec = self.input_spec
- inputs = nest.flatten(inputs)
- if len(inputs) != len(input_spec):
- raise ValueError('Layer ' + self.name + ' expects ' +
- str(len(input_spec)) + ' inputs, '
- 'but it received ' + str(len(inputs)) +
- ' input tensors. Inputs received: ' + str(inputs))
- for input_index, (x, spec) in enumerate(zip(inputs, input_spec)):
- if spec is None:
- continue
-
- if (spec.ndim is not None or
- spec.min_ndim is not None or
- spec.max_ndim is not None):
- if x.get_shape().ndims is None:
- raise ValueError('Input ' + str(input_index) + ' of layer ' +
- self.name + ' is incompatible with the layer: '
- 'its rank is undefined, but the layer requires a '
- 'defined rank.')
-
- # Check ndim.
- if spec.ndim is not None:
- ndim = x.get_shape().ndims
- if ndim != spec.ndim:
- raise ValueError('Input ' + str(input_index) + ' of layer ' +
- self.name + ' is incompatible with the layer: '
- 'expected ndim=' + str(spec.ndim) + ', found ndim=' +
- str(ndim) + '. Full shape received: ' +
- str(x.get_shape().as_list()))
- if spec.max_ndim is not None:
- ndim = x.get_shape().ndims
- if ndim is not None and ndim > spec.max_ndim:
- raise ValueError('Input ' + str(input_index) + ' of layer ' +
- self.name + ' is incompatible with the layer: '
- 'expected max_ndim=' + str(spec.max_ndim) +
- ', found ndim=' + str(ndim))
- if spec.min_ndim is not None:
- ndim = x.get_shape().ndims
- if ndim is not None and ndim < spec.min_ndim:
- raise ValueError('Input ' + str(input_index) + ' of layer ' +
- self.name + ' is incompatible with the layer: '
- ': expected min_ndim=' + str(spec.min_ndim) +
- ', found ndim=' + str(ndim) +
- '. Full shape received: ' +
- str(x.get_shape().as_list()))
- # Check dtype.
- if spec.dtype is not None:
- if x.dtype != spec.dtype:
- raise ValueError('Input ' + str(input_index) + ' of layer ' +
- self.name + ' is incompatible with the layer: '
- 'expected dtype=' + str(spec.dtype) +
- ', found dtype=' + str(x.dtype))
- # Check specific shape axes.
- if spec.axes:
- shape = x.get_shape().as_list()
- if shape is not None:
- for axis, value in spec.axes.items():
- if hasattr(value, 'value'):
- value = value.value
- if value is not None and shape[int(axis)] not in {value, None}:
- raise ValueError(
- 'Input ' + str(input_index) + ' of layer ' + self.name + ' is'
- ' incompatible with the layer: expected axis ' + str(axis) +
- ' of input shape to have value ' + str(value) +
- ' but received input with shape ' + str(shape))
- # Check shape.
- if spec.shape is not None:
- shape = x.get_shape().as_list()
- if shape is not None:
- for spec_dim, dim in zip(spec.shape, shape):
- if spec_dim is not None and dim is not None:
- if spec_dim != dim:
- raise ValueError('Input ' + str(input_index) +
- ' is incompatible with layer ' + self.name +
- ': expected shape=' + str(spec.shape) +
- ', found shape=' + str(shape))
-
-
-@tf_export('keras.layers.InputSpec', 'layers.InputSpec')
-class InputSpec(object):
- """Specifies the ndim, dtype and shape of every input to a layer.
-
- Every layer should expose (if appropriate) an `input_spec` attribute:
- a list of instances of InputSpec (one per input tensor).
-
- A None entry in a shape is compatible with any dimension,
- a None shape is compatible with any shape.
-
- Arguments:
- dtype: Expected DataType of the input.
- shape: Shape tuple, expected shape of the input
- (may include None for unchecked axes).
- ndim: Integer, expected rank of the input.
- max_ndim: Integer, maximum rank of the input.
- min_ndim: Integer, minimum rank of the input.
- axes: Dictionary mapping integer axes to
- a specific dimension value.
- """
-
- def __init__(self,
- dtype=None,
- shape=None,
- ndim=None,
- max_ndim=None,
- min_ndim=None,
- axes=None):
- self.dtype = dtype
- self.shape = shape
- if shape is not None:
- self.ndim = len(shape)
- else:
- self.ndim = ndim
- self.max_ndim = max_ndim
- self.min_ndim = min_ndim
- self.axes = axes or {}
-
- def __repr__(self):
- spec = [('dtype=' + str(self.dtype)) if self.dtype else '',
- ('shape=' + str(self.shape)) if self.shape else '',
- ('ndim=' + str(self.ndim)) if self.ndim else '',
- ('max_ndim=' + str(self.max_ndim)) if self.max_ndim else '',
- ('min_ndim=' + str(self.min_ndim)) if self.min_ndim else '',
- ('axes=' + str(self.axes)) if self.axes else '']
- return 'InputSpec(%s)' % ', '.join(x for x in spec if x)
-
-
-class Node(object):
- """A `Node` describes the connectivity between two layers.
-
- Each time a layer is connected to some new input,
- a node is added to `layer._inbound_nodes`.
- Each time the output of a layer is used by another layer,
- a node is added to `layer._outbound_nodes`.
-
- Arguments:
- outbound_layer: the layer that takes
- `input_tensors` and turns them into `output_tensors`
- (the node gets created when the `call`
- method of the layer was called).
- inbound_layers: a list of layers, the same length as `input_tensors`,
- the layers from where `input_tensors` originate.
- node_indices: a list of integers, the same length as `inbound_layers`.
- `node_indices[i]` is the origin node of `input_tensors[i]`
- (necessary since each inbound layer might have several nodes,
- e.g. if the layer is being shared with a different data stream).
- tensor_indices: a list of integers,
- the same length as `inbound_layers`.
- `tensor_indices[i]` is the index of `input_tensors[i]` within the
- output of the inbound layer
- (necessary since each inbound layer might
- have multiple tensor outputs, with each one being
- independently manipulable).
- input_tensors: list of input tensors.
- output_tensors: list of output tensors.
- arguments: dictionary of keyword arguments that were passed to the
- `call` method of the layer at the call that created the node.
-
- `node_indices` and `tensor_indices` are basically fine-grained coordinates
- describing the origin of the `input_tensors`.
-
- A node from layer A to layer B is added to:
- - A._outbound_nodes
- - B._inbound_nodes
- """
-
- def __init__(self,
- outbound_layer,
- inbound_layers,
- node_indices,
- tensor_indices,
- input_tensors,
- output_tensors,
- arguments=None):
- # Layer instance (NOT a list).
- if isinstance(outbound_layer, list):
- raise ValueError(
- '`outbound_layer` should be a layer instance, not a list.')
- # this is the layer that takes a list of input tensors
- # and turns them into a list of output tensors.
- # the current node will be added to
- # the inbound_nodes of outbound_layer.
- self.outbound_layer = outbound_layer
-
- # The following 3 properties describe where
- # the input tensors come from: which layers,
- # and for each layer, which node and which
- # tensor output of each node.
-
- # List of layer instances.
- self.inbound_layers = inbound_layers
- # List of integers, 1:1 mapping with inbound_layers.
- self.node_indices = node_indices
- # List of integers, 1:1 mapping with inbound_layers.
- self.tensor_indices = tensor_indices
-
- # Following 2 properties:
- # tensor inputs and outputs of outbound_layer.
-
- # List of tensors. 1:1 mapping with inbound_layers.
- self.input_tensors = input_tensors
- # List of tensors, created by outbound_layer.call().
- self.output_tensors = output_tensors
-
- # Following 2 properties: input and output shapes.
-
- # List of shape tuples, shapes of input_tensors.
- self.input_shapes = [layers_util.static_shape(x) for x in input_tensors]
- # List of shape tuples, shapes of output_tensors.
- self.output_shapes = [layers_util.static_shape(x) for x in output_tensors]
-
- # Optional keyword arguments to layer's `call`.
- self.arguments = arguments
-
- # Add nodes to all layers involved.
- for layer in inbound_layers:
- if layer is not None:
- # For compatibility with external Keras, we use the deprecated
- # accessor here.
- layer.outbound_nodes.append(self)
- # For compatibility with external Keras, we use the deprecated
- # accessor here.
- outbound_layer.inbound_nodes.append(self)
-
- def get_config(self):
- inbound_names = []
- for layer in self.inbound_layers:
- if layer:
- inbound_names.append(layer.name)
- else:
- inbound_names.append(None)
- return {
- 'outbound_layer': self.outbound_layer.name,
- 'inbound_layers': inbound_names,
- 'node_indices': self.node_indices,
- 'tensor_indices': self.tensor_indices
- }
-
-
-class _DeferredTensor(object):
- """Tensor-like object used to build graphs of layers in Eager mode.
-
- When calling a layer on a DeferredTensor, the layer will not perform any
- computation and will simply perfom shape inference to return new
- DeferredTensors with appropriate shape information. Thus DeferredTensor
- behaves like a graph-mode Tensor when manipulated by layers.
- """
-
- def __init__(self, shape, dtype, name=None):
- self.shape = tensor_shape.TensorShape(shape)
- if dtype is None:
- self.dtype = dtypes.as_dtype(np.float32)
- else:
- self.dtype = dtypes.as_dtype(dtype)
- self.name = name
-
- def get_shape(self):
- return self.shape
-
- def __str__(self):
- return "DeferredTensor('%s', shape=%s, dtype=%s)" % (self.name,
- self.get_shape(),
- self.dtype.name)
-
- def __repr__(self):
- return "<_DeferredTensor '%s' shape=%s dtype=%s>" % (self.name,
- self.get_shape(),
- self.dtype.name)
-
-
-def _is_tensor_or_tensor_list(v):
- v = nest.flatten(v)
- if v and isinstance(v[0], ops.Tensor):
- return True
- else:
- return False
-
-
-def _to_snake_case(name):
- intermediate = re.sub('(.)([A-Z][a-z0-9]+)', r'\1_\2', name)
- insecure = re.sub('([a-z])([A-Z])', r'\1_\2', intermediate).lower()
- # If the class is private the name starts with "_" which is not secure
- # for creating scopes. We prefix the name with "private" in this case.
- if insecure[0] != '_':
- return insecure
- return 'private' + insecure
-
-
-def _to_list(x):
- """This normalizes a list/tuple or single element into a list.
-
- If a single element is passed, we return
- a list of size 1 containing the element.
-
- Arguments:
- x: list or tuple or single element.
-
- Returns:
- A list.
- """
- if isinstance(x, (list, tuple)):
- return list(x)
- return [x]
-
def _add_elements_to_collection(elements, collection_list):
if context.executing_eagerly():
@@ -1473,105 +354,3 @@ def _add_elements_to_collection(elements, collection_list):
if element not in collection_set:
collection.append(element)
-
-def _is_all_none(iterable_or_element):
- if not isinstance(iterable_or_element, (list, tuple)):
- iterable = [iterable_or_element]
- else:
- iterable = iterable_or_element
- # We cannot use Python's `any` because the iterable may return Tensors.
- for element in iterable:
- if element is not None:
- return False
- return True
-
-
-def _have_all_keras_metadata(iterable_or_element):
- if not isinstance(iterable_or_element, (list, tuple)):
- iterable = [iterable_or_element]
- else:
- iterable = iterable_or_element
- return all([hasattr(x, '_keras_history') for x in iterable])
-
-
-def _collect_previous_mask(input_tensors):
- """Retrieves the output mask(s) of the previous node.
-
- Arguments:
- input_tensors: A tensor or list of tensors.
-
- Returns:
- A mask tensor or list of mask tensors.
- """
- input_tensors = nest.flatten(input_tensors)
- masks = []
- for x in input_tensors:
- if hasattr(x, '_keras_mask'):
- mask = x._keras_mask # pylint: disable=protected-access
- masks.append(mask)
- else:
- masks.append(None)
- if len(masks) == 1:
- return masks[0]
- return masks
-
-
-# A global dictionary mapping graph objects to an index of counters used
-# for various layer names in each graph.
-# Allows to give unique autogenerated names to layers, in a graph-specific way.
-PER_GRAPH_LAYER_NAME_UIDS = weakref.WeakKeyDictionary()
-
-
-def _get_default_graph_uid_map():
- graph = ops.get_default_graph()
- name_uid_map = PER_GRAPH_LAYER_NAME_UIDS.get(graph, None)
- if name_uid_map is None:
- name_uid_map = collections.defaultdict(int)
- PER_GRAPH_LAYER_NAME_UIDS[graph] = name_uid_map
- return name_uid_map
-
-
-def _unique_layer_name(name, name_uid_map=None, avoid_names=None, namespace='',
- zero_based=False):
- """Makes a layer name (or arbitrary string) unique within a TensorFlow graph.
-
- Arguments:
- name: String name to make unique.
- name_uid_map: An optional defaultdict(int) to use when creating unique
- names. If None (default), uses a per-Graph dictionary.
- avoid_names: An optional set or dict with names which should not be used. If
- None (default) does not avoid any names.
- namespace: Gets a name which is unique within the (graph, namespace). Layers
- which are not Networks use a blank namespace and so get graph-global
- names.
- zero_based: If True, name sequences start with no suffix (e.g. "dense",
- "dense_1"). If False, naming is one-based ("dense_1", "dense_2").
-
- Returns:
- Unique string name.
-
- Example:
-
- ```python
- _unique_layer_name('dense') # dense_1
- _unique_layer_name('dense') # dense_2
- ```
- """
- if name_uid_map is None:
- name_uid_map = _get_default_graph_uid_map()
- if avoid_names is None:
- avoid_names = set()
- proposed_name = None
- while proposed_name is None or proposed_name in avoid_names:
- name_key = (namespace, name)
- if zero_based:
- number = name_uid_map[name_key]
- if number:
- proposed_name = name + '_' + str(number)
- else:
- proposed_name = name
- name_uid_map[name_key] += 1
- else:
- name_uid_map[name_key] += 1
- proposed_name = name + '_' + str(name_uid_map[name_key])
- return proposed_name
diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py
index 9ed4afeaba..c05c675263 100644
--- a/tensorflow/python/layers/base_test.py
+++ b/tensorflow/python/layers/base_test.py
@@ -94,61 +94,6 @@ class BaseLayerTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, 'activity_regularizer'):
core_layers.Dense(1, activity_regularizer=lambda *args, **kwargs: 0.)
- def testGetVariable(self):
- with self.test_session():
-
- class MyLayer(base_layers.Layer):
-
- def build(self, input_shape):
- self.my_var = self.add_variable(
- 'my_var', [2, 2], initializer=init_ops.zeros_initializer())
-
- def call(self, inputs):
- return inputs * 2
-
- layer = MyLayer(name='my_layer')
- inputs = random_ops.random_uniform((5,), seed=1)
- layer.apply(inputs)
- layer.apply(inputs)
- self.assertEqual([v.name for v in layer.variables],
- ['my_layer/my_var:0'])
-
- # Creating a layer with no scope leads to lazy construction of
- # the scope at apply() time. It uses scope "<current scope>/base_name"
- lazy_layer = MyLayer(_reuse=True)
- with variable_scope.variable_scope('new_scope'):
- with variable_scope.variable_scope('my_layer'):
- variable_scope.get_variable('my_var', [2, 2])
-
- # Smoke test: it runs.
- lazy_layer.apply(inputs)
- # The variables were created outside of the Layer, and
- # reuse=True, so the Layer does not own them and they are not
- # stored in its collection.
- self.assertEqual(lazy_layer.variables, [])
- self.assertEqual(lazy_layer._scope.name, 'new_scope/my_layer')
-
- # Creating a layer with no scope leads to lazy construction of
- # the scope at apply() time. If 'scope' argument is passed to
- # apply(), it uses that scope when accessing variables.
- lazy_layer = MyLayer(_reuse=True)
- with variable_scope.variable_scope('new_scope') as new_scope:
- variable_scope.get_variable('my_var', [2, 2])
-
- # Smoke test: it runs.
- lazy_layer.apply(inputs, scope=new_scope)
- # The variables were created outside of the Layer, and
- # reuse=True, so the Layer does not own them and they are not
- # stored in its collection.
- self.assertEqual(lazy_layer.variables, [])
- self.assertEqual(lazy_layer._scope.name, 'new_scope')
-
- # Checking for graph equality is only done in GRAPH mode.
- with ops.Graph().as_default():
- inputs_ng = random_ops.random_uniform((5,), seed=1)
- with self.assertRaisesRegexp(ValueError, r'graph are not the same'):
- layer.apply(inputs_ng)
-
@test_util.run_in_graph_and_eager_modes()
def testCall(self):
@@ -165,38 +110,6 @@ class BaseLayerTest(test.TestCase):
# op is only supported in GRAPH mode
self.assertEqual(outputs.op.name, 'my_layer/Square')
- def testFirstCallCanCreateVariablesButSecondCanNotWhenBuildEmpty(self):
- # Note that this test is only run in Graph mode since with EAGER mode we can
- # still create a new variable on second call.
-
- class MyLayer(base_layers.Layer):
-
- def build(self, _):
- # Do not mark the layer as built.
- pass
-
- def call(self, inputs):
- self.my_var = self.add_variable('my_var', [2, 2])
- if self.built:
- # Skip creating on the first call; try to create after it's
- # built. This is expected to fail.
- self.add_variable('this_will_break_on_second_call', [2, 2])
- return inputs + math_ops.square(self.my_var)
-
- layer = MyLayer(name='my_layer')
- inputs = random_ops.random_uniform((2,), seed=1)
- outputs = layer.apply(inputs)
- self.assertEqual(layer.built, True)
- self.assertEqual(outputs.op.name, 'my_layer/add')
- self.assertEqual([v.name
- for v in layer.variables], ['my_layer/my_var:0'])
- with self.assertRaisesRegexp(ValueError,
- 'my_layer/this_will_break_on_second_call'):
- layer.apply(inputs)
- # The list of variables hasn't changed.
- self.assertEqual([v.name
- for v in layer.variables], ['my_layer/my_var:0'])
-
@test_util.run_in_graph_and_eager_modes()
def testDeepCopy(self):
@@ -645,13 +558,14 @@ class BaseLayerTest(test.TestCase):
def testLayerGraphSetInFirstApply(self):
with ops.Graph().as_default():
- layer = core_layers.Dense(1) # Graph at construction time is ignored
+ # Graph at construction time is ignored
+ layer = core_layers.Dense(1)
with ops.Graph().as_default():
- layer.apply(constant_op.constant([[1]]))
+ layer.apply(constant_op.constant([[1.]]))
# layer is now bound to second Graph
with ops.Graph().as_default(), self.assertRaisesRegexp(
ValueError, 'Input graph and Layer graph are not the same'):
- layer.apply(constant_op.constant([[1]]))
+ layer.apply(constant_op.constant([[1.]]))
if __name__ == '__main__':
diff --git a/tensorflow/python/layers/convolutional.py b/tensorflow/python/layers/convolutional.py
index 2d99b1688f..34a1487e74 100644
--- a/tensorflow/python/layers/convolutional.py
+++ b/tensorflow/python/layers/convolutional.py
@@ -23,6 +23,7 @@ from __future__ import print_function
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.keras._impl.keras import layers as keras_layers
from tensorflow.python.layers import base
from tensorflow.python.layers import utils
from tensorflow.python.ops import array_ops
@@ -32,201 +33,8 @@ from tensorflow.python.ops import nn_ops
from tensorflow.python.util.tf_export import tf_export
-class _Conv(base.Layer):
- """Abstract nD convolution layer (private, used as implementation base).
-
- This layer creates a convolution kernel that is convolved
- (actually cross-correlated) with the layer input to produce a tensor of
- outputs. If `use_bias` is True (and a `bias_initializer` is provided),
- a bias vector is created and added to the outputs. Finally, if
- `activation` is not `None`, it is applied to the outputs as well.
-
- Arguments:
- rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution.
- filters: Integer, the dimensionality of the output space (i.e. the number
- of filters in the convolution).
- kernel_size: An integer or tuple/list of n integers, specifying the
- length of the convolution window.
- strides: An integer or tuple/list of n integers,
- specifying the stride length of the convolution.
- Specifying any stride value != 1 is incompatible with specifying
- any `dilation_rate` value != 1.
- padding: One of `"valid"` or `"same"` (case-insensitive).
- data_format: A string, one of `channels_last` (default) or `channels_first`.
- The ordering of the dimensions in the inputs.
- `channels_last` corresponds to inputs with shape
- `(batch, ..., channels)` while `channels_first` corresponds to
- inputs with shape `(batch, channels, ...)`.
- dilation_rate: An integer or tuple/list of n integers, specifying
- the dilation rate to use for dilated convolution.
- Currently, specifying any `dilation_rate` value != 1 is
- incompatible with specifying any `strides` value != 1.
- activation: Activation function. Set it to None to maintain a
- linear activation.
- use_bias: Boolean, whether the layer uses a bias.
- kernel_initializer: An initializer for the convolution kernel.
- bias_initializer: An initializer for the bias vector. If None, the default
- initializer will be used.
- kernel_regularizer: Optional regularizer for the convolution kernel.
- bias_regularizer: Optional regularizer for the bias vector.
- activity_regularizer: Optional regularizer function for the output.
- kernel_constraint: Optional projection function to be applied to the
- kernel after being updated by an `Optimizer` (e.g. used to implement
- norm constraints or value constraints for layer weights). The function
- must take as input the unprojected variable and must return the
- projected variable (which must have the same shape). Constraints are
- not safe to use when doing asynchronous distributed training.
- bias_constraint: Optional projection function to be applied to the
- bias after being updated by an `Optimizer`.
- trainable: Boolean, if `True` also add variables to the graph collection
- `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
- name: A string, the name of the layer.
- """
-
- def __init__(self, rank,
- filters,
- kernel_size,
- strides=1,
- padding='valid',
- data_format='channels_last',
- dilation_rate=1,
- activation=None,
- use_bias=True,
- kernel_initializer=None,
- bias_initializer=init_ops.zeros_initializer(),
- kernel_regularizer=None,
- bias_regularizer=None,
- activity_regularizer=None,
- kernel_constraint=None,
- bias_constraint=None,
- trainable=True,
- name=None,
- **kwargs):
- super(_Conv, self).__init__(trainable=trainable, name=name,
- activity_regularizer=activity_regularizer,
- **kwargs)
- self.rank = rank
- self.filters = filters
- self.kernel_size = utils.normalize_tuple(kernel_size, rank, 'kernel_size')
- self.strides = utils.normalize_tuple(strides, rank, 'strides')
- self.padding = utils.normalize_padding(padding)
- self.data_format = utils.normalize_data_format(data_format)
- self.dilation_rate = utils.normalize_tuple(
- dilation_rate, rank, 'dilation_rate')
- self.activation = activation
- self.use_bias = use_bias
- self.kernel_initializer = kernel_initializer
- self.bias_initializer = bias_initializer
- self.kernel_regularizer = kernel_regularizer
- self.bias_regularizer = bias_regularizer
- self.kernel_constraint = kernel_constraint
- self.bias_constraint = bias_constraint
- self.input_spec = base.InputSpec(ndim=self.rank + 2)
-
- def build(self, input_shape):
- input_shape = tensor_shape.TensorShape(input_shape)
- if self.data_format == 'channels_first':
- channel_axis = 1
- else:
- channel_axis = -1
- if input_shape[channel_axis].value is None:
- raise ValueError('The channel dimension of the inputs '
- 'should be defined. Found `None`.')
- input_dim = input_shape[channel_axis].value
- kernel_shape = self.kernel_size + (input_dim, self.filters)
-
- self.kernel = self.add_variable(name='kernel',
- shape=kernel_shape,
- initializer=self.kernel_initializer,
- regularizer=self.kernel_regularizer,
- constraint=self.kernel_constraint,
- trainable=True,
- dtype=self.dtype)
- if self.use_bias:
- self.bias = self.add_variable(name='bias',
- shape=(self.filters,),
- initializer=self.bias_initializer,
- regularizer=self.bias_regularizer,
- constraint=self.bias_constraint,
- trainable=True,
- dtype=self.dtype)
- else:
- self.bias = None
- self.input_spec = base.InputSpec(ndim=self.rank + 2,
- axes={channel_axis: input_dim})
- self._convolution_op = nn_ops.Convolution(
- input_shape,
- filter_shape=self.kernel.get_shape(),
- dilation_rate=self.dilation_rate,
- strides=self.strides,
- padding=self.padding.upper(),
- data_format=utils.convert_data_format(self.data_format,
- self.rank + 2))
- self.built = True
-
- def call(self, inputs):
- outputs = self._convolution_op(inputs, self.kernel)
-
- if self.use_bias:
- if self.data_format == 'channels_first':
- if self.rank == 1:
- # nn.bias_add does not accept a 1D input tensor.
- bias = array_ops.reshape(self.bias, (1, self.filters, 1))
- outputs += bias
- if self.rank == 2:
- outputs = nn.bias_add(outputs, self.bias, data_format='NCHW')
- if self.rank == 3:
- # As of Mar 2017, direct addition is significantly slower than
- # bias_add when computing gradients. To use bias_add, we collapse Z
- # and Y into a single dimension to obtain a 4D input tensor.
- outputs_shape = outputs.shape.as_list()
- if outputs_shape[0] is None:
- outputs_shape[0] = -1
- outputs_4d = array_ops.reshape(outputs,
- [outputs_shape[0], outputs_shape[1],
- outputs_shape[2] * outputs_shape[3],
- outputs_shape[4]])
- outputs_4d = nn.bias_add(outputs_4d, self.bias, data_format='NCHW')
- outputs = array_ops.reshape(outputs_4d, outputs_shape)
- else:
- outputs = nn.bias_add(outputs, self.bias, data_format='NHWC')
-
- if self.activation is not None:
- return self.activation(outputs)
- return outputs
-
- def compute_output_shape(self, input_shape):
- input_shape = tensor_shape.TensorShape(input_shape).as_list()
- if self.data_format == 'channels_last':
- space = input_shape[1:-1]
- new_space = []
- for i in range(len(space)):
- new_dim = utils.conv_output_length(
- space[i],
- self.kernel_size[i],
- padding=self.padding,
- stride=self.strides[i],
- dilation=self.dilation_rate[i])
- new_space.append(new_dim)
- return tensor_shape.TensorShape([input_shape[0]] + new_space +
- [self.filters])
- else:
- space = input_shape[2:]
- new_space = []
- for i in range(len(space)):
- new_dim = utils.conv_output_length(
- space[i],
- self.kernel_size[i],
- padding=self.padding,
- stride=self.strides[i],
- dilation=self.dilation_rate[i])
- new_space.append(new_dim)
- return tensor_shape.TensorShape([input_shape[0], self.filters] +
- new_space)
-
-
@tf_export('layers.Conv1D')
-class Conv1D(_Conv):
+class Conv1D(keras_layers.Conv1D, base.Layer):
"""1D convolution layer (e.g. temporal convolution).
This layer creates a convolution kernel that is convolved
@@ -294,8 +102,7 @@ class Conv1D(_Conv):
trainable=True,
name=None,
**kwargs):
- super(Convolution1D, self).__init__(
- rank=1,
+ super(Conv1D, self).__init__(
filters=filters,
kernel_size=kernel_size,
strides=strides,
@@ -417,7 +224,7 @@ def conv1d(inputs,
@tf_export('layers.Conv2D')
-class Conv2D(_Conv):
+class Conv2D(keras_layers.Conv2D, base.Layer):
"""2D convolution layer (e.g. spatial convolution over images).
This layer creates a convolution kernel that is convolved
@@ -493,7 +300,6 @@ class Conv2D(_Conv):
name=None,
**kwargs):
super(Conv2D, self).__init__(
- rank=2,
filters=filters,
kernel_size=kernel_size,
strides=strides,
@@ -622,7 +428,7 @@ def conv2d(inputs,
@tf_export('layers.Conv3D')
-class Conv3D(_Conv):
+class Conv3D(keras_layers.Conv3D, base.Layer):
"""3D convolution layer (e.g. spatial convolution over volumes).
This layer creates a convolution kernel that is convolved
@@ -699,7 +505,6 @@ class Conv3D(_Conv):
name=None,
**kwargs):
super(Conv3D, self).__init__(
- rank=3,
filters=filters,
kernel_size=kernel_size,
strides=strides,
@@ -828,169 +633,8 @@ def conv3d(inputs,
return layer.apply(inputs)
-class _SeparableConv(_Conv):
- """Abstract base layer for separable nD convolution.
-
- This layer performs a depthwise convolution that acts separately on
- channels, followed by a pointwise convolution that mixes channels.
- If `use_bias` is True and a bias initializer is provided,
- it adds a bias vector to the output.
- It then optionally applies an activation function to produce the final output.
-
- Arguments:
- rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution.
- filters: Integer, the dimensionality of the output space (i.e. the number
- of filters in the convolution).
- kernel_size: A tuple or list of integers specifying the spatial
- dimensions of the filters. Can be a single integer to specify the same
- value for all spatial dimensions.
- strides: A tuple or list of integers specifying the strides
- of the convolution. Can be a single integer to specify the same value for
- all spatial dimensions.
- Specifying any `stride` value != 1 is incompatible with specifying
- any `dilation_rate` value != 1.
- padding: One of `"valid"` or `"same"` (case-insensitive).
- data_format: A string, one of `channels_last` (default) or `channels_first`.
- The ordering of the dimensions in the inputs.
- `channels_last` corresponds to inputs with shape
- `(batch, ..., channels)` while `channels_first` corresponds to
- inputs with shape `(batch, channels, ...)`.
- dilation_rate: An integer or tuple/list of 2 integers, specifying
- the dilation rate to use for dilated convolution.
- Can be a single integer to specify the same value for
- all spatial dimensions.
- Currently, specifying any `dilation_rate` value != 1 is
- incompatible with specifying any stride value != 1.
- depth_multiplier: The number of depthwise convolution output channels for
- each input channel. The total number of depthwise convolution output
- channels will be equal to `num_filters_in * depth_multiplier`.
- activation: Activation function. Set it to None to maintain a
- linear activation.
- use_bias: Boolean, whether the layer uses a bias.
- depthwise_initializer: An initializer for the depthwise convolution kernel.
- pointwise_initializer: An initializer for the pointwise convolution kernel.
- bias_initializer: An initializer for the bias vector. If None, the default
- initializer will be used.
- depthwise_regularizer: Optional regularizer for the depthwise
- convolution kernel.
- pointwise_regularizer: Optional regularizer for the pointwise
- convolution kernel.
- bias_regularizer: Optional regularizer for the bias vector.
- activity_regularizer: Optional regularizer function for the output.
- depthwise_constraint: Optional projection function to be applied to the
- depthwise kernel after being updated by an `Optimizer` (e.g. used for
- norm constraints or value constraints for layer weights). The function
- must take as input the unprojected variable and must return the
- projected variable (which must have the same shape). Constraints are
- not safe to use when doing asynchronous distributed training.
- pointwise_constraint: Optional projection function to be applied to the
- pointwise kernel after being updated by an `Optimizer`.
- bias_constraint: Optional projection function to be applied to the
- bias after being updated by an `Optimizer`.
- trainable: Boolean, if `True` also add variables to the graph collection
- `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
- name: A string, the name of the layer.
- """
-
- def __init__(self,
- rank,
- filters,
- kernel_size,
- strides=1,
- padding='valid',
- data_format='channels_last',
- dilation_rate=1,
- depth_multiplier=1,
- activation=None,
- use_bias=True,
- depthwise_initializer=None,
- pointwise_initializer=None,
- bias_initializer=init_ops.zeros_initializer(),
- depthwise_regularizer=None,
- pointwise_regularizer=None,
- bias_regularizer=None,
- activity_regularizer=None,
- depthwise_constraint=None,
- pointwise_constraint=None,
- bias_constraint=None,
- trainable=True,
- name=None,
- **kwargs):
- super(_SeparableConv, self).__init__(
- rank=rank,
- filters=filters,
- kernel_size=kernel_size,
- strides=strides,
- padding=padding,
- data_format=data_format,
- dilation_rate=dilation_rate,
- activation=activation,
- use_bias=use_bias,
- bias_regularizer=bias_regularizer,
- activity_regularizer=activity_regularizer,
- bias_constraint=bias_constraint,
- trainable=trainable,
- name=name,
- **kwargs)
- self.depth_multiplier = depth_multiplier
- self.depthwise_initializer = depthwise_initializer
- self.pointwise_initializer = pointwise_initializer
- self.depthwise_regularizer = depthwise_regularizer
- self.pointwise_regularizer = pointwise_regularizer
- self.depthwise_constraint = depthwise_constraint
- self.pointwise_constraint = pointwise_constraint
-
- def build(self, input_shape):
- input_shape = tensor_shape.TensorShape(input_shape)
- if self.data_format == 'channels_first':
- channel_axis = 1
- else:
- channel_axis = -1
- if input_shape[channel_axis].value is None:
- raise ValueError('The channel dimension of the inputs '
- 'should be defined. Found `None`.')
- input_dim = input_shape[channel_axis].value
- self.input_spec = base.InputSpec(ndim=self.rank + 2,
- axes={channel_axis: input_dim})
- depthwise_kernel_shape = self.kernel_size + (input_dim,
- self.depth_multiplier)
- pointwise_kernel_shape = (
- 1,) * self.rank + (self.depth_multiplier * input_dim, self.filters)
-
- self.depthwise_kernel = self.add_variable(
- name='depthwise_kernel',
- shape=depthwise_kernel_shape,
- initializer=self.depthwise_initializer,
- regularizer=self.depthwise_regularizer,
- constraint=self.depthwise_constraint,
- trainable=True,
- dtype=self.dtype)
- self.pointwise_kernel = self.add_variable(
- name='pointwise_kernel',
- shape=pointwise_kernel_shape,
- initializer=self.pointwise_initializer,
- regularizer=self.pointwise_regularizer,
- constraint=self.pointwise_constraint,
- trainable=True,
- dtype=self.dtype)
- if self.use_bias:
- self.bias = self.add_variable(name='bias',
- shape=(self.filters,),
- initializer=self.bias_initializer,
- regularizer=self.bias_regularizer,
- constraint=self.bias_constraint,
- trainable=True,
- dtype=self.dtype)
- else:
- self.bias = None
- self.built = True
-
- def call(self, inputs):
- raise NotImplementedError
-
-
@tf_export('layers.SeparableConv1D')
-class SeparableConv1D(_SeparableConv):
+class SeparableConv1D(keras_layers.SeparableConv1D, base.Layer):
"""Depthwise separable 1D convolution.
This layer performs a depthwise convolution that acts separately on
@@ -1072,7 +716,6 @@ class SeparableConv1D(_SeparableConv):
name=None,
**kwargs):
super(SeparableConv1D, self).__init__(
- rank=1,
filters=filters,
kernel_size=kernel_size,
strides=strides,
@@ -1096,45 +739,9 @@ class SeparableConv1D(_SeparableConv):
name=name,
**kwargs)
- def call(self, inputs):
- if self.data_format == 'channels_last':
- strides = (1,) + self.strides * 2 + (1,)
- spatial_start_dim = 1
- else:
- strides = (1, 1) + self.strides * 2
- spatial_start_dim = 2
-
- # Explicitly broadcast inputs and kernels to 4D.
- # TODO(fchollet): refactor when a native separable_conv1d op is available.
- inputs = array_ops.expand_dims(inputs, spatial_start_dim)
- depthwise_kernel = array_ops.expand_dims(self.depthwise_kernel, 0)
- pointwise_kernel = array_ops.expand_dims(self.pointwise_kernel, 0)
- dilation_rate = (1,) + self.dilation_rate
-
- outputs = nn.separable_conv2d(
- inputs,
- depthwise_kernel,
- pointwise_kernel,
- strides=strides,
- padding=self.padding.upper(),
- rate=dilation_rate,
- data_format=utils.convert_data_format(self.data_format, ndim=4))
-
- if self.use_bias:
- outputs = nn.bias_add(
- outputs,
- self.bias,
- data_format=utils.convert_data_format(self.data_format, ndim=4))
-
- outputs = array_ops.squeeze(outputs, [spatial_start_dim])
-
- if self.activation is not None:
- return self.activation(outputs)
- return outputs
-
@tf_export('layers.SeparableConv2D')
-class SeparableConv2D(_SeparableConv):
+class SeparableConv2D(keras_layers.SeparableConv2D, base.Layer):
"""Depthwise separable 2D convolution.
This layer performs a depthwise convolution that acts separately on
@@ -1221,7 +828,6 @@ class SeparableConv2D(_SeparableConv):
name=None,
**kwargs):
super(SeparableConv2D, self).__init__(
- rank=2,
filters=filters,
kernel_size=kernel_size,
strides=strides,
@@ -1245,31 +851,6 @@ class SeparableConv2D(_SeparableConv):
name=name,
**kwargs)
- def call(self, inputs):
- # Apply the actual ops.
- if self.data_format == 'channels_last':
- strides = (1,) + self.strides + (1,)
- else:
- strides = (1, 1) + self.strides
- outputs = nn.separable_conv2d(
- inputs,
- self.depthwise_kernel,
- self.pointwise_kernel,
- strides=strides,
- padding=self.padding.upper(),
- rate=self.dilation_rate,
- data_format=utils.convert_data_format(self.data_format, ndim=4))
-
- if self.use_bias:
- outputs = nn.bias_add(
- outputs,
- self.bias,
- data_format=utils.convert_data_format(self.data_format, ndim=4))
-
- if self.activation is not None:
- return self.activation(outputs)
- return outputs
-
@tf_export('layers.separable_conv1d')
def separable_conv1d(inputs,
@@ -1511,7 +1092,7 @@ def separable_conv2d(inputs,
@tf_export('layers.Conv2DTranspose')
-class Conv2DTranspose(Conv2D):
+class Conv2DTranspose(keras_layers.Conv2DTranspose, base.Layer):
"""Transposed 2D convolution layer (sometimes called 2D Deconvolution).
The need for transposed convolutions generally arises
@@ -1576,8 +1157,8 @@ class Conv2DTranspose(Conv2D):
name=None,
**kwargs):
super(Conv2DTranspose, self).__init__(
- filters,
- kernel_size,
+ filters=filters,
+ kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
@@ -1593,120 +1174,6 @@ class Conv2DTranspose(Conv2D):
trainable=trainable,
name=name,
**kwargs)
- self.input_spec = base.InputSpec(ndim=4)
-
- def build(self, input_shape):
- if len(input_shape) != 4:
- raise ValueError('Inputs should have rank 4. Received input shape: ' +
- str(input_shape))
- if self.data_format == 'channels_first':
- channel_axis = 1
- else:
- channel_axis = -1
- if input_shape[channel_axis] is None:
- raise ValueError('The channel dimension of the inputs '
- 'should be defined. Found `None`.')
- input_dim = input_shape[channel_axis]
- self.input_spec = base.InputSpec(ndim=4, axes={channel_axis: input_dim})
- kernel_shape = self.kernel_size + (self.filters, input_dim)
-
- self.kernel = self.add_variable(name='kernel',
- shape=kernel_shape,
- initializer=self.kernel_initializer,
- regularizer=self.kernel_regularizer,
- constraint=self.kernel_constraint,
- trainable=True,
- dtype=self.dtype)
- if self.use_bias:
- self.bias = self.add_variable(name='bias',
- shape=(self.filters,),
- initializer=self.bias_initializer,
- regularizer=self.bias_regularizer,
- constraint=self.bias_constraint,
- trainable=True,
- dtype=self.dtype)
- else:
- self.bias = None
- self.built = True
-
- def call(self, inputs):
- inputs_shape = array_ops.shape(inputs)
- batch_size = inputs_shape[0]
- if self.data_format == 'channels_first':
- c_axis, h_axis, w_axis = 1, 2, 3
- else:
- c_axis, h_axis, w_axis = 3, 1, 2
-
- height, width = inputs_shape[h_axis], inputs_shape[w_axis]
- kernel_h, kernel_w = self.kernel_size
- stride_h, stride_w = self.strides
-
- # Infer the dynamic output shape:
- out_height = utils.deconv_output_length(height,
- kernel_h,
- self.padding,
- stride_h)
- out_width = utils.deconv_output_length(width,
- kernel_w,
- self.padding,
- stride_w)
- if self.data_format == 'channels_first':
- output_shape = (batch_size, self.filters, out_height, out_width)
- strides = (1, 1, stride_h, stride_w)
- else:
- output_shape = (batch_size, out_height, out_width, self.filters)
- strides = (1, stride_h, stride_w, 1)
-
- output_shape_tensor = array_ops.stack(output_shape)
- outputs = nn.conv2d_transpose(
- inputs,
- self.kernel,
- output_shape_tensor,
- strides,
- padding=self.padding.upper(),
- data_format=utils.convert_data_format(self.data_format, ndim=4))
-
- if not context.executing_eagerly():
- # Infer the static output shape:
- out_shape = inputs.get_shape().as_list()
- out_shape[c_axis] = self.filters
- out_shape[h_axis] = utils.deconv_output_length(out_shape[h_axis],
- kernel_h,
- self.padding,
- stride_h)
- out_shape[w_axis] = utils.deconv_output_length(out_shape[w_axis],
- kernel_w,
- self.padding,
- stride_w)
- outputs.set_shape(out_shape)
-
- if self.use_bias:
- outputs = nn.bias_add(
- outputs,
- self.bias,
- data_format=utils.convert_data_format(self.data_format, ndim=4))
-
- if self.activation is not None:
- return self.activation(outputs)
- return outputs
-
- def compute_output_shape(self, input_shape):
- input_shape = tensor_shape.TensorShape(input_shape).as_list()
- output_shape = list(input_shape)
- if self.data_format == 'channels_first':
- c_axis, h_axis, w_axis = 1, 2, 3
- else:
- c_axis, h_axis, w_axis = 3, 1, 2
-
- kernel_h, kernel_w = self.kernel_size
- stride_h, stride_w = self.strides
-
- output_shape[c_axis] = self.filters
- output_shape[h_axis] = utils.deconv_output_length(
- output_shape[h_axis], kernel_h, self.padding, stride_h)
- output_shape[w_axis] = utils.deconv_output_length(
- output_shape[w_axis], kernel_w, self.padding, stride_w)
- return tensor_shape.TensorShape(output_shape)
@tf_export('layers.conv2d_transpose')
@@ -1806,7 +1273,7 @@ def conv2d_transpose(inputs,
@tf_export('layers.Conv3DTranspose')
-class Conv3DTranspose(Conv3D):
+class Conv3DTranspose(keras_layers.Conv3DTranspose, base.Layer):
"""Transposed 3D convolution layer (sometimes called 3D Deconvolution).
Arguments:
@@ -1885,153 +1352,6 @@ class Conv3DTranspose(Conv3D):
trainable=trainable,
name=name,
**kwargs)
- self.input_spec = base.InputSpec(ndim=5)
-
- def build(self, input_shape):
- if len(input_shape) != 5:
- raise ValueError('Inputs should have rank 5, received input shape:',
- str(input_shape))
- if self.data_format == 'channels_first':
- channel_axis = 1
- else:
- channel_axis = -1
- if input_shape[channel_axis] is None:
- raise ValueError('The channel dimension of the inputs '
- 'should be defined, found None: ' + str(input_shape))
- input_dim = input_shape[channel_axis]
- kernel_shape = self.kernel_size + (self.filters, input_dim)
-
- self.kernel = self.add_variable(
- 'kernel',
- shape=kernel_shape,
- initializer=self.kernel_initializer,
- regularizer=self.kernel_regularizer,
- constraint=self.kernel_constraint,
- trainable=True,
- dtype=self.dtype)
- if self.use_bias:
- self.bias = self.add_variable(
- 'bias',
- shape=(self.filters,),
- initializer=self.bias_initializer,
- regularizer=self.bias_regularizer,
- constraint=self.bias_constraint,
- trainable=True,
- dtype=self.dtype)
- else:
- self.bias = None
- self.built = True
-
- def call(self, inputs):
- inputs_shape = array_ops.shape(inputs)
- batch_size = inputs_shape[0]
- if self.data_format == 'channels_first':
- c_axis, d_axis, h_axis, w_axis = 1, 2, 3, 4
- else:
- c_axis, d_axis, h_axis, w_axis = 4, 1, 2, 3
-
- self.input_spec = base.InputSpec(ndim=5,
- axes={c_axis: inputs_shape[c_axis]})
-
- depth = inputs_shape[d_axis]
- height = inputs_shape[h_axis]
- width = inputs_shape[w_axis]
-
- kernel_d, kernel_h, kernel_w = self.kernel_size
- stride_d, stride_h, stride_w = self.strides
-
- # Infer the dynamic output shape:
- out_depth = utils.deconv_output_length(depth,
- kernel_d,
- self.padding,
- stride_d)
- out_height = utils.deconv_output_length(height,
- kernel_h,
- self.padding,
- stride_h)
- out_width = utils.deconv_output_length(width,
- kernel_w,
- self.padding,
- stride_w)
- if self.data_format == 'channels_first':
- output_shape = (batch_size, self.filters, out_depth, out_height,
- out_width)
- strides = (1, 1, stride_d, stride_h, stride_w)
- else:
- output_shape = (batch_size, out_depth, out_height, out_width,
- self.filters)
- strides = (1, stride_d, stride_h, stride_w, 1)
-
- output_shape_tensor = array_ops.stack(output_shape)
- outputs = nn.conv3d_transpose(
- inputs,
- self.kernel,
- output_shape_tensor,
- strides,
- data_format=utils.convert_data_format(self.data_format, ndim=5),
- padding=self.padding.upper())
-
- if not context.executing_eagerly():
- # Infer the static output shape:
- out_shape = inputs.get_shape().as_list()
- out_shape[c_axis] = self.filters
- out_shape[d_axis] = utils.deconv_output_length(out_shape[d_axis],
- kernel_d,
- self.padding,
- stride_d)
- out_shape[h_axis] = utils.deconv_output_length(out_shape[h_axis],
- kernel_h,
- self.padding,
- stride_h)
- out_shape[w_axis] = utils.deconv_output_length(out_shape[w_axis],
- kernel_w,
- self.padding,
- stride_w)
- outputs.set_shape(out_shape)
-
- if self.use_bias:
- outputs_shape = outputs.shape.as_list()
- if outputs_shape[0] is None:
- outputs_shape[0] = -1
- if self.data_format == 'channels_first':
- outputs_4d = array_ops.reshape(outputs, [
- outputs_shape[0], outputs_shape[1],
- outputs_shape[2] * outputs_shape[3], outputs_shape[4]
- ])
- else:
- outputs_4d = array_ops.reshape(outputs, [
- outputs_shape[0], outputs_shape[1] * outputs_shape[2],
- outputs_shape[3], outputs_shape[4]
- ])
- outputs_4d = nn.bias_add(
- outputs_4d,
- self.bias,
- data_format=utils.convert_data_format(self.data_format, ndim=4))
- outputs = array_ops.reshape(outputs_4d, outputs_shape)
-
- if self.activation is not None:
- return self.activation(outputs)
- return outputs
-
- def compute_output_shape(self, input_shape):
- input_shape = tensor_shape.TensorShape(input_shape).as_list()
- output_shape = list(input_shape)
- if self.data_format == 'channels_first':
- c_axis, d_axis, h_axis, w_axis = 1, 2, 3, 4
- else:
- c_axis, d_axis, h_axis, w_axis = 4, 1, 2, 3
-
- kernel_d, kernel_h, kernel_w = self.kernel_size
- stride_d, stride_h, stride_w = self.strides
-
- output_shape[c_axis] = self.filters
- output_shape[d_axis] = utils.deconv_output_length(
- output_shape[d_axis], kernel_d, self.padding, stride_d)
- output_shape[h_axis] = utils.deconv_output_length(
- output_shape[h_axis], kernel_h, self.padding, stride_h)
- output_shape[w_axis] = utils.deconv_output_length(
- output_shape[w_axis], kernel_w, self.padding, stride_w)
- return tensor_shape.TensorShape(output_shape)
@tf_export('layers.conv3d_transpose')
diff --git a/tensorflow/python/layers/core.py b/tensorflow/python/layers/core.py
index e598d9f83a..6d8e9eac87 100644
--- a/tensorflow/python/layers/core.py
+++ b/tensorflow/python/layers/core.py
@@ -27,23 +27,14 @@ import six
from six.moves import xrange # pylint: disable=redefined-builtin
import numpy as np
-from tensorflow.python.eager import context
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
+from tensorflow.python.keras._impl.keras import layers as keras_layers
from tensorflow.python.layers import base
-from tensorflow.python.layers import utils
-from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import gen_math_ops
-from tensorflow.python.ops import nn
-from tensorflow.python.ops import nn_ops
-from tensorflow.python.ops import standard_ops
from tensorflow.python.util.tf_export import tf_export
@tf_export('layers.Dense')
-class Dense(base.Layer):
+class Dense(keras_layers.Dense, base.Layer):
"""Densely-connected layer class.
This layer implements the operation:
@@ -108,73 +99,19 @@ class Dense(base.Layer):
trainable=True,
name=None,
**kwargs):
- super(Dense, self).__init__(trainable=trainable, name=name,
+ super(Dense, self).__init__(units=units,
+ activation=activation,
+ use_bias=use_bias,
+ kernel_initializer=kernel_initializer,
+ bias_initializer=bias_initializer,
+ kernel_regularizer=kernel_regularizer,
+ bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
+ kernel_constraint=kernel_constraint,
+ bias_constraint=bias_constraint,
+ trainable=trainable,
+ name=name,
**kwargs)
- self.units = units
- self.activation = activation
- self.use_bias = use_bias
- self.kernel_initializer = kernel_initializer
- self.bias_initializer = bias_initializer
- self.kernel_regularizer = kernel_regularizer
- self.bias_regularizer = bias_regularizer
- self.kernel_constraint = kernel_constraint
- self.bias_constraint = bias_constraint
- self.input_spec = base.InputSpec(min_ndim=2)
-
- def build(self, input_shape):
- input_shape = tensor_shape.TensorShape(input_shape)
- if input_shape[-1].value is None:
- raise ValueError('The last dimension of the inputs to `Dense` '
- 'should be defined. Found `None`.')
- self.input_spec = base.InputSpec(min_ndim=2,
- axes={-1: input_shape[-1].value})
- self.kernel = self.add_variable('kernel',
- shape=[input_shape[-1].value, self.units],
- initializer=self.kernel_initializer,
- regularizer=self.kernel_regularizer,
- constraint=self.kernel_constraint,
- dtype=self.dtype,
- trainable=True)
- if self.use_bias:
- self.bias = self.add_variable('bias',
- shape=[self.units,],
- initializer=self.bias_initializer,
- regularizer=self.bias_regularizer,
- constraint=self.bias_constraint,
- dtype=self.dtype,
- trainable=True)
- else:
- self.bias = None
- self.built = True
-
- def call(self, inputs):
- inputs = ops.convert_to_tensor(inputs, dtype=self.dtype)
- shape = inputs.get_shape().as_list()
- if len(shape) > 2:
- # Broadcasting is required for the inputs.
- outputs = standard_ops.tensordot(inputs, self.kernel, [[len(shape) - 1],
- [0]])
- # Reshape the output back to the original ndim of the input.
- if not context.executing_eagerly():
- output_shape = shape[:-1] + [self.units]
- outputs.set_shape(output_shape)
- else:
- outputs = gen_math_ops.mat_mul(inputs, self.kernel)
- if self.use_bias:
- outputs = nn.bias_add(outputs, self.bias)
- if self.activation is not None:
- return self.activation(outputs) # pylint: disable=not-callable
- return outputs
-
- def compute_output_shape(self, input_shape):
- input_shape = tensor_shape.TensorShape(input_shape)
- input_shape = input_shape.with_rank_at_least(2)
- if input_shape[-1].value is None:
- raise ValueError(
- 'The innermost dimension of input_shape must be defined, but saw: %s'
- % input_shape)
- return input_shape[:-1].concatenate(self.units)
@tf_export('layers.dense')
@@ -254,7 +191,7 @@ def dense(
@tf_export('layers.Dropout')
-class Dropout(base.Layer):
+class Dropout(keras_layers.Dropout, base.Layer):
"""Applies Dropout to the input.
Dropout consists in randomly setting a fraction `rate` of input units to 0
@@ -282,31 +219,14 @@ class Dropout(base.Layer):
seed=None,
name=None,
**kwargs):
- super(Dropout, self).__init__(name=name, **kwargs)
- self.rate = rate
- self.noise_shape = noise_shape
- self.seed = seed
-
- def _get_noise_shape(self, inputs):
- # Subclasses of `Dropout` may implement `_get_noise_shape(self, inputs)`,
- # which will override `self.noise_shape`, and allows for custom noise
- # shapes with dynamically sized inputs.
- if self.noise_shape is None:
- return self.noise_shape
- return nn_ops._get_noise_shape(inputs, self.noise_shape)
+ super(Dropout, self).__init__(rate=rate,
+ noise_shape=noise_shape,
+ seed=seed,
+ name=name,
+ **kwargs)
def call(self, inputs, training=False):
-
- def dropped_inputs():
- return nn.dropout(inputs, 1 - self.rate,
- noise_shape=self._get_noise_shape(inputs),
- seed=self.seed)
- return utils.smart_cond(training,
- dropped_inputs,
- lambda: array_ops.identity(inputs))
-
- def compute_output_shape(self, input_shape):
- return input_shape
+ return super(Dropout, self).call(inputs, training=training)
@tf_export('layers.dropout')
@@ -352,7 +272,7 @@ def dropout(inputs,
@tf_export('layers.Flatten')
-class Flatten(base.Layer):
+class Flatten(keras_layers.Flatten, base.Layer):
"""Flattens an input tensor while preserving the batch axis (axis 0).
Examples:
@@ -367,25 +287,7 @@ class Flatten(base.Layer):
# now `y` has shape `(None, None)`
```
"""
-
- def __init__(self, **kwargs):
- super(Flatten, self).__init__(**kwargs)
- self.input_spec = base.InputSpec(min_ndim=2)
-
- def call(self, inputs):
- outputs = array_ops.reshape(inputs, (array_ops.shape(inputs)[0], -1))
- if not context.executing_eagerly():
- outputs.set_shape(self.compute_output_shape(inputs.get_shape()))
- return outputs
-
- def compute_output_shape(self, input_shape):
- input_shape = tensor_shape.TensorShape(input_shape).as_list()
- output_shape = [input_shape[0]]
- if all(input_shape[1:]):
- output_shape += [np.prod(input_shape[1:])]
- else:
- output_shape += [None]
- return tensor_shape.TensorShape(output_shape)
+ pass
@tf_export('layers.flatten')
diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py
index 83b201e642..33284b0d69 100644
--- a/tensorflow/python/layers/normalization.py
+++ b/tensorflow/python/layers/normalization.py
@@ -24,26 +24,14 @@ import six
from six.moves import xrange # pylint: disable=redefined-builtin
import numpy as np
-from tensorflow.python.eager import context
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
+from tensorflow.python.keras._impl.keras import layers as keras_layers
from tensorflow.python.layers import base
-from tensorflow.python.layers import utils
-from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import nn
-from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.ops import state_ops
-from tensorflow.python.training import distribute as distribute_lib
-from tensorflow.python.training import moving_averages
from tensorflow.python.util.tf_export import tf_export
@tf_export('layers.BatchNormalization')
-class BatchNormalization(base.Layer):
+class BatchNormalization(keras_layers.BatchNormalization, base.Layer):
"""Batch Normalization layer from http://arxiv.org/abs/1502.03167.
"Batch Normalization: Accelerating Deep Network Training by Reducing
@@ -143,485 +131,31 @@ class BatchNormalization(base.Layer):
name=None,
**kwargs):
super(BatchNormalization, self).__init__(
- name=name, trainable=trainable, **kwargs)
- if isinstance(axis, list):
- self.axis = axis[:]
- else:
- self.axis = axis
- self.momentum = momentum
- self.epsilon = epsilon
- self.center = center
- self.scale = scale
- self.beta_initializer = beta_initializer
- self.gamma_initializer = gamma_initializer
- self.moving_mean_initializer = moving_mean_initializer
- self.moving_variance_initializer = moving_variance_initializer
- self.beta_regularizer = beta_regularizer
- self.gamma_regularizer = gamma_regularizer
- self.beta_constraint = beta_constraint
- self.gamma_constraint = gamma_constraint
- self.renorm = renorm
- self.virtual_batch_size = virtual_batch_size
- self.adjustment = adjustment
- if fused is None:
- fused = True
-
- self.fused = fused
- self._bessels_correction_test_only = True
-
- if renorm:
- renorm_clipping = renorm_clipping or {}
- keys = ['rmax', 'rmin', 'dmax']
- if set(renorm_clipping) - set(keys):
- raise ValueError('renorm_clipping %s contains keys not in %s' %
- (renorm_clipping, keys))
- self.renorm_clipping = renorm_clipping
- self.renorm_momentum = renorm_momentum
-
- def _add_tower_local_variable(self, *args, **kwargs):
- tower_context = distribute_lib.get_tower_context()
- with tower_context.tower_local_var_scope('mean'):
- return self.add_variable(*args, **kwargs)
-
- def build(self, input_shape):
- input_shape = tensor_shape.TensorShape(input_shape)
- if not input_shape.ndims:
- raise ValueError('Input has undefined rank:', input_shape)
- ndims = len(input_shape)
-
- # Convert axis to list and resolve negatives
- if isinstance(self.axis, int):
- self.axis = [self.axis]
-
- if not isinstance(self.axis, list):
- raise TypeError('axis must be int or list, type given: %s'
- % type(self.axis))
-
- for idx, x in enumerate(self.axis):
- if x < 0:
- self.axis[idx] = ndims + x
-
- # Validate axes
- for x in self.axis:
- if x < 0 or x >= ndims:
- raise ValueError('Invalid axis: %d' % x)
- if len(self.axis) != len(set(self.axis)):
- raise ValueError('Duplicate axis: %s' % self.axis)
-
- if self.virtual_batch_size is not None:
- if self.virtual_batch_size <= 0:
- raise ValueError('virtual_batch_size must be a positive integer that '
- 'divides the true batch size of the input Tensor')
- # If using virtual batches, the first dimension must be the batch
- # dimension and cannot be the batch norm axis
- if 0 in self.axis:
- raise ValueError('When using virtual_batch_size, the batch dimension '
- 'must be 0 and thus axis cannot include 0')
- if self.adjustment is not None:
- raise ValueError('When using virtual_batch_size, adjustment cannot '
- 'be specified')
-
- if self.fused:
- # Currently fused batch norm doesn't support renorm. It also only supports
- # an input tensor of rank 4 and a channel dimension on axis 1 or 3.
- # TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the
- # output back to its original shape accordingly.
- self.fused = (not self.renorm and
- ndims == 4 and
- self.axis in [[1], [3]] and
- self.virtual_batch_size is None and
- self.adjustment is None)
- # TODO(chrisying): fused batch norm is currently not supported for
- # multi-axis batch norm and by extension virtual batches. In some cases,
- # it might be possible to use fused batch norm but would require reshaping
- # the Tensor to 4D with the axis in 1 or 3 (preferred 1) which is
- # particularly tricky. A compromise might be to just support the most
- # common use case (turning 5D w/ virtual batch to NCHW)
-
- if self.fused:
- if self.axis == [1]:
- self._data_format = 'NCHW'
- elif self.axis == [3]:
- self._data_format = 'NHWC'
- else:
- raise ValueError('Unsupported axis, fused batch norm only supports '
- 'axis == [1] or axis == [3]')
-
- # Raise parameters of fp16 batch norm to fp32
- if self.dtype == dtypes.float16 or self.dtype == dtypes.bfloat16:
- param_dtype = dtypes.float32
- else:
- param_dtype = self.dtype or dtypes.float32
-
- axis_to_dim = {x: input_shape[x].value for x in self.axis}
- for x in axis_to_dim:
- if axis_to_dim[x] is None:
- raise ValueError('Input has undefined `axis` dimension. Input shape: ',
- input_shape)
- self.input_spec = base.InputSpec(ndim=ndims, axes=axis_to_dim)
-
- if len(axis_to_dim) == 1 and self.virtual_batch_size is None:
- # Single axis batch norm (most common/default use-case)
- param_shape = (list(axis_to_dim.values())[0],)
- else:
- # Parameter shape is the original shape but with 1 in all non-axis dims
- param_shape = [axis_to_dim[i] if i in axis_to_dim
- else 1 for i in range(ndims)]
- if self.virtual_batch_size is not None:
- # When using virtual batches, add an extra dim at index 1
- param_shape.insert(1, 1)
- for idx, x in enumerate(self.axis):
- self.axis[idx] = x + 1 # Account for added dimension
-
- if self.scale:
- self.gamma = self.add_variable(
- name='gamma',
- shape=param_shape,
- dtype=param_dtype,
- initializer=self.gamma_initializer,
- regularizer=self.gamma_regularizer,
- constraint=self.gamma_constraint,
- trainable=True)
- else:
- self.gamma = None
- if self.fused:
- self._gamma_const = array_ops.constant(
- 1.0, dtype=param_dtype, shape=param_shape)
-
- if self.center:
- self.beta = self.add_variable(
- name='beta',
- shape=param_shape,
- dtype=param_dtype,
- initializer=self.beta_initializer,
- regularizer=self.beta_regularizer,
- constraint=self.beta_constraint,
- trainable=True)
- else:
- self.beta = None
- if self.fused:
- self._beta_const = array_ops.constant(
- 0.0, dtype=param_dtype, shape=param_shape)
-
- # Disable variable partitioning when creating the moving mean and variance
- try:
- if self._scope:
- partitioner = self._scope.partitioner
- self._scope.set_partitioner(None)
- else:
- partitioner = None
- self.moving_mean = self._add_tower_local_variable(
- name='moving_mean',
- shape=param_shape,
- dtype=param_dtype,
- initializer=self.moving_mean_initializer,
- trainable=False)
-
- self.moving_variance = self._add_tower_local_variable(
- name='moving_variance',
- shape=param_shape,
- dtype=param_dtype,
- initializer=self.moving_variance_initializer,
- trainable=False)
-
- if self.renorm:
- # Create variables to maintain the moving mean and standard deviation.
- # These are used in training and thus are different from the moving
- # averages above. The renorm variables are colocated with moving_mean
- # and moving_variance.
- # NOTE: below, the outer `with device` block causes the current device
- # stack to be cleared. The nested ones use a `lambda` to set the desired
- # device and ignore any devices that may be set by the custom getter.
- def _renorm_variable(name, shape):
- var = self._add_tower_local_variable(
- name=name,
- shape=shape,
- dtype=param_dtype,
- initializer=init_ops.zeros_initializer(),
- trainable=False)
- return var
-
- with distribute_lib.get_distribution_strategy().colocate_vars_with(
- self.moving_mean):
- self.renorm_mean = _renorm_variable('renorm_mean', param_shape)
- self.renorm_mean_weight = _renorm_variable('renorm_mean_weight', ())
- # We initialize renorm_stddev to 0, and maintain the (0-initialized)
- # renorm_stddev_weight. This allows us to (1) mix the average
- # stddev with the minibatch stddev early in training, and (2) compute
- # the unbiased average stddev by dividing renorm_stddev by the weight.
- with distribute_lib.get_distribution_strategy().colocate_vars_with(
- self.moving_variance):
- self.renorm_stddev = _renorm_variable('renorm_stddev', param_shape)
- self.renorm_stddev_weight = _renorm_variable('renorm_stddev_weight',
- ())
- finally:
- if partitioner:
- self._scope.set_partitioner(partitioner)
- self.built = True
-
- def _assign_moving_average(self, variable, value, momentum):
- with ops.name_scope(None, 'AssignMovingAvg',
- [variable, value, momentum]) as scope:
- decay = ops.convert_to_tensor(1.0 - momentum, name='decay')
- if decay.dtype != variable.dtype.base_dtype:
- decay = math_ops.cast(decay, variable.dtype.base_dtype)
- update_delta = (variable - value) * decay
- return state_ops.assign_sub(variable, update_delta, name=scope)
-
- def _fused_batch_norm(self, inputs, training):
- """Returns the output of fused batch norm."""
- beta = self.beta if self.center else self._beta_const
- gamma = self.gamma if self.scale else self._gamma_const
-
- def _fused_batch_norm_training():
- return nn.fused_batch_norm(
- inputs,
- gamma,
- beta,
- epsilon=self.epsilon,
- data_format=self._data_format)
-
- def _fused_batch_norm_inference():
- return nn.fused_batch_norm(
- inputs,
- gamma,
- beta,
- mean=self.moving_mean,
- variance=self.moving_variance,
- epsilon=self.epsilon,
- is_training=False,
- data_format=self._data_format)
-
- output, mean, variance = utils.smart_cond(
- training, _fused_batch_norm_training, _fused_batch_norm_inference)
- if not self._bessels_correction_test_only:
- # Remove Bessel's correction to be consistent with non-fused batch norm.
- # Note that the variance computed by fused batch norm is
- # with Bessel's correction.
- sample_size = math_ops.cast(
- array_ops.size(inputs) / array_ops.size(variance), variance.dtype)
- factor = (sample_size - math_ops.cast(1.0, variance.dtype)) / sample_size
- variance *= factor
-
- training_value = utils.constant_value(training)
- if training_value is None:
- momentum = utils.smart_cond(training, lambda: self.momentum, lambda: 1.0)
- else:
- momentum = ops.convert_to_tensor(self.momentum)
- if training_value or training_value is None:
- mean_update = self._assign_moving_average(self.moving_mean, mean,
- momentum)
- variance_update = self._assign_moving_average(self.moving_variance,
- variance, momentum)
- self.add_update(mean_update, inputs=inputs)
- self.add_update(variance_update, inputs=inputs)
-
- return output
-
- def _renorm_correction_and_moments(self, mean, variance, training):
- """Returns the correction and update values for renorm."""
- stddev = math_ops.sqrt(variance + self.epsilon)
- # Compute the average mean and standard deviation, as if they were
- # initialized with this batch's moments.
- mixed_renorm_mean = (self.renorm_mean +
- (1. - self.renorm_mean_weight) * mean)
- mixed_renorm_stddev = (self.renorm_stddev +
- (1. - self.renorm_stddev_weight) * stddev)
- # Compute the corrections for batch renorm.
- r = stddev / mixed_renorm_stddev
- d = (mean - mixed_renorm_mean) / mixed_renorm_stddev
- # Ensure the corrections use pre-update moving averages.
- with ops.control_dependencies([r, d]):
- mean = array_ops.identity(mean)
- stddev = array_ops.identity(stddev)
- rmin, rmax, dmax = [self.renorm_clipping.get(key)
- for key in ['rmin', 'rmax', 'dmax']]
- if rmin is not None:
- r = math_ops.maximum(r, rmin)
- if rmax is not None:
- r = math_ops.minimum(r, rmax)
- if dmax is not None:
- d = math_ops.maximum(d, -dmax)
- d = math_ops.minimum(d, dmax)
- # When not training, use r=1, d=0.
- r = utils.smart_cond(training, lambda: r, lambda: array_ops.ones_like(r))
- d = utils.smart_cond(training, lambda: d, lambda: array_ops.zeros_like(d))
-
- def _update_renorm_variable(var, weight, value):
- """Updates a moving average and weight, returns the unbiased value."""
- value = array_ops.identity(value)
- def _do_update():
- """Updates the var and weight, returns their updated ratio."""
- # Update the variables without zero debiasing. The debiasing will be
- # accomplished by dividing the exponential moving average by the weight.
- # For example, after a single update, the moving average would be
- # (1-decay) * value. and the weight will be 1-decay, with their ratio
- # giving the value.
- # Make sure the weight is not updated until before r and d computation.
- with ops.control_dependencies([value]):
- weight_value = array_ops.constant(1., dtype=weight.dtype)
- new_var = self._assign_moving_average(var, value, self.renorm_momentum)
- new_weight = self._assign_moving_average(weight, weight_value,
- self.renorm_momentum)
- # TODO(yuefengz): the updates to var and weighted can not be batched
- # together if we fetch their updated values here. Consider calculating
- # new values and delaying the updates.
- return new_var / new_weight
-
- def _fake_update():
- return array_ops.identity(var)
- return utils.smart_cond(training, _do_update, _fake_update)
-
- # TODO(yuefengz): colocate the operations
- new_mean = _update_renorm_variable(self.renorm_mean,
- self.renorm_mean_weight, mean)
- new_stddev = _update_renorm_variable(self.renorm_stddev,
- self.renorm_stddev_weight, stddev)
- # Make sqrt(moving_variance + epsilon) = new_stddev.
- new_variance = math_ops.square(new_stddev) - self.epsilon
-
- return (r, d, new_mean, new_variance)
+ axis=axis,
+ momentum=momentum,
+ epsilon=epsilon,
+ center=center,
+ scale=scale,
+ beta_initializer=beta_initializer,
+ gamma_initializer=gamma_initializer,
+ moving_mean_initializer=moving_mean_initializer,
+ moving_variance_initializer=moving_variance_initializer,
+ beta_regularizer=beta_regularizer,
+ gamma_regularizer=gamma_regularizer,
+ beta_constraint=beta_constraint,
+ gamma_constraint=gamma_constraint,
+ renorm=renorm,
+ renorm_clipping=renorm_clipping,
+ renorm_momentum=renorm_momentum,
+ fused=fused,
+ trainable=trainable,
+ virtual_batch_size=virtual_batch_size,
+ adjustment=adjustment,
+ name=name,
+ **kwargs)
def call(self, inputs, training=False):
- in_eager_mode = context.executing_eagerly()
- if self.virtual_batch_size is not None:
- # Virtual batches (aka ghost batches) can be simulated by reshaping the
- # Tensor and reusing the existing batch norm implementation
- original_shape = [-1] + inputs.shape.as_list()[1:]
- expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:]
-
- # Will cause errors if virtual_batch_size does not divide the batch size
- inputs = array_ops.reshape(inputs, expanded_shape)
-
- def undo_virtual_batching(outputs):
- outputs = array_ops.reshape(outputs, original_shape)
- return outputs
-
- if self.fused:
- outputs = self._fused_batch_norm(inputs, training=training)
- if self.virtual_batch_size is not None:
- # Currently never reaches here since fused_batch_norm does not support
- # virtual batching
- return undo_virtual_batching(outputs)
- return outputs
-
- # Compute the axes along which to reduce the mean / variance
- input_shape = inputs.get_shape()
- ndims = len(input_shape)
- reduction_axes = [i for i in range(ndims) if i not in self.axis]
- if self.virtual_batch_size is not None:
- del reduction_axes[1] # Do not reduce along virtual batch dim
-
- # Broadcasting only necessary for single-axis batch norm where the axis is
- # not the last dimension
- broadcast_shape = [1] * ndims
- broadcast_shape[self.axis[0]] = input_shape[self.axis[0]].value
- def _broadcast(v):
- if (v is not None and
- len(v.get_shape()) != ndims and
- reduction_axes != list(range(ndims - 1))):
- return array_ops.reshape(v, broadcast_shape)
- return v
-
- scale, offset = _broadcast(self.gamma), _broadcast(self.beta)
-
- def _compose_transforms(scale, offset, then_scale, then_offset):
- if then_scale is not None:
- scale *= then_scale
- offset *= then_scale
- if then_offset is not None:
- offset += then_offset
- return (scale, offset)
-
- # Determine a boolean value for `training`: could be True, False, or None.
- training_value = utils.constant_value(training)
- if training_value is not False:
- if self.adjustment:
- adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs))
- # Adjust only during training.
- adj_scale = utils.smart_cond(training,
- lambda: adj_scale,
- lambda: array_ops.ones_like(adj_scale))
- adj_bias = utils.smart_cond(training,
- lambda: adj_bias,
- lambda: array_ops.zeros_like(adj_bias))
- scale, offset = _compose_transforms(adj_scale, adj_bias, scale, offset)
-
- # Some of the computations here are not necessary when training==False
- # but not a constant. However, this makes the code simpler.
- keep_dims = self.virtual_batch_size is not None or len(self.axis) > 1
- mean, variance = nn.moments(inputs, reduction_axes, keep_dims=keep_dims)
-
- moving_mean = self.moving_mean
- moving_variance = self.moving_variance
-
- mean = utils.smart_cond(training,
- lambda: mean,
- lambda: moving_mean)
- variance = utils.smart_cond(training,
- lambda: variance,
- lambda: moving_variance)
-
- if self.renorm:
- r, d, new_mean, new_variance = self._renorm_correction_and_moments(
- mean, variance, training)
- # When training, the normalized values (say, x) will be transformed as
- # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
- # = x * (r * gamma) + (d * gamma + beta) with renorm.
- r = _broadcast(array_ops.stop_gradient(r, name='renorm_r'))
- d = _broadcast(array_ops.stop_gradient(d, name='renorm_d'))
- scale, offset = _compose_transforms(r, d, scale, offset)
- else:
- new_mean, new_variance = mean, variance
-
- if self.virtual_batch_size is not None:
- # This isn't strictly correct since in ghost batch norm, you are
- # supposed to sequentially update the moving_mean and moving_variance
- # with each sub-batch. However, since the moving statistics are only
- # used during evaluation, it is more efficient to just update in one
- # step and should not make a significant difference in the result.
- new_mean = math_ops.reduce_mean(new_mean,
- axis=1, keep_dims=True)
- new_variance = math_ops.reduce_mean(new_variance,
- axis=1, keep_dims=True)
-
- def _do_update(var, value):
- if in_eager_mode and not self.trainable:
- return
-
- return self._assign_moving_average(var, value, self.momentum)
-
- mean_update = utils.smart_cond(
- training,
- lambda: _do_update(self.moving_mean, new_mean),
- lambda: self.moving_mean)
- variance_update = utils.smart_cond(
- training,
- lambda: _do_update(self.moving_variance, new_variance),
- lambda: self.moving_variance)
- if not context.executing_eagerly():
- self.add_update(mean_update, inputs=inputs)
- self.add_update(variance_update, inputs=inputs)
-
- else:
- mean, variance = self.moving_mean, self.moving_variance
-
- outputs = nn.batch_normalization(inputs,
- _broadcast(mean),
- _broadcast(variance),
- offset,
- scale,
- self.epsilon)
- # If some components of the shape got lost due to adjustments, fix that.
- outputs.set_shape(input_shape)
-
- if self.virtual_batch_size is not None:
- return undo_virtual_batching(outputs)
-
- return outputs
-
- def compute_output_shape(self, input_shape):
- return input_shape
+ return super(BatchNormalization, self).call(inputs, training=training)
@tf_export('layers.batch_normalization')
diff --git a/tensorflow/python/layers/pooling.py b/tensorflow/python/layers/pooling.py
index 50503ce093..75abe56f51 100644
--- a/tensorflow/python/layers/pooling.py
+++ b/tensorflow/python/layers/pooling.py
@@ -13,92 +13,19 @@
# limitations under the License.
# =============================================================================
-# pylint: disable=unused-import,g-bad-import-order
"""Contains the pooling layer classes and their functional aliases.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.eager import context
-from tensorflow.python.framework import tensor_shape
+from tensorflow.python.keras._impl.keras import layers as keras_layers
from tensorflow.python.layers import base
-from tensorflow.python.layers import utils
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import nn
from tensorflow.python.util.tf_export import tf_export
-class _Pooling1D(base.Layer):
- """Pooling layer for arbitrary pooling functions, for 1D inputs.
-
- This class only exists for code reuse. It will never be an exposed API.
-
- Arguments:
- pool_function: The pooling function to apply, e.g. `tf.nn.max_pool`.
- pool_size: An integer or tuple/list of a single integer,
- representing the size of the pooling window.
- strides: An integer or tuple/list of a single integer, specifying the
- strides of the pooling operation.
- padding: A string. The padding method, either 'valid' or 'same'.
- Case-insensitive.
- data_format: A string, one of `channels_last` (default) or `channels_first`.
- The ordering of the dimensions in the inputs.
- `channels_last` corresponds to inputs with shape
- `(batch, length, channels)` while `channels_first` corresponds to
- inputs with shape `(batch, channels, length)`.
- name: A string, the name of the layer.
- """
-
- def __init__(self, pool_function, pool_size, strides,
- padding='valid', data_format='channels_last',
- name=None, **kwargs):
- super(_Pooling1D, self).__init__(name=name, **kwargs)
- self.pool_function = pool_function
- self.pool_size = utils.normalize_tuple(pool_size, 1, 'pool_size')
- self.strides = utils.normalize_tuple(strides, 1, 'strides')
- self.padding = utils.normalize_padding(padding)
- self.data_format = utils.normalize_data_format(data_format)
- self.input_spec = base.InputSpec(ndim=3)
-
- def call(self, inputs):
- # There is no TF op for 1D pooling, hence we make the inputs 4D.
- if self.data_format == 'channels_last':
- # input is NWC, make it NHWC
- inputs = array_ops.expand_dims(inputs, 1)
- # pool on the W dim
- pool_shape = (1, 1) + self.pool_size + (1,)
- strides = (1, 1) + self.strides + (1,)
- data_format = 'NHWC'
- else:
- # input is NCW, make it NCHW
- inputs = array_ops.expand_dims(inputs, 2)
- # pool on the W dim
- pool_shape = (1, 1, 1) + self.pool_size
- strides = (1, 1, 1) + self.strides
- data_format = 'NCHW'
-
- outputs = self.pool_function(
- inputs,
- ksize=pool_shape,
- strides=strides,
- padding=self.padding.upper(),
- data_format=data_format)
-
- if self.data_format == 'channels_last':
- return array_ops.squeeze(outputs, 1)
- else:
- return array_ops.squeeze(outputs, 2)
-
- def compute_output_shape(self, input_shape):
- input_shape = tensor_shape.TensorShape(input_shape).as_list()
- length = utils.conv_output_length(input_shape[1], self.pool_size[0],
- self.padding, self.strides[0])
- return tensor_shape.TensorShape([input_shape[0], length, input_shape[2]])
-
-
@tf_export('layers.AveragePooling1D')
-class AveragePooling1D(_Pooling1D):
+class AveragePooling1D(keras_layers.AveragePooling1D, base.Layer):
"""Average Pooling layer for 1D inputs.
Arguments:
@@ -119,8 +46,9 @@ class AveragePooling1D(_Pooling1D):
def __init__(self, pool_size, strides,
padding='valid', data_format='channels_last',
name=None, **kwargs):
+ if strides is None:
+ raise ValueError('Argument `strides` must not be None.')
super(AveragePooling1D, self).__init__(
- nn.avg_pool,
pool_size=pool_size,
strides=strides,
padding=padding,
@@ -165,7 +93,7 @@ def average_pooling1d(inputs, pool_size, strides,
@tf_export('layers.MaxPooling1D')
-class MaxPooling1D(_Pooling1D):
+class MaxPooling1D(keras_layers.MaxPooling1D, base.Layer):
"""Max Pooling layer for 1D inputs.
Arguments:
@@ -186,8 +114,9 @@ class MaxPooling1D(_Pooling1D):
def __init__(self, pool_size, strides,
padding='valid', data_format='channels_last',
name=None, **kwargs):
+ if strides is None:
+ raise ValueError('Argument `strides` must not be None.')
super(MaxPooling1D, self).__init__(
- nn.max_pool,
pool_size=pool_size,
strides=strides,
padding=padding,
@@ -231,79 +160,8 @@ def max_pooling1d(inputs, pool_size, strides,
return layer.apply(inputs)
-class _Pooling2D(base.Layer):
- """Pooling layer for arbitrary pooling functions, for 2D inputs (e.g. images).
-
- This class only exists for code reuse. It will never be an exposed API.
-
- Arguments:
- pool_function: The pooling function to apply, e.g. `tf.nn.max_pool`.
- pool_size: An integer or tuple/list of 2 integers: (pool_height, pool_width)
- specifying the size of the pooling window.
- Can be a single integer to specify the same value for
- all spatial dimensions.
- strides: An integer or tuple/list of 2 integers,
- specifying the strides of the pooling operation.
- Can be a single integer to specify the same value for
- all spatial dimensions.
- padding: A string. The padding method, either 'valid' or 'same'.
- Case-insensitive.
- data_format: A string, one of `channels_last` (default) or `channels_first`.
- The ordering of the dimensions in the inputs.
- `channels_last` corresponds to inputs with shape
- `(batch, height, width, channels)` while `channels_first` corresponds to
- inputs with shape `(batch, channels, height, width)`.
- name: A string, the name of the layer.
- """
-
- def __init__(self, pool_function, pool_size, strides,
- padding='valid', data_format='channels_last',
- name=None, **kwargs):
- super(_Pooling2D, self).__init__(name=name, **kwargs)
- self.pool_function = pool_function
- self.pool_size = utils.normalize_tuple(pool_size, 2, 'pool_size')
- self.strides = utils.normalize_tuple(strides, 2, 'strides')
- self.padding = utils.normalize_padding(padding)
- self.data_format = utils.normalize_data_format(data_format)
- self.input_spec = base.InputSpec(ndim=4)
-
- def call(self, inputs):
- if self.data_format == 'channels_last':
- pool_shape = (1,) + self.pool_size + (1,)
- strides = (1,) + self.strides + (1,)
- else:
- pool_shape = (1, 1) + self.pool_size
- strides = (1, 1) + self.strides
- outputs = self.pool_function(
- inputs,
- ksize=pool_shape,
- strides=strides,
- padding=self.padding.upper(),
- data_format=utils.convert_data_format(self.data_format, 4))
- return outputs
-
- def compute_output_shape(self, input_shape):
- input_shape = tensor_shape.TensorShape(input_shape).as_list()
- if self.data_format == 'channels_first':
- rows = input_shape[2]
- cols = input_shape[3]
- else:
- rows = input_shape[1]
- cols = input_shape[2]
- rows = utils.conv_output_length(rows, self.pool_size[0], self.padding,
- self.strides[0])
- cols = utils.conv_output_length(cols, self.pool_size[1], self.padding,
- self.strides[1])
- if self.data_format == 'channels_first':
- return tensor_shape.TensorShape(
- [input_shape[0], input_shape[1], rows, cols])
- else:
- return tensor_shape.TensorShape(
- [input_shape[0], rows, cols, input_shape[3]])
-
-
@tf_export('layers.AveragePooling2D')
-class AveragePooling2D(_Pooling2D):
+class AveragePooling2D(keras_layers.AveragePooling2D, base.Layer):
"""Average pooling layer for 2D inputs (e.g. images).
Arguments:
@@ -328,8 +186,9 @@ class AveragePooling2D(_Pooling2D):
def __init__(self, pool_size, strides,
padding='valid', data_format='channels_last',
name=None, **kwargs):
+ if strides is None:
+ raise ValueError('Argument `strides` must not be None.')
super(AveragePooling2D, self).__init__(
- nn.avg_pool,
pool_size=pool_size, strides=strides,
padding=padding, data_format=data_format, name=name, **kwargs)
@@ -373,7 +232,7 @@ def average_pooling2d(inputs,
@tf_export('layers.MaxPooling2D')
-class MaxPooling2D(_Pooling2D):
+class MaxPooling2D(keras_layers.MaxPooling2D, base.Layer):
"""Max pooling layer for 2D inputs (e.g. images).
Arguments:
@@ -398,8 +257,9 @@ class MaxPooling2D(_Pooling2D):
def __init__(self, pool_size, strides,
padding='valid', data_format='channels_last',
name=None, **kwargs):
+ if strides is None:
+ raise ValueError('Argument `strides` must not be None.')
super(MaxPooling2D, self).__init__(
- nn.max_pool,
pool_size=pool_size, strides=strides,
padding=padding, data_format=data_format, name=name, **kwargs)
@@ -442,90 +302,8 @@ def max_pooling2d(inputs,
return layer.apply(inputs)
-class _Pooling3D(base.Layer):
- """Pooling layer for arbitrary pooling functions, for 3D inputs.
-
- This class only exists for code reuse. It will never be an exposed API.
-
- Arguments:
- pool_function: The pooling function to apply, e.g. `tf.nn.max_pool`.
- pool_size: An integer or tuple/list of 3 integers:
- (pool_depth, pool_height, pool_width)
- specifying the size of the pooling window.
- Can be a single integer to specify the same value for
- all spatial dimensions.
- strides: An integer or tuple/list of 3 integers,
- specifying the strides of the pooling operation.
- Can be a single integer to specify the same value for
- all spatial dimensions.
- padding: A string. The padding method, either 'valid' or 'same'.
- Case-insensitive.
- data_format: A string, one of `channels_last` (default) or `channels_first`.
- The ordering of the dimensions in the inputs.
- `channels_last` corresponds to inputs with shape
- `(batch, depth, height, width, channels)`
- while `channels_first` corresponds to
- inputs with shape `(batch, channels, depth, height, width)`.
- name: A string, the name of the layer.
- """
-
- def __init__(self, pool_function, pool_size, strides,
- padding='valid', data_format='channels_last',
- name=None, **kwargs):
- super(_Pooling3D, self).__init__(name=name, **kwargs)
- self.pool_function = pool_function
- self.pool_size = utils.normalize_tuple(pool_size, 3, 'pool_size')
- self.strides = utils.normalize_tuple(strides, 3, 'strides')
- self.padding = utils.normalize_padding(padding)
- self.data_format = utils.normalize_data_format(data_format)
- self.input_spec = base.InputSpec(ndim=5)
-
- def call(self, inputs):
- pool_shape = (1,) + self.pool_size + (1,)
- strides = (1,) + self.strides + (1,)
-
- if self.data_format == 'channels_first':
- # TF does not support `channels_first` with 3D pooling operations,
- # so we must handle this case manually.
- # TODO(fchollet): remove this when TF pooling is feature-complete.
- inputs = array_ops.transpose(inputs, (0, 2, 3, 4, 1))
-
- outputs = self.pool_function(
- inputs,
- ksize=pool_shape,
- strides=strides,
- padding=self.padding.upper())
-
- if self.data_format == 'channels_first':
- outputs = array_ops.transpose(outputs, (0, 4, 1, 2, 3))
- return outputs
-
- def compute_output_shape(self, input_shape):
- input_shape = tensor_shape.TensorShape(input_shape).as_list()
- if self.data_format == 'channels_first':
- len_dim1 = input_shape[2]
- len_dim2 = input_shape[3]
- len_dim3 = input_shape[4]
- else:
- len_dim1 = input_shape[1]
- len_dim2 = input_shape[2]
- len_dim3 = input_shape[3]
- len_dim1 = utils.conv_output_length(len_dim1, self.pool_size[0],
- self.padding, self.strides[0])
- len_dim2 = utils.conv_output_length(len_dim2, self.pool_size[1],
- self.padding, self.strides[1])
- len_dim3 = utils.conv_output_length(len_dim3, self.pool_size[2],
- self.padding, self.strides[2])
- if self.data_format == 'channels_first':
- return tensor_shape.TensorShape(
- [input_shape[0], input_shape[1], len_dim1, len_dim2, len_dim3])
- else:
- return tensor_shape.TensorShape(
- [input_shape[0], len_dim1, len_dim2, len_dim3, input_shape[4]])
-
-
@tf_export('layers.AveragePooling3D')
-class AveragePooling3D(_Pooling3D):
+class AveragePooling3D(keras_layers.AveragePooling3D, base.Layer):
"""Average pooling layer for 3D inputs (e.g. volumes).
Arguments:
@@ -552,8 +330,9 @@ class AveragePooling3D(_Pooling3D):
def __init__(self, pool_size, strides,
padding='valid', data_format='channels_last',
name=None, **kwargs):
+ if strides is None:
+ raise ValueError('Argument `strides` must not be None.')
super(AveragePooling3D, self).__init__(
- nn.avg_pool3d,
pool_size=pool_size, strides=strides,
padding=padding, data_format=data_format, name=name, **kwargs)
@@ -599,7 +378,7 @@ def average_pooling3d(inputs,
@tf_export('layers.MaxPooling3D')
-class MaxPooling3D(_Pooling3D):
+class MaxPooling3D(keras_layers.MaxPooling3D, base.Layer):
"""Max pooling layer for 3D inputs (e.g. volumes).
Arguments:
@@ -626,8 +405,9 @@ class MaxPooling3D(_Pooling3D):
def __init__(self, pool_size, strides,
padding='valid', data_format='channels_last',
name=None, **kwargs):
+ if strides is None:
+ raise ValueError('Argument `strides` must not be None.')
super(MaxPooling3D, self).__init__(
- nn.max_pool3d,
pool_size=pool_size, strides=strides,
padding=padding, data_format=data_format, name=name, **kwargs)
diff --git a/tensorflow/python/layers/utils_test.py b/tensorflow/python/layers/utils_test.py
index c941aad7bc..7e94dda648 100644
--- a/tensorflow/python/layers/utils_test.py
+++ b/tensorflow/python/layers/utils_test.py
@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.layers import utils
-from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@@ -89,33 +88,5 @@ class ConvUtilsTest(test.TestCase):
self.assertEqual(6, utils.deconv_output_length(4, 2, 'full', 2))
-class GraphUtilsTest(test.TestCase):
-
- def testGetReachableFromInputs(self):
-
- with self.test_session():
- pl_1 = array_ops.placeholder(shape=None, dtype='float32')
- pl_2 = array_ops.placeholder(shape=None, dtype='float32')
- pl_3 = array_ops.placeholder(shape=None, dtype='float32')
- x_1 = pl_1 + pl_2
- x_2 = pl_2 * 2
- x_3 = pl_3 + 1
- x_4 = x_1 + x_2
- x_5 = x_3 * pl_1
-
- self.assertEqual(
- utils.get_reachable_from_inputs([pl_1]),
- {pl_1, x_1, x_4, x_5})
- self.assertEqual(
- utils.get_reachable_from_inputs([pl_1, pl_2]),
- {pl_1, pl_2, x_1, x_2, x_4, x_5})
- self.assertEqual(
- utils.get_reachable_from_inputs([pl_3]),
- {pl_3, x_3, x_5})
- self.assertEqual(
- utils.get_reachable_from_inputs([x_3]),
- {x_3, x_5})
-
-
if __name__ == '__main__':
test.main()