aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/backend.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-21 10:06:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-21 10:10:08 -0700
commitd3ab92cf907e15da2ba70bccd65e5b4ccbfad575 (patch)
tree3632a8617efe2d370916d330beb9d3a36dc339db /tensorflow/python/keras/backend.py
parent293b21eddc34ee0ceda1143ec7699e54c9768a1c (diff)
Replace unshared convolution backend for LocallyConnected1D and LocallyConnected2D layers with a common dimension-agnostic implementation.
PiperOrigin-RevId: 201542873
Diffstat (limited to 'tensorflow/python/keras/backend.py')
-rw-r--r--tensorflow/python/keras/backend.py188
1 files changed, 108 insertions, 80 deletions
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index c55a756bcc..fed779650e 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -22,6 +22,7 @@ from __future__ import division
from __future__ import print_function
import collections
+import itertools
import json
import os
import weakref
@@ -4245,58 +4246,115 @@ def pool3d(x,
return x
-def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None):
- """Apply 1D conv with un-shared weights.
-
- Arguments:
- inputs: 3D tensor with shape:
- (batch_size, steps, input_dim)
- if data_format is "channels_last" or
- (batch_size, input_dim, steps)
- if data_format is "channels_first".
- kernel: the unshared weight for convolution,
- with shape (output_length, feature_dim, filters)
- kernel_size: a tuple of a single integer,
- specifying the length of the 1D convolution window
- strides: a tuple of a single integer,
- specifying the stride length of the convolution
- data_format: the data format, channels_first or channels_last
-
- Returns:
- the tensor after 1d conv with un-shared weights, with shape (batch_size,
- output_length, filters)
+def local_conv(inputs,
+ kernel,
+ kernel_size,
+ strides,
+ output_shape,
+ data_format=None):
+ """Apply N-D convolution with un-shared weights.
+
+ Arguments:
+ inputs: (N+2)-D tensor with shape
+ (batch_size, channels_in, d_in1, ..., d_inN)
+ if data_format='channels_first', or
+ (batch_size, d_in1, ..., d_inN, channels_in)
+ if data_format='channels_last'.
+ kernel: the unshared weight for N-D convolution,
+ with shape (output_items, feature_dim, channels_out), where
+ feature_dim = np.prod(kernel_size) * channels_in,
+ output_items = np.prod(output_shape).
+ kernel_size: a tuple of N integers, specifying the
+ spatial dimensions of the N-D convolution window.
+ strides: a tuple of N integers, specifying the strides
+ of the convolution along the spatial dimensions.
+ output_shape: a tuple of (d_out1, ..., d_outN) specifying the spatial
+ dimensionality of the output.
+ data_format: string, "channels_first" or "channels_last".
+
+ Returns:
+ An (N+2)-D tensor with shape:
+ (batch_size, channels_out) + output_shape
+ if data_format='channels_first', or:
+ (batch_size,) + output_shape + (channels_out,)
+ if data_format='channels_last'.
Raises:
- ValueError: if `data_format` is neither `channels_last` or
- `channels_first`.
+ ValueError: if `data_format` is neither
+ `channels_last` nor `channels_first`.
"""
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format: ' + str(data_format))
- stride = strides[0]
kernel_shape = int_shape(kernel)
- output_length = kernel_shape[0]
feature_dim = kernel_shape[1]
+ channels_out = kernel_shape[-1]
+ ndims = len(output_shape)
+ spatial_dimensions = list(range(ndims))
xs = []
- for i in range(output_length):
- slice_length = slice(i * stride, i * stride + kernel_size[0])
+ output_axes_ticks = [range(axis_max) for axis_max in output_shape]
+ for position in itertools.product(*output_axes_ticks):
+ slices = [slice(None)]
+
if data_format == 'channels_first':
- xs.append(reshape(inputs[:, :, slice_length], (1, -1, feature_dim)))
- else:
- xs.append(reshape(inputs[:, slice_length, :], (1, -1, feature_dim)))
+ slices.append(slice(None))
+
+ slices.extend([slice(position[d] * strides[d],
+ position[d] * strides[d] + kernel_size[d])
+ for d in spatial_dimensions])
+
+ if data_format == 'channels_last':
+ slices.append(slice(None))
+
+ xs.append(reshape(inputs[slices], (1, -1, feature_dim)))
x_aggregate = concatenate(xs, axis=0)
- # Shape: `(output_length, batch_size, filters)`.
output = batch_dot(x_aggregate, kernel)
+ output = reshape(output, output_shape + (-1, channels_out))
if data_format == 'channels_first':
- output = permute_dimensions(output, (1, 2, 0))
+ permutation = [ndims, ndims + 1] + spatial_dimensions
else:
- output = permute_dimensions(output, (1, 0, 2))
- return output
+ permutation = [ndims] + spatial_dimensions + [ndims + 1]
+
+ return permute_dimensions(output, permutation)
+
+
+def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None):
+ """Apply 1D conv with un-shared weights.
+
+ Arguments:
+ inputs: 3D tensor with shape:
+ (batch_size, steps, input_dim)
+ if data_format is "channels_last" or
+ (batch_size, input_dim, steps)
+ if data_format is "channels_first".
+ kernel: the unshared weight for convolution,
+ with shape (output_length, feature_dim, filters).
+ kernel_size: a tuple of a single integer,
+ specifying the length of the 1D convolution window.
+ strides: a tuple of a single integer,
+ specifying the stride length of the convolution.
+ data_format: the data format, channels_first or channels_last.
+
+ Returns:
+ A 3d tensor with shape:
+ (batch_size, output_length, filters)
+ if data_format='channels_first'
+ or 3D tensor with shape:
+ (batch_size, filters, output_length)
+ if data_format='channels_last'.
+ """
+ output_shape = (kernel.shape[0],)
+ return local_conv(inputs,
+ kernel,
+ kernel_size,
+ strides,
+ output_shape,
+ data_format)
def local_conv2d(inputs,
@@ -4309,64 +4367,34 @@ def local_conv2d(inputs,
Arguments:
inputs: 4D tensor with shape:
- (batch_size, filters, new_rows, new_cols)
- if data_format='channels_first'
- or 4D tensor with shape:
- (batch_size, new_rows, new_cols, filters)
- if data_format='channels_last'.
+ (batch_size, filters, new_rows, new_cols)
+ if data_format='channels_first'
+ or 4D tensor with shape:
+ (batch_size, new_rows, new_cols, filters)
+ if data_format='channels_last'.
kernel: the unshared weight for convolution,
- with shape (output_items, feature_dim, filters)
+ with shape (output_items, feature_dim, filters).
kernel_size: a tuple of 2 integers, specifying the
- width and height of the 2D convolution window.
+ width and height of the 2D convolution window.
strides: a tuple of 2 integers, specifying the strides
- of the convolution along the width and height.
- output_shape: a tuple with (output_row, output_col)
- data_format: the data format, channels_first or channels_last
+ of the convolution along the width and height.
+ output_shape: a tuple with (output_row, output_col).
+ data_format: the data format, channels_first or channels_last.
Returns:
- A 4d tensor with shape:
+ A 4D tensor with shape:
(batch_size, filters, new_rows, new_cols)
if data_format='channels_first'
or 4D tensor with shape:
(batch_size, new_rows, new_cols, filters)
if data_format='channels_last'.
-
- Raises:
- ValueError: if `data_format` is neither
- `channels_last` or `channels_first`.
"""
- if data_format is None:
- data_format = image_data_format()
- if data_format not in {'channels_first', 'channels_last'}:
- raise ValueError('Unknown data_format: ' + str(data_format))
-
- stride_row, stride_col = strides
- output_row, output_col = output_shape
- kernel_shape = int_shape(kernel)
- feature_dim = kernel_shape[1]
- filters = kernel_shape[2]
-
- xs = []
- for i in range(output_row):
- for j in range(output_col):
- slice_row = slice(i * stride_row, i * stride_row + kernel_size[0])
- slice_col = slice(j * stride_col, j * stride_col + kernel_size[1])
- if data_format == 'channels_first':
- xs.append(
- reshape(inputs[:, :, slice_row, slice_col], (1, -1, feature_dim)))
- else:
- xs.append(
- reshape(inputs[:, slice_row, slice_col, :], (1, -1, feature_dim)))
-
- x_aggregate = concatenate(xs, axis=0)
- output = batch_dot(x_aggregate, kernel)
- output = reshape(output, (output_row, output_col, -1, filters))
-
- if data_format == 'channels_first':
- output = permute_dimensions(output, (2, 3, 0, 1))
- else:
- output = permute_dimensions(output, (2, 0, 1, 3))
- return output
+ return local_conv(inputs,
+ kernel,
+ kernel_size,
+ strides,
+ output_shape,
+ data_format)
@tf_export('keras.backend.bias_add')