diff options
Diffstat (limited to 'tensorflow/python/keras/layers/local.py')
-rw-r--r-- | tensorflow/python/keras/layers/local.py | 340 |
1 files changed, 312 insertions, 28 deletions
diff --git a/tensorflow/python/keras/layers/local.py b/tensorflow/python/keras/layers/local.py index 0ebafe07cc..33d09a1660 100644 --- a/tensorflow/python/keras/layers/local.py +++ b/tensorflow/python/keras/layers/local.py @@ -85,6 +85,28 @@ class LocallyConnected1D(Layer): the output of the layer (its "activation").. kernel_constraint: Constraint function applied to the kernel matrix. bias_constraint: Constraint function applied to the bias vector. + implementation: implementation mode, either `1` or `2`. + `1` loops over input spatial locations to perform the forward pass. + It is memory-efficient but performs a lot of (small) ops. + + `2` stores layer weights in a dense but sparsely-populated 2D matrix + and implements the forward pass as a single matrix-multiply. It uses + a lot of RAM but performs few (large) ops. + + Depending on the inputs, layer parameters, hardware, and + `tf.executing_eagerly()` one implementation can be dramatically faster + (e.g. 50X) than another. + + It is recommended to benchmark both in the setting of interest to pick + the most efficient one (in terms of speed and memory usage). + + Following scenarios could benefit from setting `implementation=2`: + - eager execution; + - inference; + - running on CPU; + - large amount of RAM available; + - small models (few filters, small kernel); + - using `padding=same` (only possible with `implementation=2`). Input shape: 3D tensor with shape: `(batch_size, steps, input_dim)` @@ -109,15 +131,17 @@ class LocallyConnected1D(Layer): activity_regularizer=None, kernel_constraint=None, bias_constraint=None, + implementation=1, **kwargs): super(LocallyConnected1D, self).__init__(**kwargs) self.filters = filters self.kernel_size = conv_utils.normalize_tuple(kernel_size, 1, 'kernel_size') self.strides = conv_utils.normalize_tuple(strides, 1, 'strides') self.padding = conv_utils.normalize_padding(padding) - if self.padding != 'valid': + if self.padding != 'valid' and implementation == 1: raise ValueError('Invalid border mode for LocallyConnected1D ' - '(only "valid" is supported): ' + padding) + '(only "valid" is supported if implementation is 1): ' + + padding) self.data_format = conv_utils.normalize_data_format(data_format) self.activation = activations.get(activation) self.use_bias = use_bias @@ -128,6 +152,7 @@ class LocallyConnected1D(Layer): self.activity_regularizer = regularizers.get(activity_regularizer) self.kernel_constraint = constraints.get(kernel_constraint) self.bias_constraint = constraints.get(bias_constraint) + self.implementation = implementation self.input_spec = InputSpec(ndim=3) @tf_utils.shape_type_conversion @@ -142,14 +167,45 @@ class LocallyConnected1D(Layer): 'Found shape:', input_shape) self.output_length = conv_utils.conv_output_length( input_length, self.kernel_size[0], self.padding, self.strides[0]) - self.kernel_shape = (self.output_length, self.kernel_size[0] * input_dim, - self.filters) - self.kernel = self.add_weight( - shape=self.kernel_shape, - initializer=self.kernel_initializer, - name='kernel', - regularizer=self.kernel_regularizer, - constraint=self.kernel_constraint) + + if self.implementation == 1: + self.kernel_shape = (self.output_length, self.kernel_size[0] * input_dim, + self.filters) + + self.kernel = self.add_weight( + shape=self.kernel_shape, + initializer=self.kernel_initializer, + name='kernel', + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint) + + elif self.implementation == 2: + if self.data_format == 'channels_first': + self.kernel_shape = (input_dim, input_length, + self.filters, self.output_length) + else: + self.kernel_shape = (input_length, input_dim, + self.output_length, self.filters) + + self.kernel = self.add_weight(shape=self.kernel_shape, + initializer=self.kernel_initializer, + name='kernel', + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint) + + self.kernel_mask = get_locallyconnected_mask( + input_shape=(input_length,), + kernel_shape=self.kernel_size, + strides=self.strides, + padding=self.padding, + data_format=self.data_format, + dtype=self.kernel.dtype + ) + + else: + raise ValueError('Unrecognized implementation mode: %d.' + % self.implementation) + if self.use_bias: self.bias = self.add_weight( shape=(self.output_length, self.filters), @@ -182,8 +238,17 @@ class LocallyConnected1D(Layer): return (input_shape[0], length, self.filters) def call(self, inputs): - output = K.local_conv(inputs, self.kernel, self.kernel_size, self.strides, - (self.output_length,), self.data_format) + if self.implementation == 1: + output = K.local_conv(inputs, self.kernel, self.kernel_size, self.strides, + (self.output_length,), self.data_format) + + elif self.implementation == 2: + output = local_conv_matmul(inputs, self.kernel, self.kernel_mask, + self.compute_output_shape(inputs.shape)) + + else: + raise ValueError('Unrecognized implementation mode: %d.' + % self.implementation) if self.use_bias: output = K.bias_add(output, self.bias, data_format=self.data_format) @@ -220,7 +285,9 @@ class LocallyConnected1D(Layer): 'kernel_constraint': constraints.serialize(self.kernel_constraint), 'bias_constraint': - constraints.serialize(self.bias_constraint) + constraints.serialize(self.bias_constraint), + 'implementation': + self.implementation } base_config = super(LocallyConnected1D, self).get_config() return dict(list(base_config.items()) + list(config.items())) @@ -284,9 +351,31 @@ class LocallyConnected2D(Layer): the `kernel` weights matrix. bias_regularizer: Regularizer function applied to the bias vector. activity_regularizer: Regularizer function applied to - the output of the layer (its "activation").. + the output of the layer (its "activation"). kernel_constraint: Constraint function applied to the kernel matrix. bias_constraint: Constraint function applied to the bias vector. + implementation: implementation mode, either `1` or `2`. + `1` loops over input spatial locations to perform the forward pass. + It is memory-efficient but performs a lot of (small) ops. + + `2` stores layer weights in a dense but sparsely-populated 2D matrix + and implements the forward pass as a single matrix-multiply. It uses + a lot of RAM but performs few (large) ops. + + Depending on the inputs, layer parameters, hardware, and + `tf.executing_eagerly()` one implementation can be dramatically faster + (e.g. 50X) than another. + + It is recommended to benchmark both in the setting of interest to pick + the most efficient one (in terms of speed and memory usage). + + Following scenarios could benefit from setting `implementation=2`: + - eager execution; + - inference; + - running on CPU; + - large amount of RAM available; + - small models (few filters, small kernel); + - using `padding=same` (only possible with `implementation=2`). Input shape: 4D tensor with shape: @@ -317,15 +406,17 @@ class LocallyConnected2D(Layer): activity_regularizer=None, kernel_constraint=None, bias_constraint=None, + implementation=1, **kwargs): super(LocallyConnected2D, self).__init__(**kwargs) self.filters = filters self.kernel_size = conv_utils.normalize_tuple(kernel_size, 2, 'kernel_size') self.strides = conv_utils.normalize_tuple(strides, 2, 'strides') self.padding = conv_utils.normalize_padding(padding) - if self.padding != 'valid': + if self.padding != 'valid' and implementation == 1: raise ValueError('Invalid border mode for LocallyConnected2D ' - '(only "valid" is supported): ' + padding) + '(only "valid" is supported if implementation is 1): ' + + padding) self.data_format = conv_utils.normalize_data_format(data_format) self.activation = activations.get(activation) self.use_bias = use_bias @@ -336,6 +427,7 @@ class LocallyConnected2D(Layer): self.activity_regularizer = regularizers.get(activity_regularizer) self.kernel_constraint = constraints.get(kernel_constraint) self.bias_constraint = constraints.get(bias_constraint) + self.implementation = implementation self.input_spec = InputSpec(ndim=4) @tf_utils.shape_type_conversion @@ -357,15 +449,47 @@ class LocallyConnected2D(Layer): self.padding, self.strides[1]) self.output_row = output_row self.output_col = output_col - self.kernel_shape = ( - output_row * output_col, - self.kernel_size[0] * self.kernel_size[1] * input_filter, self.filters) - self.kernel = self.add_weight( - shape=self.kernel_shape, - initializer=self.kernel_initializer, - name='kernel', - regularizer=self.kernel_regularizer, - constraint=self.kernel_constraint) + + if self.implementation == 1: + self.kernel_shape = ( + output_row * output_col, + self.kernel_size[0] * self.kernel_size[1] * input_filter, + self.filters) + + self.kernel = self.add_weight( + shape=self.kernel_shape, + initializer=self.kernel_initializer, + name='kernel', + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint) + + elif self.implementation == 2: + if self.data_format == 'channels_first': + self.kernel_shape = (input_filter, input_row, input_col, + self.filters, self.output_row, self.output_col) + else: + self.kernel_shape = (input_row, input_col, input_filter, + self.output_row, self.output_col, self.filters) + + self.kernel = self.add_weight(shape=self.kernel_shape, + initializer=self.kernel_initializer, + name='kernel', + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint) + + self.kernel_mask = get_locallyconnected_mask( + input_shape=(input_row, input_col), + kernel_shape=self.kernel_size, + strides=self.strides, + padding=self.padding, + data_format=self.data_format, + dtype=self.kernel.dtype + ) + + else: + raise ValueError('Unrecognized implementation mode: %d.' + % self.implementation) + if self.use_bias: self.bias = self.add_weight( shape=(output_row, output_col, self.filters), @@ -401,8 +525,18 @@ class LocallyConnected2D(Layer): return (input_shape[0], rows, cols, self.filters) def call(self, inputs): - output = K.local_conv(inputs, self.kernel, self.kernel_size, self.strides, - (self.output_row, self.output_col), self.data_format) + if self.implementation == 1: + output = K.local_conv(inputs, self.kernel, self.kernel_size, self.strides, + (self.output_row, self.output_col), + self.data_format) + + elif self.implementation == 2: + output = local_conv_matmul(inputs, self.kernel, self.kernel_mask, + self.compute_output_shape(inputs.shape)) + + else: + raise ValueError('Unrecognized implementation mode: %d.' + % self.implementation) if self.use_bias: output = K.bias_add(output, self.bias, data_format=self.data_format) @@ -439,7 +573,157 @@ class LocallyConnected2D(Layer): 'kernel_constraint': constraints.serialize(self.kernel_constraint), 'bias_constraint': - constraints.serialize(self.bias_constraint) + constraints.serialize(self.bias_constraint), + 'implementation': + self.implementation } base_config = super(LocallyConnected2D, self).get_config() return dict(list(base_config.items()) + list(config.items())) + + +def get_locallyconnected_mask(input_shape, + kernel_shape, + strides, + padding, + data_format, + dtype): + """Return a mask representing connectivity of a locally-connected operation. + + This method returns a masking tensor of 0s and 1s (of type `dtype`) that, + when element-wise multiplied with a fully-connected weight tensor, masks out + the weights between disconnected input-output pairs and thus implements local + connectivity through a sparse fully-connected weight tensor. + + Assume an unshared convolution with given parameters is applied to an input + having N spatial dimensions with `input_shape = (d_in1, ..., d_inN)` + to produce an output with spatial shape `(d_out1, ..., d_outN)` (determined + by layer parameters such as `strides`). + + This method returns a mask which can be broadcast-multiplied (element-wise) + with a 2*(N+1)-D weight matrix (equivalent to a fully-connected layer between + (N+1)-D activations (N spatial + 1 channel dimensions for input and output) + to make it perform an unshared convolution with given `kernel_shape`, + `strides`, `padding` and `data_format`. + + Arguments: + input_shape: tuple of size N: `(d_in1, ..., d_inN)` + spatial shape of the input. + kernel_shape: tuple of size N, spatial shape of the convolutional kernel + / receptive field. + strides: tuple of size N, strides along each spatial dimension. + padding: type of padding, string `"same"` or `"valid"`. + data_format: a string, `"channels_first"` or `"channels_last"`. + dtype: type of the layer operation, e.g. `tf.float64`. + + Returns: + a `dtype`-tensor of shape + `(1, d_in1, ..., d_inN, 1, d_out1, ..., d_outN)` + if `data_format == `"channels_first"`, or + `(d_in1, ..., d_inN, 1, d_out1, ..., d_outN, 1)` + if `data_format == "channels_last"`. + + Raises: + ValueError: if `data_format` is neither `"channels_first"` nor + `"channels_last"`. + """ + mask = conv_utils.conv_kernel_mask( + input_shape=input_shape, + kernel_shape=kernel_shape, + strides=strides, + padding=padding + ) + + ndims = int(mask.ndim / 2) + mask = K.variable(mask, dtype) + + if data_format == 'channels_first': + mask = K.expand_dims(mask, 0) + mask = K.expand_dims(mask, - ndims - 1) + + elif data_format == 'channels_last': + mask = K.expand_dims(mask, ndims) + mask = K.expand_dims(mask, -1) + + else: + raise ValueError('Unrecognized data_format: ' + str(data_format)) + + return mask + + +def local_conv_matmul(inputs, kernel, kernel_mask, output_shape): + """Apply N-D convolution with un-shared weights using a single matmul call. + + This method outputs `inputs . (kernel * kernel_mask)` + (with `.` standing for matrix-multiply and `*` for element-wise multiply) + and requires a precomputed `kernel_mask` to zero-out weights in `kernel` and + hence perform the same operation as a convolution with un-shared + (the remaining entries in `kernel`) weights. It also does the necessary + reshapes to make `inputs` and `kernel` 2-D and `output` (N+2)-D. + + Arguments: + inputs: (N+2)-D tensor with shape + `(batch_size, channels_in, d_in1, ..., d_inN)` + or + `(batch_size, d_in1, ..., d_inN, channels_in)`. + kernel: the unshared weights for N-D convolution, + an (N+2)-D tensor of shape: + `(d_in1, ..., d_inN, channels_in, d_out2, ..., d_outN, channels_out)` + or + `(channels_in, d_in1, ..., d_inN, channels_out, d_out2, ..., d_outN)`, + with the ordering of channels and spatial dimensions matching + that of the input. + Each entry is the weight between a particular input and + output location, similarly to a fully-connected weight matrix. + kernel_mask: a float 0/1 mask tensor of shape: + `(d_in1, ..., d_inN, 1, d_out2, ..., d_outN, 1)` + or + `(1, d_in1, ..., d_inN, 1, d_out2, ..., d_outN)`, + with the ordering of singleton and spatial dimensions + matching that of the input. + Mask represents the connectivity pattern of the layer and is + precomputed elsewhere based on layer parameters: stride, + padding, and the receptive field shape. + output_shape: a tuple of (N+2) elements representing the output shape: + `(batch_size, channels_out, d_out1, ..., d_outN)` + or + `(batch_size, d_out1, ..., d_outN, channels_out)`, + with the ordering of channels and spatial dimensions matching that of + the input. + + Returns: + Output (N+2)-D tensor with shape `output_shape`. + """ + inputs_flat = K.reshape(inputs, (K.shape(inputs)[0], -1)) + + kernel = kernel_mask * kernel + kernel = make_2d(kernel, split_dim=K.ndim(kernel) // 2) + + output_flat = K.math_ops.sparse_matmul(inputs_flat, kernel, b_is_sparse=True) + output = K.reshape(output_flat, + [K.shape(output_flat)[0],] + output_shape.as_list()[1:]) + return output + + +def make_2d(tensor, split_dim): + """Reshapes an N-dimensional tensor into a 2D tensor. + + Dimensions before (excluding) and after (including) `split_dim` are grouped + together. + + Arguments: + tensor: a tensor of shape `(d0, ..., d(N-1))`. + split_dim: an integer from 1 to N-1, index of the dimension to group + dimensions before (excluding) and after (including). + + Returns: + Tensor of shape + `(d0 * ... * d(split_dim-1), d(split_dim) * ... * d(N-1))`. + """ + shape = K.array_ops.shape(tensor) + in_dims = shape[:split_dim] + out_dims = shape[split_dim:] + + in_size = K.math_ops.reduce_prod(in_dims) + out_size = K.math_ops.reduce_prod(out_dims) + + return K.array_ops.reshape(tensor, (in_size, out_size)) |