aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/layers/convolutional_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/layers/convolutional_test.py')
-rw-r--r--tensorflow/python/layers/convolutional_test.py169
1 files changed, 169 insertions, 0 deletions
diff --git a/tensorflow/python/layers/convolutional_test.py b/tensorflow/python/layers/convolutional_test.py
index da962b2f99..42a2d77534 100644
--- a/tensorflow/python/layers/convolutional_test.py
+++ b/tensorflow/python/layers/convolutional_test.py
@@ -651,5 +651,174 @@ class Conv2DTransposeTest(test.TestCase):
self.assertEqual(len(variables.trainable_variables()), 4)
+class Conv3DTransposeTest(test.TestCase):
+
+ def testInvalidDataFormat(self):
+ depth, height, width = 5, 7, 9
+ volumes = random_ops.random_uniform((5, depth, height, width, 32), seed=1)
+ with self.assertRaisesRegexp(ValueError, 'data_format'):
+ conv_layers.conv3d_transpose(volumes, 4, 3, data_format='invalid')
+
+ def testInvalidStrides(self):
+ depth, height, width = 5, 7, 9
+ volumes = random_ops.random_uniform((5, depth, height, width, 32), seed=1)
+ with self.assertRaisesRegexp(ValueError, 'strides'):
+ conv_layers.conv3d_transpose(volumes, 4, 3, strides=(1, 2))
+
+ with self.assertRaisesRegexp(ValueError, 'strides'):
+ conv_layers.conv3d_transpose(volumes, 4, 3, strides=None)
+
+ def testInvalidKernelSize(self):
+ depth, height, width = 5, 7, 9
+ volumes = random_ops.random_uniform((5, depth, height, width, 32), seed=1)
+ with self.assertRaisesRegexp(ValueError, 'kernel_size'):
+ conv_layers.conv3d_transpose(volumes, 4, (1, 2))
+
+ with self.assertRaisesRegexp(ValueError, 'kernel_size'):
+ conv_layers.conv3d_transpose(volumes, 4, None)
+
+ def testCreateConv3DTranspose(self):
+ depth, height, width = 5, 7, 9
+ volumes = random_ops.random_uniform((5, depth, height, width, 32))
+ layer = conv_layers.Conv3DTranspose(4, [3, 3, 3], activation=nn_ops.relu)
+ output = layer.apply(volumes)
+ self.assertEqual(output.op.name, 'conv3d_transpose/Relu')
+ self.assertListEqual(output.get_shape().as_list(),
+ [5, depth + 2, height + 2, width + 2, 4])
+ self.assertListEqual(layer.kernel.get_shape().as_list(), [3, 3, 3, 4, 32])
+ self.assertListEqual(layer.bias.get_shape().as_list(), [4])
+
+ def testCreateConv3DTransposeIntegerKernelSize(self):
+ depth, height, width = 5, 7, 9
+ volumes = random_ops.random_uniform((5, depth, height, width, 32))
+ layer = conv_layers.Conv3DTranspose(4, 3)
+ output = layer.apply(volumes)
+ self.assertListEqual(output.get_shape().as_list(),
+ [5, depth + 2, height + 2, width + 2, 4])
+ self.assertListEqual(layer.kernel.get_shape().as_list(), [3, 3, 3, 4, 32])
+ self.assertListEqual(layer.bias.get_shape().as_list(), [4])
+
+ def testCreateConv3DTransposeChannelsFirst(self):
+ depth, height, width = 5, 7, 9
+ volumes = random_ops.random_uniform((5, 32, depth, height, width))
+ layer = conv_layers.Conv3DTranspose(
+ 4, [3, 3, 3], data_format='channels_first')
+ output = layer.apply(volumes)
+ self.assertListEqual(output.get_shape().as_list(),
+ [5, 4, depth + 2, height + 2, width + 2])
+ self.assertListEqual(layer.kernel.get_shape().as_list(), [3, 3, 3, 4, 32])
+ self.assertListEqual(layer.bias.get_shape().as_list(), [4])
+
+ def testConv3DTransposePaddingSame(self):
+ depth, height, width = 5, 7, 9
+ volumes = random_ops.random_uniform((5, depth, height, width, 64), seed=1)
+ layer = conv_layers.Conv3DTranspose(
+ 32, volumes.get_shape()[1:4], padding='same')
+ output = layer.apply(volumes)
+ self.assertListEqual(output.get_shape().as_list(),
+ [5, depth, height, width, 32])
+
+ def testCreateConv3DTransposeWithStrides(self):
+ depth, height, width = 4, 6, 8
+ # Test strides tuple.
+ volumes = random_ops.random_uniform((5, depth, height, width, 32), seed=1)
+ layer = conv_layers.Conv3DTranspose(
+ 4, [3, 3, 3], strides=(2, 2, 2), padding='same')
+ output = layer.apply(volumes)
+ self.assertListEqual(output.get_shape().as_list(),
+ [5, depth * 2, height * 2, width * 2, 4])
+
+ # Test strides integer.
+ layer = conv_layers.Conv3DTranspose(4, [3, 3, 3], strides=2, padding='same')
+ output = layer.apply(volumes)
+ self.assertListEqual(output.get_shape().as_list(),
+ [5, depth * 2, height * 2, width * 2, 4])
+
+ # Test unequal strides.
+ layer = conv_layers.Conv3DTranspose(
+ 4, [3, 3, 3], strides=(2, 1, 1), padding='same')
+ output = layer.apply(volumes)
+ self.assertListEqual(output.get_shape().as_list(),
+ [5, depth * 2, height, width, 4])
+
+ def testConv3DTransposeKernelRegularizer(self):
+ depth, height, width = 5, 7, 9
+ volumes = random_ops.random_uniform((5, depth, height, width, 32))
+ reg = lambda x: 0.1 * math_ops.reduce_sum(x)
+ layer = conv_layers.Conv3DTranspose(4, [3, 3, 3], kernel_regularizer=reg)
+ layer.apply(volumes)
+ loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
+ self.assertEqual(len(loss_keys), 1)
+ self.assertListEqual(layer.losses, loss_keys)
+
+ def testConv3DTransposeBiasRegularizer(self):
+ depth, height, width = 5, 7, 9
+ volumes = random_ops.random_uniform((5, depth, height, width, 32))
+ reg = lambda x: 0.1 * math_ops.reduce_sum(x)
+ layer = conv_layers.Conv3DTranspose(4, [3, 3, 3], bias_regularizer=reg)
+ layer.apply(volumes)
+ loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
+ self.assertEqual(len(loss_keys), 1)
+ self.assertListEqual(layer.losses, loss_keys)
+
+ def testConv3DTransposeNoBias(self):
+ depth, height, width = 5, 7, 9
+ volumes = random_ops.random_uniform((5, depth, height, width, 32))
+ layer = conv_layers.Conv3DTranspose(
+ 4, [3, 3, 3], activation=nn_ops.relu, use_bias=False)
+ output = layer.apply(volumes)
+ self.assertEqual(output.op.name, 'conv3d_transpose/Relu')
+ self.assertListEqual(output.get_shape().as_list(),
+ [5, depth + 2, height + 2, width + 2, 4])
+ self.assertListEqual(layer.kernel.get_shape().as_list(), [3, 3, 3, 4, 32])
+ self.assertEqual(layer.bias, None)
+
+ def testFunctionalConv3DTransposeReuse(self):
+ depth, height, width = 5, 7, 9
+ volumes = random_ops.random_uniform((5, depth, height, width, 32), seed=1)
+ conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1')
+ self.assertEqual(len(variables.trainable_variables()), 2)
+ conv_layers.conv3d_transpose(
+ volumes, 4, [3, 3, 3], name='deconv1', reuse=True)
+ self.assertEqual(len(variables.trainable_variables()), 2)
+
+ def testFunctionalConv3DTransposeReuseFromScope(self):
+ with variable_scope.variable_scope('scope'):
+ depth, height, width = 5, 7, 9
+ volumes = random_ops.random_uniform((5, depth, height, width, 32), seed=1)
+ conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1')
+ self.assertEqual(len(variables.trainable_variables()), 2)
+ with variable_scope.variable_scope('scope', reuse=True):
+ conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1')
+ self.assertEqual(len(variables.trainable_variables()), 2)
+
+ def testFunctionalConv3DTransposeInitializerFromScope(self):
+ with self.test_session() as sess:
+ with variable_scope.variable_scope(
+ 'scope', initializer=init_ops.ones_initializer()):
+ depth, height, width = 5, 7, 9
+ volumes = random_ops.random_uniform(
+ (5, depth, height, width, 32), seed=1)
+ conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1')
+ weights = variables.trainable_variables()
+ # Check the names of weights in order.
+ self.assertTrue('kernel' in weights[0].name)
+ self.assertTrue('bias' in weights[1].name)
+ sess.run(variables.global_variables_initializer())
+ weights = sess.run(weights)
+ # Check that the kernel weights got initialized to ones (from scope)
+ self.assertAllClose(weights[0], np.ones((3, 3, 3, 4, 32)))
+ # Check that the bias still got initialized to zeros.
+ self.assertAllClose(weights[1], np.zeros((4)))
+
+ def testFunctionalConv3DTransposeNoReuse(self):
+ depth, height, width = 5, 7, 9
+ volumes = random_ops.random_uniform((5, depth, height, width, 32), seed=1)
+ conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3])
+ self.assertEqual(len(variables.trainable_variables()), 2)
+ conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3])
+ self.assertEqual(len(variables.trainable_variables()), 4)
+
+
if __name__ == '__main__':
test.main()