aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/layers/local.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/layers/local.py')
-rw-r--r--tensorflow/python/keras/layers/local.py340
1 files changed, 312 insertions, 28 deletions
diff --git a/tensorflow/python/keras/layers/local.py b/tensorflow/python/keras/layers/local.py
index 0ebafe07cc..33d09a1660 100644
--- a/tensorflow/python/keras/layers/local.py
+++ b/tensorflow/python/keras/layers/local.py
@@ -85,6 +85,28 @@ class LocallyConnected1D(Layer):
the output of the layer (its "activation")..
kernel_constraint: Constraint function applied to the kernel matrix.
bias_constraint: Constraint function applied to the bias vector.
+ implementation: implementation mode, either `1` or `2`.
+ `1` loops over input spatial locations to perform the forward pass.
+ It is memory-efficient but performs a lot of (small) ops.
+
+ `2` stores layer weights in a dense but sparsely-populated 2D matrix
+ and implements the forward pass as a single matrix-multiply. It uses
+ a lot of RAM but performs few (large) ops.
+
+ Depending on the inputs, layer parameters, hardware, and
+ `tf.executing_eagerly()` one implementation can be dramatically faster
+ (e.g. 50X) than another.
+
+ It is recommended to benchmark both in the setting of interest to pick
+ the most efficient one (in terms of speed and memory usage).
+
+ Following scenarios could benefit from setting `implementation=2`:
+ - eager execution;
+ - inference;
+ - running on CPU;
+ - large amount of RAM available;
+ - small models (few filters, small kernel);
+ - using `padding=same` (only possible with `implementation=2`).
Input shape:
3D tensor with shape: `(batch_size, steps, input_dim)`
@@ -109,15 +131,17 @@ class LocallyConnected1D(Layer):
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
+ implementation=1,
**kwargs):
super(LocallyConnected1D, self).__init__(**kwargs)
self.filters = filters
self.kernel_size = conv_utils.normalize_tuple(kernel_size, 1, 'kernel_size')
self.strides = conv_utils.normalize_tuple(strides, 1, 'strides')
self.padding = conv_utils.normalize_padding(padding)
- if self.padding != 'valid':
+ if self.padding != 'valid' and implementation == 1:
raise ValueError('Invalid border mode for LocallyConnected1D '
- '(only "valid" is supported): ' + padding)
+ '(only "valid" is supported if implementation is 1): '
+ + padding)
self.data_format = conv_utils.normalize_data_format(data_format)
self.activation = activations.get(activation)
self.use_bias = use_bias
@@ -128,6 +152,7 @@ class LocallyConnected1D(Layer):
self.activity_regularizer = regularizers.get(activity_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
+ self.implementation = implementation
self.input_spec = InputSpec(ndim=3)
@tf_utils.shape_type_conversion
@@ -142,14 +167,45 @@ class LocallyConnected1D(Layer):
'Found shape:', input_shape)
self.output_length = conv_utils.conv_output_length(
input_length, self.kernel_size[0], self.padding, self.strides[0])
- self.kernel_shape = (self.output_length, self.kernel_size[0] * input_dim,
- self.filters)
- self.kernel = self.add_weight(
- shape=self.kernel_shape,
- initializer=self.kernel_initializer,
- name='kernel',
- regularizer=self.kernel_regularizer,
- constraint=self.kernel_constraint)
+
+ if self.implementation == 1:
+ self.kernel_shape = (self.output_length, self.kernel_size[0] * input_dim,
+ self.filters)
+
+ self.kernel = self.add_weight(
+ shape=self.kernel_shape,
+ initializer=self.kernel_initializer,
+ name='kernel',
+ regularizer=self.kernel_regularizer,
+ constraint=self.kernel_constraint)
+
+ elif self.implementation == 2:
+ if self.data_format == 'channels_first':
+ self.kernel_shape = (input_dim, input_length,
+ self.filters, self.output_length)
+ else:
+ self.kernel_shape = (input_length, input_dim,
+ self.output_length, self.filters)
+
+ self.kernel = self.add_weight(shape=self.kernel_shape,
+ initializer=self.kernel_initializer,
+ name='kernel',
+ regularizer=self.kernel_regularizer,
+ constraint=self.kernel_constraint)
+
+ self.kernel_mask = get_locallyconnected_mask(
+ input_shape=(input_length,),
+ kernel_shape=self.kernel_size,
+ strides=self.strides,
+ padding=self.padding,
+ data_format=self.data_format,
+ dtype=self.kernel.dtype
+ )
+
+ else:
+ raise ValueError('Unrecognized implementation mode: %d.'
+ % self.implementation)
+
if self.use_bias:
self.bias = self.add_weight(
shape=(self.output_length, self.filters),
@@ -182,8 +238,17 @@ class LocallyConnected1D(Layer):
return (input_shape[0], length, self.filters)
def call(self, inputs):
- output = K.local_conv(inputs, self.kernel, self.kernel_size, self.strides,
- (self.output_length,), self.data_format)
+ if self.implementation == 1:
+ output = K.local_conv(inputs, self.kernel, self.kernel_size, self.strides,
+ (self.output_length,), self.data_format)
+
+ elif self.implementation == 2:
+ output = local_conv_matmul(inputs, self.kernel, self.kernel_mask,
+ self.compute_output_shape(inputs.shape))
+
+ else:
+ raise ValueError('Unrecognized implementation mode: %d.'
+ % self.implementation)
if self.use_bias:
output = K.bias_add(output, self.bias, data_format=self.data_format)
@@ -220,7 +285,9 @@ class LocallyConnected1D(Layer):
'kernel_constraint':
constraints.serialize(self.kernel_constraint),
'bias_constraint':
- constraints.serialize(self.bias_constraint)
+ constraints.serialize(self.bias_constraint),
+ 'implementation':
+ self.implementation
}
base_config = super(LocallyConnected1D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@@ -284,9 +351,31 @@ class LocallyConnected2D(Layer):
the `kernel` weights matrix.
bias_regularizer: Regularizer function applied to the bias vector.
activity_regularizer: Regularizer function applied to
- the output of the layer (its "activation")..
+ the output of the layer (its "activation").
kernel_constraint: Constraint function applied to the kernel matrix.
bias_constraint: Constraint function applied to the bias vector.
+ implementation: implementation mode, either `1` or `2`.
+ `1` loops over input spatial locations to perform the forward pass.
+ It is memory-efficient but performs a lot of (small) ops.
+
+ `2` stores layer weights in a dense but sparsely-populated 2D matrix
+ and implements the forward pass as a single matrix-multiply. It uses
+ a lot of RAM but performs few (large) ops.
+
+ Depending on the inputs, layer parameters, hardware, and
+ `tf.executing_eagerly()` one implementation can be dramatically faster
+ (e.g. 50X) than another.
+
+ It is recommended to benchmark both in the setting of interest to pick
+ the most efficient one (in terms of speed and memory usage).
+
+ Following scenarios could benefit from setting `implementation=2`:
+ - eager execution;
+ - inference;
+ - running on CPU;
+ - large amount of RAM available;
+ - small models (few filters, small kernel);
+ - using `padding=same` (only possible with `implementation=2`).
Input shape:
4D tensor with shape:
@@ -317,15 +406,17 @@ class LocallyConnected2D(Layer):
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
+ implementation=1,
**kwargs):
super(LocallyConnected2D, self).__init__(**kwargs)
self.filters = filters
self.kernel_size = conv_utils.normalize_tuple(kernel_size, 2, 'kernel_size')
self.strides = conv_utils.normalize_tuple(strides, 2, 'strides')
self.padding = conv_utils.normalize_padding(padding)
- if self.padding != 'valid':
+ if self.padding != 'valid' and implementation == 1:
raise ValueError('Invalid border mode for LocallyConnected2D '
- '(only "valid" is supported): ' + padding)
+ '(only "valid" is supported if implementation is 1): '
+ + padding)
self.data_format = conv_utils.normalize_data_format(data_format)
self.activation = activations.get(activation)
self.use_bias = use_bias
@@ -336,6 +427,7 @@ class LocallyConnected2D(Layer):
self.activity_regularizer = regularizers.get(activity_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
+ self.implementation = implementation
self.input_spec = InputSpec(ndim=4)
@tf_utils.shape_type_conversion
@@ -357,15 +449,47 @@ class LocallyConnected2D(Layer):
self.padding, self.strides[1])
self.output_row = output_row
self.output_col = output_col
- self.kernel_shape = (
- output_row * output_col,
- self.kernel_size[0] * self.kernel_size[1] * input_filter, self.filters)
- self.kernel = self.add_weight(
- shape=self.kernel_shape,
- initializer=self.kernel_initializer,
- name='kernel',
- regularizer=self.kernel_regularizer,
- constraint=self.kernel_constraint)
+
+ if self.implementation == 1:
+ self.kernel_shape = (
+ output_row * output_col,
+ self.kernel_size[0] * self.kernel_size[1] * input_filter,
+ self.filters)
+
+ self.kernel = self.add_weight(
+ shape=self.kernel_shape,
+ initializer=self.kernel_initializer,
+ name='kernel',
+ regularizer=self.kernel_regularizer,
+ constraint=self.kernel_constraint)
+
+ elif self.implementation == 2:
+ if self.data_format == 'channels_first':
+ self.kernel_shape = (input_filter, input_row, input_col,
+ self.filters, self.output_row, self.output_col)
+ else:
+ self.kernel_shape = (input_row, input_col, input_filter,
+ self.output_row, self.output_col, self.filters)
+
+ self.kernel = self.add_weight(shape=self.kernel_shape,
+ initializer=self.kernel_initializer,
+ name='kernel',
+ regularizer=self.kernel_regularizer,
+ constraint=self.kernel_constraint)
+
+ self.kernel_mask = get_locallyconnected_mask(
+ input_shape=(input_row, input_col),
+ kernel_shape=self.kernel_size,
+ strides=self.strides,
+ padding=self.padding,
+ data_format=self.data_format,
+ dtype=self.kernel.dtype
+ )
+
+ else:
+ raise ValueError('Unrecognized implementation mode: %d.'
+ % self.implementation)
+
if self.use_bias:
self.bias = self.add_weight(
shape=(output_row, output_col, self.filters),
@@ -401,8 +525,18 @@ class LocallyConnected2D(Layer):
return (input_shape[0], rows, cols, self.filters)
def call(self, inputs):
- output = K.local_conv(inputs, self.kernel, self.kernel_size, self.strides,
- (self.output_row, self.output_col), self.data_format)
+ if self.implementation == 1:
+ output = K.local_conv(inputs, self.kernel, self.kernel_size, self.strides,
+ (self.output_row, self.output_col),
+ self.data_format)
+
+ elif self.implementation == 2:
+ output = local_conv_matmul(inputs, self.kernel, self.kernel_mask,
+ self.compute_output_shape(inputs.shape))
+
+ else:
+ raise ValueError('Unrecognized implementation mode: %d.'
+ % self.implementation)
if self.use_bias:
output = K.bias_add(output, self.bias, data_format=self.data_format)
@@ -439,7 +573,157 @@ class LocallyConnected2D(Layer):
'kernel_constraint':
constraints.serialize(self.kernel_constraint),
'bias_constraint':
- constraints.serialize(self.bias_constraint)
+ constraints.serialize(self.bias_constraint),
+ 'implementation':
+ self.implementation
}
base_config = super(LocallyConnected2D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
+
+
+def get_locallyconnected_mask(input_shape,
+ kernel_shape,
+ strides,
+ padding,
+ data_format,
+ dtype):
+ """Return a mask representing connectivity of a locally-connected operation.
+
+ This method returns a masking tensor of 0s and 1s (of type `dtype`) that,
+ when element-wise multiplied with a fully-connected weight tensor, masks out
+ the weights between disconnected input-output pairs and thus implements local
+ connectivity through a sparse fully-connected weight tensor.
+
+ Assume an unshared convolution with given parameters is applied to an input
+ having N spatial dimensions with `input_shape = (d_in1, ..., d_inN)`
+ to produce an output with spatial shape `(d_out1, ..., d_outN)` (determined
+ by layer parameters such as `strides`).
+
+ This method returns a mask which can be broadcast-multiplied (element-wise)
+ with a 2*(N+1)-D weight matrix (equivalent to a fully-connected layer between
+ (N+1)-D activations (N spatial + 1 channel dimensions for input and output)
+ to make it perform an unshared convolution with given `kernel_shape`,
+ `strides`, `padding` and `data_format`.
+
+ Arguments:
+ input_shape: tuple of size N: `(d_in1, ..., d_inN)`
+ spatial shape of the input.
+ kernel_shape: tuple of size N, spatial shape of the convolutional kernel
+ / receptive field.
+ strides: tuple of size N, strides along each spatial dimension.
+ padding: type of padding, string `"same"` or `"valid"`.
+ data_format: a string, `"channels_first"` or `"channels_last"`.
+ dtype: type of the layer operation, e.g. `tf.float64`.
+
+ Returns:
+ a `dtype`-tensor of shape
+ `(1, d_in1, ..., d_inN, 1, d_out1, ..., d_outN)`
+ if `data_format == `"channels_first"`, or
+ `(d_in1, ..., d_inN, 1, d_out1, ..., d_outN, 1)`
+ if `data_format == "channels_last"`.
+
+ Raises:
+ ValueError: if `data_format` is neither `"channels_first"` nor
+ `"channels_last"`.
+ """
+ mask = conv_utils.conv_kernel_mask(
+ input_shape=input_shape,
+ kernel_shape=kernel_shape,
+ strides=strides,
+ padding=padding
+ )
+
+ ndims = int(mask.ndim / 2)
+ mask = K.variable(mask, dtype)
+
+ if data_format == 'channels_first':
+ mask = K.expand_dims(mask, 0)
+ mask = K.expand_dims(mask, - ndims - 1)
+
+ elif data_format == 'channels_last':
+ mask = K.expand_dims(mask, ndims)
+ mask = K.expand_dims(mask, -1)
+
+ else:
+ raise ValueError('Unrecognized data_format: ' + str(data_format))
+
+ return mask
+
+
+def local_conv_matmul(inputs, kernel, kernel_mask, output_shape):
+ """Apply N-D convolution with un-shared weights using a single matmul call.
+
+ This method outputs `inputs . (kernel * kernel_mask)`
+ (with `.` standing for matrix-multiply and `*` for element-wise multiply)
+ and requires a precomputed `kernel_mask` to zero-out weights in `kernel` and
+ hence perform the same operation as a convolution with un-shared
+ (the remaining entries in `kernel`) weights. It also does the necessary
+ reshapes to make `inputs` and `kernel` 2-D and `output` (N+2)-D.
+
+ Arguments:
+ inputs: (N+2)-D tensor with shape
+ `(batch_size, channels_in, d_in1, ..., d_inN)`
+ or
+ `(batch_size, d_in1, ..., d_inN, channels_in)`.
+ kernel: the unshared weights for N-D convolution,
+ an (N+2)-D tensor of shape:
+ `(d_in1, ..., d_inN, channels_in, d_out2, ..., d_outN, channels_out)`
+ or
+ `(channels_in, d_in1, ..., d_inN, channels_out, d_out2, ..., d_outN)`,
+ with the ordering of channels and spatial dimensions matching
+ that of the input.
+ Each entry is the weight between a particular input and
+ output location, similarly to a fully-connected weight matrix.
+ kernel_mask: a float 0/1 mask tensor of shape:
+ `(d_in1, ..., d_inN, 1, d_out2, ..., d_outN, 1)`
+ or
+ `(1, d_in1, ..., d_inN, 1, d_out2, ..., d_outN)`,
+ with the ordering of singleton and spatial dimensions
+ matching that of the input.
+ Mask represents the connectivity pattern of the layer and is
+ precomputed elsewhere based on layer parameters: stride,
+ padding, and the receptive field shape.
+ output_shape: a tuple of (N+2) elements representing the output shape:
+ `(batch_size, channels_out, d_out1, ..., d_outN)`
+ or
+ `(batch_size, d_out1, ..., d_outN, channels_out)`,
+ with the ordering of channels and spatial dimensions matching that of
+ the input.
+
+ Returns:
+ Output (N+2)-D tensor with shape `output_shape`.
+ """
+ inputs_flat = K.reshape(inputs, (K.shape(inputs)[0], -1))
+
+ kernel = kernel_mask * kernel
+ kernel = make_2d(kernel, split_dim=K.ndim(kernel) // 2)
+
+ output_flat = K.math_ops.sparse_matmul(inputs_flat, kernel, b_is_sparse=True)
+ output = K.reshape(output_flat,
+ [K.shape(output_flat)[0],] + output_shape.as_list()[1:])
+ return output
+
+
+def make_2d(tensor, split_dim):
+ """Reshapes an N-dimensional tensor into a 2D tensor.
+
+ Dimensions before (excluding) and after (including) `split_dim` are grouped
+ together.
+
+ Arguments:
+ tensor: a tensor of shape `(d0, ..., d(N-1))`.
+ split_dim: an integer from 1 to N-1, index of the dimension to group
+ dimensions before (excluding) and after (including).
+
+ Returns:
+ Tensor of shape
+ `(d0 * ... * d(split_dim-1), d(split_dim) * ... * d(N-1))`.
+ """
+ shape = K.array_ops.shape(tensor)
+ in_dims = shape[:split_dim]
+ out_dims = shape[split_dim:]
+
+ in_size = K.math_ops.reduce_prod(in_dims)
+ out_size = K.math_ops.reduce_prod(out_dims)
+
+ return K.array_ops.reshape(tensor, (in_size, out_size))