aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-05-04 07:46:46 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-04 08:52:26 -0700
commit6a187ccddaebb741ea77fc3201c6e36625f0aadb (patch)
tree48097a7dbc49a3256de30ef0d9ae631e940af83e
parente5df6adc63653f31e9d5d6c539f799539cfbbed1 (diff)
Add support for 3d convolutions and pooling. CPU kernels use Eigen, GPU kernels use CuDNN.
Change: 121484787
-rw-r--r--tensorflow/core/kernels/BUILD21
-rw-r--r--tensorflow/core/kernels/conv_2d.h87
-rw-r--r--tensorflow/core/kernels/conv_3d.h48
-rw-r--r--tensorflow/core/kernels/conv_grad_ops.cc96
-rw-r--r--tensorflow/core/kernels/conv_grad_ops_3d.cc739
-rw-r--r--tensorflow/core/kernels/conv_ops.cc280
-rw-r--r--tensorflow/core/kernels/conv_ops_3d.cc355
-rw-r--r--tensorflow/core/kernels/conv_ops_gpu_3.cu.cc202
-rw-r--r--tensorflow/core/kernels/cudnn_pooling_gpu.cc216
-rw-r--r--tensorflow/core/kernels/cudnn_pooling_gpu.h65
-rw-r--r--tensorflow/core/kernels/ops_util.cc28
-rw-r--r--tensorflow/core/kernels/ops_util.h15
-rw-r--r--tensorflow/core/kernels/pooling_ops_3d.cc515
-rw-r--r--tensorflow/core/kernels/pooling_ops_common.cc24
-rw-r--r--tensorflow/core/ops/nn_ops.cc153
-rw-r--r--tensorflow/core/util/tensor_format.h52
-rw-r--r--tensorflow/python/BUILD1
-rw-r--r--tensorflow/python/kernel_tests/conv_ops_3d_test.py420
-rw-r--r--tensorflow/python/kernel_tests/conv_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/pooling_ops_3d_test.py340
-rw-r--r--tensorflow/python/kernel_tests/pooling_ops_test.py4
-rw-r--r--tensorflow/python/ops/common_shapes.py124
-rw-r--r--tensorflow/python/ops/nn.py3
-rw-r--r--tensorflow/python/ops/nn_grad.py101
-rw-r--r--tensorflow/python/ops/nn_ops.py158
25 files changed, 3591 insertions, 460 deletions
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 0307765ac0..a6a9fee68a 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -72,6 +72,16 @@ cc_library(
)
cc_library(
+ name = "conv_3d",
+ hdrs = ["conv_3d.h"],
+ deps = [
+ ":eigen_helpers",
+ "//tensorflow/core:framework",
+ "//third_party/eigen3",
+ ],
+)
+
+cc_library(
name = "fill_functor",
hdrs = ["fill_functor.h"],
deps = [
@@ -1106,11 +1116,15 @@ tf_cuda_cc_test(
# conv_ops_gpu.h has be separated into its own library.
tf_kernel_library(
name = "conv_ops",
- srcs = ["conv_grad_ops.cc"],
+ srcs = [
+ "conv_grad_ops.cc",
+ "conv_grad_ops_3d.cc",
+ ],
prefix = "conv_ops",
deps = [
":bounds_check",
":conv_2d",
+ ":conv_3d",
":ops_util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
@@ -1238,11 +1252,14 @@ tf_kernel_library(
name = "pooling_ops",
srcs = [
"avgpooling_op.cc",
+ "cudnn_pooling_gpu.cc",
"maxpooling_op.cc",
+ "pooling_ops_3d.cc",
"pooling_ops_common.cc",
],
hdrs = [
"avgpooling_op.h",
+ "cudnn_pooling_gpu.h",
"maxpooling_op.h",
"pooling_ops_common.h",
],
@@ -1257,6 +1274,8 @@ tf_kernel_library(
],
deps = [
":conv_2d",
+ ":conv_3d",
+ ":conv_ops",
":eigen_helpers",
":ops_util",
"//tensorflow/core:core_cpu",
diff --git a/tensorflow/core/kernels/conv_2d.h b/tensorflow/core/kernels/conv_2d.h
index c7d5c3aeeb..9bbc67520f 100644
--- a/tensorflow/core/kernels/conv_2d.h
+++ b/tensorflow/core/kernels/conv_2d.h
@@ -116,23 +116,31 @@ struct MatMulConvFunctor {
}
};
-template <typename Device, typename T, typename IndexType>
+// Shuffles a filter tensor from:
+// [<spatial_dims>, in, out]
+// to:
+// [out, in, <spatial_dims>]
+template <typename Device, typename T, typename IndexType, int NDIMS>
struct TransformFilter {
void operator()(const Device& d,
- typename TTypes<T, 4, IndexType>::ConstTensor in,
- typename TTypes<T, 4, IndexType>::Tensor out) {
- // We want a 3, 2, 0, 1 shuffle. We can merge dimensions 0 and 1 together
- // to help speedup the shuffle operation.
+ typename TTypes<T, NDIMS, IndexType>::ConstTensor in,
+ typename TTypes<T, NDIMS, IndexType>::Tensor out) {
+ // We want a 3, 2, 0, 1 shuffle. Merge the spatial dimensions together
+ // to speed up the shuffle operation.
Eigen::DSizes<IndexType, 3> merged_dims;
- merged_dims[0] = in.dimension(0) * in.dimension(1);
- merged_dims[1] = in.dimension(2);
- merged_dims[2] = in.dimension(3);
-
- Eigen::DSizes<IndexType, 4> expanded_dims;
- expanded_dims[0] = in.dimension(3);
- expanded_dims[1] = in.dimension(2);
- expanded_dims[2] = in.dimension(0);
- expanded_dims[3] = in.dimension(1);
+ merged_dims[0] = in.dimension(0); // spatial dimensions
+ for (int i = 1; i < NDIMS - 2; ++i) {
+ merged_dims[0] *= in.dimension(i);
+ }
+ merged_dims[1] = in.dimension(NDIMS - 2); // input filters
+ merged_dims[2] = in.dimension(NDIMS - 1); // output filters
+
+ Eigen::DSizes<IndexType, NDIMS> expanded_dims;
+ expanded_dims[0] = in.dimension(NDIMS - 1); // output filters
+ expanded_dims[1] = in.dimension(NDIMS - 2); // input filters
+ for (int i = 0; i < NDIMS; ++i) { // spatial dimensions
+ expanded_dims[i + 2] = in.dimension(i);
+ }
out.device(d) = in.reshape(merged_dims)
.shuffle(Eigen::DSizes<IndexType, 3>(2, 1, 0))
@@ -194,41 +202,50 @@ struct TransformDepth {
}
};
-template <typename Device, typename T, typename IndexType>
+template <typename Device, typename T, typename IndexType, int NDIMS>
struct PadInput {
void operator()(const Device& d,
- typename TTypes<T, 4, IndexType>::ConstTensor in,
- int padding_rows_left, int padding_rows_right,
- int padding_cols_left, int padding_cols_right,
- typename TTypes<T, 4, IndexType>::Tensor out,
+ typename TTypes<T, NDIMS, IndexType>::ConstTensor in,
+ const std::array<int, NDIMS - 2>& padding_left,
+ const std::array<int, NDIMS - 2>& padding_right,
+ typename TTypes<T, NDIMS, IndexType>::Tensor out,
TensorFormat format) {
- Eigen::array<std::pair<IndexType, IndexType>, 4> padding;
- padding[GetTensorDimIndex(format, 'N')] = std::make_pair(0, 0);
- padding[GetTensorDimIndex(format, 'H')] =
- std::make_pair(padding_rows_left, padding_rows_right);
- padding[GetTensorDimIndex(format, 'W')] =
- std::make_pair(padding_cols_left, padding_cols_right);
- padding[GetTensorDimIndex(format, 'C')] = std::make_pair(0, 0);
+ Eigen::array<std::pair<IndexType, IndexType>, NDIMS> padding;
+ padding[GetTensorDimIndex<NDIMS - 2>(format, 'N')] = std::make_pair(0, 0);
+ for (int i = 0; i < NDIMS - 2; ++i) {
+ padding[GetTensorDimIndex<NDIMS - 2>(format, '0' + i)] =
+ std::make_pair(padding_left[i], padding_right[i]);
+ }
+ padding[GetTensorDimIndex<NDIMS - 2>(format, 'C')] = std::make_pair(0, 0);
out.device(d) = in.pad(padding);
}
};
-template <typename Device, typename T>
+// Converts a tensor from:
+// [batch, <spatial>, filters]
+// to:
+// [batch, filters, <spatial>]
+template <typename Device, typename T, int NDIMS>
struct NHWCToNCHW {
- void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor in,
- typename TTypes<T, 4>::Tensor out);
+ void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in,
+ typename TTypes<T, NDIMS>::Tensor out);
};
-template <typename Device, typename T>
+// Converts a tensor from:
+// [batch, filters, <spatial>]
+// to:
+// [batch, <spatial>, filters]
+template <typename Device, typename T, int NDIMS>
struct NCHWToNHWC {
- void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor in,
- typename TTypes<T, 4>::Tensor out);
+ void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in,
+ typename TTypes<T, NDIMS>::Tensor out);
};
-template <typename Device, typename T>
+// Reverses the effect of TransformFilter above.
+template <typename Device, typename T, int NDIMS>
struct ReverseTransformFilter {
- void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor in,
- typename TTypes<T, 4>::Tensor out);
+ void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in,
+ typename TTypes<T, NDIMS>::Tensor out);
};
} // namespace functor
diff --git a/tensorflow/core/kernels/conv_3d.h b/tensorflow/core/kernels/conv_3d.h
new file mode 100644
index 0000000000..af3841ad4a
--- /dev/null
+++ b/tensorflow/core/kernels/conv_3d.h
@@ -0,0 +1,48 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Functors for 3d convolution.
+
+#ifndef TENSORFLOW_KERNELS_CONV_3D_H_
+#define TENSORFLOW_KERNELS_CONV_3D_H_
+
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/eigen_cuboid_convolution.h"
+
+namespace tensorflow {
+namespace functor {
+
+// Applies a 3D convolution to a batch of multi-channel volumes.
+template <typename Device, typename T>
+struct CuboidConvolution;
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+template <typename T>
+struct CuboidConvolution<CPUDevice, T> {
+ void operator()(const CPUDevice& d, typename TTypes<T, 5>::Tensor output,
+ typename TTypes<T, 5>::ConstTensor input,
+ typename TTypes<T, 5>::ConstTensor filter, int stride_planes,
+ int stride_rows, int stride_cols,
+ const Eigen::PaddingType& padding) {
+ output.device(d) = Eigen::CuboidConvolution(
+ input, filter, stride_planes, stride_rows, stride_cols, padding);
+ }
+};
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_CONV_3D_H_
diff --git a/tensorflow/core/kernels/conv_grad_ops.cc b/tensorflow/core/kernels/conv_grad_ops.cc
index f5daa9b2ec..84cc7017c4 100644
--- a/tensorflow/core/kernels/conv_grad_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_ops.cc
@@ -946,7 +946,7 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
filter_rows, filter_cols}),
&transformed_filter));
- functor::TransformFilter<Device, T, int>()(
+ functor::TransformFilter<Device, T, int, 4>()(
context->eigen_device<Device>(), To32Bit(filter.tensor<T, 4>()),
To32Bit(transformed_filter.tensor<T, 4>()));
@@ -959,9 +959,9 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
output_cols, out_depth),
&transformed_out_backprop));
- functor::NHWCToNCHW<Device, T>()(context->eigen_device<Device>(),
- out_backprop.tensor<T, 4>(),
- transformed_out_backprop.tensor<T, 4>());
+ functor::NHWCToNCHW<Device, T, 4>()(
+ context->eigen_device<Device>(), out_backprop.tensor<T, 4>(),
+ transformed_out_backprop.tensor<T, 4>());
} else {
transformed_out_backprop = out_backprop;
}
@@ -1022,11 +1022,11 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
&in_backprop_remove_padding));
// Remove the padding for odd rows or cols.
- functor::PadInput<GPUDevice, T, int>()(
+ functor::PadInput<GPUDevice, T, int, 4>()(
context->template eigen_device<GPUDevice>(),
To32Bit(const_cast<const Tensor&>(pre_transformed_in_backprop)
.tensor<T, 4>()),
- 0, -rows_odd, 0, -cols_odd,
+ {{0, 0}}, {{-rows_odd, -cols_odd}},
To32Bit(in_backprop_remove_padding.tensor<T, 4>()), FORMAT_NCHW);
pre_transformed_in_backprop = in_backprop_remove_padding;
@@ -1034,7 +1034,7 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
if (data_format_ == FORMAT_NHWC) {
auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
- functor::NCHWToNHWC<Device, T>()(
+ functor::NCHWToNHWC<Device, T, 4>()(
context->eigen_device<Device>(),
toConstTensor(pre_transformed_in_backprop).template tensor<T, 4>(),
in_backprop->tensor<T, 4>());
@@ -1167,9 +1167,9 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
input_cols + cols_odd, in_depth),
&compatible_input));
- functor::PadInput<GPUDevice, T, int>()(
+ functor::PadInput<GPUDevice, T, int, 4>()(
context->template eigen_device<GPUDevice>(),
- To32Bit(input.tensor<T, 4>()), 0, rows_odd, 0, cols_odd,
+ To32Bit(input.tensor<T, 4>()), {{0, 0}}, {{rows_odd, cols_odd}},
To32Bit(compatible_input.tensor<T, 4>()), data_format_);
} else {
compatible_input = input;
@@ -1227,9 +1227,9 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
ShapeFromFormat(FORMAT_NCHW, batch, output_rows,
output_cols, out_depth),
&transformed_out_backprop));
- functor::NHWCToNCHW<Device, T>()(context->eigen_device<Device>(),
- out_backprop.tensor<T, 4>(),
- transformed_out_backprop.tensor<T, 4>());
+ functor::NHWCToNCHW<Device, T, 4>()(
+ context->eigen_device<Device>(), out_backprop.tensor<T, 4>(),
+ transformed_out_backprop.tensor<T, 4>());
} else {
transformed_out_backprop = out_backprop;
}
@@ -1246,7 +1246,7 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
GetTensorDim(compatible_input, data_format_, 'W'),
GetTensorDim(compatible_input, data_format_, 'C')),
&transformed_input));
- functor::NHWCToNCHW<Device, T>()(
+ functor::NHWCToNCHW<Device, T, 4>()(
context->eigen_device<Device>(),
const_cast<const Tensor&>(compatible_input).tensor<T, 4>(),
transformed_input.tensor<T, 4>());
@@ -1284,7 +1284,7 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
}
auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
- functor::ReverseTransformFilter<Device, T>()(
+ functor::ReverseTransformFilter<Device, T, 4>()(
context->eigen_device<Device>(),
toConstTensor(pre_transformed_filter_backprop).template tensor<T, 4>(),
filter_backprop->tensor<T, 4>());
@@ -1301,40 +1301,40 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
// Forward declarations of the functor specializations for GPU.
namespace functor {
-#define DECLARE_GPU_SPEC(T) \
- template <> \
- void ShuffleAndReverse<GPUDevice, T, 4, int>::operator()( \
- const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor input, \
- const Eigen::DSizes<int, 4>& order, \
- const Eigen::array<bool, 4>& reverse_dims, \
- typename TTypes<T, 4, int>::Tensor output); \
- extern template struct ShuffleAndReverse<GPUDevice, T, 4, int>; \
- template <> \
- void InflatePadAndShuffle<GPUDevice, T, 4, int>::operator()( \
- const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor input, \
- const Eigen::DSizes<int, 4>& strides, \
- const Eigen::array<Eigen::IndexPair<int>, 4>& pad_dims, \
- const Eigen::DSizes<int, 4>& order, \
- typename TTypes<T, 4, int>::Tensor output); \
- extern template struct InflatePadAndShuffle<GPUDevice, T, 4, int>; \
- template <> \
- void TransformFilter<GPUDevice, T, int>::operator()( \
- const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
- typename TTypes<T, 4, int>::Tensor out); \
- extern template struct TransformFilter<GPUDevice, T, int>; \
- template <> \
- void TransformDepth<GPUDevice, T, int>::operator()( \
- const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
- const Eigen::DSizes<int, 4>& shuffle, \
- typename TTypes<T, 4, int>::Tensor out); \
- extern template struct TransformDepth<GPUDevice, T, int>; \
- template <> \
- void PadInput<GPUDevice, T, int>::operator()( \
- const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
- int padding_rows_left, int padding_rows_right, int padding_cols_left, \
- int padding_cols_right, typename TTypes<T, 4, int>::Tensor out, \
- TensorFormat data_format); \
- extern template struct PadInput<GPUDevice, T, int>;
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void ShuffleAndReverse<GPUDevice, T, 4, int>::operator()( \
+ const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor input, \
+ const Eigen::DSizes<int, 4>& order, \
+ const Eigen::array<bool, 4>& reverse_dims, \
+ typename TTypes<T, 4, int>::Tensor output); \
+ extern template struct ShuffleAndReverse<GPUDevice, T, 4, int>; \
+ template <> \
+ void InflatePadAndShuffle<GPUDevice, T, 4, int>::operator()( \
+ const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor input, \
+ const Eigen::DSizes<int, 4>& strides, \
+ const Eigen::array<Eigen::IndexPair<int>, 4>& pad_dims, \
+ const Eigen::DSizes<int, 4>& order, \
+ typename TTypes<T, 4, int>::Tensor output); \
+ extern template struct InflatePadAndShuffle<GPUDevice, T, 4, int>; \
+ template <> \
+ void TransformFilter<GPUDevice, T, int, 4>::operator()( \
+ const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
+ typename TTypes<T, 4, int>::Tensor out); \
+ extern template struct TransformFilter<GPUDevice, T, int, 4>; \
+ template <> \
+ void TransformDepth<GPUDevice, T, int>::operator()( \
+ const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
+ const Eigen::DSizes<int, 4>& shuffle, \
+ typename TTypes<T, 4, int>::Tensor out); \
+ extern template struct TransformDepth<GPUDevice, T, int>; \
+ template <> \
+ void PadInput<GPUDevice, T, int, 4>::operator()( \
+ const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
+ const std::array<int, 2>& padding_left, \
+ const std::array<int, 2>& padding_right, \
+ typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format); \
+ extern template struct PadInput<GPUDevice, T, int, 4>;
DECLARE_GPU_SPEC(float);
#undef DECLARE_GPU_SPEC
diff --git a/tensorflow/core/kernels/conv_grad_ops_3d.cc b/tensorflow/core/kernels/conv_grad_ops_3d.cc
new file mode 100644
index 0000000000..1be72034a4
--- /dev/null
+++ b/tensorflow/core/kernels/conv_grad_ops_3d.cc
@@ -0,0 +1,739 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define USE_EIGEN_TENSOR
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/conv_3d.h"
+
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/kernels/conv_2d.h"
+#include "tensorflow/core/kernels/conv_ops_gpu.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/util/padding.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+#if GOOGLE_CUDA
+#include "tensorflow/core/platform/stream_executor.h"
+using perftools::gputools::dnn::DimIndex;
+#endif
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+// TODO(mjanusz): Get rid of the macro and return shapes directly.
+#define EXTRACT_AND_VERIFY_DIMENSIONS(label) \
+ const Tensor& input = context->input(0); \
+ const Tensor& filter = context->input(1); \
+ const Tensor& out_backprop = context->input(2); \
+ OP_REQUIRES( \
+ context, input.dims() == 5, \
+ errors::InvalidArgument(label, ": input must be 5-dimensional")); \
+ OP_REQUIRES( \
+ context, filter.dims() == 5, \
+ errors::InvalidArgument(label, ": filter must be 5-dimensional")); \
+ OP_REQUIRES( \
+ context, out_backprop.dims() == 5, \
+ errors::InvalidArgument(label, ": out_backprop must be 5-dimensional")); \
+ const int64 batch = input.dim_size(0); \
+ OP_REQUIRES( \
+ context, batch == out_backprop.dim_size(0), \
+ errors::InvalidArgument( \
+ label, ": input and out_backprop must have the same batch size")); \
+ const std::array<int64, 3> input_size = { \
+ {input.dim_size(1), input.dim_size(2), input.dim_size(3)}}; \
+ const std::array<int64, 3> filter_size = { \
+ {filter.dim_size(0), filter.dim_size(1), filter.dim_size(2)}}; \
+ const int64 output_cols = out_backprop.dim_size(3); \
+ const int64 output_rows = out_backprop.dim_size(2); \
+ const int64 output_planes = out_backprop.dim_size(1); \
+ const int64 in_depth = input.dim_size(4); \
+ OP_REQUIRES(context, in_depth == filter.dim_size(3), \
+ errors::InvalidArgument( \
+ label, ": input and filter must have the same depth")); \
+ const int64 out_depth = filter.dim_size(4); \
+ OP_REQUIRES( \
+ context, out_depth == out_backprop.dim_size(4), \
+ errors::InvalidArgument( \
+ label, ": filter and out_backprop must have the same out_depth")); \
+ const std::array<int64, 3> strides = {{stride_[1], stride_[2], stride_[3]}}; \
+ std::array<int64, 3> out, padding; \
+ OP_REQUIRES_OK(context, Get3dOutputSize(input_size, filter_size, strides, \
+ padding_, &out, &padding)); \
+ OP_REQUIRES(context, output_planes == out[0], \
+ errors::InvalidArgument( \
+ label, \
+ ": Number of planes of out_backprop doesn't match " \
+ "computed: actual = ", \
+ output_planes, ", computed = ", out[0])); \
+ OP_REQUIRES( \
+ context, output_rows == out[1], \
+ errors::InvalidArgument( \
+ label, ": Number of rows of out_backprop doesn't match computed: ", \
+ "actual = ", output_rows, ", computed = ", out[1])); \
+ OP_REQUIRES( \
+ context, output_cols == out[2], \
+ errors::InvalidArgument( \
+ label, ": Number of cols of out_backprop doesn't match computed: ", \
+ "actual = ", output_cols, ", computed = ", out[2])); \
+ const auto expanded_out_planes = (output_planes - 1) * strides[0] + 1; \
+ const auto expanded_out_rows = (output_rows - 1) * strides[1] + 1; \
+ const auto expanded_out_cols = (output_cols - 1) * strides[2] + 1; \
+ const auto padded_out_planes = input_size[0] + filter_size[0] - 1; \
+ const auto padded_out_rows = input_size[1] + filter_size[1] - 1; \
+ const auto padded_out_cols = input_size[2] + filter_size[2] - 1; \
+ const auto top_pad_planes = filter_size[0] - 1 - padding[0]; \
+ const auto top_pad_rows = filter_size[1] - 1 - padding[1]; \
+ const auto left_pad_cols = filter_size[2] - 1 - padding[2]; \
+ const auto bottom_pad_planes = \
+ padded_out_planes - expanded_out_planes - top_pad_planes; \
+ const auto bottom_pad_rows = \
+ padded_out_rows - expanded_out_rows - top_pad_rows; \
+ const auto right_pad_cols = \
+ padded_out_cols - expanded_out_cols - left_pad_cols; \
+ VLOG(2) << "Conv3d: " << label \
+ << ": expanded_out_planes = " << expanded_out_planes \
+ << ": expanded_out_rows = " << expanded_out_rows \
+ << ", expanded_out_cols = " << expanded_out_cols \
+ << ", padded_out_planes = " << padded_out_planes \
+ << ", padded_out_rows = " << padded_out_rows \
+ << ", padded_out_cols = " << padded_out_cols \
+ << ", top_pad_planes = " << top_pad_planes \
+ << ", top_pad_rows = " << top_pad_rows \
+ << ", left_pad_cols = " << left_pad_cols \
+ << ", bottom_pad_planes = " << bottom_pad_planes \
+ << ", bottom_pad_rows = " << bottom_pad_rows \
+ << ", right_pad_cols = " << right_pad_cols
+
+// Backprop for input.
+template <typename Device, class T>
+class Conv3DBackpropInputOp : public OpKernel {
+ public:
+ explicit Conv3DBackpropInputOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 5,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES(
+ context, (stride_[0] == 1 && stride_[4] == 1),
+ errors::InvalidArgument("Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropInput");
+ Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 5> pad_dims{
+ {0, 0},
+ {top_pad_planes, bottom_pad_planes},
+ {top_pad_rows, bottom_pad_rows},
+ {left_pad_cols, right_pad_cols},
+ {0, 0}};
+ Tensor* in_backprop;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input.shape(), &in_backprop));
+
+ // Fill out a padded out_backprop.
+ TensorShape padded_out_shape({batch, padded_out_planes, padded_out_rows,
+ padded_out_cols, out_depth});
+ Tensor padded_output;
+ OP_REQUIRES_OK(context,
+ context->allocate_temp(DataTypeToEnum<T>::v(),
+ padded_out_shape, &padded_output));
+ Eigen::DSizes<Eigen::DenseIndex, 5> no_op_shuffle{0, 1, 2, 3, 4};
+ Eigen::DSizes<Eigen::DenseIndex, 5> eigen_strides{1, strides[0], strides[1],
+ strides[2], 1};
+ functor::InflatePadAndShuffle<Device, T, 5, Eigen::DenseIndex>()(
+ context->eigen_device<Device>(), out_backprop.tensor<T, 5>(),
+ eigen_strides, pad_dims, no_op_shuffle, padded_output.tensor<T, 5>());
+ const Tensor& padded_output_cref = padded_output;
+
+ // Fill a new "reverted" filter. We need to transpose the in_depth and
+ // out_depth for the filter and reverse the planes, rows and cols.
+ TensorShape r_filter_shape(
+ {filter_size[0], filter_size[1], filter_size[2], out_depth, in_depth});
+ Tensor r_filter;
+ OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(),
+ r_filter_shape, &r_filter));
+ Eigen::DSizes<Eigen::DenseIndex, 5> filter_order{0, 1, 2, 4, 3};
+ Eigen::array<bool, 5> filter_rev_dims{true, true, true, false, false};
+ functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()(
+ context->eigen_device<Device>(), filter.tensor<T, 5>(), filter_order,
+ filter_rev_dims, r_filter.tensor<T, 5>());
+ const Tensor& r_filter_cref = r_filter;
+
+ // Now we can call conv_3d directly.
+ functor::CuboidConvolution<Device, T>()(
+ context->eigen_device<Device>(), in_backprop->tensor<T, 5>(),
+ padded_output_cref.tensor<T, 5>(), r_filter_cref.tensor<T, 5>(), 1, 1,
+ 1, BrainPadding2EigenPadding(VALID));
+ }
+
+ private:
+ std::vector<int32> stride_;
+ Padding padding_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("Conv3DBackpropInput").Device(DEVICE_CPU).TypeConstraint<float>("T"),
+ Conv3DBackpropInputOp<CPUDevice, float>);
+#ifndef __ANDROID__
+REGISTER_KERNEL_BUILDER(
+ Name("Conv3DBackpropInput").Device(DEVICE_CPU).TypeConstraint<double>("T"),
+ Conv3DBackpropInputOp<CPUDevice, double>);
+#endif
+
+// Backprop for filter.
+template <typename Device, class T>
+class Conv3DBackpropFilterOp : public OpKernel {
+ public:
+ explicit Conv3DBackpropFilterOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 5,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES(
+ context, (stride_[0] == 1 && stride_[4] == 1),
+ errors::InvalidArgument("Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropFilter");
+ Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 5> pad_dims{
+ {0, 0},
+ {top_pad_planes, bottom_pad_planes},
+ {top_pad_rows, bottom_pad_rows},
+ {left_pad_cols, right_pad_cols},
+ {0, 0}};
+ Tensor* filter_backprop;
+ OP_REQUIRES_OK(
+ context, context->allocate_output(0, filter.shape(), &filter_backprop));
+
+ // For the backprop of the filter, we need to also transpose the
+ // out_backprop.
+ // The shape of backprop is
+ // [batch, out_z, out_y, out_x, out_depth]
+ // And we need to change it to
+ // [out_depth, out_x, out_y, out_z, batch]
+ Eigen::DSizes<Eigen::DenseIndex, 5> out_order{4, 1, 2, 3, 0};
+ TensorShape padded_out_shape({out_depth, padded_out_planes, padded_out_rows,
+ padded_out_cols, batch});
+ Tensor padded_output;
+ OP_REQUIRES_OK(context,
+ context->allocate_temp(DataTypeToEnum<T>::v(),
+ padded_out_shape, &padded_output));
+ Eigen::DSizes<Eigen::DenseIndex, 5> eigen_strides{1, strides[0], strides[1],
+ strides[2], 1};
+ functor::InflatePadAndShuffle<Device, T, 5, Eigen::DenseIndex>()(
+ context->eigen_device<Device>(), out_backprop.tensor<T, 5>(),
+ eigen_strides, pad_dims, out_order, padded_output.tensor<T, 5>());
+ const Tensor& padded_output_cref = padded_output;
+
+ // For the backprop of the filter, we need to transpose the input.
+ // The shape of input is
+ // [batch, in_z, in_y, in_x, in_depth]
+ // And we need to change it to
+ // [in_z, in_y, in_x, batch, in_depth]
+ Eigen::DSizes<Eigen::DenseIndex, 5> in_order{1, 2, 3, 0, 4};
+ TensorShape in_shuffle_shape(
+ {input_size[0], input_size[1], input_size[2], batch, in_depth});
+ Tensor in_shuffle;
+ OP_REQUIRES_OK(context,
+ context->allocate_temp(DataTypeToEnum<T>::v(),
+ in_shuffle_shape, &in_shuffle));
+ // No need for reversing this time.
+ Eigen::array<bool, 5> no_reverse{false, false, false, false, false};
+ functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()(
+ context->eigen_device<Device>(), input.tensor<T, 5>(), in_order,
+ no_reverse, in_shuffle.tensor<T, 5>());
+ const Tensor& in_shuffle_cref = in_shuffle;
+
+ // The output of the conv_3d would be
+ // [out_depth, filter_size[2], filter_size[1], filter_size[0], in_depth]
+ // and we need to shuffle it back to
+ // [filter_size[2], filter_size[1], filter_size[0], in_depth, out_depth];
+ // And we need to reverse the filter backprops.
+ // So we need to allocate (sigh) yet another piece of memory to hold the
+ // output.
+ TensorShape filter_shuffle_shape(
+ {out_depth, filter_size[0], filter_size[1], filter_size[2], in_depth});
+ Tensor filter_shuffle;
+ OP_REQUIRES_OK(
+ context, context->allocate_temp(DataTypeToEnum<T>::v(),
+ filter_shuffle_shape, &filter_shuffle));
+ functor::CuboidConvolution<Device, T>()(
+ context->eigen_device<Device>(), filter_shuffle.tensor<T, 5>(),
+ padded_output_cref.tensor<T, 5>(), in_shuffle_cref.tensor<T, 5>(), 1, 1,
+ 1, BrainPadding2EigenPadding(VALID));
+
+ // Now copy the filter_backprop back to the destination.
+ Eigen::DSizes<Eigen::DenseIndex, 5> filter_order{1, 2, 3, 4, 0};
+ Eigen::array<bool, 5> filter_rev_dims{true, true, true, false, false};
+ const Tensor& filter_shuffle_cref = filter_shuffle;
+ functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()(
+ context->eigen_device<Device>(), filter_shuffle_cref.tensor<T, 5>(),
+ filter_order, filter_rev_dims, filter_backprop->tensor<T, 5>());
+ }
+
+ private:
+ std::vector<int32> stride_;
+ Padding padding_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<float>("T"),
+ Conv3DBackpropFilterOp<CPUDevice, float>);
+#ifndef __ANDROID__
+REGISTER_KERNEL_BUILDER(
+ Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<double>("T"),
+ Conv3DBackpropFilterOp<CPUDevice, double>);
+#endif
+
+// GPU definitions of both ops.
+#if GOOGLE_CUDA
+// Forward declarations of the functor specializations for GPU.
+// This ensures that the custom implementation is used instead of the default
+// Eigen one (which is used for CPU).
+namespace functor {
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void TransformFilter<GPUDevice, T, int, 5>::operator()( \
+ const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \
+ typename TTypes<T, 5, int>::Tensor out); \
+ template <> \
+ void ReverseTransformFilter<GPUDevice, T, 5>::operator()( \
+ const GPUDevice& d, typename TTypes<T, 5>::ConstTensor in, \
+ typename TTypes<T, 5>::Tensor out); \
+ template <> \
+ void PadInput<GPUDevice, T, int, 5>::operator()( \
+ const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \
+ const std::array<int, 3>& padding_left, \
+ const std::array<int, 3>& padding_right, \
+ typename TTypes<T, 5, int>::Tensor out, TensorFormat format);
+
+DECLARE_GPU_SPEC(float);
+#undef DECLARE_GPU_SPEC
+} // namespace functor
+
+template <typename T>
+class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
+ public:
+ explicit Conv3DBackpropInputOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 5,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES(
+ context, (stride_[0] == 1 && stride_[4] == 1),
+ errors::InvalidArgument("Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ }
+ void Compute(OpKernelContext* context) override {
+ EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropInput");
+ Tensor* in_backprop;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input.shape(), &in_backprop));
+
+ auto* stream = context->op_device_context()->stream();
+ OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
+
+ if (filter_size[1] == 1 && filter_size[2] == 1 && filter_size[0] == 1 &&
+ stride_[0] == 1 && stride_[1] == 1 && stride_[2] == 1) {
+ const uint64 m = batch * input_size[1] * input_size[2] * input_size[0];
+ const uint64 k = out_depth;
+ const uint64 n = in_depth;
+
+ auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
+ out_backprop.template flat<T>().size());
+ auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
+ filter.template flat<T>().size());
+ auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
+ in_backprop->template flat<T>().size());
+
+ auto transpose = perftools::gputools::blas::Transpose::kTranspose;
+ auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose;
+
+ bool blas_launch_status =
+ stream
+ ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k,
+ a_ptr, k, 0.0f, &c_ptr, n)
+ .ok();
+ if (!blas_launch_status) {
+ context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
+ ", n=", n, ", k=", k));
+ }
+ return;
+ }
+
+ int padding_rows = 0, padding_cols = 0, padding_planes = 0;
+
+ if (padding_ == Padding::SAME) {
+ padding_planes =
+ (output_planes - 1) * strides[0] + filter_size[0] - input_size[0];
+ padding_cols =
+ (output_cols - 1) * strides[2] + filter_size[2] - input_size[2];
+ padding_rows =
+ (output_rows - 1) * strides[1] + filter_size[1] - input_size[1];
+ }
+ const bool rows_odd = (padding_rows % 2 != 0);
+ const bool cols_odd = (padding_cols % 2 != 0);
+ const bool planes_odd = (padding_planes % 2 != 0);
+
+ TensorShape compatible_input_shape;
+ if (rows_odd || cols_odd || planes_odd) {
+ // cuDNN only supports the same amount of padding on both sides.
+ compatible_input_shape = {
+ batch,
+ in_depth,
+ input_size[0] + planes_odd,
+ input_size[1] + rows_odd,
+ input_size[2] + cols_odd,
+ };
+ } else {
+ compatible_input_shape = {batch, in_depth, input_size[0], input_size[1],
+ input_size[2]};
+ }
+
+ perftools::gputools::dnn::BatchDescriptor input_desc(3);
+ input_desc.set_count(batch)
+ .set_spatial_dim(DimIndex::X, compatible_input_shape.dim_size(4))
+ .set_spatial_dim(DimIndex::Y, compatible_input_shape.dim_size(3))
+ .set_spatial_dim(DimIndex::Z, compatible_input_shape.dim_size(2))
+ .set_feature_map_count(in_depth)
+ .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+ perftools::gputools::dnn::BatchDescriptor output_desc(3);
+ output_desc.set_count(batch)
+ .set_spatial_dim(DimIndex::X, output_cols)
+ .set_spatial_dim(DimIndex::Y, output_rows)
+ .set_spatial_dim(DimIndex::Z, output_planes)
+ .set_feature_map_count(out_depth)
+ .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+ perftools::gputools::dnn::FilterDescriptor filter_desc(3);
+ filter_desc.set_spatial_dim(DimIndex::X, filter_size[2])
+ .set_spatial_dim(DimIndex::Y, filter_size[1])
+ .set_spatial_dim(DimIndex::Z, filter_size[0])
+ .set_input_feature_map_count(in_depth)
+ .set_output_feature_map_count(out_depth);
+ perftools::gputools::dnn::ConvolutionDescriptor conv_desc(3);
+ conv_desc.set_filter_stride(DimIndex::X, strides[2])
+ .set_filter_stride(DimIndex::Y, strides[1])
+ .set_filter_stride(DimIndex::Z, strides[0])
+ .set_zero_padding(DimIndex::X, padding_cols / 2)
+ .set_zero_padding(DimIndex::Y, padding_rows / 2)
+ .set_zero_padding(DimIndex::Z, padding_planes / 2);
+
+ // Shape: out, in, z, y, x.
+ Tensor transformed_filter;
+ OP_REQUIRES_OK(
+ context,
+ context->allocate_temp(DataTypeToEnum<T>::value,
+ TensorShape({out_depth, in_depth, filter_size[0],
+ filter_size[1], filter_size[2]}),
+ &transformed_filter));
+ functor::TransformFilter<GPUDevice, T, int, 5>()(
+ context->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 5>()),
+ To32Bit(transformed_filter.tensor<T, 5>()));
+
+ // Shape: batch, filters, z, y, x.
+ Tensor transformed_out_backprop;
+ OP_REQUIRES_OK(context,
+ context->allocate_temp(DataTypeToEnum<T>::value,
+ {batch, out_depth, output_planes,
+ output_rows, output_cols},
+ &transformed_out_backprop));
+ functor::NHWCToNCHW<GPUDevice, T, 5>()(
+ context->eigen_device<GPUDevice>(), out_backprop.tensor<T, 5>(),
+ transformed_out_backprop.tensor<T, 5>());
+
+ // Shape: batch, filters, z, y, x.
+ Tensor pre_transformed_in_backprop;
+ OP_REQUIRES_OK(
+ context,
+ context->allocate_temp(DataTypeToEnum<T>::value, compatible_input_shape,
+ &pre_transformed_in_backprop));
+
+ auto out_backprop_ptr =
+ AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
+ transformed_out_backprop.template flat<T>().size());
+ auto filter_ptr =
+ AsDeviceMemory(transformed_filter.template flat<T>().data(),
+ transformed_filter.template flat<T>().size());
+ auto in_backprop_ptr =
+ AsDeviceMemory(pre_transformed_in_backprop.template flat<T>().data(),
+ pre_transformed_in_backprop.template flat<T>().size());
+
+ static int64 ConvolveBackwardDataScratchSize = GetCudnnWorkspaceLimit(
+ "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32); // 4GB by default
+
+ CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
+ context);
+ bool cudnn_launch_status =
+ stream
+ ->ThenConvolveBackwardDataWithScratch(
+ filter_desc, filter_ptr, output_desc, out_backprop_ptr,
+ conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator)
+ .ok();
+
+ if (!cudnn_launch_status) {
+ context->SetStatus(errors::Internal(
+ "cuDNN Backward Data function launch failure : input shape(",
+ input.shape().DebugString(), ") filter shape(",
+ filter.shape().DebugString(), ")"));
+ }
+
+ if (rows_odd || cols_odd || planes_odd) {
+ Tensor in_backprop_remove_padding;
+ OP_REQUIRES_OK(context,
+ context->allocate_temp(DataTypeToEnum<T>::value,
+ {batch, in_depth, input_size[0],
+ input_size[1], input_size[2]},
+ &in_backprop_remove_padding));
+
+ // Remove the padding for odd spatial dimensions.
+ functor::PadInput<GPUDevice, T, int, 5>()(
+ context->eigen_device<GPUDevice>(),
+ To32Bit(const_cast<const Tensor&>(pre_transformed_in_backprop)
+ .tensor<T, 5>()),
+ {{0, 0, 0}}, {{-planes_odd, -rows_odd, -cols_odd}},
+ To32Bit(in_backprop_remove_padding.tensor<T, 5>()), FORMAT_NCHW);
+
+ pre_transformed_in_backprop = in_backprop_remove_padding;
+ }
+ auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
+ functor::NCHWToNHWC<GPUDevice, T, 5>()(
+ context->eigen_device<GPUDevice>(),
+ toConstTensor(pre_transformed_in_backprop).template tensor<T, 5>(),
+ in_backprop->tensor<T, 5>());
+ }
+
+ private:
+ std::vector<int32> stride_;
+ Padding padding_;
+};
+
+template <typename T>
+class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
+ public:
+ explicit Conv3DBackpropFilterOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 5,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES(
+ context, (stride_[0] == 1 && stride_[4] == 1),
+ errors::InvalidArgument("Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropFilter");
+
+ Tensor* filter_backprop;
+ OP_REQUIRES_OK(
+ context, context->allocate_output(0, filter.shape(), &filter_backprop));
+
+ auto* stream = context->op_device_context()->stream();
+ OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
+
+ if (filter_size[1] == 1 && filter_size[2] == 1 && filter_size[0] == 1 &&
+ strides[2] == 1 && strides[1] == 1 && strides[0] == 1) {
+ const uint64 m = in_depth;
+ const uint64 k = batch * input_size[1] * input_size[2] * input_size[0];
+ const uint64 n = out_depth;
+
+ // The shape of output backprop is
+ // [batch, out_z, out_y, out_x, out_depth]
+ // From cublas's perspective, it is: n x k
+ auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
+ out_backprop.template flat<T>().size());
+
+ // The shape of input is:
+ // [batch, in_z, in_y, in_x, in_depth],
+ // From cublas's perspective, it is: m x k
+ auto b_ptr = AsDeviceMemory(input.template flat<T>().data(),
+ input.template flat<T>().size());
+
+ // The shape of the filter backprop is:
+ // [1, 1, 1, in_depth, out_depth]
+ // From cublas's perspective, it is: n x m
+ auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(),
+ filter_backprop->template flat<T>().size());
+
+ bool blas_launch_status =
+ stream
+ ->ThenBlasGemm(perftools::gputools::blas::Transpose::kNoTranspose,
+ perftools::gputools::blas::Transpose::kTranspose,
+ n, m, k, 1.0f, a_ptr, n, b_ptr, m, 0.0f, &c_ptr, n)
+ .ok();
+ if (!blas_launch_status) {
+ context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
+ ", n=", n, ", k=", k));
+ }
+ return;
+ }
+ int padding_rows = 0, padding_cols = 0, padding_planes = 0;
+
+ if (padding_ == Padding::SAME) {
+ padding_planes =
+ (output_planes - 1) * strides[0] + filter_size[0] - input_size[0];
+ padding_cols =
+ (output_cols - 1) * strides[2] + filter_size[2] - input_size[2];
+ padding_rows =
+ (output_rows - 1) * strides[1] + filter_size[1] - input_size[1];
+ }
+ bool rows_odd = (padding_rows % 2 != 0);
+ bool cols_odd = (padding_cols % 2 != 0);
+ bool planes_odd = (padding_planes % 2 != 0);
+
+ Tensor compatible_input;
+ if (rows_odd || cols_odd || planes_odd) {
+ OP_REQUIRES_OK(
+ context, context->allocate_temp(DataTypeToEnum<T>::value,
+ {batch, input_size[0] + planes_odd,
+ input_size[1] + rows_odd,
+ input_size[2] + cols_odd, in_depth},
+ &compatible_input));
+
+ functor::PadInput<GPUDevice, T, int, 5>()(
+ context->template eigen_device<GPUDevice>(),
+ To32Bit(input.tensor<T, 5>()), {{0, 0, 0}},
+ {{planes_odd, rows_odd, cols_odd}},
+ To32Bit(compatible_input.tensor<T, 5>()), FORMAT_NHWC);
+ } else {
+ compatible_input = input;
+ }
+
+ perftools::gputools::dnn::BatchDescriptor input_desc(3);
+ input_desc.set_count(batch)
+ .set_spatial_dim(DimIndex::X, compatible_input.dim_size(3))
+ .set_spatial_dim(DimIndex::Y, compatible_input.dim_size(2))
+ .set_spatial_dim(DimIndex::Z, compatible_input.dim_size(1))
+ .set_feature_map_count(in_depth)
+ .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+ perftools::gputools::dnn::BatchDescriptor output_desc(3);
+ output_desc.set_count(batch)
+ .set_spatial_dim(DimIndex::X, output_cols)
+ .set_spatial_dim(DimIndex::Y, output_rows)
+ .set_spatial_dim(DimIndex::Z, output_planes)
+ .set_feature_map_count(out_depth)
+ .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+ perftools::gputools::dnn::FilterDescriptor filter_desc(3);
+ filter_desc.set_spatial_dim(DimIndex::X, filter_size[2])
+ .set_spatial_dim(DimIndex::Y, filter_size[1])
+ .set_spatial_dim(DimIndex::Z, filter_size[0])
+ .set_input_feature_map_count(in_depth)
+ .set_output_feature_map_count(out_depth);
+ perftools::gputools::dnn::ConvolutionDescriptor conv_desc(3);
+ conv_desc.set_filter_stride(DimIndex::X, strides[2])
+ .set_filter_stride(DimIndex::Y, strides[1])
+ .set_filter_stride(DimIndex::Z, strides[0])
+ .set_zero_padding(DimIndex::X, padding_cols / 2)
+ .set_zero_padding(DimIndex::Y, padding_rows / 2)
+ .set_zero_padding(DimIndex::Z, padding_planes / 2);
+
+ Tensor pre_transformed_filter_backprop;
+ OP_REQUIRES_OK(
+ context,
+ context->allocate_temp(DataTypeToEnum<T>::value,
+ TensorShape({out_depth, in_depth, filter_size[0],
+ filter_size[1], filter_size[2]}),
+ &pre_transformed_filter_backprop));
+
+ Tensor transformed_out_backprop;
+ OP_REQUIRES_OK(context,
+ context->allocate_temp(DataTypeToEnum<T>::value,
+ {batch, out_depth, output_planes,
+ output_rows, output_cols},
+ &transformed_out_backprop));
+ functor::NHWCToNCHW<GPUDevice, T, 5>()(
+ context->eigen_device<GPUDevice>(), out_backprop.tensor<T, 5>(),
+ transformed_out_backprop.tensor<T, 5>());
+
+ Tensor transformed_input;
+ OP_REQUIRES_OK(context, context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ {batch, in_depth, compatible_input.dim_size(1),
+ compatible_input.dim_size(2),
+ compatible_input.dim_size(3)},
+ &transformed_input));
+ functor::NHWCToNCHW<GPUDevice, T, 5>()(
+ context->eigen_device<GPUDevice>(),
+ const_cast<const Tensor&>(compatible_input).tensor<T, 5>(),
+ transformed_input.tensor<T, 5>());
+
+ auto out_backprop_ptr =
+ AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
+ transformed_out_backprop.template flat<T>().size());
+ auto filter_backprop_ptr = AsDeviceMemory(
+ pre_transformed_filter_backprop.template flat<T>().data(),
+ pre_transformed_filter_backprop.template flat<T>().size());
+ auto input_ptr =
+ AsDeviceMemory(transformed_input.template flat<T>().data(),
+ transformed_input.template flat<T>().size());
+
+ static int64 ConvolveBackwardFilterScratchSize = GetCudnnWorkspaceLimit(
+ "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32); // 4GB by default
+ CudnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
+ context);
+ bool cudnn_launch_status =
+ stream
+ ->ThenConvolveBackwardFilterWithScratch(
+ input_desc, input_ptr, output_desc, out_backprop_ptr, conv_desc,
+ filter_desc, &filter_backprop_ptr, &scratch_allocator)
+ .ok();
+
+ if (!cudnn_launch_status) {
+ context->SetStatus(errors::Internal(
+ "cuDNN Backward Filter function launch failure : input shape(",
+ input.shape().DebugString(), ") filter shape(",
+ filter.shape().DebugString(), ")"));
+ }
+
+ auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
+ functor::ReverseTransformFilter<GPUDevice, T, 5>()(
+ context->eigen_device<GPUDevice>(),
+ toConstTensor(pre_transformed_filter_backprop).template tensor<T, 5>(),
+ filter_backprop->tensor<T, 5>());
+ }
+
+ private:
+ std::vector<int32> stride_;
+ Padding padding_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("Conv3DBackpropInput").Device(DEVICE_GPU).TypeConstraint<float>("T"),
+ Conv3DBackpropInputOp<GPUDevice, float>);
+REGISTER_KERNEL_BUILDER(
+ Name("Conv3DBackpropFilter").Device(DEVICE_GPU).TypeConstraint<float>("T"),
+ Conv3DBackpropFilterOp<GPUDevice, float>);
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc
index 69fd0dc00b..d88c1025af 100644
--- a/tensorflow/core/kernels/conv_ops.cc
+++ b/tensorflow/core/kernels/conv_ops.cc
@@ -303,145 +303,145 @@ struct LaunchConvOp<GPUDevice, T> {
", n=", n, ", k=", k));
}
return;
- }
- int padding_rows = 0;
- int padding_cols = 0;
- const int64 in_batch = GetTensorDim(input, data_format, 'N');
- int64 in_rows = GetTensorDim(input, data_format, 'H');
- int64 in_cols = GetTensorDim(input, data_format, 'W');
- const int64 in_depths = GetTensorDim(input, data_format, 'C');
- const int64 out_batch = GetTensorDim(*output, data_format, 'N');
- const int64 out_rows = GetTensorDim(*output, data_format, 'H');
- const int64 out_cols = GetTensorDim(*output, data_format, 'W');
- const int64 out_depths = GetTensorDim(*output, data_format, 'C');
- const int64 patch_rows = filter.dim_size(0);
- const int64 patch_cols = filter.dim_size(1);
- if (padding == Eigen::PADDING_SAME) {
- // Total padding on rows and cols is
- // Pr = (R' - 1) * S + Kr - R
- // Pc = (C' - 1) * S + Kc - C
- // where (R', C') are output dimensions, (R, C) are input dimensions, S
- // is stride, (Kr, Kc) are filter dimensions.
- // We pad Pr/2 on the left and Pr - Pr/2 on the right, Pc/2 on the top
- // and Pc - Pc/2 on the bottom. When Pr or Pc is odd, this means
- // we pad more on the right and bottom than on the top and left.
- padding_rows = (out_rows - 1) * row_stride + patch_rows - in_rows;
- padding_cols = (out_cols - 1) * col_stride + patch_cols - in_cols;
- const bool rows_odd = (padding_rows % 2 != 0);
- const bool cols_odd = (padding_cols % 2 != 0);
- if (rows_odd || cols_odd) {
- Tensor transformed_input;
- int64 new_in_rows = in_rows + rows_odd;
- int64 new_in_cols = in_cols + cols_odd;
- OP_REQUIRES_OK(ctx,
- ctx->allocate_temp(
- DataTypeToEnum<T>::value,
- ShapeFromFormat(data_format, in_batch, new_in_rows,
- new_in_cols, in_depths),
- &transformed_input));
-
- functor::PadInput<GPUDevice, T, int>()(
- ctx->eigen_device<GPUDevice>(),
- To32Bit(input_param.tensor<T, 4>()), 0, rows_odd, 0, cols_odd,
- To32Bit(transformed_input.tensor<T, 4>()), data_format);
- input = transformed_input;
- in_rows = new_in_rows;
- in_cols = new_in_cols;
- }
- }
-
- if (data_format == FORMAT_NHWC) {
- // Convert the input tensor from NHWC to NCHW.
+ }
+ int padding_rows = 0;
+ int padding_cols = 0;
+ const int64 in_batch = GetTensorDim(input, data_format, 'N');
+ int64 in_rows = GetTensorDim(input, data_format, 'H');
+ int64 in_cols = GetTensorDim(input, data_format, 'W');
+ const int64 in_depths = GetTensorDim(input, data_format, 'C');
+ const int64 out_batch = GetTensorDim(*output, data_format, 'N');
+ const int64 out_rows = GetTensorDim(*output, data_format, 'H');
+ const int64 out_cols = GetTensorDim(*output, data_format, 'W');
+ const int64 out_depths = GetTensorDim(*output, data_format, 'C');
+ const int64 patch_rows = filter.dim_size(0);
+ const int64 patch_cols = filter.dim_size(1);
+ if (padding == Eigen::PADDING_SAME) {
+ // Total padding on rows and cols is
+ // Pr = (R' - 1) * S + Kr - R
+ // Pc = (C' - 1) * S + Kc - C
+ // where (R', C') are output dimensions, (R, C) are input dimensions, S
+ // is stride, (Kr, Kc) are filter dimensions.
+ // We pad Pr/2 on the left and Pr - Pr/2 on the right, Pc/2 on the top
+ // and Pc - Pc/2 on the bottom. When Pr or Pc is odd, this means
+ // we pad more on the right and bottom than on the top and left.
+ padding_rows = (out_rows - 1) * row_stride + patch_rows - in_rows;
+ padding_cols = (out_cols - 1) * col_stride + patch_cols - in_cols;
+ const bool rows_odd = (padding_rows % 2 != 0);
+ const bool cols_odd = (padding_cols % 2 != 0);
+ if (rows_odd || cols_odd) {
Tensor transformed_input;
- OP_REQUIRES_OK(ctx, ctx->allocate_temp(
- DataTypeToEnum<T>::value,
- ShapeFromFormat(FORMAT_NCHW, in_batch, in_rows,
- in_cols, in_depths),
- &transformed_input));
- functor::NHWCToNCHW<GPUDevice, T>()(
- ctx->eigen_device<GPUDevice>(),
- const_cast<const Tensor&>(input).tensor<T, 4>(),
- transformed_input.tensor<T, 4>());
+ int64 new_in_rows = in_rows + rows_odd;
+ int64 new_in_cols = in_cols + cols_odd;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_temp(
+ DataTypeToEnum<T>::value,
+ ShapeFromFormat(data_format, in_batch, new_in_rows,
+ new_in_cols, in_depths),
+ &transformed_input));
+
+ functor::PadInput<GPUDevice, T, int, 4>()(
+ ctx->eigen_device<GPUDevice>(), To32Bit(input_param.tensor<T, 4>()),
+ {{0, 0}}, {{rows_odd, cols_odd}},
+ To32Bit(transformed_input.tensor<T, 4>()), data_format);
input = transformed_input;
+ in_rows = new_in_rows;
+ in_cols = new_in_cols;
}
+ }
- perftools::gputools::dnn::BatchDescriptor input_desc;
- input_desc.set_count(in_batch)
- .set_feature_map_count(in_depths)
- .set_height(in_rows)
- .set_width(in_cols)
- .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
- perftools::gputools::dnn::BatchDescriptor output_desc;
- output_desc.set_count(out_batch)
- .set_height(out_rows)
- .set_width(out_cols)
- .set_feature_map_count(out_depths)
- .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
- perftools::gputools::dnn::FilterDescriptor filter_desc;
- filter_desc.set_input_filter_height(filter.dim_size(0))
- .set_input_filter_width(filter.dim_size(1))
- .set_input_feature_map_count(filter.dim_size(2))
- .set_output_feature_map_count(filter.dim_size(3));
- perftools::gputools::dnn::ConvolutionDescriptor conv_desc;
- conv_desc.set_vertical_filter_stride(row_stride)
- .set_horizontal_filter_stride(col_stride)
- .set_zero_padding_height(padding_rows / 2)
- .set_zero_padding_width(padding_cols / 2);
-
- Tensor transformed_filter;
- OP_REQUIRES_OK(ctx,
- ctx->allocate_temp(
- DataTypeToEnum<T>::value,
- TensorShape({filter.dim_size(3), filter.dim_size(2),
- filter.dim_size(0), filter.dim_size(1)}),
- &transformed_filter));
-
- functor::TransformFilter<GPUDevice, T, int>()(
- ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 4>()),
- To32Bit(transformed_filter.tensor<T, 4>()));
-
- Tensor transformed_output;
- OP_REQUIRES_OK(ctx, ctx->allocate_temp(
- DataTypeToEnum<T>::value,
- ShapeFromFormat(FORMAT_NCHW, out_batch, out_rows,
- out_cols, out_depths),
- &transformed_output));
-
- auto input_ptr = AsDeviceMemory(input.template flat<T>().data(),
- input.template flat<T>().size());
- auto filter_ptr =
- AsDeviceMemory(transformed_filter.template flat<T>().data(),
- transformed_filter.template flat<T>().size());
- auto output_ptr =
- AsDeviceMemory(transformed_output.template flat<T>().data(),
- transformed_output.template flat<T>().size());
-
- static int64 ConvolveScratchSize = GetCudnnWorkspaceLimit(
- "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB by default
- );
- CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
- bool cudnn_launch_status =
- stream
- ->ThenConvolveWithScratch(input_desc, input_ptr, filter_desc,
- filter_ptr, conv_desc, output_desc,
- &output_ptr, &scratch_allocator)
- .ok();
+ if (data_format == FORMAT_NHWC) {
+ // Convert the input tensor from NHWC to NCHW.
+ Tensor transformed_input;
+ OP_REQUIRES_OK(
+ ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
+ ShapeFromFormat(FORMAT_NCHW, in_batch,
+ in_rows, in_cols, in_depths),
+ &transformed_input));
+ functor::NHWCToNCHW<GPUDevice, T, 4>()(
+ ctx->eigen_device<GPUDevice>(),
+ const_cast<const Tensor&>(input).tensor<T, 4>(),
+ transformed_input.tensor<T, 4>());
+ input = transformed_input;
+ }
- if (!cudnn_launch_status) {
- ctx->SetStatus(errors::Internal(
- "cuDNN launch failure : input shape(", input.shape().DebugString(),
- ") filter shape(", filter.shape().DebugString(), ")"));
- }
+ perftools::gputools::dnn::BatchDescriptor input_desc;
+ input_desc.set_count(in_batch)
+ .set_feature_map_count(in_depths)
+ .set_height(in_rows)
+ .set_width(in_cols)
+ .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+ perftools::gputools::dnn::BatchDescriptor output_desc;
+ output_desc.set_count(out_batch)
+ .set_height(out_rows)
+ .set_width(out_cols)
+ .set_feature_map_count(out_depths)
+ .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+ perftools::gputools::dnn::FilterDescriptor filter_desc;
+ filter_desc.set_input_filter_height(filter.dim_size(0))
+ .set_input_filter_width(filter.dim_size(1))
+ .set_input_feature_map_count(filter.dim_size(2))
+ .set_output_feature_map_count(filter.dim_size(3));
+ perftools::gputools::dnn::ConvolutionDescriptor conv_desc;
+ conv_desc.set_vertical_filter_stride(row_stride)
+ .set_horizontal_filter_stride(col_stride)
+ .set_zero_padding_height(padding_rows / 2)
+ .set_zero_padding_width(padding_cols / 2);
+
+ Tensor transformed_filter;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_temp(
+ DataTypeToEnum<T>::value,
+ TensorShape({filter.dim_size(3), filter.dim_size(2),
+ filter.dim_size(0), filter.dim_size(1)}),
+ &transformed_filter));
+
+ functor::TransformFilter<GPUDevice, T, int, 4>()(
+ ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 4>()),
+ To32Bit(transformed_filter.tensor<T, 4>()));
+
+ Tensor transformed_output;
+ OP_REQUIRES_OK(
+ ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
+ ShapeFromFormat(FORMAT_NCHW, out_batch,
+ out_rows, out_cols, out_depths),
+ &transformed_output));
+
+ auto input_ptr = AsDeviceMemory(input.template flat<T>().data(),
+ input.template flat<T>().size());
+ auto filter_ptr =
+ AsDeviceMemory(transformed_filter.template flat<T>().data(),
+ transformed_filter.template flat<T>().size());
+ auto output_ptr =
+ AsDeviceMemory(transformed_output.template flat<T>().data(),
+ transformed_output.template flat<T>().size());
+
+ static int64 ConvolveScratchSize = GetCudnnWorkspaceLimit(
+ "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB by default
+ );
+ CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
+ bool cudnn_launch_status =
+ stream
+ ->ThenConvolveWithScratch(input_desc, input_ptr, filter_desc,
+ filter_ptr, conv_desc, output_desc,
+ &output_ptr, &scratch_allocator)
+ .ok();
+
+ if (!cudnn_launch_status) {
+ ctx->SetStatus(errors::Internal(
+ "cuDNN launch failure : input shape(", input.shape().DebugString(),
+ ") filter shape(", filter.shape().DebugString(), ")"));
+ }
- // Convert the output tensor back from NHWC to NCHW.
- if (data_format == FORMAT_NHWC) {
- functor::NCHWToNHWC<GPUDevice, T>()(
- ctx->eigen_device<GPUDevice>(),
- const_cast<const Tensor&>(transformed_output).tensor<T, 4>(),
- output->tensor<T, 4>());
- } else {
- *output = transformed_output;
- }
+ // Convert the output tensor back from NHWC to NCHW.
+ if (data_format == FORMAT_NHWC) {
+ functor::NCHWToNHWC<GPUDevice, T, 4>()(
+ ctx->eigen_device<GPUDevice>(),
+ const_cast<const Tensor&>(transformed_output).tensor<T, 4>(),
+ output->tensor<T, 4>());
+ } else {
+ *output = transformed_output;
+ }
}
};
@@ -466,17 +466,17 @@ namespace functor {
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair); \
extern template struct MatMulConvFunctor<GPUDevice, T>; \
template <> \
- void TransformFilter<GPUDevice, T, int>::operator()( \
+ void TransformFilter<GPUDevice, T, int, 4>::operator()( \
const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
typename TTypes<T, 4, int>::Tensor out); \
- extern template struct TransformFilter<GPUDevice, T, int>; \
+ extern template struct TransformFilter<GPUDevice, T, int, 4>; \
template <> \
- void PadInput<GPUDevice, T, int>::operator()( \
+ void PadInput<GPUDevice, T, int, 4>::operator()( \
const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
- int padding_rows_left, int padding_rows_right, int padding_cols_left, \
- int padding_cols_right, typename TTypes<T, 4, int>::Tensor out, \
- TensorFormat data_format); \
- extern template struct PadInput<GPUDevice, T, int>
+ const std::array<int, 2>& padding_left, \
+ const std::array<int, 2>& padding_right, \
+ typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format); \
+ extern template struct PadInput<GPUDevice, T, int, 4>
DECLARE_GPU_SPEC(float);
#undef DECLARE_GPU_SPEC
diff --git a/tensorflow/core/kernels/conv_ops_3d.cc b/tensorflow/core/kernels/conv_ops_3d.cc
new file mode 100644
index 0000000000..ea3c4c90cb
--- /dev/null
+++ b/tensorflow/core/kernels/conv_ops_3d.cc
@@ -0,0 +1,355 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define USE_EIGEN_TENSOR
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/conv_2d.h"
+#include "tensorflow/core/kernels/conv_3d.h"
+
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/kernels/conv_ops_gpu.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/util/padding.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+#if GOOGLE_CUDA
+#include "tensorflow/core/platform/stream_executor.h"
+using perftools::gputools::dnn::DimIndex;
+#endif
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename Device, typename T>
+struct LaunchConvOp;
+
+template <typename T>
+struct LaunchConvOp<CPUDevice, T> {
+ static void launch(OpKernelContext* context, const Tensor& input,
+ const Tensor& filter, const std::array<int64, 3>& strides,
+ const Padding padding, Tensor* output) {
+ functor::CuboidConvolution<CPUDevice, T>()(
+ context->eigen_device<CPUDevice>(), output->tensor<T, 5>(),
+ input.tensor<T, 5>(), filter.tensor<T, 5>(), strides[0], strides[1],
+ strides[2], BrainPadding2EigenPadding(padding));
+ }
+};
+
+template <typename Device, typename T>
+class Conv3DOp : public BinaryOp<T> {
+ public:
+ explicit Conv3DOp(OpKernelConstruction* context) : BinaryOp<T>(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 5,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES(
+ context, (stride_[0] == 1 && stride_[4] == 1),
+ errors::InvalidArgument("Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ // Input tensor is of the following dimensions:
+ // [ batch, in_z, in_y, in_x, in_channels ]
+ const Tensor& input = context->input(0);
+
+ // Input filter is of the following dimensions:
+ // [ filter_z, filter_y, filter_x, in_channels, out_channels]
+ const Tensor& filter = context->input(1);
+
+ // NOTE: The ordering of the spatial dimensions is arbitrary, but has to be
+ // kept consistent between input/filter/output.
+ OP_REQUIRES(context, input.dims() == 5,
+ errors::InvalidArgument("input must be 5-dimensional"));
+ OP_REQUIRES(context, filter.dims() == 5,
+ errors::InvalidArgument("filter must be 5-dimensional"));
+
+ const int64 in_depth = input.dim_size(4);
+ const int64 in_batch = input.dim_size(0);
+
+ const int64 out_depth = filter.dim_size(4);
+ OP_REQUIRES(
+ context, in_depth == filter.dim_size(3),
+ errors::InvalidArgument("input and filter must have the same depth"));
+
+ std::array<int64, 3> input_size = {
+ {input.dim_size(1), input.dim_size(2), input.dim_size(3)}};
+ std::array<int64, 3> filter_size = {
+ {filter.dim_size(0), filter.dim_size(1), filter.dim_size(2)}};
+ std::array<int64, 3> strides = {{stride_[1], stride_[2], stride_[3]}};
+ std::array<int64, 3> out, padding;
+
+ OP_REQUIRES_OK(context, Get3dOutputSize(input_size, filter_size, strides,
+ padding_, &out, &padding));
+
+ TensorShape out_shape = {in_batch, out[0], out[1], out[2], out_depth};
+ Tensor* output;
+ OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
+
+ // Return early if nothing to do.
+ if (out_shape.num_elements() == 0) return;
+
+ LaunchConvOp<Device, T>::launch(context, input, filter, strides, padding_,
+ output);
+ }
+
+ private:
+ std::vector<int32> stride_;
+ Padding padding_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("Conv3D").Device(DEVICE_CPU).TypeConstraint<float>("T"),
+ Conv3DOp<CPUDevice, float>);
+
+#ifndef __ANDROID__
+REGISTER_KERNEL_BUILDER(
+ Name("Conv3D").Device(DEVICE_CPU).TypeConstraint<double>("T"),
+ Conv3DOp<CPUDevice, double>);
+#endif
+
+#if GOOGLE_CUDA
+
+// TODO(mjanusz): Share logic with 2d implementation as much as possible.
+template <typename T>
+struct LaunchConvOp<GPUDevice, T> {
+ static void launch(OpKernelContext* ctx, const Tensor& input_param,
+ const Tensor& filter, const std::array<int64, 3>& strides,
+ const Padding padding, Tensor* output) {
+ auto* stream = ctx->op_device_context()->stream();
+ OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
+
+ Tensor input = input_param;
+
+ const int64 in_batch = input.dim_size(0);
+ int64 in_planes = input.dim_size(1);
+ int64 in_rows = input.dim_size(2);
+ int64 in_cols = input.dim_size(3);
+ const int64 in_depth = input.dim_size(4);
+
+ const int64 filter_planes = filter.dim_size(0);
+ const int64 filter_rows = filter.dim_size(1);
+ const int64 filter_cols = filter.dim_size(2);
+ const int64 out_depth = filter.dim_size(4);
+
+ int64 pad_planes = 0, pad_rows = 0, pad_cols = 0;
+ int64 out_planes = output->dim_size(1);
+ int64 out_rows = output->dim_size(2);
+ int64 out_cols = output->dim_size(3);
+
+ if (padding == Padding::SAME) {
+ pad_planes = (out_planes - 1) * strides[0] + filter_planes - in_planes;
+ pad_rows = (out_rows - 1) * strides[1] + filter_rows - in_rows;
+ pad_cols = (out_cols - 1) * strides[2] + filter_cols - in_cols;
+ }
+
+ // NOTE: This only works in NHWC.
+ if (filter_planes == 1 && filter_rows == 1 && filter_cols == 1 &&
+ strides[0] == 1 && strides[1] == 1 && strides[2] == 1) {
+ // 1x1 filter, so call cublas directly.
+ const uint64 m = in_batch * in_cols * in_rows * in_planes;
+ const uint64 k = in_depth;
+ const uint64 n = out_depth;
+
+ auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
+ input.template flat<T>().size());
+ auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
+ filter.template flat<T>().size());
+ auto c_ptr = AsDeviceMemory(output->template flat<T>().data(),
+ output->template flat<T>().size());
+
+ auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose;
+ bool blas_launch_status =
+ stream
+ ->ThenBlasGemm(no_transpose, no_transpose, n, m, k, 1.0f, b_ptr,
+ n, a_ptr, k, 0.0f, &c_ptr, n)
+ .ok();
+ if (!blas_launch_status) {
+ ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
+ ", n=", n, ", k=", k));
+ }
+ return;
+ }
+
+ if (padding == Padding::SAME) {
+ const bool rows_odd = (pad_rows % 2 != 0);
+ const bool cols_odd = (pad_cols % 2 != 0);
+ const bool planes_odd = (pad_planes % 2 != 0);
+
+ // Necessary because cuDNN only supports symmetric padding.
+ // TODO(mjanusz): Consider making this optional? This would save some
+ // overhead and would work as long as an op trained this way is only
+ // used on GPU.
+ if (rows_odd || cols_odd || planes_odd) {
+ Tensor transformed_input;
+ int64 new_in_rows = in_rows + rows_odd;
+ int64 new_in_cols = in_cols + cols_odd;
+ int64 new_in_planes = in_planes + planes_odd;
+
+ TensorShape transformed_shape(
+ {in_batch, new_in_planes, new_in_rows, new_in_cols, in_depth});
+ OP_REQUIRES_OK(
+ ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, transformed_shape,
+ &transformed_input));
+
+ functor::PadInput<GPUDevice, T, int, 5>()(
+ ctx->eigen_device<GPUDevice>(), To32Bit(input_param.tensor<T, 5>()),
+ {{0, 0, 0}}, {{planes_odd, rows_odd, cols_odd}},
+ To32Bit(transformed_input.tensor<T, 5>()), FORMAT_NHWC);
+ input = transformed_input;
+ in_rows = new_in_rows;
+ in_cols = new_in_cols;
+ in_planes = new_in_planes;
+ }
+ }
+
+ Tensor transformed_input;
+ OP_REQUIRES_OK(
+ ctx, ctx->allocate_temp(
+ DataTypeToEnum<T>::value,
+ TensorShape({in_batch, in_depth, in_planes, in_rows, in_cols}),
+ &transformed_input));
+ // input: [b, x, y, z, d]
+ // t_input: [b, d, x, y, z]
+ // NCDHW is the only format universally supported by cuDNN.
+ functor::NHWCToNCHW<GPUDevice, T, 5>()(
+ ctx->eigen_device<GPUDevice>(),
+ const_cast<const Tensor&>(input).tensor<T, 5>(),
+ transformed_input.tensor<T, 5>());
+ input = transformed_input;
+
+ perftools::gputools::dnn::BatchDescriptor input_desc(3);
+ input_desc.set_count(in_batch)
+ .set_feature_map_count(in_depth)
+ .set_spatial_dim(DimIndex::X, in_cols)
+ .set_spatial_dim(DimIndex::Y, in_rows)
+ .set_spatial_dim(DimIndex::Z, in_planes)
+ .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+ perftools::gputools::dnn::BatchDescriptor output_desc(3);
+ output_desc.set_count(in_batch)
+ .set_spatial_dim(DimIndex::X, out_cols)
+ .set_spatial_dim(DimIndex::Y, out_rows)
+ .set_spatial_dim(DimIndex::Z, out_planes)
+ .set_feature_map_count(out_depth)
+ .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+ perftools::gputools::dnn::FilterDescriptor filter_desc(3);
+ filter_desc.set_spatial_dim(DimIndex::X, filter_cols)
+ .set_spatial_dim(DimIndex::Y, filter_rows)
+ .set_spatial_dim(DimIndex::Z, filter_planes)
+ .set_input_feature_map_count(in_depth)
+ .set_output_feature_map_count(out_depth);
+ perftools::gputools::dnn::ConvolutionDescriptor conv_desc(3);
+ conv_desc.set_filter_stride(DimIndex::X, strides[2])
+ .set_filter_stride(DimIndex::Y, strides[1])
+ .set_filter_stride(DimIndex::Z, strides[0])
+ .set_zero_padding(DimIndex::X, pad_cols / 2)
+ .set_zero_padding(DimIndex::Y, pad_rows / 2)
+ .set_zero_padding(DimIndex::Z, pad_planes / 2);
+
+ Tensor transformed_filter;
+ OP_REQUIRES_OK(
+ ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
+ TensorShape({out_depth, in_depth, filter_planes,
+ filter_rows, filter_cols}),
+ &transformed_filter));
+ // filter: [x, y, z, in, out]
+ // t_filter: [out, in, x, y, z]
+ functor::TransformFilter<GPUDevice, T, int, 5>()(
+ ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 5>()),
+ To32Bit(transformed_filter.tensor<T, 5>()));
+
+ Tensor transformed_output;
+ OP_REQUIRES_OK(
+ ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
+ TensorShape({in_batch, out_depth, out_planes,
+ out_rows, out_cols}),
+ &transformed_output));
+
+ auto input_ptr = AsDeviceMemory(input.template flat<T>().data(),
+ input.template flat<T>().size());
+ auto filter_ptr =
+ AsDeviceMemory(transformed_filter.template flat<T>().data(),
+ transformed_filter.template flat<T>().size());
+ auto output_ptr =
+ AsDeviceMemory(transformed_output.template flat<T>().data(),
+ transformed_output.template flat<T>().size());
+
+ static int64 ConvolveScratchSize = GetCudnnWorkspaceLimit(
+ "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32); // 4GB by default
+ CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
+ bool cudnn_launch_status =
+ stream
+ ->ThenConvolveWithScratch(input_desc, input_ptr, filter_desc,
+ filter_ptr, conv_desc, output_desc,
+ &output_ptr, &scratch_allocator)
+ .ok();
+
+ if (!cudnn_launch_status) {
+ ctx->SetStatus(errors::Internal(
+ "cuDNN launch failure : input shape(", input.shape().DebugString(),
+ ") filter shape(", filter.shape().DebugString(), ")"));
+ }
+
+ // t_output: [b, out, x, y, z]
+ // output: [b, x, y, z, out]
+ functor::NCHWToNHWC<GPUDevice, T, 5>()(
+ ctx->eigen_device<GPUDevice>(),
+ const_cast<const Tensor&>(transformed_output).tensor<T, 5>(),
+ output->tensor<T, 5>());
+ }
+};
+
+// Forward declarations of the functor specializations for GPU.
+// This ensures that the custom implementation is used instead of the default
+// Eigen one (which is used for CPU).
+namespace functor {
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void TransformFilter<GPUDevice, T, int, 5>::operator()( \
+ const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \
+ typename TTypes<T, 5, int>::Tensor out); \
+ template <> \
+ void ReverseTransformFilter<GPUDevice, T, 5>::operator()( \
+ const GPUDevice& d, typename TTypes<T, 5>::ConstTensor in, \
+ typename TTypes<T, 5>::Tensor out); \
+ template <> \
+ void PadInput<GPUDevice, T, int, 5>::operator()( \
+ const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \
+ const std::array<int, 3>& padding_left, \
+ const std::array<int, 3>& padding_right, \
+ typename TTypes<T, 5, int>::Tensor out, TensorFormat format);
+
+DECLARE_GPU_SPEC(float);
+#undef DECLARE_GPU_SPEC
+
+} // namespace functor
+
+// Registration of the GPU implementations.
+REGISTER_KERNEL_BUILDER(
+ Name("Conv3D").Device(DEVICE_GPU).TypeConstraint<float>("T"),
+ Conv3DOp<GPUDevice, float>);
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
index ccd983833d..5af9bc0e5b 100644
--- a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
+++ b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
@@ -18,6 +18,7 @@ limitations under the License.
#define EIGEN_USE_GPU
#include <algorithm>
+#include <array>
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/conv_2d.h"
@@ -30,6 +31,7 @@ typedef Eigen::GpuDevice GPUDevice;
namespace functor {
+// TODO(mjanusz): Move this to a shared util file.
// A simple array that contains data that can be passed between CPU and GPU.
template <typename T, int IndexCount, T DefaultValue>
struct Array {
@@ -65,6 +67,11 @@ struct Array {
data[i] = DefaultValue;
}
}
+ EIGEN_STRONG_INLINE Array(const std::array<T, IndexCount>& array) {
+ for (int i = 0; i < IndexCount; i++) {
+ data[i] = array[i];
+ }
+ }
T data[IndexCount];
};
@@ -78,6 +85,8 @@ struct Dimension : Array<int, IndexCount, 1> {
: Base(a0, a1) {}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension(int a0, int a1, int a2)
: Base(a0, a1, a2) {}
+ EIGEN_STRONG_INLINE Dimension(const std::array<int, IndexCount>& array)
+ : Base(array) {}
};
// An index type with compile-time known size.
@@ -248,25 +257,28 @@ __global__ void SwapDimension1And2InTensor3UsingTiles(const T* input,
// A Cuda custom kernel that convert input to output, given proper padding on
// the left and the top. The padded value is zero.
-template <typename T>
+template <typename T, int NDIMS>
__global__ void PadInputCustomKernelNHWC(int nthreads, const T* input,
- Dimension<4> input_dims, T* output,
- Dimension<4> output_dims,
- int padding_rows_left,
- int padding_cols_left) {
+ Dimension<NDIMS> input_dims, T* output,
+ Dimension<NDIMS> output_dims,
+ Dimension<NDIMS - 2> padding_left) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int output_index = index;
- Index<4> output_tensor_index = FlatToTensorIndex(output_index, output_dims);
-
- Index<4> input_tensor_index;
- input_tensor_index[0] = output_tensor_index[0];
- input_tensor_index[1] = output_tensor_index[1] - padding_rows_left;
- input_tensor_index[2] = output_tensor_index[2] - padding_cols_left;
- input_tensor_index[3] = output_tensor_index[3];
+ Index<NDIMS> output_tensor_index =
+ FlatToTensorIndex(output_index, output_dims);
+
+ Index<NDIMS> input_tensor_index;
+ input_tensor_index[0] = output_tensor_index[0]; // batch
+ bool ok = true;
+ for (int i = 1; i < NDIMS - 1; i++) {
+ input_tensor_index[i] = output_tensor_index[i] - padding_left[i - 1];
+ ok &=
+ (input_tensor_index[i] >= 0 && input_tensor_index[i] < input_dims[i]);
+ }
+ input_tensor_index[NDIMS - 1] = output_tensor_index[NDIMS - 1]; // channels
- if (input_tensor_index[1] >= 0 && input_tensor_index[1] < input_dims[1] &&
- input_tensor_index[2] >= 0 && input_tensor_index[2] < input_dims[2]) {
- int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
+ if (ok) {
+ const int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
output[output_index] = input[input_index];
} else {
output[output_index] = T(0);
@@ -274,25 +286,28 @@ __global__ void PadInputCustomKernelNHWC(int nthreads, const T* input,
}
}
-template <typename T>
+template <typename T, int NDIMS>
__global__ void PadInputCustomKernelNCHW(int nthreads, const T* input,
- Dimension<4> input_dims, T* output,
- Dimension<4> output_dims,
- int padding_rows_left,
- int padding_cols_left) {
+ Dimension<NDIMS> input_dims, T* output,
+ Dimension<NDIMS> output_dims,
+ Dimension<NDIMS - 2> padding_left) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int output_index = index;
- Index<4> output_tensor_index = FlatToTensorIndex(output_index, output_dims);
-
- Index<4> input_tensor_index;
- input_tensor_index[0] = output_tensor_index[0];
- input_tensor_index[1] = output_tensor_index[1];
- input_tensor_index[2] = output_tensor_index[2] - padding_rows_left;
- input_tensor_index[3] = output_tensor_index[3] - padding_cols_left;
+ Index<NDIMS> output_tensor_index =
+ FlatToTensorIndex(output_index, output_dims);
+
+ Index<NDIMS> input_tensor_index;
+ input_tensor_index[0] = output_tensor_index[0]; // batch
+ input_tensor_index[1] = output_tensor_index[1]; // channels
+ bool ok = true;
+ for (int i = 2; i < NDIMS; i++) {
+ input_tensor_index[i] = output_tensor_index[i] - padding_left[i - 2];
+ ok &=
+ (input_tensor_index[i] >= 0 && input_tensor_index[i] < input_dims[i]);
+ }
- if (input_tensor_index[2] >= 0 && input_tensor_index[2] < input_dims[2] &&
- input_tensor_index[3] >= 0 && input_tensor_index[3] < input_dims[3]) {
- int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
+ if (ok) {
+ const int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
output[output_index] = input[input_index];
} else {
output[output_index] = T(0);
@@ -302,15 +317,19 @@ __global__ void PadInputCustomKernelNCHW(int nthreads, const T* input,
// A GPU helper function that converts TensorFlow filter format to Cudnn filter
// format.
-template <typename T>
-struct TransformFilter<GPUDevice, T, int> {
+template <typename T, int NDIMS>
+struct TransformFilter<GPUDevice, T, int, NDIMS> {
typedef GPUDevice Device;
- void operator()(const Device& d, typename TTypes<T, 4, int>::ConstTensor in,
- typename TTypes<T, 4, int>::Tensor out) {
+ void operator()(const Device& d,
+ typename TTypes<T, NDIMS, int>::ConstTensor in,
+ typename TTypes<T, NDIMS, int>::Tensor out) {
Dimension<3> combined_dims;
- combined_dims[0] = in.dimension(0) * in.dimension(1);
- combined_dims[1] = in.dimension(2);
- combined_dims[2] = in.dimension(3);
+ combined_dims[0] = in.dimension(0); // spatial dimensions
+ for (int i = 1; i < NDIMS - 2; i++) {
+ combined_dims[0] *= in.dimension(i);
+ }
+ combined_dims[1] = in.dimension(NDIMS - 2); // input filters
+ combined_dims[2] = in.dimension(NDIMS - 1); // output filters
CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
SwapDimension0And2InTensor3<
T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
@@ -319,15 +338,18 @@ struct TransformFilter<GPUDevice, T, int> {
};
// Converts Cudnn filter format back to TensorFlow filter format.
-template <typename T>
-struct ReverseTransformFilter<GPUDevice, T> {
+template <typename T, int NDIMS>
+struct ReverseTransformFilter<GPUDevice, T, NDIMS> {
typedef GPUDevice Device;
- void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor in,
- typename TTypes<T, 4>::Tensor out) {
+ void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in,
+ typename TTypes<T, NDIMS>::Tensor out) {
Dimension<3> combined_dims;
- combined_dims[0] = in.dimension(0);
- combined_dims[1] = in.dimension(1);
- combined_dims[2] = in.dimension(2) * in.dimension(3);
+ combined_dims[0] = in.dimension(0); // output filters
+ combined_dims[1] = in.dimension(1); // input filters
+ combined_dims[2] = in.dimension(2); // spatial dimensions
+ for (int i = 3; i < NDIMS; ++i) {
+ combined_dims[2] *= in.dimension(i);
+ }
CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
SwapDimension0And2InTensor3<
T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
@@ -337,33 +359,37 @@ struct ReverseTransformFilter<GPUDevice, T> {
// A GPU helper function that converts input tensor to a larger output tensor,
// given proper padding values. The padded value is zero.
-template <typename T>
-struct PadInput<GPUDevice, T, int> {
+template <typename T, int NDIMS>
+struct PadInput<GPUDevice, T, int, NDIMS> {
typedef GPUDevice Device;
- void operator()(const Device& d, typename TTypes<T, 4, int>::ConstTensor in,
- int padding_rows_left, int padding_rows_right,
- int padding_cols_left, int padding_cols_right,
- typename TTypes<T, 4, int>::Tensor out, TensorFormat format) {
+ void operator()(const Device& d,
+ typename TTypes<T, NDIMS, int>::ConstTensor in,
+ const std::array<int, NDIMS - 2>& padding_left,
+ const std::array<int, NDIMS - 2>& padding_right,
+ typename TTypes<T, NDIMS, int>::Tensor out,
+ TensorFormat format) {
CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
- Dimension<4> input_dims;
- for (int i = 0; i < 4; i++) {
+ Dimension<NDIMS> input_dims;
+ for (int i = 0; i < NDIMS; ++i) {
input_dims[i] = in.dimension(i);
}
- Dimension<4> output_dims;
- for (int i = 0; i < 4; i++) {
+ Dimension<NDIMS> output_dims;
+ for (int i = 0; i < NDIMS; ++i) {
output_dims[i] = out.dimension(i);
}
+ const Dimension<NDIMS - 2> padding_left_dim(padding_left);
+
if (format == FORMAT_NHWC) {
- PadInputCustomKernelNHWC<
- T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
+ PadInputCustomKernelNHWC<T, NDIMS><<<
+ config.block_count, config.thread_per_block, 0, d.stream()>>>(
config.virtual_thread_count, in.data(), input_dims, out.data(),
- output_dims, padding_rows_left, padding_cols_left);
+ output_dims, padding_left_dim);
} else if (format == FORMAT_NCHW) {
- PadInputCustomKernelNCHW<
- T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
+ PadInputCustomKernelNCHW<T, NDIMS><<<
+ config.block_count, config.thread_per_block, 0, d.stream()>>>(
config.virtual_thread_count, in.data(), input_dims, out.data(),
- output_dims, padding_rows_left, padding_cols_left);
+ output_dims, padding_left_dim);
} else {
LOG(FATAL) << "Invalid data format: " << format;
}
@@ -405,30 +431,36 @@ void RunSwapDimension1And2InTensor3(const GPUDevice& d, const T* input,
// A GPU helper functor that converts NHWC TensorFlow data format to
// NCHW format that is accepted by Cudnn.
-template <typename T>
-struct NHWCToNCHW<GPUDevice, T> {
+template <typename T, int NDIMS>
+struct NHWCToNCHW<GPUDevice, T, NDIMS> {
typedef GPUDevice Device;
- void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor in,
- typename TTypes<T, 4>::Tensor out) {
+ void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in,
+ typename TTypes<T, NDIMS>::Tensor out) {
Dimension<3> combined_dims;
- combined_dims[0] = in.dimension(0);
- combined_dims[1] = in.dimension(1) * in.dimension(2);
- combined_dims[2] = in.dimension(3);
+ combined_dims[0] = in.dimension(0); // N (batch)
+ combined_dims[1] = in.dimension(1); // spatial dimensions (HW)
+ for (int i = 2; i < NDIMS - 1; ++i) {
+ combined_dims[1] *= in.dimension(i);
+ }
+ combined_dims[2] = in.dimension(NDIMS - 1); // C (channels)
RunSwapDimension1And2InTensor3(d, in.data(), combined_dims, out.data());
}
};
// A GPU helper functor that converts NCHW Cudnn data format to NHWC TensorFlow
// Format.
-template <typename T>
-struct NCHWToNHWC<GPUDevice, T> {
+template <typename T, int NDIMS>
+struct NCHWToNHWC<GPUDevice, T, NDIMS> {
typedef GPUDevice Device;
- void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor in,
- typename TTypes<T, 4>::Tensor out) {
+ void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in,
+ typename TTypes<T, NDIMS>::Tensor out) {
Dimension<3> combined_dims;
- combined_dims[0] = in.dimension(0);
- combined_dims[1] = in.dimension(1);
- combined_dims[2] = in.dimension(2) * in.dimension(3);
+ combined_dims[0] = in.dimension(0); // N (batch)
+ combined_dims[1] = in.dimension(1); // C (channel)
+ combined_dims[2] = in.dimension(2); // spatial dimensions (HW)
+ for (int i = 3; i < NDIMS; ++i) {
+ combined_dims[2] *= in.dimension(i);
+ }
RunSwapDimension1And2InTensor3(d, in.data(), combined_dims, out.data());
}
};
@@ -440,17 +472,21 @@ template struct functor::ShuffleAndReverse<GPUDevice, float, 4, int>;
template struct functor::ShuffleAndReverse<GPUDevice, float, 4,
Eigen::DenseIndex>;
-template struct functor::TransformFilter<GPUDevice, float, int>;
-
-template struct functor::ReverseTransformFilter<GPUDevice, float>;
-
-template struct functor::PadInput<GPUDevice, float, int>;
-
template struct functor::TransformDepth<GPUDevice, float, int>;
-template struct functor::NHWCToNCHW<GPUDevice, float>;
-
-template struct functor::NCHWToNHWC<GPUDevice, float>;
+// For 2d ops.
+template struct functor::TransformFilter<GPUDevice, float, int, 4>;
+template struct functor::ReverseTransformFilter<GPUDevice, float, 4>;
+template struct functor::NHWCToNCHW<GPUDevice, float, 4>;
+template struct functor::NCHWToNHWC<GPUDevice, float, 4>;
+template struct functor::PadInput<GPUDevice, float, int, 4>;
+
+// For 3d ops.
+template struct functor::TransformFilter<GPUDevice, float, int, 5>;
+template struct functor::ReverseTransformFilter<GPUDevice, float, 5>;
+template struct functor::NHWCToNCHW<GPUDevice, float, 5>;
+template struct functor::NCHWToNHWC<GPUDevice, float, 5>;
+template struct functor::PadInput<GPUDevice, float, int, 5>;
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cudnn_pooling_gpu.cc b/tensorflow/core/kernels/cudnn_pooling_gpu.cc
new file mode 100644
index 0000000000..35bbca53e6
--- /dev/null
+++ b/tensorflow/core/kernels/cudnn_pooling_gpu.cc
@@ -0,0 +1,216 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define USE_EIGEN_TENSOR
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/cudnn_pooling_gpu.h"
+#include "tensorflow/core/kernels/conv_2d.h"
+#include "tensorflow/core/kernels/conv_3d.h"
+#include "tensorflow/core/kernels/conv_ops_gpu.h"
+
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace tensorflow {
+
+#if GOOGLE_CUDA
+
+template <typename T>
+void DnnPooling3dOp<T>::Compute(
+ OpKernelContext* context,
+ perftools::gputools::dnn::PoolingMode pooling_mode,
+ const std::array<int64, 3>& window, const std::array<int64, 3>& stride,
+ const std::array<int64, 3>& padding, const Tensor& tensor_in,
+ Tensor* output) {
+ const auto in_shape = tensor_in.shape();
+ const auto out_shape = output->shape();
+
+ const int64 in_batch = in_shape.dim_size(0);
+ const int64 in_features = in_shape.dim_size(4);
+
+ Tensor transformed_input;
+ OP_REQUIRES_OK(context, context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ {in_shape.dim_size(0), in_shape.dim_size(4),
+ in_shape.dim_size(1), in_shape.dim_size(2),
+ in_shape.dim_size(3)},
+ &transformed_input));
+ functor::NHWCToNCHW<GPUDevice, T, 5>()(context->eigen_device<GPUDevice>(),
+ tensor_in.tensor<T, 5>(),
+ transformed_input.tensor<T, 5>());
+ Tensor transformed_output;
+ OP_REQUIRES_OK(context, context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ {out_shape.dim_size(0), out_shape.dim_size(4),
+ out_shape.dim_size(1), out_shape.dim_size(2),
+ out_shape.dim_size(3)},
+ &transformed_output));
+
+ perftools::gputools::dnn::PoolingDescriptor pooling_desc(3);
+ pooling_desc.set_pooling_mode(pooling_mode);
+ perftools::gputools::dnn::BatchDescriptor input_desc(3);
+ input_desc.set_count(in_batch)
+ .set_feature_map_count(in_features)
+ .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+ perftools::gputools::dnn::BatchDescriptor output_desc(3);
+ output_desc.set_count(in_batch)
+ .set_feature_map_count(in_features)
+ .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+ for (size_t i = 0; i < window.size(); ++i) {
+ const auto dim_i = static_cast<perftools::gputools::dnn::DimIndex>(i);
+ pooling_desc.set_window(dim_i, window.rbegin()[i]);
+ pooling_desc.set_stride(dim_i, stride.rbegin()[i]);
+ pooling_desc.set_padding(dim_i, padding.rbegin()[i]);
+ input_desc.set_spatial_dim(dim_i, in_shape.dim_size(3 - i));
+ output_desc.set_spatial_dim(dim_i, out_shape.dim_size(3 - i));
+ }
+
+ auto input_data = AsDeviceMemory(transformed_input.template flat<T>().data(),
+ transformed_input.template flat<T>().size());
+ auto output_data =
+ AsDeviceMemory(transformed_output.template flat<T>().data(),
+ transformed_output.template flat<T>().size());
+
+ auto* stream = context->op_device_context()->stream();
+ OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
+
+ bool status = stream
+ ->ThenPoolForward(pooling_desc, input_desc, input_data,
+ output_desc, &output_data)
+ .ok();
+ OP_REQUIRES(context, status,
+ errors::Internal("cudnn PoolForward launch failed"));
+
+ auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
+ functor::NCHWToNHWC<GPUDevice, T, 5>()(
+ context->eigen_device<GPUDevice>(),
+ toConstTensor(transformed_output).template tensor<T, 5>(),
+ output->tensor<T, 5>());
+}
+
+template <typename T>
+void DnnPooling3dGradOp<T>::Compute(
+ OpKernelContext* context,
+ perftools::gputools::dnn::PoolingMode pooling_mode,
+ const std::array<int64, 3>& window, const std::array<int64, 3>& stride,
+ const std::array<int64, 3>& padding,
+ const std::array<int64, 3>& output_size, const Tensor& out_backprop,
+ const TensorShape& tensor_in_shape, const Tensor* tensor_in,
+ const Tensor* tensor_out, Tensor* input_backprop) {
+ CHECK((pooling_mode != perftools::gputools::dnn::PoolingMode::kMaximum) ||
+ (tensor_in && tensor_out))
+ << "For MaxPoolGrad, both tensor_in and tensor_out needs to be "
+ "specified";
+
+ const int64 in_batch = tensor_in_shape.dim_size(0);
+ const int64 in_features = tensor_in_shape.dim_size(4);
+
+ Tensor transformed_input;
+ TensorShape transformed_input_shape = {
+ in_batch, in_features, tensor_in_shape.dim_size(1),
+ tensor_in_shape.dim_size(2), tensor_in_shape.dim_size(3)};
+ OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value,
+ transformed_input_shape,
+ &transformed_input));
+ Tensor transformed_output;
+ TensorShape transformed_output_shape = {
+ out_backprop.dim_size(0), out_backprop.dim_size(4),
+ out_backprop.dim_size(1), out_backprop.dim_size(2),
+ out_backprop.dim_size(3)};
+ OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value,
+ transformed_output_shape,
+ &transformed_output));
+ Tensor transformed_input_backprop;
+ OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value,
+ transformed_input_shape,
+ &transformed_input_backprop));
+ Tensor transformed_output_backprop;
+ OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value,
+ transformed_output_shape,
+ &transformed_output_backprop));
+ if (tensor_in != nullptr) {
+ functor::NHWCToNCHW<GPUDevice, T, 5>()(context->eigen_device<GPUDevice>(),
+ tensor_in->tensor<T, 5>(),
+ transformed_input.tensor<T, 5>());
+ }
+ if (tensor_out != nullptr) {
+ functor::NHWCToNCHW<GPUDevice, T, 5>()(context->eigen_device<GPUDevice>(),
+ tensor_out->tensor<T, 5>(),
+ transformed_output.tensor<T, 5>());
+ }
+ functor::NHWCToNCHW<GPUDevice, T, 5>()(
+ context->eigen_device<GPUDevice>(), out_backprop.tensor<T, 5>(),
+ transformed_output_backprop.tensor<T, 5>());
+
+ perftools::gputools::dnn::PoolingDescriptor pooling_desc(3);
+ pooling_desc.set_pooling_mode(pooling_mode);
+
+ perftools::gputools::dnn::BatchDescriptor orig_output_desc(3);
+ orig_output_desc.set_count(in_batch)
+ .set_feature_map_count(in_features)
+ .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+
+ perftools::gputools::dnn::BatchDescriptor orig_input_desc(3);
+ orig_input_desc.set_count(in_batch)
+ .set_feature_map_count(in_features)
+ .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+
+ for (size_t i = 0; i < window.size(); ++i) {
+ const auto dim_i = static_cast<perftools::gputools::dnn::DimIndex>(i);
+ pooling_desc.set_window(dim_i, window[i]);
+ pooling_desc.set_stride(dim_i, stride[i]);
+ pooling_desc.set_padding(dim_i, padding[i]);
+ orig_input_desc.set_spatial_dim(dim_i, tensor_in_shape.dim_size(3 - i));
+ orig_output_desc.set_spatial_dim(dim_i, output_size[i]);
+ }
+
+ auto orig_output_data =
+ AsDeviceMemory(transformed_output.template flat<T>().data(),
+ transformed_output.template flat<T>().size());
+ auto orig_input_data =
+ AsDeviceMemory(transformed_input.template flat<T>().data(),
+ transformed_input.template flat<T>().size());
+ auto output_backprop_data =
+ AsDeviceMemory(transformed_output_backprop.template flat<T>().data(),
+ transformed_output_backprop.template flat<T>().size());
+ auto input_backprop_data =
+ AsDeviceMemory(transformed_input_backprop.template flat<T>().data(),
+ transformed_input_backprop.template flat<T>().size());
+
+ auto* stream = context->op_device_context()->stream();
+ OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
+
+ bool status =
+ stream
+ ->ThenPoolBackward(pooling_desc, orig_input_desc, orig_input_data,
+ orig_output_desc, orig_output_data,
+ output_backprop_data, &input_backprop_data)
+ .ok();
+ OP_REQUIRES(context, status,
+ errors::Internal("cudnn PoolBackward launch failed"));
+
+ auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
+ functor::NCHWToNHWC<GPUDevice, T, 5>()(
+ context->eigen_device<GPUDevice>(),
+ toConstTensor(transformed_input_backprop).template tensor<T, 5>(),
+ input_backprop->tensor<T, 5>());
+}
+
+template class DnnPooling3dOp<float>;
+template class DnnPooling3dGradOp<float>;
+
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cudnn_pooling_gpu.h b/tensorflow/core/kernels/cudnn_pooling_gpu.h
new file mode 100644
index 0000000000..2e28d69601
--- /dev/null
+++ b/tensorflow/core/kernels/cudnn_pooling_gpu.h
@@ -0,0 +1,65 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Helper functions to run 3d pooling on GPU using CuDNN.
+
+#ifndef TENSORFLOW_KERNELS_CUDNN_POOLING_GPU_H_
+#define TENSORFLOW_KERNELS_CUDNN_POOLING_GPU_H_
+
+#include "tensorflow/core/framework/op_kernel.h"
+
+#if GOOGLE_CUDA
+#include "tensorflow/core/platform/stream_executor.h"
+#endif
+
+#include "tensorflow/core/util/padding.h"
+
+namespace tensorflow {
+
+#if GOOGLE_CUDA
+
+// Runs (avg/max)pooling on GPU.
+template <typename T>
+class DnnPooling3dOp {
+ public:
+ static void Compute(OpKernelContext* context,
+ perftools::gputools::dnn::PoolingMode pooling_mode,
+ const std::array<int64, 3>& size,
+ const std::array<int64, 3>& stride,
+ const std::array<int64, 3>& padding,
+ const Tensor& tensor_in, Tensor* output);
+};
+
+// Computes the gradient of (avg/max)pooling on GPU.
+template <typename T>
+class DnnPooling3dGradOp {
+ public:
+ static void Compute(OpKernelContext* context,
+ perftools::gputools::dnn::PoolingMode pooling_mode,
+ const std::array<int64, 3>& window,
+ const std::array<int64, 3>& stride,
+ const std::array<int64, 3>& padding,
+ const std::array<int64, 3>& output_size,
+ const Tensor& out_backprop,
+ const TensorShape& tensor_in_shape,
+ const Tensor* tensor_in, const Tensor* tensor_out,
+ Tensor* input_backprop);
+};
+
+#endif
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_CUDNN_POOLING_GPU_H_
diff --git a/tensorflow/core/kernels/ops_util.cc b/tensorflow/core/kernels/ops_util.cc
index a32443b841..c0e939c845 100644
--- a/tensorflow/core/kernels/ops_util.cc
+++ b/tensorflow/core/kernels/ops_util.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <algorithm>
#include <cmath>
#include "tensorflow/core/kernels/ops_util.h"
@@ -73,6 +74,33 @@ Status Get2dOutputSizeVerbose(const int in_height, const int in_width,
return Status::OK();
}
+Status Get3dOutputSize(const std::array<int64, 3>& input,
+ const std::array<int64, 3>& window,
+ const std::array<int64, 3>& strides,
+ Padding padding_type, std::array<int64, 3>* output_ptr,
+ std::array<int64, 3>* padding_ptr) {
+ auto& output = *output_ptr;
+ auto& padding = *padding_ptr;
+ switch (padding_type) {
+ case Padding::VALID:
+ for (size_t i = 0; i < input.size(); ++i) {
+ output[i] = (input[i] - window[i] + strides[i]) / strides[i];
+ padding[i] = 0;
+ }
+ break;
+ case Padding::SAME:
+ for (size_t i = 0; i < input.size(); ++i) {
+ output[i] = (input[i] + strides[i] - 1) / strides[i];
+ const int64 delta = (output[i] - 1) * strides[i] + window[i] - input[i];
+ // For odd values of total padding, add more padding at the 'right'
+ // side of the given dimension.
+ padding[i] = std::max(delta / 2, 0ll);
+ }
+ break;
+ }
+ return Status::OK();
+}
+
Eigen::PaddingType BrainPadding2EigenPadding(Padding padding) {
switch (padding) {
case Padding::VALID:
diff --git a/tensorflow/core/kernels/ops_util.h b/tensorflow/core/kernels/ops_util.h
index f27a5bc423..2cdad1c415 100644
--- a/tensorflow/core/kernels/ops_util.h
+++ b/tensorflow/core/kernels/ops_util.h
@@ -18,6 +18,8 @@ limitations under the License.
// This file contains utilities for various operations.
+#include <array>
+
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/status.h"
@@ -89,6 +91,19 @@ Status Get2dOutputSizeVerbose(const int in_height, const int in_width,
int* new_height, int* new_width, int* pad_top,
int* pad_bottom, int* pad_left, int* pad_right);
+// Given an input tensor, kernel, stride and padding type, populates the 3D size
+// of the output tensor and padding to be applied to the input tensor at the
+// lower end of every dimension. Use for 3D convolutions, where the input data
+// is padded with zeros, as well as for 3D avg/max pooling, where the input data
+// is padded with invalid values that are not considered for pooling.
+//
+// TODO(mjanusz): Unify this with Get2dOutputSize by using a common template.
+Status Get3dOutputSize(const std::array<int64, 3>& input,
+ const std::array<int64, 3>& window,
+ const std::array<int64, 3>& strides,
+ Padding padding_type, std::array<int64, 3>* output,
+ std::array<int64, 3>* padding);
+
// Calculates broadcast starting index and size. For SAME padding, addition
// padding could be applied to right, left, top and bottom. Depending on the
// current index, input size, kernel size, stride, padding size, the starting
diff --git a/tensorflow/core/kernels/pooling_ops_3d.cc b/tensorflow/core/kernels/pooling_ops_3d.cc
new file mode 100644
index 0000000000..e9a95b7240
--- /dev/null
+++ b/tensorflow/core/kernels/pooling_ops_3d.cc
@@ -0,0 +1,515 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#define EIGEN_USE_THREADS
+
+#include <array>
+
+#include "third_party/eigen3/Eigen/Core"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/kernels/eigen_pooling.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/util/padding.h"
+
+#if GOOGLE_CUDA
+#include "tensorflow/core/kernels/cudnn_pooling_gpu.h"
+#endif
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+enum PoolingType { MAX, AVG };
+
+template <typename Device, typename T, PoolingType Type>
+struct LaunchPoolingOp;
+
+template <typename T>
+struct LaunchPoolingOp<CPUDevice, T, AVG> {
+ static void launch(OpKernelContext* context, const Tensor& tensor_in,
+ const std::array<int64, 3>& window,
+ const std::array<int64, 3>& stride,
+ const std::array<int64, 3>& padding, Padding padding_type,
+ Tensor* output) {
+ output->tensor<T, 5>().device(context->eigen_device<CPUDevice>()) =
+ Eigen::CuboidAvgPooling(tensor_in.tensor<T, 5>(), window[0], window[1],
+ window[2], stride[0], stride[1], stride[2],
+ BrainPadding2EigenPadding(padding_type));
+ }
+};
+
+template <typename T>
+struct LaunchPoolingOp<CPUDevice, T, MAX> {
+ static void launch(OpKernelContext* context, const Tensor& tensor_in,
+ const std::array<int64, 3>& window,
+ const std::array<int64, 3>& stride,
+ const std::array<int64, 3>& padding, Padding padding_type,
+ Tensor* output) {
+ output->tensor<T, 5>().device(context->eigen_device<CPUDevice>()) =
+ Eigen::CuboidMaxPooling(tensor_in.tensor<T, 5>(), window[0], window[1],
+ window[2], stride[0], stride[1], stride[2],
+ BrainPadding2EigenPadding(padding_type));
+ }
+};
+
+template <typename Device, typename T, PoolingType Type>
+class Pooling3DOp : public UnaryOp<T> {
+ public:
+ explicit Pooling3DOp(OpKernelConstruction* context) : UnaryOp<T>(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
+ OP_REQUIRES(context, ksize_.size() == 5,
+ errors::InvalidArgument("Sliding window ksize field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 5,
+ errors::InvalidArgument("Sliding window stride field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
+ errors::Unimplemented(
+ "Pooling is not yet supported on the batch dimension."));
+ OP_REQUIRES(context, ksize_[4] == 1 && stride_[4] == 1,
+ errors::Unimplemented(
+ "Pooling is not yet supported on the depth dimension."));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& tensor_in = context->input(0);
+
+ OP_REQUIRES(context, tensor_in.dims() == 5,
+ errors::InvalidArgument("tensor_in must be 5-dimensional"));
+ const int64 depth = tensor_in.dim_size(4);
+ const int64 in_batch = tensor_in.dim_size(0);
+
+ std::array<int64, 3> input_size{
+ {tensor_in.dim_size(3), tensor_in.dim_size(2), tensor_in.dim_size(1)}};
+ std::array<int64, 3> window({{ksize_[3], ksize_[2], ksize_[1]}});
+ std::array<int64, 3> stride({{stride_[3], stride_[2], stride_[1]}});
+ std::array<int64, 3> padding, out;
+
+ OP_REQUIRES_OK(context, Get3dOutputSize(input_size, window, stride,
+ padding_, &out, &padding));
+
+ TensorShape out_shape({in_batch, out[2], out[1], out[0], depth});
+ Tensor* output;
+ OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
+ LaunchPoolingOp<Device, T, Type>::launch(context, tensor_in, window, stride,
+ padding, padding_, output);
+ }
+
+ private:
+ std::vector<int32> ksize_;
+ std::vector<int32> stride_;
+ Padding padding_;
+};
+REGISTER_KERNEL_BUILDER(
+ Name("AvgPool3D").Device(DEVICE_CPU).TypeConstraint<float>("T"),
+ Pooling3DOp<CPUDevice, float, AVG>);
+REGISTER_KERNEL_BUILDER(
+ Name("MaxPool3D").Device(DEVICE_CPU).TypeConstraint<float>("T"),
+ Pooling3DOp<CPUDevice, float, MAX>);
+
+template <typename Device, typename T>
+struct LaunchMaxPooling3dGradOp;
+
+template <typename T>
+struct LaunchMaxPooling3dGradOp<CPUDevice, T> {
+ static void launch(OpKernelContext* context, const Tensor& tensor_in,
+ const Tensor& tensor_out, const Tensor& out_backprop,
+ const std::array<int64, 3>& window,
+ const std::array<int64, 3>& stride,
+ const std::array<int64, 3>& out,
+ const std::array<int64, 3>& padding, Tensor* output) {
+ output->flat<T>().setZero();
+ for (int64 p = 0; p < out_backprop.dim_size(3); ++p) {
+ // Calculate broadcast size for planes/rows/cols. For SAME padding,
+ // current index could be in the padding area, and
+ // p * stride_planes + window_planes
+ // could be beyond the input tensor's boundary. In such cases, change
+ // the starting index and reduce the broadcast size.
+ //
+ // The same procedure is repeated for every spatial dimension in the
+ // nested loops below.
+ int pindex, psize;
+ std::array<int64, 3> input_size{{tensor_in.dim_size(3),
+ tensor_in.dim_size(2),
+ tensor_in.dim_size(1)}};
+ OP_REQUIRES_OK(context,
+ GetBroadcastSize(p, input_size[0], window[0], stride[0],
+ padding[0], &pindex, &psize));
+ for (int64 r = 0; r < out_backprop.dim_size(2); ++r) {
+ int rindex, rsize;
+ OP_REQUIRES_OK(context,
+ GetBroadcastSize(r, input_size[1], window[1], stride[1],
+ padding[1], &rindex, &rsize));
+ for (int64 c = 0; c < out_backprop.dim_size(1); ++c) {
+ int cindex, csize;
+ OP_REQUIRES_OK(
+ context, GetBroadcastSize(c, input_size[2], window[2], stride[2],
+ padding[2], &cindex, &csize));
+ TensorSlice src{{0, -1}, {c, 1}, {r, 1}, {p, 1}, {0, -1}};
+ TensorSlice dst{{0, -1},
+ {cindex, csize},
+ {rindex, rsize},
+ {pindex, psize},
+ {0, -1}};
+ Eigen::DSizes<Eigen::DenseIndex, 5> src_indices;
+ Eigen::DSizes<Eigen::DenseIndex, 5> src_sizes;
+ Eigen::DSizes<Eigen::DenseIndex, 5> dst_indices;
+ Eigen::DSizes<Eigen::DenseIndex, 5> dst_sizes;
+ src.FillIndicesAndSizes<5>(out_backprop.shape(), &src_indices,
+ &src_sizes);
+ dst.FillIndicesAndSizes<5>(tensor_in.shape(), &dst_indices,
+ &dst_sizes);
+
+#if !defined(EIGEN_HAS_INDEX_LIST)
+ Eigen::array<int, 5> bcast = {1, csize, rsize, psize, 1};
+#else
+ Eigen::IndexList<Eigen::type2index<1>, int, int, int,
+ Eigen::type2index<1> >
+ bcast;
+ bcast.set(1, csize);
+ bcast.set(2, rsize);
+ bcast.set(3, psize);
+#endif
+
+ // Slice from tensor_in.
+ Eigen::Tensor<T, 5, Eigen::RowMajor> tensor_in_slice(dst_sizes);
+ tensor_in_slice.device(context->eigen_cpu_device()) =
+ tensor_in.tensor<T, 5>().slice(dst_indices, dst_sizes);
+
+ // Slice from tensor_out.
+ Eigen::Tensor<T, 5, Eigen::RowMajor> tensor_out_slice(src_sizes);
+ tensor_out_slice.device(context->eigen_cpu_device()) =
+ tensor_out.tensor<T, 5>().slice(src_indices, src_sizes);
+
+ // Backprop slice.
+ Eigen::Tensor<T, 5, Eigen::RowMajor> out_backprop_slice(src_sizes);
+ out_backprop_slice.device(context->eigen_cpu_device()) =
+ out_backprop.tensor<T, 5>().slice(src_indices, src_sizes);
+
+ // The true backprop slice: if an element is the max, choose
+ // the backprop slice; otherwise set to 0.
+ Eigen::Tensor<T, 5, Eigen::RowMajor> select_slice(dst_sizes);
+ Eigen::Tensor<T, 5, Eigen::RowMajor> mat0(dst_sizes);
+ mat0.setZero();
+ select_slice =
+ ((tensor_in_slice - tensor_out_slice.broadcast(bcast)).abs() <
+ tensor_in_slice.constant(1e-5))
+ .select(out_backprop_slice.broadcast(bcast), mat0);
+
+ output->tensor<T, 5>()
+ .slice(dst_indices, dst_sizes)
+ .device(context->eigen_cpu_device()) += select_slice;
+ }
+ }
+ }
+ }
+};
+
+template <class Device, class T>
+class MaxPooling3dGradOp : public OpKernel {
+ public:
+ explicit MaxPooling3dGradOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
+ OP_REQUIRES(context, ksize_.size() == 5,
+ errors::InvalidArgument("Sliding window ksize field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 5,
+ errors::InvalidArgument("Sliding window stride field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
+ errors::Unimplemented(
+ "Pooling is not yet supported on the batch dimension."));
+ OP_REQUIRES(context, ksize_[4] == 1 && stride_[4] == 1,
+ errors::Unimplemented(
+ "Pooling is not yet supported on the depth dimension."));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& tensor_in = context->input(0);
+ const Tensor& tensor_out = context->input(1);
+ const Tensor& out_backprop = context->input(2);
+ OP_REQUIRES(context, tensor_in.dims() == 5,
+ errors::InvalidArgument("tensor_in must be 5-dimensional"));
+ OP_REQUIRES(context, tensor_out.dims() == 5,
+ errors::InvalidArgument("tensor_out must be 5-dimensional"));
+ OP_REQUIRES(context, out_backprop.dims() == 5,
+ errors::InvalidArgument("out_backprop must be 5-dimensional"));
+
+ const TensorShape& output_shape = tensor_in.shape();
+ Tensor* input_backprop;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, output_shape, &input_backprop));
+
+ std::array<int64, 3> input_size = {{output_shape.dim_size(3),
+ output_shape.dim_size(2),
+ output_shape.dim_size(1)}};
+ std::array<int64, 3> window = {{ksize_[3], ksize_[2], ksize_[1]}};
+ std::array<int64, 3> stride = {{stride_[3], stride_[2], stride_[1]}};
+ std::array<int64, 3> out, padding;
+
+ OP_REQUIRES_OK(context, Get3dOutputSize(input_size, window, stride,
+ padding_, &out, &padding));
+ LaunchMaxPooling3dGradOp<Device, T>::launch(context, tensor_in, tensor_out,
+ out_backprop, window, stride,
+ out, padding, input_backprop);
+ }
+
+ private:
+ std::vector<int32> ksize_;
+ std::vector<int32> stride_;
+ Padding padding_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("MaxPool3DGrad").Device(DEVICE_CPU).TypeConstraint<float>("T"),
+ MaxPooling3dGradOp<CPUDevice, float>);
+
+template <typename Device, typename T>
+struct LaunchAvgPooling3dGradOp;
+
+template <typename T>
+struct LaunchAvgPooling3dGradOp<CPUDevice, T> {
+ static void launch(OpKernelContext* context,
+ const TensorShape& tensor_in_shape,
+ const Tensor& out_backprop,
+ const std::array<int64, 3>& window,
+ const std::array<int64, 3>& stride,
+ const std::array<int64, 3>& output_shape,
+ const std::array<int64, 3>& padding, Tensor* output) {
+ output->flat<T>().setZero();
+ std::array<int64, 3> input_size = {{tensor_in_shape.dim_size(3),
+ tensor_in_shape.dim_size(2),
+ tensor_in_shape.dim_size(1)}};
+ for (int64 p = 0; p < out_backprop.dim_size(3); ++p) {
+ // Calculate broadcast size for planes/rows/cols. For SAME padding,
+ // current index could be in the padding area, and
+ // p * stride_planes + window_planes
+ // could be beyond the input tensor's boundary. In such cases, change
+ // the starting index and reduce the broadcast size.
+ //
+ // The same procedure is repeated for every spatial dimension in the
+ // nested loops below.
+ int pindex, psize;
+ OP_REQUIRES_OK(context,
+ GetBroadcastSize(p, input_size[0], window[0], stride[0],
+ padding[0], &pindex, &psize));
+ for (int64 r = 0; r < out_backprop.dim_size(2); ++r) {
+ int rindex, rsize;
+ OP_REQUIRES_OK(context,
+ GetBroadcastSize(r, input_size[1], window[1], stride[1],
+ padding[1], &rindex, &rsize));
+ for (int64 c = 0; c < out_backprop.dim_size(1); ++c) {
+ int cindex, csize;
+ OP_REQUIRES_OK(
+ context, GetBroadcastSize(c, input_size[2], window[2], stride[2],
+ padding[2], &cindex, &csize));
+ TensorSlice src{{0, -1}, {c, 1}, {r, 1}, {p, 1}, {0, -1}};
+ TensorSlice dst{{0, -1},
+ {cindex, csize},
+ {rindex, rsize},
+ {pindex, psize},
+ {0, -1}};
+ Eigen::DSizes<Eigen::DenseIndex, 5> src_indices;
+ Eigen::DSizes<Eigen::DenseIndex, 5> src_sizes;
+ Eigen::DSizes<Eigen::DenseIndex, 5> dst_indices;
+ Eigen::DSizes<Eigen::DenseIndex, 5> dst_sizes;
+ src.FillIndicesAndSizes<5>(out_backprop.shape(), &src_indices,
+ &src_sizes);
+ dst.FillIndicesAndSizes<5>(tensor_in_shape, &dst_indices, &dst_sizes);
+#if !defined(EIGEN_HAS_INDEX_LIST)
+ Eigen::array<int, 5> bcast = {1, csize, rsize, psize, 1};
+#else
+ Eigen::IndexList<Eigen::type2index<1>, int, int, int,
+ Eigen::type2index<1> >
+ bcast;
+ bcast.set(1, csize);
+ bcast.set(2, rsize);
+ bcast.set(3, psize);
+#endif
+ Eigen::Tensor<T, 5, Eigen::RowMajor> slices(src_sizes);
+ slices.device(context->eigen_cpu_device()) =
+ out_backprop.tensor<T, 5>().slice(src_indices, src_sizes);
+ // Divide by the size of the actual patch (psize * rsize * csize).
+ float divide_size = rsize * csize * psize * 1.0f;
+ slices *= slices.constant(1.0f / divide_size);
+
+ output->tensor<T, 5>()
+ .slice(dst_indices, dst_sizes)
+ .device(context->eigen_cpu_device()) += slices.broadcast(bcast);
+ }
+ }
+ }
+ }
+};
+
+template <class Device, class T>
+class AvgPooling3dGradOp : public OpKernel {
+ public:
+ explicit AvgPooling3dGradOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
+ OP_REQUIRES(context, ksize_.size() == 5,
+ errors::InvalidArgument("Sliding window ksize field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 5,
+ errors::InvalidArgument("Sliding window stride field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
+ errors::Unimplemented(
+ "Pooling is not yet supported on the batch dimension."));
+ OP_REQUIRES(context, ksize_[4] == 1 && stride_[4] == 1,
+ errors::Unimplemented(
+ "Pooling is not yet supported on the depth dimension."));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& tensor_in_shape = context->input(0);
+ const Tensor& out_backprop = context->input(1);
+ OP_REQUIRES(context, tensor_in_shape.dims() == 1 &&
+ tensor_in_shape.NumElements() == 5,
+ errors::InvalidArgument("tensor_in must be 1-dimensional and 5 "
+ "elements"));
+ OP_REQUIRES(context, out_backprop.dims() == 5,
+ errors::InvalidArgument("out_backprop must be 5-dimensional"));
+
+ TensorShape output_shape;
+ auto shape_vec = tensor_in_shape.vec<int32>();
+ for (int64 i = 0; i < tensor_in_shape.NumElements(); ++i) {
+ output_shape.AddDim(shape_vec(i));
+ }
+
+ Tensor* output;
+ OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
+
+ std::array<int64, 3> input_size = {{output_shape.dim_size(3),
+ output_shape.dim_size(2),
+ output_shape.dim_size(1)}};
+ std::array<int64, 3> window = {{ksize_[3], ksize_[2], ksize_[1]}};
+ std::array<int64, 3> stride = {{stride_[3], stride_[2], stride_[1]}};
+ std::array<int64, 3> padding, out;
+
+ OP_REQUIRES_OK(context, Get3dOutputSize(input_size, window, stride,
+ padding_, &out, &padding));
+
+ LaunchAvgPooling3dGradOp<Device, T>::launch(context, output_shape,
+ out_backprop, window, stride,
+ out, padding, output);
+ }
+
+ private:
+ std::vector<int32> ksize_;
+ std::vector<int32> stride_;
+ Padding padding_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("AvgPool3DGrad")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<float>("T")
+ .HostMemory("orig_input_shape"),
+ AvgPooling3dGradOp<CPUDevice, float>);
+
+#if GOOGLE_CUDA
+
+template <typename T>
+struct LaunchPoolingOp<GPUDevice, T, AVG> {
+ static void launch(OpKernelContext* context, const Tensor& tensor_in,
+ const std::array<int64, 3>& window,
+ const std::array<int64, 3>& stride,
+ const std::array<int64, 3>& padding, Padding padding_type,
+ Tensor* output) {
+ DnnPooling3dOp<T>::Compute(context,
+ perftools::gputools::dnn::PoolingMode::kAverage,
+ window, stride, padding, tensor_in, output);
+ }
+};
+
+template <typename T>
+struct LaunchPoolingOp<GPUDevice, T, MAX> {
+ static void launch(OpKernelContext* context, const Tensor& tensor_in,
+ const std::array<int64, 3>& window,
+ const std::array<int64, 3>& stride,
+ const std::array<int64, 3>& padding, Padding padding_type,
+ Tensor* output) {
+ DnnPooling3dOp<T>::Compute(context,
+ perftools::gputools::dnn::PoolingMode::kMaximum,
+ window, stride, padding, tensor_in, output);
+ }
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("AvgPool3D").Device(DEVICE_GPU).TypeConstraint<float>("T"),
+ Pooling3DOp<GPUDevice, float, AVG>);
+REGISTER_KERNEL_BUILDER(
+ Name("MaxPool3D").Device(DEVICE_GPU).TypeConstraint<float>("T"),
+ Pooling3DOp<GPUDevice, float, MAX>);
+
+template <typename T>
+struct LaunchMaxPooling3dGradOp<GPUDevice, T> {
+ static void launch(OpKernelContext* context, const Tensor& tensor_in,
+ const Tensor& tensor_out, const Tensor& out_backprop,
+ const std::array<int64, 3>& window,
+ const std::array<int64, 3>& stride,
+ const std::array<int64, 3>& out,
+ const std::array<int64, 3>& padding,
+ Tensor* input_backprop) {
+ const TensorShape output_shape = tensor_in.shape();
+ DnnPooling3dGradOp<T>::Compute(
+ context, perftools::gputools::dnn::PoolingMode::kMaximum, window,
+ stride, padding, out, out_backprop, output_shape, &tensor_in,
+ &tensor_out, input_backprop);
+ }
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("MaxPool3DGrad").Device(DEVICE_GPU).TypeConstraint<float>("T"),
+ MaxPooling3dGradOp<GPUDevice, float>);
+
+template <typename T>
+struct LaunchAvgPooling3dGradOp<GPUDevice, T> {
+ static void launch(OpKernelContext* context,
+ const TensorShape& tensor_in_shape,
+ const Tensor& out_backprop,
+ const std::array<int64, 3>& window,
+ const std::array<int64, 3>& stride,
+ const std::array<int64, 3>& out,
+ const std::array<int64, 3>& padding, Tensor* output) {
+ DnnPooling3dGradOp<T>::Compute(
+ context, perftools::gputools::dnn::PoolingMode::kAverage, window,
+ stride, padding, out, out_backprop, tensor_in_shape, nullptr, nullptr,
+ output);
+ }
+};
+REGISTER_KERNEL_BUILDER(Name("AvgPool3DGrad")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<float>("T")
+ .HostMemory("orig_input_shape"),
+ AvgPooling3dGradOp<GPUDevice, float>);
+
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/pooling_ops_common.cc b/tensorflow/core/kernels/pooling_ops_common.cc
index eecaf25c2b..017b789473 100644
--- a/tensorflow/core/kernels/pooling_ops_common.cc
+++ b/tensorflow/core/kernels/pooling_ops_common.cc
@@ -153,9 +153,9 @@ void DnnPoolingOp<T>::Compute(
ShapeFromFormat(FORMAT_NCHW, tensor_in.shape(),
data_format),
&transformed_input));
- functor::NHWCToNCHW<GPUDevice, T>()(context->eigen_device<Device>(),
- tensor_in.tensor<T, 4>(),
- transformed_input.tensor<T, 4>());
+ functor::NHWCToNCHW<GPUDevice, T, 4>()(context->eigen_device<Device>(),
+ tensor_in.tensor<T, 4>(),
+ transformed_input.tensor<T, 4>());
} else {
transformed_input = tensor_in;
}
@@ -213,7 +213,7 @@ void DnnPoolingOp<T>::Compute(
if (data_format == FORMAT_NHWC) {
/// Transform the output data from NCHW back to NHWC
auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
- functor::NCHWToNHWC<GPUDevice, T>()(
+ functor::NCHWToNHWC<GPUDevice, T, 4>()(
context->eigen_device<Device>(),
toConstTensor(transformed_output).template tensor<T, 4>(),
tensor_out->tensor<T, 4>());
@@ -292,19 +292,19 @@ void DnnPoolingGradOp<T>::Compute(
// For AvgPoolGrad, the original input tensor is not necessary. However,
// cudnn still requires them to run, although they do not affect the
// results.
- functor::NHWCToNCHW<GPUDevice, T>()(context->eigen_device<Device>(),
- tensor_in->tensor<T, 4>(),
- transformed_input.tensor<T, 4>());
+ functor::NHWCToNCHW<GPUDevice, T, 4>()(context->eigen_device<Device>(),
+ tensor_in->tensor<T, 4>(),
+ transformed_input.tensor<T, 4>());
}
if (tensor_out) {
// For AvgPoolGrad, the original output tensor is not necessary. However,
// cudnn still requires them to run, although they do not affect the
// results.
- functor::NHWCToNCHW<GPUDevice, T>()(context->eigen_device<Device>(),
- tensor_out->tensor<T, 4>(),
- transformed_output.tensor<T, 4>());
+ functor::NHWCToNCHW<GPUDevice, T, 4>()(context->eigen_device<Device>(),
+ tensor_out->tensor<T, 4>(),
+ transformed_output.tensor<T, 4>());
}
- functor::NHWCToNCHW<GPUDevice, T>()(
+ functor::NHWCToNCHW<GPUDevice, T, 4>()(
context->eigen_device<Device>(), out_backprop.tensor<T, 4>(),
transformed_output_backprop.tensor<T, 4>());
}
@@ -361,7 +361,7 @@ void DnnPoolingGradOp<T>::Compute(
if (data_format == FORMAT_NHWC) {
/// Transform the output data from NCHW back to NHWC.
auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
- functor::NCHWToNHWC<GPUDevice, T>()(
+ functor::NCHWToNHWC<GPUDevice, T, 4>()(
context->eigen_device<Device>(),
toConstTensor(transformed_input_backprop).template tensor<T, 4>(),
input_backprop->tensor<T, 4>());
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index 5cb9b5fa26..e68be084a4 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -416,6 +416,159 @@ output: 4-D with shape
)doc");
// --------------------------------------------------------------------------
+REGISTER_OP("Conv3D")
+ .Input("input: T")
+ .Input("filter: T")
+ .Output("output: T")
+ .Attr("T: numbertype")
+ .Attr("strides: list(int) >= 5")
+ .Attr(GetPaddingAttrString())
+ .Doc(R"doc(
+Computes a 3-D convolution given 5-D `input` and `filter` tensors.
+
+In signal processing, cross-correlation is a measure of similarity of
+two waveforms as a function of a time-lag applied to one of them. This
+is also known as a sliding dot product or sliding inner-product.
+
+Our Conv3D implements a form of cross-correlation.
+
+input: Shape `[batch, in_depth, in_height, in_width, in_channels]`.
+filter: Shape `[filter_depth, filter_height, filter_width, in_channels, out_channels]`.
+ `in_channels` must match between `input` and `filter`.
+strides: 1-D tensor of length 5. The stride of the sliding window for each
+ dimension of `input`. Must have `strides[0] = strides[4] = 1`.
+padding: The type of padding algorithm to use.
+
+)doc");
+
+REGISTER_OP("Conv3DBackpropInput")
+ .Input("input: T")
+ .Input("filter: T")
+ .Input("out_backprop: T")
+ .Output("output: T")
+ .Attr("T: numbertype")
+ .Attr("strides: list(int) >= 5")
+ .Attr(GetPaddingAttrString())
+ .Doc(R"doc(
+Computes the gradients of 3D convolution with respect to the input.
+
+input: Shape `[batch, depth, rows, cols, in_channels]`.
+filter: Shape `[depth, rows, cols, in_channels, out_channels]`.
+ `in_channels` must match between `input` and `filter`.
+out_backprop: Backprop signal of shape `[batch, out_depth, out_rows, out_cols, out_channels]`.
+strides: 1-D tensor of length 5. The stride of the sliding window for each
+ dimension of `input`. Must have `strides[0] = strides[4] = 1`.
+padding: The type of padding algorithm to use.
+
+)doc");
+
+REGISTER_OP("Conv3DBackpropFilter")
+ .Input("input: T")
+ .Input("filter: T")
+ .Input("out_backprop: T")
+ .Output("output: T")
+ .Attr("T: numbertype")
+ .Attr("strides: list(int) >= 5")
+ .Attr(GetPaddingAttrString())
+ .Doc(R"doc(
+Computes the gradients of 3D convolution with respect to the filter.
+
+input: Shape `[batch, depth, rows, cols, in_channels]`.
+filter: Shape `[depth, rows, cols, in_channels, out_channels]`.
+ `in_channels` must match between `input` and `filter`.
+out_backprop: Backprop signal of shape `[batch, out_depth, out_rows, out_cols, out_channels]`.
+strides: 1-D tensor of length 5. The stride of the sliding window for each
+ dimension of `input`. Must have `strides[0] = strides[4] = 1`.
+padding: The type of padding algorithm to use.
+
+)doc");
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("AvgPool3D")
+ .Input("input: T")
+ .Output("output: T")
+ .Attr("ksize: list(int) >= 5")
+ .Attr("strides: list(int) >= 5")
+ .Attr(GetPaddingAttrString())
+ .Attr("T: numbertype")
+ .Doc(R"doc(
+Performs 3D average pooling on the input.
+
+ksize: 1-D tensor of length 5. The size of the window for each dimension of
+ the input tensor. Must have `ksize[0] = ksize[1] = 1`.
+strides: 1-D tensor of length 5. The stride of the sliding window for each
+ dimension of `input`. Must have `strides[0] = strides[4] = 1`.
+padding: The type of padding algorithm to use.
+input: Shape `[batch, depth, rows, cols, channels]` tensor to pool over.
+output: The average pooled output tensor.
+)doc");
+
+REGISTER_OP("AvgPool3DGrad")
+ .Input("orig_input_shape: int32")
+ .Input("grad: T")
+ .Output("output: T")
+ .Attr("ksize: list(int) >= 5")
+ .Attr("strides: list(int) >= 5")
+ .Attr(GetPaddingAttrString())
+ .Attr("T: numbertype")
+ .Doc(R"doc(
+Computes gradients of average pooling function.
+
+ksize: 1-D tensor of length 5. The size of the window for each dimension of
+ the input tensor. Must have `ksize[0] = ksize[1] = 1`.
+strides: 1-D tensor of length 5. The stride of the sliding window for each
+ dimension of `input`. Must have `strides[0] = strides[4] = 1`.
+padding: The type of padding algorithm to use.
+orig_input_shape: The original input dimensions.
+grad: Output backprop of shape `[batch, depth, rows, cols, channels]`.
+output: The backprop for input.
+)doc");
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("MaxPool3D")
+ .Input("input: T")
+ .Output("output: T")
+ .Attr("ksize: list(int) >= 5")
+ .Attr("strides: list(int) >= 5")
+ .Attr(GetPaddingAttrString())
+ .Attr("T: numbertype")
+ .Doc(R"doc(
+Performs 3D max pooling on the input.
+
+ksize: 1-D tensor of length 5. The size of the window for each dimension of
+ the input tensor. Must have `ksize[0] = ksize[1] = 1`.
+strides: 1-D tensor of length 5. The stride of the sliding window for each
+ dimension of `input`. Must have `strides[0] = strides[4] = 1`.
+padding: The type of padding algorithm to use.
+input: Shape `[batch, depth, rows, cols, channels]` tensor to pool over.
+output: The max pooled output tensor.
+)doc");
+
+REGISTER_OP("MaxPool3DGrad")
+ .Input("orig_input: float")
+ .Input("orig_output: float")
+ .Input("grad: T")
+ .Output("output: T")
+ .Attr("ksize: list(int) >= 5 ")
+ .Attr("strides: list(int) >= 5")
+ .Attr(GetPaddingAttrString())
+ .Attr("T: numbertype")
+ .Doc(R"doc(
+Computes gradients of max pooling function.
+
+ksize: 1-D tensor of length 5. The size of the window for each dimension of
+ the input tensor. Must have `ksize[0] = ksize[1] = 1`.
+strides: 1-D tensor of length 5. The stride of the sliding window for each
+ dimension of `input`. Must have `strides[0] = strides[4] = 1`.
+padding: The type of padding algorithm to use.
+orig_input: The original input tensor.
+orig_output: The original output tensor.
+grad: Output backprop of shape `[batch, depth, rows, cols, channels]`.
+)doc");
+
+// --------------------------------------------------------------------------
REGISTER_OP("L2Loss")
.Input("t: T")
diff --git a/tensorflow/core/util/tensor_format.h b/tensorflow/core/util/tensor_format.h
index ee5f3703ce..4115afb2b1 100644
--- a/tensorflow/core/util/tensor_format.h
+++ b/tensorflow/core/util/tensor_format.h
@@ -36,18 +36,26 @@ bool FormatFromString(const string& format_str, TensorFormat* format);
string ToString(TensorFormat format);
// Return the position index from a format given a dimension specification with
-// a char.
+// a char. The chars can be N (batch), C (channels), H (y), W (x), or
+// 0 .. (NDIMS-1).
+template <int NDIMS>
inline int32 GetTensorDimIndex(TensorFormat format, char dimension) {
if (format == FORMAT_NHWC) {
switch (dimension) {
case 'N':
return 0;
- case 'H':
+ case '0':
return 1;
- case 'W':
+ case '1':
return 2;
- case 'C':
+ case '2':
return 3;
+ case 'H':
+ return NDIMS - 1;
+ case 'W':
+ return NDIMS;
+ case 'C':
+ return 1 + NDIMS;
default:
LOG(FATAL) << "Invalid dimension: " << dimension;
}
@@ -57,10 +65,16 @@ inline int32 GetTensorDimIndex(TensorFormat format, char dimension) {
return 0;
case 'C':
return 1;
- case 'H':
+ case '0':
return 2;
- case 'W':
+ case '1':
return 3;
+ case '2':
+ return 4;
+ case 'H':
+ return NDIMS;
+ case 'W':
+ return NDIMS + 1;
default:
LOG(FATAL) << "Invalid dimension: " << dimension;
}
@@ -69,11 +83,15 @@ inline int32 GetTensorDimIndex(TensorFormat format, char dimension) {
}
}
+inline int32 GetTensorDimIndex(TensorFormat format, char dimension) {
+ return GetTensorDimIndex<2>(format, dimension);
+}
+
// Return the given tensor dimension from a tensor. The tensor is interpretted
// using the specified format, and a dimension specification using a char.
inline int64 GetTensorDim(const Tensor& tensor, TensorFormat format,
char dimension) {
- int index = GetTensorDimIndex(format, dimension);
+ int index = GetTensorDimIndex<2>(format, dimension);
CHECK(index >= 0 && index < tensor.dims())
<< "Invalid index from the dimension: " << index << ", " << format << ", "
<< dimension;
@@ -86,7 +104,7 @@ inline int64 GetTensorDim(const Tensor& tensor, TensorFormat format,
// specification using a char.
inline int64 GetTensorDim(const TensorShape& tensor_shape, TensorFormat format,
char dimension) {
- int index = GetTensorDimIndex(format, dimension);
+ int index = GetTensorDimIndex<2>(format, dimension);
CHECK(index >= 0 && index < tensor_shape.dims())
<< "Invalid index from the dimension: " << index << ", " << format << ", "
<< dimension;
@@ -99,7 +117,7 @@ inline int64 GetTensorDim(const TensorShape& tensor_shape, TensorFormat format,
template <typename T>
T GetTensorDim(const std::vector<T>& attributes, TensorFormat format,
char dimension) {
- int index = GetTensorDimIndex(format, dimension);
+ int index = GetTensorDimIndex<2>(format, dimension);
CHECK(index >= 0 && index < attributes.size())
<< "Invalid index from the dimension: " << index << ", " << format << ", "
<< dimension;
@@ -113,10 +131,10 @@ string GetConvnetDataFormatAttrString();
inline TensorShape ShapeFromFormat(TensorFormat format, int64 N, int64 H,
int64 W, int64 C) {
std::vector<int64> dim_sizes(4);
- dim_sizes[GetTensorDimIndex(format, 'N')] = N;
- dim_sizes[GetTensorDimIndex(format, 'H')] = H;
- dim_sizes[GetTensorDimIndex(format, 'W')] = W;
- dim_sizes[GetTensorDimIndex(format, 'C')] = C;
+ dim_sizes[GetTensorDimIndex<2>(format, 'N')] = N;
+ dim_sizes[GetTensorDimIndex<2>(format, 'H')] = H;
+ dim_sizes[GetTensorDimIndex<2>(format, 'W')] = W;
+ dim_sizes[GetTensorDimIndex<2>(format, 'C')] = C;
return TensorShape(dim_sizes);
}
@@ -128,13 +146,13 @@ inline TensorShape ShapeFromFormat(TensorFormat dst_format,
return src_shape;
}
std::vector<int64> dim_sizes(4);
- dim_sizes[GetTensorDimIndex(dst_format, 'N')] =
+ dim_sizes[GetTensorDimIndex<2>(dst_format, 'N')] =
GetTensorDim(src_shape, src_format, 'N');
- dim_sizes[GetTensorDimIndex(dst_format, 'H')] =
+ dim_sizes[GetTensorDimIndex<2>(dst_format, 'H')] =
GetTensorDim(src_shape, src_format, 'H');
- dim_sizes[GetTensorDimIndex(dst_format, 'W')] =
+ dim_sizes[GetTensorDimIndex<2>(dst_format, 'W')] =
GetTensorDim(src_shape, src_format, 'W');
- dim_sizes[GetTensorDimIndex(dst_format, 'C')] =
+ dim_sizes[GetTensorDimIndex<2>(dst_format, 'C')] =
GetTensorDim(src_shape, src_format, 'C');
return TensorShape(dim_sizes);
}
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 05df66e4bb..66963145f7 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -1119,6 +1119,7 @@ sharded_kernel_test_list = glob([
"kernel_tests/cwise_ops_test.py",
"kernel_tests/embedding_ops_test.py",
"kernel_tests/linalg_grad_test.py",
+ "kernel_tests/conv_ops_3d_test.py",
])
cpu_only_kernel_test_list = glob([
diff --git a/tensorflow/python/kernel_tests/conv_ops_3d_test.py b/tensorflow/python/kernel_tests/conv_ops_3d_test.py
new file mode 100644
index 0000000000..a86d1b60ea
--- /dev/null
+++ b/tensorflow/python/kernel_tests/conv_ops_3d_test.py
@@ -0,0 +1,420 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional tests for 3d convolutional operations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import tensorflow as tf
+
+
+class Conv3DTest(tf.test.TestCase):
+
+ def _VerifyValuesForDevice(self, tensor_in_sizes, filter_in_sizes, stride,
+ padding, expected, use_gpu):
+ total_size_1 = 1
+ total_size_2 = 1
+ for s in tensor_in_sizes:
+ total_size_1 *= s
+ for s in filter_in_sizes:
+ total_size_2 *= s
+
+ # Initializes the input tensor with array containing incrementing
+ # numbers from 1.
+ x1 = [f * 1.0 for f in range(1, total_size_1 + 1)]
+ x2 = [f * 1.0 for f in range(1, total_size_2 + 1)]
+ with self.test_session(use_gpu=use_gpu) as sess:
+ t1 = tf.constant(x1, shape=tensor_in_sizes)
+ t2 = tf.constant(x2, shape=filter_in_sizes)
+ conv = tf.nn.conv3d(t1,
+ t2, [1, stride, stride, stride, 1],
+ padding=padding)
+ value = sess.run(conv)
+ print("expected = ", expected)
+ print("actual = ", value)
+ self.assertArrayNear(expected, value.flatten(), 1e-5)
+
+ def _VerifyValues(self, tensor_in_sizes, filter_in_sizes, stride, padding,
+ expected):
+ self._VerifyValuesForDevice(tensor_in_sizes,
+ filter_in_sizes,
+ stride,
+ padding,
+ expected,
+ use_gpu=False)
+ self._VerifyValuesForDevice(tensor_in_sizes,
+ filter_in_sizes,
+ stride,
+ padding,
+ expected,
+ use_gpu=True)
+
+ def testConv3D1x1x1Filter(self):
+ expected_output = [30.0, 36.0, 42.0, 66.0, 81.0, 96.0, 102.0, 126.0, 150.0,
+ 138.0, 171.0, 204.0, 174.0, 216.0, 258.0, 210.0, 261.0,
+ 312.0]
+
+ # These are equivalent to the Conv2D1x1 case.
+ self._VerifyValues(tensor_in_sizes=[1, 2, 3, 1, 3],
+ filter_in_sizes=[1, 1, 1, 3, 3],
+ stride=1,
+ padding="VALID",
+ expected=expected_output)
+ self._VerifyValues(tensor_in_sizes=[1, 2, 1, 3, 3],
+ filter_in_sizes=[1, 1, 1, 3, 3],
+ stride=1,
+ padding="VALID",
+ expected=expected_output)
+ self._VerifyValues(tensor_in_sizes=[1, 1, 2, 3, 3],
+ filter_in_sizes=[1, 1, 1, 3, 3],
+ stride=1,
+ padding="VALID",
+ expected=expected_output)
+
+ # Expected values computed using scipy's correlate function.
+ def testConv3D2x2x2Filter(self):
+ expected_output = [19554., 19962., 20370., 22110., 22590., 23070., 34890.,
+ 35730., 36570., 37446., 38358., 39270., 50226., 51498.,
+ 52770., 52782., 54126., 55470.]
+ # expected_shape = [1, 3, 1, 2, 5]
+ self._VerifyValues(tensor_in_sizes=[1, 4, 2, 3, 3], # b, z, y, x, fin
+ filter_in_sizes=[2, 2, 2, 3, 3], # z, y, x, fin, fout
+ stride=1, padding="VALID",
+ expected=expected_output)
+
+ def testConv3D2x2x2FilterStride2(self):
+ expected_output = [19554., 19962., 20370., 50226., 51498., 52770.]
+ self._VerifyValues(tensor_in_sizes=[1, 4, 2, 3, 3],
+ filter_in_sizes=[2, 2, 2, 3, 3],
+ stride=2,
+ padding="VALID",
+ expected=expected_output)
+
+ def testConv3DStride3(self):
+ expected_output = [
+ 36564., 38022., 39480., 37824., 39354., 40884., 39084., 40686., 42288.,
+ 46644., 48678., 50712., 47904., 50010., 52116., 49164., 51342., 53520.,
+ 107124., 112614., 118104., 108384., 113946., 119508., 109644., 115278.,
+ 120912., 117204., 123270., 129336., 118464., 124602., 130740., 119724.,
+ 125934., 132144.
+ ]
+ self._VerifyValues(tensor_in_sizes=[1, 6, 7, 8, 2],
+ filter_in_sizes=[3, 2, 1, 2, 3],
+ stride=3,
+ padding="VALID",
+ expected=expected_output)
+
+ def testConv3D2x2x2FilterStride2Same(self):
+ expected_output = [
+ 19554., 19962., 20370., 10452., 10710., 10968., 50226., 51498., 52770.,
+ 23844., 24534., 25224.
+ ]
+ self._VerifyValues(tensor_in_sizes=[1, 4, 2, 3, 3],
+ filter_in_sizes=[2, 2, 2, 3, 3],
+ stride=2,
+ padding="SAME",
+ expected=expected_output)
+
+ def testKernelSmallerThanStride(self):
+ expected_output = [1., 3., 7., 9., 19., 21., 25., 27.]
+ self._VerifyValues(tensor_in_sizes=[1, 3, 3, 3, 1],
+ filter_in_sizes=[1, 1, 1, 1, 1],
+ stride=2,
+ padding="SAME",
+ expected=expected_output)
+ self._VerifyValues(tensor_in_sizes=[1, 3, 3, 3, 1],
+ filter_in_sizes=[1, 1, 1, 1, 1],
+ stride=2,
+ padding="VALID",
+ expected=expected_output)
+
+ expected_output = [1484., 1592., 770.,
+ 2240., 2348., 1106.,
+ 1149., 1191., 539.,
+
+ 6776., 6884., 3122.,
+ 7532., 7640., 3458.,
+ 3207., 3249., 1421.,
+
+ 3005., 3035., 1225.,
+ 3215., 3245., 1309.,
+ 1013., 1022., 343.]
+ self._VerifyValues(tensor_in_sizes=[1, 7, 7, 7, 1],
+ filter_in_sizes=[2, 2, 2, 1, 1],
+ stride=3,
+ padding="SAME",
+ expected=expected_output)
+
+ expected_output = [1484., 1592.,
+ 2240., 2348.,
+
+ 6776., 6884.,
+ 7532., 7640.]
+ self._VerifyValues(tensor_in_sizes=[1, 7, 7, 7, 1],
+ filter_in_sizes=[2, 2, 2, 1, 1],
+ stride=3,
+ padding="VALID",
+ expected=expected_output)
+
+ def ConstructAndTestGradient(self, batch, input_planes, input_rows,
+ input_cols, filter_planes, filter_rows,
+ filter_cols, in_depth, out_depth, stride,
+ padding, test_input, use_gpu):
+ input_shape = [batch, input_planes, input_rows, input_cols, in_depth]
+ filter_shape = [filter_planes, filter_rows, filter_cols, in_depth,
+ out_depth]
+ if padding == "VALID":
+ output_planes = int(math.ceil((input_planes - filter_planes + 1.0) /
+ stride))
+ output_rows = int(math.ceil((input_rows - filter_rows + 1.0) / stride))
+ output_cols = int(math.ceil((input_cols - filter_cols + 1.0) / stride))
+ else:
+ output_planes = int(math.ceil(float(input_planes) / stride))
+ output_rows = int(math.ceil(float(input_rows) / stride))
+ output_cols = int(math.ceil(float(input_cols) / stride))
+ output_shape = [batch, output_planes, output_rows, output_cols, out_depth]
+ input_size = 1
+ for x in input_shape:
+ input_size *= x
+ filter_size = 1
+ for x in filter_shape:
+ filter_size *= x
+ input_data = [x * 1.0 / input_size for x in range(0, input_size)]
+ filter_data = [x * 1.0 / filter_size for x in range(0, filter_size)]
+ if use_gpu:
+ data_type = tf.float32
+ tolerance = 4e-3
+ else:
+ data_type = tf.float64
+ tolerance = 1e-8
+ with self.test_session(use_gpu=use_gpu):
+ input_tensor = tf.constant(input_data,
+ shape=input_shape,
+ dtype=data_type,
+ name="input")
+ filter_tensor = tf.constant(filter_data,
+ shape=filter_shape,
+ dtype=data_type,
+ name="filter")
+ conv = tf.nn.conv3d(input_tensor,
+ filter_tensor, [1, stride, stride, stride, 1],
+ padding,
+ name="conv")
+
+ if test_input:
+ err = tf.test.compute_gradient_error(input_tensor, input_shape, conv,
+ output_shape)
+ else:
+ err = tf.test.compute_gradient_error(filter_tensor, filter_shape, conv,
+ output_shape)
+ print("conv3d gradient error = ", err)
+ self.assertLess(err, tolerance)
+
+ def testInputGradientValidPaddingStrideOne(self):
+ for use_gpu in [False, True]:
+ self.ConstructAndTestGradient(batch=2,
+ input_planes=3,
+ input_rows=5,
+ input_cols=4,
+ filter_planes=3,
+ filter_rows=3,
+ filter_cols=3,
+ in_depth=2,
+ out_depth=3,
+ stride=1,
+ padding="VALID",
+ test_input=True,
+ use_gpu=use_gpu)
+
+ def testFilterGradientValidPaddingStrideOne(self):
+ for use_gpu in [False, True]:
+ self.ConstructAndTestGradient(batch=4,
+ input_planes=4,
+ input_rows=6,
+ input_cols=5,
+ filter_planes=2,
+ filter_rows=2,
+ filter_cols=2,
+ in_depth=2,
+ out_depth=3,
+ stride=1,
+ padding="VALID",
+ test_input=False,
+ use_gpu=use_gpu)
+
+ def testInputGradientValidPaddingStrideTwo(self):
+ for use_gpu in [False, True]:
+ self.ConstructAndTestGradient(batch=2,
+ input_planes=6,
+ input_rows=3,
+ input_cols=5,
+ filter_planes=3,
+ filter_rows=3,
+ filter_cols=3,
+ in_depth=2,
+ out_depth=3,
+ stride=2,
+ padding="VALID",
+ test_input=True,
+ use_gpu=use_gpu)
+
+ def testFilterGradientValidPaddingStrideTwo(self):
+ for use_gpu in [False, True]:
+ self.ConstructAndTestGradient(batch=2,
+ input_planes=7,
+ input_rows=6,
+ input_cols=5,
+ filter_planes=2,
+ filter_rows=2,
+ filter_cols=2,
+ in_depth=2,
+ out_depth=3,
+ stride=2,
+ padding="VALID",
+ test_input=False,
+ use_gpu=use_gpu)
+
+ def testInputGradientValidPaddingStrideThree(self):
+ for use_gpu in [False, True]:
+ self.ConstructAndTestGradient(batch=2,
+ input_planes=3,
+ input_rows=7,
+ input_cols=6,
+ filter_planes=3,
+ filter_rows=3,
+ filter_cols=3,
+ in_depth=2,
+ out_depth=3,
+ stride=3,
+ padding="VALID",
+ test_input=True,
+ use_gpu=use_gpu)
+
+ def testFilterGradientValidPaddingStrideThree(self):
+ for use_gpu in [False, True]:
+ self.ConstructAndTestGradient(batch=2,
+ input_planes=4,
+ input_rows=4,
+ input_cols=7,
+ filter_planes=4,
+ filter_rows=4,
+ filter_cols=4,
+ in_depth=2,
+ out_depth=3,
+ stride=3,
+ padding="VALID",
+ test_input=False,
+ use_gpu=use_gpu)
+
+ def testInputGradientSamePaddingStrideOne(self):
+ for use_gpu in [False, True]:
+ self.ConstructAndTestGradient(batch=2,
+ input_planes=3,
+ input_rows=2,
+ input_cols=2,
+ filter_planes=3,
+ filter_rows=2,
+ filter_cols=1,
+ in_depth=2,
+ out_depth=1,
+ stride=1,
+ padding="SAME",
+ test_input=True,
+ use_gpu=use_gpu)
+
+ def testFilterGradientSamePaddingStrideOne(self):
+ for use_gpu in [False, True]:
+ self.ConstructAndTestGradient(batch=2,
+ input_planes=3,
+ input_rows=6,
+ input_cols=5,
+ filter_planes=2,
+ filter_rows=2,
+ filter_cols=2,
+ in_depth=2,
+ out_depth=3,
+ stride=1,
+ padding="SAME",
+ test_input=False,
+ use_gpu=use_gpu)
+
+ def testInputGradientSamePaddingStrideTwo(self):
+ for use_gpu in [False, True]:
+ self.ConstructAndTestGradient(batch=2,
+ input_planes=6,
+ input_rows=3,
+ input_cols=4,
+ filter_planes=3,
+ filter_rows=3,
+ filter_cols=3,
+ in_depth=2,
+ out_depth=3,
+ stride=2,
+ padding="SAME",
+ test_input=True,
+ use_gpu=use_gpu)
+
+ def testFilterGradientSamePaddingStrideTwo(self):
+ for use_gpu in [False, True]:
+ self.ConstructAndTestGradient(batch=4,
+ input_planes=7,
+ input_rows=3,
+ input_cols=5,
+ filter_planes=2,
+ filter_rows=2,
+ filter_cols=2,
+ in_depth=2,
+ out_depth=3,
+ stride=2,
+ padding="SAME",
+ test_input=False,
+ use_gpu=use_gpu)
+
+ def testInputGradientSamePaddingStrideThree(self):
+ for use_gpu in [False, True]:
+ self.ConstructAndTestGradient(batch=2,
+ input_planes=9,
+ input_rows=3,
+ input_cols=6,
+ filter_planes=3,
+ filter_rows=3,
+ filter_cols=3,
+ in_depth=2,
+ out_depth=3,
+ stride=3,
+ padding="SAME",
+ test_input=True,
+ use_gpu=use_gpu)
+
+ def testFilterGradientSamePaddingStrideThree(self):
+ for use_gpu in [False, True]:
+ self.ConstructAndTestGradient(batch=2,
+ input_planes=9,
+ input_rows=4,
+ input_cols=7,
+ filter_planes=4,
+ filter_rows=4,
+ filter_cols=4,
+ in_depth=2,
+ out_depth=3,
+ stride=3,
+ padding="SAME",
+ test_input=False,
+ use_gpu=use_gpu)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py
index 6a9b9978c1..b59bb5f3b8 100644
--- a/tensorflow/python/kernel_tests/conv_ops_test.py
+++ b/tensorflow/python/kernel_tests/conv_ops_test.py
@@ -862,14 +862,14 @@ class Conv2DTest(tf.test.TestCase):
# Filter larger than input.
with self.assertRaisesRegexp(ValueError,
- "filter must not be larger than the input"):
+ "Filter must not be larger than the input"):
tf.nn.conv2d(tf.placeholder(tf.float32,
shape=[32, 20, 20, 3]),
tf.placeholder(tf.float32,
shape=[20, 21, 3, 2]),
strides=[1, 1, 1, 1], padding="SAME")
with self.assertRaisesRegexp(ValueError,
- "filter must not be larger than the input"):
+ "Filter must not be larger than the input"):
tf.nn.conv2d(tf.placeholder(tf.float32,
shape=[32, 20, 20, 3]),
tf.placeholder(tf.float32,
diff --git a/tensorflow/python/kernel_tests/pooling_ops_3d_test.py b/tensorflow/python/kernel_tests/pooling_ops_3d_test.py
new file mode 100644
index 0000000000..e7686871e1
--- /dev/null
+++ b/tensorflow/python/kernel_tests/pooling_ops_3d_test.py
@@ -0,0 +1,340 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional tests for 3d pooling operations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+
+class PoolingTest(tf.test.TestCase):
+
+ def _VerifyValues(self, pool_func, input_sizes, window, strides, padding,
+ expected, use_gpu):
+ """Verifies the output values of the pooling function.
+
+ Args:
+ pool_func: Function to be called: co.MaxPool, co.AvgPool.
+ input_sizes: Input tensor dimensions.
+ window: Tuple of kernel dims: planes, rows, cols.
+ strides: Tuple of strides for dims: planes, rows, cols.
+ padding: Padding type.
+ expected: An array containing the expected operation outputs.
+ use_gpu: Whether we are running on GPU.
+ """
+ total_size = 1
+ for s in input_sizes:
+ total_size *= s
+ # Initializes the input tensor with array containing incrementing
+ # numbers from 1.
+ x = [f * 1.0 for f in range(1, total_size + 1)]
+ with self.test_session(use_gpu=use_gpu) as sess:
+ t = tf.constant(x, shape=input_sizes)
+ t = pool_func(t,
+ ksize=[1, window[0], window[1], window[2], 1],
+ strides=[1, strides[0], strides[1], strides[2], 1],
+ padding=padding)
+ vals = sess.run(t)
+ # Verifies values.
+ actual = vals.flatten()
+ self.assertAllClose(expected, actual)
+
+ def _testAvgPool3dValidPadding(self, use_gpu):
+ expected_output = [20.5, 21.5, 22.5]
+ self._VerifyValues(tf.nn.avg_pool3d,
+ input_sizes=[1, 3, 3, 3, 3],
+ window=(2, 2, 2),
+ strides=(2, 2, 2),
+ padding="VALID",
+ expected=expected_output,
+ use_gpu=use_gpu)
+
+ def _testAvgPool3dSamePadding(self, use_gpu):
+ expected_output = [20.5, 21.5, 22.5, 26.5, 27.5, 28.5]
+ self._VerifyValues(tf.nn.avg_pool3d,
+ input_sizes=[1, 2, 2, 4, 3],
+ window=(2, 2, 2),
+ strides=(2, 2, 2),
+ padding="SAME",
+ expected=expected_output,
+ use_gpu=use_gpu)
+
+ def _testMaxPool3dValidPadding(self, use_gpu):
+ expected_output = [40.0, 41.0, 42.0]
+ self._VerifyValues(tf.nn.max_pool3d,
+ input_sizes=[1, 3, 3, 3, 3],
+ window=(2, 2, 2),
+ strides=(2, 2, 2),
+ padding="VALID",
+ expected=expected_output,
+ use_gpu=use_gpu)
+
+ def _testMaxPool3dSamePadding(self, use_gpu):
+ expected_output = [31., 32., 33., 34., 35., 36.]
+ self._VerifyValues(tf.nn.max_pool3d,
+ input_sizes=[1, 2, 2, 3, 3],
+ window=(2, 2, 2),
+ strides=(2, 2, 2),
+ padding="SAME",
+ expected=expected_output,
+ use_gpu=use_gpu)
+
+ def testAvgPooling3d(self):
+ for use_gpu in [False, True]:
+ self._testAvgPool3dValidPadding(use_gpu)
+ self._testAvgPool3dSamePadding(use_gpu)
+
+ def testMaxPooling3d(self):
+ for use_gpu in [False, True]:
+ self._testMaxPool3dValidPadding(use_gpu)
+ self._testMaxPool3dSamePadding(use_gpu)
+
+ def testKernelSmallerThanStride(self):
+ for use_gpu in [True, False]:
+ self._VerifyValues(tf.nn.max_pool3d, input_sizes=[1, 3, 3, 3, 1],
+ window=[1, 1, 1], strides=[2, 2, 2],
+ padding="SAME",
+ expected=[1, 3, 7, 9, 19, 21, 25, 27],
+ use_gpu=use_gpu)
+
+ self._VerifyValues(tf.nn.max_pool3d, input_sizes=[1, 7, 7, 7, 1],
+ window=[2, 2, 2], strides=[3, 3, 3],
+ padding="VALID",
+ expected=[58, 61, 79, 82, 205, 208, 226, 229],
+ use_gpu=use_gpu)
+
+ self._VerifyValues(tf.nn.avg_pool3d, input_sizes=[1, 3, 3, 3, 1],
+ window=[1, 1, 1], strides=[2, 2, 2],
+ padding="SAME",
+ expected=[1, 3, 7, 9, 19, 21, 25, 27],
+ use_gpu=use_gpu)
+
+ self._VerifyValues(tf.nn.avg_pool3d, input_sizes=[1, 7, 7, 7, 1],
+ window=[2, 2, 2], strides=[3, 3, 3],
+ padding="VALID",
+ expected=[29.5, 32.5, 50.5, 53.5,
+ 176.5, 179.5, 197.5, 200.5],
+ use_gpu=use_gpu)
+
+ def _ConstructAndTestGradient(self,
+ pool_func,
+ input_sizes,
+ output_sizes,
+ window,
+ strides,
+ padding,
+ x_init_value=None,
+ use_gpu=False):
+ """Verifies the gradients of the avg pooling function.
+
+ Args:
+ pool_func: Function to be called, co.MaxPool, co.AvgPool,
+ or the Lua version.
+ input_sizes: Input tensor dimensions.
+ output_sizes: Output tensor dimensions.
+ window: Tuple of kernel dims: planes, rows, cols.
+ strides: Tuple of strides for dims: planes, rows, cols.
+ padding: Padding type.
+ x_init_value: Values to be passed to the gradient checker.
+ use_gpu: Whether to run pooling on GPU.
+ """
+ total_size = 1
+ for s in input_sizes:
+ total_size *= s
+ # Initializes the input tensor with array containing incrementing
+ # numbers from 1.
+ x = [f * 1.0 for f in range(1, total_size + 1)]
+ with self.test_session(use_gpu=use_gpu):
+ input_tensor = tf.constant(x, shape=input_sizes, name="input")
+ err_margin = 1e-3
+ if pool_func == tf.nn.avg_pool3d:
+ func_name = "avg_pool3d"
+ else:
+ if x_init_value is None:
+ x_init_value = np.asfarray(
+ np.arange(1, total_size + 1),
+ dtype=np.float32).reshape(input_sizes)
+ func_name = "max_pool3d"
+
+ t = pool_func(input_tensor,
+ ksize=[1, window[0], window[1], window[2], 1],
+ strides=[1, strides[0], strides[1], strides[2], 1],
+ padding=padding,
+ name=func_name)
+
+ err = tf.test.compute_gradient_error(input_tensor,
+ input_sizes,
+ t,
+ output_sizes,
+ x_init_value=x_init_value,
+ delta=1e-2)
+ print("%s gradient error = " % func_name, err)
+ self.assertLess(err, err_margin)
+
+ def testMaxPoolGradValidPadding1_1_3d(self):
+ for use_gpu in (False, True):
+ self._ConstructAndTestGradient(tf.nn.max_pool3d,
+ input_sizes=[1, 3, 3, 3, 1],
+ output_sizes=[1, 3, 3, 3, 1],
+ window=(1, 1, 1),
+ strides=(1, 1, 1),
+ padding="VALID",
+ use_gpu=use_gpu)
+
+ def testMaxPoolGradValidPadding2_1_6_3d(self):
+ for use_gpu in (False, True):
+ self._ConstructAndTestGradient(tf.nn.max_pool3d,
+ input_sizes=[2, 3, 3, 6, 3],
+ output_sizes=[2, 2, 2, 5, 3],
+ window=(2, 2, 2),
+ strides=(1, 1, 1),
+ padding="VALID",
+ use_gpu=use_gpu)
+
+ def testMaxPoolGradValidPadding2_1_7_3d(self):
+ for use_gpu in (False, True):
+ self._ConstructAndTestGradient(tf.nn.max_pool3d,
+ input_sizes=[2, 3, 5, 7, 3],
+ output_sizes=[2, 2, 4, 6, 3],
+ window=(2, 2, 2),
+ strides=(1, 1, 1),
+ padding="VALID",
+ use_gpu=use_gpu)
+
+ def testMaxPoolGradValidPadding2_2_3d(self):
+ for use_gpu in (False, True):
+ self._ConstructAndTestGradient(tf.nn.max_pool3d,
+ input_sizes=[2, 2, 2, 2, 3],
+ output_sizes=[2, 1, 1, 1, 3],
+ window=(2, 2, 2),
+ strides=(2, 2, 2),
+ padding="VALID",
+ use_gpu=use_gpu)
+
+ def testMaxPoolGradSamePadding1_1_3d(self):
+ for use_gpu in (False, True):
+ self._ConstructAndTestGradient(tf.nn.max_pool3d,
+ input_sizes=[2, 3, 2, 4, 1],
+ output_sizes=[2, 3, 2, 4, 1],
+ window=(1, 1, 1),
+ strides=(1, 1, 1),
+ padding="SAME",
+ use_gpu=use_gpu)
+
+ def testMaxPoolGradSamePadding2_1_3d(self):
+ for use_gpu in (False, True):
+ self._ConstructAndTestGradient(tf.nn.max_pool3d,
+ input_sizes=[2, 3, 2, 4, 1],
+ output_sizes=[2, 3, 2, 4, 1],
+ window=(2, 2, 2),
+ strides=(1, 1, 1),
+ padding="SAME",
+ use_gpu=use_gpu)
+
+ def testMaxPoolGradSamePadding2_2_3d(self):
+ for use_gpu in (False, True):
+ self._ConstructAndTestGradient(tf.nn.max_pool3d,
+ input_sizes=[2, 5, 2, 4, 3],
+ output_sizes=[2, 3, 1, 2, 3],
+ window=(2, 2, 2),
+ strides=(2, 2, 2),
+ padding="SAME",
+ use_gpu=use_gpu)
+
+ def testMaxPoolGradSamePadding3_1_3d(self):
+ for use_gpu in (False, True):
+ self._ConstructAndTestGradient(tf.nn.max_pool3d,
+ input_sizes=[1, 3, 3, 7, 1],
+ output_sizes=[1, 3, 3, 7, 1],
+ window=(3, 3, 3),
+ strides=(1, 1, 1),
+ padding="SAME",
+ use_gpu=use_gpu)
+
+ def testAvgPoolGradValidPadding1_1_3d(self):
+ for use_gpu in (False, True):
+ self._ConstructAndTestGradient(tf.nn.avg_pool3d,
+ input_sizes=[2, 3, 3, 3, 3],
+ output_sizes=[2, 3, 3, 3, 3],
+ window=(1, 1, 1),
+ strides=(1, 1, 1),
+ padding="VALID",
+ use_gpu=use_gpu)
+
+ def testAvgPoolGradValidPadding2_1_3d(self):
+ for use_gpu in (False, True):
+ self._ConstructAndTestGradient(tf.nn.avg_pool3d,
+ input_sizes=[2, 3, 3, 3, 3],
+ output_sizes=[2, 2, 2, 2, 3],
+ window=(2, 2, 2),
+ strides=(1, 1, 1),
+ padding="VALID",
+ use_gpu=use_gpu)
+
+ def testAvgPoolGradValidPadding2_2_3d(self):
+ for use_gpu in (False, True):
+ self._ConstructAndTestGradient(tf.nn.avg_pool3d,
+ input_sizes=[2, 2, 2, 2, 3],
+ output_sizes=[2, 1, 1, 1, 3],
+ window=(2, 2, 2),
+ strides=(2, 2, 2),
+ padding="VALID",
+ use_gpu=use_gpu)
+
+ def testAvgPoolGradSamePadding1_1_3d(self):
+ for use_gpu in (False, True):
+ self._ConstructAndTestGradient(tf.nn.avg_pool3d,
+ input_sizes=[2, 3, 2, 4, 3],
+ output_sizes=[2, 3, 2, 4, 3],
+ window=(1, 1, 1),
+ strides=(1, 1, 1),
+ padding="SAME",
+ use_gpu=use_gpu)
+
+ def testAvgPoolGradSamePadding2_1_3d(self):
+ for use_gpu in (False, True):
+ self._ConstructAndTestGradient(tf.nn.avg_pool3d,
+ input_sizes=[1, 2, 2, 2, 1],
+ output_sizes=[1, 2, 2, 2, 1],
+ window=(2, 2, 2),
+ strides=(1, 1, 1),
+ padding="SAME",
+ use_gpu=use_gpu)
+
+ def testAvgPoolGradSamePadding2_2_3d(self):
+ for use_gpu in (False, True):
+ self._ConstructAndTestGradient(tf.nn.avg_pool3d,
+ input_sizes=[2, 5, 2, 4, 3],
+ output_sizes=[2, 3, 1, 2, 3],
+ window=(2, 2, 2),
+ strides=(2, 2, 2),
+ padding="SAME",
+ use_gpu=use_gpu)
+
+ def testAvgPoolGradSamePadding3_1_3d(self):
+ for use_gpu in (False, True):
+ self._ConstructAndTestGradient(tf.nn.avg_pool3d,
+ input_sizes=[1, 3, 6, 7, 1],
+ output_sizes=[1, 3, 6, 7, 1],
+ window=(3, 3, 3),
+ strides=(1, 1, 1),
+ padding="SAME",
+ use_gpu=use_gpu)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/pooling_ops_test.py b/tensorflow/python/kernel_tests/pooling_ops_test.py
index 9165748579..7e0993bba3 100644
--- a/tensorflow/python/kernel_tests/pooling_ops_test.py
+++ b/tensorflow/python/kernel_tests/pooling_ops_test.py
@@ -902,12 +902,12 @@ class PoolingTest(tf.test.TestCase):
for pool_func in [tf.nn.max_pool, tf.nn.avg_pool,
tf.nn.max_pool_with_argmax]:
with self.assertRaisesRegexp(ValueError,
- "filter must not be larger than the input"):
+ "Filter must not be larger than the input"):
pool_func(tf.placeholder(tf.float32,
shape=[32, 20, 20, 3]),
ksize=[1, 20, 21, 1], strides=[1, 1, 1, 1], padding="SAME")
with self.assertRaisesRegexp(ValueError,
- "filter must not be larger than the input"):
+ "Filter must not be larger than the input"):
pool_func(tf.placeholder(tf.float32,
shape=[32, 20, 20, 3]),
ksize=[1, 21, 20, 1], strides=[1, 1, 1, 1], padding="SAME")
diff --git a/tensorflow/python/ops/common_shapes.py b/tensorflow/python/ops/common_shapes.py
index db5ed6c551..d746cb1ab3 100644
--- a/tensorflow/python/ops/common_shapes.py
+++ b/tensorflow/python/ops/common_shapes.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""A library of common shape functions."""
from __future__ import absolute_import
from __future__ import division
@@ -41,8 +40,10 @@ def unchanged_shape_with_rank(rank):
A shape function for ops that output a tensor of the same size as their
input, with a particular rank.
"""
+
def _ShapeFunction(op):
return [op.inputs[0].get_shape().with_rank(rank)]
+
return _ShapeFunction
@@ -56,8 +57,10 @@ def unchanged_shape_with_rank_at_least(rank):
A shape function for ops that output a tensor of the same size as their
input, with a particular rank.
"""
+
def _ShapeFunction(op):
return [op.inputs[0].get_shape().with_rank_at_least(rank)]
+
return _ShapeFunction
@@ -71,8 +74,10 @@ def unchanged_shape_with_rank_at_most(rank):
A shape function for ops that output a tensor of the same size as their
input, with a particular rank.
"""
+
def _ShapeFunction(op):
return [op.inputs[0].get_shape().with_rank_at_most(rank)]
+
return _ShapeFunction
@@ -103,12 +108,11 @@ def bias_add_shape(op):
data_format = None
if data_format == b"NCHW":
# Merge the length of bias_shape into the third-to-last dimension.
- output_shape = input_shape[0:-3].concatenate(
- input_shape[-3].merge_with(bias_shape[0])).concatenate(
- input_shape[-2:])
+ output_shape = input_shape[0:-3].concatenate(input_shape[-3].merge_with(
+ bias_shape[0])).concatenate(input_shape[-2:])
else:
- output_shape = input_shape[0:-1].concatenate(
- input_shape[-1].merge_with(bias_shape[0]))
+ output_shape = input_shape[0:-1].concatenate(input_shape[-1].merge_with(
+ bias_shape[0]))
else:
output_shape = tensor_shape.unknown_shape()
return [output_shape]
@@ -130,47 +134,54 @@ def bias_add_grad_shape(op):
return [output_shape]
-def get2d_conv_output_size(input_height, input_width, filter_height,
- filter_width, row_stride, col_stride, padding_type):
- """Returns the number of rows and columns in a convolution/pooling output."""
- input_height = tensor_shape.as_dimension(input_height)
- input_width = tensor_shape.as_dimension(input_width)
- filter_height = tensor_shape.as_dimension(filter_height)
- filter_width = tensor_shape.as_dimension(filter_width)
- row_stride = int(row_stride)
- col_stride = int(col_stride)
-
- if filter_height.value == 1 and filter_width.value == 1 and (
- row_stride == 1 and col_stride == 1):
- return input_height, input_width
+def get_conv_output_size(input_size, filter_size, strides, padding_type):
+ """Returns the spatial size of a n-d convolution/pooling output."""
+ input_size = tuple([tensor_shape.as_dimension(x).value for x in input_size])
+ filter_size = tuple([tensor_shape.as_dimension(x).value for x in filter_size])
+ strides = [int(x) for x in strides]
+
+ if all(x == 1 for x in input_size) and all(x == 1 for x in filter_size):
+ return input_size
+
+ if any(x is not None and y is not None and x > y for x, y in
+ zip(filter_size, input_size)):
+ raise ValueError("Filter must not be larger than the input: "
+ "Filter: %r Input: %r" % (filter_size, input_size))
+
+ if padding_type == b"VALID":
+
+ def _valid(in_dim, k_dim, s_dim):
+ if in_dim is not None and k_dim is not None:
+ return (in_dim - k_dim + s_dim) // s_dim
+ else:
+ return None
+
+ output_size = [
+ _valid(in_dim, k_dim, s_dim)
+ for in_dim, k_dim, s_dim in zip(input_size, filter_size, strides)
+ ]
+ elif padding_type == b"SAME":
+
+ def _same(in_dim, s_dim):
+ if in_dim is not None:
+ return (in_dim + s_dim - 1) // s_dim
+ else:
+ return None
+
+ output_size = [_same(in_dim, s_dim)
+ for in_dim, s_dim in zip(input_size, strides)]
else:
- if filter_height > input_height or filter_width > input_width:
- raise ValueError(
- "filter must not be larger than the input: "
- "Filter: [%sx%s] Input: [%sx%s]"
- % (filter_height, filter_width, input_height, input_width))
-
- # Compute number of rows in the output, based on the padding.
- if input_height.value is None or filter_height.value is None:
- out_rows = None
- elif padding_type == b"VALID":
- out_rows = ((input_height.value - filter_height.value + row_stride) //
- row_stride)
- elif padding_type == b"SAME":
- out_rows = (input_height.value + row_stride - 1) // row_stride
- else:
- raise ValueError("Invalid value for padding: %r" % padding_type)
+ raise ValueError("Invalid padding: %r" % padding_type)
+
+ return tuple(output_size)
- # Compute number of columns in the output, based on the padding.
- if input_width.value is None or filter_width.value is None:
- out_cols = None
- elif padding_type == b"VALID":
- out_cols = ((input_width.value - filter_width.value + col_stride) //
- col_stride)
- elif padding_type == b"SAME":
- out_cols = (input_width.value + col_stride - 1) // col_stride
- return out_rows, out_cols
+def get2d_conv_output_size(input_height, input_width, filter_height,
+ filter_width, row_stride, col_stride, padding_type):
+ """Returns the number of rows and columns in a convolution/pooling output."""
+ return get_conv_output_size((input_height, input_width),
+ (filter_height, filter_width),
+ (row_stride, col_stride), padding_type)
def conv2d_shape(op):
@@ -230,8 +241,9 @@ def conv2d_shape(op):
# information in the input to be ignored. This will require a change
# in the kernel implementation.
padding = op.get_attr("padding")
- out_rows, out_cols = get2d_conv_output_size(
- in_rows, in_cols, filter_rows, filter_cols, stride_r, stride_c, padding)
+ out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, filter_rows,
+ filter_cols, stride_r, stride_c,
+ padding)
output_shape = [batch_size, out_rows, out_cols, depth_out]
if data_format == b"NCHW":
@@ -290,8 +302,9 @@ def depthwise_conv2d_native_shape(op):
# in the kernel implementation.
stride = stride_r
padding = op.get_attr("padding")
- out_rows, out_cols = get2d_conv_output_size(
- in_rows, in_cols, filter_rows, filter_cols, stride, stride, padding)
+ out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, filter_rows,
+ filter_cols, stride, stride,
+ padding)
return [tensor_shape.TensorShape([batch_size, out_rows, out_cols, depth_out])]
@@ -352,8 +365,9 @@ def separable_conv2d_shape(op):
# in the kernel implementation.
stride = stride_r
padding = op.get_attr("padding")
- out_rows, out_cols = get2d_conv_output_size(
- in_rows, in_cols, filter_rows, filter_cols, stride, stride, padding)
+ out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, filter_rows,
+ filter_cols, stride, stride,
+ padding)
return [tensor_shape.TensorShape([batch_size, out_rows, out_cols, depth_out])]
@@ -414,8 +428,9 @@ def avg_pool_shape(op):
# in the kernel implementation.
padding = op.get_attr("padding")
- out_rows, out_cols = get2d_conv_output_size(
- in_rows, in_cols, ksize_r, ksize_c, stride_r, stride_c, padding)
+ out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, ksize_r,
+ ksize_c, stride_r, stride_c,
+ padding)
output_shape = [batch_size, out_rows, out_cols, depth]
if data_format == b"NCHW":
@@ -485,8 +500,9 @@ def max_pool_shape(op):
# in the kernel implementation.
if ksize_d == 1:
padding = op.get_attr("padding")
- out_rows, out_cols = get2d_conv_output_size(
- in_rows, in_cols, ksize_r, ksize_c, stride_r, stride_c, padding)
+ out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, ksize_r,
+ ksize_c, stride_r, stride_c,
+ padding)
output_shape = [batch_size, out_rows, out_cols, depth]
else:
if depth % ksize_d > 0:
diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py
index c911e33fe3..cf26be31e9 100644
--- a/tensorflow/python/ops/nn.py
+++ b/tensorflow/python/ops/nn.py
@@ -108,6 +108,7 @@ concatenated.
@@separable_conv2d
@@atrous_conv2d
@@conv2d_transpose
+@@conv3d
## Pooling
@@ -127,6 +128,8 @@ to the `Convolution` section for details about the padding calculation.
@@avg_pool
@@max_pool
@@max_pool_with_argmax
+@@avg_pool3d
+@@max_pool3d
## Normalization
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index 9dac828896..188d936c0e 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Gradients for operators defined in nn_ops.py."""
from __future__ import absolute_import
@@ -40,14 +39,48 @@ def _Conv2DBackpropGrad(op, grad):
the gradients w.r.t. the input and the filter
"""
return [None,
- nn_ops.conv2d_backprop_filter(
- grad, array_ops.shape(op.inputs[1]), op.inputs[2],
- op.get_attr("strides"), op.get_attr("padding"),
- op.get_attr("use_cudnn_on_gpu"), op.get_attr("data_format")),
- nn_ops.conv2d(
- grad, op.inputs[1], op.get_attr("strides"),
- op.get_attr("padding"), op.get_attr("use_cudnn_on_gpu"),
- op.get_attr("data_format"))]
+ nn_ops.conv2d_backprop_filter(grad, array_ops.shape(op.inputs[1]),
+ op.inputs[2], op.get_attr("strides"),
+ op.get_attr("padding"),
+ op.get_attr("use_cudnn_on_gpu"),
+ op.get_attr("data_format")),
+ nn_ops.conv2d(grad, op.inputs[1], op.get_attr("strides"),
+ op.get_attr("padding"), op.get_attr("use_cudnn_on_gpu"),
+ op.get_attr("data_format"))]
+
+
+@ops.RegisterGradient("Conv3D")
+def _Conv3DGrad(op, grad):
+ return [nn_ops.conv3d_backprop_input(op.inputs[0],
+ op.inputs[1],
+ grad,
+ strides=op.get_attr("strides"),
+ padding=op.get_attr("padding")),
+ nn_ops.conv3d_backprop_filter(op.inputs[0],
+ op.inputs[1],
+ grad,
+ strides=op.get_attr("strides"),
+ padding=op.get_attr("padding"))]
+
+
+@ops.RegisterGradient("AvgPool3D")
+def _AvgPool3DGrad(op, grad):
+ return nn_ops.avg_pool3d_grad(
+ array_ops.shape(op.inputs[0]),
+ grad,
+ ksize=op.get_attr("ksize"),
+ strides=op.get_attr("strides"),
+ padding=op.get_attr("padding"))
+
+
+@ops.RegisterGradient("MaxPool3D")
+def _MaxPool3DGrad(op, grad):
+ return nn_ops.max_pool3d_grad(op.inputs[0],
+ op.outputs[0],
+ grad,
+ ksize=op.get_attr("ksize"),
+ strides=op.get_attr("strides"),
+ padding=op.get_attr("padding"))
@ops.RegisterGradient("Softmax")
@@ -74,10 +107,8 @@ def _SoftmaxGrad(op, grad_softmax):
# graph-construction time? Alternatively: do different things
# depending on the dimensionality of the input tensors.
softmax = op.outputs[0]
- grad_x = ((grad_softmax -
- array_ops.reshape(math_ops.reduce_sum(grad_softmax * softmax, [1]),
- [-1, 1]))
- * softmax)
+ grad_x = ((grad_softmax - array_ops.reshape(
+ math_ops.reduce_sum(grad_softmax * softmax, [1]), [-1, 1])) * softmax)
return grad_x
@@ -128,7 +159,8 @@ def _BiasAddGradV1(unused_bias_op, received_grad):
the second one for the "bias" input of the BiasOp.
"""
reduction_dim_tensor = math_ops.range(array_ops.rank(received_grad) - 1)
- return (received_grad, math_ops.reduce_sum(received_grad, reduction_dim_tensor))
+ return (received_grad, math_ops.reduce_sum(received_grad,
+ reduction_dim_tensor))
@ops.RegisterGradient("Relu")
@@ -159,8 +191,8 @@ def _SoftsignGrad(op, grad):
@ops.RegisterGradient("ReluGrad")
def _ReluGradGrad(op, grad):
x = op.inputs[1]
- return (gen_nn_ops._relu_grad(grad, x),
- array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype))
+ return (gen_nn_ops._relu_grad(grad, x), array_ops.zeros(
+ shape=array_ops.shape(x), dtype=x.dtype))
def _BroadcastMul(vec, mat):
@@ -196,12 +228,10 @@ def _SparseSoftmaxCrossEntropyWithLogitsGrad(op, grad_0, _):
@ops.RegisterGradient("Conv2D")
def _Conv2DGrad(op, grad):
- return [nn_ops.conv2d_backprop_input(array_ops.shape(op.inputs[0]),
- op.inputs[1], grad,
- op.get_attr("strides"),
- op.get_attr("padding"),
- op.get_attr("use_cudnn_on_gpu"),
- op.get_attr("data_format")),
+ return [nn_ops.conv2d_backprop_input(
+ array_ops.shape(op.inputs[0]), op.inputs[1], grad, op.get_attr("strides"),
+ op.get_attr("padding"), op.get_attr("use_cudnn_on_gpu"),
+ op.get_attr("data_format")),
nn_ops.conv2d_backprop_filter(op.inputs[0],
array_ops.shape(op.inputs[1]), grad,
op.get_attr("strides"),
@@ -228,28 +258,30 @@ def _LRNGrad(op, grad):
bias = op.get_attr("bias")
alpha = op.get_attr("alpha")
beta = op.get_attr("beta")
- return [gen_nn_ops._lrn_grad(grad, op.inputs[0], op.outputs[0],
- depth_radius, bias, alpha, beta)]
+ return [gen_nn_ops._lrn_grad(grad, op.inputs[0], op.outputs[0], depth_radius,
+ bias, alpha, beta)]
@ops.RegisterGradient("AvgPool")
def _AvgPoolGrad(op, grad):
- return gen_nn_ops._avg_pool_grad(array_ops.shape(op.inputs[0]), grad,
- op.get_attr("ksize"),
- op.get_attr("strides"),
- op.get_attr("padding"),
- data_format=op.get_attr("data_format")
- )
+ return gen_nn_ops._avg_pool_grad(
+ array_ops.shape(op.inputs[0]),
+ grad,
+ op.get_attr("ksize"),
+ op.get_attr("strides"),
+ op.get_attr("padding"),
+ data_format=op.get_attr("data_format"))
@ops.RegisterGradient("MaxPool")
def _MaxPoolGrad(op, grad):
- return gen_nn_ops._max_pool_grad(op.inputs[0], op.outputs[0], grad,
+ return gen_nn_ops._max_pool_grad(op.inputs[0],
+ op.outputs[0],
+ grad,
op.get_attr("ksize"),
op.get_attr("strides"),
padding=op.get_attr("padding"),
- data_format=op.get_attr("data_format")
- )
+ data_format=op.get_attr("data_format"))
@ops.RegisterGradient("BatchNormWithGlobalNormalization")
@@ -328,5 +360,4 @@ def _TopKGrad(op, grad, _):
array_ops.reshape(grad, [-1]),
validate_indices=False),
in_shape), array_ops.zeros(
- [],
- dtype=dtypes.int32)]
+ [], dtype=dtypes.int32)]
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 661dc790d2..68fc9a364f 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Wrappers for primitive Neural Net (NN) Operations."""
# pylint: disable=invalid-name
@@ -37,7 +36,6 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops.gen_nn_ops import *
# pylint: enable=wildcard-import
-
# Aliases for some automatically-generated names.
local_response_normalization = gen_nn_ops.lrn
@@ -153,12 +151,13 @@ def atrous_conv2d(value, filters, rate, padding, name=None):
"value's input channels does not match filters' input channels, "
"{} != {}".format(value_shape[3], filter_shape[2]))
if rate < 1:
- raise ValueError(
- "rate {} cannot be less than one".format(rate))
+ raise ValueError("rate {} cannot be less than one".format(rate))
if rate == 1:
- value = gen_nn_ops.conv2d(input=value, filter=filters,
- strides=[1, 1, 1, 1], padding=padding)
+ value = gen_nn_ops.conv2d(input=value,
+ filter=filters,
+ strides=[1, 1, 1, 1],
+ padding=padding)
return value
# We have two padding contributions. The first is used for converting "SAME"
@@ -201,20 +200,30 @@ def atrous_conv2d(value, filters, rate, padding, name=None):
# The paddings argument to space_to_batch includes both padding components
space_to_batch_pad = [[pad_top, pad_bottom + pad_bottom_extra],
[pad_left, pad_right + pad_right_extra]]
- value = array_ops.space_to_batch(
- input=value, paddings=space_to_batch_pad, block_size=rate)
+ value = array_ops.space_to_batch(input=value,
+ paddings=space_to_batch_pad,
+ block_size=rate)
- value = gen_nn_ops.conv2d(input=value, filter=filters, strides=[1, 1, 1, 1],
- padding="VALID", name=name)
+ value = gen_nn_ops.conv2d(input=value,
+ filter=filters,
+ strides=[1, 1, 1, 1],
+ padding="VALID",
+ name=name)
# The crops argument to batch_to_space is just the extra padding component
batch_to_space_crop = [[0, pad_bottom_extra], [0, pad_right_extra]]
- value = array_ops.batch_to_space(
- input=value, crops=batch_to_space_crop, block_size=rate)
+ value = array_ops.batch_to_space(input=value,
+ crops=batch_to_space_crop,
+ block_size=rate)
return value
-def conv2d_transpose(value, filter, output_shape, strides, padding="SAME",
+
+def conv2d_transpose(value,
+ filter,
+ output_shape,
+ strides,
+ padding="SAME",
name=None):
"""The transpose of `conv2d`.
@@ -248,9 +257,9 @@ def conv2d_transpose(value, filter, output_shape, strides, padding="SAME",
value = ops.convert_to_tensor(value, name="value")
filter = ops.convert_to_tensor(filter, name="filter")
if not value.get_shape()[3].is_compatible_with(filter.get_shape()[3]):
- raise ValueError(
- "input channels does not match filter's input channels, "
- "{} != {}".format(value.get_shape()[3], filter.get_shape()[3]))
+ raise ValueError("input channels does not match filter's input channels, "
+ "{} != {}".format(value.get_shape()[3], filter.get_shape(
+ )[3]))
output_shape_ = ops.convert_to_tensor(output_shape, name="output_shape")
if not output_shape_.get_shape().is_compatible_with(tensor_shape.vector(4)):
@@ -302,8 +311,8 @@ def bias_add(value, bias, data_format=None, name=None):
bias = ops.convert_to_tensor(bias, dtype=value.dtype, name="bias")
return gen_nn_ops._bias_add(value, bias, data_format=data_format, name=name)
-ops.RegisterShape("BiasAdd")(common_shapes.bias_add_shape)
+ops.RegisterShape("BiasAdd")(common_shapes.bias_add_shape)
ops.RegisterShape("BiasAddGrad")(common_shapes.bias_add_grad_shape)
@@ -338,7 +347,6 @@ def bias_add_v1(value, bias, name=None):
ops.RegisterShape("BiasAddV1")(common_shapes.bias_add_shape)
-
ops.RegisterShape("BiasAddGradV1")(common_shapes.bias_add_grad_shape)
@@ -490,7 +498,9 @@ def avg_pool(value, ksize, strides, padding, data_format="NHWC", name=None):
"""
with ops.op_scope([value], name, "AvgPool") as name:
value = ops.convert_to_tensor(value, name="input")
- return gen_nn_ops._avg_pool(value, ksize=ksize, strides=strides,
+ return gen_nn_ops._avg_pool(value,
+ ksize=ksize,
+ strides=strides,
padding=padding,
data_format=data_format,
name=name)
@@ -515,7 +525,9 @@ def max_pool(value, ksize, strides, padding, data_format="NHWC", name=None):
"""
with ops.op_scope([value], name, "MaxPool") as name:
value = ops.convert_to_tensor(value, name="input")
- return gen_nn_ops._max_pool(value, ksize=ksize, strides=strides,
+ return gen_nn_ops._max_pool(value,
+ ksize=ksize,
+ strides=strides,
padding=padding,
data_format=data_format,
name=name)
@@ -547,7 +559,6 @@ def _BinaryElementwiseShape(op):
ops.RegisterShape("L2Loss")(common_shapes.scalar_shape)
-
ops.RegisterShape("LRN")(common_shapes.unchanged_shape_with_rank(4))
@@ -560,12 +571,9 @@ def _LRNGradShape(op):
return [in_grads_shape.merge_with(in_image_shape).merge_with(out_image_shape)]
-ops.RegisterShape("Softmax")(
- common_shapes.unchanged_shape_with_rank(2))
-
+ops.RegisterShape("Softmax")(common_shapes.unchanged_shape_with_rank(2))
-ops.RegisterShape("LogSoftmax")(
- common_shapes.unchanged_shape_with_rank(2))
+ops.RegisterShape("LogSoftmax")(common_shapes.unchanged_shape_with_rank(2))
@ops.RegisterShape("InTopK")
@@ -744,6 +752,93 @@ def _calc_conv_weight_params(graph, node):
filter_in_depth * filter_out_depth))
+@ops.RegisterShape("Conv3D")
+def _Conv3DShape(op):
+ """Shape function for Conv3D."""
+ input_shape = op.inputs[0].get_shape().with_rank(5)
+ filter_shape = op.inputs[1].get_shape().with_rank(5)
+
+ batch_size = input_shape[0]
+ out_channels = filter_shape[4]
+ # Check that the input number of channels is compatible between
+ # input data and filter size.
+ input_shape[4].assert_is_compatible_with(filter_shape[3])
+
+ stride_b, stride_p, stride_r, stride_c, stride_d = op.get_attr("strides")
+ assert stride_b == 1
+ assert stride_d == 1
+
+ padding_type = op.get_attr("padding")
+ out_planes, out_rows, out_cols = common_shapes.get_conv_output_size(
+ input_shape[1:4], filter_shape[0:3], (stride_p, stride_r, stride_c),
+ padding_type)
+
+ return [tensor_shape.TensorShape([batch_size, out_planes, out_rows, out_cols,
+ out_channels])]
+
+
+@ops.RegisterShape("MaxPool3D")
+@ops.RegisterShape("AvgPool3D")
+def _Pool3DShape(op):
+ """Shape function for Max/AvgPool3D."""
+ input_shape = op.inputs[0].get_shape().with_rank(5)
+ ksize_b, ksize_p, ksize_r, ksize_c, ksize_d = op.get_attr("ksize")
+ assert ksize_b == 1
+ assert ksize_d == 1
+
+ stride_b, stride_p, stride_r, stride_c, stride_d = op.get_attr("strides")
+ assert stride_b == 1
+ assert stride_d == 1
+
+ batch_size = input_shape[0]
+ channels = input_shape[4]
+
+ padding = op.get_attr("padding")
+ out_planes, out_rows, out_cols = common_shapes.get_conv_output_size(
+ input_shape[1:4], (ksize_p, ksize_r, ksize_c),
+ (stride_p, stride_r, stride_c), padding)
+ return [tensor_shape.TensorShape([batch_size, out_planes, out_rows, out_cols,
+ channels])]
+
+
+def _ShapeOrUnknown(input_shape, ndims=5):
+ if input_shape == None: # pylint:disable=g-equals-none
+ return [tensor_shape.unknown_shape(ndims=ndims)]
+ else:
+ return [input_shape]
+
+
+@ops.RegisterShape("Conv3DBackpropFilter")
+def _Conv3DBackpropFilterShape(op):
+ """Shape function for the Conv3DBackpropFilter op."""
+ filter_shape = op.inputs[1].get_shape()
+ return _ShapeOrUnknown(filter_shape)
+
+
+@ops.RegisterShape("Conv3DBackpropInput")
+def _Conv3DBackpropInputShape(op):
+ """Shape function for the Conv3DBackpropInput op."""
+ input_shape = op.inputs[0].get_shape()
+ return _ShapeOrUnknown(input_shape)
+
+
+@ops.RegisterShape("AvgPool3DGrad")
+def _AvgPool3DGradShape(op):
+ """Shape function for the AvgPool3DGrad op."""
+ orig_input_shape = tensor_util.constant_value(op.inputs[0])
+ if orig_input_shape != None: # pylint:disable=g-equals-none
+ return [tensor_shape.TensorShape(orig_input_shape.tolist())]
+ else:
+ return [tensor_shape.unknown_shape(ndims=5)]
+
+
+@ops.RegisterShape("MaxPool3DGrad")
+def _MaxPool3DGradShape(op):
+ """Shape function for the MaxPoolGrad op."""
+ orig_input_shape = op.inputs[0].get_shape().with_rank(5)
+ return [orig_input_shape]
+
+
@ops.RegisterStatistics("BiasAdd", "flops")
def _calc_bias_add_flops(graph, node):
"""Calculates the computing needed for BiasAdd."""
@@ -846,15 +941,17 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None):
if isinstance(keep_prob, float) and not 0 < keep_prob <= 1:
raise ValueError("keep_prob must be a scalar tensor or a float in the "
"range (0, 1], got %g" % keep_prob)
- keep_prob = ops.convert_to_tensor(
- keep_prob, dtype=x.dtype, name="keep_prob")
+ keep_prob = ops.convert_to_tensor(keep_prob,
+ dtype=x.dtype,
+ name="keep_prob")
keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar())
noise_shape = noise_shape if noise_shape is not None else array_ops.shape(x)
# uniform [keep_prob, 1.0 + keep_prob)
random_tensor = keep_prob
- random_tensor += random_ops.random_uniform(
- noise_shape, seed=seed, dtype=x.dtype)
+ random_tensor += random_ops.random_uniform(noise_shape,
+ seed=seed,
+ dtype=x.dtype)
# 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
binary_tensor = math_ops.floor(random_tensor)
ret = x * math_ops.inv(keep_prob) * binary_tensor
@@ -890,5 +987,4 @@ def top_k(input, k=1, sorted=True, name=None):
"""
return gen_nn_ops._top_kv2(input, k=k, sorted=sorted, name=name)
-
# pylint: enable=invalid-name