diff options
author | Francois Chollet <fchollet@google.com> | 2017-08-03 16:48:48 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-08-03 16:55:56 -0700 |
commit | 3599fd44d6bfcb16f45e763608a0e5da4e9072f5 (patch) | |
tree | 65366abce44c2ef95a62c53baeb634a77dfd0fad /tensorflow/contrib/keras | |
parent | 57626dd38a7867b76c44f3933e7810190174a2ee (diff) |
Add Network class. Networks are directed acyclic graphs of layers, that implement the full layer API. You can think of a network as a "bigger layer".
- Rename tf.contrib.keras Container as Network
- Add a Network class in tf.layers which implements the part of Container that we want to add to core layers.
- Make Keras Network subclass core Network.
PiperOrigin-RevId: 164202674
Diffstat (limited to 'tensorflow/contrib/keras')
5 files changed, 137 insertions, 1373 deletions
diff --git a/tensorflow/contrib/keras/python/keras/engine/topology.py b/tensorflow/contrib/keras/python/keras/engine/topology.py index c8c746e8af..4e8f6f9db1 100644 --- a/tensorflow/contrib/keras/python/keras/engine/topology.py +++ b/tensorflow/contrib/keras/python/keras/engine/topology.py @@ -29,10 +29,8 @@ from six.moves import zip # pylint: disable=redefined-builtin from tensorflow.contrib.keras.python.keras import backend as K from tensorflow.contrib.keras.python.keras.utils import conv_utils -from tensorflow.contrib.keras.python.keras.utils.generic_utils import has_arg from tensorflow.contrib.keras.python.keras.utils.io_utils import ask_to_proceed_with_overwrite from tensorflow.contrib.keras.python.keras.utils.layer_utils import print_summary as print_layer_summary -from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.layers import base as tf_base_layers from tensorflow.python.platform import tf_logging as logging @@ -51,125 +49,7 @@ except ImportError: # pylint: enable=g-import-not-at-top InputSpec = tf_base_layers.InputSpec # pylint: disable=invalid-name - - -class Node(object): - """A `Node` describes the connectivity between two layers. - - Each time a layer is connected to some new input, - a node is added to `layer.inbound_nodes`. - Each time the output of a layer is used by another layer, - a node is added to `layer.outbound_nodes`. - - Arguments: - outbound_layer: the layer that takes - `input_tensors` and turns them into `output_tensors` - (the node gets created when the `call` - method of the layer was called). - inbound_layers: a list of layers, the same length as `input_tensors`, - the layers from where `input_tensors` originate. - node_indices: a list of integers, the same length as `inbound_layers`. - `node_indices[i]` is the origin node of `input_tensors[i]` - (necessary since each inbound layer might have several nodes, - e.g. if the layer is being shared with a different data stream). - tensor_indices: a list of integers, - the same length as `inbound_layers`. - `tensor_indices[i]` is the index of `input_tensors[i]` within the - output of the inbound layer - (necessary since each inbound layer might - have multiple tensor outputs, with each one being - independently manipulable). - input_tensors: list of input tensors. - output_tensors: list of output tensors. - input_masks: list of input masks (a mask can be a tensor, or None). - output_masks: list of output masks (a mask can be a tensor, or None). - arguments: dictionary of keyword arguments that were passed to the - `call` method of the layer at the call that created the node. - - `node_indices` and `tensor_indices` are basically fine-grained coordinates - describing the origin of the `input_tensors`. - - A node from layer A to layer B is added to: - A.outbound_nodes - B.inbound_nodes - """ - - def __init__(self, - outbound_layer, - inbound_layers, - node_indices, - tensor_indices, - input_tensors, - output_tensors, - input_masks, - output_masks, - arguments=None): - # Layer instance (NOT a list). - # this is the layer that takes a list of input tensors - # and turns them into a list of output tensors. - # the current node will be added to - # the inbound_nodes of outbound_layer. - self.outbound_layer = outbound_layer - - # The following 3 properties describe where - # the input tensors come from: which layers, - # and for each layer, which node and which - # tensor output of each node. - - # List of layer instances. - self.inbound_layers = inbound_layers - # List of integers, 1:1 mapping with inbound_layers. - self.node_indices = node_indices - # List of integers, 1:1 mapping with inbound_layers. - self.tensor_indices = tensor_indices - - # Following 2 properties: - # tensor inputs and outputs of outbound_layer. - - # List of tensors. 1:1 mapping with inbound_layers. - self.input_tensors = input_tensors - # List of tensors, created by outbound_layer.call(). - self.output_tensors = output_tensors - - # Following 2 properties: input and output masks. - # List of tensors, 1:1 mapping with input_tensor. - self.input_masks = input_masks - # List of tensors, created by outbound_layer.compute_mask(). - self.output_masks = output_masks - - # Following 2 properties: input and output shapes. - - # List of shape tuples, shapes of input_tensors. - self.input_shapes = [K.int_shape(x) for x in input_tensors] - # List of shape tuples, shapes of output_tensors. - self.output_shapes = [K.int_shape(x) for x in output_tensors] - - # Optional keyword arguments to layer's `call`. - self.arguments = arguments - - # Add nodes to all layers involved. - for layer in inbound_layers: - if layer is not None: - layer.outbound_nodes.append(self) - outbound_layer.inbound_nodes.append(self) - - def get_config(self): - inbound_names = [] - for layer in self.inbound_layers: - if layer: - inbound_names.append(layer.name) - else: - inbound_names.append(None) - return { - 'outbound_layer': - self.outbound_layer.name if self.outbound_layer else None, - 'inbound_layers': - inbound_names, - 'node_indices': - self.node_indices, - 'tensor_indices': - self.tensor_indices - } +Node = tf_base_layers.Node # pylint: disable=invalid-name class Layer(tf_base_layers.Layer): @@ -274,15 +154,9 @@ class Layer(tf_base_layers.Layer): super(Layer, self).__init__(name=name, dtype=dtype, trainable=trainable) # Add properties that are Keras-only for now. - self.input_spec = None self.supports_masking = False self._constraints = {} # dict {tensor: constraint instance} - # These lists will be filled via successive calls - # to self._add_inbound_node(). - self.inbound_nodes = [] - self.outbound_nodes = [] - # Manage input shape information if passed. if 'input_shape' in kwargs or 'batch_input_shape' in kwargs: # In this case we will later create an input layer @@ -378,52 +252,17 @@ class Layer(tf_base_layers.Layer): ValueError: in case the layer is missing shape information for its `build` call. """ - if isinstance(inputs, list): - inputs = inputs[:] - - # Handle mask propagation. - previous_mask = _collect_previous_mask(inputs) - user_kwargs = copy.copy(kwargs) - if not _is_all_none(previous_mask): - # The previous layer generated a mask. - if has_arg(self.call, 'mask'): - if 'mask' not in kwargs: - # If mask is explicitly passed to __call__, - # we should override the default mask. - kwargs['mask'] = previous_mask - # Actually call the layer (optionally building it). output = super(Layer, self).__call__(inputs, **kwargs) - # Handle mask computation. - with K.name_scope(self.name): - output_mask = self.compute_mask(inputs, previous_mask) - - # If the layer returns tensors from its inputs, unmodified, - # we copy them to avoid loss of tensor metadata. - output_ls = _to_list(output) - inputs_ls = _to_list(inputs) - output_ls_copy = [] - for x in output_ls: - if x in inputs_ls: - x = K.identity(x) - output_ls_copy.append(x) - if len(output_ls_copy) == 1: - output = output_ls_copy[0] - else: - output = output_ls_copy - - # Add an inbound node to the layer, so that it keeps track - # of the call and of all new variables created during the call. - # This also updates the layer history of the output tensor(s). - # If the input tensor(s) had not previous Keras history, - # this does nothing. - self._add_inbound_node( - input_tensors=inputs, - output_tensors=output, - input_masks=previous_mask, - output_masks=output_mask, - arguments=user_kwargs) + # Update learning phase info. + output_tensors = _to_list(output) + uses_lp = any( + [getattr(x, '_uses_learning_phase', False) for x in _to_list(inputs)]) + uses_lp = getattr(self, 'uses_learning_phase', False) or uses_lp + for i in range(len(output_tensors)): + output_tensors[i]._uses_learning_phase = getattr( + output_tensors[i], '_uses_learning_phase', False) or uses_lp # Optionally load weight values that were specified at layer instantiation. if hasattr(self, '_initial_weights') and self._initial_weights is not None: @@ -431,63 +270,6 @@ class Layer(tf_base_layers.Layer): del self._initial_weights return output - def _add_inbound_node(self, - input_tensors, - output_tensors, - input_masks, - output_masks, - arguments=None): - """Internal method to create an inbound node for the layer. - - Arguments: - input_tensors: list of input tensors. - output_tensors: list of output tensors. - input_masks: list of input masks (a mask can be a tensor, or None). - output_masks: list of output masks (a mask can be a tensor, or None). - arguments: dictionary of keyword arguments that were passed to the - `call` method of the layer at the call that created the node. - """ - input_tensors = _to_list(input_tensors) - output_tensors = _to_list(output_tensors) - input_masks = _to_list(input_masks) - output_masks = _to_list(output_masks) - - # Collect input tensor(s) coordinates. - inbound_layers = [] - node_indices = [] - tensor_indices = [] - for x in input_tensors: - if hasattr(x, '_keras_history'): - inbound_layer, node_index, tensor_index = x._keras_history - inbound_layers.append(inbound_layer) - node_indices.append(node_index) - tensor_indices.append(tensor_index) - else: - inbound_layers.append(None) - node_indices.append(None) - tensor_indices.append(None) - - # Create node, add it to inbound nodes. - Node( - self, - inbound_layers=inbound_layers, - node_indices=node_indices, - tensor_indices=tensor_indices, - input_tensors=input_tensors, - output_tensors=output_tensors, - input_masks=input_masks, - output_masks=output_masks, - arguments=arguments) - - # Update tensor history and `_uses_learning_phase`. - for i in range(len(output_tensors)): - uses_lp = any( - [getattr(x, '_uses_learning_phase', False) for x in input_tensors]) - uses_lp = getattr(self, 'uses_learning_phase', False) or uses_lp - output_tensors[i]._uses_learning_phase = getattr( - output_tensors[i], '_uses_learning_phase', False) or uses_lp - output_tensors[i]._keras_history = (self, len(self.inbound_nodes) - 1, i) - def _compute_output_shape(self, input_shape): """Computes the output shape of the layer. @@ -534,115 +316,6 @@ class Layer(tf_base_layers.Layer): # carry over the input mask return mask - def build(self, input_shape): # pylint: disable=unused-argument - """Creates the layer weights. - - Must be implemented on all layers that have weights. - - Arguments: - input_shape: Keras tensor (future input to layer) - or list/tuple of Keras tensors to reference - for weight shape computations. - """ - self.built = True - - def _get_node_attribute_at_index(self, node_index, attr, attr_name): - """Retrieves an attribute (e.g. input_tensors) from a node. - - This is used to implement the methods: - - get_input_shape_at - - get_output_shape_at - - get_input_at - etc... - - Arguments: - node_index: Integer index of the node from which - to retrieve the attribute. - attr: Exact node attribute name. - attr_name: Human-readable attribute name, for error messages. - - Returns: - The layer's attribute `attr` at the node of index `node_index`. - - Raises: - RuntimeError: If the layer has no inbound nodes. - ValueError: If the index is does not match any node. - """ - if not self.inbound_nodes: - raise RuntimeError('The layer has never been called ' - 'and thus has no defined ' + attr_name + '.') - if not len(self.inbound_nodes) > node_index: - raise ValueError('Asked to get ' + attr_name + ' at node ' + - str(node_index) + ', but the layer has only ' + - str(len(self.inbound_nodes)) + ' inbound nodes.') - values = getattr(self.inbound_nodes[node_index], attr) - if len(values) == 1: - return values[0] - else: - return values - - def get_input_shape_at(self, node_index): - """Retrieves the input shape(s) of a layer at a given node. - - Arguments: - node_index: Integer, index of the node - from which to retrieve the attribute. - E.g. `node_index=0` will correspond to the - first time the layer was called. - - Returns: - A shape tuple - (or list of shape tuples if the layer has multiple inputs). - """ - return self._get_node_attribute_at_index(node_index, 'input_shapes', - 'input shape') - - def get_output_shape_at(self, node_index): - """Retrieves the output shape(s) of a layer at a given node. - - Arguments: - node_index: Integer, index of the node - from which to retrieve the attribute. - E.g. `node_index=0` will correspond to the - first time the layer was called. - - Returns: - A shape tuple - (or list of shape tuples if the layer has multiple outputs). - """ - return self._get_node_attribute_at_index(node_index, 'output_shapes', - 'output shape') - - def get_input_at(self, node_index): - """Retrieves the input tensor(s) of a layer at a given node. - - Arguments: - node_index: Integer, index of the node - from which to retrieve the attribute. - E.g. `node_index=0` will correspond to the - first time the layer was called. - - Returns: - A tensor (or list of tensors if the layer has multiple inputs). - """ - return self._get_node_attribute_at_index(node_index, 'input_tensors', - 'input') - - def get_output_at(self, node_index): - """Retrieves the output tensor(s) of a layer at a given node. - - Arguments: - node_index: Integer, index of the node - from which to retrieve the attribute. - E.g. `node_index=0` will correspond to the - first time the layer was called. - - Returns: - A tensor (or list of tensors if the layer has multiple outputs). - """ - return self._get_node_attribute_at_index(node_index, 'output_tensors', - 'output') - def get_input_mask_at(self, node_index): """Retrieves the input mask tensor(s) of a layer at a given node. @@ -656,8 +329,11 @@ class Layer(tf_base_layers.Layer): A mask tensor (or list of tensors if the layer has multiple inputs). """ - return self._get_node_attribute_at_index(node_index, 'input_masks', - 'input mask') + inputs = self.get_input_at(node_index) + if isinstance(inputs, list): + return [getattr(x, '_keras_mask', None) for x in inputs] + else: + return getattr(inputs, '_keras_mask', None) def get_output_mask_at(self, node_index): """Retrieves the output mask tensor(s) of a layer at a given node. @@ -672,57 +348,11 @@ class Layer(tf_base_layers.Layer): A mask tensor (or list of tensors if the layer has multiple outputs). """ - return self._get_node_attribute_at_index(node_index, 'output_masks', - 'output mask') - - @property - def input(self): - """Retrieves the input tensor(s) of a layer. - - Only applicable if the layer has exactly one inbound node, - i.e. if it is connected to one incoming layer. - - Returns: - Input tensor or list of input tensors. - - Raises: - AttributeError: if the layer is connected to - more than one incoming layers. - """ - if len(self.inbound_nodes) > 1: - raise AttributeError('Layer ' + self.name + - ' has multiple inbound nodes, ' - 'hence the notion of "layer input" ' - 'is ill-defined. ' - 'Use `get_input_at(node_index)` instead.') - elif not self.inbound_nodes: - raise AttributeError('Layer ' + self.name + - ' is not connected, no input to return.') - return self._get_node_attribute_at_index(0, 'input_tensors', 'input') - - @property - def output(self): - """Retrieves the output tensor(s) of a layer. - - Only applicable if the layer has exactly one inbound node, - i.e. if it is connected to one incoming layer. - - Returns: - Output tensor or list of output tensors. - - Raises: - AttributeError: if the layer is connected to - more than one incoming layers. - """ - if not self.inbound_nodes: - raise AttributeError('Layer ' + self.name + ' has no inbound nodes.') - if len(self.inbound_nodes) > 1: - raise AttributeError('Layer ' + self.name + - ' has multiple inbound nodes, ' - 'hence the notion of "layer output" ' - 'is ill-defined. ' - 'Use `get_output_at(node_index)` instead.') - return self._get_node_attribute_at_index(0, 'output_tensors', 'output') + output = self.get_output_at(node_index) + if isinstance(output, list): + return [getattr(x, '_keras_mask', None) for x in output] + else: + return getattr(output, '_keras_mask', None) @property def input_mask(self): @@ -739,14 +369,11 @@ class Layer(tf_base_layers.Layer): AttributeError: if the layer is connected to more than one incoming layers. """ - if len(self.inbound_nodes) != 1: - raise AttributeError('Layer ' + self.name + - ' has multiple inbound nodes, ' + - 'hence the notion of "layer input mask" ' - 'is ill-defined. ' - 'Use `get_input_mask_at(node_index)` ' - 'instead.') - return self._get_node_attribute_at_index(0, 'input_masks', 'input mask') + inputs = self.input + if isinstance(inputs, list): + return [getattr(x, '_keras_mask', None) for x in inputs] + else: + return getattr(inputs, '_keras_mask', None) @property def output_mask(self): @@ -763,90 +390,11 @@ class Layer(tf_base_layers.Layer): AttributeError: if the layer is connected to more than one incoming layers. """ - if len(self.inbound_nodes) != 1: - raise AttributeError('Layer ' + self.name + - ' has multiple inbound nodes, ' - 'hence the notion of "layer output mask" ' - 'is ill-defined. ' - 'Use `get_output_mask_at(node_index)` ' - 'instead.') - return self._get_node_attribute_at_index(0, 'output_masks', 'output mask') - - @property - def input_shape(self): - """Retrieves the input shape(s) of a layer. - - Only applicable if the layer has exactly one inbound node, - i.e. if it is connected to one incoming layer. - - Returns: - Input shape, as `TensorShape` - (or list of `TensorShape`, one tuple per input tensor). - - Raises: - AttributeError: if the layer is connected to - more than one incoming layers. - """ - if not self.inbound_nodes: - raise AttributeError('The layer has never been called ' - 'and thus has no defined input shape.') - all_input_shapes = set( - [str(node.input_shapes) for node in self.inbound_nodes]) - if len(all_input_shapes) == 1: - input_shapes = self.inbound_nodes[0].input_shapes - if len(input_shapes) == 1: - return tuple(tensor_shape.TensorShape(input_shapes[0]).as_list()) - else: - return [ - tuple(tensor_shape.TensorShape(shape).as_list()) - for shape in input_shapes - ] + output = self.output + if isinstance(output, list): + return [getattr(x, '_keras_mask', None) for x in output] else: - raise AttributeError('The layer "' + str(self.name) + - ' has multiple inbound nodes, ' - 'with different input shapes. Hence ' - 'the notion of "input shape" is ' - 'ill-defined for the layer. ' - 'Use `get_input_shape_at(node_index)` ' - 'instead.') - - @property - def output_shape(self): - """Retrieves the output shape(s) of a layer. - - Only applicable if the layer has one inbound node, - or if all inbound nodes have the same output shape. - - Returns: - Output shape, as `TensorShape` - (or list of `TensorShape`, one tuple per output tensor). - - Raises: - AttributeError: if the layer is connected to - more than one incoming layers. - """ - if not self.inbound_nodes: - raise AttributeError('The layer has never been called ' - 'and thus has no defined output shape.') - all_output_shapes = set( - [str(node.output_shapes) for node in self.inbound_nodes]) - if len(all_output_shapes) == 1: - output_shapes = self.inbound_nodes[0].output_shapes - if len(output_shapes) == 1: - return tuple(tensor_shape.TensorShape(output_shapes[0]).as_list()) - else: - return [ - tuple(tensor_shape.TensorShape(shape).as_list()) - for shape in output_shapes - ] - else: - raise AttributeError('The layer "' + str(self.name) + - ' has multiple inbound nodes, ' - 'with different output shapes. Hence ' - 'the notion of "output shape" is ' - 'ill-defined for the layer. ' - 'Use `get_output_shape_at(node_index)` ' - 'instead.') + return getattr(output, '_keras_mask', None) def set_weights(self, weights): """Sets the weights of the layer, from Numpy arrays. @@ -951,17 +499,15 @@ class Layer(tf_base_layers.Layer): return sum([K.count_params(p) for p in self.weights]) -class InputLayer(Layer): +class InputLayer(tf_base_layers.InputLayer, Layer): """Layer to be used as an entry point into a graph. It can either wrap an existing tensor (pass an `input_tensor` argument) - or create its a placeholder tensor (pass arguments `input_shape` - or `batch_input_shape` as well as `dtype`). + or create its a placeholder tensor (pass argument `input_shape`. Arguments: input_shape: Shape tuple, not including the batch axis. batch_size: Optional input batch size (integer or None). - batch_input_shape: Shape tuple, including the batch axis. dtype: Datatype of the input. input_tensor: Optional tensor to use as layer input instead of creating a placeholder. @@ -973,71 +519,44 @@ class InputLayer(Layer): def __init__(self, input_shape=None, batch_size=None, - batch_input_shape=None, dtype=None, input_tensor=None, sparse=False, - name=None): + name=None, + **kwargs): + if 'batch_input_shape' in kwargs: + batch_input_shape = kwargs.pop('batch_input_shape') + if input_shape and batch_input_shape: + raise ValueError('Only provide the input_shape OR ' + 'batch_input_shape argument to ' + 'InputLayer, not both at the same time.') + batch_size = batch_input_shape[0] + input_shape = batch_input_shape[1:] + if kwargs: + raise ValueError('Unrecognized keyword arguments:', kwargs.keys()) + if not name: prefix = 'input' name = prefix + '_' + str(K.get_uid(prefix)) + if not dtype: if input_tensor is None: dtype = K.floatx() else: dtype = K.dtype(input_tensor) - super(InputLayer, self).__init__(dtype=dtype, name=name) - self.built = True - self.sparse = sparse - - if input_shape and batch_input_shape: - raise ValueError('Only provide the input_shape OR ' - 'batch_input_shape argument to ' - 'InputLayer, not both at the same time.') + super(InputLayer, self).__init__(input_shape=input_shape, + batch_size=batch_size, + dtype=dtype, + input_tensor=input_tensor, + sparse=sparse, + name=name) + if input_tensor is not None: - # Attempt automatic input shape inference. - try: - batch_input_shape = K.int_shape(input_tensor) - except TypeError: - if not input_shape and not batch_input_shape: - raise ValueError('InputLayer was provided ' - 'an input_tensor argument, ' - 'but its input shape cannot be ' - 'automatically inferred. ' - 'You should pass an input_shape or ' - 'batch_input_shape argument.') - if not batch_input_shape: - if not input_shape: - raise ValueError('An Input layer should be passed either ' - 'a `batch_input_shape` or an `input_shape`.') - else: - batch_input_shape = (batch_size,) + tuple(input_shape) + self.is_placeholder = False + self.batch_input_shape = tuple(input_tensor.get_shape().as_list()) else: - batch_input_shape = tuple(batch_input_shape) - self.batch_input_shape = batch_input_shape - - if input_tensor is None: self.is_placeholder = True - input_tensor = K.placeholder( - shape=batch_input_shape, - dtype=dtype, - sparse=self.sparse, - name=self.name) - else: - self.is_placeholder = False - # Create an input node to add to self.outbound_node - # and set output_tensors' _keras_history. - input_tensor._uses_learning_phase = False - input_tensor._keras_history = (self, 0, 0) - Node( - self, - inbound_layers=[], - node_indices=[], - tensor_indices=[], - input_tensors=[input_tensor], - output_tensors=[input_tensor], - input_masks=[None], - output_masks=[None]) + self.batch_input_shape = (batch_size,) + tuple(input_shape) def get_config(self): config = { @@ -1051,11 +570,12 @@ class InputLayer(Layer): def Input( # pylint: disable=invalid-name shape=None, - batch_shape=None, + batch_size=None, name=None, - dtype=K.floatx(), + dtype=None, sparse=False, - tensor=None): + tensor=None, + **kwargs): """`Input()` is used to instantiate a Keras tensor. A Keras tensor is a tensor object from the underlying backend @@ -1073,14 +593,10 @@ def Input( # pylint: disable=invalid-name recursively. Arguments: - shape: A shape tuple (integer), not including the batch size. + shape: A shape tuple (integers), not including the batch size. For instance, `shape=(32,)` indicates that the expected input will be batches of 32-dimensional vectors. - batch_shape: A shape tuple (integer), including the batch size. - For instance, `batch_shape=(10, 32)` indicates that - the expected input will be batches of 10 32-dimensional vectors. - `batch_shape=(None, 32)` indicates batches of an arbitrary number - of 32-dimensional vectors. + batch_size: optional static batch size (integer). name: An optional name string for the layer. Should be unique in a model (do not reuse the same name twice). It will be autogenerated if it isn't provided. @@ -1090,6 +606,7 @@ def Input( # pylint: disable=invalid-name to be created is sparse. tensor: Optional existing tensor to wrap into the `Input` layer. If set, the layer will not create a placeholder tensor. + **kwargs: deprecated arguments support. Returns: A tensor. @@ -1102,16 +619,31 @@ def Input( # pylint: disable=invalid-name y = Dense(16, activation='softmax')(x) model = Model(x, y) ``` + + Raises: + ValueError: in case of invalid arguments. """ - if not batch_shape and tensor is None: - assert shape, ('Please provide to Input either a `shape`' - ' or a `batch_shape` argument. Note that ' - '`shape` does not include the batch ' - 'dimension.') - if shape and not batch_shape: - batch_shape = (None,) + tuple(shape) + if 'batch_shape' in kwargs: + batch_shape = kwargs.pop('batch_shape') + if shape and batch_shape: + raise ValueError('Only provide the shape OR ' + 'batch_shape argument to ' + 'Input, not both at the same time.') + batch_size = batch_shape[0] + shape = batch_shape[1:] + if kwargs: + raise ValueError('Unrecognized keyword arguments:', kwargs.keys()) + + if dtype is None: + dtype = K.floatx() + if not shape and tensor is None: + raise ValueError('Please provide to Input either a `shape`' + ' or a `tensor` argument. Note that ' + '`shape` does not include the batch ' + 'dimension.') input_layer = InputLayer( - batch_input_shape=batch_shape, + input_shape=shape, + batch_size=batch_size, name=name, dtype=dtype, sparse=sparse, @@ -1125,7 +657,7 @@ def Input( # pylint: disable=invalid-name return outputs -class Container(Layer): +class Network(tf_base_layers.Network, Layer): """A Container is a directed acyclic graph of layers. It is the topological form of a "model". A Model @@ -1162,118 +694,20 @@ class Container(Layer): from_config """ - def __init__(self, inputs, outputs, name=None): # pylint: disable=super-init-not-called - # Handle `name` argument. - if not name: - prefix = self.__class__.__name__.lower() - name = prefix + '_' + str(K.get_uid(prefix)) - self.name = name - self.supports_masking = False - self.trainable = True - self._per_input_losses = {} - self._per_input_updates = {} - - # The following properties are not actually used by Keras; - # they exist for compatibility with TF. - self._updates = [] - self._losses = [] - self._scope = None - self._reuse = None - self._base_name = name - self._graph = ops.get_default_graph() - - # Container-specific properties. - if isinstance(inputs, (list, tuple)): - self.inputs = list(inputs) # Tensor or list of tensors. - else: - self.inputs = [inputs] - if isinstance(outputs, (list, tuple)): - self.outputs = list(outputs) - else: - self.outputs = [outputs] - - # Check for redundancy in inputs. - inputs_set = set(self.inputs) - if len(inputs_set) != len(self.inputs): - raise ValueError('The list of inputs passed to the model ' - 'is redundant. ' - 'All inputs should only appear once.' - ' Found: ' + str(self.inputs)) - - # List of initial layers (1 to 1 mapping with self.inputs, - # hence the same layer might appear twice) - self.input_layers = [] - self.input_layers_node_indices = [] - self.input_layers_tensor_indices = [] - # list of layers (1 to 1 mapping with self.inputs, - # hence the same layer might appear twice) - self.output_layers = [] - self.output_layers_node_indices = [] - self.output_layers_tensor_indices = [] - # all layers in order of horizontal graph traversal. - # Entries are unique. Includes input and output layers. - self.layers = [] - - # This is for performance optimization - # when calling the Container on new inputs. - # every time the Container is called on a set on input tensors, - # we compute the output tensors, - # output masks and output shapes in one pass, - # then cache them here. When of of these output is queried later, - # we retrieve it from there instead of recomputing it. - self._output_mask_cache = {} - self._output_tensor_cache = {} - self._output_shape_cache = {} - - # User-provided arguments validation. - for x in self.inputs: - # Check that x is a Keras tensor. - if not hasattr(x, '_keras_history'): - cls_name = self.__class__.__name__ - raise TypeError('Input tensors to a ' + cls_name + ' ' + - 'must be Keras tensors. Found: ' + str(x) + - ' (missing Keras metadata).') - # Check that x is an input tensor. - layer, node_index, tensor_index = x._keras_history - if len(layer.inbound_nodes) > 1 or ( - layer.inbound_nodes and layer.inbound_nodes[0].inbound_layers): - cls_name = self.__class__.__name__ - logging.warning(cls_name + ' inputs must come from ' - 'a Keras Input layer, ' - 'they cannot be the output of ' - 'a previous non-Input layer. ' - 'Here, a tensor specified as ' - 'input to "' + self.name + '" was not an Input tensor, ' - 'it was generated by layer ' + layer.name + '.\n' - 'Note that input tensors are ' - 'instantiated via `tensor = Input(shape)`.\n' - 'The tensor that caused the issue was: ' + str(x.name)) - for x in self.outputs: - if not hasattr(x, '_keras_history'): - cls_name = self.__class__.__name__ - raise TypeError('Output tensors to a ' + cls_name + ' must be ' - 'Keras tensors. Found: ' + str(x)) - # Build self.output_layers: - for x in self.outputs: - layer, node_index, tensor_index = x._keras_history - self.output_layers.append(layer) - self.output_layers_node_indices.append(node_index) - self.output_layers_tensor_indices.append(tensor_index) + def __init__(self, inputs, outputs, name=None): + super(Network, self).__init__(inputs, outputs, name=name) + self.supports_masking = False # Fill in the output mask cache. masks = [] for x in self.inputs: - layer, node_index, tensor_index = x._keras_history - node = layer.inbound_nodes[node_index] - mask = node.output_masks[tensor_index] + mask = x._keras_mask if hasattr(x, '_keras_mask') else None masks.append(mask) - mask_cache_key = ','.join([str(id(x)) for x in self.inputs]) - mask_cache_key += '_' + ','.join([str(id(x)) for x in masks]) + mask_cache_key = (tf_base_layers._object_list_uid(self.inputs) + '_' + + tf_base_layers._object_list_uid(masks)) masks = [] for x in self.outputs: - layer, node_index, tensor_index = x._keras_history - node = layer.inbound_nodes[node_index] - mask = node.output_masks[tensor_index] + mask = x._keras_mask if hasattr(x, '_keras_mask') else None masks.append(mask) if len(masks) == 1: mask = masks[0] @@ -1281,322 +715,24 @@ class Container(Layer): mask = masks self._output_mask_cache[mask_cache_key] = mask - # Build self.input_layers: - for x in self.inputs: - layer, node_index, tensor_index = x._keras_history - # It's supposed to be an input layer, so only one node - # and one tensor output. - assert node_index == 0 - assert tensor_index == 0 - self.input_layers.append(layer) - self.input_layers_node_indices.append(node_index) - self.input_layers_tensor_indices.append(tensor_index) - # Build self.input_names and self.output_names. self.input_names = [] self.output_names = [] self._feed_input_names = [] self._feed_inputs = [] self._feed_input_shapes = [] - for i, layer in enumerate(self.input_layers): + for i, layer in enumerate(self._input_layers): self.input_names.append(layer.name) if layer.is_placeholder: self._feed_input_names.append(layer.name) self._feed_inputs.append(layer.input) self._feed_input_shapes.append(K.int_shape(self.inputs[i])) - for layer in self.output_layers: + for layer in self._output_layers: self.output_names.append(layer.name) self.internal_input_shapes = [K.int_shape(x) for x in self.inputs] self.internal_output_shapes = [K.int_shape(x) for x in self.outputs] - # Container_nodes: set of nodes included in the graph - # (not all nodes included in the layers - # are relevant to the current graph). - container_nodes = set() # ids of all nodes relevant to the Container - nodes_depths = {} # dict {node: depth value} - layers_depths = {} # dict {layer: depth value} - layer_indices = {} # dict {layer: index in traversal} - nodes_in_decreasing_depth = [] - - def build_map_of_graph(tensor, - finished_nodes, - nodes_in_progress, - layer=None, - node_index=None, - tensor_index=None): - """Builds a map of the graph of layers. - - This recursively updates the map `layer_indices`, - the list `nodes_in_decreasing_depth` and the set `container_nodes`. - - Arguments: - tensor: Some tensor in a graph. - finished_nodes: Set of nodes whose subgraphs have been traversed - completely. Useful to prevent duplicated work. - nodes_in_progress: Set of nodes that are currently active on the - recursion stack. Useful to detect cycles. - layer: Layer from which `tensor` comes from. If not provided, - will be obtained from `tensor._keras_history`. - node_index: Node index from which `tensor` comes from. - tensor_index: Tensor_index from which `tensor` comes from. - - Raises: - RuntimeError: if a cycle is detected. - """ - if not layer or node_index is None or tensor_index is None: - layer, node_index, tensor_index = tensor._keras_history - node = layer.inbound_nodes[node_index] - - # Prevent cycles. - if node in nodes_in_progress: - raise RuntimeError('The tensor ' + str(tensor) + ' at layer "' + - layer.name + '" is part of a cycle.') - - # Don't repeat work for shared subgraphs - if node in finished_nodes: - return - - node_key = layer.name + '_ib-' + str(node_index) - # Update container_nodes. - container_nodes.add(node_key) - - # Store the traversal order for layer sorting. - if layer not in layer_indices: - layer_indices[layer] = len(layer_indices) - - nodes_in_progress.add(node) - - # Propagate to all previous tensors connected to this node. - for i in range(len(node.inbound_layers)): - x = node.input_tensors[i] - layer = node.inbound_layers[i] - node_index = node.node_indices[i] - tensor_index = node.tensor_indices[i] - build_map_of_graph(x, finished_nodes, nodes_in_progress, layer, - node_index, tensor_index) - - finished_nodes.add(node) - nodes_in_progress.remove(node) - - nodes_in_decreasing_depth.append(node) - - finished_nodes = set() - nodes_in_progress = set() - for x in self.outputs: - build_map_of_graph(x, finished_nodes, nodes_in_progress) - - for node in reversed(nodes_in_decreasing_depth): - # If the depth is not set, the node has no outbound nodes (depth 0). - depth = nodes_depths.setdefault(node, 0) - - # Update the depth of the corresponding layer - previous_depth = layers_depths.get(node.outbound_layer, 0) - # If we've seen this layer before at a higher depth, - # we should use that depth instead of the node depth. - # This is necessary for shared layers that have inputs at different - # depth levels in the graph. - depth = max(depth, previous_depth) - layers_depths[node.outbound_layer] = depth - nodes_depths[node] = depth - - # Update the depth of inbound nodes. - for i in range(len(node.inbound_layers)): - inbound_layer = node.inbound_layers[i] - node_index = node.node_indices[i] - inbound_node = inbound_layer.inbound_nodes[node_index] - previous_depth = nodes_depths.get(inbound_node, 0) - nodes_depths[inbound_node] = max(depth + 1, previous_depth) - - # Build a dict {depth: list of nodes with this depth} - nodes_by_depth = {} - for node, depth in nodes_depths.items(): - if depth not in nodes_by_depth: - nodes_by_depth[depth] = [] - nodes_by_depth[depth].append(node) - - # Build a dict {depth: list of layers with this depth} - layers_by_depth = {} - for layer, depth in layers_depths.items(): - if depth not in layers_by_depth: - layers_by_depth[depth] = [] - layers_by_depth[depth].append(layer) - - # Get sorted list of layer depths. - depth_keys = list(layers_by_depth.keys()) - depth_keys.sort(reverse=True) - - # Set self.layers and self.layers_by_depth. - layers = [] - for depth in depth_keys: - layers_for_depth = layers_by_depth[depth] - # Container.layers needs to have a deterministic order: - # here we order them by traversal order. - layers_for_depth.sort(key=lambda x: layer_indices[x]) - for layer in layers_for_depth: - layers.append(layer) - self.layers = layers - self.layers_by_depth = layers_by_depth - - # Get sorted list of node depths. - depth_keys = list(nodes_by_depth.keys()) - depth_keys.sort(reverse=True) - - # Check that all tensors required are computable. - # computable_tensors: all tensors in the graph - # that can be computed from the inputs provided. - computable_tensors = [] - for x in self.inputs: - computable_tensors.append(x) - - layers_with_complete_input = [] # To provide a better error msg. - for depth in depth_keys: - for node in nodes_by_depth[depth]: - layer = node.outbound_layer - if layer: - for x in node.input_tensors: - if x not in computable_tensors: - raise RuntimeError('Graph disconnected: ' - 'cannot obtain value for tensor ' + str(x) + - ' at layer "' + layer.name + '". ' - 'The following previous layers ' - 'were accessed without issue: ' + - str(layers_with_complete_input)) - for x in node.output_tensors: - computable_tensors.append(x) - layers_with_complete_input.append(layer.name) - - # Set self.nodes and self.nodes_by_depth. - self.container_nodes = container_nodes - self.nodes_by_depth = nodes_by_depth - - # Ensure name unicity, which will be crucial for serialization - # (since serialized nodes refer to layers by their name). - all_names = [layer.name for layer in self.layers] - for name in all_names: - if all_names.count(name) != 1: - raise RuntimeError('The name "' + name + '" is used ' + - str(all_names.count(name)) + ' times in the model. ' - 'All layer names should be unique.') - - # Layer parameters. - # The new container starts with a single inbound node - # for its inputs, and no outbound nodes. - self.outbound_nodes = [] # Will be appended to by future calls to __call__ - self.inbound_nodes = [ - ] # Will be appended to below, and by future calls to __call__ - # Create the node linking internal inputs to internal outputs. - Node( - outbound_layer=self, - inbound_layers=[], - node_indices=[], - tensor_indices=[], - input_tensors=self.inputs, - output_tensors=self.outputs, - # No container-level masking for now. - input_masks=[None for _ in self.inputs], - output_masks=[None for _ in self.outputs]) - self.built = True - - # The following are implemented as property functions: - # self.constraints - # self.trainable_weights - # self.non_trainable_weights - # self.input_spec - - def get_layer(self, name=None, index=None): - """Retrieves a layer based on either its name (unique) or index. - - Indices are based on order of horizontal graph traversal (bottom-up). - - Arguments: - name: String, name of layer. - index: Integer, index of layer. - - Returns: - A layer instance. - - Raises: - ValueError: In case of invalid layer name or index. - """ - # It would be unreliable to build a dictionary - # based on layer names, because names can potentially - # be changed at any point by the user - # without the container being notified of it. - if index is not None: - if len(self.layers) <= index: - raise ValueError('Was asked to retrieve layer at index ' + str(index) + - ' but model only has ' + str(len(self.layers)) + - ' layers.') - else: - return self.layers[index] - else: - if not name: - raise ValueError('Provide either a layer name or layer index.') - layer = None - for layer in self.layers: - if layer.name == name: - return layer - if not layer: - raise ValueError('No such layer: ' + name) - - @property - def updates(self): - """Retrieve the model's updates. - - Will only include updates that are either - unconditional, or conditional on inputs to this model - (e.g. will not include updates that depend on tensors - that aren't inputs to this model). - - Returns: - A list of update ops. - """ - updates = [] - for layer in self.layers: - if hasattr(layer, 'updates'): - # Collect updates that are dependent on inputs - # that are part of the model. - for node_index, node in enumerate(layer.inbound_nodes): - node_key = layer.name + '_ib-' + str(node_index) - if node_key in self.container_nodes: - # The model owns this layer node. - inputs = node.input_tensors - updates += layer.get_updates_for(inputs) - # Collect unconditional updates. - updates += layer.get_updates_for(None) - return updates - - @property - def losses(self): - """Retrieve the model's losses. - - Will only include losses that are either - unconditional, or conditional on inputs to this model - (e.g. will not include losses that depend on tensors - that aren't inputs to this model). - - Returns: - A list of loss tensors. - """ - losses = [] - # Retrieve losses for all internal layers. - for layer in self.layers: - if hasattr(layer, 'losses'): - # Collect losses that are dependent on inputs - # that are part of the model. - for node_index, node in enumerate(layer.inbound_nodes): - node_key = layer.name + '_ib-' + str(node_index) - if node_key in self.container_nodes: - # The model owns this layer node. - inputs = node.input_tensors - losses += layer.get_losses_for(inputs) - # Collect unconditional losses. - losses += layer.get_losses_for(None) - # Add any potential unconditional model-level loss. - losses += self.get_losses_for(None) - return losses - @property def uses_learning_phase(self): return any([x._uses_learning_phase for x in self.outputs]) @@ -1640,27 +776,6 @@ class Container(Layer): cons[key] = value return cons - @property - def trainable_weights(self): - if not self.trainable: - return [] - weights = [] - for layer in self.layers: - weights += layer.trainable_weights - return weights - - @property - def non_trainable_weights(self): - weights = [] - for layer in self.layers: - weights += layer.non_trainable_weights - if not self.trainable: - trainable_weights = [] - for layer in self.layers: - trainable_weights += layer.trainable_weights - return trainable_weights + weights - return weights - def get_weights(self): """Retrieves the weights of the model. @@ -1688,60 +803,6 @@ class Container(Layer): weights = weights[num_param:] K.batch_set_value(tuples) - @property - def input_spec(self): - """Gets the model's input specs. - - Returns: - A list of `InputSpec` instances (one per input to the model) - or a single instance if the model has only one input. - """ - specs = [] - for layer in getattr(self, 'input_layers', []): - if layer.input_spec is None: - specs.append(None) - else: - if not isinstance(layer.input_spec, list): - raise TypeError('Layer ' + layer.name + - ' has an input_spec attribute that ' - 'is not a list. We expect a list. ' - 'Found input_spec = ' + str(layer.input_spec)) - specs += layer.input_spec - if len(specs) == 1: - return specs[0] - return specs - - def call(self, inputs, mask=None): - """Call the model on new inputs. - - In this case `call` just reapplies - all ops in the graph to the new inputs - (e.g. build a new computational graph from the provided inputs). - - A model is callable on non-Keras tensors. - - Arguments: - inputs: A tensor or list of tensors. - mask: A mask or list of masks. A mask can be - either a tensor or None (no mask). - - Returns: - A tensor if there is a single output, or - a list of tensors if there are more than one outputs. - """ - inputs = _to_list(inputs) - if mask is None: - masks = [None for _ in range(len(inputs))] - else: - masks = _to_list(mask) - cache_key = ','.join([str(id(x)) for x in inputs]) - cache_key += '_' + ','.join([str(id(x)) for x in masks]) - if cache_key in self._output_tensor_cache: - return self._output_tensor_cache[cache_key] - else: - output_tensors, _, _ = self.run_internal_graph(inputs, masks) - return output_tensors - def compute_mask(self, inputs, mask): inputs = _to_list(inputs) if mask is None: @@ -1753,278 +814,25 @@ class Container(Layer): if cache_key in self._output_mask_cache: return self._output_mask_cache[cache_key] else: - _, output_masks, _ = self.run_internal_graph(inputs, masks) + _, output_masks, _ = self._run_internal_graph(inputs, masks) return output_masks - def _compute_output_shape(self, input_shape): - if isinstance(input_shape, list): - input_shapes = [] - for shape in input_shape: - if shape is not None: - input_shapes.append(tuple(tensor_shape.TensorShape(shape).as_list())) - else: - input_shapes.append(None) - else: - if input_shape is not None: - input_shapes = [tuple(tensor_shape.TensorShape(input_shape).as_list())] - else: - input_shapes = [None] - - if len(input_shapes) != len(self.input_layers): - raise ValueError('Invalid input_shape argument ' + str(input_shape) + - ': model has ' + str(len(self.input_layers)) + - ' tensor inputs.') - - cache_key = ','.join([str(x) for x in input_shapes]) - if cache_key in self._output_shape_cache: - output_shapes = self._output_shape_cache[cache_key] - if isinstance(output_shapes, list): - if len(output_shapes) == 1: - return tensor_shape.TensorShape(output_shapes[0]) - else: - return [tensor_shape.TensorShape(shape) for shape in output_shapes] - else: - return tensor_shape.TensorShape(output_shapes) - else: - # Bad luck, we have to run the graph manually. - layers_to_output_shapes = {} - for i in range(len(input_shapes)): - layer = self.input_layers[i] - input_shape = input_shapes[i] - # It's an input layer: compute_output_shape is identity, - # and there is only one node and one tensor output. - shape_key = layer.name + '_0_0' - layers_to_output_shapes[shape_key] = input_shape - - depth_keys = list(self.nodes_by_depth.keys()) - depth_keys.sort(reverse=True) - # Iterate over nodes, by depth level. - if len(depth_keys) > 1: - for depth in depth_keys: - nodes = self.nodes_by_depth[depth] - for node in nodes: - # This is always a single layer, never a list. - layer = node.outbound_layer - if layer in self.input_layers: - # We've already covered the input layers - # a few lines above. - continue - # Potentially redundant list, - # same size of node.input_tensors. - input_shapes = [] - for j in range(len(node.inbound_layers)): - inbound_layer = node.inbound_layers[j] - node_index = node.node_indices[j] - tensor_index = node.tensor_indices[j] - shape_key = inbound_layer.name + '_%s_%s' % (node_index, - tensor_index) - input_shape = layers_to_output_shapes[shape_key] - input_shapes.append(input_shape) - - if len(input_shapes) == 1: - output_shape = layer._compute_output_shape(input_shapes[0]) - else: - output_shape = layer._compute_output_shape(input_shapes) - if isinstance(output_shape, list): - output_shapes = [ - tuple(tensor_shape.TensorShape(shape).as_list()) - for shape in output_shape - ] - else: - output_shapes = [ - tuple(tensor_shape.TensorShape(output_shape).as_list()) - ] - - node_index = layer.inbound_nodes.index(node) - for j in range(len(output_shapes)): - shape_key = layer.name + '_%s_%s' % (node_index, j) - layers_to_output_shapes[shape_key] = output_shapes[j] - - # Read final output shapes from layers_to_output_shapes. - output_shapes = [] - output_shape_keys = [] - for i in range(len(self.output_layers)): - layer = self.output_layers[i] - node_index = self.output_layers_node_indices[i] - tensor_index = self.output_layers_tensor_indices[i] - shape_key = layer.name + '_%s_%s' % (node_index, tensor_index) - output_shape_keys.append(shape_key) - - for i, key in enumerate(output_shape_keys): - assert key in layers_to_output_shapes - output_shapes.append(layers_to_output_shapes[key]) - # Store in cache. - self._output_shape_cache[cache_key] = output_shapes - if isinstance(output_shapes, list): - if len(output_shapes) == 1: - return tensor_shape.TensorShape(output_shapes[0]) - else: - return [tensor_shape.TensorShape(shape) for shape in output_shapes] - else: - return tensor_shape.TensorShape(output_shapes) - - def run_internal_graph(self, inputs, masks=None): - """Computes output tensors for new inputs. - - # Note: - - Expects `inputs` to be a list (potentially with 1 element). - - Can be run on non-Keras tensors. - - Arguments: - inputs: List of tensors - masks: List of masks (tensors or None). - - Returns: - Three lists: output_tensors, output_masks, output_shapes - """ - if masks is None: - masks = [None for _ in range(len(inputs))] - - # Dictionary mapping reference tensors to tuples - # (computed tensor, compute mask) - # we assume a 1:1 mapping from tensor to mask - # TODO(fchollet): raise exception when a `.compute_mask()` call - # does not return a list the same size as `call` - tensor_map = {} - for x, y, mask in zip(self.inputs, inputs, masks): - tensor_map[str(id(x))] = (y, mask) - - depth_keys = list(self.nodes_by_depth.keys()) - depth_keys.sort(reverse=True) - for depth in depth_keys: - nodes = self.nodes_by_depth[depth] - for node in nodes: - # This is always a single layer, never a list. - layer = node.outbound_layer - - reference_input_tensors = node.input_tensors - reference_output_tensors = node.output_tensors - - # If all previous input tensors are available in tensor_map, - # then call node.inbound_layer on them. - computed_data = [] # List of tuples (input, mask). - for x in reference_input_tensors: - if str(id(x)) in tensor_map: - computed_data.append(tensor_map[str(id(x))]) - - if len(computed_data) == len(reference_input_tensors): - # call layer - with K.name_scope(layer.name): - if node.arguments: - kwargs = node.arguments - else: - kwargs = {} - if len(computed_data) == 1: - computed_tensor, computed_mask = computed_data[0] - if has_arg(layer.call, 'mask'): - if 'mask' not in kwargs: - kwargs['mask'] = computed_mask - output_tensors = _to_list(layer.call(computed_tensor, **kwargs)) - output_masks = _to_list( - layer.compute_mask(computed_tensor, computed_mask)) - computed_tensors = [computed_tensor] - computed_masks = [computed_mask] - else: - computed_tensors = [x[0] for x in computed_data] - computed_masks = [x[1] for x in computed_data] - if has_arg(layer.call, 'mask'): - if 'mask' not in kwargs: - kwargs['mask'] = computed_masks - output_tensors = _to_list(layer.call(computed_tensors, **kwargs)) - output_masks = _to_list( - layer.compute_mask(computed_tensors, computed_masks)) - - # Apply activity regularizer if any: - if hasattr(layer, 'activity_regularizer' - ) and layer.activity_regularizer is not None: - regularization_losses = [ - layer.activity_regularizer(x) for x in computed_tensors - ] - layer.add_loss(regularization_losses, computed_tensors) - - # Update model updates and losses: - # Keep track of updates that depend on the inputs - # (e.g. BN updates). - self.add_update(layer.get_updates_for(computed_tensors), inputs) - # Keep track of unconditional updates (e.g. a counter). - self.add_update(layer.get_updates_for(None), None) - # Keep track of losses that depend on the inputs - # (e.g. activity regularizers). - self.add_loss(layer.get_losses_for(computed_tensors), inputs) - # Keep track of unconditional losses - # (e.g. weight regularizers). - self.add_loss(layer.get_losses_for(None), None) - - # Update `_uses_learning_phase`. - if len(computed_tensors) == 1: - uses_learning_phase = getattr(computed_tensors[0], - '_uses_learning_phase', False) - else: - uses_learning_phase = any([ - getattr(x, '_uses_learning_phase', False) - for x in computed_tensors - ]) - for x in output_tensors: - x._uses_learning_phase = getattr(x, '_uses_learning_phase', - False) or uses_learning_phase - - # Update tensor_map. - for x, y, mask in zip(reference_output_tensors, output_tensors, - output_masks): - tensor_map[str(id(x))] = (y, mask) - - output_tensors = [] - output_masks = [] - output_shapes = [] - for x in self.outputs: - assert str(id(x)) in tensor_map, 'Could not compute output ' + str(x) - tensor, mask = tensor_map[str(id(x))] - output_shapes.append(K.int_shape(x)) - output_tensors.append(tensor) - output_masks.append(mask) - - # Update cache; - # keys are based on ids on input tensors and inputs masks. - cache_key = ','.join([str(id(x)) for x in inputs]) - cache_key += '_' + ','.join([str(id(x)) for x in masks]) - - if len(output_tensors) == 1: - output_tensors = output_tensors[0] - self._output_tensor_cache[cache_key] = output_tensors - else: - self._output_tensor_cache[cache_key] = output_tensors - - if len(output_masks) == 1: - output_masks = output_masks[0] - self._output_mask_cache[cache_key] = output_masks - else: - self._output_mask_cache[cache_key] = output_masks - - if output_shapes is not None: - input_shapes = [K.int_shape(x) for x in inputs] - cache_key = ','.join([str(x) for x in input_shapes]) - if len(output_shapes) == 1: - output_shapes = output_shapes[0] - self._output_shape_cache[cache_key] = output_shapes - else: - self._output_shape_cache[cache_key] = output_shapes - return output_tensors, output_masks, output_shapes - def get_config(self): config = { 'name': self.name, } node_conversion_map = {} for layer in self.layers: - if issubclass(layer.__class__, Container): + if issubclass(layer.__class__, Network): # Containers start with a pre-existing node # linking their input to output. kept_nodes = 1 else: kept_nodes = 0 for original_node_index, node in enumerate(layer.inbound_nodes): - node_key = layer.name + '_ib-' + str(original_node_index) - if node_key in self.container_nodes: + node_key = tf_base_layers._make_node_key(layer.name, + original_node_index) + if node_key in self._network_nodes: node_conversion_map[node_key] = kept_nodes kept_nodes += 1 layer_configs = [] @@ -2033,8 +841,9 @@ class Container(Layer): layer_config = layer.get_config() filtered_inbound_nodes = [] for original_node_index, node in enumerate(layer.inbound_nodes): - node_key = layer.name + '_ib-' + str(original_node_index) - if node_key in self.container_nodes: + node_key = tf_base_layers._make_node_key(layer.name, + original_node_index) + if node_key in self._network_nodes: # The node is relevant to the model: # add to filtered_inbound_nodes. if node.arguments: @@ -2057,7 +866,8 @@ class Container(Layer): inbound_layer = node.inbound_layers[i] node_index = node.node_indices[i] tensor_index = node.tensor_indices[i] - node_key = inbound_layer.name + '_ib-' + str(node_index) + node_key = tf_base_layers._make_node_key(inbound_layer.name, + node_index) new_node_index = node_conversion_map.get(node_key, 0) node_data.append( [inbound_layer.name, new_node_index, tensor_index, kwargs]) @@ -2072,21 +882,19 @@ class Container(Layer): # Gather info about inputs and outputs. model_inputs = [] - for i in range(len(self.input_layers)): - layer = self.input_layers[i] - node_index = self.input_layers_node_indices[i] - node_key = layer.name + '_ib-' + str(node_index) + for i in range(len(self._input_layers)): + layer, node_index, tensor_index = self._input_coordinates[i] + node_key = tf_base_layers._make_node_key(layer.name, + node_index) new_node_index = node_conversion_map[node_key] - tensor_index = self.input_layers_tensor_indices[i] model_inputs.append([layer.name, new_node_index, tensor_index]) config['input_layers'] = model_inputs model_outputs = [] - for i in range(len(self.output_layers)): - layer = self.output_layers[i] - node_index = self.output_layers_node_indices[i] - node_key = layer.name + '_ib-' + str(node_index) + for i in range(len(self._output_layers)): + layer, node_index, tensor_index = self._output_coordinates[i] + node_key = tf_base_layers._make_node_key(layer.name, + node_index) new_node_index = node_conversion_map[node_key] - tensor_index = self.output_layers_tensor_indices[i] model_outputs.append([layer.name, new_node_index, tensor_index]) config['output_layers'] = model_outputs return copy.deepcopy(config) @@ -2373,6 +1181,10 @@ class Container(Layer): print_fn=print_fn) +# Alias for legacy support. +Container = Network + + def get_source_inputs(tensor, layer=None, node_index=None): """Returns the list of input tensors necessary to compute `tensor`. @@ -2436,41 +1248,6 @@ def _object_list_uid(object_list): return ', '.join([str(abs(id(x))) for x in object_list]) -def _is_all_none(iterable_or_element): - if not isinstance(iterable_or_element, (list, tuple)): - iterable = [iterable_or_element] - else: - iterable = iterable_or_element - for element in iterable: - if element is not None: - return False - return True - - -def _collect_previous_mask(input_tensors): - """Retrieves the output mask(s) of the previous node. - - Arguments: - input_tensors: A tensor or list of tensors. - - Returns: - A mask tensor or list of mask tensors. - """ - input_tensors = _to_list(input_tensors) - masks = [] - for x in input_tensors: - if hasattr(x, '_keras_history'): - inbound_layer, node_index, tensor_index = x._keras_history - node = inbound_layer.inbound_nodes[node_index] - mask = node.output_masks[tensor_index] - masks.append(mask) - else: - masks.append(None) - if len(masks) == 1: - return masks[0] - return masks - - def _to_snake_case(name): intermediate = re.sub('(.)([A-Z][a-z0-9]+)', r'\1_\2', name) insecure = re.sub('([a-z])([A-Z])', r'\1_\2', intermediate).lower() diff --git a/tensorflow/contrib/keras/python/keras/engine/topology_test.py b/tensorflow/contrib/keras/python/keras/engine/topology_test.py index ec4fa2eed8..f6c0b8a607 100644 --- a/tensorflow/contrib/keras/python/keras/engine/topology_test.py +++ b/tensorflow/contrib/keras/python/keras/engine/topology_test.py @@ -164,11 +164,9 @@ class TopologyConstructionTest(test.TestCase): self.assertListEqual(node.inbound_layers, []) self.assertListEqual(node.input_tensors, [a]) - self.assertListEqual(node.input_masks, [None]) self.assertListEqual(node.input_shapes, [(None, 32)]) self.assertListEqual(node.output_tensors, [a]) self.assertListEqual(node.output_shapes, [(None, 32)]) - self.assertListEqual(node.output_masks, [None]) dense = keras.layers.Dense(16, name='dense_1') a_2 = dense(a) @@ -189,22 +187,9 @@ class TopologyConstructionTest(test.TestCase): self.assertListEqual(test_layer.kernel.get_shape().as_list(), [32, 16]) self.assertEqual(test_layer.input, a) self.assertEqual(test_layer.output, a_test) - self.assertEqual(test_layer.input_mask, None) - self.assertEqual(test_layer.output_mask, None) self.assertEqual(test_layer.input_shape, (None, 32)) self.assertEqual(test_layer.output_shape, (None, 16)) - # pylint: disable=pointless-statement - with self.assertRaises(AttributeError): - dense.input - with self.assertRaises(AttributeError): - dense.output - with self.assertRaises(AttributeError): - dense.input_mask - with self.assertRaises(AttributeError): - dense.output_mask - # pylint: enable=pointless-statement - self.assertEqual(dense.get_input_at(0), a) self.assertEqual(dense.get_input_at(1), b) self.assertEqual(dense.get_output_at(0), a_2) @@ -256,9 +241,9 @@ class TopologyConstructionTest(test.TestCase): # ordering of same-level layers is not fixed self.assertListEqual([l.name for l in model.layers][2:], ['dense_1', 'merge', 'dense_2', 'dense_3']) - self.assertListEqual([l.name for l in model.input_layers], + self.assertListEqual([l.name for l in model._input_layers], ['input_a', 'input_b']) - self.assertListEqual([l.name for l in model.output_layers], + self.assertListEqual([l.name for l in model._output_layers], ['dense_2', 'dense_3']) # actually run model @@ -278,9 +263,9 @@ class TopologyConstructionTest(test.TestCase): self.assertListEqual([l.name for l in recreated_model.layers][2:], ['dense_1', 'merge', 'dense_2', 'dense_3']) - self.assertListEqual([l.name for l in recreated_model.input_layers], + self.assertListEqual([l.name for l in recreated_model._input_layers], ['input_a', 'input_b']) - self.assertListEqual([l.name for l in recreated_model.output_layers], + self.assertListEqual([l.name for l in recreated_model._output_layers], ['dense_2', 'dense_3']) fn = keras.backend.function(recreated_model.inputs, @@ -507,6 +492,12 @@ class TopologyConstructionTest(test.TestCase): x = keras.layers.Input(tensor=x) keras.layers.Dense(2)(x) + def test_basic_masking(self): + a = keras.layers.Input(shape=(10, 32), name='input_a') + b = keras.layers.Masking()(a) + model = keras.models.Model(a, b) + self.assertEqual(model.output_mask.get_shape().as_list(), [None, 10]) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/keras/python/keras/engine/training.py b/tensorflow/contrib/keras/python/keras/engine/training.py index 1563cf8c41..bd8badd902 100644 --- a/tensorflow/contrib/keras/python/keras/engine/training.py +++ b/tensorflow/contrib/keras/python/keras/engine/training.py @@ -856,7 +856,7 @@ class Model(Container): def append_metric(layer_num, metric_name, metric_tensor): """Helper function used in loop below.""" if len(self.output_names) > 1: - metric_name = self.output_layers[layer_num].name + '_' + metric_name + metric_name = self._output_layers[layer_num].name + '_' + metric_name self.metrics_names.append(metric_name) self.metrics_tensors.append(metric_tensor) diff --git a/tensorflow/contrib/keras/python/keras/models.py b/tensorflow/contrib/keras/python/keras/models.py index 8864f5e69d..813a462f55 100644 --- a/tensorflow/contrib/keras/python/keras/models.py +++ b/tensorflow/contrib/keras/python/keras/models.py @@ -414,6 +414,7 @@ class Sequential(Model): self.outputs = [] # List of length 1: the output tensor (unique). self._trainable = True self._initial_weights = None + self._input_layers = [] # Model attributes. self.inbound_nodes = [] @@ -501,10 +502,7 @@ class Sequential(Model): node_indices=[], tensor_indices=[], input_tensors=self.inputs, - output_tensors=self.outputs, - # no model-level masking for now - input_masks=[None for _ in self.inputs], - output_masks=[None]) + output_tensors=self.outputs) else: output_tensor = layer(self.outputs[0]) if isinstance(output_tensor, list): @@ -578,14 +576,12 @@ class Sequential(Model): self._output_mask_cache = self.model._output_mask_cache self._output_tensor_cache = self.model._output_tensor_cache self._output_shape_cache = self.model._output_shape_cache - self.input_layers = self.model.input_layers - self.input_layers_node_indices = self.model.input_layers_node_indices - self.input_layers_tensor_indices = self.model.input_layers_tensor_indices - self.output_layers = self.model.output_layers - self.output_layers_node_indices = self.model.output_layers_node_indices - self.output_layers_tensor_indices = self.model.output_layers_tensor_indices - self.nodes_by_depth = self.model.nodes_by_depth - self.container_nodes = self.model.container_nodes + self._input_layers = self.model._input_layers + self._output_layers = self.model._output_layers + self._input_coordinates = self.model._input_coordinates + self._output_coordinates = self.model._output_coordinates + self._nodes_by_depth = self.model._nodes_by_depth + self._network_nodes = self.model._network_nodes self.output_names = self.model.output_names self.input_names = self.model.input_names self._feed_input_names = self.model._feed_input_names diff --git a/tensorflow/contrib/keras/python/keras/utils/layer_utils.py b/tensorflow/contrib/keras/python/keras/utils/layer_utils.py index 1c3481fdb8..12d5368b08 100644 --- a/tensorflow/contrib/keras/python/keras/utils/layer_utils.py +++ b/tensorflow/contrib/keras/python/keras/utils/layer_utils.py @@ -46,7 +46,7 @@ def print_summary(model, line_length=None, positions=None, print_fn=None): sequential_like = True else: sequential_like = True - for v in model.nodes_by_depth.values(): + for v in model._nodes_by_depth.values(): # pylint: disable=protected-access if (len(v) > 1) or (len(v) == 1 and len(v[0].inbound_layers) > 1): # If the model has multiple nodes or if the nodes have # multiple inbound_layers, the model is no longer sequential. @@ -68,7 +68,7 @@ def print_summary(model, line_length=None, positions=None, print_fn=None): # header names for the different log elements to_display = ['Layer (type)', 'Output Shape', 'Param #', 'Connected to'] relevant_nodes = [] - for v in model.nodes_by_depth.values(): + for v in model._nodes_by_depth.values(): # pylint: disable=protected-access relevant_nodes += v def print_row(fields, positions): |