aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/backend_test.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-08 18:17:55 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-08 18:20:25 -0700
commit898f9664488f0036ccc02bbb34379cb613f07a55 (patch)
tree199084150e9027e6f37e15908ae4725b1c35e4ad /tensorflow/python/keras/backend_test.py
parent9070f24ae15a4f589219d4cb9c962b14612c2d8c (diff)
Make LocallyConnected1D layer respect the data_format parameter.
PiperOrigin-RevId: 199879521
Diffstat (limited to 'tensorflow/python/keras/backend_test.py')
-rw-r--r--tensorflow/python/keras/backend_test.py47
1 files changed, 47 insertions, 0 deletions
diff --git a/tensorflow/python/keras/backend_test.py b/tensorflow/python/keras/backend_test.py
index 58df263a4f..53e30e0e4a 100644
--- a/tensorflow/python/keras/backend_test.py
+++ b/tensorflow/python/keras/backend_test.py
@@ -810,6 +810,53 @@ class BackendNNOpsTest(test.TestCase):
padding='same', data_format='channels_last')
self.assertEqual(y.get_shape().as_list(), [10, 5, 5])
+ def test_local_conv1d_channels_dim(self):
+ input_length = 5
+ input_dim = 3
+ batch_size = 2
+
+ inputs = np.random.normal(0, 1, (batch_size, input_dim, input_length))
+ inputs_cf = keras.backend.variable(inputs)
+
+ filters = 4
+ for kernel_size in [(1,), (2,), (3,)]:
+ for strides in [(1,), (2,), (3,)]:
+ output_length = (input_length - kernel_size[0]
+ + strides[0]) // strides[0]
+
+ kernel_shape = (output_length, kernel_size[0] * input_dim, filters)
+ kernel = np.random.normal(0, 1, (output_length,
+ input_dim,
+ kernel_size[0],
+ filters))
+ kernel_cf = np.reshape(kernel, kernel_shape)
+ kernel_cf = keras.backend.variable(kernel_cf)
+
+ conv_cf = keras.backend.local_conv1d(inputs_cf,
+ kernel_cf,
+ kernel_size,
+ strides,
+ 'channels_first')
+
+ inputs_cl = np.transpose(inputs, (0, 2, 1))
+ inputs_cl = keras.backend.variable(inputs_cl)
+
+ kernel_cl = np.reshape(np.transpose(kernel, (0, 2, 1, 3)),
+ kernel_shape)
+ kernel_cl = keras.backend.variable(kernel_cl)
+
+ conv_cl = keras.backend.local_conv1d(inputs_cl,
+ kernel_cl,
+ kernel_size,
+ strides,
+ 'channels_last')
+ with self.test_session():
+ conv_cf = keras.backend.eval(conv_cf)
+ conv_cl = keras.backend.eval(conv_cl)
+
+ self.assertAllCloseAccordingToType(conv_cf,
+ np.transpose(conv_cl, (0, 2, 1)))
+
def test_conv2d(self):
val = np.random.random((10, 4, 10, 10))
x = keras.backend.variable(val)