aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-01-06 12:55:10 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2016-01-06 12:55:10 -0800
commit9edd13a9c96edd9b7fdca71f694bb045f265ad66 (patch)
tree9381bc2a8bb45b9c3e61ea4146ba8156629c8983
parent3ae059ec4952493366510e4180773a2e2f398510 (diff)
Rename deconv2d to conv2d_transpose and expose publicly
deconv2d is a misleading name for an operation which is exactly the transpose of conv2d. conv2d_transpose is much better. Fixes #256. Change: 111529902
-rw-r--r--tensorflow/g3doc/api_docs/python/index.md1
-rw-r--r--tensorflow/g3doc/api_docs/python/nn.md37
-rw-r--r--tensorflow/python/framework/gen_docs_combined.py2
-rw-r--r--tensorflow/python/ops/nn.py1
-rw-r--r--tensorflow/python/ops/nn_grad.py2
-rw-r--r--tensorflow/python/ops/nn_ops.py13
-rw-r--r--tensorflow/python/ops/nn_test.py20
7 files changed, 61 insertions, 15 deletions
diff --git a/tensorflow/g3doc/api_docs/python/index.md b/tensorflow/g3doc/api_docs/python/index.md
index efd38ed148..3adb42218f 100644
--- a/tensorflow/g3doc/api_docs/python/index.md
+++ b/tensorflow/g3doc/api_docs/python/index.md
@@ -289,6 +289,7 @@
* [`bias_add`](../../api_docs/python/nn.md#bias_add)
* [`compute_accidental_hits`](../../api_docs/python/nn.md#compute_accidental_hits)
* [`conv2d`](../../api_docs/python/nn.md#conv2d)
+ * [`conv2d_transpose`](../../api_docs/python/nn.md#conv2d_transpose)
* [`depthwise_conv2d`](../../api_docs/python/nn.md#depthwise_conv2d)
* [`dropout`](../../api_docs/python/nn.md#dropout)
* [`elu`](../../api_docs/python/nn.md#elu)
diff --git a/tensorflow/g3doc/api_docs/python/nn.md b/tensorflow/g3doc/api_docs/python/nn.md
index 767677a3e2..3cc2244fc7 100644
--- a/tensorflow/g3doc/api_docs/python/nn.md
+++ b/tensorflow/g3doc/api_docs/python/nn.md
@@ -412,6 +412,43 @@ horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
A 4-D `Tensor` of shape `[batch, out_height, out_width, out_channels]`.
+- - -
+
+### `tf.nn.conv2d_transpose(value, filter, output_shape, strides, padding='SAME', name=None)` {#conv2d_transpose}
+
+The transpose of `conv2d`.
+
+This operation is sometimes called "deconvolution" after
+(Deconvolutional Networks)[http://www.matthewzeiler.com/pubs/cvpr2010/cvpr2010.pdf],
+but is actually the transpose (gradient) of `conv2d` rather than an actual
+deconvolution.
+
+##### Args:
+
+
+* <b>`value`</b>: A 4-D `Tensor` of type `float` and shape
+ `[batch, height, width, in_channels]`.
+* <b>`filter`</b>: A 4-D `Tensor` with the same type as `value` and shape
+ `[height, width, output_channels, in_channels]`. `filter`'s
+ `in_channels` dimension must match that of `value`.
+* <b>`output_shape`</b>: A 1-D `Tensor` representing the output shape of the
+ deconvolution op.
+* <b>`strides`</b>: A list of ints. The stride of the sliding window for each
+ dimension of the input tensor.
+* <b>`padding`</b>: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
+* <b>`name`</b>: Optional name for the returned tensor.
+
+##### Returns:
+
+ A `Tensor` with the same type as `value`.
+
+##### Raises:
+
+
+* <b>`ValueError`</b>: If input/output depth does not match `filter`'s shape, or if
+ padding is other than `'VALID'` or `'SAME'`.
+
+
## Pooling
diff --git a/tensorflow/python/framework/gen_docs_combined.py b/tensorflow/python/framework/gen_docs_combined.py
index e998310a69..b75377a91b 100644
--- a/tensorflow/python/framework/gen_docs_combined.py
+++ b/tensorflow/python/framework/gen_docs_combined.py
@@ -92,7 +92,7 @@ def all_libraries(module_to_name, members, documented):
prefix=PREFIX_TEXT),
library("python_io", "Data IO (Python functions)", tf.python_io),
library("nn", "Neural Network", tf.nn,
- exclude_symbols=["deconv2d", "conv2d_backprop_input",
+ exclude_symbols=["conv2d_backprop_input",
"conv2d_backprop_filter", "avg_pool_grad",
"max_pool_grad", "max_pool_grad_with_argmax",
"batch_norm_with_global_normalization_grad",
diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py
index 6fecea8666..7688070269 100644
--- a/tensorflow/python/ops/nn.py
+++ b/tensorflow/python/ops/nn.py
@@ -106,6 +106,7 @@ concatenated.
@@conv2d
@@depthwise_conv2d
@@separable_conv2d
+@@conv2d_transpose
## Pooling
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index b4b6b3b0c1..a1a4c0c6c4 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -28,7 +28,7 @@ from tensorflow.python.ops import gen_nn_ops
@ops.RegisterGradient("Conv2DBackpropInput")
-def _DeConv2DGrad(op, grad):
+def _Conv2DBackpropGrad(op, grad):
"""The derivatives for deconvolution.
Args:
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 6b2a1f45a6..b6e459c27f 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -39,12 +39,14 @@ from tensorflow.python.ops.gen_nn_ops import *
local_response_normalization = gen_nn_ops.lrn
-def deconv2d(value, filter, output_shape, strides, padding="SAME",
- name=None):
+def conv2d_transpose(value, filter, output_shape, strides, padding="SAME",
+ name=None):
"""The transpose of `conv2d`.
- This used to be called "deconvolution", but it is actually the transpose
- (gradient) of `conv2d`, not an actual deconvolution.
+ This operation is sometimes called "deconvolution" after (Deconvolutional
+ Networks)[http://www.matthewzeiler.com/pubs/cvpr2010/cvpr2010.pdf], but is
+ actually the transpose (gradient) of `conv2d` rather than an actual
+ deconvolution.
Args:
value: A 4-D `Tensor` of type `float` and shape
@@ -66,7 +68,8 @@ def deconv2d(value, filter, output_shape, strides, padding="SAME",
ValueError: If input/output depth does not match `filter`'s shape, or if
padding is other than `'VALID'` or `'SAME'`.
"""
- with ops.op_scope([value, filter, output_shape], name, "DeConv2D") as name:
+ with ops.op_scope([value, filter, output_shape], name,
+ "conv2d_transpose") as name:
value = ops.convert_to_tensor(value, name="value")
filter = ops.convert_to_tensor(filter, name="filter")
if not value.get_shape()[3].is_compatible_with(filter.get_shape()[3]):
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py
index 4146803c25..35bbaeea0b 100644
--- a/tensorflow/python/ops/nn_test.py
+++ b/tensorflow/python/ops/nn_test.py
@@ -146,9 +146,9 @@ class SoftmaxTest(tf.test.TestCase):
self.assertLess(err, eps)
-class DeConv2DTest(tf.test.TestCase):
+class Conv2DTransposeTest(tf.test.TestCase):
- def testDeConv2DSingleStride(self):
+ def testConv2DTransposeSingleStride(self):
with self.test_session():
strides = [1, 1, 1, 1]
@@ -161,7 +161,8 @@ class DeConv2DTest(tf.test.TestCase):
x = tf.constant(1.0, shape=x_shape, name="x", dtype=tf.float32)
f = tf.constant(1.0, shape=f_shape, name="filter", dtype=tf.float32)
- output = tf.nn.deconv2d(x, f, y_shape, strides=strides, padding="SAME")
+ output = tf.nn.conv2d_transpose(x, f, y_shape, strides=strides,
+ padding="SAME")
value = output.eval()
# We count the number of cells being added at the locations in the output.
@@ -183,7 +184,7 @@ class DeConv2DTest(tf.test.TestCase):
target += 2 * 3.0
self.assertAllClose(target, value[n, h, w, k])
- def testDeConv2DSame(self):
+ def testConv2DTransposeSame(self):
with self.test_session():
strides = [1, 2, 2, 1]
@@ -196,7 +197,8 @@ class DeConv2DTest(tf.test.TestCase):
x = tf.constant(1.0, shape=x_shape, name="x", dtype=tf.float32)
f = tf.constant(1.0, shape=f_shape, name="filter", dtype=tf.float32)
- output = tf.nn.deconv2d(x, f, y_shape, strides=strides, padding="SAME")
+ output = tf.nn.conv2d_transpose(x, f, y_shape, strides=strides,
+ padding="SAME")
value = output.eval()
for n in xrange(x_shape[0]):
@@ -213,7 +215,7 @@ class DeConv2DTest(tf.test.TestCase):
target += 3.0
self.assertAllClose(target, value[n, h, w, k])
- def testDeConv2DValid(self):
+ def testConv2DTransposeValid(self):
with self.test_session():
strides = [1, 2, 2, 1]
@@ -226,7 +228,8 @@ class DeConv2DTest(tf.test.TestCase):
x = tf.constant(1.0, shape=x_shape, name="x", dtype=tf.float32)
f = tf.constant(1.0, shape=f_shape, name="filter", dtype=tf.float32)
- output = tf.nn.deconv2d(x, f, y_shape, strides=strides, padding="VALID")
+ output = tf.nn.conv2d_transpose(x, f, y_shape, strides=strides,
+ padding="VALID")
value = output.eval()
cache_values = np.zeros(y_shape, dtype=np.float32)
@@ -269,7 +272,8 @@ class DeConv2DTest(tf.test.TestCase):
with self.test_session():
x = tf.constant(x_val, name="x", dtype=tf.float32)
f = tf.constant(f_val, name="f", dtype=tf.float32)
- output = tf.nn.deconv2d(x, f, y_shape, strides=strides, padding="SAME")
+ output = tf.nn.conv2d_transpose(x, f, y_shape, strides=strides,
+ padding="SAME")
err = tf.test.compute_gradient_error(
[x, f], [x_shape, f_shape], output, y_shape)
print("DeConv gradient err = %g " % err)