From 9edd13a9c96edd9b7fdca71f694bb045f265ad66 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 6 Jan 2016 12:55:10 -0800 Subject: Rename deconv2d to conv2d_transpose and expose publicly deconv2d is a misleading name for an operation which is exactly the transpose of conv2d. conv2d_transpose is much better. Fixes #256. Change: 111529902 --- tensorflow/g3doc/api_docs/python/index.md | 1 + tensorflow/g3doc/api_docs/python/nn.md | 37 ++++++++++++++++++++++++ tensorflow/python/framework/gen_docs_combined.py | 2 +- tensorflow/python/ops/nn.py | 1 + tensorflow/python/ops/nn_grad.py | 2 +- tensorflow/python/ops/nn_ops.py | 13 +++++---- tensorflow/python/ops/nn_test.py | 20 ++++++++----- 7 files changed, 61 insertions(+), 15 deletions(-) diff --git a/tensorflow/g3doc/api_docs/python/index.md b/tensorflow/g3doc/api_docs/python/index.md index efd38ed148..3adb42218f 100644 --- a/tensorflow/g3doc/api_docs/python/index.md +++ b/tensorflow/g3doc/api_docs/python/index.md @@ -289,6 +289,7 @@ * [`bias_add`](../../api_docs/python/nn.md#bias_add) * [`compute_accidental_hits`](../../api_docs/python/nn.md#compute_accidental_hits) * [`conv2d`](../../api_docs/python/nn.md#conv2d) + * [`conv2d_transpose`](../../api_docs/python/nn.md#conv2d_transpose) * [`depthwise_conv2d`](../../api_docs/python/nn.md#depthwise_conv2d) * [`dropout`](../../api_docs/python/nn.md#dropout) * [`elu`](../../api_docs/python/nn.md#elu) diff --git a/tensorflow/g3doc/api_docs/python/nn.md b/tensorflow/g3doc/api_docs/python/nn.md index 767677a3e2..3cc2244fc7 100644 --- a/tensorflow/g3doc/api_docs/python/nn.md +++ b/tensorflow/g3doc/api_docs/python/nn.md @@ -412,6 +412,43 @@ horizontal and vertical strides, `strides = [1, stride, stride, 1]`. A 4-D `Tensor` of shape `[batch, out_height, out_width, out_channels]`. +- - - + +### `tf.nn.conv2d_transpose(value, filter, output_shape, strides, padding='SAME', name=None)` {#conv2d_transpose} + +The transpose of `conv2d`. + +This operation is sometimes called "deconvolution" after +(Deconvolutional Networks)[http://www.matthewzeiler.com/pubs/cvpr2010/cvpr2010.pdf], +but is actually the transpose (gradient) of `conv2d` rather than an actual +deconvolution. + +##### Args: + + +* `value`: A 4-D `Tensor` of type `float` and shape + `[batch, height, width, in_channels]`. +* `filter`: A 4-D `Tensor` with the same type as `value` and shape + `[height, width, output_channels, in_channels]`. `filter`'s + `in_channels` dimension must match that of `value`. +* `output_shape`: A 1-D `Tensor` representing the output shape of the + deconvolution op. +* `strides`: A list of ints. The stride of the sliding window for each + dimension of the input tensor. +* `padding`: A string, either `'VALID'` or `'SAME'`. The padding algorithm. +* `name`: Optional name for the returned tensor. + +##### Returns: + + A `Tensor` with the same type as `value`. + +##### Raises: + + +* `ValueError`: If input/output depth does not match `filter`'s shape, or if + padding is other than `'VALID'` or `'SAME'`. + + ## Pooling diff --git a/tensorflow/python/framework/gen_docs_combined.py b/tensorflow/python/framework/gen_docs_combined.py index e998310a69..b75377a91b 100644 --- a/tensorflow/python/framework/gen_docs_combined.py +++ b/tensorflow/python/framework/gen_docs_combined.py @@ -92,7 +92,7 @@ def all_libraries(module_to_name, members, documented): prefix=PREFIX_TEXT), library("python_io", "Data IO (Python functions)", tf.python_io), library("nn", "Neural Network", tf.nn, - exclude_symbols=["deconv2d", "conv2d_backprop_input", + exclude_symbols=["conv2d_backprop_input", "conv2d_backprop_filter", "avg_pool_grad", "max_pool_grad", "max_pool_grad_with_argmax", "batch_norm_with_global_normalization_grad", diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py index 6fecea8666..7688070269 100644 --- a/tensorflow/python/ops/nn.py +++ b/tensorflow/python/ops/nn.py @@ -106,6 +106,7 @@ concatenated. @@conv2d @@depthwise_conv2d @@separable_conv2d +@@conv2d_transpose ## Pooling diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index b4b6b3b0c1..a1a4c0c6c4 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -28,7 +28,7 @@ from tensorflow.python.ops import gen_nn_ops @ops.RegisterGradient("Conv2DBackpropInput") -def _DeConv2DGrad(op, grad): +def _Conv2DBackpropGrad(op, grad): """The derivatives for deconvolution. Args: diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 6b2a1f45a6..b6e459c27f 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -39,12 +39,14 @@ from tensorflow.python.ops.gen_nn_ops import * local_response_normalization = gen_nn_ops.lrn -def deconv2d(value, filter, output_shape, strides, padding="SAME", - name=None): +def conv2d_transpose(value, filter, output_shape, strides, padding="SAME", + name=None): """The transpose of `conv2d`. - This used to be called "deconvolution", but it is actually the transpose - (gradient) of `conv2d`, not an actual deconvolution. + This operation is sometimes called "deconvolution" after (Deconvolutional + Networks)[http://www.matthewzeiler.com/pubs/cvpr2010/cvpr2010.pdf], but is + actually the transpose (gradient) of `conv2d` rather than an actual + deconvolution. Args: value: A 4-D `Tensor` of type `float` and shape @@ -66,7 +68,8 @@ def deconv2d(value, filter, output_shape, strides, padding="SAME", ValueError: If input/output depth does not match `filter`'s shape, or if padding is other than `'VALID'` or `'SAME'`. """ - with ops.op_scope([value, filter, output_shape], name, "DeConv2D") as name: + with ops.op_scope([value, filter, output_shape], name, + "conv2d_transpose") as name: value = ops.convert_to_tensor(value, name="value") filter = ops.convert_to_tensor(filter, name="filter") if not value.get_shape()[3].is_compatible_with(filter.get_shape()[3]): diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index 4146803c25..35bbaeea0b 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -146,9 +146,9 @@ class SoftmaxTest(tf.test.TestCase): self.assertLess(err, eps) -class DeConv2DTest(tf.test.TestCase): +class Conv2DTransposeTest(tf.test.TestCase): - def testDeConv2DSingleStride(self): + def testConv2DTransposeSingleStride(self): with self.test_session(): strides = [1, 1, 1, 1] @@ -161,7 +161,8 @@ class DeConv2DTest(tf.test.TestCase): x = tf.constant(1.0, shape=x_shape, name="x", dtype=tf.float32) f = tf.constant(1.0, shape=f_shape, name="filter", dtype=tf.float32) - output = tf.nn.deconv2d(x, f, y_shape, strides=strides, padding="SAME") + output = tf.nn.conv2d_transpose(x, f, y_shape, strides=strides, + padding="SAME") value = output.eval() # We count the number of cells being added at the locations in the output. @@ -183,7 +184,7 @@ class DeConv2DTest(tf.test.TestCase): target += 2 * 3.0 self.assertAllClose(target, value[n, h, w, k]) - def testDeConv2DSame(self): + def testConv2DTransposeSame(self): with self.test_session(): strides = [1, 2, 2, 1] @@ -196,7 +197,8 @@ class DeConv2DTest(tf.test.TestCase): x = tf.constant(1.0, shape=x_shape, name="x", dtype=tf.float32) f = tf.constant(1.0, shape=f_shape, name="filter", dtype=tf.float32) - output = tf.nn.deconv2d(x, f, y_shape, strides=strides, padding="SAME") + output = tf.nn.conv2d_transpose(x, f, y_shape, strides=strides, + padding="SAME") value = output.eval() for n in xrange(x_shape[0]): @@ -213,7 +215,7 @@ class DeConv2DTest(tf.test.TestCase): target += 3.0 self.assertAllClose(target, value[n, h, w, k]) - def testDeConv2DValid(self): + def testConv2DTransposeValid(self): with self.test_session(): strides = [1, 2, 2, 1] @@ -226,7 +228,8 @@ class DeConv2DTest(tf.test.TestCase): x = tf.constant(1.0, shape=x_shape, name="x", dtype=tf.float32) f = tf.constant(1.0, shape=f_shape, name="filter", dtype=tf.float32) - output = tf.nn.deconv2d(x, f, y_shape, strides=strides, padding="VALID") + output = tf.nn.conv2d_transpose(x, f, y_shape, strides=strides, + padding="VALID") value = output.eval() cache_values = np.zeros(y_shape, dtype=np.float32) @@ -269,7 +272,8 @@ class DeConv2DTest(tf.test.TestCase): with self.test_session(): x = tf.constant(x_val, name="x", dtype=tf.float32) f = tf.constant(f_val, name="f", dtype=tf.float32) - output = tf.nn.deconv2d(x, f, y_shape, strides=strides, padding="SAME") + output = tf.nn.conv2d_transpose(x, f, y_shape, strides=strides, + padding="SAME") err = tf.test.compute_gradient_error( [x, f], [x_shape, f_shape], output, y_shape) print("DeConv gradient err = %g " % err) -- cgit v1.2.3