diff options
Diffstat (limited to 'tensorflow/python/keras/layers/wrappers.py')
-rw-r--r-- | tensorflow/python/keras/layers/wrappers.py | 123 |
1 files changed, 119 insertions, 4 deletions
diff --git a/tensorflow/python/keras/layers/wrappers.py b/tensorflow/python/keras/layers/wrappers.py index e61acf8e77..f0c1e76156 100644 --- a/tensorflow/python/keras/layers/wrappers.py +++ b/tensorflow/python/keras/layers/wrappers.py @@ -47,7 +47,6 @@ class Wrapper(Layer): def __init__(self, layer, **kwargs): assert isinstance(layer, Layer) self.layer = layer - self._track_checkpointable(layer, name='layer') # Tracks mapping of Wrapper inputs to inner layer inputs. Useful when # the inner layer has update ops that depend on its inputs (as opposed # to the inputs to the Wrapper layer). @@ -168,6 +167,39 @@ class TimeDistributed(Wrapper): '`Layer` instance. You passed: {input}'.format(input=layer)) super(TimeDistributed, self).__init__(layer, **kwargs) self.supports_masking = True + self._track_checkpointable(layer, name='layer') + + def _get_shape_tuple(self, init_tuple, tensor, start_idx, int_shape=None): + """Finds non-specific dimensions in the static shapes. + + The static shapes are replaced with the corresponding dynamic shapes of the + tensor. + + Arguments: + init_tuple: a tuple, the first part of the output shape + tensor: the tensor from which to get the (static and dynamic) shapes + as the last part of the output shape + start_idx: int, which indicate the first dimension to take from + the static shape of the tensor + int_shape: an alternative static shape to take as the last part + of the output shape + Returns: + The new int_shape with the first part from init_tuple + and the last part from either `int_shape` (if provided) + or `tensor.shape`, where every `None` is replaced by + the corresponding dimension from `tf.shape(tensor)`. + """ + # replace all None in int_shape by K.shape + if int_shape is None: + int_shape = K.int_shape(tensor)[start_idx:] + if not any(not s for s in int_shape): + return init_tuple + tuple(int_shape) + shape = K.shape(tensor) + int_shape = list(int_shape) + for i, s in enumerate(int_shape): + if not s: + int_shape[i] = shape[start_idx + i] + return init_tuple + tuple(int_shape) def build(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape).as_list() @@ -224,18 +256,24 @@ class TimeDistributed(Wrapper): input_length = input_shape[1] if not input_length: input_length = array_ops.shape(inputs)[1] + inner_input_shape = self._get_shape_tuple((-1,), inputs, 2) # Shape: (num_samples * timesteps, ...). And track the # transformation in self._input_map. input_uid = generic_utils.object_list_uid(inputs) - inputs = array_ops.reshape(inputs, (-1,) + input_shape[2:]) + inputs = array_ops.reshape(inputs, inner_input_shape) self._input_map[input_uid] = inputs # (num_samples * timesteps, ...) + if generic_utils.has_arg(self.layer.call, 'mask') and mask is not None: + inner_mask_shape = self._get_shape_tuple((-1,), mask, 2) + kwargs['mask'] = K.reshape(mask, inner_mask_shape) y = self.layer.call(inputs, **kwargs) if hasattr(y, '_uses_learning_phase'): uses_learning_phase = y._uses_learning_phase # Shape: (num_samples, timesteps, ...) output_shape = self.compute_output_shape(input_shape).as_list() - y = array_ops.reshape(y, (-1, input_length) + tuple(output_shape[2:])) + output_shape = self._get_shape_tuple( + (-1, input_length), y, 1, output_shape[2:]) + y = array_ops.reshape(y, output_shape) # Apply activity regularizer if any: if (hasattr(self.layer, 'activity_regularizer') and @@ -247,6 +285,80 @@ class TimeDistributed(Wrapper): y._uses_learning_phase = True return y + def compute_mask(self, inputs, mask=None): + """Computes an output mask tensor for Embedding layer. + + This is based on the inputs, mask, and the inner layer. + If batch size is specified: + Simply return the input `mask`. (An rnn-based implementation with + more than one rnn inputs is required but not supported in tf.keras yet.) + Otherwise we call `compute_mask` of the inner layer at each time step. + If the output mask at each time step is not `None`: + (E.g., inner layer is Masking or RNN) + Concatenate all of them and return the concatenation. + If the output mask at each time step is `None` and the input mask is not + `None`:(E.g., inner layer is Dense) + Reduce the input_mask to 2 dimensions and return it. + Otherwise (both the output mask and the input mask are `None`): + (E.g., `mask` is not used at all) + Return `None`. + + Arguments: + inputs: Tensor with shape [batch size, timesteps, ...] indicating the + input to TimeDistributed. If static shape information is available for + "batch size", `mask` is returned unmodified. + mask: Either None (indicating no masking) or a Tensor indicating the + input mask for TimeDistributed. The shape can be static or dynamic. + + Returns: + Either None (no masking), or a [batch size, timesteps, ...] Tensor with + an output mask for the TimeDistributed layer with the shape beyond the + second dimension being the value of the input mask shape(if the computed + output mask is none), an output mask with the shape beyond the first + dimension being the value of the mask shape(if mask is not None) or + output mask with the shape beyond the first dimension being the + value of the computed output shape. + + """ + # cases need to call the layer.compute_mask when input_mask is None: + # Masking layer and Embedding layer with mask_zero + input_shape = K.int_shape(inputs) + if input_shape[0]: + # batch size matters, we currently do not handle mask explicitly + return mask + inner_mask = mask + if inner_mask is not None: + inner_mask_shape = self._get_shape_tuple((-1,), mask, 2) + inner_mask = K.reshape(inner_mask, inner_mask_shape) + input_uid = generic_utils.object_list_uid(inputs) + inner_inputs = self._input_map[input_uid] + output_mask = self.layer.compute_mask(inner_inputs, inner_mask) + if output_mask is None: + if mask is None: + return None + # input_mask is not None, and output_mask is None: + # we should return a not-None mask + output_mask = mask + for _ in range(2, len(K.int_shape(mask))): + output_mask = K.any(output_mask, axis=-1) + else: + # output_mask is not None. We need to reshape it + input_length = input_shape[1] + if not input_length: + input_length = K.shape(inputs)[1] + output_mask_int_shape = K.int_shape(output_mask) + if output_mask_int_shape is None: + # if the output_mask does not have a static shape, + # its shape must be the same as mask's + if mask is not None: + output_mask_int_shape = K.int_shape(mask) + else: + output_mask_int_shape = K.compute_output_shape(input_shape)[:-1] + output_mask_shape = self._get_shape_tuple( + (-1, input_length), output_mask, 1, output_mask_int_shape[1:]) + output_mask = K.reshape(output_mask, output_mask_shape) + return output_mask + @tf_export('keras.layers.Bidirectional') class Bidirectional(Wrapper): @@ -305,6 +417,8 @@ class Bidirectional(Wrapper): self._num_constants = None super(Bidirectional, self).__init__(layer, **kwargs) self.input_spec = layer.input_spec + self._track_checkpointable(self.forward_layer, name='forward_layer') + self._track_checkpointable(self.backward_layer, name='backward_layer') @property def trainable(self): @@ -414,7 +528,8 @@ class Bidirectional(Wrapper): else: return super(Bidirectional, self).__call__(inputs, **kwargs) - def call(self, inputs, + def call(self, + inputs, training=None, mask=None, initial_state=None, |