aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/layers/wrappers.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/layers/wrappers.py')
-rw-r--r--tensorflow/python/keras/layers/wrappers.py123
1 files changed, 119 insertions, 4 deletions
diff --git a/tensorflow/python/keras/layers/wrappers.py b/tensorflow/python/keras/layers/wrappers.py
index e61acf8e77..f0c1e76156 100644
--- a/tensorflow/python/keras/layers/wrappers.py
+++ b/tensorflow/python/keras/layers/wrappers.py
@@ -47,7 +47,6 @@ class Wrapper(Layer):
def __init__(self, layer, **kwargs):
assert isinstance(layer, Layer)
self.layer = layer
- self._track_checkpointable(layer, name='layer')
# Tracks mapping of Wrapper inputs to inner layer inputs. Useful when
# the inner layer has update ops that depend on its inputs (as opposed
# to the inputs to the Wrapper layer).
@@ -168,6 +167,39 @@ class TimeDistributed(Wrapper):
'`Layer` instance. You passed: {input}'.format(input=layer))
super(TimeDistributed, self).__init__(layer, **kwargs)
self.supports_masking = True
+ self._track_checkpointable(layer, name='layer')
+
+ def _get_shape_tuple(self, init_tuple, tensor, start_idx, int_shape=None):
+ """Finds non-specific dimensions in the static shapes.
+
+ The static shapes are replaced with the corresponding dynamic shapes of the
+ tensor.
+
+ Arguments:
+ init_tuple: a tuple, the first part of the output shape
+ tensor: the tensor from which to get the (static and dynamic) shapes
+ as the last part of the output shape
+ start_idx: int, which indicate the first dimension to take from
+ the static shape of the tensor
+ int_shape: an alternative static shape to take as the last part
+ of the output shape
+ Returns:
+ The new int_shape with the first part from init_tuple
+ and the last part from either `int_shape` (if provided)
+ or `tensor.shape`, where every `None` is replaced by
+ the corresponding dimension from `tf.shape(tensor)`.
+ """
+ # replace all None in int_shape by K.shape
+ if int_shape is None:
+ int_shape = K.int_shape(tensor)[start_idx:]
+ if not any(not s for s in int_shape):
+ return init_tuple + tuple(int_shape)
+ shape = K.shape(tensor)
+ int_shape = list(int_shape)
+ for i, s in enumerate(int_shape):
+ if not s:
+ int_shape[i] = shape[start_idx + i]
+ return init_tuple + tuple(int_shape)
def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
@@ -224,18 +256,24 @@ class TimeDistributed(Wrapper):
input_length = input_shape[1]
if not input_length:
input_length = array_ops.shape(inputs)[1]
+ inner_input_shape = self._get_shape_tuple((-1,), inputs, 2)
# Shape: (num_samples * timesteps, ...). And track the
# transformation in self._input_map.
input_uid = generic_utils.object_list_uid(inputs)
- inputs = array_ops.reshape(inputs, (-1,) + input_shape[2:])
+ inputs = array_ops.reshape(inputs, inner_input_shape)
self._input_map[input_uid] = inputs
# (num_samples * timesteps, ...)
+ if generic_utils.has_arg(self.layer.call, 'mask') and mask is not None:
+ inner_mask_shape = self._get_shape_tuple((-1,), mask, 2)
+ kwargs['mask'] = K.reshape(mask, inner_mask_shape)
y = self.layer.call(inputs, **kwargs)
if hasattr(y, '_uses_learning_phase'):
uses_learning_phase = y._uses_learning_phase
# Shape: (num_samples, timesteps, ...)
output_shape = self.compute_output_shape(input_shape).as_list()
- y = array_ops.reshape(y, (-1, input_length) + tuple(output_shape[2:]))
+ output_shape = self._get_shape_tuple(
+ (-1, input_length), y, 1, output_shape[2:])
+ y = array_ops.reshape(y, output_shape)
# Apply activity regularizer if any:
if (hasattr(self.layer, 'activity_regularizer') and
@@ -247,6 +285,80 @@ class TimeDistributed(Wrapper):
y._uses_learning_phase = True
return y
+ def compute_mask(self, inputs, mask=None):
+ """Computes an output mask tensor for Embedding layer.
+
+ This is based on the inputs, mask, and the inner layer.
+ If batch size is specified:
+ Simply return the input `mask`. (An rnn-based implementation with
+ more than one rnn inputs is required but not supported in tf.keras yet.)
+ Otherwise we call `compute_mask` of the inner layer at each time step.
+ If the output mask at each time step is not `None`:
+ (E.g., inner layer is Masking or RNN)
+ Concatenate all of them and return the concatenation.
+ If the output mask at each time step is `None` and the input mask is not
+ `None`:(E.g., inner layer is Dense)
+ Reduce the input_mask to 2 dimensions and return it.
+ Otherwise (both the output mask and the input mask are `None`):
+ (E.g., `mask` is not used at all)
+ Return `None`.
+
+ Arguments:
+ inputs: Tensor with shape [batch size, timesteps, ...] indicating the
+ input to TimeDistributed. If static shape information is available for
+ "batch size", `mask` is returned unmodified.
+ mask: Either None (indicating no masking) or a Tensor indicating the
+ input mask for TimeDistributed. The shape can be static or dynamic.
+
+ Returns:
+ Either None (no masking), or a [batch size, timesteps, ...] Tensor with
+ an output mask for the TimeDistributed layer with the shape beyond the
+ second dimension being the value of the input mask shape(if the computed
+ output mask is none), an output mask with the shape beyond the first
+ dimension being the value of the mask shape(if mask is not None) or
+ output mask with the shape beyond the first dimension being the
+ value of the computed output shape.
+
+ """
+ # cases need to call the layer.compute_mask when input_mask is None:
+ # Masking layer and Embedding layer with mask_zero
+ input_shape = K.int_shape(inputs)
+ if input_shape[0]:
+ # batch size matters, we currently do not handle mask explicitly
+ return mask
+ inner_mask = mask
+ if inner_mask is not None:
+ inner_mask_shape = self._get_shape_tuple((-1,), mask, 2)
+ inner_mask = K.reshape(inner_mask, inner_mask_shape)
+ input_uid = generic_utils.object_list_uid(inputs)
+ inner_inputs = self._input_map[input_uid]
+ output_mask = self.layer.compute_mask(inner_inputs, inner_mask)
+ if output_mask is None:
+ if mask is None:
+ return None
+ # input_mask is not None, and output_mask is None:
+ # we should return a not-None mask
+ output_mask = mask
+ for _ in range(2, len(K.int_shape(mask))):
+ output_mask = K.any(output_mask, axis=-1)
+ else:
+ # output_mask is not None. We need to reshape it
+ input_length = input_shape[1]
+ if not input_length:
+ input_length = K.shape(inputs)[1]
+ output_mask_int_shape = K.int_shape(output_mask)
+ if output_mask_int_shape is None:
+ # if the output_mask does not have a static shape,
+ # its shape must be the same as mask's
+ if mask is not None:
+ output_mask_int_shape = K.int_shape(mask)
+ else:
+ output_mask_int_shape = K.compute_output_shape(input_shape)[:-1]
+ output_mask_shape = self._get_shape_tuple(
+ (-1, input_length), output_mask, 1, output_mask_int_shape[1:])
+ output_mask = K.reshape(output_mask, output_mask_shape)
+ return output_mask
+
@tf_export('keras.layers.Bidirectional')
class Bidirectional(Wrapper):
@@ -305,6 +417,8 @@ class Bidirectional(Wrapper):
self._num_constants = None
super(Bidirectional, self).__init__(layer, **kwargs)
self.input_spec = layer.input_spec
+ self._track_checkpointable(self.forward_layer, name='forward_layer')
+ self._track_checkpointable(self.backward_layer, name='backward_layer')
@property
def trainable(self):
@@ -414,7 +528,8 @@ class Bidirectional(Wrapper):
else:
return super(Bidirectional, self).__call__(inputs, **kwargs)
- def call(self, inputs,
+ def call(self,
+ inputs,
training=None,
mask=None,
initial_state=None,