aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/keras
diff options
context:
space:
mode:
authorGravatar Francois Chollet <fchollet@google.com>2017-08-03 16:48:48 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-03 16:55:56 -0700
commit3599fd44d6bfcb16f45e763608a0e5da4e9072f5 (patch)
tree65366abce44c2ef95a62c53baeb634a77dfd0fad /tensorflow/contrib/keras
parent57626dd38a7867b76c44f3933e7810190174a2ee (diff)
Add Network class. Networks are directed acyclic graphs of layers, that implement the full layer API. You can think of a network as a "bigger layer".
- Rename tf.contrib.keras Container as Network - Add a Network class in tf.layers which implements the part of Container that we want to add to core layers. - Make Keras Network subclass core Network. PiperOrigin-RevId: 164202674
Diffstat (limited to 'tensorflow/contrib/keras')
-rw-r--r--tensorflow/contrib/keras/python/keras/engine/topology.py1455
-rw-r--r--tensorflow/contrib/keras/python/keras/engine/topology_test.py29
-rw-r--r--tensorflow/contrib/keras/python/keras/engine/training.py2
-rw-r--r--tensorflow/contrib/keras/python/keras/models.py20
-rw-r--r--tensorflow/contrib/keras/python/keras/utils/layer_utils.py4
5 files changed, 137 insertions, 1373 deletions
diff --git a/tensorflow/contrib/keras/python/keras/engine/topology.py b/tensorflow/contrib/keras/python/keras/engine/topology.py
index c8c746e8af..4e8f6f9db1 100644
--- a/tensorflow/contrib/keras/python/keras/engine/topology.py
+++ b/tensorflow/contrib/keras/python/keras/engine/topology.py
@@ -29,10 +29,8 @@ from six.moves import zip # pylint: disable=redefined-builtin
from tensorflow.contrib.keras.python.keras import backend as K
from tensorflow.contrib.keras.python.keras.utils import conv_utils
-from tensorflow.contrib.keras.python.keras.utils.generic_utils import has_arg
from tensorflow.contrib.keras.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.contrib.keras.python.keras.utils.layer_utils import print_summary as print_layer_summary
-from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.layers import base as tf_base_layers
from tensorflow.python.platform import tf_logging as logging
@@ -51,125 +49,7 @@ except ImportError:
# pylint: enable=g-import-not-at-top
InputSpec = tf_base_layers.InputSpec # pylint: disable=invalid-name
-
-
-class Node(object):
- """A `Node` describes the connectivity between two layers.
-
- Each time a layer is connected to some new input,
- a node is added to `layer.inbound_nodes`.
- Each time the output of a layer is used by another layer,
- a node is added to `layer.outbound_nodes`.
-
- Arguments:
- outbound_layer: the layer that takes
- `input_tensors` and turns them into `output_tensors`
- (the node gets created when the `call`
- method of the layer was called).
- inbound_layers: a list of layers, the same length as `input_tensors`,
- the layers from where `input_tensors` originate.
- node_indices: a list of integers, the same length as `inbound_layers`.
- `node_indices[i]` is the origin node of `input_tensors[i]`
- (necessary since each inbound layer might have several nodes,
- e.g. if the layer is being shared with a different data stream).
- tensor_indices: a list of integers,
- the same length as `inbound_layers`.
- `tensor_indices[i]` is the index of `input_tensors[i]` within the
- output of the inbound layer
- (necessary since each inbound layer might
- have multiple tensor outputs, with each one being
- independently manipulable).
- input_tensors: list of input tensors.
- output_tensors: list of output tensors.
- input_masks: list of input masks (a mask can be a tensor, or None).
- output_masks: list of output masks (a mask can be a tensor, or None).
- arguments: dictionary of keyword arguments that were passed to the
- `call` method of the layer at the call that created the node.
-
- `node_indices` and `tensor_indices` are basically fine-grained coordinates
- describing the origin of the `input_tensors`.
-
- A node from layer A to layer B is added to:
- A.outbound_nodes
- B.inbound_nodes
- """
-
- def __init__(self,
- outbound_layer,
- inbound_layers,
- node_indices,
- tensor_indices,
- input_tensors,
- output_tensors,
- input_masks,
- output_masks,
- arguments=None):
- # Layer instance (NOT a list).
- # this is the layer that takes a list of input tensors
- # and turns them into a list of output tensors.
- # the current node will be added to
- # the inbound_nodes of outbound_layer.
- self.outbound_layer = outbound_layer
-
- # The following 3 properties describe where
- # the input tensors come from: which layers,
- # and for each layer, which node and which
- # tensor output of each node.
-
- # List of layer instances.
- self.inbound_layers = inbound_layers
- # List of integers, 1:1 mapping with inbound_layers.
- self.node_indices = node_indices
- # List of integers, 1:1 mapping with inbound_layers.
- self.tensor_indices = tensor_indices
-
- # Following 2 properties:
- # tensor inputs and outputs of outbound_layer.
-
- # List of tensors. 1:1 mapping with inbound_layers.
- self.input_tensors = input_tensors
- # List of tensors, created by outbound_layer.call().
- self.output_tensors = output_tensors
-
- # Following 2 properties: input and output masks.
- # List of tensors, 1:1 mapping with input_tensor.
- self.input_masks = input_masks
- # List of tensors, created by outbound_layer.compute_mask().
- self.output_masks = output_masks
-
- # Following 2 properties: input and output shapes.
-
- # List of shape tuples, shapes of input_tensors.
- self.input_shapes = [K.int_shape(x) for x in input_tensors]
- # List of shape tuples, shapes of output_tensors.
- self.output_shapes = [K.int_shape(x) for x in output_tensors]
-
- # Optional keyword arguments to layer's `call`.
- self.arguments = arguments
-
- # Add nodes to all layers involved.
- for layer in inbound_layers:
- if layer is not None:
- layer.outbound_nodes.append(self)
- outbound_layer.inbound_nodes.append(self)
-
- def get_config(self):
- inbound_names = []
- for layer in self.inbound_layers:
- if layer:
- inbound_names.append(layer.name)
- else:
- inbound_names.append(None)
- return {
- 'outbound_layer':
- self.outbound_layer.name if self.outbound_layer else None,
- 'inbound_layers':
- inbound_names,
- 'node_indices':
- self.node_indices,
- 'tensor_indices':
- self.tensor_indices
- }
+Node = tf_base_layers.Node # pylint: disable=invalid-name
class Layer(tf_base_layers.Layer):
@@ -274,15 +154,9 @@ class Layer(tf_base_layers.Layer):
super(Layer, self).__init__(name=name, dtype=dtype, trainable=trainable)
# Add properties that are Keras-only for now.
- self.input_spec = None
self.supports_masking = False
self._constraints = {} # dict {tensor: constraint instance}
- # These lists will be filled via successive calls
- # to self._add_inbound_node().
- self.inbound_nodes = []
- self.outbound_nodes = []
-
# Manage input shape information if passed.
if 'input_shape' in kwargs or 'batch_input_shape' in kwargs:
# In this case we will later create an input layer
@@ -378,52 +252,17 @@ class Layer(tf_base_layers.Layer):
ValueError: in case the layer is missing shape information
for its `build` call.
"""
- if isinstance(inputs, list):
- inputs = inputs[:]
-
- # Handle mask propagation.
- previous_mask = _collect_previous_mask(inputs)
- user_kwargs = copy.copy(kwargs)
- if not _is_all_none(previous_mask):
- # The previous layer generated a mask.
- if has_arg(self.call, 'mask'):
- if 'mask' not in kwargs:
- # If mask is explicitly passed to __call__,
- # we should override the default mask.
- kwargs['mask'] = previous_mask
-
# Actually call the layer (optionally building it).
output = super(Layer, self).__call__(inputs, **kwargs)
- # Handle mask computation.
- with K.name_scope(self.name):
- output_mask = self.compute_mask(inputs, previous_mask)
-
- # If the layer returns tensors from its inputs, unmodified,
- # we copy them to avoid loss of tensor metadata.
- output_ls = _to_list(output)
- inputs_ls = _to_list(inputs)
- output_ls_copy = []
- for x in output_ls:
- if x in inputs_ls:
- x = K.identity(x)
- output_ls_copy.append(x)
- if len(output_ls_copy) == 1:
- output = output_ls_copy[0]
- else:
- output = output_ls_copy
-
- # Add an inbound node to the layer, so that it keeps track
- # of the call and of all new variables created during the call.
- # This also updates the layer history of the output tensor(s).
- # If the input tensor(s) had not previous Keras history,
- # this does nothing.
- self._add_inbound_node(
- input_tensors=inputs,
- output_tensors=output,
- input_masks=previous_mask,
- output_masks=output_mask,
- arguments=user_kwargs)
+ # Update learning phase info.
+ output_tensors = _to_list(output)
+ uses_lp = any(
+ [getattr(x, '_uses_learning_phase', False) for x in _to_list(inputs)])
+ uses_lp = getattr(self, 'uses_learning_phase', False) or uses_lp
+ for i in range(len(output_tensors)):
+ output_tensors[i]._uses_learning_phase = getattr(
+ output_tensors[i], '_uses_learning_phase', False) or uses_lp
# Optionally load weight values that were specified at layer instantiation.
if hasattr(self, '_initial_weights') and self._initial_weights is not None:
@@ -431,63 +270,6 @@ class Layer(tf_base_layers.Layer):
del self._initial_weights
return output
- def _add_inbound_node(self,
- input_tensors,
- output_tensors,
- input_masks,
- output_masks,
- arguments=None):
- """Internal method to create an inbound node for the layer.
-
- Arguments:
- input_tensors: list of input tensors.
- output_tensors: list of output tensors.
- input_masks: list of input masks (a mask can be a tensor, or None).
- output_masks: list of output masks (a mask can be a tensor, or None).
- arguments: dictionary of keyword arguments that were passed to the
- `call` method of the layer at the call that created the node.
- """
- input_tensors = _to_list(input_tensors)
- output_tensors = _to_list(output_tensors)
- input_masks = _to_list(input_masks)
- output_masks = _to_list(output_masks)
-
- # Collect input tensor(s) coordinates.
- inbound_layers = []
- node_indices = []
- tensor_indices = []
- for x in input_tensors:
- if hasattr(x, '_keras_history'):
- inbound_layer, node_index, tensor_index = x._keras_history
- inbound_layers.append(inbound_layer)
- node_indices.append(node_index)
- tensor_indices.append(tensor_index)
- else:
- inbound_layers.append(None)
- node_indices.append(None)
- tensor_indices.append(None)
-
- # Create node, add it to inbound nodes.
- Node(
- self,
- inbound_layers=inbound_layers,
- node_indices=node_indices,
- tensor_indices=tensor_indices,
- input_tensors=input_tensors,
- output_tensors=output_tensors,
- input_masks=input_masks,
- output_masks=output_masks,
- arguments=arguments)
-
- # Update tensor history and `_uses_learning_phase`.
- for i in range(len(output_tensors)):
- uses_lp = any(
- [getattr(x, '_uses_learning_phase', False) for x in input_tensors])
- uses_lp = getattr(self, 'uses_learning_phase', False) or uses_lp
- output_tensors[i]._uses_learning_phase = getattr(
- output_tensors[i], '_uses_learning_phase', False) or uses_lp
- output_tensors[i]._keras_history = (self, len(self.inbound_nodes) - 1, i)
-
def _compute_output_shape(self, input_shape):
"""Computes the output shape of the layer.
@@ -534,115 +316,6 @@ class Layer(tf_base_layers.Layer):
# carry over the input mask
return mask
- def build(self, input_shape): # pylint: disable=unused-argument
- """Creates the layer weights.
-
- Must be implemented on all layers that have weights.
-
- Arguments:
- input_shape: Keras tensor (future input to layer)
- or list/tuple of Keras tensors to reference
- for weight shape computations.
- """
- self.built = True
-
- def _get_node_attribute_at_index(self, node_index, attr, attr_name):
- """Retrieves an attribute (e.g. input_tensors) from a node.
-
- This is used to implement the methods:
- - get_input_shape_at
- - get_output_shape_at
- - get_input_at
- etc...
-
- Arguments:
- node_index: Integer index of the node from which
- to retrieve the attribute.
- attr: Exact node attribute name.
- attr_name: Human-readable attribute name, for error messages.
-
- Returns:
- The layer's attribute `attr` at the node of index `node_index`.
-
- Raises:
- RuntimeError: If the layer has no inbound nodes.
- ValueError: If the index is does not match any node.
- """
- if not self.inbound_nodes:
- raise RuntimeError('The layer has never been called '
- 'and thus has no defined ' + attr_name + '.')
- if not len(self.inbound_nodes) > node_index:
- raise ValueError('Asked to get ' + attr_name + ' at node ' +
- str(node_index) + ', but the layer has only ' +
- str(len(self.inbound_nodes)) + ' inbound nodes.')
- values = getattr(self.inbound_nodes[node_index], attr)
- if len(values) == 1:
- return values[0]
- else:
- return values
-
- def get_input_shape_at(self, node_index):
- """Retrieves the input shape(s) of a layer at a given node.
-
- Arguments:
- node_index: Integer, index of the node
- from which to retrieve the attribute.
- E.g. `node_index=0` will correspond to the
- first time the layer was called.
-
- Returns:
- A shape tuple
- (or list of shape tuples if the layer has multiple inputs).
- """
- return self._get_node_attribute_at_index(node_index, 'input_shapes',
- 'input shape')
-
- def get_output_shape_at(self, node_index):
- """Retrieves the output shape(s) of a layer at a given node.
-
- Arguments:
- node_index: Integer, index of the node
- from which to retrieve the attribute.
- E.g. `node_index=0` will correspond to the
- first time the layer was called.
-
- Returns:
- A shape tuple
- (or list of shape tuples if the layer has multiple outputs).
- """
- return self._get_node_attribute_at_index(node_index, 'output_shapes',
- 'output shape')
-
- def get_input_at(self, node_index):
- """Retrieves the input tensor(s) of a layer at a given node.
-
- Arguments:
- node_index: Integer, index of the node
- from which to retrieve the attribute.
- E.g. `node_index=0` will correspond to the
- first time the layer was called.
-
- Returns:
- A tensor (or list of tensors if the layer has multiple inputs).
- """
- return self._get_node_attribute_at_index(node_index, 'input_tensors',
- 'input')
-
- def get_output_at(self, node_index):
- """Retrieves the output tensor(s) of a layer at a given node.
-
- Arguments:
- node_index: Integer, index of the node
- from which to retrieve the attribute.
- E.g. `node_index=0` will correspond to the
- first time the layer was called.
-
- Returns:
- A tensor (or list of tensors if the layer has multiple outputs).
- """
- return self._get_node_attribute_at_index(node_index, 'output_tensors',
- 'output')
-
def get_input_mask_at(self, node_index):
"""Retrieves the input mask tensor(s) of a layer at a given node.
@@ -656,8 +329,11 @@ class Layer(tf_base_layers.Layer):
A mask tensor
(or list of tensors if the layer has multiple inputs).
"""
- return self._get_node_attribute_at_index(node_index, 'input_masks',
- 'input mask')
+ inputs = self.get_input_at(node_index)
+ if isinstance(inputs, list):
+ return [getattr(x, '_keras_mask', None) for x in inputs]
+ else:
+ return getattr(inputs, '_keras_mask', None)
def get_output_mask_at(self, node_index):
"""Retrieves the output mask tensor(s) of a layer at a given node.
@@ -672,57 +348,11 @@ class Layer(tf_base_layers.Layer):
A mask tensor
(or list of tensors if the layer has multiple outputs).
"""
- return self._get_node_attribute_at_index(node_index, 'output_masks',
- 'output mask')
-
- @property
- def input(self):
- """Retrieves the input tensor(s) of a layer.
-
- Only applicable if the layer has exactly one inbound node,
- i.e. if it is connected to one incoming layer.
-
- Returns:
- Input tensor or list of input tensors.
-
- Raises:
- AttributeError: if the layer is connected to
- more than one incoming layers.
- """
- if len(self.inbound_nodes) > 1:
- raise AttributeError('Layer ' + self.name +
- ' has multiple inbound nodes, '
- 'hence the notion of "layer input" '
- 'is ill-defined. '
- 'Use `get_input_at(node_index)` instead.')
- elif not self.inbound_nodes:
- raise AttributeError('Layer ' + self.name +
- ' is not connected, no input to return.')
- return self._get_node_attribute_at_index(0, 'input_tensors', 'input')
-
- @property
- def output(self):
- """Retrieves the output tensor(s) of a layer.
-
- Only applicable if the layer has exactly one inbound node,
- i.e. if it is connected to one incoming layer.
-
- Returns:
- Output tensor or list of output tensors.
-
- Raises:
- AttributeError: if the layer is connected to
- more than one incoming layers.
- """
- if not self.inbound_nodes:
- raise AttributeError('Layer ' + self.name + ' has no inbound nodes.')
- if len(self.inbound_nodes) > 1:
- raise AttributeError('Layer ' + self.name +
- ' has multiple inbound nodes, '
- 'hence the notion of "layer output" '
- 'is ill-defined. '
- 'Use `get_output_at(node_index)` instead.')
- return self._get_node_attribute_at_index(0, 'output_tensors', 'output')
+ output = self.get_output_at(node_index)
+ if isinstance(output, list):
+ return [getattr(x, '_keras_mask', None) for x in output]
+ else:
+ return getattr(output, '_keras_mask', None)
@property
def input_mask(self):
@@ -739,14 +369,11 @@ class Layer(tf_base_layers.Layer):
AttributeError: if the layer is connected to
more than one incoming layers.
"""
- if len(self.inbound_nodes) != 1:
- raise AttributeError('Layer ' + self.name +
- ' has multiple inbound nodes, ' +
- 'hence the notion of "layer input mask" '
- 'is ill-defined. '
- 'Use `get_input_mask_at(node_index)` '
- 'instead.')
- return self._get_node_attribute_at_index(0, 'input_masks', 'input mask')
+ inputs = self.input
+ if isinstance(inputs, list):
+ return [getattr(x, '_keras_mask', None) for x in inputs]
+ else:
+ return getattr(inputs, '_keras_mask', None)
@property
def output_mask(self):
@@ -763,90 +390,11 @@ class Layer(tf_base_layers.Layer):
AttributeError: if the layer is connected to
more than one incoming layers.
"""
- if len(self.inbound_nodes) != 1:
- raise AttributeError('Layer ' + self.name +
- ' has multiple inbound nodes, '
- 'hence the notion of "layer output mask" '
- 'is ill-defined. '
- 'Use `get_output_mask_at(node_index)` '
- 'instead.')
- return self._get_node_attribute_at_index(0, 'output_masks', 'output mask')
-
- @property
- def input_shape(self):
- """Retrieves the input shape(s) of a layer.
-
- Only applicable if the layer has exactly one inbound node,
- i.e. if it is connected to one incoming layer.
-
- Returns:
- Input shape, as `TensorShape`
- (or list of `TensorShape`, one tuple per input tensor).
-
- Raises:
- AttributeError: if the layer is connected to
- more than one incoming layers.
- """
- if not self.inbound_nodes:
- raise AttributeError('The layer has never been called '
- 'and thus has no defined input shape.')
- all_input_shapes = set(
- [str(node.input_shapes) for node in self.inbound_nodes])
- if len(all_input_shapes) == 1:
- input_shapes = self.inbound_nodes[0].input_shapes
- if len(input_shapes) == 1:
- return tuple(tensor_shape.TensorShape(input_shapes[0]).as_list())
- else:
- return [
- tuple(tensor_shape.TensorShape(shape).as_list())
- for shape in input_shapes
- ]
+ output = self.output
+ if isinstance(output, list):
+ return [getattr(x, '_keras_mask', None) for x in output]
else:
- raise AttributeError('The layer "' + str(self.name) +
- ' has multiple inbound nodes, '
- 'with different input shapes. Hence '
- 'the notion of "input shape" is '
- 'ill-defined for the layer. '
- 'Use `get_input_shape_at(node_index)` '
- 'instead.')
-
- @property
- def output_shape(self):
- """Retrieves the output shape(s) of a layer.
-
- Only applicable if the layer has one inbound node,
- or if all inbound nodes have the same output shape.
-
- Returns:
- Output shape, as `TensorShape`
- (or list of `TensorShape`, one tuple per output tensor).
-
- Raises:
- AttributeError: if the layer is connected to
- more than one incoming layers.
- """
- if not self.inbound_nodes:
- raise AttributeError('The layer has never been called '
- 'and thus has no defined output shape.')
- all_output_shapes = set(
- [str(node.output_shapes) for node in self.inbound_nodes])
- if len(all_output_shapes) == 1:
- output_shapes = self.inbound_nodes[0].output_shapes
- if len(output_shapes) == 1:
- return tuple(tensor_shape.TensorShape(output_shapes[0]).as_list())
- else:
- return [
- tuple(tensor_shape.TensorShape(shape).as_list())
- for shape in output_shapes
- ]
- else:
- raise AttributeError('The layer "' + str(self.name) +
- ' has multiple inbound nodes, '
- 'with different output shapes. Hence '
- 'the notion of "output shape" is '
- 'ill-defined for the layer. '
- 'Use `get_output_shape_at(node_index)` '
- 'instead.')
+ return getattr(output, '_keras_mask', None)
def set_weights(self, weights):
"""Sets the weights of the layer, from Numpy arrays.
@@ -951,17 +499,15 @@ class Layer(tf_base_layers.Layer):
return sum([K.count_params(p) for p in self.weights])
-class InputLayer(Layer):
+class InputLayer(tf_base_layers.InputLayer, Layer):
"""Layer to be used as an entry point into a graph.
It can either wrap an existing tensor (pass an `input_tensor` argument)
- or create its a placeholder tensor (pass arguments `input_shape`
- or `batch_input_shape` as well as `dtype`).
+ or create its a placeholder tensor (pass argument `input_shape`.
Arguments:
input_shape: Shape tuple, not including the batch axis.
batch_size: Optional input batch size (integer or None).
- batch_input_shape: Shape tuple, including the batch axis.
dtype: Datatype of the input.
input_tensor: Optional tensor to use as layer input
instead of creating a placeholder.
@@ -973,71 +519,44 @@ class InputLayer(Layer):
def __init__(self,
input_shape=None,
batch_size=None,
- batch_input_shape=None,
dtype=None,
input_tensor=None,
sparse=False,
- name=None):
+ name=None,
+ **kwargs):
+ if 'batch_input_shape' in kwargs:
+ batch_input_shape = kwargs.pop('batch_input_shape')
+ if input_shape and batch_input_shape:
+ raise ValueError('Only provide the input_shape OR '
+ 'batch_input_shape argument to '
+ 'InputLayer, not both at the same time.')
+ batch_size = batch_input_shape[0]
+ input_shape = batch_input_shape[1:]
+ if kwargs:
+ raise ValueError('Unrecognized keyword arguments:', kwargs.keys())
+
if not name:
prefix = 'input'
name = prefix + '_' + str(K.get_uid(prefix))
+
if not dtype:
if input_tensor is None:
dtype = K.floatx()
else:
dtype = K.dtype(input_tensor)
- super(InputLayer, self).__init__(dtype=dtype, name=name)
- self.built = True
- self.sparse = sparse
-
- if input_shape and batch_input_shape:
- raise ValueError('Only provide the input_shape OR '
- 'batch_input_shape argument to '
- 'InputLayer, not both at the same time.')
+ super(InputLayer, self).__init__(input_shape=input_shape,
+ batch_size=batch_size,
+ dtype=dtype,
+ input_tensor=input_tensor,
+ sparse=sparse,
+ name=name)
+
if input_tensor is not None:
- # Attempt automatic input shape inference.
- try:
- batch_input_shape = K.int_shape(input_tensor)
- except TypeError:
- if not input_shape and not batch_input_shape:
- raise ValueError('InputLayer was provided '
- 'an input_tensor argument, '
- 'but its input shape cannot be '
- 'automatically inferred. '
- 'You should pass an input_shape or '
- 'batch_input_shape argument.')
- if not batch_input_shape:
- if not input_shape:
- raise ValueError('An Input layer should be passed either '
- 'a `batch_input_shape` or an `input_shape`.')
- else:
- batch_input_shape = (batch_size,) + tuple(input_shape)
+ self.is_placeholder = False
+ self.batch_input_shape = tuple(input_tensor.get_shape().as_list())
else:
- batch_input_shape = tuple(batch_input_shape)
- self.batch_input_shape = batch_input_shape
-
- if input_tensor is None:
self.is_placeholder = True
- input_tensor = K.placeholder(
- shape=batch_input_shape,
- dtype=dtype,
- sparse=self.sparse,
- name=self.name)
- else:
- self.is_placeholder = False
- # Create an input node to add to self.outbound_node
- # and set output_tensors' _keras_history.
- input_tensor._uses_learning_phase = False
- input_tensor._keras_history = (self, 0, 0)
- Node(
- self,
- inbound_layers=[],
- node_indices=[],
- tensor_indices=[],
- input_tensors=[input_tensor],
- output_tensors=[input_tensor],
- input_masks=[None],
- output_masks=[None])
+ self.batch_input_shape = (batch_size,) + tuple(input_shape)
def get_config(self):
config = {
@@ -1051,11 +570,12 @@ class InputLayer(Layer):
def Input( # pylint: disable=invalid-name
shape=None,
- batch_shape=None,
+ batch_size=None,
name=None,
- dtype=K.floatx(),
+ dtype=None,
sparse=False,
- tensor=None):
+ tensor=None,
+ **kwargs):
"""`Input()` is used to instantiate a Keras tensor.
A Keras tensor is a tensor object from the underlying backend
@@ -1073,14 +593,10 @@ def Input( # pylint: disable=invalid-name
recursively.
Arguments:
- shape: A shape tuple (integer), not including the batch size.
+ shape: A shape tuple (integers), not including the batch size.
For instance, `shape=(32,)` indicates that the expected input
will be batches of 32-dimensional vectors.
- batch_shape: A shape tuple (integer), including the batch size.
- For instance, `batch_shape=(10, 32)` indicates that
- the expected input will be batches of 10 32-dimensional vectors.
- `batch_shape=(None, 32)` indicates batches of an arbitrary number
- of 32-dimensional vectors.
+ batch_size: optional static batch size (integer).
name: An optional name string for the layer.
Should be unique in a model (do not reuse the same name twice).
It will be autogenerated if it isn't provided.
@@ -1090,6 +606,7 @@ def Input( # pylint: disable=invalid-name
to be created is sparse.
tensor: Optional existing tensor to wrap into the `Input` layer.
If set, the layer will not create a placeholder tensor.
+ **kwargs: deprecated arguments support.
Returns:
A tensor.
@@ -1102,16 +619,31 @@ def Input( # pylint: disable=invalid-name
y = Dense(16, activation='softmax')(x)
model = Model(x, y)
```
+
+ Raises:
+ ValueError: in case of invalid arguments.
"""
- if not batch_shape and tensor is None:
- assert shape, ('Please provide to Input either a `shape`'
- ' or a `batch_shape` argument. Note that '
- '`shape` does not include the batch '
- 'dimension.')
- if shape and not batch_shape:
- batch_shape = (None,) + tuple(shape)
+ if 'batch_shape' in kwargs:
+ batch_shape = kwargs.pop('batch_shape')
+ if shape and batch_shape:
+ raise ValueError('Only provide the shape OR '
+ 'batch_shape argument to '
+ 'Input, not both at the same time.')
+ batch_size = batch_shape[0]
+ shape = batch_shape[1:]
+ if kwargs:
+ raise ValueError('Unrecognized keyword arguments:', kwargs.keys())
+
+ if dtype is None:
+ dtype = K.floatx()
+ if not shape and tensor is None:
+ raise ValueError('Please provide to Input either a `shape`'
+ ' or a `tensor` argument. Note that '
+ '`shape` does not include the batch '
+ 'dimension.')
input_layer = InputLayer(
- batch_input_shape=batch_shape,
+ input_shape=shape,
+ batch_size=batch_size,
name=name,
dtype=dtype,
sparse=sparse,
@@ -1125,7 +657,7 @@ def Input( # pylint: disable=invalid-name
return outputs
-class Container(Layer):
+class Network(tf_base_layers.Network, Layer):
"""A Container is a directed acyclic graph of layers.
It is the topological form of a "model". A Model
@@ -1162,118 +694,20 @@ class Container(Layer):
from_config
"""
- def __init__(self, inputs, outputs, name=None): # pylint: disable=super-init-not-called
- # Handle `name` argument.
- if not name:
- prefix = self.__class__.__name__.lower()
- name = prefix + '_' + str(K.get_uid(prefix))
- self.name = name
- self.supports_masking = False
- self.trainable = True
- self._per_input_losses = {}
- self._per_input_updates = {}
-
- # The following properties are not actually used by Keras;
- # they exist for compatibility with TF.
- self._updates = []
- self._losses = []
- self._scope = None
- self._reuse = None
- self._base_name = name
- self._graph = ops.get_default_graph()
-
- # Container-specific properties.
- if isinstance(inputs, (list, tuple)):
- self.inputs = list(inputs) # Tensor or list of tensors.
- else:
- self.inputs = [inputs]
- if isinstance(outputs, (list, tuple)):
- self.outputs = list(outputs)
- else:
- self.outputs = [outputs]
-
- # Check for redundancy in inputs.
- inputs_set = set(self.inputs)
- if len(inputs_set) != len(self.inputs):
- raise ValueError('The list of inputs passed to the model '
- 'is redundant. '
- 'All inputs should only appear once.'
- ' Found: ' + str(self.inputs))
-
- # List of initial layers (1 to 1 mapping with self.inputs,
- # hence the same layer might appear twice)
- self.input_layers = []
- self.input_layers_node_indices = []
- self.input_layers_tensor_indices = []
- # list of layers (1 to 1 mapping with self.inputs,
- # hence the same layer might appear twice)
- self.output_layers = []
- self.output_layers_node_indices = []
- self.output_layers_tensor_indices = []
- # all layers in order of horizontal graph traversal.
- # Entries are unique. Includes input and output layers.
- self.layers = []
-
- # This is for performance optimization
- # when calling the Container on new inputs.
- # every time the Container is called on a set on input tensors,
- # we compute the output tensors,
- # output masks and output shapes in one pass,
- # then cache them here. When of of these output is queried later,
- # we retrieve it from there instead of recomputing it.
- self._output_mask_cache = {}
- self._output_tensor_cache = {}
- self._output_shape_cache = {}
-
- # User-provided arguments validation.
- for x in self.inputs:
- # Check that x is a Keras tensor.
- if not hasattr(x, '_keras_history'):
- cls_name = self.__class__.__name__
- raise TypeError('Input tensors to a ' + cls_name + ' ' +
- 'must be Keras tensors. Found: ' + str(x) +
- ' (missing Keras metadata).')
- # Check that x is an input tensor.
- layer, node_index, tensor_index = x._keras_history
- if len(layer.inbound_nodes) > 1 or (
- layer.inbound_nodes and layer.inbound_nodes[0].inbound_layers):
- cls_name = self.__class__.__name__
- logging.warning(cls_name + ' inputs must come from '
- 'a Keras Input layer, '
- 'they cannot be the output of '
- 'a previous non-Input layer. '
- 'Here, a tensor specified as '
- 'input to "' + self.name + '" was not an Input tensor, '
- 'it was generated by layer ' + layer.name + '.\n'
- 'Note that input tensors are '
- 'instantiated via `tensor = Input(shape)`.\n'
- 'The tensor that caused the issue was: ' + str(x.name))
- for x in self.outputs:
- if not hasattr(x, '_keras_history'):
- cls_name = self.__class__.__name__
- raise TypeError('Output tensors to a ' + cls_name + ' must be '
- 'Keras tensors. Found: ' + str(x))
- # Build self.output_layers:
- for x in self.outputs:
- layer, node_index, tensor_index = x._keras_history
- self.output_layers.append(layer)
- self.output_layers_node_indices.append(node_index)
- self.output_layers_tensor_indices.append(tensor_index)
+ def __init__(self, inputs, outputs, name=None):
+ super(Network, self).__init__(inputs, outputs, name=name)
+ self.supports_masking = False
# Fill in the output mask cache.
masks = []
for x in self.inputs:
- layer, node_index, tensor_index = x._keras_history
- node = layer.inbound_nodes[node_index]
- mask = node.output_masks[tensor_index]
+ mask = x._keras_mask if hasattr(x, '_keras_mask') else None
masks.append(mask)
- mask_cache_key = ','.join([str(id(x)) for x in self.inputs])
- mask_cache_key += '_' + ','.join([str(id(x)) for x in masks])
+ mask_cache_key = (tf_base_layers._object_list_uid(self.inputs) + '_' +
+ tf_base_layers._object_list_uid(masks))
masks = []
for x in self.outputs:
- layer, node_index, tensor_index = x._keras_history
- node = layer.inbound_nodes[node_index]
- mask = node.output_masks[tensor_index]
+ mask = x._keras_mask if hasattr(x, '_keras_mask') else None
masks.append(mask)
if len(masks) == 1:
mask = masks[0]
@@ -1281,322 +715,24 @@ class Container(Layer):
mask = masks
self._output_mask_cache[mask_cache_key] = mask
- # Build self.input_layers:
- for x in self.inputs:
- layer, node_index, tensor_index = x._keras_history
- # It's supposed to be an input layer, so only one node
- # and one tensor output.
- assert node_index == 0
- assert tensor_index == 0
- self.input_layers.append(layer)
- self.input_layers_node_indices.append(node_index)
- self.input_layers_tensor_indices.append(tensor_index)
-
# Build self.input_names and self.output_names.
self.input_names = []
self.output_names = []
self._feed_input_names = []
self._feed_inputs = []
self._feed_input_shapes = []
- for i, layer in enumerate(self.input_layers):
+ for i, layer in enumerate(self._input_layers):
self.input_names.append(layer.name)
if layer.is_placeholder:
self._feed_input_names.append(layer.name)
self._feed_inputs.append(layer.input)
self._feed_input_shapes.append(K.int_shape(self.inputs[i]))
- for layer in self.output_layers:
+ for layer in self._output_layers:
self.output_names.append(layer.name)
self.internal_input_shapes = [K.int_shape(x) for x in self.inputs]
self.internal_output_shapes = [K.int_shape(x) for x in self.outputs]
- # Container_nodes: set of nodes included in the graph
- # (not all nodes included in the layers
- # are relevant to the current graph).
- container_nodes = set() # ids of all nodes relevant to the Container
- nodes_depths = {} # dict {node: depth value}
- layers_depths = {} # dict {layer: depth value}
- layer_indices = {} # dict {layer: index in traversal}
- nodes_in_decreasing_depth = []
-
- def build_map_of_graph(tensor,
- finished_nodes,
- nodes_in_progress,
- layer=None,
- node_index=None,
- tensor_index=None):
- """Builds a map of the graph of layers.
-
- This recursively updates the map `layer_indices`,
- the list `nodes_in_decreasing_depth` and the set `container_nodes`.
-
- Arguments:
- tensor: Some tensor in a graph.
- finished_nodes: Set of nodes whose subgraphs have been traversed
- completely. Useful to prevent duplicated work.
- nodes_in_progress: Set of nodes that are currently active on the
- recursion stack. Useful to detect cycles.
- layer: Layer from which `tensor` comes from. If not provided,
- will be obtained from `tensor._keras_history`.
- node_index: Node index from which `tensor` comes from.
- tensor_index: Tensor_index from which `tensor` comes from.
-
- Raises:
- RuntimeError: if a cycle is detected.
- """
- if not layer or node_index is None or tensor_index is None:
- layer, node_index, tensor_index = tensor._keras_history
- node = layer.inbound_nodes[node_index]
-
- # Prevent cycles.
- if node in nodes_in_progress:
- raise RuntimeError('The tensor ' + str(tensor) + ' at layer "' +
- layer.name + '" is part of a cycle.')
-
- # Don't repeat work for shared subgraphs
- if node in finished_nodes:
- return
-
- node_key = layer.name + '_ib-' + str(node_index)
- # Update container_nodes.
- container_nodes.add(node_key)
-
- # Store the traversal order for layer sorting.
- if layer not in layer_indices:
- layer_indices[layer] = len(layer_indices)
-
- nodes_in_progress.add(node)
-
- # Propagate to all previous tensors connected to this node.
- for i in range(len(node.inbound_layers)):
- x = node.input_tensors[i]
- layer = node.inbound_layers[i]
- node_index = node.node_indices[i]
- tensor_index = node.tensor_indices[i]
- build_map_of_graph(x, finished_nodes, nodes_in_progress, layer,
- node_index, tensor_index)
-
- finished_nodes.add(node)
- nodes_in_progress.remove(node)
-
- nodes_in_decreasing_depth.append(node)
-
- finished_nodes = set()
- nodes_in_progress = set()
- for x in self.outputs:
- build_map_of_graph(x, finished_nodes, nodes_in_progress)
-
- for node in reversed(nodes_in_decreasing_depth):
- # If the depth is not set, the node has no outbound nodes (depth 0).
- depth = nodes_depths.setdefault(node, 0)
-
- # Update the depth of the corresponding layer
- previous_depth = layers_depths.get(node.outbound_layer, 0)
- # If we've seen this layer before at a higher depth,
- # we should use that depth instead of the node depth.
- # This is necessary for shared layers that have inputs at different
- # depth levels in the graph.
- depth = max(depth, previous_depth)
- layers_depths[node.outbound_layer] = depth
- nodes_depths[node] = depth
-
- # Update the depth of inbound nodes.
- for i in range(len(node.inbound_layers)):
- inbound_layer = node.inbound_layers[i]
- node_index = node.node_indices[i]
- inbound_node = inbound_layer.inbound_nodes[node_index]
- previous_depth = nodes_depths.get(inbound_node, 0)
- nodes_depths[inbound_node] = max(depth + 1, previous_depth)
-
- # Build a dict {depth: list of nodes with this depth}
- nodes_by_depth = {}
- for node, depth in nodes_depths.items():
- if depth not in nodes_by_depth:
- nodes_by_depth[depth] = []
- nodes_by_depth[depth].append(node)
-
- # Build a dict {depth: list of layers with this depth}
- layers_by_depth = {}
- for layer, depth in layers_depths.items():
- if depth not in layers_by_depth:
- layers_by_depth[depth] = []
- layers_by_depth[depth].append(layer)
-
- # Get sorted list of layer depths.
- depth_keys = list(layers_by_depth.keys())
- depth_keys.sort(reverse=True)
-
- # Set self.layers and self.layers_by_depth.
- layers = []
- for depth in depth_keys:
- layers_for_depth = layers_by_depth[depth]
- # Container.layers needs to have a deterministic order:
- # here we order them by traversal order.
- layers_for_depth.sort(key=lambda x: layer_indices[x])
- for layer in layers_for_depth:
- layers.append(layer)
- self.layers = layers
- self.layers_by_depth = layers_by_depth
-
- # Get sorted list of node depths.
- depth_keys = list(nodes_by_depth.keys())
- depth_keys.sort(reverse=True)
-
- # Check that all tensors required are computable.
- # computable_tensors: all tensors in the graph
- # that can be computed from the inputs provided.
- computable_tensors = []
- for x in self.inputs:
- computable_tensors.append(x)
-
- layers_with_complete_input = [] # To provide a better error msg.
- for depth in depth_keys:
- for node in nodes_by_depth[depth]:
- layer = node.outbound_layer
- if layer:
- for x in node.input_tensors:
- if x not in computable_tensors:
- raise RuntimeError('Graph disconnected: '
- 'cannot obtain value for tensor ' + str(x) +
- ' at layer "' + layer.name + '". '
- 'The following previous layers '
- 'were accessed without issue: ' +
- str(layers_with_complete_input))
- for x in node.output_tensors:
- computable_tensors.append(x)
- layers_with_complete_input.append(layer.name)
-
- # Set self.nodes and self.nodes_by_depth.
- self.container_nodes = container_nodes
- self.nodes_by_depth = nodes_by_depth
-
- # Ensure name unicity, which will be crucial for serialization
- # (since serialized nodes refer to layers by their name).
- all_names = [layer.name for layer in self.layers]
- for name in all_names:
- if all_names.count(name) != 1:
- raise RuntimeError('The name "' + name + '" is used ' +
- str(all_names.count(name)) + ' times in the model. '
- 'All layer names should be unique.')
-
- # Layer parameters.
- # The new container starts with a single inbound node
- # for its inputs, and no outbound nodes.
- self.outbound_nodes = [] # Will be appended to by future calls to __call__
- self.inbound_nodes = [
- ] # Will be appended to below, and by future calls to __call__
- # Create the node linking internal inputs to internal outputs.
- Node(
- outbound_layer=self,
- inbound_layers=[],
- node_indices=[],
- tensor_indices=[],
- input_tensors=self.inputs,
- output_tensors=self.outputs,
- # No container-level masking for now.
- input_masks=[None for _ in self.inputs],
- output_masks=[None for _ in self.outputs])
- self.built = True
-
- # The following are implemented as property functions:
- # self.constraints
- # self.trainable_weights
- # self.non_trainable_weights
- # self.input_spec
-
- def get_layer(self, name=None, index=None):
- """Retrieves a layer based on either its name (unique) or index.
-
- Indices are based on order of horizontal graph traversal (bottom-up).
-
- Arguments:
- name: String, name of layer.
- index: Integer, index of layer.
-
- Returns:
- A layer instance.
-
- Raises:
- ValueError: In case of invalid layer name or index.
- """
- # It would be unreliable to build a dictionary
- # based on layer names, because names can potentially
- # be changed at any point by the user
- # without the container being notified of it.
- if index is not None:
- if len(self.layers) <= index:
- raise ValueError('Was asked to retrieve layer at index ' + str(index) +
- ' but model only has ' + str(len(self.layers)) +
- ' layers.')
- else:
- return self.layers[index]
- else:
- if not name:
- raise ValueError('Provide either a layer name or layer index.')
- layer = None
- for layer in self.layers:
- if layer.name == name:
- return layer
- if not layer:
- raise ValueError('No such layer: ' + name)
-
- @property
- def updates(self):
- """Retrieve the model's updates.
-
- Will only include updates that are either
- unconditional, or conditional on inputs to this model
- (e.g. will not include updates that depend on tensors
- that aren't inputs to this model).
-
- Returns:
- A list of update ops.
- """
- updates = []
- for layer in self.layers:
- if hasattr(layer, 'updates'):
- # Collect updates that are dependent on inputs
- # that are part of the model.
- for node_index, node in enumerate(layer.inbound_nodes):
- node_key = layer.name + '_ib-' + str(node_index)
- if node_key in self.container_nodes:
- # The model owns this layer node.
- inputs = node.input_tensors
- updates += layer.get_updates_for(inputs)
- # Collect unconditional updates.
- updates += layer.get_updates_for(None)
- return updates
-
- @property
- def losses(self):
- """Retrieve the model's losses.
-
- Will only include losses that are either
- unconditional, or conditional on inputs to this model
- (e.g. will not include losses that depend on tensors
- that aren't inputs to this model).
-
- Returns:
- A list of loss tensors.
- """
- losses = []
- # Retrieve losses for all internal layers.
- for layer in self.layers:
- if hasattr(layer, 'losses'):
- # Collect losses that are dependent on inputs
- # that are part of the model.
- for node_index, node in enumerate(layer.inbound_nodes):
- node_key = layer.name + '_ib-' + str(node_index)
- if node_key in self.container_nodes:
- # The model owns this layer node.
- inputs = node.input_tensors
- losses += layer.get_losses_for(inputs)
- # Collect unconditional losses.
- losses += layer.get_losses_for(None)
- # Add any potential unconditional model-level loss.
- losses += self.get_losses_for(None)
- return losses
-
@property
def uses_learning_phase(self):
return any([x._uses_learning_phase for x in self.outputs])
@@ -1640,27 +776,6 @@ class Container(Layer):
cons[key] = value
return cons
- @property
- def trainable_weights(self):
- if not self.trainable:
- return []
- weights = []
- for layer in self.layers:
- weights += layer.trainable_weights
- return weights
-
- @property
- def non_trainable_weights(self):
- weights = []
- for layer in self.layers:
- weights += layer.non_trainable_weights
- if not self.trainable:
- trainable_weights = []
- for layer in self.layers:
- trainable_weights += layer.trainable_weights
- return trainable_weights + weights
- return weights
-
def get_weights(self):
"""Retrieves the weights of the model.
@@ -1688,60 +803,6 @@ class Container(Layer):
weights = weights[num_param:]
K.batch_set_value(tuples)
- @property
- def input_spec(self):
- """Gets the model's input specs.
-
- Returns:
- A list of `InputSpec` instances (one per input to the model)
- or a single instance if the model has only one input.
- """
- specs = []
- for layer in getattr(self, 'input_layers', []):
- if layer.input_spec is None:
- specs.append(None)
- else:
- if not isinstance(layer.input_spec, list):
- raise TypeError('Layer ' + layer.name +
- ' has an input_spec attribute that '
- 'is not a list. We expect a list. '
- 'Found input_spec = ' + str(layer.input_spec))
- specs += layer.input_spec
- if len(specs) == 1:
- return specs[0]
- return specs
-
- def call(self, inputs, mask=None):
- """Call the model on new inputs.
-
- In this case `call` just reapplies
- all ops in the graph to the new inputs
- (e.g. build a new computational graph from the provided inputs).
-
- A model is callable on non-Keras tensors.
-
- Arguments:
- inputs: A tensor or list of tensors.
- mask: A mask or list of masks. A mask can be
- either a tensor or None (no mask).
-
- Returns:
- A tensor if there is a single output, or
- a list of tensors if there are more than one outputs.
- """
- inputs = _to_list(inputs)
- if mask is None:
- masks = [None for _ in range(len(inputs))]
- else:
- masks = _to_list(mask)
- cache_key = ','.join([str(id(x)) for x in inputs])
- cache_key += '_' + ','.join([str(id(x)) for x in masks])
- if cache_key in self._output_tensor_cache:
- return self._output_tensor_cache[cache_key]
- else:
- output_tensors, _, _ = self.run_internal_graph(inputs, masks)
- return output_tensors
-
def compute_mask(self, inputs, mask):
inputs = _to_list(inputs)
if mask is None:
@@ -1753,278 +814,25 @@ class Container(Layer):
if cache_key in self._output_mask_cache:
return self._output_mask_cache[cache_key]
else:
- _, output_masks, _ = self.run_internal_graph(inputs, masks)
+ _, output_masks, _ = self._run_internal_graph(inputs, masks)
return output_masks
- def _compute_output_shape(self, input_shape):
- if isinstance(input_shape, list):
- input_shapes = []
- for shape in input_shape:
- if shape is not None:
- input_shapes.append(tuple(tensor_shape.TensorShape(shape).as_list()))
- else:
- input_shapes.append(None)
- else:
- if input_shape is not None:
- input_shapes = [tuple(tensor_shape.TensorShape(input_shape).as_list())]
- else:
- input_shapes = [None]
-
- if len(input_shapes) != len(self.input_layers):
- raise ValueError('Invalid input_shape argument ' + str(input_shape) +
- ': model has ' + str(len(self.input_layers)) +
- ' tensor inputs.')
-
- cache_key = ','.join([str(x) for x in input_shapes])
- if cache_key in self._output_shape_cache:
- output_shapes = self._output_shape_cache[cache_key]
- if isinstance(output_shapes, list):
- if len(output_shapes) == 1:
- return tensor_shape.TensorShape(output_shapes[0])
- else:
- return [tensor_shape.TensorShape(shape) for shape in output_shapes]
- else:
- return tensor_shape.TensorShape(output_shapes)
- else:
- # Bad luck, we have to run the graph manually.
- layers_to_output_shapes = {}
- for i in range(len(input_shapes)):
- layer = self.input_layers[i]
- input_shape = input_shapes[i]
- # It's an input layer: compute_output_shape is identity,
- # and there is only one node and one tensor output.
- shape_key = layer.name + '_0_0'
- layers_to_output_shapes[shape_key] = input_shape
-
- depth_keys = list(self.nodes_by_depth.keys())
- depth_keys.sort(reverse=True)
- # Iterate over nodes, by depth level.
- if len(depth_keys) > 1:
- for depth in depth_keys:
- nodes = self.nodes_by_depth[depth]
- for node in nodes:
- # This is always a single layer, never a list.
- layer = node.outbound_layer
- if layer in self.input_layers:
- # We've already covered the input layers
- # a few lines above.
- continue
- # Potentially redundant list,
- # same size of node.input_tensors.
- input_shapes = []
- for j in range(len(node.inbound_layers)):
- inbound_layer = node.inbound_layers[j]
- node_index = node.node_indices[j]
- tensor_index = node.tensor_indices[j]
- shape_key = inbound_layer.name + '_%s_%s' % (node_index,
- tensor_index)
- input_shape = layers_to_output_shapes[shape_key]
- input_shapes.append(input_shape)
-
- if len(input_shapes) == 1:
- output_shape = layer._compute_output_shape(input_shapes[0])
- else:
- output_shape = layer._compute_output_shape(input_shapes)
- if isinstance(output_shape, list):
- output_shapes = [
- tuple(tensor_shape.TensorShape(shape).as_list())
- for shape in output_shape
- ]
- else:
- output_shapes = [
- tuple(tensor_shape.TensorShape(output_shape).as_list())
- ]
-
- node_index = layer.inbound_nodes.index(node)
- for j in range(len(output_shapes)):
- shape_key = layer.name + '_%s_%s' % (node_index, j)
- layers_to_output_shapes[shape_key] = output_shapes[j]
-
- # Read final output shapes from layers_to_output_shapes.
- output_shapes = []
- output_shape_keys = []
- for i in range(len(self.output_layers)):
- layer = self.output_layers[i]
- node_index = self.output_layers_node_indices[i]
- tensor_index = self.output_layers_tensor_indices[i]
- shape_key = layer.name + '_%s_%s' % (node_index, tensor_index)
- output_shape_keys.append(shape_key)
-
- for i, key in enumerate(output_shape_keys):
- assert key in layers_to_output_shapes
- output_shapes.append(layers_to_output_shapes[key])
- # Store in cache.
- self._output_shape_cache[cache_key] = output_shapes
- if isinstance(output_shapes, list):
- if len(output_shapes) == 1:
- return tensor_shape.TensorShape(output_shapes[0])
- else:
- return [tensor_shape.TensorShape(shape) for shape in output_shapes]
- else:
- return tensor_shape.TensorShape(output_shapes)
-
- def run_internal_graph(self, inputs, masks=None):
- """Computes output tensors for new inputs.
-
- # Note:
- - Expects `inputs` to be a list (potentially with 1 element).
- - Can be run on non-Keras tensors.
-
- Arguments:
- inputs: List of tensors
- masks: List of masks (tensors or None).
-
- Returns:
- Three lists: output_tensors, output_masks, output_shapes
- """
- if masks is None:
- masks = [None for _ in range(len(inputs))]
-
- # Dictionary mapping reference tensors to tuples
- # (computed tensor, compute mask)
- # we assume a 1:1 mapping from tensor to mask
- # TODO(fchollet): raise exception when a `.compute_mask()` call
- # does not return a list the same size as `call`
- tensor_map = {}
- for x, y, mask in zip(self.inputs, inputs, masks):
- tensor_map[str(id(x))] = (y, mask)
-
- depth_keys = list(self.nodes_by_depth.keys())
- depth_keys.sort(reverse=True)
- for depth in depth_keys:
- nodes = self.nodes_by_depth[depth]
- for node in nodes:
- # This is always a single layer, never a list.
- layer = node.outbound_layer
-
- reference_input_tensors = node.input_tensors
- reference_output_tensors = node.output_tensors
-
- # If all previous input tensors are available in tensor_map,
- # then call node.inbound_layer on them.
- computed_data = [] # List of tuples (input, mask).
- for x in reference_input_tensors:
- if str(id(x)) in tensor_map:
- computed_data.append(tensor_map[str(id(x))])
-
- if len(computed_data) == len(reference_input_tensors):
- # call layer
- with K.name_scope(layer.name):
- if node.arguments:
- kwargs = node.arguments
- else:
- kwargs = {}
- if len(computed_data) == 1:
- computed_tensor, computed_mask = computed_data[0]
- if has_arg(layer.call, 'mask'):
- if 'mask' not in kwargs:
- kwargs['mask'] = computed_mask
- output_tensors = _to_list(layer.call(computed_tensor, **kwargs))
- output_masks = _to_list(
- layer.compute_mask(computed_tensor, computed_mask))
- computed_tensors = [computed_tensor]
- computed_masks = [computed_mask]
- else:
- computed_tensors = [x[0] for x in computed_data]
- computed_masks = [x[1] for x in computed_data]
- if has_arg(layer.call, 'mask'):
- if 'mask' not in kwargs:
- kwargs['mask'] = computed_masks
- output_tensors = _to_list(layer.call(computed_tensors, **kwargs))
- output_masks = _to_list(
- layer.compute_mask(computed_tensors, computed_masks))
-
- # Apply activity regularizer if any:
- if hasattr(layer, 'activity_regularizer'
- ) and layer.activity_regularizer is not None:
- regularization_losses = [
- layer.activity_regularizer(x) for x in computed_tensors
- ]
- layer.add_loss(regularization_losses, computed_tensors)
-
- # Update model updates and losses:
- # Keep track of updates that depend on the inputs
- # (e.g. BN updates).
- self.add_update(layer.get_updates_for(computed_tensors), inputs)
- # Keep track of unconditional updates (e.g. a counter).
- self.add_update(layer.get_updates_for(None), None)
- # Keep track of losses that depend on the inputs
- # (e.g. activity regularizers).
- self.add_loss(layer.get_losses_for(computed_tensors), inputs)
- # Keep track of unconditional losses
- # (e.g. weight regularizers).
- self.add_loss(layer.get_losses_for(None), None)
-
- # Update `_uses_learning_phase`.
- if len(computed_tensors) == 1:
- uses_learning_phase = getattr(computed_tensors[0],
- '_uses_learning_phase', False)
- else:
- uses_learning_phase = any([
- getattr(x, '_uses_learning_phase', False)
- for x in computed_tensors
- ])
- for x in output_tensors:
- x._uses_learning_phase = getattr(x, '_uses_learning_phase',
- False) or uses_learning_phase
-
- # Update tensor_map.
- for x, y, mask in zip(reference_output_tensors, output_tensors,
- output_masks):
- tensor_map[str(id(x))] = (y, mask)
-
- output_tensors = []
- output_masks = []
- output_shapes = []
- for x in self.outputs:
- assert str(id(x)) in tensor_map, 'Could not compute output ' + str(x)
- tensor, mask = tensor_map[str(id(x))]
- output_shapes.append(K.int_shape(x))
- output_tensors.append(tensor)
- output_masks.append(mask)
-
- # Update cache;
- # keys are based on ids on input tensors and inputs masks.
- cache_key = ','.join([str(id(x)) for x in inputs])
- cache_key += '_' + ','.join([str(id(x)) for x in masks])
-
- if len(output_tensors) == 1:
- output_tensors = output_tensors[0]
- self._output_tensor_cache[cache_key] = output_tensors
- else:
- self._output_tensor_cache[cache_key] = output_tensors
-
- if len(output_masks) == 1:
- output_masks = output_masks[0]
- self._output_mask_cache[cache_key] = output_masks
- else:
- self._output_mask_cache[cache_key] = output_masks
-
- if output_shapes is not None:
- input_shapes = [K.int_shape(x) for x in inputs]
- cache_key = ','.join([str(x) for x in input_shapes])
- if len(output_shapes) == 1:
- output_shapes = output_shapes[0]
- self._output_shape_cache[cache_key] = output_shapes
- else:
- self._output_shape_cache[cache_key] = output_shapes
- return output_tensors, output_masks, output_shapes
-
def get_config(self):
config = {
'name': self.name,
}
node_conversion_map = {}
for layer in self.layers:
- if issubclass(layer.__class__, Container):
+ if issubclass(layer.__class__, Network):
# Containers start with a pre-existing node
# linking their input to output.
kept_nodes = 1
else:
kept_nodes = 0
for original_node_index, node in enumerate(layer.inbound_nodes):
- node_key = layer.name + '_ib-' + str(original_node_index)
- if node_key in self.container_nodes:
+ node_key = tf_base_layers._make_node_key(layer.name,
+ original_node_index)
+ if node_key in self._network_nodes:
node_conversion_map[node_key] = kept_nodes
kept_nodes += 1
layer_configs = []
@@ -2033,8 +841,9 @@ class Container(Layer):
layer_config = layer.get_config()
filtered_inbound_nodes = []
for original_node_index, node in enumerate(layer.inbound_nodes):
- node_key = layer.name + '_ib-' + str(original_node_index)
- if node_key in self.container_nodes:
+ node_key = tf_base_layers._make_node_key(layer.name,
+ original_node_index)
+ if node_key in self._network_nodes:
# The node is relevant to the model:
# add to filtered_inbound_nodes.
if node.arguments:
@@ -2057,7 +866,8 @@ class Container(Layer):
inbound_layer = node.inbound_layers[i]
node_index = node.node_indices[i]
tensor_index = node.tensor_indices[i]
- node_key = inbound_layer.name + '_ib-' + str(node_index)
+ node_key = tf_base_layers._make_node_key(inbound_layer.name,
+ node_index)
new_node_index = node_conversion_map.get(node_key, 0)
node_data.append(
[inbound_layer.name, new_node_index, tensor_index, kwargs])
@@ -2072,21 +882,19 @@ class Container(Layer):
# Gather info about inputs and outputs.
model_inputs = []
- for i in range(len(self.input_layers)):
- layer = self.input_layers[i]
- node_index = self.input_layers_node_indices[i]
- node_key = layer.name + '_ib-' + str(node_index)
+ for i in range(len(self._input_layers)):
+ layer, node_index, tensor_index = self._input_coordinates[i]
+ node_key = tf_base_layers._make_node_key(layer.name,
+ node_index)
new_node_index = node_conversion_map[node_key]
- tensor_index = self.input_layers_tensor_indices[i]
model_inputs.append([layer.name, new_node_index, tensor_index])
config['input_layers'] = model_inputs
model_outputs = []
- for i in range(len(self.output_layers)):
- layer = self.output_layers[i]
- node_index = self.output_layers_node_indices[i]
- node_key = layer.name + '_ib-' + str(node_index)
+ for i in range(len(self._output_layers)):
+ layer, node_index, tensor_index = self._output_coordinates[i]
+ node_key = tf_base_layers._make_node_key(layer.name,
+ node_index)
new_node_index = node_conversion_map[node_key]
- tensor_index = self.output_layers_tensor_indices[i]
model_outputs.append([layer.name, new_node_index, tensor_index])
config['output_layers'] = model_outputs
return copy.deepcopy(config)
@@ -2373,6 +1181,10 @@ class Container(Layer):
print_fn=print_fn)
+# Alias for legacy support.
+Container = Network
+
+
def get_source_inputs(tensor, layer=None, node_index=None):
"""Returns the list of input tensors necessary to compute `tensor`.
@@ -2436,41 +1248,6 @@ def _object_list_uid(object_list):
return ', '.join([str(abs(id(x))) for x in object_list])
-def _is_all_none(iterable_or_element):
- if not isinstance(iterable_or_element, (list, tuple)):
- iterable = [iterable_or_element]
- else:
- iterable = iterable_or_element
- for element in iterable:
- if element is not None:
- return False
- return True
-
-
-def _collect_previous_mask(input_tensors):
- """Retrieves the output mask(s) of the previous node.
-
- Arguments:
- input_tensors: A tensor or list of tensors.
-
- Returns:
- A mask tensor or list of mask tensors.
- """
- input_tensors = _to_list(input_tensors)
- masks = []
- for x in input_tensors:
- if hasattr(x, '_keras_history'):
- inbound_layer, node_index, tensor_index = x._keras_history
- node = inbound_layer.inbound_nodes[node_index]
- mask = node.output_masks[tensor_index]
- masks.append(mask)
- else:
- masks.append(None)
- if len(masks) == 1:
- return masks[0]
- return masks
-
-
def _to_snake_case(name):
intermediate = re.sub('(.)([A-Z][a-z0-9]+)', r'\1_\2', name)
insecure = re.sub('([a-z])([A-Z])', r'\1_\2', intermediate).lower()
diff --git a/tensorflow/contrib/keras/python/keras/engine/topology_test.py b/tensorflow/contrib/keras/python/keras/engine/topology_test.py
index ec4fa2eed8..f6c0b8a607 100644
--- a/tensorflow/contrib/keras/python/keras/engine/topology_test.py
+++ b/tensorflow/contrib/keras/python/keras/engine/topology_test.py
@@ -164,11 +164,9 @@ class TopologyConstructionTest(test.TestCase):
self.assertListEqual(node.inbound_layers, [])
self.assertListEqual(node.input_tensors, [a])
- self.assertListEqual(node.input_masks, [None])
self.assertListEqual(node.input_shapes, [(None, 32)])
self.assertListEqual(node.output_tensors, [a])
self.assertListEqual(node.output_shapes, [(None, 32)])
- self.assertListEqual(node.output_masks, [None])
dense = keras.layers.Dense(16, name='dense_1')
a_2 = dense(a)
@@ -189,22 +187,9 @@ class TopologyConstructionTest(test.TestCase):
self.assertListEqual(test_layer.kernel.get_shape().as_list(), [32, 16])
self.assertEqual(test_layer.input, a)
self.assertEqual(test_layer.output, a_test)
- self.assertEqual(test_layer.input_mask, None)
- self.assertEqual(test_layer.output_mask, None)
self.assertEqual(test_layer.input_shape, (None, 32))
self.assertEqual(test_layer.output_shape, (None, 16))
- # pylint: disable=pointless-statement
- with self.assertRaises(AttributeError):
- dense.input
- with self.assertRaises(AttributeError):
- dense.output
- with self.assertRaises(AttributeError):
- dense.input_mask
- with self.assertRaises(AttributeError):
- dense.output_mask
- # pylint: enable=pointless-statement
-
self.assertEqual(dense.get_input_at(0), a)
self.assertEqual(dense.get_input_at(1), b)
self.assertEqual(dense.get_output_at(0), a_2)
@@ -256,9 +241,9 @@ class TopologyConstructionTest(test.TestCase):
# ordering of same-level layers is not fixed
self.assertListEqual([l.name for l in model.layers][2:],
['dense_1', 'merge', 'dense_2', 'dense_3'])
- self.assertListEqual([l.name for l in model.input_layers],
+ self.assertListEqual([l.name for l in model._input_layers],
['input_a', 'input_b'])
- self.assertListEqual([l.name for l in model.output_layers],
+ self.assertListEqual([l.name for l in model._output_layers],
['dense_2', 'dense_3'])
# actually run model
@@ -278,9 +263,9 @@ class TopologyConstructionTest(test.TestCase):
self.assertListEqual([l.name for l in recreated_model.layers][2:],
['dense_1', 'merge', 'dense_2', 'dense_3'])
- self.assertListEqual([l.name for l in recreated_model.input_layers],
+ self.assertListEqual([l.name for l in recreated_model._input_layers],
['input_a', 'input_b'])
- self.assertListEqual([l.name for l in recreated_model.output_layers],
+ self.assertListEqual([l.name for l in recreated_model._output_layers],
['dense_2', 'dense_3'])
fn = keras.backend.function(recreated_model.inputs,
@@ -507,6 +492,12 @@ class TopologyConstructionTest(test.TestCase):
x = keras.layers.Input(tensor=x)
keras.layers.Dense(2)(x)
+ def test_basic_masking(self):
+ a = keras.layers.Input(shape=(10, 32), name='input_a')
+ b = keras.layers.Masking()(a)
+ model = keras.models.Model(a, b)
+ self.assertEqual(model.output_mask.get_shape().as_list(), [None, 10])
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/keras/python/keras/engine/training.py b/tensorflow/contrib/keras/python/keras/engine/training.py
index 1563cf8c41..bd8badd902 100644
--- a/tensorflow/contrib/keras/python/keras/engine/training.py
+++ b/tensorflow/contrib/keras/python/keras/engine/training.py
@@ -856,7 +856,7 @@ class Model(Container):
def append_metric(layer_num, metric_name, metric_tensor):
"""Helper function used in loop below."""
if len(self.output_names) > 1:
- metric_name = self.output_layers[layer_num].name + '_' + metric_name
+ metric_name = self._output_layers[layer_num].name + '_' + metric_name
self.metrics_names.append(metric_name)
self.metrics_tensors.append(metric_tensor)
diff --git a/tensorflow/contrib/keras/python/keras/models.py b/tensorflow/contrib/keras/python/keras/models.py
index 8864f5e69d..813a462f55 100644
--- a/tensorflow/contrib/keras/python/keras/models.py
+++ b/tensorflow/contrib/keras/python/keras/models.py
@@ -414,6 +414,7 @@ class Sequential(Model):
self.outputs = [] # List of length 1: the output tensor (unique).
self._trainable = True
self._initial_weights = None
+ self._input_layers = []
# Model attributes.
self.inbound_nodes = []
@@ -501,10 +502,7 @@ class Sequential(Model):
node_indices=[],
tensor_indices=[],
input_tensors=self.inputs,
- output_tensors=self.outputs,
- # no model-level masking for now
- input_masks=[None for _ in self.inputs],
- output_masks=[None])
+ output_tensors=self.outputs)
else:
output_tensor = layer(self.outputs[0])
if isinstance(output_tensor, list):
@@ -578,14 +576,12 @@ class Sequential(Model):
self._output_mask_cache = self.model._output_mask_cache
self._output_tensor_cache = self.model._output_tensor_cache
self._output_shape_cache = self.model._output_shape_cache
- self.input_layers = self.model.input_layers
- self.input_layers_node_indices = self.model.input_layers_node_indices
- self.input_layers_tensor_indices = self.model.input_layers_tensor_indices
- self.output_layers = self.model.output_layers
- self.output_layers_node_indices = self.model.output_layers_node_indices
- self.output_layers_tensor_indices = self.model.output_layers_tensor_indices
- self.nodes_by_depth = self.model.nodes_by_depth
- self.container_nodes = self.model.container_nodes
+ self._input_layers = self.model._input_layers
+ self._output_layers = self.model._output_layers
+ self._input_coordinates = self.model._input_coordinates
+ self._output_coordinates = self.model._output_coordinates
+ self._nodes_by_depth = self.model._nodes_by_depth
+ self._network_nodes = self.model._network_nodes
self.output_names = self.model.output_names
self.input_names = self.model.input_names
self._feed_input_names = self.model._feed_input_names
diff --git a/tensorflow/contrib/keras/python/keras/utils/layer_utils.py b/tensorflow/contrib/keras/python/keras/utils/layer_utils.py
index 1c3481fdb8..12d5368b08 100644
--- a/tensorflow/contrib/keras/python/keras/utils/layer_utils.py
+++ b/tensorflow/contrib/keras/python/keras/utils/layer_utils.py
@@ -46,7 +46,7 @@ def print_summary(model, line_length=None, positions=None, print_fn=None):
sequential_like = True
else:
sequential_like = True
- for v in model.nodes_by_depth.values():
+ for v in model._nodes_by_depth.values(): # pylint: disable=protected-access
if (len(v) > 1) or (len(v) == 1 and len(v[0].inbound_layers) > 1):
# If the model has multiple nodes or if the nodes have
# multiple inbound_layers, the model is no longer sequential.
@@ -68,7 +68,7 @@ def print_summary(model, line_length=None, positions=None, print_fn=None):
# header names for the different log elements
to_display = ['Layer (type)', 'Output Shape', 'Param #', 'Connected to']
relevant_nodes = []
- for v in model.nodes_by_depth.values():
+ for v in model._nodes_by_depth.values(): # pylint: disable=protected-access
relevant_nodes += v
def print_row(fields, positions):